pairs_trading/lib/pt_strategy/trading_pair.py
Oleg Sheynin c0fabcb429 progress
2026-01-12 21:26:15 +00:00

227 lines
7.8 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List
import pandas as pd
# ---
from cvttpy_tools.base import NamedObject
from cvttpy_tools.config import Config
# ---
from cvttpy_trading.trading.instrument import ExchangeInstrument
# ---
from pairs_trading.lib.pt_strategy.model_data_policy import DataWindowParams
from pairs_trading.lib.pt_strategy.prediction import Prediction
class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSE = 3
CLOSE_POSITION = 4
CLOSE_STOP_LOSS = 5
CLOSE_STOP_PROFIT = 6
class TradingPair(NamedObject, ABC):
config_: Config
model_: Any # "PairsTradingModel"
market_data_: pd.DataFrame
user_data_: Dict[str, Any]
stat_model_price_: str
instruments_: List[ExchangeInstrument]
def __init__(
self,
config: Config,
instruments: List[ExchangeInstrument],
):
from pairs_trading.lib.pt_strategy.pt_model import PairsTradingModel
self.config_ = config
self.model_ = PairsTradingModel.create(config)
self.user_data_ = {}
self.instruments_ = instruments
self.instruments_[0].user_data_["symbol"] = instruments[0].instrument_id().split("-", 1)[1]
self.instruments_[1].user_data_["symbol"] = instruments[1].instrument_id().split("-", 1)[1]
self.stat_model_price_ = config.get_value("model/stat_model_price")
def run(self, market_data: pd.DataFrame, data_params: DataWindowParams) -> Prediction: # type: ignore[assignment]
self.market_data_ = market_data[
data_params.training_start_index_ : data_params.training_start_index_ + data_params.training_size_
]
return self.model_.predict(pair=self)
def colnames(self) -> List[str]:
return [
f"{self.stat_model_price_}_{self.symbol_a()}",
f"{self.stat_model_price_}_{self.symbol_b()}",
]
def symbol_a(self) -> str:
return self.get_instrument_a().user_data_["symbol"]
def symbol_b(self) -> str:
return self.get_instrument_b().user_data_["symbol"]
def get_instrument_a(self) -> ExchangeInstrument:
return self.instruments_[0]
def get_instrument_b(self) -> ExchangeInstrument:
return self.instruments_[1]
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}:"
f" symbol_a={self.symbol_a()},"
f" symbol_b={self.symbol_b()},"
f" model={self.model_.__class__.__name__}"
)
class ResearchTradingPair(TradingPair):
def __init__(
self,
config: Config,
instruments: List[ExchangeInstrument],
):
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
super().__init__(config=config, instruments=instruments)
self.user_data_ = {
"state": PairState.INITIAL,
}
def is_closed(self) -> bool:
return self.user_data_["state"] in [
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]
def is_open(self) -> bool:
return not self.is_closed()
def exec_prices_colnames(self) -> List[str]:
return [
f"exec_price_{self.symbol_a()}",
f"exec_price_{self.symbol_b()}",
]
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
config = self.config_
if (
not config.key_exists("stop_close_conditions")
or config.get_value("stop_close_conditions") is None
):
return False
if "profit" in config.get_value("stop_close_conditions"):
current_return = self._current_return(predicted_row)
#
# print(f"time={predicted_row['tstamp']} current_return={current_return}")
#
if current_return >= config.get_value("stop_close_conditions")["profit"]:
print(f"STOP PROFIT: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
return True
if "loss" in config.get_value("stop_close_conditions"):
if current_return <= config.get_value("stop_close_conditions")["loss"]:
print(f"STOP LOSS: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
return True
return False
def _current_return(self, predicted_row: pd.Series) -> float:
if "open_trades" in self.user_data_:
open_trades = self.user_data_["open_trades"]
if len(open_trades) == 0:
return 0.0
def _single_instrument_return(symbol: str) -> float:
instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
instrument_open_price = instrument_open_trades["price"].iloc[0]
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"]
instrument_return = (
sign
* (instrument_price - instrument_open_price)
/ instrument_open_price
)
return float(instrument_return) * 100.0
instrument_a_return = _single_instrument_return(self.symbol_a())
instrument_b_return = _single_instrument_return(self.symbol_b())
return instrument_a_return + instrument_b_return
return 0.0
def on_open_trades(self, trades: pd.DataFrame) -> None:
if "close_trades" in self.user_data_:
del self.user_data_["close_trades"]
self.user_data_["open_trades"] = trades
def on_close_trades(self, trades: pd.DataFrame) -> None:
del self.user_data_["open_trades"]
self.user_data_["close_trades"] = trades
def add_outstanding_position(
self,
symbol: str,
open_side: str,
open_px: float,
open_tstamp: datetime,
last_mkt_data_row: pd.Series,
) -> None:
assert symbol in [
self.symbol_a(),
self.symbol_b(),
], "Symbol must be one of the pair's symbols"
assert open_side in ["BUY", "SELL"], "Open side must be either BUY or SELL"
assert open_px > 0, "Open price must be greater than 0"
assert open_tstamp is not None, "Open timestamp must be provided"
assert last_mkt_data_row is not None, "Last market data row must be provided"
exec_prices_col_a, exec_prices_col_b = self.exec_prices_colnames()
if symbol == self.symbol_a():
last_px = last_mkt_data_row[exec_prices_col_a]
else:
last_px = last_mkt_data_row[exec_prices_col_b]
funding_per_position = self.config_.get_value("funding_per_pair") / 2
shares = funding_per_position / open_px
if open_side == "SELL":
shares = -shares
if "outstanding_positions" not in self.user_data_:
self.user_data_["outstanding_positions"] = []
self.user_data_["outstanding_positions"].append(
{
"symbol": symbol,
"open_side": open_side,
"open_px": open_px,
"shares": shares,
"open_tstamp": open_tstamp,
"last_px": last_px,
"last_tstamp": last_mkt_data_row["tstamp"],
"last_value": last_px * shares,
}
)
class LiveTradingPair(TradingPair):
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
super().__init__(config, instruments)
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
# TODO LiveTradingPair.to_stop_close_conditions()
return False