dev progress
This commit is contained in:
parent
2e32b26fad
commit
121c85def0
5
.vscode/pairs_trading.code-workspace
vendored
5
.vscode/pairs_trading.code-workspace
vendored
@ -3,5 +3,8 @@
|
||||
{
|
||||
"path": ".."
|
||||
}
|
||||
]
|
||||
],
|
||||
"settings": {
|
||||
"workbench.colorTheme": "Dracula Theme"
|
||||
}
|
||||
}
|
||||
@ -15,7 +15,7 @@ from cvttpy_trading.trading.active_instruments import Instruments
|
||||
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
|
||||
# ---
|
||||
from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy
|
||||
# from pairs_trading.lib.pt_strategy.live.pricer_md_client import PtMktDataClient
|
||||
from pairs_trading.lib.live.mkt_data_client import CvttRestMktDataClient, MdSummary
|
||||
from pairs_trading.lib.pt_strategy.live.ti_sender import TradingInstructionsSender
|
||||
|
||||
# import sys
|
||||
@ -23,6 +23,20 @@ from pairs_trading.lib.pt_strategy.live.ti_sender import TradingInstructionsSend
|
||||
# for path in sys.path:
|
||||
# print(path)
|
||||
|
||||
'''
|
||||
Config
|
||||
=======
|
||||
{
|
||||
"cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456",
|
||||
"ti_config": {
|
||||
TODO
|
||||
},
|
||||
"strategy_config": {
|
||||
TODO
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
HistMdCbT = Callable[[List[MdTradesAggregate]], Coroutine]
|
||||
UpdateMdCbT = Callable[[MdTradesAggregate], Coroutine]
|
||||
|
||||
@ -31,7 +45,7 @@ class PairsTrader(NamedObject):
|
||||
instruments_: List[JsonDictT]
|
||||
|
||||
live_strategy_: PtLiveStrategy
|
||||
# pricer_client_: PtMktDataClient
|
||||
pricer_client_: CvttRestMktDataClient
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.instruments_ = []
|
||||
@ -90,16 +104,10 @@ class PairsTrader(NamedObject):
|
||||
Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}")
|
||||
|
||||
# # ------- CREATE PRICER CLIENT -------
|
||||
# URGENT
|
||||
# pricer_config = self.config_.get_subconfig("pricer_config", {})
|
||||
# self.pricer_client_ = PtMktDataClient(
|
||||
# live_strategy=self.live_strategy_,
|
||||
# pricer_config=pricer_config
|
||||
# )
|
||||
# Log.info(f"{self.fname()} CVTT Pricer client created: {self.pricer_client_}")
|
||||
self.pricer_client_ = CvttRestMktDataClient(self.config_)
|
||||
|
||||
# ------- CREATE TRADER CLIENT -------
|
||||
# URGENT
|
||||
# URGENT CREATE TRADER CLIENT
|
||||
# (send TradingInstructions)
|
||||
# ti_config = self.config_.get_subconfig("ti_config", {})
|
||||
# self.ti_sender_ = TradingInstructionsSender(config=ti_config)
|
||||
@ -107,11 +115,22 @@ class PairsTrader(NamedObject):
|
||||
|
||||
|
||||
# # ------- CREATE REST SERVER -------
|
||||
# for dashboard communications
|
||||
# URGENT CREATE REST SERVER for dashboard communications
|
||||
|
||||
async def subscribe_md(self) -> None:
|
||||
pass
|
||||
# URGENT implement PairsTrader.subscribe_md()
|
||||
for inst in self.instruments_:
|
||||
exch_acct = inst.get("exch_acct", "?exch_acct?")
|
||||
instrument_id = inst.get("instrument_id", "?instrument_id?")
|
||||
await self.pricer_client_.add_subscription(
|
||||
exch_acct=exch_acct,
|
||||
instrument_id=instrument_id,
|
||||
interval_sec=self.live_strategy_.interval_sec(),
|
||||
history_depth_sec=self.live_strategy_.history_depth_sec(),
|
||||
callback=self._on_md_summary
|
||||
)
|
||||
|
||||
def _on_md_summary(self, history: List[MdSummary]) -> None:
|
||||
pass # URGENT
|
||||
|
||||
async def run(self) -> None:
|
||||
Log.info(f"{self.fname()} ...")
|
||||
|
||||
@ -1,220 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Coroutine, Dict, Optional
|
||||
|
||||
import websockets
|
||||
from cvttpy_tools.logger import Log
|
||||
from cvttpy_tools.settings.cvtt_types import JsonDictT
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
MessageTypeT = str
|
||||
SubscriptionIdT = str
|
||||
MessageT = Dict
|
||||
UrlT = str
|
||||
CallbackT = Callable[[MessageTypeT, SubscriptionIdT, MessageT], Coroutine[None, str, None]]
|
||||
|
||||
@dataclass
|
||||
class CvttPricesSubscription:
|
||||
id_: str
|
||||
exchange_config_name_: str
|
||||
instrument_id_: str
|
||||
interval_sec_: int
|
||||
history_depth_sec_: int
|
||||
is_subscribed_: bool
|
||||
is_historical_: bool
|
||||
callback_: CallbackT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exchange_config_name: str,
|
||||
instrument_id: str,
|
||||
interval_sec: int,
|
||||
history_depth_sec: int,
|
||||
callback: CallbackT,
|
||||
):
|
||||
self.exchange_config_name_ = exchange_config_name
|
||||
self.instrument_id_ = instrument_id
|
||||
self.interval_sec_ = interval_sec
|
||||
self.history_depth_sec_ = history_depth_sec
|
||||
self.callback_ = callback
|
||||
self.id_ = str(uuid.uuid4())
|
||||
self.is_subscribed_ = False
|
||||
self.is_historical_ = history_depth_sec > 0
|
||||
|
||||
class CvttWebSockClient:
|
||||
ws_url_: UrlT
|
||||
websocket_: Optional[ClientConnection]
|
||||
is_connected_: bool
|
||||
|
||||
def __init__(self, url: str):
|
||||
self.ws_url_ = url
|
||||
self.websocket_ = None
|
||||
self.is_connected_ = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.websocket_ = await websockets.connect(self.ws_url_)
|
||||
self.is_connected_ = True
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.websocket_ is not None:
|
||||
await self.websocket_.close()
|
||||
self.is_connected_ = False
|
||||
|
||||
async def receive_message(self) -> JsonDictT:
|
||||
assert self.websocket_ is not None
|
||||
assert self.is_connected_
|
||||
message = await self.websocket_.recv()
|
||||
message_str = (
|
||||
message.decode("utf-8")
|
||||
if isinstance(message, bytes)
|
||||
else message
|
||||
)
|
||||
res = json.loads(message_str)
|
||||
assert res is not None
|
||||
assert isinstance(res, dict)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
async def check_connection(cls, url: str) -> bool:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
result = True
|
||||
except Exception as e:
|
||||
Log.error(f"Unable to connect to {url}: {str(e)}")
|
||||
result = False
|
||||
return result
|
||||
|
||||
class CvttPricerWebSockClient(CvttWebSockClient):
|
||||
# Class members with type hints
|
||||
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
|
||||
|
||||
def __init__(self, url: str):
|
||||
super().__init__(url)
|
||||
self.subscriptions_ = {}
|
||||
|
||||
async def subscribe(
|
||||
self, subscription: CvttPricesSubscription
|
||||
) -> str: # returns subscription id
|
||||
|
||||
if not self.is_connected_:
|
||||
try:
|
||||
Log.info(f"Connecting to {self.ws_url_}")
|
||||
await self.connect()
|
||||
except Exception as e:
|
||||
Log.error(f"Unable to connect to {self.ws_url_}: {str(e)}")
|
||||
raise e
|
||||
|
||||
subscr_msg = {
|
||||
"type": "subscr",
|
||||
"id": subscription.id_,
|
||||
"subscr_type": "MD_AGGREGATE",
|
||||
"exchange_config_name": subscription.exchange_config_name_,
|
||||
"instrument_id": subscription.instrument_id_,
|
||||
"interval_sec": subscription.interval_sec_,
|
||||
}
|
||||
if subscription.is_historical_:
|
||||
subscr_msg["history_depth_sec"] = subscription.history_depth_sec_
|
||||
|
||||
assert self.websocket_ is not None
|
||||
await self.websocket_.send(json.dumps(subscr_msg))
|
||||
|
||||
response = await self.websocket_.recv()
|
||||
response_data = json.loads(response)
|
||||
if not await self.handle_subscription_response(subscription, response_data):
|
||||
await self.websocket_.close()
|
||||
self.is_connected_ = False
|
||||
raise Exception(f"Subscription failed: {str(response)}")
|
||||
|
||||
self.subscriptions_[subscription.id_] = subscription
|
||||
return subscription.id_
|
||||
|
||||
async def handle_subscription_response(
|
||||
self, subscription: CvttPricesSubscription, response: dict
|
||||
) -> bool:
|
||||
if response.get("type") != "subscr" or response.get("id") != subscription.id_:
|
||||
return False
|
||||
|
||||
if response.get("status") == "success":
|
||||
Log.info(f"Subscription successful: {json.dumps(response)}")
|
||||
return True
|
||||
elif response.get("status") == "error":
|
||||
Log.error(f"Subscription failed: {response.get('reason')}")
|
||||
return False
|
||||
return False
|
||||
|
||||
async def run(self) -> None:
|
||||
assert self.websocket_
|
||||
try:
|
||||
while self.is_connected_:
|
||||
try:
|
||||
msg_dict: JsonDictT = await self.receive_message()
|
||||
except websockets.ConnectionClosed:
|
||||
Log.warning("Connection closed")
|
||||
self.is_connected_ = False
|
||||
break
|
||||
except Exception as e:
|
||||
Log.error(f"Error occurred: {str(e)}")
|
||||
self.is_connected_ = False
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
|
||||
await self.process_message(msg_dict)
|
||||
|
||||
except Exception as e:
|
||||
Log.error(f"Error occurred: {str(e)}")
|
||||
self.is_connected_ = False
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
|
||||
async def process_message(self, message: Dict) -> None:
|
||||
message_type = message.get("type")
|
||||
if message_type in ["md_aggregate", "historical_md_aggregate"]:
|
||||
subscription_id = message.get("subscr_id")
|
||||
if subscription_id not in self.subscriptions_:
|
||||
Log.warning(f"Unknown subscription id: {subscription_id}")
|
||||
return
|
||||
|
||||
subscription = self.subscriptions_[subscription_id]
|
||||
await subscription.callback_(message_type, subscription_id, message)
|
||||
else:
|
||||
Log.warning(f"Unknown message type: {message.get('type')}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
|
||||
print(f"{message_type=} {subscr_id=} {instrument_id}")
|
||||
if message_type == "md_aggregate":
|
||||
aggr = message.get("md_aggregate", [])
|
||||
print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}")
|
||||
elif message_type == "historical_md_aggregate":
|
||||
for aggr in message.get("historical_data", []):
|
||||
print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}")
|
||||
else:
|
||||
print(f"Unknown message type: {message_type}")
|
||||
|
||||
pricer_client = CvttPricerWebSockClient(
|
||||
"ws://localhost:12346/ws"
|
||||
)
|
||||
await pricer_client.subscribe(CvttPricesSubscription(
|
||||
exchange_config_name="COINBASE_AT",
|
||||
instrument_id="PAIR-BTC-USD",
|
||||
interval_sec=60,
|
||||
history_depth_sec=60*60*24,
|
||||
callback=partial(on_message, instrument_id="PAIR-BTC-USD")
|
||||
))
|
||||
await pricer_client.subscribe(CvttPricesSubscription(
|
||||
exchange_config_name="COINBASE_AT",
|
||||
instrument_id="PAIR-ETH-USD",
|
||||
interval_sec=60,
|
||||
history_depth_sec=60*60*24,
|
||||
callback=partial(on_message, instrument_id="PAIR-ETH-USD")
|
||||
))
|
||||
|
||||
await pricer_client.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,70 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, Any, List, Optional
|
||||
import time
|
||||
from typing import Callable, Dict, Any, List, Optional, Set
|
||||
|
||||
import requests
|
||||
|
||||
from cvttpy_tools.base import NamedObject
|
||||
from cvttpy_tools.app import App
|
||||
from cvttpy_tools.logger import Log
|
||||
from cvttpy_tools.config import Config
|
||||
from cvttpy_tools.timer import Timer
|
||||
|
||||
from cvttpy_tools.timeutils import NanoPerSec, NanosT, current_nanoseconds, current_seconds
|
||||
from cvttpy_tools.timeutils import NanosT, current_seconds
|
||||
from cvttpy_tools.settings.cvtt_types import InstrumentIdT, IntervalSecT
|
||||
# ---
|
||||
from cvttpy_trading.trading.mkt_data.historical_md import HistMdBar
|
||||
from cvttpy_trading.trading.accounting.exch_account import ExchangeAccountNameT
|
||||
# ---
|
||||
from pairs_trading.lib.live.rest_client import RESTSender
|
||||
|
||||
|
||||
class RESTSender(NamedObject):
|
||||
session_: requests.Session
|
||||
base_url_: str
|
||||
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url_ = base_url
|
||||
self.session_ = requests.Session()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Checks if the server is up and responding"""
|
||||
url = f"{self.base_url_}/ping"
|
||||
try:
|
||||
response = self.session_.get(url)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
def send_post(self, endpoint: str, post_body: Dict) -> requests.Response:
|
||||
|
||||
while not self.is_ready():
|
||||
print("Waiting for FrontGateway to start...")
|
||||
time.sleep(5)
|
||||
|
||||
url = f"{self.base_url_}/{endpoint}"
|
||||
try:
|
||||
return self.session_.request(
|
||||
method="POST",
|
||||
url=url,
|
||||
json=post_body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
except requests.exceptions.RequestException as excpt:
|
||||
raise ConnectionError(
|
||||
f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore
|
||||
) from excpt
|
||||
|
||||
def send_get(self, endpoint: str) -> requests.Response:
|
||||
while not self.is_ready():
|
||||
print("Waiting for FrontGateway to start...")
|
||||
time.sleep(5)
|
||||
|
||||
url = f"{self.base_url_}/{endpoint}"
|
||||
try:
|
||||
return self.session_.request(method="GET", url=url)
|
||||
except requests.exceptions.RequestException as excpt:
|
||||
raise ConnectionError(
|
||||
f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore
|
||||
) from excpt
|
||||
|
||||
class MdSummary(HistMdBar):
|
||||
def __init__(
|
||||
self,
|
||||
@ -110,10 +64,10 @@ MdSummaryCallbackT = Callable[[List[MdSummary]], None]
|
||||
|
||||
class MdSummaryCollector(NamedObject):
|
||||
sender_: RESTSender
|
||||
exch_acct_: str
|
||||
instrument_id_: str
|
||||
interval_sec_: int
|
||||
history_depth_sec_: int
|
||||
exch_acct_: ExchangeAccountNameT
|
||||
instrument_id_: InstrumentIdT
|
||||
interval_sec_: IntervalSecT
|
||||
history_depth_sec_: IntervalSecT
|
||||
|
||||
history_: List[MdSummary]
|
||||
callbacks_: List[MdSummaryCallbackT]
|
||||
@ -122,10 +76,10 @@ class MdSummaryCollector(NamedObject):
|
||||
def __init__(
|
||||
self,
|
||||
sender: RESTSender,
|
||||
exch_acct: str,
|
||||
instrument_id: str,
|
||||
interval_sec: int,
|
||||
history_depth_sec: int,
|
||||
exch_acct: ExchangeAccountNameT,
|
||||
instrument_id: InstrumentIdT,
|
||||
interval_sec: IntervalSecT,
|
||||
history_depth_sec: IntervalSecT,
|
||||
) -> None:
|
||||
self.sender_ = sender
|
||||
self.exch_acct_ = exch_acct
|
||||
@ -140,6 +94,9 @@ class MdSummaryCollector(NamedObject):
|
||||
def add_callback(self, cb: MdSummaryCallbackT) -> None:
|
||||
self.callbacks_.append(cb)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.exch_acct_, self.instrument_id_, self.interval_sec_, self.history_depth_sec_))
|
||||
|
||||
def rqst_data(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"exch_acct": self.exch_acct_,
|
||||
@ -192,7 +149,6 @@ class MdSummaryCollector(NamedObject):
|
||||
curr_sec = int(current_seconds())
|
||||
return (curr_sec - curr_sec % self.interval_sec_) + self.interval_sec_ + 2
|
||||
|
||||
|
||||
async def _load_new(self) -> None:
|
||||
|
||||
last: Optional[MdSummary] = self.get_last()
|
||||
@ -213,39 +169,54 @@ class MdSummaryCollector(NamedObject):
|
||||
self.timer_.cancel()
|
||||
self.timer_ = None
|
||||
|
||||
class CvttRESTClient(NamedObject):
|
||||
class CvttRestMktDataClient(NamedObject):
|
||||
config_: Config
|
||||
sender_: RESTSender
|
||||
collectors_: Set[MdSummaryCollector]
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config_ = config
|
||||
base_url = self.config_.get_value("cvtt_base_url", default="")
|
||||
assert base_url
|
||||
self.sender_ = RESTSender(base_url=base_url)
|
||||
self.collectors_ = set()
|
||||
|
||||
async def add_subscription(self,
|
||||
exch_acct: ExchangeAccountNameT,
|
||||
instrument_id: InstrumentIdT,
|
||||
interval_sec: IntervalSecT,
|
||||
history_depth_sec: IntervalSecT,
|
||||
callback: MdSummaryCallbackT
|
||||
) -> None:
|
||||
mdsc = MdSummaryCollector(
|
||||
sender=self.sender_,
|
||||
exch_acct=exch_acct,
|
||||
instrument_id=instrument_id,
|
||||
interval_sec=interval_sec,
|
||||
history_depth_sec=history_depth_sec,
|
||||
)
|
||||
mdsc.add_callback(callback)
|
||||
self.collectors_.add(mdsc)
|
||||
await mdsc.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = Config(json_src={"cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456"})
|
||||
# config = Config(json_src={"cvtt_base_url": "http://dev-server-02.cvtt.vpn:23456"})
|
||||
|
||||
cvtt_client = CvttRESTClient(config)
|
||||
def _calback(history: List[MdSummary]) -> None:
|
||||
Log.info(f"MdSummary Hist Length is {len(history)}. Last summary: {history[-1] if len(history) > 0 else '[]'}")
|
||||
|
||||
mdsc = MdSummaryCollector(
|
||||
sender=cvtt_client.sender_,
|
||||
|
||||
async def __run() -> None:
|
||||
Log.info("Starting...")
|
||||
cvtt_client = CvttRestMktDataClient(config)
|
||||
await cvtt_client.add_subscription(
|
||||
exch_acct="COINBASE_AT",
|
||||
instrument_id="PAIR-BTC-USD",
|
||||
interval_sec=60,
|
||||
history_depth_sec=24 * 3600,
|
||||
callback=_calback
|
||||
)
|
||||
|
||||
def _calback(history: List[MdSummary]) -> None:
|
||||
Log.info(f"MdSummary Hist Length is {len(history)}. Last summary: {history[-1] if len(history) > 0 else '[]'}")
|
||||
|
||||
mdsc.add_callback(_calback)
|
||||
|
||||
async def __run() -> None:
|
||||
Log.info("Starting...")
|
||||
await mdsc.start()
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
67
lib/live/rest_client.py
Normal file
67
lib/live/rest_client.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, Any, List, Optional
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from cvttpy_tools.base import NamedObject
|
||||
from cvttpy_tools.logger import Log
|
||||
from cvttpy_tools.config import Config
|
||||
from cvttpy_tools.timer import Timer
|
||||
|
||||
from cvttpy_tools.timeutils import NanoPerSec, NanosT, current_nanoseconds, current_seconds
|
||||
from cvttpy_trading.trading.mkt_data.historical_md import HistMdBar
|
||||
|
||||
|
||||
class RESTSender(NamedObject):
|
||||
session_: requests.Session
|
||||
base_url_: str
|
||||
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url_ = base_url
|
||||
self.session_ = requests.Session()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Checks if the server is up and responding"""
|
||||
url = f"{self.base_url_}/ping"
|
||||
try:
|
||||
response = self.session_.get(url)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
def send_post(self, endpoint: str, post_body: Dict) -> requests.Response:
|
||||
|
||||
while not self.is_ready():
|
||||
print("Waiting for FrontGateway to start...")
|
||||
time.sleep(5)
|
||||
|
||||
url = f"{self.base_url_}/{endpoint}"
|
||||
try:
|
||||
return self.session_.request(
|
||||
method="POST",
|
||||
url=url,
|
||||
json=post_body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
except requests.exceptions.RequestException as excpt:
|
||||
raise ConnectionError(
|
||||
f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore
|
||||
) from excpt
|
||||
|
||||
def send_get(self, endpoint: str) -> requests.Response:
|
||||
while not self.is_ready():
|
||||
print("Waiting for FrontGateway to start...")
|
||||
time.sleep(5)
|
||||
|
||||
url = f"{self.base_url_}/{endpoint}"
|
||||
try:
|
||||
return self.session_.request(method="GET", url=url)
|
||||
except requests.exceptions.RequestException as excpt:
|
||||
raise ConnectionError(
|
||||
f"Failed to send status={excpt.response.status_code} {excpt.response.text}" # type: ignore
|
||||
) from excpt
|
||||
|
||||
@ -8,8 +8,7 @@ import pandas as pd
|
||||
# ---
|
||||
from cvttpy_tools.base import NamedObject
|
||||
from cvttpy_tools.app import App
|
||||
from cvttpy_tools.logger import Log
|
||||
from cvttpy_tools.settings.cvtt_types import JsonDictT
|
||||
from cvttpy_tools.settings.cvtt_types import IntervalSecT
|
||||
# ---
|
||||
from cvttpy_trading.trading.instrument import ExchangeInstrument
|
||||
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
|
||||
@ -18,7 +17,6 @@ from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pairs_trading.lib.pt_strategy.pt_model import Prediction
|
||||
from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair
|
||||
from pairs_trading.apps.pairs_trader import PairsTrader
|
||||
from pairs_trading.lib.pt_strategy.pt_market_data import RealTimeMarketData
|
||||
"""
|
||||
--config=pair.cfg
|
||||
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
|
||||
@ -41,7 +39,6 @@ class PtLiveStrategy(NamedObject):
|
||||
model_data_policy_: ModelDataPolicy
|
||||
pairs_trader_: PairsTrader
|
||||
|
||||
pt_mkt_data_: RealTimeMarketData
|
||||
# ti_sender_: TradingInstructionsSender
|
||||
|
||||
# for presentation: history of prediction values and trading signals
|
||||
@ -90,26 +87,31 @@ class PtLiveStrategy(NamedObject):
|
||||
pass # URGENT PtiveStrategy.on_mkt_data_hist_snapshot()
|
||||
|
||||
async def on_mkt_data_update(self, aggr: MdTradesAggregate) -> None:
|
||||
market_data_df = await self.pt_mkt_data_.on_mkt_data_update(update=aggr)
|
||||
if market_data_df is not None:
|
||||
self.trading_pair_.market_data_ = market_data_df
|
||||
self.model_data_policy_.advance()
|
||||
prediction = self.trading_pair_.run(
|
||||
market_data_df, self.model_data_policy_.advance()
|
||||
)
|
||||
self.predictions_ = pd.concat(
|
||||
[self.predictions_, prediction.to_df()], ignore_index=True
|
||||
)
|
||||
# if market_data_df is not None:
|
||||
# self.trading_pair_.market_data_ = market_data_df
|
||||
# self.model_data_policy_.advance()
|
||||
# prediction = self.trading_pair_.run(
|
||||
# market_data_df, self.model_data_policy_.advance()
|
||||
# )
|
||||
# self.predictions_ = pd.concat(
|
||||
# [self.predictions_, prediction.to_df()], ignore_index=True
|
||||
# )
|
||||
|
||||
trading_instructions: List[TradingInstruction] = (
|
||||
self._create_trading_instructions(
|
||||
prediction=prediction, last_row=market_data_df.iloc[-1]
|
||||
)
|
||||
)
|
||||
if len(trading_instructions) > 0:
|
||||
await self._send_trading_instructions(trading_instructions)
|
||||
# trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
|
||||
pass
|
||||
# trading_instructions: List[TradingInstruction] = (
|
||||
# self._create_trading_instructions(
|
||||
# prediction=prediction, last_row=market_data_df.iloc[-1]
|
||||
# )
|
||||
# )
|
||||
# if len(trading_instructions) > 0:
|
||||
# await self._send_trading_instructions(trading_instructions)
|
||||
# # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
|
||||
pass # URGENT
|
||||
|
||||
def interval_sec(self) -> IntervalSecT:
|
||||
return 60 # URGENT use config
|
||||
|
||||
def history_depth_sec(self) -> IntervalSecT:
|
||||
return 3600 * 60 * 2 # URGENT use config
|
||||
|
||||
async def _send_trading_instructions(
|
||||
self, trading_instructions: List[TradingInstruction]
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||
```python
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
# from cvtt_client.mkt_data import (CvttPricerWebSockClient,
|
||||
# CvttPricesSubscription, MessageTypeT,
|
||||
# SubscriptionIdT)
|
||||
from cvttpy_tools.app import App
|
||||
from cvttpy_tools.base import NamedObject
|
||||
from cvttpy_tools.config import Config
|
||||
from cvttpy_tools.logger import Log
|
||||
from cvttpy_tools.settings.cvtt_types import JsonDictT
|
||||
from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy
|
||||
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
|
||||
|
||||
"""
|
||||
--config=pair.cfg
|
||||
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
|
||||
"""
|
||||
|
||||
|
||||
class PtMktDataClient(NamedObject):
|
||||
config_: Config
|
||||
live_strategy_: PtLiveStrategy
|
||||
pricer_client_: CvttPricerWebSockClient
|
||||
subscriptions_: List[CvttPricesSubscription]
|
||||
|
||||
def __init__(self, live_strategy: PtLiveStrategy, pricer_config: Config):
|
||||
self.config_ = pricer_config
|
||||
self.live_strategy_ = live_strategy
|
||||
|
||||
App.instance().add_call(App.Stage.Start, self._on_start())
|
||||
App.instance().add_call(App.Stage.Run, self.run())
|
||||
|
||||
async def _on_start(self) -> None:
|
||||
pricer_url = self.config_.get_value("pricer_url")
|
||||
assert pricer_url is not None, "pricer_url is not found in config"
|
||||
self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)
|
||||
|
||||
|
||||
async def _subscribe(self) -> None:
|
||||
history_depth_sec = self.config_.get_value("history_depth_sec", 86400)
|
||||
interval_sec = self.config_.get_value("interval_sec", 60)
|
||||
|
||||
pair: TradingPair = self.live_strategy_.trading_pair_
|
||||
subscriptions = [CvttPricesSubscription(
|
||||
exchange_config_name=instrument["exchange_config_name"],
|
||||
instrument_id=instrument["instrument_id"],
|
||||
interval_sec=interval_sec,
|
||||
history_depth_sec=history_depth_sec,
|
||||
callback=partial(
|
||||
self.on_message, instrument_id=instrument["instrument_id"]
|
||||
),
|
||||
) for instrument in pair.instruments_]
|
||||
|
||||
for subscription in subscriptions:
|
||||
Log.info(f"{self.fname()} Subscribing to {subscription}")
|
||||
await self.pricer_client_.subscribe(subscription)
|
||||
|
||||
async def on_message(
|
||||
self,
|
||||
message_type: MessageTypeT,
|
||||
subscr_id: SubscriptionIdT,
|
||||
message: Dict,
|
||||
instrument_id: str,
|
||||
) -> None:
|
||||
Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
|
||||
aggr: JsonDictT
|
||||
if message_type == "md_aggregate":
|
||||
aggr = message.get("md_aggregate", {})
|
||||
await self.live_strategy_.on_mkt_data_update(aggr)
|
||||
elif message_type == "historical_md_aggregate":
|
||||
aggr = message.get("historical_data", {})
|
||||
await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
|
||||
else:
|
||||
Log.info(f"Unknown message type: {message_type}")
|
||||
|
||||
async def run(self) -> None:
|
||||
if not await CvttPricerWebSockClient.check_connection(self.pricer_client_.ws_url_):
|
||||
Log.error(f"Unable to connect to {self.pricer_client_.ws_url_}")
|
||||
raise Exception(f"Unable to connect to {self.pricer_client_.ws_url_}")
|
||||
await self._subscribe()
|
||||
|
||||
await self.pricer_client_.run()
|
||||
```
|
||||
@ -91,7 +91,7 @@ class OptimizedWndDataPolicy(ModelDataPolicy, ABC):
|
||||
self.min_training_size_ = cast(int, config.get("min_training_size"))
|
||||
self.max_training_size_ = cast(int, config.get("max_training_size"))
|
||||
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
|
||||
self.pair_ = cast(TradingPair, kwargs.get("pair"))
|
||||
|
||||
if "mkt_data" in kwargs:
|
||||
|
||||
@ -6,8 +6,8 @@ import statsmodels.api as sm
|
||||
|
||||
|
||||
|
||||
from pt_strategy.pt_model import PairsTradingModel, Prediction
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
from pairs_trading.lib.pt_strategy.pt_model import PairsTradingModel, Prediction
|
||||
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
|
||||
|
||||
|
||||
class OLSModel(PairsTradingModel):
|
||||
|
||||
@ -171,63 +171,3 @@ class ResearchMarketData(PtMarketData):
|
||||
f"exec_price_{self.symbol_b_}",
|
||||
]
|
||||
|
||||
class RealTimeMarketData(PtMarketData):
|
||||
|
||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
|
||||
async def on_mkt_data_hist_snapshot(self, snapshot: JsonDictT) -> None:
|
||||
# URGENT
|
||||
# create origin_mkt_data_df_ from snapshot
|
||||
# verify that the data for both instruments are present
|
||||
|
||||
# transform it to market_data_df_ tstamp, close_symbolA, close_symbolB
|
||||
'''
|
||||
# from cvttpy/exchanges/binance/spot/mkt_data.py
|
||||
values = {
|
||||
"time_ns": time_ns,
|
||||
"tstamp": format_nanos_utc(time_ns),
|
||||
"exchange_id": exch_inst.exchange_id_,
|
||||
"instrument_id": exch_inst.instrument_id(),
|
||||
"interval_ns": interval_sec * 1_000_000_000,
|
||||
"open": float(kline[1]),
|
||||
"high": float(kline[2]),
|
||||
"low": float(kline[3]),
|
||||
"close": float(kline[4]),
|
||||
"volume": float(kline[5]),
|
||||
"num_trades": kline[8],
|
||||
"vwap": float(kline[7]) / float(kline[5]) if float(kline[5]) > 0 else 0.0 # Calculate VWAP
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
pass
|
||||
|
||||
async def on_mkt_data_update(self, update: MdTradesAggregate) -> Optional[pd.DataFrame]:
|
||||
# URGENT
|
||||
# make sure update has both instruments
|
||||
# create DataFrame tmp1 from update
|
||||
# transform tmp1 into temp. datframe tmp2
|
||||
# add tmp1 to origin_mkt_data_df_
|
||||
# add tmp2 to market_data_df_
|
||||
# return market_data_df_
|
||||
'''
|
||||
class MdTradesAggregate(NamedObject):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"time_ns": self.time_ns_,
|
||||
"tstamp": format_nanos_utc(self.time_ns_),
|
||||
"exchange_id": self.exch_inst_.exchange_id_,
|
||||
"instrument_id": self.exch_inst_.instrument_id(),
|
||||
"interval_ns": self.interval_ns_,
|
||||
"open": self.exch_inst_.get_price(self.open_),
|
||||
"high": self.exch_inst_.get_price(self.high_),
|
||||
"low": self.exch_inst_.get_price(self.low_),
|
||||
"close": self.exch_inst_.get_price(self.close_),
|
||||
"volume": self.exch_inst_.get_quantity(self.volume_),
|
||||
"vwap": self.exch_inst_.get_price(self.vwap_),
|
||||
"num_trades": self.exch_inst_.get_quantity(self.num_trades_),
|
||||
}
|
||||
'''
|
||||
|
||||
return pd.DataFrame()
|
||||
@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pt_strategy.pt_market_data import ResearchMarketData
|
||||
from pt_strategy.pt_model import Prediction
|
||||
from pt_strategy.trading_pair import PairState, TradingPair
|
||||
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pairs_trading.lib.pt_strategy.pt_market_data import ResearchMarketData
|
||||
from pairs_trading.lib.pt_strategy.pt_model import Prediction
|
||||
from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair
|
||||
|
||||
|
||||
class PtResearchStrategy:
|
||||
@ -24,8 +24,8 @@ class PtResearchStrategy:
|
||||
datafiles: List[str],
|
||||
instruments: List[Dict[str, str]],
|
||||
):
|
||||
from pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
|
||||
|
||||
self.config_ = config
|
||||
self.trades_ = []
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import date, datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
|
||||
|
||||
|
||||
# Recommended replacement adapters and converters for Python 3.12+
|
||||
|
||||
@ -57,7 +57,7 @@ class TradingPair:
|
||||
instruments: List[Dict[str, str]],
|
||||
):
|
||||
|
||||
from pt_strategy.pt_model import PairsTradingModel
|
||||
from pairs_trading.lib.pt_strategy.pt_model import PairsTradingModel
|
||||
|
||||
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy
|
||||
|
||||
|
||||
def visualize_prices(strategy: PtResearchStrategy, trading_date: str) -> None:
|
||||
|
||||
@ -3,11 +3,11 @@ from __future__ import annotations
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pt_strategy.results import (PairResearchResult, create_result_database,
|
||||
from pairs_trading.lib.pairs_trading.lib.tegy.results import (PairResearchResult, create_result_database,
|
||||
store_config_in_database)
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
from pairs_trading.lib.pairs_trading.lib.t_strategy.research_strategy import PtResearchStrategy
|
||||
from pairs_trading.lib.tools.filetools import resolve_datafiles
|
||||
from pairs_trading.lib.tools.instruments import get_instruments
|
||||
|
||||
|
||||
def visualize_trades(strategy: PtResearchStrategy, results: PairResearchResult, trading_date: str) -> None:
|
||||
|
||||
@ -170,6 +170,7 @@ types-PyYAML>=5.4
|
||||
types-redis>=3.5
|
||||
types-requests>=2.25
|
||||
types-retry>=0.9
|
||||
types-seaborn>0.13.2
|
||||
types-selenium>=3.141
|
||||
types-Send2Trash>=1.8
|
||||
types-setuptools>=57.4
|
||||
|
||||
@ -3,20 +3,20 @@ from __future__ import annotations
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pt_strategy.results import (
|
||||
from pairs_trading.lib.pt_strategy.results import (
|
||||
PairResearchResult,
|
||||
create_result_database,
|
||||
store_config_in_database,
|
||||
)
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy
|
||||
from pairs_trading.lib.tools.filetools import resolve_datafiles
|
||||
from pairs_trading.lib.tools.instruments import get_instruments
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
from tools.config import expand_filename, load_config
|
||||
from pairs_trading.lib.tools.config import expand_filename, load_config
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
||||
parser.add_argument(
|
||||
|
||||
@ -117,7 +117,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -135,8 +135,8 @@
|
||||
" from IPython.display import clear_output\n",
|
||||
"\n",
|
||||
" # Import our modules\n",
|
||||
" from pt_strategy.trading_pair import TradingPair, PairState\n",
|
||||
" from pt_strategy.results import PairResearchResult\n",
|
||||
" from pairs_trading.lib.pairs_trading.lib.pairs_trading.lib.pt_strategy.trading_pair import TradingPair, PairState\n",
|
||||
" from pairs_trading.lib.pairs_trading.lib.pairs_trading.lib.pt_strategy.results import PairResearchResult\n",
|
||||
"\n",
|
||||
" pd.set_option('display.width', 400)\n",
|
||||
" pd.set_option('display.max_colwidth', None)\n",
|
||||
@ -301,9 +301,9 @@
|
||||
" global PT_RESULTS\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" from pt_strategy.trading_pair import TradingPair\n",
|
||||
" from pt_strategy.research_strategy import PtResearchStrategy\n",
|
||||
" from pt_strategy.results import PairResearchResult\n",
|
||||
" from pairs_trading.lib.pt_strategy.trading_pair import TradingPair\n",
|
||||
" from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy\n",
|
||||
" from pairs_trading.lib.pt_strategy.results import PairResearchResult\n",
|
||||
"\n",
|
||||
" # Create trading pair\n",
|
||||
" PT_RESULTS = PairResearchResult(config=PT_BT_CONFIG)\n",
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from pt_trading.fit_method import PairsTradingFitMethod
|
||||
from pairs_trading.lib.pt_trading.fit_method import PairsTradingFitMethod
|
||||
|
||||
|
||||
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
|
||||
|
||||
@ -3,18 +3,18 @@ from __future__ import annotations
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pt_strategy.results import (PairResearchResult, create_result_database,
|
||||
from pairs_trading.lib.pt_strategy.results import (PairResearchResult, create_result_database,
|
||||
store_config_in_database)
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
from tools.viz.viz_trades import visualize_trades
|
||||
from pairs_trading.lib.pt_strategy.research_strategy import PtResearchStrategy
|
||||
from pairs_trading.lib.tools.filetools import resolve_datafiles
|
||||
from pairs_trading.lib.tools.instruments import get_instruments
|
||||
from pairs_trading.lib.tools.viz.viz_trades import visualize_trades
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
from tools.config import expand_filename, load_config
|
||||
from pairs_trading.lib.tools.config import expand_filename, load_config
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
||||
parser.add_argument(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user