pairs_trading/lib/pt_strategy/live/live_strategy.py
Oleg Sheynin c0fabcb429 progress
2026-01-12 21:26:15 +00:00

329 lines
12 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional
import pandas as pd
# ---
from cvttpy_tools.base import NamedObject
from cvttpy_tools.app import App
from cvttpy_tools.config import Config
from cvttpy_tools.settings.cvtt_types import IntervalSecT
from cvttpy_tools.timeutils import SecPerHour, current_nanoseconds, NanoPerSec
from cvttpy_tools.logger import Log
# ---
from cvttpy_trading.trading.instrument import ExchangeInstrument
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
from cvttpy_trading.trading.trading_instructions import TradingInstructions
from cvttpy_trading.trading.trading_instructions import TargetPositionSignal
# ---
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
from pairs_trading.lib.pt_strategy.pt_model import Prediction
from pairs_trading.lib.pt_strategy.trading_pair import LiveTradingPair
from pairs_trading.apps.pair_trader import PairTrader
from pairs_trading.lib.pt_strategy.pt_market_data import LiveMarketData
class PtLiveStrategy(NamedObject):
config_: Config
instruments_: List[ExchangeInstrument]
interval_sec_: IntervalSecT
history_depth_sec_: IntervalSecT
open_threshold_: float
close_threshold_: float
trading_pair_: LiveTradingPair
model_data_policy_: ModelDataPolicy
pairs_trader_: PairTrader
# for presentation: history of prediction values and trading signals
predictions_df_: pd.DataFrame
trading_signals_df_: pd.DataFrame
def __init__(
self,
config: Config,
pairs_trader: PairTrader,
):
# import copy
# self.config_ = Config(json_src=copy.deepcopy(config.data()))
self.config_ = config
self.pairs_trader_ = pairs_trader
self.trading_pair_ = LiveTradingPair(
config=config,
instruments=self.pairs_trader_.instruments_,
)
self.model_data_policy_ = ModelDataPolicy.create(
self.config_,
is_real_time=True,
pair=self.trading_pair_,
)
assert (
self.model_data_policy_ is not None
), f"{self.fname()}: Unable to create ModelDataPolicy"
self.predictions_df_ = pd.DataFrame()
self.trading_signals_df_ = pd.DataFrame()
self.instruments_ = self.pairs_trader_.instruments_
App.instance().add_call(
stage=App.Stage.Config, func=self._on_config(), can_run_now=True
)
async def _on_config(self) -> None:
self.interval_sec_ = self.config_.get_value("interval_sec", 0)
assert self.interval_sec_ > 0, "interval_sec cannot be 0"
self.history_depth_sec_ = (
self.config_.get_value("history_depth_hours", 0) * SecPerHour
)
assert self.history_depth_sec_ > 0, "history_depth_hours cannot be 0"
await self.pairs_trader_.subscribe_md()
self.open_threshold_ = self.config_.get_value(
"model/disequilibrium/open_trshld", 0.0
)
self.close_threshold_ = self.config_.get_value(
"model/disequilibrium/close_trshld", 0.0
)
assert (
self.open_threshold_ > 0
), "disequilibrium/open_trshld must be greater than 0"
assert (
self.close_threshold_ > 0
), "disequilibrium/close_trshld must be greater than 0"
def __repr__(self) -> str:
return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, "
async def on_mkt_data_hist_snapshot(
self, hist_aggr: List[MdTradesAggregate]
) -> None:
if not self._is_md_actual(hist_aggr=hist_aggr):
return
market_data_df: pd.DataFrame = self._create_md_df(hist_aggr=hist_aggr)
if len(market_data_df) == 0:
Log.warning(f"{self.fname()} Unable to create market data df")
return
self.trading_pair_.market_data_ = market_data_df
Log.info(f"{self.fname()}: Running prediction for pair: {self.trading_pair_}")
prediction = self.trading_pair_.run(
market_data_df, self.model_data_policy_.advance()
)
self.predictions_df_ = pd.concat(
[self.predictions_df_, prediction.to_df()], ignore_index=True
)
trading_instructions: List[TradingInstructions] = (
self._create_trading_instructions(
prediction=prediction, last_row=market_data_df.iloc[-1]
)
)
if trading_instructions is not None:
await self._send_trading_instructions(trading_instructions)
def _is_md_actual(self, hist_aggr: List[MdTradesAggregate]) -> bool:
curr_ns = current_nanoseconds()
LAG_THRESHOLD = 5 * NanoPerSec
if len(hist_aggr) == 0:
Log.warning(f"{self.fname()} list of aggregates IS EMPTY")
return False
# MAYBE check market data length
lag_ns = curr_ns - hist_aggr[-1].time_ns_
if lag_ns > LAG_THRESHOLD:
Log.warning(f"{self.fname()} {hist_aggr[-1].exch_inst_.details_short()} Lagging {int(lag_ns/NanoPerSec)} seconds")
return False
return True
def _create_md_df(self, hist_aggr: List[MdTradesAggregate]) -> pd.DataFrame:
"""
tstamp time_ns symbol open high low close volume num_trades vwap
0 2025-09-10 11:30:00 1757503800000000000 ADA-USDT 0.8750 0.8750 0.8743 0.8743 50710.500 0 0.874489
1 2025-09-10 11:30:00 1757503800000000000 SOL-USDT 219.9700 219.9800 219.6600 219.7000 2648.582 0 219.787847
2 2025-09-10 11:31:00 1757503860000000000 SOL-USDT 219.7000 219.7300 219.6200 219.6200 1134.886 0 219.663460
3 2025-09-10 11:31:00 1757503860000000000 ADA-USDT 0.8743 0.8745 0.8741 0.8741 10696.400 0 0.874234
4 2025-09-10 11:32:00 1757503920000000000 ADA-USDT 0.8742 0.8742 0.8739 0.8740 18546.900 0 0.874037
"""
rows: List[Dict[str, Any]] = []
for aggr in hist_aggr:
exch_inst = aggr.exch_inst_
rows.append(
{
# convert nanoseconds → tz-aware pandas timestamp
"tstamp": pd.to_datetime(aggr.time_ns_, unit="ns", utc=True),
"time_ns": aggr.time_ns_,
"symbol": exch_inst.instrument_id().split("-", 1)[1],
"exchange_id": exch_inst.exchange_id_,
"instrument_id": exch_inst.instrument_id(),
"open": exch_inst.get_price(aggr.open_),
"high": exch_inst.get_price(aggr.high_),
"low": exch_inst.get_price(aggr.low_),
"close": exch_inst.get_price(aggr.close_),
"volume": exch_inst.get_quantity(aggr.volume_),
"num_trades": aggr.num_trades_,
"vwap": exch_inst.get_price(aggr.vwap_),
}
)
source_md_df = pd.DataFrame(
rows,
columns=[
"tstamp",
"time_ns",
"symbol",
"exchange_id",
"instrument_id",
"open",
"high",
"low",
"close",
"volume",
"num_trades",
"vwap",
],
)
# automatic sorting
source_md_df.sort_values(
by=["time_ns", "symbol"],
ascending=True,
inplace=True,
kind="mergesort", # stable sort
)
source_md_df.reset_index(drop=True, inplace=True)
pt_mkt_data = LiveMarketData(config=self.config_, instruments=self.instruments_)
pt_mkt_data.origin_mkt_data_df_ = source_md_df
pt_mkt_data.set_market_data()
return pt_mkt_data.market_data_df_
def interval_sec(self) -> IntervalSecT:
return self.interval_sec_
def history_depth_sec(self) -> IntervalSecT:
return self.history_depth_sec_
async def _send_trading_instructions(
self, trading_instructions: List[TradingInstructions]
) -> None:
for ti in trading_instructions:
Log.info(f"{self.fname()} Sending trading instructions {ti}")
await self.pairs_trader_.ti_sender_.send_trading_instructions(ti)
def _create_trading_instructions(
self, prediction: Prediction, last_row: pd.Series
) -> List[TradingInstructions]:
trd_instructions: List[TradingInstructions] = []
pair = self.trading_pair_
scaled_disequilibrium = prediction.scaled_disequilibrium_
abs_scaled_disequilibrium = abs(scaled_disequilibrium)
if abs_scaled_disequilibrium >= self.open_threshold_:
trd_instructions = self._create_open_trade_instructions(
pair, row=last_row, prediction=prediction
)
elif abs_scaled_disequilibrium <= self.close_threshold_ or pair.to_stop_close_conditions(predicted_row=last_row):
trd_instructions = self._create_close_trade_instructions(
pair, row=last_row # , prediction=prediction
)
return trd_instructions
def _strength(self, scaled_disequilibrium: float) -> float:
# TODO PtLiveStrategy._strength()
return 1.0
def _create_open_trade_instructions(
self, pair: LiveTradingPair, row: pd.Series, prediction: Prediction
) -> List[TradingInstructions]:
diseqlbrm = prediction.disequilibrium_
scaled_disequilibrium = prediction.scaled_disequilibrium_
if diseqlbrm > 0:
side_a = -1
side_b = 1
else:
side_a = 1
side_b = -1
ti_a: Optional[TradingInstructions] = TradingInstructions(
book=self.pairs_trader_.book_id_,
strategy_id=self.__class__.__name__,
ti_type=TradingInstructions.Type.TARGET_POSITION,
issued_ts_ns=current_nanoseconds(),
data=TargetPositionSignal(
strength=side_a * self._strength(scaled_disequilibrium),
base_asset=pair.get_instrument_a().base_asset_id_,
quote_asset=pair.get_instrument_a().quote_asset_id_,
user_data={}
),
)
if not ti_a:
return []
ti_b: Optional[TradingInstructions] = TradingInstructions(
book=self.pairs_trader_.book_id_,
strategy_id=self.__class__.__name__,
ti_type=TradingInstructions.Type.TARGET_POSITION,
issued_ts_ns=current_nanoseconds(),
data=TargetPositionSignal(
strength=side_b * self._strength(scaled_disequilibrium),
base_asset=pair.get_instrument_b().base_asset_id_,
quote_asset=pair.get_instrument_b().quote_asset_id_,
user_data={}
),
)
if not ti_b:
return []
return [ti_a, ti_b]
def _create_close_trade_instructions(
self, pair: LiveTradingPair, row: pd.Series
) -> List[TradingInstructions]:
ti_a: Optional[TradingInstructions] = TradingInstructions(
book=self.pairs_trader_.book_id_,
strategy_id=self.__class__.__name__,
ti_type=TradingInstructions.Type.TARGET_POSITION,
issued_ts_ns=current_nanoseconds(),
data=TargetPositionSignal(
strength=0,
base_asset=pair.get_instrument_a().base_asset_id_,
quote_asset=pair.get_instrument_a().quote_asset_id_,
user_data={}
),
)
if not ti_a:
return []
ti_b: Optional[TradingInstructions] = TradingInstructions(
book=self.pairs_trader_.book_id_,
strategy_id=self.__class__.__name__,
ti_type=TradingInstructions.Type.TARGET_POSITION,
issued_ts_ns=current_nanoseconds(),
data=TargetPositionSignal(
strength=0,
base_asset=pair.get_instrument_b().base_asset_id_,
quote_asset=pair.get_instrument_b().quote_asset_id_,
user_data={}
),
)
if not ti_b:
return []
return [ti_a, ti_b]