pairs_trading/research/pt_backtest.py
Oleg Sheynin c2f701e3a2 progress
2025-07-25 20:20:23 +00:00

233 lines
7.8 KiB
Python

import argparse
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
from research.research_tools import create_pairs
from tools.config import expand_filename, load_config
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
)
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
DayT = str
DataFileNameT = str
def resolve_datafiles(
config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
) -> List[Tuple[DayT, DataFileNameT]]:
resolved_files: List[Tuple[DayT, DataFileNameT]] = []
for inst in instruments:
pattern = date_pattern
inst_type = inst["instrument_type"]
data_dir = config["market_data_loading"][inst_type]["data_directory"]
if "*" in pattern or "?" in pattern:
# Handle wildcards
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
matched_files = glob.glob(pattern)
for matched_file in matched_files:
import re
match = re.search(r"(\d{8})\.mktdata\.ohlcv\.db$", matched_file)
assert match is not None
day = match.group(1)
resolved_files.append((day, matched_file))
else:
# Handle explicit file path
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
resolved_files.append((date_pattern, pattern))
return sorted(list(set(resolved_files))) # Remove duplicates and sort
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
instruments = [
{
"symbol": inst.split(":")[0],
"instrument_type": inst.split(":")[1],
"exchange_id": inst.split(":")[2],
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
"instrument_id_pfx"
],
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
"db_table_name"
],
}
for inst in args.instruments.split(",")
]
return instruments
def run_backtest(
config: Dict,
datafiles: List[str],
fit_method: PairsTradingFitMethod,
instruments: List[Dict[str, str]],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
# if len(datafiles) < 2:
# print(f"WARNING: insufficient data files: {datafiles}")
# return bt_result
if not all([os.path.exists(datafile) for datafile in datafiles]):
print(f"WARNING: data file {datafiles} does not exist")
return bt_result
pairs_trades = []
pairs = create_pairs(
datafiles=datafiles,
fit_method=fit_method,
config=config,
instruments=instruments,
)
for pair in pairs:
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
print(f"pairs_trades:\n{pairs_trades}")
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
return bt_result
bt_result.collect_single_day_results(pairs_trades)
return bt_result
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--date_pattern",
type=str,
required=True,
help="Date YYYYMMDD, allows * and ? wildcards",
)
parser.add_argument(
"--instruments",
type=str,
required=True,
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
)
parser.add_argument(
"--result_db",
type=str,
required=True,
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Dynamically instantiate fit method class
fit_method = PairsTradingFitMethod.create(config)
# Resolve data files (CLI takes priority over config)
instruments = get_instruments(args, config)
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
days = list(set([day for day, _ in datafiles]))
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
# Create result database if needed
if args.result_db.upper() != "NONE":
args.result_db = expand_filename(args.result_db)
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
is_config_stored = False
# Process each data file
for day in sorted(days):
md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day]
if not all([os.path.exists(datafile) for datafile in md_datafiles]):
print(f"WARNING: insufficient data files: {md_datafiles}")
continue
print(f"\n====== Processing {day} ======")
if not is_config_stored:
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=config["fit_method_class"],
datafiles=datafiles,
instruments=instruments,
)
is_config_stored = True
# Process data for this file
try:
fit_method.reset()
bt_results = run_backtest(
config=config,
datafiles=md_datafiles,
fit_method=fit_method,
instruments=instruments,
)
if bt_results.trades is None or len(bt_results.trades) == 0:
print(f"No trades found for {day}")
continue
# Store results with day name as key
filename = os.path.basename(day)
all_results[filename] = {
"trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(),
}
# Store results in database
if args.result_db.upper() != "NONE":
bt_results.calculate_returns(
{
filename: {
"trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(),
}
}
)
bt_results.store_results_in_database(db_path=args.result_db, day=day)
print(f"Successfully processed {filename}")
except Exception as err:
print(f"Error processing {day}: {str(err)}")
import traceback
traceback.print_exc()
# Calculate and print results using a new BacktestResult instance for aggregation
if all_results:
aggregate_bt_results = BacktestResult(config=config)
aggregate_bt_results.calculate_returns(all_results)
aggregate_bt_results.print_grand_totals()
aggregate_bt_results.print_outstanding_positions()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")
else:
print("No results to display.")
if __name__ == "__main__":
main()