from __future__ import annotations import os from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, Generator, List, Optional, Type, cast import pandas as pd from pt_strategy.model_data_policy import ModelDataPolicy from pt_strategy.pt_market_data import PtMarketData from pt_strategy.pt_model import Prediction from pt_strategy.results import ( PairResearchResult, create_result_database, store_config_in_database, ) from pt_strategy.trading_pair import PairState, TradingPair from tools.filetools import resolve_datafiles from tools.instruments import get_instruments class PtResearchStrategy: config_: Dict[str, Any] trading_pair_: TradingPair model_data_policy_: ModelDataPolicy pt_mkt_data_: PtMarketData trades_: List[pd.DataFrame] def __init__( self, config: Dict[str, Any], datafiles: List[str], instruments: List[Dict[str, str]], ): from pt_strategy.model_data_policy import ModelDataPolicy from pt_strategy.pt_market_data import PtMarketData, ResearchMarketData from pt_strategy.trading_pair import TradingPair self.config_ = config self.trades_ = [] self.trading_pair_ = TradingPair(config=config, instruments=instruments) self.model_data_policy_ = ModelDataPolicy.create(config) import copy # modified config must be passed to PtMarketData config_copy = copy.deepcopy(config) config_copy["instruments"] = instruments config_copy["datafiles"] = datafiles self.pt_mkt_data_ = PtMarketData.create( config=config_copy, md_class=ResearchMarketData ) self.pt_mkt_data_.load() def outstanding_positions(self) -> List[Dict[str, Any]]: return list(self.trading_pair_.user_data_.get("outstanding_positions", [])) def run(self) -> None: training_minutes = self.config_.get("training_minutes", 120) market_data_series: pd.Series market_data_df = pd.DataFrame() idx = 0 while self.pt_mkt_data_.has_next(): market_data_series = self.pt_mkt_data_.get_next() market_data_df = pd.concat( [market_data_df, market_data_series.to_frame().T], ignore_index=True ) if idx >= training_minutes: break idx += 1 assert idx >= training_minutes, "Not enough training data" while self.pt_mkt_data_.has_next(): market_data_series = self.pt_mkt_data_.get_next() new_row = market_data_series.to_frame().T market_data_df = pd.concat([market_data_df, new_row], ignore_index=True) prediction = self.trading_pair_.run( market_data_df, self.model_data_policy_.advance() ) assert prediction is not None trades = self._create_trades( prediction=prediction, last_row=market_data_df.iloc[-1] ) if trades is not None: self.trades_.append(trades) trades = self._handle_outstanding_positions() if trades is not None: self.trades_.append(trades) def _create_trades( self, prediction: Prediction, last_row: pd.Series ) -> Optional[pd.DataFrame]: pair = self.trading_pair_ trades = None open_threshold = self.config_["dis-equilibrium_open_trshld"] close_threshold = self.config_["dis-equilibrium_close_trshld"] scaled_disequilibrium = prediction.scaled_disequilibrium_ abs_scaled_disequilibrium = abs(scaled_disequilibrium) if pair.user_data_["state"] in [ PairState.INITIAL, PairState.CLOSE, PairState.CLOSE_POSITION, PairState.CLOSE_STOP_LOSS, PairState.CLOSE_STOP_PROFIT, ]: if abs_scaled_disequilibrium >= open_threshold: trades = self._create_open_trades( pair, row=last_row, prediction=prediction ) if trades is not None: trades["status"] = PairState.OPEN.name print(f"OPEN TRADES:\n{trades}") pair.user_data_["state"] = PairState.OPEN pair.on_open_trades(trades) elif pair.user_data_["state"] == PairState.OPEN: if abs_scaled_disequilibrium <= close_threshold: trades = self._create_close_trades( pair, row=last_row, prediction=prediction ) if trades is not None: trades["status"] = PairState.CLOSE.name print(f"CLOSE TRADES:\n{trades}") pair.user_data_["state"] = PairState.CLOSE pair.on_close_trades(trades) elif pair.to_stop_close_conditions(predicted_row=last_row): trades = self._create_close_trades(pair, row=last_row) if trades is not None: trades["status"] = pair.user_data_["stop_close_state"].name print(f"STOP CLOSE TRADES:\n{trades}") pair.user_data_["state"] = pair.user_data_["stop_close_state"] pair.on_close_trades(trades) return trades def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]: trades = None pair = self.trading_pair_ # Outstanding positions if pair.user_data_["state"] == PairState.OPEN: print(f"{pair}: *** Position is NOT CLOSED. ***") # outstanding positions if self.config_["close_outstanding_positions"]: close_position_row = pd.Series(pair.market_data_.iloc[-2]) # close_position_row["disequilibrium"] = 0.0 # close_position_row["scaled_disequilibrium"] = 0.0 # close_position_row["signed_scaled_disequilibrium"] = 0.0 trades = self._create_close_trades( pair=pair, row=close_position_row, prediction=None ) if trades is not None: trades["status"] = PairState.CLOSE_POSITION.name print(f"CLOSE_POSITION TRADES:\n{trades}") pair.user_data_["state"] = PairState.CLOSE_POSITION pair.on_close_trades(trades) else: pair.add_outstanding_position( symbol=pair.symbol_a_, open_side=pair.user_data_["open_side_a"], open_px=pair.user_data_["open_px_a"], open_tstamp=pair.user_data_["open_tstamp"], last_mkt_data_row=pair.market_data_.iloc[-1], ) pair.add_outstanding_position( symbol=pair.symbol_b_, open_side=pair.user_data_["open_side_b"], open_px=pair.user_data_["open_px_b"], open_tstamp=pair.user_data_["open_tstamp"], last_mkt_data_row=pair.market_data_.iloc[-1], ) return trades def _trades_df(self) -> pd.DataFrame: types = { "time": "datetime64[ns]", "action": "string", "symbol": "string", "side": "string", "price": "float64", "disequilibrium": "float64", "scaled_disequilibrium": "float64", "signed_scaled_disequilibrium": "float64", # "pair": "object", } columns = list(types.keys()) return pd.DataFrame(columns=columns).astype(types) def _create_open_trades( self, pair: TradingPair, row: pd.Series, prediction: Prediction ) -> Optional[pd.DataFrame]: colname_a, colname_b = pair.exec_prices_colnames() tstamp = row["tstamp"] diseqlbrm = prediction.disequilibrium_ scaled_disequilibrium = prediction.scaled_disequilibrium_ px_a = row[f"{colname_a}"] px_b = row[f"{colname_b}"] # creating the trades df = self._trades_df() print(f"OPEN_TRADES: {row["tstamp"]} {scaled_disequilibrium=}") if diseqlbrm > 0: side_a = "SELL" side_b = "BUY" else: side_a = "BUY" side_b = "SELL" # save closing sides pair.user_data_["open_side_a"] = side_a # used in oustanding positions pair.user_data_["open_side_b"] = side_b pair.user_data_["open_px_a"] = px_a pair.user_data_["open_px_b"] = px_b pair.user_data_["open_tstamp"] = tstamp pair.user_data_["close_side_a"] = side_b # used for closing trades pair.user_data_["close_side_b"] = side_a # create opening trades df.loc[len(df)] = { "time": tstamp, "symbol": pair.symbol_a_, "side": side_a, "action": "OPEN", "price": px_a, "disequilibrium": diseqlbrm, "signed_scaled_disequilibrium": scaled_disequilibrium, "scaled_disequilibrium": abs(scaled_disequilibrium), # "pair": pair, } df.loc[len(df)] = { "time": tstamp, "symbol": pair.symbol_b_, "side": side_b, "action": "OPEN", "price": px_b, "disequilibrium": diseqlbrm, "scaled_disequilibrium": abs(scaled_disequilibrium), "signed_scaled_disequilibrium": scaled_disequilibrium, # "pair": pair, } return df def _create_close_trades( self, pair: TradingPair, row: pd.Series, prediction: Optional[Prediction] = None ) -> Optional[pd.DataFrame]: colname_a, colname_b = pair.exec_prices_colnames() tstamp = row["tstamp"] if prediction is not None: diseqlbrm = prediction.disequilibrium_ signed_scaled_disequilibrium = prediction.scaled_disequilibrium_ scaled_disequilibrium = abs(prediction.scaled_disequilibrium_) else: diseqlbrm = 0.0 signed_scaled_disequilibrium = 0.0 scaled_disequilibrium = 0.0 px_a = row[f"{colname_a}"] px_b = row[f"{colname_b}"] # creating the trades df = self._trades_df() # create opening trades df.loc[len(df)] = { "time": tstamp, "symbol": pair.symbol_a_, "side": pair.user_data_["close_side_a"], "action": "CLOSE", "price": px_a, "disequilibrium": diseqlbrm, "scaled_disequilibrium": scaled_disequilibrium, "signed_scaled_disequilibrium": signed_scaled_disequilibrium, # "pair": pair, } df.loc[len(df)] = { "time": tstamp, "symbol": pair.symbol_b_, "side": pair.user_data_["close_side_b"], "action": "CLOSE", "price": px_b, "disequilibrium": diseqlbrm, "scaled_disequilibrium": scaled_disequilibrium, "signed_scaled_disequilibrium": signed_scaled_disequilibrium, # "pair": pair, } del pair.user_data_["close_side_a"] del pair.user_data_["close_side_b"] del pair.user_data_["open_tstamp"] del pair.user_data_["open_px_a"] del pair.user_data_["open_px_b"] del pair.user_data_["open_side_a"] del pair.user_data_["open_side_b"] return df def day_trades(self) -> pd.DataFrame: return pd.concat(self.trades_, ignore_index=True) def main() -> None: import argparse from tools.config import expand_filename, load_config parser = argparse.ArgumentParser(description="Run pairs trading backtest.") parser.add_argument( "--config", type=str, required=True, help="Path to the configuration file." ) parser.add_argument( "--date_pattern", type=str, required=True, help="Date YYYYMMDD, allows * and ? wildcards", ) parser.add_argument( "--instruments", type=str, required=True, help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)", ) parser.add_argument( "--result_db", type=str, required=True, help="Path to SQLite database for storing results. Use 'NONE' to disable database output.", ) args = parser.parse_args() config: Dict = load_config(args.config) # Resolve data files (CLI takes priority over config) instruments = get_instruments(args, config) datafiles = resolve_datafiles(config, args.date_pattern, instruments) days = list(set([day for day, _ in datafiles])) print(f"Found {len(datafiles)} data files to process:") for df in datafiles: print(f" - {df}") # Create result database if needed if args.result_db.upper() != "NONE": args.result_db = expand_filename(args.result_db) create_result_database(args.result_db) # Initialize a dictionary to store all trade results all_results: Dict[str, Dict[str, Any]] = {} is_config_stored = False # Process each data file results = PairResearchResult(config=config) for day in sorted(days): md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day] if not all([os.path.exists(datafile) for datafile in md_datafiles]): print(f"WARNING: insufficient data files: {md_datafiles}") continue print(f"\n====== Processing {day} ======") if not is_config_stored: store_config_in_database( db_path=args.result_db, config_file_path=args.config, config=config, datafiles=datafiles, instruments=instruments, ) is_config_stored = True pt_strategy = PtResearchStrategy( config=config, datafiles=md_datafiles, instruments=instruments ) pt_strategy.run() results.add_day_results( day=day, trades=pt_strategy.day_trades(), outstanding_positions=pt_strategy.outstanding_positions(), ) # ADD RESULTS ANALYSIS results.calculate_returns() results.print_single_day_results() # Store results with day name as key # filename = os.path.basename(day) # all_results[filename] = { # "trades": pt_strategy.trades_.copy(), # "outstanding_positions": pt_strategy.outstanding_positions_.copy(), # } # print(f"Successfully processed {filename}") results.calculate_returns() results.print_grand_totals() results.print_outstanding_positions() if args.result_db.upper() != "NONE": print(f"\nResults stored in database: {args.result_db}") else: print("No results to display.") if __name__ == "__main__": main()