This commit is contained in:
Oleg Sheynin 2025-06-18 14:32:11 -04:00
parent 9240d20e16
commit 95b25eddd7
5 changed files with 274 additions and 16 deletions

View File

@ -15,10 +15,6 @@
"db_table_name": "md_1min_bars", "db_table_name": "md_1min_bars",
"exchange_id": "ALPACA", "exchange_id": "ALPACA",
"instrument_id_pfx": "STOCK-", "instrument_id_pfx": "STOCK-",
# "instruments": [
# "COIN",
# "GBTC"
# ],
"trading_hours": { "trading_hours": {
"begin_session": "9:30:00", "begin_session": "9:30:00",
"end_session": "16:00:00", "end_session": "16:00:00",
@ -31,5 +27,6 @@
"dis-equilibrium_close_trshld": 1.0, "dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120, "training_minutes": 120,
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
"strategy_class": "strategies.StaticFitStrategy" # "strategy_class": "strategies.StaticFitStrategy"
"strategy_class": "strategies.SlidingFitStrategy"
} }

View File

@ -12,7 +12,7 @@ import pandas as pd
from tools.data_loader import load_market_data from tools.data_loader import load_market_data
from tools.trading_pair import TradingPair from tools.trading_pair import TradingPair
from results import BacktestResult, create_result_database, store_results_in_database from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database
def load_config(config_path: str) -> Dict: def load_config(config_path: str) -> Dict:
@ -203,6 +203,29 @@ def main() -> None:
all_results: Dict[str, Dict[str, Any]] = {} all_results: Dict[str, Dict[str, Any]] = {}
bt_results = BacktestResult(config=config) bt_results = BacktestResult(config=config)
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [inst.strip() for inst in args.instruments.split(",")]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
strategy_class=strategy_class_name,
datafiles=datafiles,
instruments=unique_instruments
)
# Process each data file # Process each data file
price_column = config["price_column"] price_column = config["price_column"]

View File

@ -32,7 +32,7 @@ sqlite3.register_converter("datetime", convert_datetime)
def create_result_database(db_path: str) -> None: def create_result_database(db_path: str) -> None:
""" """
Create the SQLite database and pt_bt_results table if they don't exist. Create the SQLite database and required tables if they don't exist.
""" """
try: try:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
@ -58,6 +58,7 @@ def create_result_database(db_path: str) -> None:
pair_return REAL pair_return REAL
) )
''') ''')
cursor.execute("DELETE FROM pt_bt_results;")
# Create the outstanding_positions table for open positions # Create the outstanding_positions table for open positions
cursor.execute(''' cursor.execute('''
@ -72,6 +73,21 @@ def create_result_database(db_path: str) -> None:
open_side TEXT open_side TEXT
) )
''') ''')
cursor.execute("DELETE FROM outstanding_positions;")
# Create the config table for storing configuration JSON for reference
cursor.execute('''
CREATE TABLE IF NOT EXISTS config (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_timestamp DATETIME,
config_file_path TEXT,
config_json TEXT,
strategy_class TEXT,
datafiles TEXT,
instruments TEXT
)
''')
cursor.execute("DELETE FROM config;")
conn.commit() conn.commit()
conn.close() conn.close()
@ -81,6 +97,51 @@ def create_result_database(db_path: str) -> None:
raise raise
def store_config_in_database(db_path: str, config_file_path: str, config: Dict, strategy_class: str, datafiles: List[str], instruments: List[str]) -> None:
"""
Store configuration information in the database for reference.
"""
import json
if db_path.upper() == "NONE":
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Convert config to JSON string
config_json = json.dumps(config, indent=2, default=str)
# Convert lists to comma-separated strings for storage
datafiles_str = ', '.join(datafiles)
instruments_str = ', '.join(instruments)
# Insert configuration record
cursor.execute('''
INSERT INTO config (
run_timestamp, config_file_path, config_json, strategy_class, datafiles, instruments
) VALUES (?, ?, ?, ?, ?, ?)
''', (
datetime.now(),
config_file_path,
config_json,
strategy_class,
datafiles_str,
instruments_str
))
conn.commit()
conn.close()
print(f"Configuration stored in database")
except Exception as e:
print(f"Error storing configuration in database: {str(e)}")
import traceback
traceback.print_exc()
def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestResult') -> None: def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestResult') -> None:
""" """
Store backtest results in the SQLite database. Store backtest results in the SQLite database.

View File

@ -4,7 +4,7 @@ import sys
from typing import Dict, Optional from typing import Dict, Optional
import pandas as pd import pandas as pd # type: ignore
from tools.trading_pair import TradingPair from tools.trading_pair import TradingPair
from results import BacktestResult from results import BacktestResult
@ -22,7 +22,7 @@ class PairsTradingStrategy(ABC):
"pair", "pair",
] ]
@abstractmethod @abstractmethod
def run_pair(self, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]:
... ...
class StaticFitStrategy(PairsTradingStrategy): class StaticFitStrategy(PairsTradingStrategy):
@ -49,7 +49,7 @@ class StaticFitStrategy(PairsTradingStrategy):
return pair_trades return pair_trades
def create_trading_signals(self, pair: TradingPair, config: Dict, result: BacktestResult) -> pd.DataFrame: def create_trading_signals(self, pair: TradingPair, config: Dict, result: BacktestResult) -> pd.DataFrame:
beta = pair.vecm_fit_.beta beta = pair.vecm_fit_.beta # type: ignore
colname_a, colname_b = pair.colnames() colname_a, colname_b = pair.colnames()
predicted_df = pair.predicted_df_ predicted_df = pair.predicted_df_
@ -229,7 +229,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
testing_size=1 testing_size=1
) )
if len(pair.training_df_) < training_minutes: if len(pair.training_df_) < training_minutes: # type: ignore
print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.") print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.")
if pair.user_data_["state"] == PairState.OPEN: if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.") print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.")
@ -251,7 +251,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
try: try:
is_cointegrated = pair.train_pair() is_cointegrated = pair.train_pair()
except Exception as e: except Exception as e:
raise Exception(f"{pair}: Training failed: {str(e)}") from e raise RuntimeError(f"{pair}: Training failed: {str(e)}") from e
if pair.user_data_["is_cointegrated"] != is_cointegrated: if pair.user_data_["is_cointegrated"] != is_cointegrated:
pair.user_data_["is_cointegrated"] = is_cointegrated pair.user_data_["is_cointegrated"] = is_cointegrated
@ -271,7 +271,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
try: try:
pair.predict() pair.predict()
except Exception as e: except Exception as e:
raise Exception(f"{pair}: Prediction failed: {str(e)}") from e raise RuntimeError(f"{pair}: Prediction failed: {str(e)}") from e
if pair.user_data_["state"] == PairState.INITIAL: if pair.user_data_["state"] == PairState.INITIAL:
@ -296,7 +296,11 @@ class SlidingFitStrategy(PairsTradingStrategy):
predicted_df = pair.predicted_df_ predicted_df = pair.predicted_df_
open_row = predicted_df.loc[0] # Check if we have any data to work with
if len(predicted_df) == 0:
return None
open_row = predicted_df.iloc[0]
open_tstamp = open_row["tstamp"] open_tstamp = open_row["tstamp"]
open_disequilibrium = open_row["disequilibrium"] open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"] open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
@ -359,7 +363,11 @@ class SlidingFitStrategy(PairsTradingStrategy):
def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]: def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames() colname_a, colname_b = pair.colnames()
close_row = pair.predicted_df_.loc[0] # Check if we have any data to work with
if len(pair.predicted_df_) == 0:
return None
close_row = pair.predicted_df_.iloc[0]
close_tstamp = close_row["tstamp"] close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"] close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"] close_scaled_disequilibrium = close_row["scaled_disequilibrium"]

169
src/utils/db_inspector.py Normal file
View File

@ -0,0 +1,169 @@
#!/usr/bin/env python3
"""
Database inspector utility for pairs trading results database.
Provides functionality to view all tables and their contents.
"""
import sqlite3
import sys
import json
import os
from typing import List, Dict, Any
def list_tables(db_path: str) -> List[str]:
"""List all tables in the database."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table'
ORDER BY name
""")
tables = [row[0] for row in cursor.fetchall()]
conn.close()
return tables
def view_table_schema(db_path: str, table_name: str):
"""View the schema of a specific table."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(f"PRAGMA table_info({table_name})")
columns = cursor.fetchall()
print(f"\nTable: {table_name}")
print("-" * 50)
print("Column Name".ljust(20) + "Type".ljust(15) + "Not Null".ljust(10) + "Default")
print("-" * 50)
for col in columns:
cid, name, type_, not_null, default_value, pk = col
print(f"{name}".ljust(20) + f"{type_}".ljust(15) + f"{bool(not_null)}".ljust(10) + f"{default_value or ''}")
conn.close()
def view_config_table(db_path: str, limit: int = 10):
"""View entries from the config table."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(f"""
SELECT id, run_timestamp, config_file_path, strategy_class,
datafiles, instruments, config_json
FROM config
ORDER BY run_timestamp DESC
LIMIT {limit}
""")
rows = cursor.fetchall()
if not rows:
print("No configuration entries found.")
return
print(f"\nMost recent {len(rows)} configuration entries:")
print("=" * 80)
for row in rows:
id, run_timestamp, config_file_path, strategy_class, datafiles, instruments, config_json = row
print(f"ID: {id} | {run_timestamp}")
print(f"Config: {config_file_path} | Strategy: {strategy_class}")
print(f"Files: {datafiles}")
print(f"Instruments: {instruments}")
print("-" * 80)
conn.close()
def view_results_summary(db_path: str):
"""View summary of trading results."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Get results summary
cursor.execute("""
SELECT date, COUNT(*) as trade_count,
ROUND(SUM(symbol_return), 2) as total_return
FROM pt_bt_results
GROUP BY date
ORDER BY date DESC
""")
results = cursor.fetchall()
if not results:
print("No trading results found.")
return
print(f"\nTrading Results Summary:")
print("-" * 50)
print("Date".ljust(15) + "Trades".ljust(10) + "Total Return %")
print("-" * 50)
for date, trade_count, total_return in results:
print(f"{date}".ljust(15) + f"{trade_count}".ljust(10) + f"{total_return}")
# Get outstanding positions summary
cursor.execute("""
SELECT COUNT(*) as position_count,
ROUND(SUM(unrealized_return), 2) as total_unrealized
FROM outstanding_positions
""")
outstanding = cursor.fetchone()
if outstanding and outstanding[0] > 0:
print(f"\nOutstanding Positions: {outstanding[0]} positions")
print(f"Total Unrealized Return: {outstanding[1]}%")
conn.close()
def main():
if len(sys.argv) < 2:
print("Usage: python db_inspector.py <database_path> [command]")
print("Commands:")
print(" tables - List all tables")
print(" schema - Show schema for all tables")
print(" config - View configuration entries")
print(" results - View trading results summary")
print(" all - Show everything (default)")
print("\nExample: python db_inspector.py results/equity.db config")
sys.exit(1)
db_path = sys.argv[1]
command = sys.argv[2] if len(sys.argv) > 2 else "all"
if not os.path.exists(db_path):
print(f"Database file not found: {db_path}")
sys.exit(1)
try:
if command in ["tables", "all"]:
tables = list_tables(db_path)
print(f"Tables in database: {', '.join(tables)}")
if command in ["schema", "all"]:
tables = list_tables(db_path)
for table in tables:
view_table_schema(db_path, table)
if command in ["config", "all"]:
if "config" in list_tables(db_path):
view_config_table(db_path)
else:
print("Config table not found.")
if command in ["results", "all"]:
if "pt_bt_results" in list_tables(db_path):
view_results_summary(db_path)
else:
print("Results table not found.")
except Exception as e:
print(f"Error inspecting database: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()