#!/usr/bin/env python3 import argparse from ast import Sub import asyncio from functools import partial import json import logging import uuid from dataclasses import dataclass from typing import Callable, Coroutine, Dict, List, Optional from numpy.strings import str_len import websockets from websockets.asyncio.client import ClientConnection MessageTypeT = str SubscriptionIdT = str MessageT = Dict UrlT = str CallbackT = Callable[[MessageTypeT, SubscriptionIdT, MessageT], Coroutine[None, str, None]] @dataclass class CvttPricesSubscription: id_: str exchange_config_name_: str instrument_id_: str interval_sec_: int history_depth_sec_: int is_subscribed_: bool is_historical_: bool callback_: CallbackT def __init__( self, exchange_config_name: str, instrument_id: str, interval_sec: int, history_depth_sec: int, callback: CallbackT, ): self.exchange_config_name_ = exchange_config_name self.instrument_id_ = instrument_id self.interval_sec_ = interval_sec self.history_depth_sec_ = history_depth_sec self.callback_ = callback self.id_ = str(uuid.uuid4()) self.is_subscribed_ = False self.is_historical_ = history_depth_sec > 0 class CvttPricerWebSockClient: # Class members with type hints ws_url_: UrlT websocket_: Optional[ClientConnection] subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription] is_connected_: bool logger_: logging.Logger def __init__(self, url: str): self.ws_url_ = url self.websocket_ = None self.is_connected_ = False self.subscriptions_ = {} self.logger_ = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) async def subscribe( self, subscription: CvttPricesSubscription ) -> str: # returns subscription id if not self.is_connected_: try: self.logger_.info(f"Connecting to {self.ws_url_}") self.websocket_ = await websockets.connect(self.ws_url_) self.is_connected_ = True except Exception as e: self.logger_.error(f"Unable to connect to {self.ws_url_}: {str(e)}") raise e subscr_msg = { "type": "subscr", "id": subscription.id_, "subscr_type": "MD_AGGREGATE", "exchange_config_name": subscription.exchange_config_name_, "instrument_id": subscription.instrument_id_, "interval_sec": subscription.interval_sec_, } if subscription.is_historical_: subscr_msg["history_depth_sec"] = subscription.history_depth_sec_ assert self.websocket_ is not None await self.websocket_.send(json.dumps(subscr_msg)) response = await self.websocket_.recv() response_data = json.loads(response) if not await self.handle_subscription_response(subscription, response_data): await self.websocket_.close() self.is_connected_ = False raise Exception(f"Subscription failed: {str(response)}") self.subscriptions_[subscription.id_] = subscription return subscription.id_ async def handle_subscription_response( self, subscription: CvttPricesSubscription, response: dict ) -> bool: if response.get("type") != "subscr" or response.get("id") != subscription.id_: return False if response.get("status") == "success": self.logger_.info(f"Subscription successful: {json.dumps(response)}") return True elif response.get("status") == "error": self.logger_.error(f"Subscription failed: {response.get('reason')}") return False return False async def run(self) -> None: assert self.websocket_ try: while self.is_connected_: try: message = await self.websocket_.recv() message_str = ( message.decode("utf-8") if isinstance(message, bytes) else message ) await self.process_message(json.loads(message_str)) except websockets.ConnectionClosed: self.logger_.warning("Connection closed") self.is_connected_ = False break except Exception as e: self.logger_.error(f"Error occurred: {str(e)}") self.is_connected_ = False await asyncio.sleep(5) # Wait before reconnecting async def process_message(self, message: Dict) -> None: message_type = message.get("type") if message_type in ["md_aggregate", "historical_md_aggregate"]: subscription_id = message.get("subscr_id") if subscription_id not in self.subscriptions_: self.logger_.warning(f"Unknown subscription id: {subscription_id}") return subscription = self.subscriptions_[subscription_id] await subscription.callback_(message_type, subscription_id, message) else: self.logger_.warning(f"Unknown message type: {message.get('type')}") async def main() -> None: async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: print(f"{message_type=} {subscr_id=} {instrument_id}") if message_type == "md_aggregate": aggr = message.get("md_aggregate", []) print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}") elif message_type == "historical_md_aggregate": for aggr in message.get("historical_data", []): print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}") else: print(f"Unknown message type: {message_type}") pricer_client = CvttPricerWebSockClient( "ws://localhost:12346/ws" ) await pricer_client.subscribe(CvttPricesSubscription( exchange_config_name="COINBASE_AT", instrument_id="PAIR-BTC-USD", interval_sec=60, history_depth_sec=60*60*24, callback=partial(on_message, instrument_id="PAIR-BTC-USD") )) await pricer_client.subscribe(CvttPricesSubscription( exchange_config_name="COINBASE_AT", instrument_id="PAIR-ETH-USD", interval_sec=60, history_depth_sec=60*60*24, callback=partial(on_message, instrument_id="PAIR-ETH-USD") )) await pricer_client.run() if __name__ == "__main__": asyncio.run(main())