pairs_trading/research/pt_backtest.py
2025-07-13 22:33:48 +00:00

258 lines
8.7 KiB
Python

import argparse
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import pandas as pd
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
"""
Resolve the list of data files to process.
CLI datafiles take priority over config datafiles.
Supports wildcards in config but not in CLI.
"""
if cli_datafiles:
# CLI override - comma-separated list, no wildcards
datafiles = [f.strip() for f in cli_datafiles.split(",")]
# Make paths absolute relative to data directory
data_dir = config.get("data_directory", "./data")
resolved_files = []
for df in datafiles:
if not os.path.isabs(df):
df = os.path.join(data_dir, df)
resolved_files.append(df)
return resolved_files
# Use config datafiles with wildcard support
config_datafiles = config.get("datafiles", [])
data_dir = config.get("data_directory", "./data")
resolved_files = []
for pattern in config_datafiles:
if "*" in pattern or "?" in pattern:
# Handle wildcards
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
matched_files = glob.glob(pattern)
resolved_files.extend(matched_files)
else:
# Handle explicit file path
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
resolved_files.append(pattern)
return sorted(list(set(resolved_files))) # Remove duplicates and sort
def run_backtest(
config: Dict,
datafile: str,
price_column: str,
fit_method: PairsTradingFitMethod,
instruments: List[str],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy)
for a_index, b_index in unique_index_pairs:
pair = TradingPair(
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
price_column=price_column,
)
pairs.append(pair)
return pairs
pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
print(f"pairs_trades: {pairs_trades}")
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
return bt_result
result = pd.concat(pairs_trades, ignore_index=True)
result["time"] = pd.to_datetime(result["time"])
result = result.set_index("time").sort_index()
bt_result.collect_single_day_results(result)
return bt_result
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafiles",
type=str,
required=False,
help="Comma-separated list of data files (overrides config). No wildcards supported.",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
parser.add_argument(
"--result_db",
type=str,
required=True,
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles)
if not datafiles:
print("No data files found to process.")
return
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
# Create result database if needed
if args.result_db.upper() != "NONE":
args.result_db = expand_filename(args.result_db)
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
)
# Process each data file
price_column = config["price_column"]
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found for {datafile}, skipping...")
continue
# Process data for this file
try:
fit_method.reset()
bt_results = run_backtest(
config=config,
datafile=datafile,
price_column=price_column,
fit_method=fit_method,
instruments=instruments,
)
# Store results with file name as key
filename = os.path.basename(datafile)
all_results[filename] = {"trades": bt_results.trades.copy()}
# Store results in database
if args.result_db.upper() != "NONE":
store_results_in_database(args.result_db, datafile, bt_results)
print(f"Successfully processed {filename}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
# Calculate and print results using a new BacktestResult instance for aggregation
if all_results:
aggregate_bt_results = BacktestResult(config=config)
aggregate_bt_results.calculate_returns(all_results)
aggregate_bt_results.print_grand_totals()
aggregate_bt_results.print_outstanding_positions()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")
else:
print("No results to display.")
if __name__ == "__main__":
main()