pairs_trading/lib/pt_strategy/trading_pair.py
2025-07-30 04:08:02 +00:00

171 lines
6.1 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional, Type, cast, Generator, List
import pandas as pd
from pt_strategy.model_data_policy import DataParams
class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSE = 3
CLOSE_POSITION = 4
CLOSE_STOP_LOSS = 5
CLOSE_STOP_PROFIT = 6
class TradingPair:
config_: Dict[str, Any]
market_data_: pd.DataFrame
symbol_a_: str
symbol_b_: str
stat_model_price_: str
model_: PairsTradingModel # type: ignore[assignment]
model_tdp_: ModelDataPolicy # type: ignore[assignment]
user_data_: Dict[str, Any]
def __init__(
self,
config: Dict[str, Any],
instruments: List[Dict[str, str]],
):
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_model import PairsTradingModel
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
self.config_ = config
self.symbol_a_ = instruments[0]["symbol"]
self.symbol_b_ = instruments[1]["symbol"]
self.model_ = PairsTradingModel.create(config)
self.model_tdp_ = ModelDataPolicy.create(config)
self.stat_model_price_ = config["stat_model_price"]
self.user_data_ = {
"state": PairState.INITIAL,
}
def colnames(self) -> List[str]:
return [
f"{self.stat_model_price_}_{self.symbol_a_}",
f"{self.stat_model_price_}_{self.symbol_b_}",
]
def exec_prices_colnames(self) -> List[str]:
return [
f"exec_price_{self.symbol_a_}",
f"exec_price_{self.symbol_b_}",
]
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
config = self.config_
if (
"stop_close_conditions" not in config
or config["stop_close_conditions"] is None
):
return False
if "profit" in config["stop_close_conditions"]:
current_return = self._current_return(predicted_row)
#
# print(f"time={predicted_row['tstamp']} current_return={current_return}")
#
if current_return >= config["stop_close_conditions"]["profit"]:
print(f"STOP PROFIT: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
return True
if "loss" in config["stop_close_conditions"]:
if current_return <= config["stop_close_conditions"]["loss"]:
print(f"STOP LOSS: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
return True
return False
def _current_return(self, predicted_row: pd.Series) -> float:
if "open_trades" in self.user_data_:
open_trades = self.user_data_["open_trades"]
if len(open_trades) == 0:
return 0.0
def _single_instrument_return(symbol: str) -> float:
instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
instrument_open_price = instrument_open_trades["price"].iloc[0]
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"]
instrument_return = (
sign
* (instrument_price - instrument_open_price)
/ instrument_open_price
)
return float(instrument_return) * 100.0
instrument_a_return = _single_instrument_return(self.symbol_a_)
instrument_b_return = _single_instrument_return(self.symbol_b_)
return instrument_a_return + instrument_b_return
return 0.0
def on_open_trades(self, trades: pd.DataFrame) -> None:
if "close_trades" in self.user_data_:
del self.user_data_["close_trades"]
self.user_data_["open_trades"] = trades
def on_close_trades(self, trades: pd.DataFrame) -> None:
del self.user_data_["open_trades"]
self.user_data_["close_trades"] = trades
def add_outstanding_position(
self,
symbol: str,
open_side: str,
open_px: float,
open_tstamp: datetime,
last_mkt_data_row: pd.Series,
) -> None:
assert symbol in [self.symbol_a_, self.symbol_b_], "Symbol must be one of the pair's symbols"
assert open_side in ["BUY", "SELL"], "Open side must be either BUY or SELL"
assert open_px > 0, "Open price must be greater than 0"
assert open_tstamp is not None, "Open timestamp must be provided"
assert last_mkt_data_row is not None, "Last market data row must be provided"
exec_prices_col_a, exec_prices_col_b = self.exec_prices_colnames()
if symbol == self.symbol_a_:
last_px = last_mkt_data_row[exec_prices_col_a]
else:
last_px = last_mkt_data_row[exec_prices_col_b]
funding_per_position = self.config_["funding_per_pair"] / 2
shares = funding_per_position / open_px
if open_side == "SELL":
shares = -shares
if "outstanding_positions" not in self.user_data_:
self.user_data_["outstanding_positions"] = []
self.user_data_["outstanding_positions"].append({
"symbol": symbol,
"open_side": open_side,
"open_px": open_px,
"shares": shares,
"open_tstamp": open_tstamp,
"last_px": last_px,
"last_tstamp": last_mkt_data_row["tstamp"],
"last_value": last_px * shares,
})
def run(self, market_data: pd.DataFrame, data_params: DataParams) -> Prediction: # type: ignore[assignment]
self.market_data_ = market_data[data_params.training_start_index:data_params.training_start_index + data_params.training_size]
return self.model_.predict(pair=self)
while self.model_tdp_.has_next_training_data():
training_data = self.model_tdp_.get_next_training_data()