From 121c85def01021ee5a2eaebabd8a1cb02ac40dec Mon Sep 17 00:00:00 2001 From: Oleg Sheynin Date: Tue, 30 Dec 2025 10:52:33 +0000 Subject: [PATCH] dev progress --- .vscode/pairs_trading.code-workspace | 5 +- apps/pairs_trader.py | 47 ++-- lib/client/mkt_data.py | 220 ------------------ .../mkt_data_client.py} | 123 ++++------ lib/live/rest_client.py | 67 ++++++ lib/pt_strategy/live/live_strategy.py | 48 ++-- lib/pt_strategy/live/pricer_md_client.py.md | 87 ------- lib/pt_strategy/model_data_policy.py | 2 +- lib/pt_strategy/models.py | 4 +- lib/pt_strategy/pt_market_data.py | 60 ----- lib/pt_strategy/research_strategy.py | 12 +- lib/pt_strategy/results.py | 2 +- lib/pt_strategy/trading_pair.py | 2 +- lib/tools/viz/viz_prices.py | 2 +- lib/tools/viz/viz_trades.py | 8 +- requirements.txt | 1 + research/backtest.py | 10 +- research/notebooks/pair_trading_test.ipynb | 12 +- research/tools/research_tools.py | 2 +- tests/viz_test.py | 12 +- 20 files changed, 211 insertions(+), 515 deletions(-) delete mode 100644 lib/client/mkt_data.py rename lib/{client/cvtt_client.py => live/mkt_data_client.py} (69%) create mode 100644 lib/live/rest_client.py delete mode 100644 lib/pt_strategy/live/pricer_md_client.py.md diff --git a/.vscode/pairs_trading.code-workspace b/.vscode/pairs_trading.code-workspace index 9e68e72..c96e644 100644 --- a/.vscode/pairs_trading.code-workspace +++ b/.vscode/pairs_trading.code-workspace @@ -3,5 +3,8 @@ { "path": ".." } - ] + ], + "settings": { + "workbench.colorTheme": "Dracula Theme" + } } \ No newline at end of file diff --git a/apps/pairs_trader.py b/apps/pairs_trader.py index 1a48e55..a686f71 100644 --- a/apps/pairs_trader.py +++ b/apps/pairs_trader.py @@ -15,7 +15,7 @@ from cvttpy_trading.trading.active_instruments import Instruments from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate # --- from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy -# from pairs_trading.lib.pt_strategy.live.pricer_md_client import PtMktDataClient +from pairs_trading.lib.live.mkt_data_client import CvttRestMktDataClient, MdSummary from pairs_trading.lib.pt_strategy.live.ti_sender import TradingInstructionsSender # import sys @@ -23,6 +23,20 @@ from pairs_trading.lib.pt_strategy.live.ti_sender import TradingInstructionsSend # for path in sys.path: # print(path) +''' +Config +======= +{ + "cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456", + "ti_config": { + TODO + }, + "strategy_config": { + TODO + } +} +''' + HistMdCbT = Callable[[List[MdTradesAggregate]], Coroutine] UpdateMdCbT = Callable[[MdTradesAggregate], Coroutine] @@ -31,7 +45,7 @@ class PairsTrader(NamedObject): instruments_: List[JsonDictT] live_strategy_: PtLiveStrategy - # pricer_client_: PtMktDataClient + pricer_client_: CvttRestMktDataClient def __init__(self) -> None: self.instruments_ = [] @@ -90,16 +104,10 @@ class PairsTrader(NamedObject): Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}") # # ------- CREATE PRICER CLIENT ------- - # URGENT - # pricer_config = self.config_.get_subconfig("pricer_config", {}) - # self.pricer_client_ = PtMktDataClient( - # live_strategy=self.live_strategy_, - # pricer_config=pricer_config - # ) - # Log.info(f"{self.fname()} CVTT Pricer client created: {self.pricer_client_}") + self.pricer_client_ = CvttRestMktDataClient(self.config_) # ------- CREATE TRADER CLIENT ------- - # URGENT + # URGENT CREATE TRADER CLIENT # (send TradingInstructions) # ti_config = self.config_.get_subconfig("ti_config", {}) # self.ti_sender_ = TradingInstructionsSender(config=ti_config) @@ -107,12 +115,23 @@ class PairsTrader(NamedObject): # # ------- CREATE REST SERVER ------- - # for dashboard communications + # URGENT CREATE REST SERVER for dashboard communications async def subscribe_md(self) -> None: - pass - # URGENT implement PairsTrader.subscribe_md() - + for inst in self.instruments_: + exch_acct = inst.get("exch_acct", "?exch_acct?") + instrument_id = inst.get("instrument_id", "?instrument_id?") + await self.pricer_client_.add_subscription( + exch_acct=exch_acct, + instrument_id=instrument_id, + interval_sec=self.live_strategy_.interval_sec(), + history_depth_sec=self.live_strategy_.history_depth_sec(), + callback=self._on_md_summary + ) + + def _on_md_summary(self, history: List[MdSummary]) -> None: + pass # URGENT + async def run(self) -> None: Log.info(f"{self.fname()} ...") pass diff --git a/lib/client/mkt_data.py b/lib/client/mkt_data.py deleted file mode 100644 index 4ac7e2b..0000000 --- a/lib/client/mkt_data.py +++ /dev/null @@ -1,220 +0,0 @@ -#!/usr/bin/env python3 - -import asyncio -import json -import uuid -from dataclasses import dataclass -from functools import partial -from typing import Callable, Coroutine, Dict, Optional - -import websockets -from cvttpy_tools.logger import Log -from cvttpy_tools.settings.cvtt_types import JsonDictT -from websockets.asyncio.client import ClientConnection - -MessageTypeT = str -SubscriptionIdT = str -MessageT = Dict -UrlT = str -CallbackT = Callable[[MessageTypeT, SubscriptionIdT, MessageT], Coroutine[None, str, None]] - -@dataclass -class CvttPricesSubscription: - id_: str - exchange_config_name_: str - instrument_id_: str - interval_sec_: int - history_depth_sec_: int - is_subscribed_: bool - is_historical_: bool - callback_: CallbackT - - def __init__( - self, - exchange_config_name: str, - instrument_id: str, - interval_sec: int, - history_depth_sec: int, - callback: CallbackT, - ): - self.exchange_config_name_ = exchange_config_name - self.instrument_id_ = instrument_id - self.interval_sec_ = interval_sec - self.history_depth_sec_ = history_depth_sec - self.callback_ = callback - self.id_ = str(uuid.uuid4()) - self.is_subscribed_ = False - self.is_historical_ = history_depth_sec > 0 - -class CvttWebSockClient: - ws_url_: UrlT - websocket_: Optional[ClientConnection] - is_connected_: bool - - def __init__(self, url: str): - self.ws_url_ = url - self.websocket_ = None - self.is_connected_ = False - - async def connect(self) -> None: - self.websocket_ = await websockets.connect(self.ws_url_) - self.is_connected_ = True - - async def close(self) -> None: - if self.websocket_ is not None: - await self.websocket_.close() - self.is_connected_ = False - - async def receive_message(self) -> JsonDictT: - assert self.websocket_ is not None - assert self.is_connected_ - message = await self.websocket_.recv() - message_str = ( - message.decode("utf-8") - if isinstance(message, bytes) - else message - ) - res = json.loads(message_str) - assert res is not None - assert isinstance(res, dict) - return res - - @classmethod - async def check_connection(cls, url: str) -> bool: - try: - async with websockets.connect(url) as websocket: - result = True - except Exception as e: - Log.error(f"Unable to connect to {url}: {str(e)}") - result = False - return result - -class CvttPricerWebSockClient(CvttWebSockClient): - # Class members with type hints - subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription] - - def __init__(self, url: str): - super().__init__(url) - self.subscriptions_ = {} - - async def subscribe( - self, subscription: CvttPricesSubscription - ) -> str: # returns subscription id - - if not self.is_connected_: - try: - Log.info(f"Connecting to {self.ws_url_}") - await self.connect() - except Exception as e: - Log.error(f"Unable to connect to {self.ws_url_}: {str(e)}") - raise e - - subscr_msg = { - "type": "subscr", - "id": subscription.id_, - "subscr_type": "MD_AGGREGATE", - "exchange_config_name": subscription.exchange_config_name_, - "instrument_id": subscription.instrument_id_, - "interval_sec": subscription.interval_sec_, - } - if subscription.is_historical_: - subscr_msg["history_depth_sec"] = subscription.history_depth_sec_ - - assert self.websocket_ is not None - await self.websocket_.send(json.dumps(subscr_msg)) - - response = await self.websocket_.recv() - response_data = json.loads(response) - if not await self.handle_subscription_response(subscription, response_data): - await self.websocket_.close() - self.is_connected_ = False - raise Exception(f"Subscription failed: {str(response)}") - - self.subscriptions_[subscription.id_] = subscription - return subscription.id_ - - async def handle_subscription_response( - self, subscription: CvttPricesSubscription, response: dict - ) -> bool: - if response.get("type") != "subscr" or response.get("id") != subscription.id_: - return False - - if response.get("status") == "success": - Log.info(f"Subscription successful: {json.dumps(response)}") - return True - elif response.get("status") == "error": - Log.error(f"Subscription failed: {response.get('reason')}") - return False - return False - - async def run(self) -> None: - assert self.websocket_ - try: - while self.is_connected_: - try: - msg_dict: JsonDictT = await self.receive_message() - except websockets.ConnectionClosed: - Log.warning("Connection closed") - self.is_connected_ = False - break - except Exception as e: - Log.error(f"Error occurred: {str(e)}") - self.is_connected_ = False - await asyncio.sleep(5) # Wait before reconnecting - - await self.process_message(msg_dict) - - except Exception as e: - Log.error(f"Error occurred: {str(e)}") - self.is_connected_ = False - await asyncio.sleep(5) # Wait before reconnecting - - async def process_message(self, message: Dict) -> None: - message_type = message.get("type") - if message_type in ["md_aggregate", "historical_md_aggregate"]: - subscription_id = message.get("subscr_id") - if subscription_id not in self.subscriptions_: - Log.warning(f"Unknown subscription id: {subscription_id}") - return - - subscription = self.subscriptions_[subscription_id] - await subscription.callback_(message_type, subscription_id, message) - else: - Log.warning(f"Unknown message type: {message.get('type')}") - - -async def main() -> None: - async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: - print(f"{message_type=} {subscr_id=} {instrument_id}") - if message_type == "md_aggregate": - aggr = message.get("md_aggregate", []) - print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}") - elif message_type == "historical_md_aggregate": - for aggr in message.get("historical_data", []): - print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}") - else: - print(f"Unknown message type: {message_type}") - - pricer_client = CvttPricerWebSockClient( - "ws://localhost:12346/ws" - ) - await pricer_client.subscribe(CvttPricesSubscription( - exchange_config_name="COINBASE_AT", - instrument_id="PAIR-BTC-USD", - interval_sec=60, - history_depth_sec=60*60*24, - callback=partial(on_message, instrument_id="PAIR-BTC-USD") - )) - await pricer_client.subscribe(CvttPricesSubscription( - exchange_config_name="COINBASE_AT", - instrument_id="PAIR-ETH-USD", - interval_sec=60, - history_depth_sec=60*60*24, - callback=partial(on_message, instrument_id="PAIR-ETH-USD") - )) - - await pricer_client.run() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/lib/client/cvtt_client.py b/lib/live/mkt_data_client.py similarity index 69% rename from lib/client/cvtt_client.py rename to lib/live/mkt_data_client.py index d14d805..04d7304 100644 --- a/lib/client/cvtt_client.py +++ b/lib/live/mkt_data_client.py @@ -1,70 +1,24 @@ from __future__ import annotations import asyncio -from typing import Callable, Dict, Any, List, Optional -import time +from typing import Callable, Dict, Any, List, Optional, Set import requests from cvttpy_tools.base import NamedObject +from cvttpy_tools.app import App from cvttpy_tools.logger import Log from cvttpy_tools.config import Config from cvttpy_tools.timer import Timer - -from cvttpy_tools.timeutils import NanoPerSec, NanosT, current_nanoseconds, current_seconds +from cvttpy_tools.timeutils import NanosT, current_seconds +from cvttpy_tools.settings.cvtt_types import InstrumentIdT, IntervalSecT +# --- from cvttpy_trading.trading.mkt_data.historical_md import HistMdBar +from cvttpy_trading.trading.accounting.exch_account import ExchangeAccountNameT +# --- +from pairs_trading.lib.live.rest_client import RESTSender -class RESTSender(NamedObject): - session_: requests.Session - base_url_: str - - def __init__(self, base_url: str) -> None: - self.base_url_ = base_url - self.session_ = requests.Session() - - def is_ready(self) -> bool: - """Checks if the server is up and responding""" - url = f"{self.base_url_}/ping" - try: - response = self.session_.get(url) - response.raise_for_status() - return True - except requests.exceptions.RequestException: - return False - - def send_post(self, endpoint: str, post_body: Dict) -> requests.Response: - - while not self.is_ready(): - print("Waiting for FrontGateway to start...") - time.sleep(5) - - url = f"{self.base_url_}/{endpoint}" - try: - return self.session_.request( - method="POST", - url=url, - json=post_body, - headers={"Content-Type": "application/json"}, - ) - except requests.exceptions.RequestException as excpt: - raise ConnectionError( - f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore - ) from excpt - - def send_get(self, endpoint: str) -> requests.Response: - while not self.is_ready(): - print("Waiting for FrontGateway to start...") - time.sleep(5) - - url = f"{self.base_url_}/{endpoint}" - try: - return self.session_.request(method="GET", url=url) - except requests.exceptions.RequestException as excpt: - raise ConnectionError( - f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore - ) from excpt - class MdSummary(HistMdBar): def __init__( self, @@ -110,10 +64,10 @@ MdSummaryCallbackT = Callable[[List[MdSummary]], None] class MdSummaryCollector(NamedObject): sender_: RESTSender - exch_acct_: str - instrument_id_: str - interval_sec_: int - history_depth_sec_: int + exch_acct_: ExchangeAccountNameT + instrument_id_: InstrumentIdT + interval_sec_: IntervalSecT + history_depth_sec_: IntervalSecT history_: List[MdSummary] callbacks_: List[MdSummaryCallbackT] @@ -122,10 +76,10 @@ class MdSummaryCollector(NamedObject): def __init__( self, sender: RESTSender, - exch_acct: str, - instrument_id: str, - interval_sec: int, - history_depth_sec: int, + exch_acct: ExchangeAccountNameT, + instrument_id: InstrumentIdT, + interval_sec: IntervalSecT, + history_depth_sec: IntervalSecT, ) -> None: self.sender_ = sender self.exch_acct_ = exch_acct @@ -140,6 +94,9 @@ class MdSummaryCollector(NamedObject): def add_callback(self, cb: MdSummaryCallbackT) -> None: self.callbacks_.append(cb) + def __hash__(self): + return hash((self.exch_acct_, self.instrument_id_, self.interval_sec_, self.history_depth_sec_)) + def rqst_data(self) -> Dict[str, Any]: return { "exch_acct": self.exch_acct_, @@ -191,7 +148,6 @@ class MdSummaryCollector(NamedObject): def next_load_time(self) -> NanosT: curr_sec = int(current_seconds()) return (curr_sec - curr_sec % self.interval_sec_) + self.interval_sec_ + 2 - async def _load_new(self) -> None: @@ -213,39 +169,54 @@ class MdSummaryCollector(NamedObject): self.timer_.cancel() self.timer_ = None -class CvttRESTClient(NamedObject): +class CvttRestMktDataClient(NamedObject): config_: Config sender_: RESTSender + collectors_: Set[MdSummaryCollector] def __init__(self, config: Config) -> None: self.config_ = config base_url = self.config_.get_value("cvtt_base_url", default="") assert base_url self.sender_ = RESTSender(base_url=base_url) + self.collectors_ = set() + async def add_subscription(self, + exch_acct: ExchangeAccountNameT, + instrument_id: InstrumentIdT, + interval_sec: IntervalSecT, + history_depth_sec: IntervalSecT, + callback: MdSummaryCallbackT + ) -> None: + mdsc = MdSummaryCollector( + sender=self.sender_, + exch_acct=exch_acct, + instrument_id=instrument_id, + interval_sec=interval_sec, + history_depth_sec=history_depth_sec, + ) + mdsc.add_callback(callback) + self.collectors_.add(mdsc) + await mdsc.start() if __name__ == "__main__": config = Config(json_src={"cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456"}) # config = Config(json_src={"cvtt_base_url": "http://dev-server-02.cvtt.vpn:23456"}) - cvtt_client = CvttRESTClient(config) - - mdsc = MdSummaryCollector( - sender=cvtt_client.sender_, - exch_acct="COINBASE_AT", - instrument_id="PAIR-BTC-USD", - interval_sec=60, - history_depth_sec=24 * 3600, - ) - def _calback(history: List[MdSummary]) -> None: Log.info(f"MdSummary Hist Length is {len(history)}. Last summary: {history[-1] if len(history) > 0 else '[]'}") - mdsc.add_callback(_calback) async def __run() -> None: Log.info("Starting...") - await mdsc.start() + cvtt_client = CvttRestMktDataClient(config) + await cvtt_client.add_subscription( + exch_acct="COINBASE_AT", + instrument_id="PAIR-BTC-USD", + interval_sec=60, + history_depth_sec=24 * 3600, + callback=_calback + ) while True: await asyncio.sleep(5) diff --git a/lib/live/rest_client.py b/lib/live/rest_client.py new file mode 100644 index 0000000..a5ed6c1 --- /dev/null +++ b/lib/live/rest_client.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import asyncio +from typing import Callable, Dict, Any, List, Optional +import time + +import requests + +from cvttpy_tools.base import NamedObject +from cvttpy_tools.logger import Log +from cvttpy_tools.config import Config +from cvttpy_tools.timer import Timer + +from cvttpy_tools.timeutils import NanoPerSec, NanosT, current_nanoseconds, current_seconds +from cvttpy_trading.trading.mkt_data.historical_md import HistMdBar + + +class RESTSender(NamedObject): + session_: requests.Session + base_url_: str + + def __init__(self, base_url: str) -> None: + self.base_url_ = base_url + self.session_ = requests.Session() + + def is_ready(self) -> bool: + """Checks if the server is up and responding""" + url = f"{self.base_url_}/ping" + try: + response = self.session_.get(url) + response.raise_for_status() + return True + except requests.exceptions.RequestException: + return False + + def send_post(self, endpoint: str, post_body: Dict) -> requests.Response: + + while not self.is_ready(): + print("Waiting for FrontGateway to start...") + time.sleep(5) + + url = f"{self.base_url_}/{endpoint}" + try: + return self.session_.request( + method="POST", + url=url, + json=post_body, + headers={"Content-Type": "application/json"}, + ) + except requests.exceptions.RequestException as excpt: + raise ConnectionError( + f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore + ) from excpt + + def send_get(self, endpoint: str) -> requests.Response: + while not self.is_ready(): + print("Waiting for FrontGateway to start...") + time.sleep(5) + + url = f"{self.base_url_}/{endpoint}" + try: + return self.session_.request(method="GET", url=url) + except requests.exceptions.RequestException as excpt: + raise ConnectionError( + f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore + ) from excpt + diff --git a/lib/pt_strategy/live/live_strategy.py b/lib/pt_strategy/live/live_strategy.py index c63b4e6..2221ca8 100644 --- a/lib/pt_strategy/live/live_strategy.py +++ b/lib/pt_strategy/live/live_strategy.py @@ -8,8 +8,7 @@ import pandas as pd # --- from cvttpy_tools.base import NamedObject from cvttpy_tools.app import App -from cvttpy_tools.logger import Log -from cvttpy_tools.settings.cvtt_types import JsonDictT +from cvttpy_tools.settings.cvtt_types import IntervalSecT # --- from cvttpy_trading.trading.instrument import ExchangeInstrument from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate @@ -18,7 +17,6 @@ from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy from pairs_trading.lib.pt_strategy.pt_model import Prediction from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair from pairs_trading.apps.pairs_trader import PairsTrader -from pairs_trading.lib.pt_strategy.pt_market_data import RealTimeMarketData """ --config=pair.cfg --pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT @@ -41,7 +39,6 @@ class PtLiveStrategy(NamedObject): model_data_policy_: ModelDataPolicy pairs_trader_: PairsTrader - pt_mkt_data_: RealTimeMarketData # ti_sender_: TradingInstructionsSender # for presentation: history of prediction values and trading signals @@ -90,27 +87,32 @@ class PtLiveStrategy(NamedObject): pass # URGENT PtiveStrategy.on_mkt_data_hist_snapshot() async def on_mkt_data_update(self, aggr: MdTradesAggregate) -> None: - market_data_df = await self.pt_mkt_data_.on_mkt_data_update(update=aggr) - if market_data_df is not None: - self.trading_pair_.market_data_ = market_data_df - self.model_data_policy_.advance() - prediction = self.trading_pair_.run( - market_data_df, self.model_data_policy_.advance() - ) - self.predictions_ = pd.concat( - [self.predictions_, prediction.to_df()], ignore_index=True - ) + # if market_data_df is not None: + # self.trading_pair_.market_data_ = market_data_df + # self.model_data_policy_.advance() + # prediction = self.trading_pair_.run( + # market_data_df, self.model_data_policy_.advance() + # ) + # self.predictions_ = pd.concat( + # [self.predictions_, prediction.to_df()], ignore_index=True + # ) - trading_instructions: List[TradingInstruction] = ( - self._create_trading_instructions( - prediction=prediction, last_row=market_data_df.iloc[-1] - ) - ) - if len(trading_instructions) > 0: - await self._send_trading_instructions(trading_instructions) - # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1]) - pass + # trading_instructions: List[TradingInstruction] = ( + # self._create_trading_instructions( + # prediction=prediction, last_row=market_data_df.iloc[-1] + # ) + # ) + # if len(trading_instructions) > 0: + # await self._send_trading_instructions(trading_instructions) + # # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1]) + pass # URGENT + def interval_sec(self) -> IntervalSecT: + return 60 # URGENT use config + + def history_depth_sec(self) -> IntervalSecT: + return 3600 * 60 * 2 # URGENT use config + async def _send_trading_instructions( self, trading_instructions: List[TradingInstruction] ) -> None: diff --git a/lib/pt_strategy/live/pricer_md_client.py.md b/lib/pt_strategy/live/pricer_md_client.py.md deleted file mode 100644 index c60d37b..0000000 --- a/lib/pt_strategy/live/pricer_md_client.py.md +++ /dev/null @@ -1,87 +0,0 @@ -```python -from __future__ import annotations - -from functools import partial -from typing import Dict, List - -# from cvtt_client.mkt_data import (CvttPricerWebSockClient, -# CvttPricesSubscription, MessageTypeT, -# SubscriptionIdT) -from cvttpy_tools.app import App -from cvttpy_tools.base import NamedObject -from cvttpy_tools.config import Config -from cvttpy_tools.logger import Log -from cvttpy_tools.settings.cvtt_types import JsonDictT -from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy -from pairs_trading.lib.pt_strategy.trading_pair import TradingPair - -""" - --config=pair.cfg - --pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT -""" - - -class PtMktDataClient(NamedObject): - config_: Config - live_strategy_: PtLiveStrategy - pricer_client_: CvttPricerWebSockClient - subscriptions_: List[CvttPricesSubscription] - - def __init__(self, live_strategy: PtLiveStrategy, pricer_config: Config): - self.config_ = pricer_config - self.live_strategy_ = live_strategy - - App.instance().add_call(App.Stage.Start, self._on_start()) - App.instance().add_call(App.Stage.Run, self.run()) - - async def _on_start(self) -> None: - pricer_url = self.config_.get_value("pricer_url") - assert pricer_url is not None, "pricer_url is not found in config" - self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url) - - - async def _subscribe(self) -> None: - history_depth_sec = self.config_.get_value("history_depth_sec", 86400) - interval_sec = self.config_.get_value("interval_sec", 60) - - pair: TradingPair = self.live_strategy_.trading_pair_ - subscriptions = [CvttPricesSubscription( - exchange_config_name=instrument["exchange_config_name"], - instrument_id=instrument["instrument_id"], - interval_sec=interval_sec, - history_depth_sec=history_depth_sec, - callback=partial( - self.on_message, instrument_id=instrument["instrument_id"] - ), - ) for instrument in pair.instruments_] - - for subscription in subscriptions: - Log.info(f"{self.fname()} Subscribing to {subscription}") - await self.pricer_client_.subscribe(subscription) - - async def on_message( - self, - message_type: MessageTypeT, - subscr_id: SubscriptionIdT, - message: Dict, - instrument_id: str, - ) -> None: - Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}") - aggr: JsonDictT - if message_type == "md_aggregate": - aggr = message.get("md_aggregate", {}) - await self.live_strategy_.on_mkt_data_update(aggr) - elif message_type == "historical_md_aggregate": - aggr = message.get("historical_data", {}) - await self.live_strategy_.on_mkt_data_hist_snapshot(aggr) - else: - Log.info(f"Unknown message type: {message_type}") - - async def run(self) -> None: - if not await CvttPricerWebSockClient.check_connection(self.pricer_client_.ws_url_): - Log.error(f"Unable to connect to {self.pricer_client_.ws_url_}") - raise Exception(f"Unable to connect to {self.pricer_client_.ws_url_}") - await self._subscribe() - - await self.pricer_client_.run() -``` \ No newline at end of file diff --git a/lib/pt_strategy/model_data_policy.py b/lib/pt_strategy/model_data_policy.py index 8c1923e..b7de83d 100644 --- a/lib/pt_strategy/model_data_policy.py +++ b/lib/pt_strategy/model_data_policy.py @@ -91,7 +91,7 @@ class OptimizedWndDataPolicy(ModelDataPolicy, ABC): self.min_training_size_ = cast(int, config.get("min_training_size")) self.max_training_size_ = cast(int, config.get("max_training_size")) - from pt_strategy.trading_pair import TradingPair + from pairs_trading.lib.pt_strategy.trading_pair import TradingPair self.pair_ = cast(TradingPair, kwargs.get("pair")) if "mkt_data" in kwargs: diff --git a/lib/pt_strategy/models.py b/lib/pt_strategy/models.py index de343ca..ccc9264 100644 --- a/lib/pt_strategy/models.py +++ b/lib/pt_strategy/models.py @@ -6,8 +6,8 @@ import statsmodels.api as sm -from pt_strategy.pt_model import PairsTradingModel, Prediction -from pt_strategy.trading_pair import TradingPair +from pairs_trading.lib.pt_strategy.pt_model import PairsTradingModel, Prediction +from pairs_trading.lib.pt_strategy.trading_pair import TradingPair class OLSModel(PairsTradingModel): diff --git a/lib/pt_strategy/pt_market_data.py b/lib/pt_strategy/pt_market_data.py index da13a2e..edf0dd6 100644 --- a/lib/pt_strategy/pt_market_data.py +++ b/lib/pt_strategy/pt_market_data.py @@ -171,63 +171,3 @@ class ResearchMarketData(PtMarketData): f"exec_price_{self.symbol_b_}", ] -class RealTimeMarketData(PtMarketData): - - def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): - super().__init__(config, *args, **kwargs) - - async def on_mkt_data_hist_snapshot(self, snapshot: JsonDictT) -> None: - # URGENT - # create origin_mkt_data_df_ from snapshot - # verify that the data for both instruments are present - - # transform it to market_data_df_ tstamp, close_symbolA, close_symbolB - ''' - # from cvttpy/exchanges/binance/spot/mkt_data.py - values = { - "time_ns": time_ns, - "tstamp": format_nanos_utc(time_ns), - "exchange_id": exch_inst.exchange_id_, - "instrument_id": exch_inst.instrument_id(), - "interval_ns": interval_sec * 1_000_000_000, - "open": float(kline[1]), - "high": float(kline[2]), - "low": float(kline[3]), - "close": float(kline[4]), - "volume": float(kline[5]), - "num_trades": kline[8], - "vwap": float(kline[7]) / float(kline[5]) if float(kline[5]) > 0 else 0.0 # Calculate VWAP - } - ''' - - - pass - - async def on_mkt_data_update(self, update: MdTradesAggregate) -> Optional[pd.DataFrame]: - # URGENT - # make sure update has both instruments - # create DataFrame tmp1 from update - # transform tmp1 into temp. datframe tmp2 - # add tmp1 to origin_mkt_data_df_ - # add tmp2 to market_data_df_ - # return market_data_df_ - ''' - class MdTradesAggregate(NamedObject): - def to_dict(self) -> Dict[str, Any]: - return { - "time_ns": self.time_ns_, - "tstamp": format_nanos_utc(self.time_ns_), - "exchange_id": self.exch_inst_.exchange_id_, - "instrument_id": self.exch_inst_.instrument_id(), - "interval_ns": self.interval_ns_, - "open": self.exch_inst_.get_price(self.open_), - "high": self.exch_inst_.get_price(self.high_), - "low": self.exch_inst_.get_price(self.low_), - "close": self.exch_inst_.get_price(self.close_), - "volume": self.exch_inst_.get_quantity(self.volume_), - "vwap": self.exch_inst_.get_price(self.vwap_), - "num_trades": self.exch_inst_.get_quantity(self.num_trades_), - } - ''' - - return pd.DataFrame() \ No newline at end of file diff --git a/lib/pt_strategy/research_strategy.py b/lib/pt_strategy/research_strategy.py index 4f4fafc..f521c34 100644 --- a/lib/pt_strategy/research_strategy.py +++ b/lib/pt_strategy/research_strategy.py @@ -3,10 +3,10 @@ from __future__ import annotations from typing import Any, Dict, List, Optional import pandas as pd -from pt_strategy.model_data_policy import ModelDataPolicy -from pt_strategy.pt_market_data import ResearchMarketData -from pt_strategy.pt_model import Prediction -from pt_strategy.trading_pair import PairState, TradingPair +from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy +from pairs_trading.lib.pt_strategy.pt_market_data import ResearchMarketData +from pairs_trading.lib.pt_strategy.pt_model import Prediction +from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair class PtResearchStrategy: @@ -24,8 +24,8 @@ class PtResearchStrategy: datafiles: List[str], instruments: List[Dict[str, str]], ): - from pt_strategy.model_data_policy import ModelDataPolicy - from pt_strategy.trading_pair import TradingPair + from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy + from pairs_trading.lib.pt_strategy.trading_pair import TradingPair self.config_ = config self.trades_ = [] diff --git a/lib/pt_strategy/results.py b/lib/pt_strategy/results.py index 9e58042..af64c5e 100644 --- a/lib/pt_strategy/results.py +++ b/lib/pt_strategy/results.py @@ -4,7 +4,7 @@ from datetime import date, datetime from typing import Any, Dict, List, Optional, Tuple import pandas as pd -from pt_strategy.trading_pair import TradingPair +from pairs_trading.lib.pt_strategy.trading_pair import TradingPair # Recommended replacement adapters and converters for Python 3.12+ diff --git a/lib/pt_strategy/trading_pair.py b/lib/pt_strategy/trading_pair.py index 2c60801..4a243c2 100644 --- a/lib/pt_strategy/trading_pair.py +++ b/lib/pt_strategy/trading_pair.py @@ -57,7 +57,7 @@ class TradingPair: instruments: List[Dict[str, str]], ): - from pt_strategy.pt_model import PairsTradingModel + from pairs_trading.lib.pt_strategy.pt_model import PairsTradingModel assert len(instruments) == 2, "Trading pair must have exactly 2 instruments" diff --git a/lib/tools/viz/viz_prices.py b/lib/tools/viz/viz_prices.py index 418426e..e84bae5 100644 --- a/lib/tools/viz/viz_prices.py +++ b/lib/tools/viz/viz_prices.py @@ -1,4 +1,4 @@ -from pt_strategy.research_strategy import PtResearchStrategy +from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy def visualize_prices(strategy: PtResearchStrategy, trading_date: str) -> None: diff --git a/lib/tools/viz/viz_trades.py b/lib/tools/viz/viz_trades.py index 274fbd1..147c278 100644 --- a/lib/tools/viz/viz_trades.py +++ b/lib/tools/viz/viz_trades.py @@ -3,11 +3,11 @@ from __future__ import annotations import os from typing import Any, Dict -from pt_strategy.results import (PairResearchResult, create_result_database, +from pairs_trading.lib.pairs_trading.lib.tegy.results import (PairResearchResult, create_result_database, store_config_in_database) -from pt_strategy.research_strategy import PtResearchStrategy -from tools.filetools import resolve_datafiles -from tools.instruments import get_instruments +from pairs_trading.lib.pairs_trading.lib.t_strategy.research_strategy import PtResearchStrategy +from pairs_trading.lib.tools.filetools import resolve_datafiles +from pairs_trading.lib.tools.instruments import get_instruments def visualize_trades(strategy: PtResearchStrategy, results: PairResearchResult, trading_date: str) -> None: diff --git a/requirements.txt b/requirements.txt index 99f9cc2..61e2b68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -170,6 +170,7 @@ types-PyYAML>=5.4 types-redis>=3.5 types-requests>=2.25 types-retry>=0.9 +types-seaborn>0.13.2 types-selenium>=3.141 types-Send2Trash>=1.8 types-setuptools>=57.4 diff --git a/research/backtest.py b/research/backtest.py index 4e4f202..d023a0e 100644 --- a/research/backtest.py +++ b/research/backtest.py @@ -3,20 +3,20 @@ from __future__ import annotations import os from typing import Any, Dict -from pt_strategy.results import ( +from pairs_trading.lib.pt_strategy.results import ( PairResearchResult, create_result_database, store_config_in_database, ) -from pt_strategy.research_strategy import PtResearchStrategy -from tools.filetools import resolve_datafiles -from tools.instruments import get_instruments +from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy +from pairs_trading.lib.tools.filetools import resolve_datafiles +from pairs_trading.lib.tools.instruments import get_instruments def main() -> None: import argparse - from tools.config import expand_filename, load_config + from pairs_trading.lib.tools.config import expand_filename, load_config parser = argparse.ArgumentParser(description="Run pairs trading backtest.") parser.add_argument( diff --git a/research/notebooks/pair_trading_test.ipynb b/research/notebooks/pair_trading_test.ipynb index c08187b..0bcbff6 100644 --- a/research/notebooks/pair_trading_test.ipynb +++ b/research/notebooks/pair_trading_test.ipynb @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -135,8 +135,8 @@ " from IPython.display import clear_output\n", "\n", " # Import our modules\n", - " from pt_strategy.trading_pair import TradingPair, PairState\n", - " from pt_strategy.results import PairResearchResult\n", + " from pairs_trading.lib.pairs_trading.lib.pairs_trading.lib.pt_strategy.trading_pair import TradingPair, PairState\n", + " from pairs_trading.lib.pairs_trading.lib.pairs_trading.lib.pt_strategy.results import PairResearchResult\n", "\n", " pd.set_option('display.width', 400)\n", " pd.set_option('display.max_colwidth', None)\n", @@ -301,9 +301,9 @@ " global PT_RESULTS\n", "\n", " \n", - " from pt_strategy.trading_pair import TradingPair\n", - " from pt_strategy.research_strategy import PtResearchStrategy\n", - " from pt_strategy.results import PairResearchResult\n", + " from pairs_trading.lib.pt_strategy.trading_pair import TradingPair\n", + " from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy\n", + " from pairs_trading.lib.pt_strategy.results import PairResearchResult\n", "\n", " # Create trading pair\n", " PT_RESULTS = PairResearchResult(config=PT_BT_CONFIG)\n", diff --git a/research/tools/research_tools.py b/research/tools/research_tools.py index a85aef3..176faf7 100644 --- a/research/tools/research_tools.py +++ b/research/tools/research_tools.py @@ -3,7 +3,7 @@ import os from typing import Dict, List, Optional import pandas as pd -from pt_trading.fit_method import PairsTradingFitMethod +from pairs_trading.lib.pt_trading.fit_method import PairsTradingFitMethod def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: diff --git a/tests/viz_test.py b/tests/viz_test.py index e35f5b7..87abc4b 100644 --- a/tests/viz_test.py +++ b/tests/viz_test.py @@ -3,18 +3,18 @@ from __future__ import annotations import os from typing import Any, Dict -from pt_strategy.results import (PairResearchResult, create_result_database, +from pairs_trading.lib.pt_strategy.results import (PairResearchResult, create_result_database, store_config_in_database) -from pt_strategy.research_strategy import PtResearchStrategy -from tools.filetools import resolve_datafiles -from tools.instruments import get_instruments -from tools.viz.viz_trades import visualize_trades +from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy +from pairs_trading.lib.tools.filetools import resolve_datafiles +from pairs_trading.lib.tools.instruments import get_instruments +from pairs_trading.lib.tools.viz.viz_trades import visualize_trades def main() -> None: import argparse - from tools.config import expand_filename, load_config + from pairs_trading.lib.tools.config import expand_filename, load_config parser = argparse.ArgumentParser(description="Run pairs trading backtest.") parser.add_argument(