226 lines
7.4 KiB
Python
226 lines
7.4 KiB
Python
import argparse
|
|
import asyncio
|
|
import glob
|
|
import importlib
|
|
import os
|
|
from datetime import date, datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import hjson
|
|
import pandas as pd
|
|
|
|
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
|
from pt_trading.results import (
|
|
BacktestResult,
|
|
create_result_database,
|
|
store_config_in_database,
|
|
store_results_in_database,
|
|
)
|
|
from pt_trading.fit_methods import PairsTradingFitMethod
|
|
from pt_trading.trading_pair import TradingPair
|
|
|
|
|
|
def run_strategy(
|
|
config: Dict,
|
|
datafile: str,
|
|
price_column: str,
|
|
fit_method: PairsTradingFitMethod,
|
|
instruments: List[str],
|
|
) -> BacktestResult:
|
|
"""
|
|
Run backtest for all pairs using the specified instruments.
|
|
"""
|
|
bt_result: BacktestResult = BacktestResult(config=config)
|
|
|
|
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
|
|
nonlocal datafile
|
|
all_indexes = range(len(instruments))
|
|
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
|
pairs = []
|
|
|
|
# Update config to use the specified instruments
|
|
config_copy = config.copy()
|
|
config_copy["instruments"] = instruments
|
|
|
|
market_data_df = load_market_data(
|
|
datafile=datafile,
|
|
exchange_id=config_copy["exchange_id"],
|
|
instruments=config_copy["instruments"],
|
|
instrument_id_pfx=config_copy["instrument_id_pfx"],
|
|
db_table_name=config_copy["db_table_name"],
|
|
trading_hours=config_copy["trading_hours"],
|
|
)
|
|
|
|
for a_index, b_index in unique_index_pairs:
|
|
pair = fit_method.create_trading_pair(
|
|
market_data=market_data_df,
|
|
symbol_a=instruments[a_index],
|
|
symbol_b=instruments[b_index],
|
|
price_column=price_column,
|
|
)
|
|
pairs.append(pair)
|
|
return pairs
|
|
|
|
pairs_trades = []
|
|
for pair in _create_pairs(config, instruments):
|
|
single_pair_trades = fit_method.run_pair(
|
|
pair=pair, config=config, bt_result=bt_result
|
|
)
|
|
if single_pair_trades is not None and len(single_pair_trades) > 0:
|
|
pairs_trades.append(single_pair_trades)
|
|
|
|
# Check if result_list has any data before concatenating
|
|
if len(pairs_trades) == 0:
|
|
print("No trading signals found for any pairs")
|
|
return bt_result
|
|
|
|
result = pd.concat(pairs_trades, ignore_index=True)
|
|
result["time"] = pd.to_datetime(result["time"])
|
|
result = result.set_index("time").sort_index()
|
|
|
|
bt_result.collect_single_day_results(result)
|
|
return bt_result
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
|
parser.add_argument(
|
|
"--config", type=str, required=True, help="Path to the configuration file."
|
|
)
|
|
parser.add_argument(
|
|
"--datafiles",
|
|
type=str,
|
|
required=False,
|
|
help="Comma-separated list of data files (overrides config). No wildcards supported.",
|
|
)
|
|
parser.add_argument(
|
|
"--instruments",
|
|
type=str,
|
|
required=False,
|
|
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
|
|
)
|
|
parser.add_argument(
|
|
"--result_db",
|
|
type=str,
|
|
required=True,
|
|
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
config: Dict = load_config(args.config)
|
|
|
|
# Dynamically instantiate fit method class
|
|
fit_method_class_name = config.get("fit_method_class", None)
|
|
assert fit_method_class_name is not None
|
|
module_name, class_name = fit_method_class_name.rsplit(".", 1)
|
|
module = importlib.import_module(module_name)
|
|
fit_method = getattr(module, class_name)()
|
|
|
|
# Resolve data files (CLI takes priority over config)
|
|
datafiles = resolve_datafiles(config, args.datafiles)
|
|
|
|
if not datafiles:
|
|
print("No data files found to process.")
|
|
return
|
|
|
|
print(f"Found {len(datafiles)} data files to process:")
|
|
for df in datafiles:
|
|
print(f" - {df}")
|
|
|
|
# Create result database if needed
|
|
if args.result_db.upper() != "NONE":
|
|
create_result_database(args.result_db)
|
|
|
|
# Initialize a dictionary to store all trade results
|
|
all_results: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# Store configuration in database for reference
|
|
if args.result_db.upper() != "NONE":
|
|
# Get list of all instruments for storage
|
|
all_instruments = []
|
|
for datafile in datafiles:
|
|
if args.instruments:
|
|
file_instruments = [
|
|
inst.strip() for inst in args.instruments.split(",")
|
|
]
|
|
else:
|
|
file_instruments = get_available_instruments_from_db(datafile, config)
|
|
all_instruments.extend(file_instruments)
|
|
|
|
# Remove duplicates while preserving order
|
|
unique_instruments = list(dict.fromkeys(all_instruments))
|
|
|
|
store_config_in_database(
|
|
db_path=args.result_db,
|
|
config_file_path=args.config,
|
|
config=config,
|
|
fit_method_class=fit_method_class_name,
|
|
datafiles=datafiles,
|
|
instruments=unique_instruments,
|
|
)
|
|
|
|
# Process each data file
|
|
price_column = config["price_column"]
|
|
|
|
for datafile in datafiles:
|
|
print(f"\n====== Processing {os.path.basename(datafile)} ======")
|
|
|
|
# Determine instruments to use
|
|
if args.instruments:
|
|
# Use CLI-specified instruments
|
|
instruments = [inst.strip() for inst in args.instruments.split(",")]
|
|
print(f"Using CLI-specified instruments: {instruments}")
|
|
else:
|
|
# Auto-detect instruments from database
|
|
instruments = get_available_instruments_from_db(datafile, config)
|
|
print(f"Auto-detected instruments: {instruments}")
|
|
|
|
if not instruments:
|
|
print(f"No instruments found for {datafile}, skipping...")
|
|
continue
|
|
|
|
# Process data for this file
|
|
try:
|
|
fit_method.reset()
|
|
|
|
bt_results = run_strategy(
|
|
config=config,
|
|
datafile=datafile,
|
|
price_column=price_column,
|
|
fit_method=fit_method,
|
|
instruments=instruments,
|
|
)
|
|
|
|
# Store results with file name as key
|
|
filename = os.path.basename(datafile)
|
|
all_results[filename] = {"trades": bt_results.trades.copy()}
|
|
|
|
# Store results in database
|
|
if args.result_db.upper() != "NONE":
|
|
store_results_in_database(args.result_db, datafile, bt_results)
|
|
|
|
print(f"Successfully processed {filename}")
|
|
|
|
except Exception as err:
|
|
print(f"Error processing {datafile}: {str(err)}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
# Calculate and print results using a new BacktestResult instance for aggregation
|
|
if all_results:
|
|
aggregate_bt_results = BacktestResult(config=config)
|
|
aggregate_bt_results.calculate_returns(all_results)
|
|
aggregate_bt_results.print_grand_totals()
|
|
aggregate_bt_results.print_outstanding_positions()
|
|
|
|
if args.result_db.upper() != "NONE":
|
|
print(f"\nResults stored in database: {args.result_db}")
|
|
else:
|
|
print("No results to display.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|