import argparse import asyncio import glob import importlib import os from datetime import date, datetime from typing import Any, Dict, List, Optional import hjson import pandas as pd from tools.data_loader import get_available_instruments_from_db, load_market_data from pt_trading.results import ( BacktestResult, create_result_database, store_config_in_database, store_results_in_database, ) from pt_trading.fit_methods import PairsTradingFitMethod from pt_trading.trading_pair import TradingPair def run_strategy( config: Dict, datafile: str, fit_method: PairsTradingFitMethod, instruments: List[str], ) -> BacktestResult: """ Run backtest for all pairs using the specified instruments. """ bt_result: BacktestResult = BacktestResult(config=config) def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]: nonlocal datafile all_indexes = range(len(instruments)) unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j] pairs = [] # Update config to use the specified instruments config_copy = config.copy() config_copy["instruments"] = instruments market_data_df = load_market_data( datafile=datafile, exchange_id=config_copy["exchange_id"], instruments=config_copy["instruments"], instrument_id_pfx=config_copy["instrument_id_pfx"], db_table_name=config_copy["db_table_name"], trading_hours=config_copy["trading_hours"], ) for a_index, b_index in unique_index_pairs: pair = fit_method.create_trading_pair( market_data=market_data_df, symbol_a=instruments[a_index], symbol_b=instruments[b_index], ) pairs.append(pair) return pairs pairs_trades = [] for pair in _create_pairs(config, instruments): single_pair_trades = fit_method.run_pair( pair=pair, config=config, bt_result=bt_result ) if single_pair_trades is not None and len(single_pair_trades) > 0: pairs_trades.append(single_pair_trades) # Check if result_list has any data before concatenating if len(pairs_trades) == 0: print("No trading signals found for any pairs") return bt_result result = pd.concat(pairs_trades, ignore_index=True) result["time"] = pd.to_datetime(result["time"]) result = result.set_index("time").sort_index() bt_result.collect_single_day_results(result) return bt_result def main() -> None: parser = argparse.ArgumentParser(description="Run pairs trading backtest.") parser.add_argument( "--config", type=str, required=True, help="Path to the configuration file." ) parser.add_argument( "--datafiles", type=str, required=False, help="Comma-separated list of data files (overrides config). No wildcards supported.", ) parser.add_argument( "--instruments", type=str, required=False, help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.", ) parser.add_argument( "--result_db", type=str, required=True, help="Path to SQLite database for storing results. Use 'NONE' to disable database output.", ) args = parser.parse_args() config: Dict = load_config(args.config) # Dynamically instantiate fit method class fit_method_class_name = config.get("fit_method_class", None) assert fit_method_class_name is not None module_name, class_name = fit_method_class_name.rsplit(".", 1) module = importlib.import_module(module_name) fit_method = getattr(module, class_name)() # Resolve data files (CLI takes priority over config) datafiles = resolve_datafiles(config, args.datafiles) if not datafiles: print("No data files found to process.") return print(f"Found {len(datafiles)} data files to process:") for df in datafiles: print(f" - {df}") # Create result database if needed if args.result_db.upper() != "NONE": create_result_database(args.result_db) # Initialize a dictionary to store all trade results all_results: Dict[str, Dict[str, Any]] = {} # Store configuration in database for reference if args.result_db.upper() != "NONE": # Get list of all instruments for storage all_instruments = [] for datafile in datafiles: if args.instruments: file_instruments = [ inst.strip() for inst in args.instruments.split(",") ] else: file_instruments = get_available_instruments_from_db(datafile, config) all_instruments.extend(file_instruments) # Remove duplicates while preserving order unique_instruments = list(dict.fromkeys(all_instruments)) store_config_in_database( db_path=args.result_db, config_file_path=args.config, config=config, fit_method_class=fit_method_class_name, datafiles=datafiles, instruments=unique_instruments, ) # Process each data file for datafile in datafiles: print(f"\n====== Processing {os.path.basename(datafile)} ======") # Determine instruments to use if args.instruments: # Use CLI-specified instruments instruments = [inst.strip() for inst in args.instruments.split(",")] print(f"Using CLI-specified instruments: {instruments}") else: # Auto-detect instruments from database instruments = get_available_instruments_from_db(datafile, config) print(f"Auto-detected instruments: {instruments}") if not instruments: print(f"No instruments found for {datafile}, skipping...") continue # Process data for this file try: fit_method.reset() bt_results = run_strategy( config=config, datafile=datafile, fit_method=fit_method, instruments=instruments, ) # Store results with file name as key filename = os.path.basename(datafile) all_results[filename] = {"trades": bt_results.trades.copy()} # Store results in database if args.result_db.upper() != "NONE": store_results_in_database(args.result_db, datafile, bt_results) print(f"Successfully processed {filename}") except Exception as err: print(f"Error processing {datafile}: {str(err)}") import traceback traceback.print_exc() # Calculate and print results using a new BacktestResult instance for aggregation if all_results: aggregate_bt_results = BacktestResult(config=config) aggregate_bt_results.calculate_returns(all_results) aggregate_bt_results.print_grand_totals() aggregate_bt_results.print_outstanding_positions() if args.result_db.upper() != "NONE": print(f"\nResults stored in database: {args.result_db}") else: print("No results to display.") if __name__ == "__main__": asyncio.run(main())