from __future__ import annotations from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from typing import Any, Dict, List import pandas as pd # --- from cvttpy_tools.base import NamedObject from cvttpy_tools.config import Config # --- from cvttpy_trading.trading.instrument import ExchangeInstrument # --- from pairs_trading.lib.pt_strategy.model_data_policy import DataWindowParams from pairs_trading.lib.pt_strategy.prediction import Prediction class PairState(Enum): INITIAL = 1 OPEN = 2 CLOSE = 3 CLOSE_POSITION = 4 CLOSE_STOP_LOSS = 5 CLOSE_STOP_PROFIT = 6 class TradingPair(NamedObject, ABC): config_: Config model_: Any # "PairsTradingModel" market_data_: pd.DataFrame user_data_: Dict[str, Any] stat_model_price_: str instruments_: List[ExchangeInstrument] def __init__( self, config: Config, instruments: List[ExchangeInstrument], ): from pairs_trading.lib.pt_strategy.pt_model import PairsTradingModel self.config_ = config self.model_ = PairsTradingModel.create(config) self.user_data_ = {} self.instruments_ = instruments self.instruments_[0].user_data_["symbol"] = instruments[0].instrument_id().split("-", 1)[1] self.instruments_[1].user_data_["symbol"] = instruments[1].instrument_id().split("-", 1)[1] self.stat_model_price_ = config.get_value("model/stat_model_price") def run(self, market_data: pd.DataFrame, data_params: DataWindowParams) -> Prediction: # type: ignore[assignment] self.market_data_ = market_data[ data_params.training_start_index_ : data_params.training_start_index_ + data_params.training_size_ ] return self.model_.predict(pair=self) def colnames(self) -> List[str]: return [ f"{self.stat_model_price_}_{self.symbol_a()}", f"{self.stat_model_price_}_{self.symbol_b()}", ] def symbol_a(self) -> str: return self.get_instrument_a().user_data_["symbol"] def symbol_b(self) -> str: return self.get_instrument_b().user_data_["symbol"] def get_instrument_a(self) -> ExchangeInstrument: return self.instruments_[0] def get_instrument_b(self) -> ExchangeInstrument: return self.instruments_[1] def __repr__(self) -> str: return ( f"{self.__class__.__name__}:" f" symbol_a={self.symbol_a()}," f" symbol_b={self.symbol_b()}," f" model={self.model_.__class__.__name__}" ) class ResearchTradingPair(TradingPair): def __init__( self, config: Config, instruments: List[ExchangeInstrument], ): assert len(instruments) == 2, "Trading pair must have exactly 2 instruments" super().__init__(config=config, instruments=instruments) self.user_data_ = { "state": PairState.INITIAL, } def is_closed(self) -> bool: return self.user_data_["state"] in [ PairState.CLOSE, PairState.CLOSE_POSITION, PairState.CLOSE_STOP_LOSS, PairState.CLOSE_STOP_PROFIT, ] def is_open(self) -> bool: return not self.is_closed() def exec_prices_colnames(self) -> List[str]: return [ f"exec_price_{self.symbol_a()}", f"exec_price_{self.symbol_b()}", ] def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool: config = self.config_ if ( not config.key_exists("stop_close_conditions") or config.get_value("stop_close_conditions") is None ): return False if "profit" in config.get_value("stop_close_conditions"): current_return = self._current_return(predicted_row) # # print(f"time={predicted_row['tstamp']} current_return={current_return}") # if current_return >= config.get_value("stop_close_conditions")["profit"]: print(f"STOP PROFIT: {current_return}") self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT return True if "loss" in config.get_value("stop_close_conditions"): if current_return <= config.get_value("stop_close_conditions")["loss"]: print(f"STOP LOSS: {current_return}") self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS return True return False def _current_return(self, predicted_row: pd.Series) -> float: if "open_trades" in self.user_data_: open_trades = self.user_data_["open_trades"] if len(open_trades) == 0: return 0.0 def _single_instrument_return(symbol: str) -> float: instrument_open_trades = open_trades[open_trades["symbol"] == symbol] instrument_open_price = instrument_open_trades["price"].iloc[0] sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1 instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"] instrument_return = ( sign * (instrument_price - instrument_open_price) / instrument_open_price ) return float(instrument_return) * 100.0 instrument_a_return = _single_instrument_return(self.symbol_a()) instrument_b_return = _single_instrument_return(self.symbol_b()) return instrument_a_return + instrument_b_return return 0.0 def on_open_trades(self, trades: pd.DataFrame) -> None: if "close_trades" in self.user_data_: del self.user_data_["close_trades"] self.user_data_["open_trades"] = trades def on_close_trades(self, trades: pd.DataFrame) -> None: del self.user_data_["open_trades"] self.user_data_["close_trades"] = trades def add_outstanding_position( self, symbol: str, open_side: str, open_px: float, open_tstamp: datetime, last_mkt_data_row: pd.Series, ) -> None: assert symbol in [ self.symbol_a(), self.symbol_b(), ], "Symbol must be one of the pair's symbols" assert open_side in ["BUY", "SELL"], "Open side must be either BUY or SELL" assert open_px > 0, "Open price must be greater than 0" assert open_tstamp is not None, "Open timestamp must be provided" assert last_mkt_data_row is not None, "Last market data row must be provided" exec_prices_col_a, exec_prices_col_b = self.exec_prices_colnames() if symbol == self.symbol_a(): last_px = last_mkt_data_row[exec_prices_col_a] else: last_px = last_mkt_data_row[exec_prices_col_b] funding_per_position = self.config_.get_value("funding_per_pair") / 2 shares = funding_per_position / open_px if open_side == "SELL": shares = -shares if "outstanding_positions" not in self.user_data_: self.user_data_["outstanding_positions"] = [] self.user_data_["outstanding_positions"].append( { "symbol": symbol, "open_side": open_side, "open_px": open_px, "shares": shares, "open_tstamp": open_tstamp, "last_px": last_px, "last_tstamp": last_mkt_data_row["tstamp"], "last_value": last_px * shares, } ) class LiveTradingPair(TradingPair): def __init__(self, config: Config, instruments: List[ExchangeInstrument]): super().__init__(config, instruments) def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool: # TODO LiveTradingPair.to_stop_close_conditions() return False