diff --git a/src/cvtt/mkt_data.py b/src/cvtt/mkt_data.py index 9d2c653..0c8aff0 100644 --- a/src/cvtt/mkt_data.py +++ b/src/cvtt/mkt_data.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 import argparse +from ast import Sub import asyncio +from functools import partial import json import logging import uuid @@ -12,9 +14,11 @@ from numpy.strings import str_len import websockets from websockets.asyncio.client import ClientConnection +MessageTypeT = str SubscriptionIdT = str +MessageT = Dict UrlT = str -CallbackT = Callable[[Dict], Coroutine[None, str, None]] +CallbackT = Callable[[MessageTypeT, SubscriptionIdT, MessageT], Coroutine[None, str, None]] @dataclass class CvttPricesSubscription: @@ -66,11 +70,13 @@ class CvttPricerWebSockClient: ) -> str: # returns subscription id if not self.is_connected_: - self.logger_.info(f"Connecting to {self.ws_url_}") - self.websocket_ = await websockets.connect(self.ws_url_) - self.is_connected_ = True - else: - raise Exception(f"Unable to connect to {self.ws_url_}") + try: + self.logger_.info(f"Connecting to {self.ws_url_}") + self.websocket_ = await websockets.connect(self.ws_url_) + self.is_connected_ = True + except Exception as e: + self.logger_.error(f"Unable to connect to {self.ws_url_}: {str(e)}") + raise e subscr_msg = { "type": "subscr", @@ -83,6 +89,7 @@ class CvttPricerWebSockClient: if subscription.is_historical_: subscr_msg["history_depth_sec"] = subscription.history_depth_sec_ + assert self.websocket_ is not None await self.websocket_.send(json.dumps(subscr_msg)) response = await self.websocket_.recv() @@ -109,7 +116,7 @@ class CvttPricerWebSockClient: return False return False - async def connect_and_subscribe(self) -> None: + async def run(self) -> None: assert self.websocket_ try: while self.is_connected_: @@ -131,51 +138,50 @@ class CvttPricerWebSockClient: await asyncio.sleep(5) # Wait before reconnecting async def process_message(self, message: Dict) -> None: - if message.get("type") in ["md_aggregate", "historical_md_aggregate"]: - subscription_id = message.get("id") + message_type = message.get("type") + if message_type in ["md_aggregate", "historical_md_aggregate"]: + subscription_id = message.get("subscr_id") if subscription_id not in self.subscriptions_: self.logger_.warning(f"Unknown subscription id: {subscription_id}") return subscription = self.subscriptions_[subscription_id] - await subscription.callback_(message) + await subscription.callback_(message_type, subscription_id, message) else: self.logger_.warning(f"Unknown message type: {message.get('type')}") async def main() -> None: - pass - # parser = argparse.ArgumentParser(description="WebSocket API Testing Tool") - # parser.add_argument("--url", required=True, help="WebSocket API URL") - # parser.add_argument( - # "--exchange_config_name", required=True, help="Exchange config name" - # ) - # parser.add_argument( - # "--instrument_ids", required=True, help="Comma separated Instrument IDs" - # ) - # parser.add_argument( - # "--interval_sec", type=int, required=True, help="Interval in seconds" - # ) - # parser.add_argument( - # "--history_depth_sec", - # default=0, - # type=int, - # required=False, - # help="History depth in seconds", - # ) + async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: + print(f"{message_type=} {subscr_id=} {instrument_id}") + if message_type == "md_aggregate": + aggr = message.get("md_aggregate", []) + print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}") + elif message_type == "historical_md_aggregate": + for aggr in message.get("historical_data", []): + print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}") + else: + print(f"Unknown message type: {message_type}") + + pricer_client = CvttPricerWebSockClient( + "ws://localhost:12346/ws" + ) + await pricer_client.subscribe(CvttPricesSubscription( + exchange_config_name="COINBASE_AT", + instrument_id="PAIR-BTC-USD", + interval_sec=60, + history_depth_sec=60*60*24, + callback=partial(on_message, instrument_id="PAIR-BTC-USD") + )) + await pricer_client.subscribe(CvttPricesSubscription( + exchange_config_name="COINBASE_AT", + instrument_id="PAIR-ETH-USD", + interval_sec=60, + history_depth_sec=60*60*24, + callback=partial(on_message, instrument_id="PAIR-ETH-USD") + )) - # args = parser.parse_args() - - # config = PricerClientConfig( - # url_=args.url, - # exchange_config_name_=args.exchange_config_name, - # instrument_ids_=args.instrument_ids.split(","), - # interval_sec_=args.interval_sec, - # history_depth_sec_=args.history_depth_sec, - # ) - - # client = CvttPricerWebSockClient(config) - # await client.connect_and_subscribe() + await pricer_client.run() if __name__ == "__main__":