pairs_trading/research/cointegration_test.py
Oleg Sheynin c2f701e3a2 progress
2025-07-25 20:20:23 +00:00

127 lines
4.1 KiB
Python

import argparse
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import pandas as pd
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
from research.research_tools import create_pairs, resolve_datafiles
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafile",
type=str,
required=False,
help="Market data file to process.",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Resolve data files (CLI takes priority over config)
datafile = resolve_datafiles(config, args.datafile)[0]
if not datafile:
print("No data files found to process.")
return
print(f"Found {datafile} data files to process:")
# # Create result database if needed
# if args.result_db.upper() != "NONE":
# args.result_db = expand_filename(args.result_db)
# create_result_database(args.result_db)
# # Initialize a dictionary to store all trade results
# all_results: Dict[str, Dict[str, Any]] = {}
# # Store configuration in database for reference
# if args.result_db.upper() != "NONE":
# # Get list of all instruments for storage
# all_instruments = []
# for datafile in datafiles:
# if args.instruments:
# file_instruments = [
# inst.strip() for inst in args.instruments.split(",")
# ]
# else:
# file_instruments = get_available_instruments_from_db(datafile, config)
# all_instruments.extend(file_instruments)
# # Remove duplicates while preserving order
# unique_instruments = list(dict.fromkeys(all_instruments))
# store_config_in_database(
# db_path=args.result_db,
# config_file_path=args.config,
# config=config,
# fit_method_class=fit_method_class_name,
# datafiles=datafiles,
# instruments=unique_instruments,
# )
# Process each data file
stat_model_price = config["stat_model_price"]
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found in {datafile}...")
return
# Process data for this file
try:
cointegration_data: pd.DataFrame = pd.DataFrame()
for pair in create_pairs(datafile, stat_model_price, config, instruments):
cointegration_data = pd.concat([cointegration_data, pair.cointegration_check()])
pd.set_option('display.width', 400)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
print(f"cointegration_data:\n{cointegration_data}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()