pairs_trading/lib/cvtt/mkt_data.py
Oleg Sheynin 85c9d2ab93 progress
2025-07-10 18:14:37 +00:00

189 lines
6.5 KiB
Python

#!/usr/bin/env python3
import argparse
from ast import Sub
import asyncio
from functools import partial
import json
import logging
import uuid
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Optional
from numpy.strings import str_len
import websockets
from websockets.asyncio.client import ClientConnection
MessageTypeT = str
SubscriptionIdT = str
MessageT = Dict
UrlT = str
CallbackT = Callable[[MessageTypeT, SubscriptionIdT, MessageT], Coroutine[None, str, None]]
@dataclass
class CvttPricesSubscription:
id_: str
exchange_config_name_: str
instrument_id_: str
interval_sec_: int
history_depth_sec_: int
is_subscribed_: bool
is_historical_: bool
callback_: CallbackT
def __init__(
self,
exchange_config_name: str,
instrument_id: str,
interval_sec: int,
history_depth_sec: int,
callback: CallbackT,
):
self.exchange_config_name_ = exchange_config_name
self.instrument_id_ = instrument_id
self.interval_sec_ = interval_sec
self.history_depth_sec_ = history_depth_sec
self.callback_ = callback
self.id_ = str(uuid.uuid4())
self.is_subscribed_ = False
self.is_historical_ = history_depth_sec > 0
class CvttPricerWebSockClient:
# Class members with type hints
ws_url_: UrlT
websocket_: Optional[ClientConnection]
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
is_connected_: bool
logger_: logging.Logger
def __init__(self, url: str):
self.ws_url_ = url
self.websocket_ = None
self.is_connected_ = False
self.subscriptions_ = {}
self.logger_ = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
async def subscribe(
self, subscription: CvttPricesSubscription
) -> str: # returns subscription id
if not self.is_connected_:
try:
self.logger_.info(f"Connecting to {self.ws_url_}")
self.websocket_ = await websockets.connect(self.ws_url_)
self.is_connected_ = True
except Exception as e:
self.logger_.error(f"Unable to connect to {self.ws_url_}: {str(e)}")
raise e
subscr_msg = {
"type": "subscr",
"id": subscription.id_,
"subscr_type": "MD_AGGREGATE",
"exchange_config_name": subscription.exchange_config_name_,
"instrument_id": subscription.instrument_id_,
"interval_sec": subscription.interval_sec_,
}
if subscription.is_historical_:
subscr_msg["history_depth_sec"] = subscription.history_depth_sec_
assert self.websocket_ is not None
await self.websocket_.send(json.dumps(subscr_msg))
response = await self.websocket_.recv()
response_data = json.loads(response)
if not await self.handle_subscription_response(subscription, response_data):
await self.websocket_.close()
self.is_connected_ = False
raise Exception(f"Subscription failed: {str(response)}")
self.subscriptions_[subscription.id_] = subscription
return subscription.id_
async def handle_subscription_response(
self, subscription: CvttPricesSubscription, response: dict
) -> bool:
if response.get("type") != "subscr" or response.get("id") != subscription.id_:
return False
if response.get("status") == "success":
self.logger_.info(f"Subscription successful: {json.dumps(response)}")
return True
elif response.get("status") == "error":
self.logger_.error(f"Subscription failed: {response.get('reason')}")
return False
return False
async def run(self) -> None:
assert self.websocket_
try:
while self.is_connected_:
try:
message = await self.websocket_.recv()
message_str = (
message.decode("utf-8")
if isinstance(message, bytes)
else message
)
await self.process_message(json.loads(message_str))
except websockets.ConnectionClosed:
self.logger_.warning("Connection closed")
self.is_connected_ = False
break
except Exception as e:
self.logger_.error(f"Error occurred: {str(e)}")
self.is_connected_ = False
await asyncio.sleep(5) # Wait before reconnecting
async def process_message(self, message: Dict) -> None:
message_type = message.get("type")
if message_type in ["md_aggregate", "historical_md_aggregate"]:
subscription_id = message.get("subscr_id")
if subscription_id not in self.subscriptions_:
self.logger_.warning(f"Unknown subscription id: {subscription_id}")
return
subscription = self.subscriptions_[subscription_id]
await subscription.callback_(message_type, subscription_id, message)
else:
self.logger_.warning(f"Unknown message type: {message.get('type')}")
async def main() -> None:
async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
print(f"{message_type=} {subscr_id=} {instrument_id}")
if message_type == "md_aggregate":
aggr = message.get("md_aggregate", [])
print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}")
elif message_type == "historical_md_aggregate":
for aggr in message.get("historical_data", []):
print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}")
else:
print(f"Unknown message type: {message_type}")
pricer_client = CvttPricerWebSockClient(
"ws://localhost:12346/ws"
)
await pricer_client.subscribe(CvttPricesSubscription(
exchange_config_name="COINBASE_AT",
instrument_id="PAIR-BTC-USD",
interval_sec=60,
history_depth_sec=60*60*24,
callback=partial(on_message, instrument_id="PAIR-BTC-USD")
))
await pricer_client.subscribe(CvttPricesSubscription(
exchange_config_name="COINBASE_AT",
instrument_id="PAIR-ETH-USD",
interval_sec=60,
history_depth_sec=60*60*24,
callback=partial(on_message, instrument_id="PAIR-ETH-USD")
))
await pricer_client.run()
if __name__ == "__main__":
asyncio.run(main())