From 95b25eddd7594044d2199d6bdc00aff1cad73a8b Mon Sep 17 00:00:00 2001 From: Oleg Sheynin Date: Wed, 18 Jun 2025 14:32:11 -0400 Subject: [PATCH] progress --- configuration/equity.cfg | 7 +- src/pt_backtest.py | 25 +++++- src/results.py | 65 ++++++++++++++- src/strategies.py | 24 ++++-- src/utils/db_inspector.py | 169 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 274 insertions(+), 16 deletions(-) create mode 100644 src/utils/db_inspector.py diff --git a/configuration/equity.cfg b/configuration/equity.cfg index 1c90b9d..e53c8e9 100644 --- a/configuration/equity.cfg +++ b/configuration/equity.cfg @@ -15,10 +15,6 @@ "db_table_name": "md_1min_bars", "exchange_id": "ALPACA", "instrument_id_pfx": "STOCK-", - # "instruments": [ - # "COIN", - # "GBTC" - # ], "trading_hours": { "begin_session": "9:30:00", "end_session": "16:00:00", @@ -31,5 +27,6 @@ "dis-equilibrium_close_trshld": 1.0, "training_minutes": 120, "funding_per_pair": 2000.0, - "strategy_class": "strategies.StaticFitStrategy" + # "strategy_class": "strategies.StaticFitStrategy" + "strategy_class": "strategies.SlidingFitStrategy" } \ No newline at end of file diff --git a/src/pt_backtest.py b/src/pt_backtest.py index a28a235..ee2b380 100644 --- a/src/pt_backtest.py +++ b/src/pt_backtest.py @@ -12,7 +12,7 @@ import pandas as pd from tools.data_loader import load_market_data from tools.trading_pair import TradingPair -from results import BacktestResult, create_result_database, store_results_in_database +from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database def load_config(config_path: str) -> Dict: @@ -202,6 +202,29 @@ def main() -> None: # Initialize a dictionary to store all trade results all_results: Dict[str, Dict[str, Any]] = {} bt_results = BacktestResult(config=config) + + # Store configuration in database for reference + if args.result_db.upper() != "NONE": + # Get list of all instruments for storage + all_instruments = [] + for datafile in datafiles: + if args.instruments: + file_instruments = [inst.strip() for inst in args.instruments.split(",")] + else: + file_instruments = get_available_instruments_from_db(datafile, config) + all_instruments.extend(file_instruments) + + # Remove duplicates while preserving order + unique_instruments = list(dict.fromkeys(all_instruments)) + + store_config_in_database( + db_path=args.result_db, + config_file_path=args.config, + config=config, + strategy_class=strategy_class_name, + datafiles=datafiles, + instruments=unique_instruments + ) # Process each data file price_column = config["price_column"] diff --git a/src/results.py b/src/results.py index 143a439..67480e5 100644 --- a/src/results.py +++ b/src/results.py @@ -32,7 +32,7 @@ sqlite3.register_converter("datetime", convert_datetime) def create_result_database(db_path: str) -> None: """ - Create the SQLite database and pt_bt_results table if they don't exist. + Create the SQLite database and required tables if they don't exist. """ try: conn = sqlite3.connect(db_path) @@ -58,7 +58,8 @@ def create_result_database(db_path: str) -> None: pair_return REAL ) ''') - + cursor.execute("DELETE FROM pt_bt_results;") + # Create the outstanding_positions table for open positions cursor.execute(''' CREATE TABLE IF NOT EXISTS outstanding_positions ( @@ -72,6 +73,21 @@ def create_result_database(db_path: str) -> None: open_side TEXT ) ''') + cursor.execute("DELETE FROM outstanding_positions;") + + # Create the config table for storing configuration JSON for reference + cursor.execute(''' + CREATE TABLE IF NOT EXISTS config ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_timestamp DATETIME, + config_file_path TEXT, + config_json TEXT, + strategy_class TEXT, + datafiles TEXT, + instruments TEXT + ) + ''') + cursor.execute("DELETE FROM config;") conn.commit() conn.close() @@ -81,6 +97,51 @@ def create_result_database(db_path: str) -> None: raise +def store_config_in_database(db_path: str, config_file_path: str, config: Dict, strategy_class: str, datafiles: List[str], instruments: List[str]) -> None: + """ + Store configuration information in the database for reference. + """ + import json + + if db_path.upper() == "NONE": + return + + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Convert config to JSON string + config_json = json.dumps(config, indent=2, default=str) + + # Convert lists to comma-separated strings for storage + datafiles_str = ', '.join(datafiles) + instruments_str = ', '.join(instruments) + + # Insert configuration record + cursor.execute(''' + INSERT INTO config ( + run_timestamp, config_file_path, config_json, strategy_class, datafiles, instruments + ) VALUES (?, ?, ?, ?, ?, ?) + ''', ( + datetime.now(), + config_file_path, + config_json, + strategy_class, + datafiles_str, + instruments_str + )) + + conn.commit() + conn.close() + + print(f"Configuration stored in database") + + except Exception as e: + print(f"Error storing configuration in database: {str(e)}") + import traceback + traceback.print_exc() + + def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestResult') -> None: """ Store backtest results in the SQLite database. diff --git a/src/strategies.py b/src/strategies.py index 0ef87e0..1683cd7 100644 --- a/src/strategies.py +++ b/src/strategies.py @@ -4,7 +4,7 @@ import sys from typing import Dict, Optional -import pandas as pd +import pandas as pd # type: ignore from tools.trading_pair import TradingPair from results import BacktestResult @@ -22,7 +22,7 @@ class PairsTradingStrategy(ABC): "pair", ] @abstractmethod - def run_pair(self, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: + def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: ... class StaticFitStrategy(PairsTradingStrategy): @@ -49,7 +49,7 @@ class StaticFitStrategy(PairsTradingStrategy): return pair_trades def create_trading_signals(self, pair: TradingPair, config: Dict, result: BacktestResult) -> pd.DataFrame: - beta = pair.vecm_fit_.beta + beta = pair.vecm_fit_.beta # type: ignore colname_a, colname_b = pair.colnames() predicted_df = pair.predicted_df_ @@ -229,7 +229,7 @@ class SlidingFitStrategy(PairsTradingStrategy): testing_size=1 ) - if len(pair.training_df_) < training_minutes: + if len(pair.training_df_) < training_minutes: # type: ignore print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.") if pair.user_data_["state"] == PairState.OPEN: print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.") @@ -251,7 +251,7 @@ class SlidingFitStrategy(PairsTradingStrategy): try: is_cointegrated = pair.train_pair() except Exception as e: - raise Exception(f"{pair}: Training failed: {str(e)}") from e + raise RuntimeError(f"{pair}: Training failed: {str(e)}") from e if pair.user_data_["is_cointegrated"] != is_cointegrated: pair.user_data_["is_cointegrated"] = is_cointegrated @@ -271,7 +271,7 @@ class SlidingFitStrategy(PairsTradingStrategy): try: pair.predict() except Exception as e: - raise Exception(f"{pair}: Prediction failed: {str(e)}") from e + raise RuntimeError(f"{pair}: Prediction failed: {str(e)}") from e if pair.user_data_["state"] == PairState.INITIAL: @@ -295,8 +295,12 @@ class SlidingFitStrategy(PairsTradingStrategy): colname_a, colname_b = pair.colnames() predicted_df = pair.predicted_df_ + + # Check if we have any data to work with + if len(predicted_df) == 0: + return None - open_row = predicted_df.loc[0] + open_row = predicted_df.iloc[0] open_tstamp = open_row["tstamp"] open_disequilibrium = open_row["disequilibrium"] open_scaled_disequilibrium = open_row["scaled_disequilibrium"] @@ -359,7 +363,11 @@ class SlidingFitStrategy(PairsTradingStrategy): def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]: colname_a, colname_b = pair.colnames() - close_row = pair.predicted_df_.loc[0] + # Check if we have any data to work with + if len(pair.predicted_df_) == 0: + return None + + close_row = pair.predicted_df_.iloc[0] close_tstamp = close_row["tstamp"] close_disequilibrium = close_row["disequilibrium"] close_scaled_disequilibrium = close_row["scaled_disequilibrium"] diff --git a/src/utils/db_inspector.py b/src/utils/db_inspector.py new file mode 100644 index 0000000..3267920 --- /dev/null +++ b/src/utils/db_inspector.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +Database inspector utility for pairs trading results database. +Provides functionality to view all tables and their contents. +""" + +import sqlite3 +import sys +import json +import os +from typing import List, Dict, Any + +def list_tables(db_path: str) -> List[str]: + """List all tables in the database.""" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + cursor.execute(""" + SELECT name FROM sqlite_master + WHERE type='table' + ORDER BY name + """) + + tables = [row[0] for row in cursor.fetchall()] + conn.close() + return tables + +def view_table_schema(db_path: str, table_name: str): + """View the schema of a specific table.""" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + cursor.execute(f"PRAGMA table_info({table_name})") + columns = cursor.fetchall() + + print(f"\nTable: {table_name}") + print("-" * 50) + print("Column Name".ljust(20) + "Type".ljust(15) + "Not Null".ljust(10) + "Default") + print("-" * 50) + + for col in columns: + cid, name, type_, not_null, default_value, pk = col + print(f"{name}".ljust(20) + f"{type_}".ljust(15) + f"{bool(not_null)}".ljust(10) + f"{default_value or ''}") + + conn.close() + +def view_config_table(db_path: str, limit: int = 10): + """View entries from the config table.""" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + cursor.execute(f""" + SELECT id, run_timestamp, config_file_path, strategy_class, + datafiles, instruments, config_json + FROM config + ORDER BY run_timestamp DESC + LIMIT {limit} + """) + + rows = cursor.fetchall() + + if not rows: + print("No configuration entries found.") + return + + print(f"\nMost recent {len(rows)} configuration entries:") + print("=" * 80) + + for row in rows: + id, run_timestamp, config_file_path, strategy_class, datafiles, instruments, config_json = row + + print(f"ID: {id} | {run_timestamp}") + print(f"Config: {config_file_path} | Strategy: {strategy_class}") + print(f"Files: {datafiles}") + print(f"Instruments: {instruments}") + print("-" * 80) + + conn.close() + +def view_results_summary(db_path: str): + """View summary of trading results.""" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Get results summary + cursor.execute(""" + SELECT date, COUNT(*) as trade_count, + ROUND(SUM(symbol_return), 2) as total_return + FROM pt_bt_results + GROUP BY date + ORDER BY date DESC + """) + + results = cursor.fetchall() + + if not results: + print("No trading results found.") + return + + print(f"\nTrading Results Summary:") + print("-" * 50) + print("Date".ljust(15) + "Trades".ljust(10) + "Total Return %") + print("-" * 50) + + for date, trade_count, total_return in results: + print(f"{date}".ljust(15) + f"{trade_count}".ljust(10) + f"{total_return}") + + # Get outstanding positions summary + cursor.execute(""" + SELECT COUNT(*) as position_count, + ROUND(SUM(unrealized_return), 2) as total_unrealized + FROM outstanding_positions + """) + + outstanding = cursor.fetchone() + if outstanding and outstanding[0] > 0: + print(f"\nOutstanding Positions: {outstanding[0]} positions") + print(f"Total Unrealized Return: {outstanding[1]}%") + + conn.close() + +def main(): + if len(sys.argv) < 2: + print("Usage: python db_inspector.py [command]") + print("Commands:") + print(" tables - List all tables") + print(" schema - Show schema for all tables") + print(" config - View configuration entries") + print(" results - View trading results summary") + print(" all - Show everything (default)") + print("\nExample: python db_inspector.py results/equity.db config") + sys.exit(1) + + db_path = sys.argv[1] + command = sys.argv[2] if len(sys.argv) > 2 else "all" + + if not os.path.exists(db_path): + print(f"Database file not found: {db_path}") + sys.exit(1) + + try: + if command in ["tables", "all"]: + tables = list_tables(db_path) + print(f"Tables in database: {', '.join(tables)}") + + if command in ["schema", "all"]: + tables = list_tables(db_path) + for table in tables: + view_table_schema(db_path, table) + + if command in ["config", "all"]: + if "config" in list_tables(db_path): + view_config_table(db_path) + else: + print("Config table not found.") + + if command in ["results", "all"]: + if "pt_bt_results" in list_tables(db_path): + view_results_summary(db_path) + else: + print("Results table not found.") + + except Exception as e: + print(f"Error inspecting database: {str(e)}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file