233 lines
7.8 KiB
Python
233 lines
7.8 KiB
Python
import argparse
|
|
import glob
|
|
import importlib
|
|
import os
|
|
from datetime import date, datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import pandas as pd
|
|
|
|
from research.research_tools import create_pairs
|
|
from tools.config import expand_filename, load_config
|
|
from pt_trading.results import (
|
|
BacktestResult,
|
|
create_result_database,
|
|
store_config_in_database,
|
|
)
|
|
from pt_trading.fit_method import PairsTradingFitMethod
|
|
from pt_trading.trading_pair import TradingPair
|
|
|
|
DayT = str
|
|
DataFileNameT = str
|
|
|
|
def resolve_datafiles(
|
|
config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
|
|
) -> List[Tuple[DayT, DataFileNameT]]:
|
|
resolved_files: List[Tuple[DayT, DataFileNameT]] = []
|
|
for inst in instruments:
|
|
pattern = date_pattern
|
|
inst_type = inst["instrument_type"]
|
|
data_dir = config["market_data_loading"][inst_type]["data_directory"]
|
|
if "*" in pattern or "?" in pattern:
|
|
# Handle wildcards
|
|
if not os.path.isabs(pattern):
|
|
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
|
|
matched_files = glob.glob(pattern)
|
|
for matched_file in matched_files:
|
|
import re
|
|
match = re.search(r"(\d{8})\.mktdata\.ohlcv\.db$", matched_file)
|
|
assert match is not None
|
|
day = match.group(1)
|
|
resolved_files.append((day, matched_file))
|
|
else:
|
|
# Handle explicit file path
|
|
if not os.path.isabs(pattern):
|
|
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
|
|
resolved_files.append((date_pattern, pattern))
|
|
return sorted(list(set(resolved_files))) # Remove duplicates and sort
|
|
|
|
|
|
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
|
|
|
|
instruments = [
|
|
{
|
|
"symbol": inst.split(":")[0],
|
|
"instrument_type": inst.split(":")[1],
|
|
"exchange_id": inst.split(":")[2],
|
|
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
|
|
"instrument_id_pfx"
|
|
],
|
|
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
|
|
"db_table_name"
|
|
],
|
|
}
|
|
for inst in args.instruments.split(",")
|
|
]
|
|
return instruments
|
|
|
|
|
|
def run_backtest(
|
|
config: Dict,
|
|
datafiles: List[str],
|
|
fit_method: PairsTradingFitMethod,
|
|
instruments: List[Dict[str, str]],
|
|
) -> BacktestResult:
|
|
"""
|
|
Run backtest for all pairs using the specified instruments.
|
|
"""
|
|
bt_result: BacktestResult = BacktestResult(config=config)
|
|
# if len(datafiles) < 2:
|
|
# print(f"WARNING: insufficient data files: {datafiles}")
|
|
# return bt_result
|
|
|
|
if not all([os.path.exists(datafile) for datafile in datafiles]):
|
|
print(f"WARNING: data file {datafiles} does not exist")
|
|
return bt_result
|
|
|
|
pairs_trades = []
|
|
|
|
pairs = create_pairs(
|
|
datafiles=datafiles,
|
|
fit_method=fit_method,
|
|
config=config,
|
|
instruments=instruments,
|
|
)
|
|
for pair in pairs:
|
|
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
|
|
if single_pair_trades is not None and len(single_pair_trades) > 0:
|
|
pairs_trades.append(single_pair_trades)
|
|
print(f"pairs_trades:\n{pairs_trades}")
|
|
# Check if result_list has any data before concatenating
|
|
if len(pairs_trades) == 0:
|
|
print("No trading signals found for any pairs")
|
|
return bt_result
|
|
|
|
bt_result.collect_single_day_results(pairs_trades)
|
|
return bt_result
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
|
parser.add_argument(
|
|
"--config", type=str, required=True, help="Path to the configuration file."
|
|
)
|
|
parser.add_argument(
|
|
"--date_pattern",
|
|
type=str,
|
|
required=True,
|
|
help="Date YYYYMMDD, allows * and ? wildcards",
|
|
)
|
|
parser.add_argument(
|
|
"--instruments",
|
|
type=str,
|
|
required=True,
|
|
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
|
|
)
|
|
parser.add_argument(
|
|
"--result_db",
|
|
type=str,
|
|
required=True,
|
|
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
config: Dict = load_config(args.config)
|
|
|
|
# Dynamically instantiate fit method class
|
|
fit_method = PairsTradingFitMethod.create(config)
|
|
|
|
# Resolve data files (CLI takes priority over config)
|
|
instruments = get_instruments(args, config)
|
|
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
|
|
|
|
days = list(set([day for day, _ in datafiles]))
|
|
print(f"Found {len(datafiles)} data files to process:")
|
|
for df in datafiles:
|
|
print(f" - {df}")
|
|
|
|
# Create result database if needed
|
|
if args.result_db.upper() != "NONE":
|
|
args.result_db = expand_filename(args.result_db)
|
|
create_result_database(args.result_db)
|
|
|
|
# Initialize a dictionary to store all trade results
|
|
all_results: Dict[str, Dict[str, Any]] = {}
|
|
is_config_stored = False
|
|
# Process each data file
|
|
|
|
for day in sorted(days):
|
|
md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day]
|
|
if not all([os.path.exists(datafile) for datafile in md_datafiles]):
|
|
print(f"WARNING: insufficient data files: {md_datafiles}")
|
|
continue
|
|
print(f"\n====== Processing {day} ======")
|
|
|
|
if not is_config_stored:
|
|
store_config_in_database(
|
|
db_path=args.result_db,
|
|
config_file_path=args.config,
|
|
config=config,
|
|
fit_method_class=config["fit_method_class"],
|
|
datafiles=datafiles,
|
|
instruments=instruments,
|
|
)
|
|
is_config_stored = True
|
|
|
|
# Process data for this file
|
|
try:
|
|
fit_method.reset()
|
|
|
|
bt_results = run_backtest(
|
|
config=config,
|
|
datafiles=md_datafiles,
|
|
fit_method=fit_method,
|
|
instruments=instruments,
|
|
)
|
|
|
|
if bt_results.trades is None or len(bt_results.trades) == 0:
|
|
print(f"No trades found for {day}")
|
|
continue
|
|
|
|
# Store results with day name as key
|
|
filename = os.path.basename(day)
|
|
all_results[filename] = {
|
|
"trades": bt_results.trades.copy(),
|
|
"outstanding_positions": bt_results.outstanding_positions.copy(),
|
|
}
|
|
|
|
# Store results in database
|
|
if args.result_db.upper() != "NONE":
|
|
bt_results.calculate_returns(
|
|
{
|
|
filename: {
|
|
"trades": bt_results.trades.copy(),
|
|
"outstanding_positions": bt_results.outstanding_positions.copy(),
|
|
}
|
|
}
|
|
)
|
|
bt_results.store_results_in_database(db_path=args.result_db, day=day)
|
|
|
|
print(f"Successfully processed {filename}")
|
|
|
|
except Exception as err:
|
|
print(f"Error processing {day}: {str(err)}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
# Calculate and print results using a new BacktestResult instance for aggregation
|
|
if all_results:
|
|
aggregate_bt_results = BacktestResult(config=config)
|
|
aggregate_bt_results.calculate_returns(all_results)
|
|
aggregate_bt_results.print_grand_totals()
|
|
aggregate_bt_results.print_outstanding_positions()
|
|
|
|
if args.result_db.upper() != "NONE":
|
|
print(f"\nResults stored in database: {args.result_db}")
|
|
else:
|
|
print("No results to display.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|