pairs_trading/lib/pt_strategy/live/pricer_md_client.py.md
2025-12-22 23:58:41 +00:00

3.3 KiB

from __future__ import annotations

from functools import partial
from typing import Dict, List

# from cvtt_client.mkt_data import (CvttPricerWebSockClient,
#                                   CvttPricesSubscription, MessageTypeT,
#                                   SubscriptionIdT)
from cvttpy_tools.app import App
from cvttpy_tools.base import NamedObject
from cvttpy_tools.config import Config
from cvttpy_tools.logger import Log
from cvttpy_tools.settings.cvtt_types import JsonDictT
from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair

"""
  --config=pair.cfg
  --pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
"""


class PtMktDataClient(NamedObject):
    config_: Config
    live_strategy_: PtLiveStrategy
    pricer_client_: CvttPricerWebSockClient
    subscriptions_: List[CvttPricesSubscription]

    def __init__(self, live_strategy: PtLiveStrategy, pricer_config: Config):
        self.config_ = pricer_config
        self.live_strategy_ = live_strategy

        App.instance().add_call(App.Stage.Start, self._on_start())
        App.instance().add_call(App.Stage.Run, self.run())

    async def _on_start(self) -> None:
        pricer_url = self.config_.get_value("pricer_url")
        assert pricer_url is not None, "pricer_url is not found in config"
        self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)


    async def _subscribe(self) -> None:
        history_depth_sec = self.config_.get_value("history_depth_sec", 86400)
        interval_sec = self.config_.get_value("interval_sec", 60)

        pair: TradingPair = self.live_strategy_.trading_pair_
        subscriptions = [CvttPricesSubscription(
            exchange_config_name=instrument["exchange_config_name"],
            instrument_id=instrument["instrument_id"],
            interval_sec=interval_sec,
            history_depth_sec=history_depth_sec,
            callback=partial(
                self.on_message, instrument_id=instrument["instrument_id"]
            ),
        ) for instrument in pair.instruments_]

        for subscription in subscriptions:
            Log.info(f"{self.fname()} Subscribing to {subscription}")
            await self.pricer_client_.subscribe(subscription)

    async def on_message(
        self,
        message_type: MessageTypeT,
        subscr_id: SubscriptionIdT,
        message: Dict,
        instrument_id: str,
    ) -> None:
        Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
        aggr: JsonDictT
        if message_type == "md_aggregate":
            aggr = message.get("md_aggregate", {})
            await self.live_strategy_.on_mkt_data_update(aggr)
        elif message_type == "historical_md_aggregate":
            aggr = message.get("historical_data", {})
            await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
        else:
            Log.info(f"Unknown message type: {message_type}")

    async def run(self) -> None:
        if not await CvttPricerWebSockClient.check_connection(self.pricer_client_.ws_url_):
            Log.error(f"Unable to connect to {self.pricer_client_.ws_url_}")
            raise Exception(f"Unable to connect to {self.pricer_client_.ws_url_}")
        await self._subscribe()

        await self.pricer_client_.run()