248 lines
8.5 KiB
Python
248 lines
8.5 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any, Dict, List
|
|
|
|
import pandas as pd
|
|
|
|
# ---
|
|
from cvttpy_tools.base import NamedObject
|
|
from cvttpy_tools.config import Config
|
|
# ---
|
|
from cvttpy_trading.trading.instrument import ExchangeInstrument
|
|
# ---
|
|
from pairs_trading.lib.pt_strategy.model_data_policy import DataWindowParams
|
|
from pairs_trading.lib.pt_strategy.prediction import Prediction
|
|
|
|
|
|
|
|
class PairState(Enum):
|
|
INITIAL = 1
|
|
OPEN = 2
|
|
CLOSE = 3
|
|
CLOSE_POSITION = 4
|
|
CLOSE_STOP_LOSS = 5
|
|
CLOSE_STOP_PROFIT = 6
|
|
|
|
|
|
# def get_symbol(instrument: Dict[str, str]) -> str:
|
|
# if "symbol" in instrument:
|
|
# return instrument["symbol"]
|
|
# elif "instrument_id" in instrument:
|
|
# instrument_id = instrument["instrument_id"]
|
|
# instrument_pfx = instrument_id[: instrument_id.find("-") + 1]
|
|
# symbol = instrument_id[len(instrument_pfx) :]
|
|
# instrument["symbol"] = symbol
|
|
# instrument["instrument_id_pfx"] = instrument_pfx
|
|
# return symbol
|
|
# else:
|
|
# raise ValueError(
|
|
# f"Invalid instrument: {instrument}, missing symbol or instrument_id"
|
|
# )
|
|
|
|
|
|
class TradingPair(NamedObject, ABC):
|
|
config_: Config
|
|
model_: Any # "PairsTradingModel"
|
|
market_data_: pd.DataFrame
|
|
|
|
user_data_: Dict[str, Any]
|
|
stat_model_price_: str
|
|
|
|
instruments_: List[ExchangeInstrument]
|
|
|
|
def __init__(
|
|
self,
|
|
config: Config,
|
|
instruments: List[ExchangeInstrument],
|
|
):
|
|
from pairs_trading.lib.pt_strategy.pt_model import PairsTradingModel
|
|
|
|
self.config_ = config
|
|
self.model_ = PairsTradingModel.create(config)
|
|
self.user_data_ = {}
|
|
self.instruments_ = instruments
|
|
self.instruments_[0].user_data_["symbol"] = instruments[0].instrument_id().split("-", 1)[1]
|
|
self.instruments_[1].user_data_["symbol"] = instruments[1].instrument_id().split("-", 1)[1]
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"{self.__class__.__name__}:"
|
|
f" symbol_a={self.symbol_a()},"
|
|
f" symbol_b={self.symbol_b()},"
|
|
f" model={self.model_.__class__.__name__}"
|
|
)
|
|
|
|
def colnames(self) -> List[str]:
|
|
return [
|
|
f"{self.stat_model_price_}_{self.symbol_a()}",
|
|
f"{self.stat_model_price_}_{self.symbol_b()}",
|
|
]
|
|
def symbol_a(self) -> str:
|
|
return self.get_instrument_a().user_data_["symbol"]
|
|
|
|
def symbol_b(self) -> str:
|
|
return self.get_instrument_b().user_data_["symbol"]
|
|
|
|
def get_instrument_a(self) -> ExchangeInstrument:
|
|
return self.instruments_[0]
|
|
|
|
def get_instrument_b(self) -> ExchangeInstrument:
|
|
return self.instruments_[1]
|
|
|
|
|
|
|
|
class ResearchTradingPair(TradingPair):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Config,
|
|
instruments: List[ExchangeInstrument],
|
|
):
|
|
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
|
|
super().__init__(config=config, instruments=instruments)
|
|
|
|
self.stat_model_price_ = config.get_value("stat_model_price")
|
|
self.user_data_ = {
|
|
"state": PairState.INITIAL,
|
|
}
|
|
|
|
# URGENT set exchange instruments for the pair
|
|
|
|
def is_closed(self) -> bool:
|
|
return self.user_data_["state"] in [
|
|
PairState.CLOSE,
|
|
PairState.CLOSE_POSITION,
|
|
PairState.CLOSE_STOP_LOSS,
|
|
PairState.CLOSE_STOP_PROFIT,
|
|
]
|
|
|
|
def is_open(self) -> bool:
|
|
return not self.is_closed()
|
|
|
|
def exec_prices_colnames(self) -> List[str]:
|
|
return [
|
|
f"exec_price_{self.symbol_a()}",
|
|
f"exec_price_{self.symbol_b()}",
|
|
]
|
|
|
|
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
|
|
config = self.config_
|
|
if (
|
|
not config.key_exists("stop_close_conditions")
|
|
or config.get_value("stop_close_conditions") is None
|
|
):
|
|
return False
|
|
if "profit" in config.get_value("stop_close_conditions"):
|
|
current_return = self._current_return(predicted_row)
|
|
#
|
|
# print(f"time={predicted_row['tstamp']} current_return={current_return}")
|
|
#
|
|
if current_return >= config.get_value("stop_close_conditions")["profit"]:
|
|
print(f"STOP PROFIT: {current_return}")
|
|
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
|
|
return True
|
|
if "loss" in config.get_value("stop_close_conditions"):
|
|
if current_return <= config.get_value("stop_close_conditions")["loss"]:
|
|
print(f"STOP LOSS: {current_return}")
|
|
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
|
|
return True
|
|
return False
|
|
|
|
def _current_return(self, predicted_row: pd.Series) -> float:
|
|
if "open_trades" in self.user_data_:
|
|
open_trades = self.user_data_["open_trades"]
|
|
if len(open_trades) == 0:
|
|
return 0.0
|
|
|
|
def _single_instrument_return(symbol: str) -> float:
|
|
instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
|
|
instrument_open_price = instrument_open_trades["price"].iloc[0]
|
|
|
|
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
|
|
instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"]
|
|
instrument_return = (
|
|
sign
|
|
* (instrument_price - instrument_open_price)
|
|
/ instrument_open_price
|
|
)
|
|
return float(instrument_return) * 100.0
|
|
|
|
instrument_a_return = _single_instrument_return(self.symbol_a())
|
|
instrument_b_return = _single_instrument_return(self.symbol_b())
|
|
return instrument_a_return + instrument_b_return
|
|
return 0.0
|
|
|
|
def on_open_trades(self, trades: pd.DataFrame) -> None:
|
|
if "close_trades" in self.user_data_:
|
|
del self.user_data_["close_trades"]
|
|
self.user_data_["open_trades"] = trades
|
|
|
|
def on_close_trades(self, trades: pd.DataFrame) -> None:
|
|
del self.user_data_["open_trades"]
|
|
self.user_data_["close_trades"] = trades
|
|
|
|
def add_outstanding_position(
|
|
self,
|
|
symbol: str,
|
|
open_side: str,
|
|
open_px: float,
|
|
open_tstamp: datetime,
|
|
last_mkt_data_row: pd.Series,
|
|
) -> None:
|
|
assert symbol in [
|
|
self.symbol_a(),
|
|
self.symbol_b(),
|
|
], "Symbol must be one of the pair's symbols"
|
|
assert open_side in ["BUY", "SELL"], "Open side must be either BUY or SELL"
|
|
assert open_px > 0, "Open price must be greater than 0"
|
|
assert open_tstamp is not None, "Open timestamp must be provided"
|
|
assert last_mkt_data_row is not None, "Last market data row must be provided"
|
|
|
|
exec_prices_col_a, exec_prices_col_b = self.exec_prices_colnames()
|
|
if symbol == self.symbol_a():
|
|
last_px = last_mkt_data_row[exec_prices_col_a]
|
|
else:
|
|
last_px = last_mkt_data_row[exec_prices_col_b]
|
|
|
|
funding_per_position = self.config_.get_value("funding_per_pair") / 2
|
|
shares = funding_per_position / open_px
|
|
if open_side == "SELL":
|
|
shares = -shares
|
|
|
|
if "outstanding_positions" not in self.user_data_:
|
|
self.user_data_["outstanding_positions"] = []
|
|
|
|
self.user_data_["outstanding_positions"].append(
|
|
{
|
|
"symbol": symbol,
|
|
"open_side": open_side,
|
|
"open_px": open_px,
|
|
"shares": shares,
|
|
"open_tstamp": open_tstamp,
|
|
"last_px": last_px,
|
|
"last_tstamp": last_mkt_data_row["tstamp"],
|
|
"last_value": last_px * shares,
|
|
}
|
|
)
|
|
|
|
def run(self, market_data: pd.DataFrame, data_params: DataWindowParams) -> Prediction: # type: ignore[assignment]
|
|
self.market_data_ = market_data[
|
|
data_params.training_start_index_ : data_params.training_start_index_
|
|
+ data_params.training_size_
|
|
]
|
|
return self.model_.predict(pair=self)
|
|
|
|
class LiveTradingPair(TradingPair):
|
|
|
|
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
|
|
super().__init__(config, instruments)
|
|
|
|
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
|
|
# TODO LiveTradingPair.to_stop_close_conditions()
|
|
return False
|
|
|
|
|