165 lines
5.9 KiB
Python
165 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any, Dict, Generator, List, Optional, Type, cast
|
|
|
|
import pandas as pd
|
|
from pt_strategy.model_data_policy import DataWindowParams
|
|
|
|
|
|
class PairState(Enum):
|
|
INITIAL = 1
|
|
OPEN = 2
|
|
CLOSE = 3
|
|
CLOSE_POSITION = 4
|
|
CLOSE_STOP_LOSS = 5
|
|
CLOSE_STOP_PROFIT = 6
|
|
|
|
class TradingPair:
|
|
config_: Dict[str, Any]
|
|
market_data_: pd.DataFrame
|
|
symbol_a_: str
|
|
symbol_b_: str
|
|
|
|
stat_model_price_: str
|
|
model_: PairsTradingModel # type: ignore[assignment]
|
|
|
|
user_data_: Dict[str, Any]
|
|
|
|
def __init__(
|
|
self,
|
|
config: Dict[str, Any],
|
|
instruments: List[Dict[str, str]],
|
|
):
|
|
|
|
from pt_strategy.model_data_policy import ModelDataPolicy
|
|
from pt_strategy.pt_model import PairsTradingModel
|
|
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
|
|
|
|
self.config_ = config
|
|
self.symbol_a_ = instruments[0]["symbol"]
|
|
self.symbol_b_ = instruments[1]["symbol"]
|
|
self.model_ = PairsTradingModel.create(config)
|
|
self.stat_model_price_ = config["stat_model_price"]
|
|
self.user_data_ = {
|
|
"state": PairState.INITIAL,
|
|
}
|
|
|
|
def colnames(self) -> List[str]:
|
|
return [
|
|
f"{self.stat_model_price_}_{self.symbol_a_}",
|
|
f"{self.stat_model_price_}_{self.symbol_b_}",
|
|
]
|
|
|
|
def exec_prices_colnames(self) -> List[str]:
|
|
return [
|
|
f"exec_price_{self.symbol_a_}",
|
|
f"exec_price_{self.symbol_b_}",
|
|
]
|
|
|
|
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
|
|
config = self.config_
|
|
if (
|
|
"stop_close_conditions" not in config
|
|
or config["stop_close_conditions"] is None
|
|
):
|
|
return False
|
|
if "profit" in config["stop_close_conditions"]:
|
|
current_return = self._current_return(predicted_row)
|
|
#
|
|
# print(f"time={predicted_row['tstamp']} current_return={current_return}")
|
|
#
|
|
if current_return >= config["stop_close_conditions"]["profit"]:
|
|
print(f"STOP PROFIT: {current_return}")
|
|
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
|
|
return True
|
|
if "loss" in config["stop_close_conditions"]:
|
|
if current_return <= config["stop_close_conditions"]["loss"]:
|
|
print(f"STOP LOSS: {current_return}")
|
|
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
|
|
return True
|
|
return False
|
|
|
|
def _current_return(self, predicted_row: pd.Series) -> float:
|
|
if "open_trades" in self.user_data_:
|
|
open_trades = self.user_data_["open_trades"]
|
|
if len(open_trades) == 0:
|
|
return 0.0
|
|
|
|
def _single_instrument_return(symbol: str) -> float:
|
|
instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
|
|
instrument_open_price = instrument_open_trades["price"].iloc[0]
|
|
|
|
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
|
|
instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"]
|
|
instrument_return = (
|
|
sign
|
|
* (instrument_price - instrument_open_price)
|
|
/ instrument_open_price
|
|
)
|
|
return float(instrument_return) * 100.0
|
|
|
|
instrument_a_return = _single_instrument_return(self.symbol_a_)
|
|
instrument_b_return = _single_instrument_return(self.symbol_b_)
|
|
return instrument_a_return + instrument_b_return
|
|
return 0.0
|
|
|
|
def on_open_trades(self, trades: pd.DataFrame) -> None:
|
|
if "close_trades" in self.user_data_:
|
|
del self.user_data_["close_trades"]
|
|
self.user_data_["open_trades"] = trades
|
|
|
|
def on_close_trades(self, trades: pd.DataFrame) -> None:
|
|
del self.user_data_["open_trades"]
|
|
self.user_data_["close_trades"] = trades
|
|
|
|
def add_outstanding_position(
|
|
self,
|
|
symbol: str,
|
|
open_side: str,
|
|
open_px: float,
|
|
open_tstamp: datetime,
|
|
last_mkt_data_row: pd.Series,
|
|
) -> None:
|
|
assert symbol in [self.symbol_a_, self.symbol_b_], "Symbol must be one of the pair's symbols"
|
|
assert open_side in ["BUY", "SELL"], "Open side must be either BUY or SELL"
|
|
assert open_px > 0, "Open price must be greater than 0"
|
|
assert open_tstamp is not None, "Open timestamp must be provided"
|
|
assert last_mkt_data_row is not None, "Last market data row must be provided"
|
|
|
|
exec_prices_col_a, exec_prices_col_b = self.exec_prices_colnames()
|
|
if symbol == self.symbol_a_:
|
|
last_px = last_mkt_data_row[exec_prices_col_a]
|
|
else:
|
|
last_px = last_mkt_data_row[exec_prices_col_b]
|
|
|
|
|
|
funding_per_position = self.config_["funding_per_pair"] / 2
|
|
shares = funding_per_position / open_px
|
|
if open_side == "SELL":
|
|
shares = -shares
|
|
|
|
if "outstanding_positions" not in self.user_data_:
|
|
self.user_data_["outstanding_positions"] = []
|
|
|
|
self.user_data_["outstanding_positions"].append({
|
|
"symbol": symbol,
|
|
"open_side": open_side,
|
|
"open_px": open_px,
|
|
"shares": shares,
|
|
"open_tstamp": open_tstamp,
|
|
"last_px": last_px,
|
|
"last_tstamp": last_mkt_data_row["tstamp"],
|
|
"last_value": last_px * shares,
|
|
})
|
|
|
|
|
|
def run(self, market_data: pd.DataFrame, data_params: DataWindowParams) -> Prediction: # type: ignore[assignment]
|
|
self.market_data_ = market_data[data_params.training_start_index:data_params.training_start_index + data_params.training_size]
|
|
return self.model_.predict(pair=self)
|
|
|
|
|
|
|