From 0423a7d34f6530a27b2709882cb0d581c5435852 Mon Sep 17 00:00:00 2001 From: Oleg Sheynin Date: Tue, 5 Aug 2025 21:48:23 +0000 Subject: [PATCH] progress. not ready, changing live_strategy.py to use ExchangeInstrument and new TradingInstruction class --- .vscode/launch.json | 24 ++- .vscode/pairs_trading.code-workspace | 11 +- bin/pairs_trader.py | 104 ++++++++++ bin/trade_pair.py | 69 ------- configuration/pairs_trader.cfg | 21 ++ lib/cvtt_client/mkt_data.py | 88 ++++++--- lib/pt_strategy/{ => live}/live_strategy.py | 205 +++++++++++--------- lib/pt_strategy/live/pricer_md_client.py | 90 +++++++++ lib/pt_strategy/live/ti_sender.py | 85 ++++++++ lib/pt_strategy/trading_pair.py | 36 +++- requirements.txt | 1 + 11 files changed, 538 insertions(+), 196 deletions(-) create mode 100644 bin/pairs_trader.py delete mode 100644 bin/trade_pair.py create mode 100644 configuration/pairs_trader.cfg rename lib/pt_strategy/{ => live}/live_strategy.py (65%) create mode 100644 lib/pt_strategy/live/pricer_md_client.py create mode 100644 lib/pt_strategy/live/ti_sender.py diff --git a/.vscode/launch.json b/.vscode/launch.json index b2d32bf..e8b3680 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,8 +10,30 @@ "name": "Python Debugger: Current File", "type": "debugpy", "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", "program": "${file}", - "console": "integratedTerminal" + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}/lib:${workspaceFolder}/.." + }, + }, + { + "name": "-------- Live Pair Trading --------", + }, + { + "name": "PAIRS TRADER", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "${workspaceFolder}/bin/pairs_trader.py", + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}/lib:${workspaceFolder}/.." + }, + "args": [ + "--config=${workspaceFolder}/configuration/pairs_trader.cfg", + "--pair=PAIR-ADA-USDT:BNBSPOT,PAIR-SOL-USDT:BNBSPOT", + ], }, { "name": "-------- OLS --------", diff --git a/.vscode/pairs_trading.code-workspace b/.vscode/pairs_trading.code-workspace index 6553107..53d1337 100644 --- a/.vscode/pairs_trading.code-workspace +++ b/.vscode/pairs_trading.code-workspace @@ -1,9 +1,12 @@ { "folders": [ - { - "path": ".." - } - ], + { + "path": ".." + }, + { + "path": "../../cvttpy_base" + } + ], "settings": { "workbench.colorTheme": "Dracula Theme" } diff --git a/bin/pairs_trader.py b/bin/pairs_trader.py new file mode 100644 index 0000000..bb84a29 --- /dev/null +++ b/bin/pairs_trader.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from functools import partial +from typing import Dict, List + +# import sys +# print("PYTHONPATH directories:") +# for path in sys.path: +# print(path) + +from cvttpy_base.tools.app import App +from cvttpy_base.tools.base import NamedObject +from cvttpy_base.tools.logger import Log +from cvttpy_base.tools.config import CvttAppConfig +from cvttpy_base.settings.cvtt_types import JsonDictT + +from pt_strategy.live.live_strategy import PtLiveStrategy +from pt_strategy.live.pricer_md_client import PtMktDataClient +from pt_strategy.live.ti_sender import TradingInstructionsSender + +# from cvtt_client.mkt_data import (CvttPricerWebSockClient, +# CvttPricesSubscription, MessageTypeT, +# SubscriptionIdT) + +class PairTradingRunner(NamedObject): + config_: CvttAppConfig + instruments_: List[JsonDictT] + + live_strategy_: PtLiveStrategy + pricer_client_: PtMktDataClient + + def __init__(self) -> None: + self.instruments_ = [] + + App.instance().add_cmdline_arg( + "--pair", + type=str, + required=True, + help=( + "Comma-separated pair of instrument symbols" + " with exchange config name" + " (e.g., PAIR-BTC-USD:BNBSPOT,PAIR-ETH-USD:BNBSPOT)" + ), + ) + + App.instance().add_call(App.Stage.Config, self._on_config()) + App.instance().add_call(App.Stage.Run, self.run()) + + async def _on_config(self) -> None: + self.config_ = CvttAppConfig.instance() + + # ------- PARSE INSTRUMENTS ------- + instr_str = App.instance().get_argument("pair", "") + if not instr_str: + raise ValueError("Pair is required") + instr_list = instr_str.split(",") + for instr in instr_list: + instr_parts = instr.split(":") + if len(instr_parts) != 2: + raise ValueError(f"Invalid pair format: {instr}") + instrument_id = instr_parts[0] + exchange_config_name = instr_parts[1] + self.instruments_.append({ + "exchange_config_name": exchange_config_name, + "instrument_id": instrument_id + }) + + assert len(self.instruments_) == 2, "Only two instruments are supported" + Log.info(f"{self.fname()} Instruments: {self.instruments_}") + + # ------- CREATE TI (trading instructions) CLIENT ------- + ti_config = self.config_.get_subconfig("ti_config", {}) + self.ti_sender_ = TradingInstructionsSender(config=ti_config) + Log.info(f"{self.fname()} TI client created: {self.ti_sender_}") + + # ------- CREATE STRATEGY ------- + strategy_config = self.config_.get_value("strategy_config", {}) + self.live_strategy_ = PtLiveStrategy( + config=strategy_config, + instruments=self.instruments_, + ti_sender=self.ti_sender_ + ) + Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}") + + # ------- CREATE PRICER CLIENT ------- + pricer_config = self.config_.get_subconfig("pricer_config", {}) + self.pricer_client_ = PtMktDataClient( + live_strategy=self.live_strategy_, + pricer_config=pricer_config + ) + Log.info(f"{self.fname()} CVTT Pricer client created: {self.pricer_client_}") + + async def run(self) -> None: + Log.info(f"{self.fname()} ...") + pass + + + + +if __name__ == "__main__": + App() + CvttAppConfig() + PairTradingRunner() + App.instance().run() diff --git a/bin/trade_pair.py b/bin/trade_pair.py deleted file mode 100644 index 883e69b..0000000 --- a/bin/trade_pair.py +++ /dev/null @@ -1,69 +0,0 @@ - -from functools import partial -from typing import Dict - -from cvtt_client.mkt_data import (CvttPricerWebSockClient, - CvttPricesSubscription, MessageTypeT, - SubscriptionIdT) -from cvttpy_base.tools.app import App -from cvttpy_base.tools.base import NamedObject -from pt_strategy.live_strategy import PtLiveStrategy - - -class PairTradingRunner(NamedObject): - def __init__(self) -> None: - super().__init__() - App.instance().add_call(App.Stage.Config, self._on_config()) - App.instance().add_call(App.Stage.Run, self.run()) - - async def _on_config(self) -> None: - pass - - async def run(self) -> None: - pass - -# async def main() -> None: -# live_strategy = PtLiveStrategy( -# config={}, -# instruments=[ -# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-BTC-USD"}, -# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-ETH-USD"}, -# ] -# ) -# async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: -# print(f"{message_type=} {subscr_id=} {instrument_id}") -# if message_type == "md_aggregate": -# aggr = message.get("md_aggregate", []) -# print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}") -# elif message_type == "historical_md_aggregate": -# for aggr in message.get("historical_data", []): -# print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}") -# else: -# print(f"Unknown message type: {message_type}") - -# pricer_client = CvttPricerWebSockClient( -# "ws://localhost:12346/ws" -# ) -# await pricer_client.subscribe(CvttPricesSubscription( -# exchange_config_name="COINBASE_AT", -# instrument_id="PAIR-BTC-USD", -# interval_sec=60, -# history_depth_sec=60*60*24, -# callback=partial(on_message, instrument_id="PAIR-BTC-USD") -# )) -# await pricer_client.subscribe(CvttPricesSubscription( -# exchange_config_name="COINBASE_AT", -# instrument_id="PAIR-ETH-USD", -# interval_sec=60, -# history_depth_sec=60*60*24, -# callback=partial(on_message, instrument_id="PAIR-ETH-USD") -# )) - -# await pricer_client.run() - - - -if __name__ == "__main__": - App() - - App.instance().run() diff --git a/configuration/pairs_trader.cfg b/configuration/pairs_trader.cfg new file mode 100644 index 0000000..d23e312 --- /dev/null +++ b/configuration/pairs_trader.cfg @@ -0,0 +1,21 @@ +{ + "strategy_config": @inc=file:///home/oleg/develop/pairs_trading/configuration/ols.cfg + "pricer_config": { + "pricer_url": "ws://localhost:12346/ws", + "history_depth_sec": 86400 #"60*60*24", # use simpleeval + "interval_sec": 60 + }, + "ti_config": { + "cvtt_base_url": "http://localhost:23456" + "book_id": "XXXXXXXXX", + "strategy_id": "XXXXXXXXX", + "ti_endpoint": { + "method": "POST", + "url": "/trading_instructions" + }, + "health_check_endpoint": { + "method": "GET", + "url": "/ping" + } + } +} \ No newline at end of file diff --git a/lib/cvtt_client/mkt_data.py b/lib/cvtt_client/mkt_data.py index 8dade62..53284f2 100644 --- a/lib/cvtt_client/mkt_data.py +++ b/lib/cvtt_client/mkt_data.py @@ -10,10 +10,12 @@ import uuid from dataclasses import dataclass from typing import Callable, Coroutine, Dict, List, Optional -from numpy.strings import str_len import websockets from websockets.asyncio.client import ClientConnection +from cvttpy_base.settings.cvtt_types import JsonDictT +from cvttpy_base.tools.logger import Log + MessageTypeT = str SubscriptionIdT = str MessageT = Dict @@ -48,22 +50,56 @@ class CvttPricesSubscription: self.is_subscribed_ = False self.is_historical_ = history_depth_sec > 0 - -class CvttPricerWebSockClient: - # Class members with type hints +class CvttWebSockClient: ws_url_: UrlT websocket_: Optional[ClientConnection] - subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription] is_connected_: bool - logger_: logging.Logger - + def __init__(self, url: str): self.ws_url_ = url self.websocket_ = None self.is_connected_ = False + + async def connect(self) -> None: + self.websocket_ = await websockets.connect(self.ws_url_) + self.is_connected_ = True + + async def close(self) -> None: + if self.websocket_ is not None: + await self.websocket_.close() + self.is_connected_ = False + + async def receive_message(self) -> JsonDictT: + assert self.websocket_ is not None + assert self.is_connected_ + message = await self.websocket_.recv() + message_str = ( + message.decode("utf-8") + if isinstance(message, bytes) + else message + ) + res = json.loads(message_str) + assert res is not None + assert isinstance(res, dict) + return res + + @classmethod + async def check_connection(cls, url: str) -> bool: + try: + async with websockets.connect(url) as websocket: + result = True + except Exception as e: + Log.error(f"Unable to connect to {url}: {str(e)}") + result = False + return result + +class CvttPricerWebSockClient(CvttWebSockClient): + # Class members with type hints + subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription] + + def __init__(self, url: str): + super().__init__(url) self.subscriptions_ = {} - self.logger_ = logging.getLogger(__name__) - logging.basicConfig(level=logging.INFO) async def subscribe( self, subscription: CvttPricesSubscription @@ -71,11 +107,10 @@ class CvttPricerWebSockClient: if not self.is_connected_: try: - self.logger_.info(f"Connecting to {self.ws_url_}") - self.websocket_ = await websockets.connect(self.ws_url_) - self.is_connected_ = True + Log.info(f"Connecting to {self.ws_url_}") + await self.connect() except Exception as e: - self.logger_.error(f"Unable to connect to {self.ws_url_}: {str(e)}") + Log.error(f"Unable to connect to {self.ws_url_}: {str(e)}") raise e subscr_msg = { @@ -109,10 +144,10 @@ class CvttPricerWebSockClient: return False if response.get("status") == "success": - self.logger_.info(f"Subscription successful: {json.dumps(response)}") + Log.info(f"Subscription successful: {json.dumps(response)}") return True elif response.get("status") == "error": - self.logger_.error(f"Subscription failed: {response.get('reason')}") + Log.error(f"Subscription failed: {response.get('reason')}") return False return False @@ -121,19 +156,20 @@ class CvttPricerWebSockClient: try: while self.is_connected_: try: - message = await self.websocket_.recv() - message_str = ( - message.decode("utf-8") - if isinstance(message, bytes) - else message - ) - await self.process_message(json.loads(message_str)) + msg_dict: JsonDictT = await self.receive_message() except websockets.ConnectionClosed: - self.logger_.warning("Connection closed") + Log.warning("Connection closed") self.is_connected_ = False break + except Exception as e: + Log.error(f"Error occurred: {str(e)}") + self.is_connected_ = False + await asyncio.sleep(5) # Wait before reconnecting + + await self.process_message(msg_dict) + except Exception as e: - self.logger_.error(f"Error occurred: {str(e)}") + Log.error(f"Error occurred: {str(e)}") self.is_connected_ = False await asyncio.sleep(5) # Wait before reconnecting @@ -142,13 +178,13 @@ class CvttPricerWebSockClient: if message_type in ["md_aggregate", "historical_md_aggregate"]: subscription_id = message.get("subscr_id") if subscription_id not in self.subscriptions_: - self.logger_.warning(f"Unknown subscription id: {subscription_id}") + Log.warning(f"Unknown subscription id: {subscription_id}") return subscription = self.subscriptions_[subscription_id] await subscription.callback_(message_type, subscription_id, message) else: - self.logger_.warning(f"Unknown message type: {message.get('type')}") + Log.warning(f"Unknown message type: {message.get('type')}") async def main() -> None: diff --git a/lib/pt_strategy/live_strategy.py b/lib/pt_strategy/live/live_strategy.py similarity index 65% rename from lib/pt_strategy/live_strategy.py rename to lib/pt_strategy/live/live_strategy.py index 67d1960..a20ea47 100644 --- a/lib/pt_strategy/live_strategy.py +++ b/lib/pt_strategy/live/live_strategy.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional @@ -8,61 +9,26 @@ from cvttpy_base.settings.cvtt_types import JsonDictT from cvttpy_base.tools.base import NamedObject from cvttpy_base.tools.logger import Log -from cvtt_client.mkt_data import CvttPricerWebSockClient, CvttPricesSubscription, MessageTypeT, SubscriptionIdT +from pt_strategy.live.ti_sender import TradingInstructionsSender from pt_strategy.model_data_policy import ModelDataPolicy -from pt_strategy.pt_market_data import PtMarketData, RealTimeMarketData +from pt_strategy.pt_market_data import RealTimeMarketData from pt_strategy.pt_model import Prediction from pt_strategy.trading_pair import PairState, TradingPair -''' +""" --config=pair.cfg --pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT -''' +""" -class PtMktDataClient(NamedObject): - live_strategy_: PtLiveStrategy - pricer_client_: CvttPricerWebSockClient - subscriptions_: List[CvttPricesSubscription] - - def __init__(self, live_strategy: PtLiveStrategy): - self.live_strategy_ = live_strategy - - async def start(self, subscription: CvttPricesSubscription) -> None: - pricer_url = self.live_strategy_.config_.get("pricer_url", None) #, "ws://localhost:12346/ws") - assert pricer_url is not None, "pricer_url is not found in config" - self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url) - - await self._subscribe() - - async def _subscribe(self) -> None: - pair: TradingPair = self.live_strategy_.trading_pair_ - for instrument in pair.instruments_: - await self.pricer_client_.subscribe(CvttPricesSubscription( - exchange_config_name=instrument["exchange_config_name"], - instrument_id=instrument["instrument_id"], - interval_sec=60, - history_depth_sec=60*60*24, - callback=partial(self.on_message, instrument_id=instrument["instrument_id"]) - )) - - async def on_message(self, message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: - Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}") - aggr: JsonDictT - if message_type == "md_aggregate": - aggr = message.get("md_aggregate", {}) - await self.live_strategy_.on_mkt_data_update(aggr) - # print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}") - elif message_type == "historical_md_aggregate": - aggr = message.get("historical_data", {}) - await self.live_strategy_.on_mkt_data_hist_snapshot(aggr) - # print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}") - else: - Log.info(f"Unknown message type: {message_type}") - - async def run(self) -> None: - await self.pricer_client_.run() +class TradingInstructionType(Enum): + TARGET_POSITION = "TARGET_POSITION" +@dataclass +class TradingInstruction(NamedObject): + type_: TradingInstructionType + exch_instr_: ExchangeInstrument + specifics_: Dict[str, Any] class PtLiveStrategy(NamedObject): @@ -70,22 +36,24 @@ class PtLiveStrategy(NamedObject): trading_pair_: TradingPair model_data_policy_: ModelDataPolicy pt_mkt_data_: RealTimeMarketData - pt_mkt_data_client_: PtMktDataClient - - # for presentation: history of prediction values and trading signals + ti_sender_: TradingInstructionsSender + + # for presentation: history of prediction values and trading signals predictions_: pd.DataFrame - trading_signals_: pd.DataFrame + trading_signals_: pd.DataFrame def __init__( self, config: Dict[str, Any], instruments: List[Dict[str, str]], + ti_sender: TradingInstructionsSender, ): self.config_ = config self.trading_pair_ = TradingPair(config=config, instruments=instruments) self.predictions_ = pd.DataFrame() self.trading_signals_ = pd.DataFrame() + self.ti_sender_ = ti_sender import copy @@ -94,9 +62,16 @@ class PtLiveStrategy(NamedObject): config_copy["instruments"] = instruments self.pt_mkt_data_ = RealTimeMarketData(config=config_copy) self.model_data_policy_ = ModelDataPolicy.create( - config, is_real_time=True,pair=self.trading_pair_ + config, is_real_time=True, pair=self.trading_pair_ ) - + self.open_threshold_ = self.config_.get("dis-equilibrium_open_trshld", 0.0) + assert self.open_threshold_ > 0, "open_threshold must be greater than 0" + self.close_threshold_ = self.config_.get("dis-equilibrium_close_trshld", 0.0) + assert self.close_threshold_ > 0, "close_threshold must be greater than 0" + + def __repr__(self) -> str: + return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, " + async def on_mkt_data_hist_snapshot(self, aggr: JsonDictT) -> None: Log.info(f"on_mkt_data_hist_snapshot: {aggr}") await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr) @@ -107,63 +82,106 @@ class PtLiveStrategy(NamedObject): if market_data_df is not None: self.trading_pair_.market_data_ = market_data_df self.model_data_policy_.advance() - prediction = self.trading_pair_.run(market_data_df, self.model_data_policy_.advance()) - self.predictions_ = pd.concat([self.predictions_, prediction.to_df()], ignore_index=True) - trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1]) + prediction = self.trading_pair_.run( + market_data_df, self.model_data_policy_.advance() + ) + self.predictions_ = pd.concat( + [self.predictions_, prediction.to_df()], ignore_index=True + ) + + trading_instructions: List[TradingInstruction] = ( + self._create_trading_instructions( + prediction=prediction, last_row=market_data_df.iloc[-1] + ) + ) + if len(trading_instructions) > 0: + await self._send_trading_instructions(trading_instructions) + # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1]) # URGENT implement this pass - - async def run(self) -> None: - await self.pt_mkt_data_client_.run() + async def _send_trading_instructions( + self, trading_instructions: pd.DataFrame + ) -> None: + pass - def _create_trades( + def _create_trading_instructions( self, prediction: Prediction, last_row: pd.Series - ) -> Optional[pd.DataFrame]: + ) -> List[TradingInstruction]: pair = self.trading_pair_ - trades = None + trd_instructions: List[TradingInstruction] = [] - open_threshold = self.config_["dis-equilibrium_open_trshld"] - close_threshold = self.config_["dis-equilibrium_close_trshld"] scaled_disequilibrium = prediction.scaled_disequilibrium_ abs_scaled_disequilibrium = abs(scaled_disequilibrium) - if pair.user_data_["state"] in [ - PairState.INITIAL, - PairState.CLOSE, - PairState.CLOSE_POSITION, - PairState.CLOSE_STOP_LOSS, - PairState.CLOSE_STOP_PROFIT, - ]: - if abs_scaled_disequilibrium >= open_threshold: - trades = self._create_open_trades( + if pair.is_closed(): + if abs_scaled_disequilibrium >= self.open_threshold_: + trd_instructions = self._create_open_trade_instructions( pair, row=last_row, prediction=prediction ) - if trades is not None: - trades["status"] = PairState.OPEN.name - print(f"OPEN TRADES:\n{trades}") - pair.user_data_["state"] = PairState.OPEN - pair.on_open_trades(trades) - - elif pair.user_data_["state"] == PairState.OPEN: - if abs_scaled_disequilibrium <= close_threshold: - trades = self._create_close_trades( + elif pair.is_open(): + if abs_scaled_disequilibrium <= self.close_threshold_: + trd_instructions = self._create_close_trade_instructions( pair, row=last_row, prediction=prediction ) - if trades is not None: - trades["status"] = PairState.CLOSE.name - print(f"CLOSE TRADES:\n{trades}") - pair.user_data_["state"] = PairState.CLOSE - pair.on_close_trades(trades) elif pair.to_stop_close_conditions(predicted_row=last_row): - trades = self._create_close_trades(pair, row=last_row) - if trades is not None: - trades["status"] = pair.user_data_["stop_close_state"].name - print(f"STOP CLOSE TRADES:\n{trades}") - pair.user_data_["state"] = pair.user_data_["stop_close_state"] - pair.on_close_trades(trades) + trd_instructions = self._create_close_trade_instructions( + pair, row=last_row + ) - return trades + return trd_instructions + + def _create_open_trade_instructions( + self, pair: TradingPair, row: pd.Series, prediction: Prediction + ) -> List[TradingInstruction]: + scaled_disequilibrium = prediction.scaled_disequilibrium_ + + if scaled_disequilibrium > 0: + side_a = "SELL" + trd_inst_a = TradingInstruction( + type=TradingInstructionType.TARGET_POSITION, + exch_instr=pair.get_instrument_a(), + specifics={"side": "SELL", "strength": -1}, + ) + side_b = "BUY" + else: + side_a = "BUY" + side_b = "SELL" + + # save closing sides + pair.user_data_["open_side_a"] = side_a # used in oustanding positions + pair.user_data_["open_side_b"] = side_b + pair.user_data_["open_px_a"] = px_a + pair.user_data_["open_px_b"] = px_b + pair.user_data_["open_tstamp"] = tstamp + + pair.user_data_["close_side_a"] = side_b # used for closing trades + pair.user_data_["close_side_b"] = side_a + + # create opening trades + df.loc[len(df)] = { + "time": tstamp, + "symbol": pair.symbol_a_, + "side": side_a, + "action": "OPEN", + "price": px_a, + "disequilibrium": diseqlbrm, + "signed_scaled_disequilibrium": scaled_disequilibrium, + "scaled_disequilibrium": abs(scaled_disequilibrium), + # "pair": pair, + } + df.loc[len(df)] = { + "time": tstamp, + "symbol": pair.symbol_b_, + "side": side_b, + "action": "OPEN", + "price": px_b, + "disequilibrium": diseqlbrm, + "scaled_disequilibrium": abs(scaled_disequilibrium), + "signed_scaled_disequilibrium": scaled_disequilibrium, + # "pair": pair, + } + return df def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]: trades = None @@ -328,4 +346,3 @@ class PtLiveStrategy(NamedObject): del pair.user_data_["open_side_a"] del pair.user_data_["open_side_b"] return df - diff --git a/lib/pt_strategy/live/pricer_md_client.py b/lib/pt_strategy/live/pricer_md_client.py new file mode 100644 index 0000000..3a853a9 --- /dev/null +++ b/lib/pt_strategy/live/pricer_md_client.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from functools import partial +from typing import Any, Dict, List, Optional + +import pandas as pd +from cvttpy_base.tools.app import App +from cvttpy_base.tools.config import Config +from cvttpy_base.settings.cvtt_types import JsonDictT +from cvttpy_base.tools.base import NamedObject +from cvttpy_base.tools.logger import Log + +from cvtt_client.mkt_data import ( + CvttPricerWebSockClient, + CvttPricesSubscription, + MessageTypeT, + SubscriptionIdT, +) +from pt_strategy.live.live_strategy import PtLiveStrategy +from pt_strategy.trading_pair import TradingPair + +""" + --config=pair.cfg + --pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT +""" + + +class PtMktDataClient(NamedObject): + config_: Config + live_strategy_: PtLiveStrategy + pricer_client_: CvttPricerWebSockClient + subscriptions_: List[CvttPricesSubscription] + + def __init__(self, live_strategy: PtLiveStrategy, pricer_config: Config): + self.config_ = pricer_config + self.live_strategy_ = live_strategy + + App.instance().add_call(App.Stage.Start, self._on_start()) + App.instance().add_call(App.Stage.Run, self.run()) + + async def _on_start(self) -> None: + pricer_url = self.config_.get_value("pricer_url") + assert pricer_url is not None, "pricer_url is not found in config" + self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url) + + + async def _subscribe(self) -> None: + history_depth_sec = self.config_.get_value("history_depth_sec", 86400) + interval_sec = self.config_.get_value("interval_sec", 60) + + pair: TradingPair = self.live_strategy_.trading_pair_ + subscriptions = [CvttPricesSubscription( + exchange_config_name=instrument["exchange_config_name"], + instrument_id=instrument["instrument_id"], + interval_sec=interval_sec, + history_depth_sec=history_depth_sec, + callback=partial( + self.on_message, instrument_id=instrument["instrument_id"] + ), + ) for instrument in pair.instruments_] + + for subscription in subscriptions: + Log.info(f"{self.fname()} Subscribing to {subscription}") + await self.pricer_client_.subscribe(subscription) + + async def on_message( + self, + message_type: MessageTypeT, + subscr_id: SubscriptionIdT, + message: Dict, + instrument_id: str, + ) -> None: + Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}") + aggr: JsonDictT + if message_type == "md_aggregate": + aggr = message.get("md_aggregate", {}) + await self.live_strategy_.on_mkt_data_update(aggr) + elif message_type == "historical_md_aggregate": + aggr = message.get("historical_data", {}) + await self.live_strategy_.on_mkt_data_hist_snapshot(aggr) + else: + Log.info(f"Unknown message type: {message_type}") + + async def run(self) -> None: + if not await CvttPricerWebSockClient.check_connection(self.pricer_client_.ws_url_): + Log.error(f"Unable to connect to {self.pricer_client_.ws_url_}") + raise Exception(f"Unable to connect to {self.pricer_client_.ws_url_}") + await self._subscribe() + + await self.pricer_client_.run() diff --git a/lib/pt_strategy/live/ti_sender.py b/lib/pt_strategy/live/ti_sender.py new file mode 100644 index 0000000..2435c21 --- /dev/null +++ b/lib/pt_strategy/live/ti_sender.py @@ -0,0 +1,85 @@ +from enum import Enum +from typing import Dict, Any, Tuple +import time +# import aiohttp +from cvttpy_base.tools.app import App +from cvttpy_base.tools.base import NamedObject +from cvttpy_base.tools.config import Config +from cvttpy_base.tools.logger import Log +from cvttpy_base.tools.web.rest_client import REST_RequestProcessor +from cvttpy_base.tools.timeutils import NanoPerSec +from cvttpy_base.tools.timer import Timer + + +class TradingInstructionsSender(NamedObject): + + class TradingInstType(str, Enum): + TARGET_POSITION = "TARGET_POSITION" + DIRECT_ORDER = "DIRECT_ORDER" + MARKET_MAKING = "MARKET_MAKING" + NONE = "NONE" + + config_: Config + ti_method_: str + ti_url_: str + health_check_method_: str + health_check_url_: str + + def __init__(self, config: Config): + self.config_ = config + base_url = config.get_value("url", "ws://localhost:12346/ws") + + self.book_id_ = config.get_value("book_id", "") + assert self.book_id_, "book_id is required" + + self.strategy_id_ = config.get_value("strategy_id", "") + assert self.strategy_id_, "strategy_id is required" + + endpoint_uri = config.get_value("ti_endpoint/url", "/trading_instructions") + endpoint_method = config.get_value("ti_endpoint/method", "POST") + + health_check_uri = config.get_value("health_check_endpoint/url", "/ping") + health_check_method = config.get_value("health_check_endpoint/method", "GET") + + + + self.ti_method_ = endpoint_method + self.ti_url_ = f"{base_url}{endpoint_uri}" + + self.health_check_method_ = health_check_method + self.health_check_url_ = f"{base_url}{health_check_uri}" + + App.instance().add_call(App.Stage.Start, self._set_health_check_timer(), can_run_now=True) + + async def _set_health_check_timer(self) -> None: + # TODO: configurable interval + self.health_check_timer_ = Timer(is_periodic=True, period_interval=15, start_in_sec=0, func=self._health_check) + Log.info(f"{self.fname()} Health check timer set to 15 seconds") + + async def _health_check(self) -> None: + rqst = REST_RequestProcessor(method=self.health_check_method_, url=self.health_check_url_) + async with rqst as (status, msg, headers): + if status != 200: + Log.error(f"{self.fname()} CVTT Service is not responding") + + async def send_tgt_positions(self, strength: float, base_asset: str, quote_asset: str) -> Tuple[int, str]: + instr = { + "type": self.TradingInstType.TARGET_POSITION.value, + "book_id": self.book_id_, + "strategy_id": self.strategy_id_, + "issued_ts_ns": int(time.time() * NanoPerSec), + "data": { + "strength": strength, + "base_asset": base_asset, + "quote_asset": quote_asset, + "user_data": {}, + }, + } + + rqst = REST_RequestProcessor(method=self.ti_method_, url=self.ti_url_, params=instr) + async with rqst as (status, msg, headers): + if status != 200: + raise ConnectionError(f"Failed to send trading instructions: {msg}") + return (status, msg) + + diff --git a/lib/pt_strategy/trading_pair.py b/lib/pt_strategy/trading_pair.py index 9b96adb..db24b9e 100644 --- a/lib/pt_strategy/trading_pair.py +++ b/lib/pt_strategy/trading_pair.py @@ -18,6 +18,20 @@ class PairState(Enum): CLOSE_STOP_LOSS = 5 CLOSE_STOP_PROFIT = 6 + +def get_symbol(instrument: Dict[str, str]) -> str: + if "symbol" in instrument: + return instrument["symbol"] + elif "instrument_id" in instrument: + instrument_id = instrument["instrument_id"] + instrument_pfx = instrument_id[:instrument_id.find("-") + 1] + symbol = instrument_id[len(instrument_pfx):] + instrument["symbol"] = symbol + instrument["instrument_id_pfx"] = instrument_pfx + return symbol + else: + raise ValueError(f"Invalid instrument: {instrument}, missing symbol or instrument_id") + class TradingPair: config_: Dict[str, Any] market_data_: pd.DataFrame @@ -42,14 +56,32 @@ class TradingPair: self.config_ = config self.instruments_ = instruments - self.symbol_a_ = instruments[0]["symbol"] - self.symbol_b_ = instruments[1]["symbol"] + self.symbol_a_ = get_symbol(instruments[0]) + self.symbol_b_ = get_symbol(instruments[1]) self.model_ = PairsTradingModel.create(config) self.stat_model_price_ = config["stat_model_price"] self.user_data_ = { "state": PairState.INITIAL, } + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}:" + f" symbol_a={self.symbol_a_}," + f" symbol_b={self.symbol_b_}," + f" model={self.model_.__class__.__name__}" + ) + + def is_closed(self) -> bool: + return self.user_data_["state"] in [ + PairState.CLOSE, + PairState.CLOSE_POSITION, + PairState.CLOSE_STOP_LOSS, + PairState.CLOSE_STOP_PROFIT, + ] + def is_open(self) -> bool: + return self.user_data_["state"] == PairState.OPEN + def colnames(self) -> List[str]: return [ f"{self.stat_model_price_}_{self.symbol_a_}", diff --git a/requirements.txt b/requirements.txt index 57f7220..99f9cc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -78,6 +78,7 @@ scipy<1.13.0 seaborn>=0.13.2 SecretStorage>=3.3.1 setproctitle>=1.2.2 +simpleeval>=1.0.3 six>=1.16.0 soupsieve>=2.3.1 ssh-import-id>=5.11