221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from typing import Callable, Coroutine, Dict, Optional
|
|
|
|
import websockets
|
|
from cvttpy_tools.settings.cvtt_types import JsonDictT
|
|
from cvttpy_tools.tools.logger import Log
|
|
from websockets.asyncio.client import ClientConnection
|
|
|
|
MessageTypeT = str
|
|
SubscriptionIdT = str
|
|
MessageT = Dict
|
|
UrlT = str
|
|
CallbackT = Callable[[MessageTypeT, SubscriptionIdT, MessageT], Coroutine[None, str, None]]
|
|
|
|
@dataclass
|
|
class CvttPricesSubscription:
|
|
id_: str
|
|
exchange_config_name_: str
|
|
instrument_id_: str
|
|
interval_sec_: int
|
|
history_depth_sec_: int
|
|
is_subscribed_: bool
|
|
is_historical_: bool
|
|
callback_: CallbackT
|
|
|
|
def __init__(
|
|
self,
|
|
exchange_config_name: str,
|
|
instrument_id: str,
|
|
interval_sec: int,
|
|
history_depth_sec: int,
|
|
callback: CallbackT,
|
|
):
|
|
self.exchange_config_name_ = exchange_config_name
|
|
self.instrument_id_ = instrument_id
|
|
self.interval_sec_ = interval_sec
|
|
self.history_depth_sec_ = history_depth_sec
|
|
self.callback_ = callback
|
|
self.id_ = str(uuid.uuid4())
|
|
self.is_subscribed_ = False
|
|
self.is_historical_ = history_depth_sec > 0
|
|
|
|
class CvttWebSockClient:
|
|
ws_url_: UrlT
|
|
websocket_: Optional[ClientConnection]
|
|
is_connected_: bool
|
|
|
|
def __init__(self, url: str):
|
|
self.ws_url_ = url
|
|
self.websocket_ = None
|
|
self.is_connected_ = False
|
|
|
|
async def connect(self) -> None:
|
|
self.websocket_ = await websockets.connect(self.ws_url_)
|
|
self.is_connected_ = True
|
|
|
|
async def close(self) -> None:
|
|
if self.websocket_ is not None:
|
|
await self.websocket_.close()
|
|
self.is_connected_ = False
|
|
|
|
async def receive_message(self) -> JsonDictT:
|
|
assert self.websocket_ is not None
|
|
assert self.is_connected_
|
|
message = await self.websocket_.recv()
|
|
message_str = (
|
|
message.decode("utf-8")
|
|
if isinstance(message, bytes)
|
|
else message
|
|
)
|
|
res = json.loads(message_str)
|
|
assert res is not None
|
|
assert isinstance(res, dict)
|
|
return res
|
|
|
|
@classmethod
|
|
async def check_connection(cls, url: str) -> bool:
|
|
try:
|
|
async with websockets.connect(url) as websocket:
|
|
result = True
|
|
except Exception as e:
|
|
Log.error(f"Unable to connect to {url}: {str(e)}")
|
|
result = False
|
|
return result
|
|
|
|
class CvttPricerWebSockClient(CvttWebSockClient):
|
|
# Class members with type hints
|
|
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
|
|
|
|
def __init__(self, url: str):
|
|
super().__init__(url)
|
|
self.subscriptions_ = {}
|
|
|
|
async def subscribe(
|
|
self, subscription: CvttPricesSubscription
|
|
) -> str: # returns subscription id
|
|
|
|
if not self.is_connected_:
|
|
try:
|
|
Log.info(f"Connecting to {self.ws_url_}")
|
|
await self.connect()
|
|
except Exception as e:
|
|
Log.error(f"Unable to connect to {self.ws_url_}: {str(e)}")
|
|
raise e
|
|
|
|
subscr_msg = {
|
|
"type": "subscr",
|
|
"id": subscription.id_,
|
|
"subscr_type": "MD_AGGREGATE",
|
|
"exchange_config_name": subscription.exchange_config_name_,
|
|
"instrument_id": subscription.instrument_id_,
|
|
"interval_sec": subscription.interval_sec_,
|
|
}
|
|
if subscription.is_historical_:
|
|
subscr_msg["history_depth_sec"] = subscription.history_depth_sec_
|
|
|
|
assert self.websocket_ is not None
|
|
await self.websocket_.send(json.dumps(subscr_msg))
|
|
|
|
response = await self.websocket_.recv()
|
|
response_data = json.loads(response)
|
|
if not await self.handle_subscription_response(subscription, response_data):
|
|
await self.websocket_.close()
|
|
self.is_connected_ = False
|
|
raise Exception(f"Subscription failed: {str(response)}")
|
|
|
|
self.subscriptions_[subscription.id_] = subscription
|
|
return subscription.id_
|
|
|
|
async def handle_subscription_response(
|
|
self, subscription: CvttPricesSubscription, response: dict
|
|
) -> bool:
|
|
if response.get("type") != "subscr" or response.get("id") != subscription.id_:
|
|
return False
|
|
|
|
if response.get("status") == "success":
|
|
Log.info(f"Subscription successful: {json.dumps(response)}")
|
|
return True
|
|
elif response.get("status") == "error":
|
|
Log.error(f"Subscription failed: {response.get('reason')}")
|
|
return False
|
|
return False
|
|
|
|
async def run(self) -> None:
|
|
assert self.websocket_
|
|
try:
|
|
while self.is_connected_:
|
|
try:
|
|
msg_dict: JsonDictT = await self.receive_message()
|
|
except websockets.ConnectionClosed:
|
|
Log.warning("Connection closed")
|
|
self.is_connected_ = False
|
|
break
|
|
except Exception as e:
|
|
Log.error(f"Error occurred: {str(e)}")
|
|
self.is_connected_ = False
|
|
await asyncio.sleep(5) # Wait before reconnecting
|
|
|
|
await self.process_message(msg_dict)
|
|
|
|
except Exception as e:
|
|
Log.error(f"Error occurred: {str(e)}")
|
|
self.is_connected_ = False
|
|
await asyncio.sleep(5) # Wait before reconnecting
|
|
|
|
async def process_message(self, message: Dict) -> None:
|
|
message_type = message.get("type")
|
|
if message_type in ["md_aggregate", "historical_md_aggregate"]:
|
|
subscription_id = message.get("subscr_id")
|
|
if subscription_id not in self.subscriptions_:
|
|
Log.warning(f"Unknown subscription id: {subscription_id}")
|
|
return
|
|
|
|
subscription = self.subscriptions_[subscription_id]
|
|
await subscription.callback_(message_type, subscription_id, message)
|
|
else:
|
|
Log.warning(f"Unknown message type: {message.get('type')}")
|
|
|
|
|
|
async def main() -> None:
|
|
async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
|
|
print(f"{message_type=} {subscr_id=} {instrument_id}")
|
|
if message_type == "md_aggregate":
|
|
aggr = message.get("md_aggregate", [])
|
|
print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}")
|
|
elif message_type == "historical_md_aggregate":
|
|
for aggr in message.get("historical_data", []):
|
|
print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}")
|
|
else:
|
|
print(f"Unknown message type: {message_type}")
|
|
|
|
pricer_client = CvttPricerWebSockClient(
|
|
"ws://localhost:12346/ws"
|
|
)
|
|
await pricer_client.subscribe(CvttPricesSubscription(
|
|
exchange_config_name="COINBASE_AT",
|
|
instrument_id="PAIR-BTC-USD",
|
|
interval_sec=60,
|
|
history_depth_sec=60*60*24,
|
|
callback=partial(on_message, instrument_id="PAIR-BTC-USD")
|
|
))
|
|
await pricer_client.subscribe(CvttPricesSubscription(
|
|
exchange_config_name="COINBASE_AT",
|
|
instrument_id="PAIR-ETH-USD",
|
|
interval_sec=60,
|
|
history_depth_sec=60*60*24,
|
|
callback=partial(on_message, instrument_id="PAIR-ETH-USD")
|
|
))
|
|
|
|
await pricer_client.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|