refactoring the code for pairs, models, strategy

This commit is contained in:
Oleg Sheynin 2025-07-30 04:08:02 +00:00
parent 71822c64b0
commit ed0c0fecb2
18 changed files with 3004 additions and 287 deletions

View File

@ -0,0 +1,43 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"stat_model_price": "close",
"execution_price": {
"column": "vwap",
"shift": 1,
},
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5,
"training_size": 120,
"model_class": "pt_strategy.models.ZScoreOLSModel",
"model_data_policy_class": "pt_strategy.model_data_policy.RollingWindowDataPolicy",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
# "close_outstanding_positions": true,
"close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "7:30:00",
"end_session": "18:30:00",
}
}

View File

@ -23,7 +23,8 @@
}, },
"dis-equilibrium_open_trshld": 2.0, "dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 1.0, "dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120, "training_minutes": 120, # TODO Remove this
"training_size": 120,
"fit_method_class": "pt_trading.vecm_rolling_fit.VECMRollingFit", "fit_method_class": "pt_trading.vecm_rolling_fit.VECMRollingFit",
# ====== Stop Conditions ====== # ====== Stop Conditions ======
@ -37,7 +38,7 @@
# "close_outstanding_positions": false, # "close_outstanding_positions": false,
"trading_hours": { "trading_hours": {
"timezone": "America/New_York", "timezone": "America/New_York",
"begin_session": "9:30:00", "begin_session": "7:30:00",
"end_session": "18:30:00", "end_session": "18:30:00",
} }
} }

View File

@ -16,13 +16,14 @@
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
# ====== Trading Parameters ====== # ====== Trading Parameters ======
"stat_model_price": "close", "stat_model_price": "close",
"execution_price": { # "execution_price": {
"column": "vwap", # "column": "vwap",
"shift": 1, # "shift": 1,
}, # },
"dis-equilibrium_open_trshld": 2.0, "dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5, "dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120, "training_minutes": 120, # TODO Remove this
"training_size": 120,
"fit_method_class": "pt_trading.z-score_rolling_fit.ZScoreRollingFit", "fit_method_class": "pt_trading.z-score_rolling_fit.ZScoreRollingFit",
# ====== Stop Conditions ====== # ====== Stop Conditions ======
@ -36,7 +37,7 @@
# "close_outstanding_positions": false, # "close_outstanding_positions": false,
"trading_hours": { "trading_hours": {
"timezone": "America/New_York", "timezone": "America/New_York",
"begin_session": "9:30:00", "begin_session": "7:30:00",
"end_session": "18:30:00", "end_session": "18:30:00",
} }
} }

View File

@ -0,0 +1,43 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"stat_model_price": "close",
"execution_price": {
"column": "vwap",
"shift": 1,
},
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120, # TODO Remove this
"training_size": 120,
"fit_method_class": "pt_trading.z-score_rolling_fit.ZScoreRollingFit",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
"close_outstanding_positions": true,
# "close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00",
"end_session": "18:30:00",
}
}

View File

@ -0,0 +1,60 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, cast, Generator, List
@dataclass
class DataParams:
training_size: int
training_start_index: int
class ModelDataPolicy(ABC):
config_: Dict[str, Any]
current_data_params_: DataParams
def __init__(self, config: Dict[str, Any]):
self.config_ = config
self.current_data_params_ = DataParams(
training_size=config.get("training_size", 120),
training_start_index=0,
)
@abstractmethod
def advance(self) -> DataParams:
...
@staticmethod
def create(config: Dict[str, Any]) -> ModelDataPolicy:
import importlib
model_data_policy_class_name = config.get("model_data_policy_class", None)
assert model_data_policy_class_name is not None
module_name, class_name = model_data_policy_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
model_training_data_policy_object = getattr(module, class_name)(config=config)
return cast(ModelDataPolicy, model_training_data_policy_object)
class RollingWindowDataPolicy(ModelDataPolicy):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.count_ = 1
def advance(self) -> DataParams:
self.current_data_params_.training_start_index += 1
print(self.count_, end='\r')
self.count_ += 1
return self.current_data_params_
class ExpandingWindowDataPolicy(ModelDataPolicy):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
def advance(self) -> DataParams:
self.current_data_params_.training_size += 1
return self.current_data_params_

53
lib/pt_strategy/models.py Normal file
View File

@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Optional
import pandas as pd
import statsmodels.api as sm
from pt_strategy.pt_model import PairsTradingModel, Prediction
from pt_strategy.trading_pair import TradingPair
class ZScoreOLSModel(PairsTradingModel):
zscore_model_: Optional[sm.regression.linear_model.RegressionResultsWrapper]
pair_predict_result_: Optional[pd.DataFrame]
zscore_df_: Optional[pd.DataFrame]
def _fit_zscore(self, pair: TradingPair) -> pd.DataFrame:
assert self.training_df_ is not None
symbol_a_px_series = self.training_df_[pair.colnames()].iloc[:, 0].astype(float)
symbol_b_px_series = self.training_df_[pair.colnames()].iloc[:, 1].astype(float)
symbol_a_px_series, symbol_b_px_series = symbol_a_px_series.align(
symbol_b_px_series, axis=0
)
X = sm.add_constant(symbol_b_px_series)
self.zscore_model_ = sm.OLS(symbol_a_px_series, X).fit()
assert self.zscore_model_ is not None
hedge_ratio = self.zscore_model_.params.iloc[1]
spread = symbol_a_px_series - hedge_ratio * symbol_b_px_series
return pd.DataFrame((spread - spread.mean()) / spread.std())
def predict(self, pair: TradingPair) -> Prediction:
self.training_df_ = pair.market_data_.copy()
zscore_df = self._fit_zscore(pair=pair)
assert zscore_df is not None
# zscore is both disequilibrium and scaled_disequilibrium
self.training_df_["dis-equilibrium"] = zscore_df[0]
self.training_df_["scaled_dis-equilibrium"] = zscore_df[0]
assert zscore_df is not None
return Prediction(
tstamp_=pair.market_data_.index[-1],
disequilibrium_=self.training_df_["dis-equilibrium"].iloc[-1],
scaled_disequilibrium_=self.training_df_["scaled_dis-equilibrium"].iloc[-1],
pair_=pair,
)

View File

@ -0,0 +1,185 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
import pandas as pd
from tools.data_loader import load_market_data
class PtMarketData(ABC):
config_: Dict[str, Any]
origin_mkt_data_df_: pd.DataFrame
market_data_df_: pd.DataFrame
def __init__(self, config: Dict[str, Any]):
self.config_ = config
self.origin_mkt_data_df_ = pd.DataFrame()
@abstractmethod
def load(self) -> None:
...
@abstractmethod
def has_next(self) -> bool:
...
@abstractmethod
def get_next(self) -> pd.Series:
...
@staticmethod
def create(config: Dict[str, Any], md_class: Type[PtMarketData]) -> PtMarketData:
return md_class(config)
class ResearchMarketData(PtMarketData):
config_: Dict[str, Any]
current_index_: int
is_execution_price_: bool
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.current_index_ = 0
self.is_execution_price_ = "execution_price" in self.config_
if self.is_execution_price_:
self.execution_price_column_ = self.config_["execution_price"]["column"]
self.execution_price_shift_ = self.config_["execution_price"]["shift"]
else:
self.execution_price_column_ = None
self.execution_price_shift_ = 0
def has_next(self) -> bool:
return self.current_index_ < len(self.market_data_df_)
def get_next(self) -> pd.Series:
result = self.market_data_df_.iloc[self.current_index_]
self.current_index_ += 1
return result
def load(self) -> None:
datafiles: List[str] = self.config_.get("datafiles", [])
instruments: List[Dict[str, str]] = self.config_.get("instruments", [])
assert len(instruments) > 0, "No instruments found in config"
assert len(datafiles) > 0, "No datafiles found in config"
self.symbol_a_ = instruments[0]["symbol"]
self.symbol_b_ = instruments[1]["symbol"]
self.stat_model_price_ = self.config_["stat_model_price"]
extra_minutes: int
extra_minutes = self.execution_price_shift_
for datafile in datafiles:
md_df = load_market_data(
datafile=datafile,
instruments=instruments,
db_table_name=self.config_["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
trading_hours=self.config_["trading_hours"],
extra_minutes=extra_minutes,
)
self.origin_mkt_data_df_ = pd.concat([self.origin_mkt_data_df_, md_df])
self._set_market_data()
def _set_market_data(self, ) -> None:
if self.is_execution_price_:
self.market_data_df_ = pd.DataFrame(
self._transform_dataframe(self.origin_mkt_data_df_)[["tstamp"] + self.colnames() + self.orig_exec_prices_colnames()]
)
else:
self.market_data_df_ = pd.DataFrame(
self._transform_dataframe(self.origin_mkt_data_df_)[["tstamp"] + self.colnames()]
)
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
self.market_data_df_["tstamp"] = pd.to_datetime(self.market_data_df_["tstamp"])
self.market_data_df_ = self.market_data_df_.sort_values("tstamp")
self._set_execution_price_data()
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
df_selected: pd.DataFrame
if self.is_execution_price_:
execution_price_column = self.config_["execution_price"]["column"]
df_selected = pd.DataFrame(
df[["tstamp", "symbol", self.stat_model_price_, execution_price_column]]
)
else:
df_selected = pd.DataFrame(
df[["tstamp", "symbol", self.stat_model_price_]]
)
result_df = pd.DataFrame(df_selected["tstamp"]).drop_duplicates().reset_index(drop=True)
# For each unique symbol, add a corresponding stat_model_price column
symbols = df_selected["symbol"].unique()
for symbol in symbols:
# Filter rows for this symbol
df_symbol = df_selected[df_selected["symbol"] == symbol].reset_index(
drop=True
)
# Create column name like "close-COIN"
new_price_column = f"{self.stat_model_price_}_{symbol}"
if self.is_execution_price_:
new_execution_price_column = f"{self.execution_price_column_}_{symbol}"
# Create temporary dataframe with timestamp and price
temp_df = pd.DataFrame(
{
"tstamp": df_symbol["tstamp"],
new_price_column: df_symbol[self.stat_model_price_],
new_execution_price_column: df_symbol[execution_price_column],
}
)
else:
temp_df = pd.DataFrame(
{
"tstamp": df_symbol["tstamp"],
new_price_column: df_symbol[self.stat_model_price_],
}
)
# Join with our result dataframe
result_df = pd.merge(result_df, temp_df, on="tstamp", how="left")
result_df = result_df.reset_index(
drop=True
) # do not dropna() since irrelevant symbol would affect dataset
return result_df.dropna()
def _set_execution_price_data(self) -> None:
if "execution_price" not in self.config_:
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[f"{self.stat_model_price_}_{self.symbol_a_}"]
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[f"{self.stat_model_price_}_{self.symbol_b_}"]
return
execution_price_column = self.config_["execution_price"]["column"]
execution_price_shift = self.config_["execution_price"]["shift"]
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[f"{execution_price_column}_{self.symbol_a_}"].shift(-execution_price_shift)
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[f"{execution_price_column}_{self.symbol_b_}"].shift(-execution_price_shift)
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
def colnames(self) -> List[str]:
return [
f"{self.stat_model_price_}_{self.symbol_a_}",
f"{self.stat_model_price_}_{self.symbol_b_}",
]
def orig_exec_prices_colnames(self) -> List[str]:
return [
f"{self.execution_price_column_}_{self.symbol_a_}",
f"{self.execution_price_column_}_{self.symbol_b_}",
]
def exec_prices_colnames(self) -> List[str]:
return [
f"exec_price_{self.symbol_a_}",
f"exec_price_{self.symbol_b_}",
]

View File

@ -0,0 +1,48 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, cast, Generator, List
import pandas as pd
from pt_strategy.trading_pair import TradingPair
@dataclass
class Prediction:
tstamp_: pd.Timestamp
disequilibrium_: float
scaled_disequilibrium_: float
pair_: TradingPair
def to_dict(self) -> Dict[str, Any]:
return {
"tstamp": self.tstamp_,
"disequilibrium": self.disequilibrium_,
"signed_scaled_disequilibrium": self.scaled_disequilibrium_,
"scaled_disequilibrium": abs(self.scaled_disequilibrium_),
"pair": self.pair_,
}
def to_pd_series(self) -> pd.Series:
return pd.DataFrame([self.to_dict()]).iloc[0]
class PairsTradingModel(ABC):
@abstractmethod
def predict(self, pair: TradingPair) -> Prediction:
...
@staticmethod
def create(config: Dict[str, Any]) -> PairsTradingModel:
import importlib
model_class_name = config.get("model_class", None)
assert model_class_name is not None
module_name, class_name = model_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
model_object = getattr(module, class_name)()
return cast(PairsTradingModel, model_object)

580
lib/pt_strategy/results.py Normal file
View File

@ -0,0 +1,580 @@
import os
import sqlite3
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
from pt_trading.trading_pair import TradingPair
# Recommended replacement adapters and converters for Python 3.12+
# From: https://docs.python.org/3/library/sqlite3.html#sqlite3-adapter-converter-recipes
def adapt_date_iso(val: date) -> str:
"""Adapt datetime.date to ISO 8601 date."""
return val.isoformat()
def adapt_datetime_iso(val: datetime) -> str:
"""Adapt datetime.datetime to timezone-naive ISO 8601 date."""
return val.isoformat()
def convert_date(val: bytes) -> date:
"""Convert ISO 8601 date to datetime.date object."""
return datetime.fromisoformat(val.decode()).date()
def convert_datetime(val: bytes) -> datetime:
"""Convert ISO 8601 datetime to datetime.datetime object."""
return datetime.fromisoformat(val.decode())
# Register the adapters and converters
sqlite3.register_adapter(date, adapt_date_iso)
sqlite3.register_adapter(datetime, adapt_datetime_iso)
sqlite3.register_converter("date", convert_date)
sqlite3.register_converter("datetime", convert_datetime)
def create_result_database(db_path: str) -> None:
"""
Create the SQLite database and required tables if they don't exist.
"""
try:
# Create directory if it doesn't exist
db_dir = os.path.dirname(db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
print(f"Created directory: {db_dir}")
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create the pt_bt_results table for completed trades
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pt_bt_results (
date DATE,
pair TEXT,
symbol TEXT,
open_time DATETIME,
open_side TEXT,
open_price REAL,
open_quantity INTEGER,
open_disequilibrium REAL,
close_time DATETIME,
close_side TEXT,
close_price REAL,
close_quantity INTEGER,
close_disequilibrium REAL,
symbol_return REAL,
pair_return REAL,
close_condition TEXT
)
"""
)
cursor.execute("DELETE FROM pt_bt_results;")
# Create the outstanding_positions table for open positions
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS outstanding_positions (
date DATE,
pair TEXT,
symbol TEXT,
position_quantity REAL,
last_price REAL,
unrealized_return REAL,
open_price REAL,
open_side TEXT
)
"""
)
cursor.execute("DELETE FROM outstanding_positions;")
# Create the config table for storing configuration JSON for reference
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS config (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_timestamp DATETIME,
config_file_path TEXT,
config_json TEXT,
datafiles TEXT,
instruments TEXT
)
"""
)
cursor.execute("DELETE FROM config;")
conn.commit()
conn.close()
except Exception as e:
print(f"Error creating result database: {str(e)}")
raise
def store_config_in_database(
db_path: str,
config_file_path: str,
config: Dict,
datafiles: List[Tuple[str, str]],
instruments: List[Dict[str, str]],
) -> None:
"""
Store configuration information in the database for reference.
"""
import json
if db_path.upper() == "NONE":
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Convert config to JSON string
config_json = json.dumps(config, indent=2, default=str)
# Convert lists to comma-separated strings for storage
datafiles_str = ", ".join([f"{datafile}" for _, datafile in datafiles])
instruments_str = ", ".join(
[
f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}"
for inst in instruments
]
)
# Insert configuration record
cursor.execute(
"""
INSERT INTO config (
run_timestamp, config_file_path, config_json, datafiles, instruments
) VALUES (?, ?, ?, ?, ?)
""",
(
datetime.now(),
config_file_path,
config_json,
datafiles_str,
instruments_str,
),
)
conn.commit()
conn.close()
print(f"Configuration stored in database")
except Exception as e:
print(f"Error storing configuration in database: {str(e)}")
import traceback
traceback.print_exc()
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
if timestamp is None:
return None
if isinstance(timestamp, pd.Timestamp):
return timestamp.to_pydatetime()
elif isinstance(timestamp, datetime):
return timestamp
elif isinstance(timestamp, date):
return datetime.combine(timestamp, datetime.min.time())
elif isinstance(timestamp, str):
return datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
elif isinstance(timestamp, int):
return datetime.fromtimestamp(timestamp)
else:
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
DayT = str
TradeT = Dict[str, Any]
OutstandingPositionT = List[Dict[str, Any]]
class PairResearchResult:
"""
Class to handle pair research results for a single pair across multiple days.
Simplified version of BacktestResult focused on single pair analysis.
"""
trades_: Dict[DayT, pd.DataFrame]
outstanding_positions_: Dict[DayT, OutstandingPositionT]
def __init__(self, config: Dict[str, Any]) -> None:
self.config_ = config
self.trades_ = {}
self.outstanding_positions_ = {}
self.total_realized_pnl = 0.0
self.symbol_roundtrip_trades_: Dict[str, List[Dict[str, Any]]] = {}
def add_day_results(self, day: DayT, trades: pd.DataFrame, outstanding_positions: List[Dict[str, Any]]) -> None:
assert isinstance(trades, pd.DataFrame)
self.trades_[day] = trades
self.outstanding_positions_[day] = outstanding_positions
@property
def all_trades(self) -> List[TradeT]:
"""Get all trades across all days as a flat list."""
all_trades_list = []
for day_trades in self.trades_.values():
all_trades_list.extend(day_trades)
return all_trades_list
@property
def outstanding_positions(self) -> List[OutstandingPositionT]:
"""Get all outstanding positions across all days as a flat list."""
all_positions = []
for day_positions in self.outstanding_positions_.values():
all_positions.extend(day_positions)
return all_positions
def calculate_returns(self) -> None:
"""Calculate and store total returns for the single pair across all days."""
roundtrip_trades = self.extract_roundtrip_trades()
self.total_realized_pnl = 0.0
for day, day_trades in roundtrip_trades.items():
for trade in day_trades:
self.total_realized_pnl += trade['symbol_return']
def extract_roundtrip_trades(self) -> Dict[str, List[Dict[str, Any]]]:
"""
Extract round-trip trades by day, grouping open/close pairs for each symbol.
Returns a dictionary with day as key and list of completed round-trip trades.
"""
roundtrip_trades_by_day = {}
def _symbol_return(trade1_side: str, trade1_px: float, trade2_side: str, trade2_px: float) -> float:
if trade1_side == "BUY" and trade2_side == "SELL":
return (trade2_px - trade1_px) / trade1_px * 100
elif trade1_side == "SELL" and trade2_side == "BUY":
return (trade1_px - trade2_px) / trade1_px * 100
else:
return 0
# Process each day separately
for day, day_trades in self.trades_.items():
if not day_trades or len(day_trades) < 4:
continue
# Sort trades by timestamp for the day
sorted_trades = sorted(day_trades, key=lambda x: x["timestamp"] if x["timestamp"] else pd.Timestamp.min)
day_roundtrips = []
# Process trades in groups of 4 (open A, open B, close A, close B)
for idx in range(0, len(sorted_trades), 4):
if idx + 3 >= len(sorted_trades):
break
trade_a_1 = sorted_trades[idx] # Open A
trade_b_1 = sorted_trades[idx + 1] # Open B
trade_a_2 = sorted_trades[idx + 2] # Close A
trade_b_2 = sorted_trades[idx + 3] # Close B
# Validate trade sequence
if not (trade_a_1["action"] == "OPEN" and trade_a_2["action"] == "CLOSE"):
continue
if not (trade_b_1["action"] == "OPEN" and trade_b_2["action"] == "CLOSE"):
continue
# Calculate individual symbol returns
symbol_a_return = _symbol_return(
trade_a_1["side"], trade_a_1["price"],
trade_a_2["side"], trade_a_2["price"]
)
symbol_b_return = _symbol_return(
trade_b_1["side"], trade_b_1["price"],
trade_b_2["side"], trade_b_2["price"]
)
pair_return = symbol_a_return + symbol_b_return
# Create round-trip records for both symbols
funding_per_position = self.config_.get("funding_per_pair", 10000) / 2
# Symbol A round-trip
day_roundtrips.append({
"symbol": trade_a_1["symbol"],
"open_side": trade_a_1["side"],
"open_price": trade_a_1["price"],
"open_time": trade_a_1["timestamp"],
"close_side": trade_a_2["side"],
"close_price": trade_a_2["price"],
"close_time": trade_a_2["timestamp"],
"symbol_return": symbol_a_return,
"pair_return": pair_return,
"shares": funding_per_position / trade_a_1["price"],
"close_condition": trade_a_2.get("status", "UNKNOWN"),
"open_disequilibrium": trade_a_1.get("disequilibrium"),
"close_disequilibrium": trade_a_2.get("disequilibrium"),
})
# Symbol B round-trip
day_roundtrips.append({
"symbol": trade_b_1["symbol"],
"open_side": trade_b_1["side"],
"open_price": trade_b_1["price"],
"open_time": trade_b_1["timestamp"],
"close_side": trade_b_2["side"],
"close_price": trade_b_2["price"],
"close_time": trade_b_2["timestamp"],
"symbol_return": symbol_b_return,
"pair_return": pair_return,
"shares": funding_per_position / trade_b_1["price"],
"close_condition": trade_b_2.get("status", "UNKNOWN"),
"open_disequilibrium": trade_b_1.get("disequilibrium"),
"close_disequilibrium": trade_b_2.get("disequilibrium"),
})
if day_roundtrips:
roundtrip_trades_by_day[day] = day_roundtrips
return roundtrip_trades_by_day
def print_returns_by_day(self) -> None:
"""
Print detailed return information for each day, grouped by day.
Shows individual symbol round-trips and daily totals.
"""
roundtrip_trades = self.extract_roundtrip_trades()
if not roundtrip_trades:
print("\n====== NO ROUND-TRIP TRADES FOUND ======")
return
print("\n====== PAIR RESEARCH RETURNS BY DAY ======")
total_return_all_days = 0.0
for day in sorted(roundtrip_trades.keys()):
day_trades = roundtrip_trades[day]
print(f"\n--- {day} ---")
day_total_return = 0.0
pair_returns = []
# Group trades by pair (every 2 trades form a pair)
for i in range(0, len(day_trades), 2):
if i + 1 < len(day_trades):
trade_a = day_trades[i]
trade_b = day_trades[i + 1]
# Print individual symbol results
print(f" {trade_a['open_time'].time()}-{trade_a['close_time'].time()}")
print(f" {trade_a['symbol']}: {trade_a['open_side']} @ ${trade_a['open_price']:.2f}"
f"{trade_a['close_side']} @ ${trade_a['close_price']:.2f} | "
f"Return: {trade_a['symbol_return']:+.2f}% | Shares: {trade_a['shares']:.2f}")
print(f" {trade_b['symbol']}: {trade_b['open_side']} @ ${trade_b['open_price']:.2f}"
f"{trade_b['close_side']} @ ${trade_b['close_price']:.2f} | "
f"Return: {trade_b['symbol_return']:+.2f}% | Shares: {trade_b['shares']:.2f}")
# Show disequilibrium info if available
if trade_a.get('open_disequilibrium') is not None:
print(f" Disequilibrium: Open: {trade_a['open_disequilibrium']:.4f}, "
f"Close: {trade_a['close_disequilibrium']:.4f}")
pair_return = trade_a['pair_return']
print(f" Pair Return: {pair_return:+.2f}% | Close Condition: {trade_a['close_condition']}")
print()
pair_returns.append(pair_return)
day_total_return += pair_return
print(f" Day Total Return: {day_total_return:+.2f}% ({len(pair_returns)} pairs)")
total_return_all_days += day_total_return
print(f"\n====== TOTAL RETURN ACROSS ALL DAYS ======")
print(f"Total Return: {total_return_all_days:+.2f}%")
print(f"Total Days: {len(roundtrip_trades)}")
if len(roundtrip_trades) > 0:
print(f"Average Daily Return: {total_return_all_days / len(roundtrip_trades):+.2f}%")
def get_return_summary(self) -> Dict[str, Any]:
"""
Get a summary of returns across all days.
Returns a dictionary with key metrics.
"""
roundtrip_trades = self.extract_roundtrip_trades()
if not roundtrip_trades:
return {
"total_return": 0.0,
"total_days": 0,
"total_pairs": 0,
"average_daily_return": 0.0,
"best_day": None,
"worst_day": None,
"daily_returns": {}
}
daily_returns = {}
total_return = 0.0
total_pairs = 0
for day, day_trades in roundtrip_trades.items():
day_return = 0.0
day_pairs = len(day_trades) // 2 # Each pair has 2 symbol trades
for trade in day_trades:
day_return += trade['symbol_return']
daily_returns[day] = {
"return": day_return,
"pairs": day_pairs
}
total_return += day_return
total_pairs += day_pairs
best_day = max(daily_returns.items(), key=lambda x: x[1]["return"]) if daily_returns else None
worst_day = min(daily_returns.items(), key=lambda x: x[1]["return"]) if daily_returns else None
return {
"total_return": total_return,
"total_days": len(roundtrip_trades),
"total_pairs": total_pairs,
"average_daily_return": total_return / len(roundtrip_trades) if roundtrip_trades else 0.0,
"best_day": best_day,
"worst_day": worst_day,
"daily_returns": daily_returns
}
def print_single_day_results(self) -> None:
"""Print results for all processed days."""
all_trades_list = self.all_trades
if not all_trades_list:
print("No trades found.")
return
print(f"Total trades processed: {len(all_trades_list)}")
# Group trades by day
trades_by_day: Dict[str, List[Dict[str, Any]]] = {}
for trade in all_trades_list:
if trade["timestamp"]:
day = trade["timestamp"].date()
if day not in trades_by_day:
trades_by_day[day] = []
trades_by_day[day].append(trade)
for day, day_trades in sorted(trades_by_day.items()):
print(f"\n--- {day} ---")
for trade in day_trades:
print(f" {trade['timestamp'].time() if trade['timestamp'] else 'N/A'}: "
f"{trade['symbol']} {trade['side']} {trade['action']} @ ${trade['price']:.2f} "
f"({trade['status']})")
def print_grand_totals(self) -> None:
"""Print grand totals for the single pair analysis."""
summary = self.get_return_summary()
print(f"\n====== PAIR RESEARCH GRAND TOTALS ======")
print(f"Total Return: {summary['total_return']:+.2f}%")
print(f"Total Days Traded: {summary['total_days']}")
print(f"Total Pair Trades: {summary['total_pairs']}")
if summary['total_days'] > 0:
print(f"Average Daily Return: {summary['average_daily_return']:+.2f}%")
if summary['best_day']:
best_day, best_data = summary['best_day']
print(f"Best Day: {best_day} ({best_data['return']:+.2f}%)")
if summary['worst_day']:
worst_day, worst_data = summary['worst_day']
print(f"Worst Day: {worst_day} ({worst_data['return']:+.2f}%)")
# Update the total_realized_pnl for backward compatibility
self.total_realized_pnl = summary['total_return']
def analyze_pair_performance(self) -> None:
"""
Main method to perform comprehensive pair research analysis.
Extracts round-trip trades, calculates returns, groups by day, and prints results.
"""
print(f"\n{'='*60}")
print(f"PAIR RESEARCH PERFORMANCE ANALYSIS")
print(f"{'='*60}")
# Calculate returns first
self.calculate_returns()
# Print detailed returns by day
self.print_returns_by_day()
# Print outstanding positions if any
self.print_outstanding_positions()
# Print grand totals
self.print_grand_totals()
# Print additional analysis
self._print_additional_metrics()
def _print_additional_metrics(self) -> None:
"""Print additional performance metrics."""
summary = self.get_return_summary()
if summary['total_days'] == 0:
return
print(f"\n====== ADDITIONAL METRICS ======")
# Calculate win rate
winning_days = sum(1 for day_data in summary['daily_returns'].values() if day_data['return'] > 0)
win_rate = (winning_days / summary['total_days']) * 100
print(f"Winning Days: {winning_days}/{summary['total_days']} ({win_rate:.1f}%)")
# Calculate average trade return
if summary['total_pairs'] > 0:
# Each pair has 2 symbol trades, so total symbol trades = total_pairs * 2
total_symbol_trades = summary['total_pairs'] * 2
avg_symbol_return = summary['total_return'] / total_symbol_trades
print(f"Average Symbol Return: {avg_symbol_return:+.2f}%")
avg_pair_return = summary['total_return'] / summary['total_pairs'] / 2 # Divide by 2 since we sum both symbols
print(f"Average Pair Return: {avg_pair_return:+.2f}%")
# Show daily return distribution
daily_returns_list = [data['return'] for data in summary['daily_returns'].values()]
if daily_returns_list:
print(f"Daily Return Range: {min(daily_returns_list):+.2f}% to {max(daily_returns_list):+.2f}%")
def print_outstanding_positions(self) -> None:
"""Print outstanding positions for the single pair."""
all_positions = self.outstanding_positions
if not all_positions:
print("\n====== NO OUTSTANDING POSITIONS ======")
return
print(f"\n====== OUTSTANDING POSITIONS ======")
print(f"{'Symbol':<10} {'Side':<4} {'Shares':<10} {'Open $':<8} {'Current $':<10} {'Value $':<12}")
print("-" * 70)
total_value = 0.0
for pos in all_positions:
current_value = pos.get("last_value", 0.0)
print(f"{pos['symbol']:<10} {pos['open_side']:<4} {pos['shares']:<10.2f} "
f"{pos['open_px']:<8.2f} {pos['last_px']:<10.2f} {current_value:<12.2f}")
total_value += current_value
print("-" * 70)
print(f"{'TOTAL VALUE':<60} ${total_value:<12.2f}")
def get_total_realized_pnl(self) -> float:
"""Get total realized PnL."""
return self.total_realized_pnl
def get_outstanding_positions(self) -> List[Dict[str, Any]]:
"""Get outstanding positions."""
return self.outstanding_positions

View File

@ -0,0 +1,170 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional, Type, cast, Generator, List
import pandas as pd
from pt_strategy.model_data_policy import DataParams
class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSE = 3
CLOSE_POSITION = 4
CLOSE_STOP_LOSS = 5
CLOSE_STOP_PROFIT = 6
class TradingPair:
config_: Dict[str, Any]
market_data_: pd.DataFrame
symbol_a_: str
symbol_b_: str
stat_model_price_: str
model_: PairsTradingModel # type: ignore[assignment]
model_tdp_: ModelDataPolicy # type: ignore[assignment]
user_data_: Dict[str, Any]
def __init__(
self,
config: Dict[str, Any],
instruments: List[Dict[str, str]],
):
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_model import PairsTradingModel
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
self.config_ = config
self.symbol_a_ = instruments[0]["symbol"]
self.symbol_b_ = instruments[1]["symbol"]
self.model_ = PairsTradingModel.create(config)
self.model_tdp_ = ModelDataPolicy.create(config)
self.stat_model_price_ = config["stat_model_price"]
self.user_data_ = {
"state": PairState.INITIAL,
}
def colnames(self) -> List[str]:
return [
f"{self.stat_model_price_}_{self.symbol_a_}",
f"{self.stat_model_price_}_{self.symbol_b_}",
]
def exec_prices_colnames(self) -> List[str]:
return [
f"exec_price_{self.symbol_a_}",
f"exec_price_{self.symbol_b_}",
]
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
config = self.config_
if (
"stop_close_conditions" not in config
or config["stop_close_conditions"] is None
):
return False
if "profit" in config["stop_close_conditions"]:
current_return = self._current_return(predicted_row)
#
# print(f"time={predicted_row['tstamp']} current_return={current_return}")
#
if current_return >= config["stop_close_conditions"]["profit"]:
print(f"STOP PROFIT: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
return True
if "loss" in config["stop_close_conditions"]:
if current_return <= config["stop_close_conditions"]["loss"]:
print(f"STOP LOSS: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
return True
return False
def _current_return(self, predicted_row: pd.Series) -> float:
if "open_trades" in self.user_data_:
open_trades = self.user_data_["open_trades"]
if len(open_trades) == 0:
return 0.0
def _single_instrument_return(symbol: str) -> float:
instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
instrument_open_price = instrument_open_trades["price"].iloc[0]
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"]
instrument_return = (
sign
* (instrument_price - instrument_open_price)
/ instrument_open_price
)
return float(instrument_return) * 100.0
instrument_a_return = _single_instrument_return(self.symbol_a_)
instrument_b_return = _single_instrument_return(self.symbol_b_)
return instrument_a_return + instrument_b_return
return 0.0
def on_open_trades(self, trades: pd.DataFrame) -> None:
if "close_trades" in self.user_data_:
del self.user_data_["close_trades"]
self.user_data_["open_trades"] = trades
def on_close_trades(self, trades: pd.DataFrame) -> None:
del self.user_data_["open_trades"]
self.user_data_["close_trades"] = trades
def add_outstanding_position(
self,
symbol: str,
open_side: str,
open_px: float,
open_tstamp: datetime,
last_mkt_data_row: pd.Series,
) -> None:
assert symbol in [self.symbol_a_, self.symbol_b_], "Symbol must be one of the pair's symbols"
assert open_side in ["BUY", "SELL"], "Open side must be either BUY or SELL"
assert open_px > 0, "Open price must be greater than 0"
assert open_tstamp is not None, "Open timestamp must be provided"
assert last_mkt_data_row is not None, "Last market data row must be provided"
exec_prices_col_a, exec_prices_col_b = self.exec_prices_colnames()
if symbol == self.symbol_a_:
last_px = last_mkt_data_row[exec_prices_col_a]
else:
last_px = last_mkt_data_row[exec_prices_col_b]
funding_per_position = self.config_["funding_per_pair"] / 2
shares = funding_per_position / open_px
if open_side == "SELL":
shares = -shares
if "outstanding_positions" not in self.user_data_:
self.user_data_["outstanding_positions"] = []
self.user_data_["outstanding_positions"].append({
"symbol": symbol,
"open_side": open_side,
"open_px": open_px,
"shares": shares,
"open_tstamp": open_tstamp,
"last_px": last_px,
"last_tstamp": last_mkt_data_row["tstamp"],
"last_value": last_px * shares,
})
def run(self, market_data: pd.DataFrame, data_params: DataParams) -> Prediction: # type: ignore[assignment]
self.market_data_ = market_data[data_params.training_start_index:data_params.training_start_index + data_params.training_size]
return self.model_.predict(pair=self)
while self.model_tdp_.has_next_training_data():
training_data = self.model_tdp_.get_next_training_data()

View File

@ -0,0 +1,416 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Type, cast
import pandas as pd
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData
from pt_strategy.pt_model import Prediction
from pt_strategy.results import (
PairResearchResult,
create_result_database,
store_config_in_database,
)
from pt_strategy.trading_pair import PairState, TradingPair
from tools.filetools import resolve_datafiles
from tools.instruments import get_instruments
class PtResearchStrategy:
config_: Dict[str, Any]
trading_pair_: TradingPair
model_data_policy_: ModelDataPolicy
pt_mkt_data_: PtMarketData
trades_: List[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
datafiles: List[str],
instruments: List[Dict[str, str]],
):
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData, ResearchMarketData
from pt_strategy.trading_pair import TradingPair
self.config_ = config
self.trades_ = []
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
self.model_data_policy_ = ModelDataPolicy.create(config)
import copy
# modified config must be passed to PtMarketData
config_copy = copy.deepcopy(config)
config_copy["instruments"] = instruments
config_copy["datafiles"] = datafiles
self.pt_mkt_data_ = PtMarketData.create(
config=config_copy, md_class=ResearchMarketData
)
self.pt_mkt_data_.load()
def outstanding_positions(self) -> List[Dict[str, Any]]:
return list(self.trading_pair_.user_data_.get("outstanding_positions", []))
def run(self) -> None:
training_minutes = self.config_.get("training_minutes", 120)
market_data_series: pd.Series
market_data_df = pd.DataFrame()
idx = 0
while self.pt_mkt_data_.has_next():
market_data_series = self.pt_mkt_data_.get_next()
market_data_df = pd.concat(
[market_data_df, market_data_series.to_frame().T], ignore_index=True
)
if idx >= training_minutes:
break
idx += 1
assert idx >= training_minutes, "Not enough training data"
while self.pt_mkt_data_.has_next():
market_data_series = self.pt_mkt_data_.get_next()
new_row = market_data_series.to_frame().T
market_data_df = pd.concat([market_data_df, new_row], ignore_index=True)
prediction = self.trading_pair_.run(
market_data_df, self.model_data_policy_.advance()
)
assert prediction is not None
trades = self._create_trades(
prediction=prediction, last_row=market_data_df.iloc[-1]
)
if trades is not None:
self.trades_.append(trades)
trades = self._handle_outstanding_positions()
if trades is not None:
self.trades_.append(trades)
def _create_trades(
self, prediction: Prediction, last_row: pd.Series
) -> Optional[pd.DataFrame]:
pair = self.trading_pair_
trades = None
open_threshold = self.config_["dis-equilibrium_open_trshld"]
close_threshold = self.config_["dis-equilibrium_close_trshld"]
scaled_disequilibrium = prediction.scaled_disequilibrium_
abs_scaled_disequilibrium = abs(scaled_disequilibrium)
if pair.user_data_["state"] in [
PairState.INITIAL,
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]:
if abs_scaled_disequilibrium >= open_threshold:
trades = self._create_open_trades(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.OPEN.name
print(f"OPEN TRADES:\n{trades}")
pair.user_data_["state"] = PairState.OPEN
pair.on_open_trades(trades)
elif pair.user_data_["state"] == PairState.OPEN:
if abs_scaled_disequilibrium <= close_threshold:
trades = self._create_close_trades(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.CLOSE.name
print(f"CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = PairState.CLOSE
pair.on_close_trades(trades)
elif pair.to_stop_close_conditions(predicted_row=last_row):
trades = self._create_close_trades(pair, row=last_row)
if trades is not None:
trades["status"] = pair.user_data_["stop_close_state"].name
print(f"STOP CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
pair.on_close_trades(trades)
return trades
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
trades = None
pair = self.trading_pair_
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: *** Position is NOT CLOSED. ***")
# outstanding positions
if self.config_["close_outstanding_positions"]:
close_position_row = pd.Series(pair.market_data_.iloc[-2])
# close_position_row["disequilibrium"] = 0.0
# close_position_row["scaled_disequilibrium"] = 0.0
# close_position_row["signed_scaled_disequilibrium"] = 0.0
trades = self._create_close_trades(
pair=pair, row=close_position_row, prediction=None
)
if trades is not None:
trades["status"] = PairState.CLOSE_POSITION.name
print(f"CLOSE_POSITION TRADES:\n{trades}")
pair.user_data_["state"] = PairState.CLOSE_POSITION
pair.on_close_trades(trades)
else:
pair.add_outstanding_position(
symbol=pair.symbol_a_,
open_side=pair.user_data_["open_side_a"],
open_px=pair.user_data_["open_px_a"],
open_tstamp=pair.user_data_["open_tstamp"],
last_mkt_data_row=pair.market_data_.iloc[-1],
)
pair.add_outstanding_position(
symbol=pair.symbol_b_,
open_side=pair.user_data_["open_side_b"],
open_px=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
last_mkt_data_row=pair.market_data_.iloc[-1],
)
return trades
def _trades_df(self) -> pd.DataFrame:
types = {
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"side": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
# "pair": "object",
}
columns = list(types.keys())
return pd.DataFrame(columns=columns).astype(types)
def _create_open_trades(
self, pair: TradingPair, row: pd.Series, prediction: Prediction
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
tstamp = row["tstamp"]
diseqlbrm = prediction.disequilibrium_
scaled_disequilibrium = prediction.scaled_disequilibrium_
px_a = row[f"{colname_a}"]
px_b = row[f"{colname_b}"]
# creating the trades
df = self._trades_df()
print(f"OPEN_TRADES: {row["tstamp"]} {scaled_disequilibrium=}")
if diseqlbrm > 0:
side_a = "SELL"
side_b = "BUY"
else:
side_a = "BUY"
side_b = "SELL"
# save closing sides
pair.user_data_["open_side_a"] = side_a # used in oustanding positions
pair.user_data_["open_side_b"] = side_b
pair.user_data_["open_px_a"] = px_a
pair.user_data_["open_px_b"] = px_b
pair.user_data_["open_tstamp"] = tstamp
pair.user_data_["close_side_a"] = side_b # used for closing trades
pair.user_data_["close_side_b"] = side_a
# create opening trades
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_a_,
"side": side_a,
"action": "OPEN",
"price": px_a,
"disequilibrium": diseqlbrm,
"signed_scaled_disequilibrium": scaled_disequilibrium,
"scaled_disequilibrium": abs(scaled_disequilibrium),
# "pair": pair,
}
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_b_,
"side": side_b,
"action": "OPEN",
"price": px_b,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": abs(scaled_disequilibrium),
"signed_scaled_disequilibrium": scaled_disequilibrium,
# "pair": pair,
}
return df
def _create_close_trades(
self, pair: TradingPair, row: pd.Series, prediction: Optional[Prediction] = None
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
tstamp = row["tstamp"]
if prediction is not None:
diseqlbrm = prediction.disequilibrium_
signed_scaled_disequilibrium = prediction.scaled_disequilibrium_
scaled_disequilibrium = abs(prediction.scaled_disequilibrium_)
else:
diseqlbrm = 0.0
signed_scaled_disequilibrium = 0.0
scaled_disequilibrium = 0.0
px_a = row[f"{colname_a}"]
px_b = row[f"{colname_b}"]
# creating the trades
df = self._trades_df()
# create opening trades
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_a_,
"side": pair.user_data_["close_side_a"],
"action": "CLOSE",
"price": px_a,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": scaled_disequilibrium,
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
# "pair": pair,
}
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_b_,
"side": pair.user_data_["close_side_b"],
"action": "CLOSE",
"price": px_b,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": scaled_disequilibrium,
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
# "pair": pair,
}
del pair.user_data_["close_side_a"]
del pair.user_data_["close_side_b"]
del pair.user_data_["open_tstamp"]
del pair.user_data_["open_px_a"]
del pair.user_data_["open_px_b"]
del pair.user_data_["open_side_a"]
del pair.user_data_["open_side_b"]
return df
def day_trades(self) -> pd.DataFrame:
return pd.concat(self.trades_, ignore_index=True)
def main() -> None:
import argparse
from tools.config import expand_filename, load_config
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--date_pattern",
type=str,
required=True,
help="Date YYYYMMDD, allows * and ? wildcards",
)
parser.add_argument(
"--instruments",
type=str,
required=True,
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
)
parser.add_argument(
"--result_db",
type=str,
required=True,
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Resolve data files (CLI takes priority over config)
instruments = get_instruments(args, config)
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
days = list(set([day for day, _ in datafiles]))
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
# Create result database if needed
if args.result_db.upper() != "NONE":
args.result_db = expand_filename(args.result_db)
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
is_config_stored = False
# Process each data file
results = PairResearchResult(config=config)
for day in sorted(days):
md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day]
if not all([os.path.exists(datafile) for datafile in md_datafiles]):
print(f"WARNING: insufficient data files: {md_datafiles}")
continue
print(f"\n====== Processing {day} ======")
if not is_config_stored:
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
datafiles=datafiles,
instruments=instruments,
)
is_config_stored = True
pt_strategy = PtResearchStrategy(
config=config, datafiles=md_datafiles, instruments=instruments
)
pt_strategy.run()
results.add_day_results(
day=day,
trades=pt_strategy.day_trades(),
outstanding_positions=pt_strategy.outstanding_positions(),
)
# ADD RESULTS ANALYSIS
results.calculate_returns()
results.print_single_day_results()
# Store results with day name as key
# filename = os.path.basename(day)
# all_results[filename] = {
# "trades": pt_strategy.trades_.copy(),
# "outstanding_positions": pt_strategy.outstanding_positions_.copy(),
# }
# print(f"Successfully processed {filename}")
results.calculate_returns()
results.print_grand_totals()
results.print_outstanding_positions()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")
else:
print("No results to display.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,304 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import PairState, TradingPair
NanoPerMin = 1e9
class ExpandingWindowFit(PairsTradingFitMethod):
"""
N O T E:
=========
- This class remains to be abstract
- The following methods are to be implemented in the subclass:
- create_trading_pair()
=========
"""
def __init__(self) -> None:
super().__init__()
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]:
print(f"***{pair}*** STARTING....")
config = pair.config_
start_idx = pair.get_begin_index()
end_index = pair.get_end_index()
pair.user_data_["state"] = PairState.INITIAL
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS).astype(
{
"time": "datetime64[ns]",
"symbol": "string",
"side": "string",
"action": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object",
}
)
training_minutes = config["training_minutes"]
while training_minutes + 1 < end_index:
pair.get_datasets(
training_minutes=training_minutes,
training_start_index=start_idx,
testing_size=1,
)
# ================================ PREDICTION ================================
try:
self.pair_predict_result_ = pair.predict()
except Exception as e:
raise RuntimeError(
f"{pair}: TrainingPrediction failed: {str(e)}"
) from e
training_minutes += 1
self._create_trading_signals(pair, config, bt_result)
print(f"***{pair}*** FINISHED *** Num Trades:{len(pair.user_data_['trades'])}")
return pair.get_trades()
def _create_trading_signals(
self, pair: TradingPair, config: Dict, bt_result: BacktestResult
) -> None:
predicted_df = self.pair_predict_result_
assert predicted_df is not None
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
for curr_predicted_row_idx in range(len(predicted_df)):
pred_row = predicted_df.iloc[curr_predicted_row_idx]
scaled_disequilibrium = pred_row["scaled_disequilibrium"]
if pair.user_data_["state"] in [
PairState.INITIAL,
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]:
if scaled_disequilibrium >= open_threshold:
open_trades = self._get_open_trades(
pair, row=pred_row, open_threshold=open_threshold
)
if open_trades is not None:
open_trades["status"] = PairState.OPEN.name
print(f"OPEN TRADES:\n{open_trades}")
pair.add_trades(open_trades)
pair.user_data_["state"] = PairState.OPEN
pair.on_open_trades(open_trades)
elif pair.user_data_["state"] == PairState.OPEN:
if scaled_disequilibrium <= close_threshold:
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = PairState.CLOSE.name
print(f"CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = PairState.CLOSE
pair.on_close_trades(close_trades)
elif pair.to_stop_close_conditions(predicted_row=pred_row):
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = pair.user_data_[
"stop_close_state"
].name
print(f"STOP CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
pair.on_close_trades(close_trades)
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: *** Position is NOT CLOSED. ***")
# outstanding positions
if config["close_outstanding_positions"]:
close_position_row = pd.Series(pair.market_data_.iloc[-2])
close_position_row["disequilibrium"] = 0.0
close_position_row["scaled_disequilibrium"] = 0.0
close_position_row["signed_scaled_disequilibrium"] = 0.0
close_position_trades = self._get_close_trades(
pair=pair, row=close_position_row, close_threshold=close_threshold
)
if close_position_trades is not None:
close_position_trades["status"] = PairState.CLOSE_POSITION.name
print(f"CLOSE_POSITION TRADES:\n{close_position_trades}")
pair.add_trades(close_position_trades)
pair.user_data_["state"] = PairState.CLOSE_POSITION
pair.on_close_trades(close_position_trades)
else:
if predicted_df is not None:
bt_result.handle_outstanding_position(
pair=pair,
pair_result_df=predicted_df,
last_row_index=0,
open_side_a=pair.user_data_["open_side_a"],
open_side_b=pair.user_data_["open_side_b"],
open_px_a=pair.user_data_["open_px_a"],
open_px_b=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
)
def _get_open_trades(
self, pair: TradingPair, row: pd.Series, open_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
open_row = row
open_tstamp = open_row["tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
signed_scaled_disequilibrium = open_row["signed_scaled_disequilibrium"]
open_px_a = open_row[f"{colname_a}"]
open_px_b = open_row[f"{colname_b}"]
# creating the trades
print(f"OPEN_TRADES: {row["tstamp"]} {open_scaled_disequilibrium=}")
if open_disequilibrium > 0:
open_side_a = "SELL"
open_side_b = "BUY"
close_side_a = "BUY"
close_side_b = "SELL"
else:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
# save closing sides
pair.user_data_["open_side_a"] = open_side_a
pair.user_data_["open_side_b"] = open_side_b
pair.user_data_["open_px_a"] = open_px_a
pair.user_data_["open_px_b"] = open_px_b
pair.user_data_["open_tstamp"] = open_tstamp
pair.user_data_["close_side_a"] = close_side_a
pair.user_data_["close_side_b"] = close_side_b
# create opening trades
trd_signal_tuples = [
(
open_tstamp,
pair.symbol_a_,
open_side_a,
"OPEN",
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
(
open_tstamp,
pair.symbol_b_,
open_side_b,
"OPEN",
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
]
# Create DataFrame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype(
{
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
"pair": "object",
}
)
def _get_close_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
close_row = row
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
signed_scaled_disequilibrium = close_row["signed_scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
close_side_a = pair.user_data_["close_side_a"]
close_side_b = pair.user_data_["close_side_b"]
trd_signal_tuples = [
(
close_tstamp,
pair.symbol_a_,
close_side_a,
"CLOSE",
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
(
close_tstamp,
pair.symbol_b_,
close_side_b,
"CLOSE",
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype(
{
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
"pair": "object",
}
)
def reset(self) -> None:
pass

View File

@ -195,6 +195,16 @@ def convert_timestamp(timestamp: Any) -> Optional[datetime]:
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}") raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
class PairResarchResult:
pair_: TradingPair
trades_: Dict[str, Dict[str, Any]]
outstanding_positions_: List[Dict[str, Any]]
def __init__(self, config: Dict[str, Any], pair: TradingPair, trades: Dict[str, Dict[str, Any]], outstanding_positions: List[Dict[str, Any]]):
self.config = config
self.pair_ = pair
self.trades_ = trades
self.outstanding_positions_ = outstanding_positions
class BacktestResult: class BacktestResult:
""" """
@ -206,7 +216,7 @@ class BacktestResult:
self.trades: Dict[str, Dict[str, Any]] = {} self.trades: Dict[str, Dict[str, Any]] = {}
self.total_realized_pnl = 0.0 self.total_realized_pnl = 0.0
self.outstanding_positions: List[Dict[str, Any]] = [] self.outstanding_positions: List[Dict[str, Any]] = []
self.pairs_trades_: Dict[str, List[Dict[str, Any]]] = {} self.symbol_roundtrip_trades_: Dict[str, List[Dict[str, Any]]] = {}
def add_trade( def add_trade(
self, self,
@ -334,7 +344,7 @@ class BacktestResult:
for filename, data in all_results.items(): for filename, data in all_results.items():
pairs = list(data["trades"].keys()) pairs = list(data["trades"].keys())
for pair in pairs: for pair in pairs:
self.pairs_trades_[pair] = [] self.symbol_roundtrip_trades_[pair] = []
trades_dict = data["trades"][pair] trades_dict = data["trades"][pair]
for symbol in trades_dict.keys(): for symbol in trades_dict.keys():
trades.extend(trades_dict[symbol]) trades.extend(trades_dict[symbol])
@ -369,7 +379,7 @@ class BacktestResult:
pair_return = symbol_a_return + symbol_b_return pair_return = symbol_a_return + symbol_b_return
self.pairs_trades_[pair].append( self.symbol_roundtrip_trades_[pair].append(
{ {
"symbol": symbol_a, "symbol": symbol_a,
"open_side": trade_a_1["side"], "open_side": trade_a_1["side"],
@ -391,7 +401,7 @@ class BacktestResult:
"pair_return": pair_return "pair_return": pair_return
} }
) )
self.pairs_trades_[pair].append( self.symbol_roundtrip_trades_[pair].append(
{ {
"symbol": symbol_b, "symbol": symbol_b,
"open_side": trade_b_1["side"], "open_side": trade_b_1["side"],
@ -417,11 +427,11 @@ class BacktestResult:
# Print pair returns with disequilibrium information # Print pair returns with disequilibrium information
day_return = 0.0 day_return = 0.0
if pair in self.pairs_trades_: if pair in self.symbol_roundtrip_trades_:
print(f"{pair}:") print(f"{pair}:")
pair_return = 0.0 pair_return = 0.0
for trd in self.pairs_trades_[pair]: for trd in self.symbol_roundtrip_trades_[pair]:
disequil_info = "" disequil_info = ""
if ( if (
trd["open_scaled_disequilibrium"] is not None trd["open_scaled_disequilibrium"] is not None
@ -641,7 +651,7 @@ class BacktestResult:
for pair_name, _ in trades.items(): for pair_name, _ in trades.items():
# Second pass: insert completed trade records into database # Second pass: insert completed trade records into database
for trade_pair in sorted(self.pairs_trades_[pair_name], key=lambda x: x["open_time"]): for trade_pair in sorted(self.symbol_roundtrip_trades_[pair_name], key=lambda x: x["open_time"]):
# Only store completed trades in pt_bt_results table # Only store completed trades in pt_bt_results table
cursor.execute( cursor.execute(
""" """

View File

@ -316,4 +316,4 @@ class RollingFit(PairsTradingFitMethod):
) )
def reset(self) -> None: def reset(self) -> None:
curr_training_start_idx = 0 pass

View File

@ -119,8 +119,8 @@ class TradingPair(ABC):
return return
execution_price_column = self.config_["execution_price"]["column"] execution_price_column = self.config_["execution_price"]["column"]
execution_price_shift = self.config_["execution_price"]["shift"] execution_price_shift = self.config_["execution_price"]["shift"]
self.market_data_[f"exec_price_{self.symbol_a_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_a_}"].shift(-execution_price_shift) self.market_data_[f"exec_price_{self.symbol_a_}"] = self.market_data_[f"{execution_price_column}_{self.symbol_a_}"].shift(-execution_price_shift)
self.market_data_[f"exec_price_{self.symbol_b_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_b_}"].shift(-execution_price_shift) self.market_data_[f"exec_price_{self.symbol_b_}"] = self.market_data_[f"{execution_price_column}_{self.symbol_b_}"].shift(-execution_price_shift)
self.market_data_ = self.market_data_.dropna().reset_index(drop=True) self.market_data_ = self.market_data_.dropna().reset_index(drop=True)

33
lib/tools/filetools.py Normal file
View File

@ -0,0 +1,33 @@
import os
import glob
from typing import Dict, List, Tuple
DayT = str
DataFileNameT = str
def resolve_datafiles(
config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
) -> List[Tuple[DayT, DataFileNameT]]:
resolved_files: List[Tuple[DayT, DataFileNameT]] = []
for inst in instruments:
pattern = date_pattern
inst_type = inst["instrument_type"]
data_dir = config["market_data_loading"][inst_type]["data_directory"]
if "*" in pattern or "?" in pattern:
# Handle wildcards
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
matched_files = glob.glob(pattern)
for matched_file in matched_files:
import re
match = re.search(r"(\d{8})\.mktdata\.ohlcv\.db$", matched_file)
assert match is not None
day = match.group(1)
resolved_files.append((day, matched_file))
else:
# Handle explicit file path
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
resolved_files.append((date_pattern, pattern))
return sorted(list(set(resolved_files))) # Remove duplicates and sort

21
lib/tools/instruments.py Normal file
View File

@ -0,0 +1,21 @@
import argparse
from typing import Dict, List
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
instruments = [
{
"symbol": inst.split(":")[0],
"instrument_type": inst.split(":")[1],
"exchange_id": inst.split(":")[2],
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
"instrument_id_pfx"
],
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
"db_table_name"
],
}
for inst in args.instruments.split(",")
]
return instruments

File diff suppressed because one or more lines are too long