pairs_trading/lib/pt_strategy/trading_strategy.py
2025-07-30 17:08:06 +00:00

405 lines
14 KiB
Python

from __future__ import annotations
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Type, cast
import pandas as pd
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData
from pt_strategy.pt_model import Prediction
from pt_strategy.results import (
PairResearchResult,
create_result_database,
store_config_in_database,
)
from pt_strategy.trading_pair import PairState, TradingPair
from tools.filetools import resolve_datafiles
from tools.instruments import get_instruments
class PtResearchStrategy:
config_: Dict[str, Any]
trading_pair_: TradingPair
model_data_policy_: ModelDataPolicy
pt_mkt_data_: PtMarketData
trades_: List[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
datafiles: List[str],
instruments: List[Dict[str, str]],
):
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData, ResearchMarketData
from pt_strategy.trading_pair import TradingPair
self.config_ = config
self.trades_ = []
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
self.model_data_policy_ = ModelDataPolicy.create(config)
import copy
# modified config must be passed to PtMarketData
config_copy = copy.deepcopy(config)
config_copy["instruments"] = instruments
config_copy["datafiles"] = datafiles
self.pt_mkt_data_ = PtMarketData.create(
config=config_copy, md_class=ResearchMarketData
)
self.pt_mkt_data_.load()
def outstanding_positions(self) -> List[Dict[str, Any]]:
return list(self.trading_pair_.user_data_.get("outstanding_positions", []))
def run(self) -> None:
training_minutes = self.config_.get("training_minutes", 120)
market_data_series: pd.Series
market_data_df = pd.DataFrame()
idx = 0
while self.pt_mkt_data_.has_next():
market_data_series = self.pt_mkt_data_.get_next()
new_row = pd.DataFrame([market_data_series])
market_data_df = pd.concat(
[market_data_df, new_row], ignore_index=True
)
if idx >= training_minutes:
break
idx += 1
assert idx >= training_minutes, "Not enough training data"
while self.pt_mkt_data_.has_next():
market_data_series = self.pt_mkt_data_.get_next()
new_row = pd.DataFrame([market_data_series])
market_data_df = pd.concat([market_data_df, new_row], ignore_index=True)
prediction = self.trading_pair_.run(
market_data_df, self.model_data_policy_.advance()
)
assert prediction is not None
trades = self._create_trades(
prediction=prediction, last_row=market_data_df.iloc[-1]
)
if trades is not None:
self.trades_.append(trades)
trades = self._handle_outstanding_positions()
if trades is not None:
self.trades_.append(trades)
def _create_trades(
self, prediction: Prediction, last_row: pd.Series
) -> Optional[pd.DataFrame]:
pair = self.trading_pair_
trades = None
open_threshold = self.config_["dis-equilibrium_open_trshld"]
close_threshold = self.config_["dis-equilibrium_close_trshld"]
scaled_disequilibrium = prediction.scaled_disequilibrium_
abs_scaled_disequilibrium = abs(scaled_disequilibrium)
if pair.user_data_["state"] in [
PairState.INITIAL,
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]:
if abs_scaled_disequilibrium >= open_threshold:
trades = self._create_open_trades(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.OPEN.name
print(f"OPEN TRADES:\n{trades}")
pair.user_data_["state"] = PairState.OPEN
pair.on_open_trades(trades)
elif pair.user_data_["state"] == PairState.OPEN:
if abs_scaled_disequilibrium <= close_threshold:
trades = self._create_close_trades(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.CLOSE.name
print(f"CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = PairState.CLOSE
pair.on_close_trades(trades)
elif pair.to_stop_close_conditions(predicted_row=last_row):
trades = self._create_close_trades(pair, row=last_row)
if trades is not None:
trades["status"] = pair.user_data_["stop_close_state"].name
print(f"STOP CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
pair.on_close_trades(trades)
return trades
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
trades = None
pair = self.trading_pair_
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: *** Position is NOT CLOSED. ***")
# outstanding positions
if self.config_["close_outstanding_positions"]:
close_position_row = pd.Series(pair.market_data_.iloc[-2])
# close_position_row["disequilibrium"] = 0.0
# close_position_row["scaled_disequilibrium"] = 0.0
# close_position_row["signed_scaled_disequilibrium"] = 0.0
trades = self._create_close_trades(
pair=pair, row=close_position_row, prediction=None
)
if trades is not None:
trades["status"] = PairState.CLOSE_POSITION.name
print(f"CLOSE_POSITION TRADES:\n{trades}")
pair.user_data_["state"] = PairState.CLOSE_POSITION
pair.on_close_trades(trades)
else:
pair.add_outstanding_position(
symbol=pair.symbol_a_,
open_side=pair.user_data_["open_side_a"],
open_px=pair.user_data_["open_px_a"],
open_tstamp=pair.user_data_["open_tstamp"],
last_mkt_data_row=pair.market_data_.iloc[-1],
)
pair.add_outstanding_position(
symbol=pair.symbol_b_,
open_side=pair.user_data_["open_side_b"],
open_px=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
last_mkt_data_row=pair.market_data_.iloc[-1],
)
return trades
def _trades_df(self) -> pd.DataFrame:
types = {
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"side": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
# "pair": "object",
}
columns = list(types.keys())
return pd.DataFrame(columns=columns).astype(types)
def _create_open_trades(
self, pair: TradingPair, row: pd.Series, prediction: Prediction
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
tstamp = row["tstamp"]
diseqlbrm = prediction.disequilibrium_
scaled_disequilibrium = prediction.scaled_disequilibrium_
px_a = row[f"{colname_a}"]
px_b = row[f"{colname_b}"]
# creating the trades
df = self._trades_df()
print(f"OPEN_TRADES: {row["tstamp"]} {scaled_disequilibrium=}")
if diseqlbrm > 0:
side_a = "SELL"
side_b = "BUY"
else:
side_a = "BUY"
side_b = "SELL"
# save closing sides
pair.user_data_["open_side_a"] = side_a # used in oustanding positions
pair.user_data_["open_side_b"] = side_b
pair.user_data_["open_px_a"] = px_a
pair.user_data_["open_px_b"] = px_b
pair.user_data_["open_tstamp"] = tstamp
pair.user_data_["close_side_a"] = side_b # used for closing trades
pair.user_data_["close_side_b"] = side_a
# create opening trades
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_a_,
"side": side_a,
"action": "OPEN",
"price": px_a,
"disequilibrium": diseqlbrm,
"signed_scaled_disequilibrium": scaled_disequilibrium,
"scaled_disequilibrium": abs(scaled_disequilibrium),
# "pair": pair,
}
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_b_,
"side": side_b,
"action": "OPEN",
"price": px_b,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": abs(scaled_disequilibrium),
"signed_scaled_disequilibrium": scaled_disequilibrium,
# "pair": pair,
}
return df
def _create_close_trades(
self, pair: TradingPair, row: pd.Series, prediction: Optional[Prediction] = None
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
tstamp = row["tstamp"]
if prediction is not None:
diseqlbrm = prediction.disequilibrium_
signed_scaled_disequilibrium = prediction.scaled_disequilibrium_
scaled_disequilibrium = abs(prediction.scaled_disequilibrium_)
else:
diseqlbrm = 0.0
signed_scaled_disequilibrium = 0.0
scaled_disequilibrium = 0.0
px_a = row[f"{colname_a}"]
px_b = row[f"{colname_b}"]
# creating the trades
df = self._trades_df()
# create opening trades
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_a_,
"side": pair.user_data_["close_side_a"],
"action": "CLOSE",
"price": px_a,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": scaled_disequilibrium,
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
# "pair": pair,
}
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_b_,
"side": pair.user_data_["close_side_b"],
"action": "CLOSE",
"price": px_b,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": scaled_disequilibrium,
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
# "pair": pair,
}
del pair.user_data_["close_side_a"]
del pair.user_data_["close_side_b"]
del pair.user_data_["open_tstamp"]
del pair.user_data_["open_px_a"]
del pair.user_data_["open_px_b"]
del pair.user_data_["open_side_a"]
del pair.user_data_["open_side_b"]
return df
def day_trades(self) -> pd.DataFrame:
return pd.concat(self.trades_, ignore_index=True)
def main() -> None:
import argparse
from tools.config import expand_filename, load_config
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--date_pattern",
type=str,
required=True,
help="Date YYYYMMDD, allows * and ? wildcards",
)
parser.add_argument(
"--instruments",
type=str,
required=True,
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
)
parser.add_argument(
"--result_db",
type=str,
required=True,
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Resolve data files (CLI takes priority over config)
instruments = get_instruments(args, config)
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
days = list(set([day for day, _ in datafiles]))
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
# Create result database if needed
if args.result_db.upper() != "NONE":
args.result_db = expand_filename(args.result_db)
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
is_config_stored = False
# Process each data file
results = PairResearchResult(config=config)
for day in sorted(days):
md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day]
if not all([os.path.exists(datafile) for datafile in md_datafiles]):
print(f"WARNING: insufficient data files: {md_datafiles}")
continue
print(f"\n====== Processing {day} ======")
if not is_config_stored:
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
datafiles=datafiles,
instruments=instruments,
)
is_config_stored = True
pt_strategy = PtResearchStrategy(
config=config, datafiles=md_datafiles, instruments=instruments
)
pt_strategy.run()
results.add_day_results(
day=day,
trades=pt_strategy.day_trades(),
outstanding_positions=pt_strategy.outstanding_positions(),
)
results.analyze_pair_performance()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")
else:
print("No results to display.")
if __name__ == "__main__":
main()