progress. not ready, changing live_strategy.py to use ExchangeInstrument and new TradingInstruction class

This commit is contained in:
Oleg Sheynin 2025-08-05 21:48:23 +00:00
parent 7ab09669b4
commit 0423a7d34f
11 changed files with 538 additions and 196 deletions

24
.vscode/launch.json vendored
View File

@ -10,8 +10,30 @@
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "${file}",
"console": "integratedTerminal"
"console": "integratedTerminal",
"env": {
"PYTHONPATH": "${workspaceFolder}/lib:${workspaceFolder}/.."
},
},
{
"name": "-------- Live Pair Trading --------",
},
{
"name": "PAIRS TRADER",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "${workspaceFolder}/bin/pairs_trader.py",
"console": "integratedTerminal",
"env": {
"PYTHONPATH": "${workspaceFolder}/lib:${workspaceFolder}/.."
},
"args": [
"--config=${workspaceFolder}/configuration/pairs_trader.cfg",
"--pair=PAIR-ADA-USDT:BNBSPOT,PAIR-SOL-USDT:BNBSPOT",
],
},
{
"name": "-------- OLS --------",

View File

@ -1,9 +1,12 @@
{
"folders": [
{
"path": ".."
}
],
{
"path": ".."
},
{
"path": "../../cvttpy_base"
}
],
"settings": {
"workbench.colorTheme": "Dracula Theme"
}

104
bin/pairs_trader.py Normal file
View File

@ -0,0 +1,104 @@
from __future__ import annotations
from functools import partial
from typing import Dict, List
# import sys
# print("PYTHONPATH directories:")
# for path in sys.path:
# print(path)
from cvttpy_base.tools.app import App
from cvttpy_base.tools.base import NamedObject
from cvttpy_base.tools.logger import Log
from cvttpy_base.tools.config import CvttAppConfig
from cvttpy_base.settings.cvtt_types import JsonDictT
from pt_strategy.live.live_strategy import PtLiveStrategy
from pt_strategy.live.pricer_md_client import PtMktDataClient
from pt_strategy.live.ti_sender import TradingInstructionsSender
# from cvtt_client.mkt_data import (CvttPricerWebSockClient,
# CvttPricesSubscription, MessageTypeT,
# SubscriptionIdT)
class PairTradingRunner(NamedObject):
config_: CvttAppConfig
instruments_: List[JsonDictT]
live_strategy_: PtLiveStrategy
pricer_client_: PtMktDataClient
def __init__(self) -> None:
self.instruments_ = []
App.instance().add_cmdline_arg(
"--pair",
type=str,
required=True,
help=(
"Comma-separated pair of instrument symbols"
" with exchange config name"
" (e.g., PAIR-BTC-USD:BNBSPOT,PAIR-ETH-USD:BNBSPOT)"
),
)
App.instance().add_call(App.Stage.Config, self._on_config())
App.instance().add_call(App.Stage.Run, self.run())
async def _on_config(self) -> None:
self.config_ = CvttAppConfig.instance()
# ------- PARSE INSTRUMENTS -------
instr_str = App.instance().get_argument("pair", "")
if not instr_str:
raise ValueError("Pair is required")
instr_list = instr_str.split(",")
for instr in instr_list:
instr_parts = instr.split(":")
if len(instr_parts) != 2:
raise ValueError(f"Invalid pair format: {instr}")
instrument_id = instr_parts[0]
exchange_config_name = instr_parts[1]
self.instruments_.append({
"exchange_config_name": exchange_config_name,
"instrument_id": instrument_id
})
assert len(self.instruments_) == 2, "Only two instruments are supported"
Log.info(f"{self.fname()} Instruments: {self.instruments_}")
# ------- CREATE TI (trading instructions) CLIENT -------
ti_config = self.config_.get_subconfig("ti_config", {})
self.ti_sender_ = TradingInstructionsSender(config=ti_config)
Log.info(f"{self.fname()} TI client created: {self.ti_sender_}")
# ------- CREATE STRATEGY -------
strategy_config = self.config_.get_value("strategy_config", {})
self.live_strategy_ = PtLiveStrategy(
config=strategy_config,
instruments=self.instruments_,
ti_sender=self.ti_sender_
)
Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}")
# ------- CREATE PRICER CLIENT -------
pricer_config = self.config_.get_subconfig("pricer_config", {})
self.pricer_client_ = PtMktDataClient(
live_strategy=self.live_strategy_,
pricer_config=pricer_config
)
Log.info(f"{self.fname()} CVTT Pricer client created: {self.pricer_client_}")
async def run(self) -> None:
Log.info(f"{self.fname()} ...")
pass
if __name__ == "__main__":
App()
CvttAppConfig()
PairTradingRunner()
App.instance().run()

View File

@ -1,69 +0,0 @@
from functools import partial
from typing import Dict
from cvtt_client.mkt_data import (CvttPricerWebSockClient,
CvttPricesSubscription, MessageTypeT,
SubscriptionIdT)
from cvttpy_base.tools.app import App
from cvttpy_base.tools.base import NamedObject
from pt_strategy.live_strategy import PtLiveStrategy
class PairTradingRunner(NamedObject):
def __init__(self) -> None:
super().__init__()
App.instance().add_call(App.Stage.Config, self._on_config())
App.instance().add_call(App.Stage.Run, self.run())
async def _on_config(self) -> None:
pass
async def run(self) -> None:
pass
# async def main() -> None:
# live_strategy = PtLiveStrategy(
# config={},
# instruments=[
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-BTC-USD"},
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-ETH-USD"},
# ]
# )
# async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
# print(f"{message_type=} {subscr_id=} {instrument_id}")
# if message_type == "md_aggregate":
# aggr = message.get("md_aggregate", [])
# print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}")
# elif message_type == "historical_md_aggregate":
# for aggr in message.get("historical_data", []):
# print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}")
# else:
# print(f"Unknown message type: {message_type}")
# pricer_client = CvttPricerWebSockClient(
# "ws://localhost:12346/ws"
# )
# await pricer_client.subscribe(CvttPricesSubscription(
# exchange_config_name="COINBASE_AT",
# instrument_id="PAIR-BTC-USD",
# interval_sec=60,
# history_depth_sec=60*60*24,
# callback=partial(on_message, instrument_id="PAIR-BTC-USD")
# ))
# await pricer_client.subscribe(CvttPricesSubscription(
# exchange_config_name="COINBASE_AT",
# instrument_id="PAIR-ETH-USD",
# interval_sec=60,
# history_depth_sec=60*60*24,
# callback=partial(on_message, instrument_id="PAIR-ETH-USD")
# ))
# await pricer_client.run()
if __name__ == "__main__":
App()
App.instance().run()

View File

@ -0,0 +1,21 @@
{
"strategy_config": @inc=file:///home/oleg/develop/pairs_trading/configuration/ols.cfg
"pricer_config": {
"pricer_url": "ws://localhost:12346/ws",
"history_depth_sec": 86400 #"60*60*24", # use simpleeval
"interval_sec": 60
},
"ti_config": {
"cvtt_base_url": "http://localhost:23456"
"book_id": "XXXXXXXXX",
"strategy_id": "XXXXXXXXX",
"ti_endpoint": {
"method": "POST",
"url": "/trading_instructions"
},
"health_check_endpoint": {
"method": "GET",
"url": "/ping"
}
}
}

View File

@ -10,10 +10,12 @@ import uuid
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Optional
from numpy.strings import str_len
import websockets
from websockets.asyncio.client import ClientConnection
from cvttpy_base.settings.cvtt_types import JsonDictT
from cvttpy_base.tools.logger import Log
MessageTypeT = str
SubscriptionIdT = str
MessageT = Dict
@ -48,22 +50,56 @@ class CvttPricesSubscription:
self.is_subscribed_ = False
self.is_historical_ = history_depth_sec > 0
class CvttPricerWebSockClient:
# Class members with type hints
class CvttWebSockClient:
ws_url_: UrlT
websocket_: Optional[ClientConnection]
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
is_connected_: bool
logger_: logging.Logger
def __init__(self, url: str):
self.ws_url_ = url
self.websocket_ = None
self.is_connected_ = False
async def connect(self) -> None:
self.websocket_ = await websockets.connect(self.ws_url_)
self.is_connected_ = True
async def close(self) -> None:
if self.websocket_ is not None:
await self.websocket_.close()
self.is_connected_ = False
async def receive_message(self) -> JsonDictT:
assert self.websocket_ is not None
assert self.is_connected_
message = await self.websocket_.recv()
message_str = (
message.decode("utf-8")
if isinstance(message, bytes)
else message
)
res = json.loads(message_str)
assert res is not None
assert isinstance(res, dict)
return res
@classmethod
async def check_connection(cls, url: str) -> bool:
try:
async with websockets.connect(url) as websocket:
result = True
except Exception as e:
Log.error(f"Unable to connect to {url}: {str(e)}")
result = False
return result
class CvttPricerWebSockClient(CvttWebSockClient):
# Class members with type hints
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
def __init__(self, url: str):
super().__init__(url)
self.subscriptions_ = {}
self.logger_ = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
async def subscribe(
self, subscription: CvttPricesSubscription
@ -71,11 +107,10 @@ class CvttPricerWebSockClient:
if not self.is_connected_:
try:
self.logger_.info(f"Connecting to {self.ws_url_}")
self.websocket_ = await websockets.connect(self.ws_url_)
self.is_connected_ = True
Log.info(f"Connecting to {self.ws_url_}")
await self.connect()
except Exception as e:
self.logger_.error(f"Unable to connect to {self.ws_url_}: {str(e)}")
Log.error(f"Unable to connect to {self.ws_url_}: {str(e)}")
raise e
subscr_msg = {
@ -109,10 +144,10 @@ class CvttPricerWebSockClient:
return False
if response.get("status") == "success":
self.logger_.info(f"Subscription successful: {json.dumps(response)}")
Log.info(f"Subscription successful: {json.dumps(response)}")
return True
elif response.get("status") == "error":
self.logger_.error(f"Subscription failed: {response.get('reason')}")
Log.error(f"Subscription failed: {response.get('reason')}")
return False
return False
@ -121,19 +156,20 @@ class CvttPricerWebSockClient:
try:
while self.is_connected_:
try:
message = await self.websocket_.recv()
message_str = (
message.decode("utf-8")
if isinstance(message, bytes)
else message
)
await self.process_message(json.loads(message_str))
msg_dict: JsonDictT = await self.receive_message()
except websockets.ConnectionClosed:
self.logger_.warning("Connection closed")
Log.warning("Connection closed")
self.is_connected_ = False
break
except Exception as e:
Log.error(f"Error occurred: {str(e)}")
self.is_connected_ = False
await asyncio.sleep(5) # Wait before reconnecting
await self.process_message(msg_dict)
except Exception as e:
self.logger_.error(f"Error occurred: {str(e)}")
Log.error(f"Error occurred: {str(e)}")
self.is_connected_ = False
await asyncio.sleep(5) # Wait before reconnecting
@ -142,13 +178,13 @@ class CvttPricerWebSockClient:
if message_type in ["md_aggregate", "historical_md_aggregate"]:
subscription_id = message.get("subscr_id")
if subscription_id not in self.subscriptions_:
self.logger_.warning(f"Unknown subscription id: {subscription_id}")
Log.warning(f"Unknown subscription id: {subscription_id}")
return
subscription = self.subscriptions_[subscription_id]
await subscription.callback_(message_type, subscription_id, message)
else:
self.logger_.warning(f"Unknown message type: {message.get('type')}")
Log.warning(f"Unknown message type: {message.get('type')}")
async def main() -> None:

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional
@ -8,61 +9,26 @@ from cvttpy_base.settings.cvtt_types import JsonDictT
from cvttpy_base.tools.base import NamedObject
from cvttpy_base.tools.logger import Log
from cvtt_client.mkt_data import CvttPricerWebSockClient, CvttPricesSubscription, MessageTypeT, SubscriptionIdT
from pt_strategy.live.ti_sender import TradingInstructionsSender
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData, RealTimeMarketData
from pt_strategy.pt_market_data import RealTimeMarketData
from pt_strategy.pt_model import Prediction
from pt_strategy.trading_pair import PairState, TradingPair
'''
"""
--config=pair.cfg
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
'''
"""
class PtMktDataClient(NamedObject):
live_strategy_: PtLiveStrategy
pricer_client_: CvttPricerWebSockClient
subscriptions_: List[CvttPricesSubscription]
def __init__(self, live_strategy: PtLiveStrategy):
self.live_strategy_ = live_strategy
async def start(self, subscription: CvttPricesSubscription) -> None:
pricer_url = self.live_strategy_.config_.get("pricer_url", None) #, "ws://localhost:12346/ws")
assert pricer_url is not None, "pricer_url is not found in config"
self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)
await self._subscribe()
async def _subscribe(self) -> None:
pair: TradingPair = self.live_strategy_.trading_pair_
for instrument in pair.instruments_:
await self.pricer_client_.subscribe(CvttPricesSubscription(
exchange_config_name=instrument["exchange_config_name"],
instrument_id=instrument["instrument_id"],
interval_sec=60,
history_depth_sec=60*60*24,
callback=partial(self.on_message, instrument_id=instrument["instrument_id"])
))
async def on_message(self, message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
aggr: JsonDictT
if message_type == "md_aggregate":
aggr = message.get("md_aggregate", {})
await self.live_strategy_.on_mkt_data_update(aggr)
# print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}")
elif message_type == "historical_md_aggregate":
aggr = message.get("historical_data", {})
await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
# print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}")
else:
Log.info(f"Unknown message type: {message_type}")
async def run(self) -> None:
await self.pricer_client_.run()
class TradingInstructionType(Enum):
TARGET_POSITION = "TARGET_POSITION"
@dataclass
class TradingInstruction(NamedObject):
type_: TradingInstructionType
exch_instr_: ExchangeInstrument
specifics_: Dict[str, Any]
class PtLiveStrategy(NamedObject):
@ -70,22 +36,24 @@ class PtLiveStrategy(NamedObject):
trading_pair_: TradingPair
model_data_policy_: ModelDataPolicy
pt_mkt_data_: RealTimeMarketData
pt_mkt_data_client_: PtMktDataClient
# for presentation: history of prediction values and trading signals
ti_sender_: TradingInstructionsSender
# for presentation: history of prediction values and trading signals
predictions_: pd.DataFrame
trading_signals_: pd.DataFrame
trading_signals_: pd.DataFrame
def __init__(
self,
config: Dict[str, Any],
instruments: List[Dict[str, str]],
ti_sender: TradingInstructionsSender,
):
self.config_ = config
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
self.predictions_ = pd.DataFrame()
self.trading_signals_ = pd.DataFrame()
self.ti_sender_ = ti_sender
import copy
@ -94,9 +62,16 @@ class PtLiveStrategy(NamedObject):
config_copy["instruments"] = instruments
self.pt_mkt_data_ = RealTimeMarketData(config=config_copy)
self.model_data_policy_ = ModelDataPolicy.create(
config, is_real_time=True,pair=self.trading_pair_
config, is_real_time=True, pair=self.trading_pair_
)
self.open_threshold_ = self.config_.get("dis-equilibrium_open_trshld", 0.0)
assert self.open_threshold_ > 0, "open_threshold must be greater than 0"
self.close_threshold_ = self.config_.get("dis-equilibrium_close_trshld", 0.0)
assert self.close_threshold_ > 0, "close_threshold must be greater than 0"
def __repr__(self) -> str:
return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, "
async def on_mkt_data_hist_snapshot(self, aggr: JsonDictT) -> None:
Log.info(f"on_mkt_data_hist_snapshot: {aggr}")
await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr)
@ -107,63 +82,106 @@ class PtLiveStrategy(NamedObject):
if market_data_df is not None:
self.trading_pair_.market_data_ = market_data_df
self.model_data_policy_.advance()
prediction = self.trading_pair_.run(market_data_df, self.model_data_policy_.advance())
self.predictions_ = pd.concat([self.predictions_, prediction.to_df()], ignore_index=True)
trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
prediction = self.trading_pair_.run(
market_data_df, self.model_data_policy_.advance()
)
self.predictions_ = pd.concat(
[self.predictions_, prediction.to_df()], ignore_index=True
)
trading_instructions: List[TradingInstruction] = (
self._create_trading_instructions(
prediction=prediction, last_row=market_data_df.iloc[-1]
)
)
if len(trading_instructions) > 0:
await self._send_trading_instructions(trading_instructions)
# trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
# URGENT implement this
pass
async def run(self) -> None:
await self.pt_mkt_data_client_.run()
async def _send_trading_instructions(
self, trading_instructions: pd.DataFrame
) -> None:
pass
def _create_trades(
def _create_trading_instructions(
self, prediction: Prediction, last_row: pd.Series
) -> Optional[pd.DataFrame]:
) -> List[TradingInstruction]:
pair = self.trading_pair_
trades = None
trd_instructions: List[TradingInstruction] = []
open_threshold = self.config_["dis-equilibrium_open_trshld"]
close_threshold = self.config_["dis-equilibrium_close_trshld"]
scaled_disequilibrium = prediction.scaled_disequilibrium_
abs_scaled_disequilibrium = abs(scaled_disequilibrium)
if pair.user_data_["state"] in [
PairState.INITIAL,
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]:
if abs_scaled_disequilibrium >= open_threshold:
trades = self._create_open_trades(
if pair.is_closed():
if abs_scaled_disequilibrium >= self.open_threshold_:
trd_instructions = self._create_open_trade_instructions(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.OPEN.name
print(f"OPEN TRADES:\n{trades}")
pair.user_data_["state"] = PairState.OPEN
pair.on_open_trades(trades)
elif pair.user_data_["state"] == PairState.OPEN:
if abs_scaled_disequilibrium <= close_threshold:
trades = self._create_close_trades(
elif pair.is_open():
if abs_scaled_disequilibrium <= self.close_threshold_:
trd_instructions = self._create_close_trade_instructions(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.CLOSE.name
print(f"CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = PairState.CLOSE
pair.on_close_trades(trades)
elif pair.to_stop_close_conditions(predicted_row=last_row):
trades = self._create_close_trades(pair, row=last_row)
if trades is not None:
trades["status"] = pair.user_data_["stop_close_state"].name
print(f"STOP CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
pair.on_close_trades(trades)
trd_instructions = self._create_close_trade_instructions(
pair, row=last_row
)
return trades
return trd_instructions
def _create_open_trade_instructions(
self, pair: TradingPair, row: pd.Series, prediction: Prediction
) -> List[TradingInstruction]:
scaled_disequilibrium = prediction.scaled_disequilibrium_
if scaled_disequilibrium > 0:
side_a = "SELL"
trd_inst_a = TradingInstruction(
type=TradingInstructionType.TARGET_POSITION,
exch_instr=pair.get_instrument_a(),
specifics={"side": "SELL", "strength": -1},
)
side_b = "BUY"
else:
side_a = "BUY"
side_b = "SELL"
# save closing sides
pair.user_data_["open_side_a"] = side_a # used in oustanding positions
pair.user_data_["open_side_b"] = side_b
pair.user_data_["open_px_a"] = px_a
pair.user_data_["open_px_b"] = px_b
pair.user_data_["open_tstamp"] = tstamp
pair.user_data_["close_side_a"] = side_b # used for closing trades
pair.user_data_["close_side_b"] = side_a
# create opening trades
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_a_,
"side": side_a,
"action": "OPEN",
"price": px_a,
"disequilibrium": diseqlbrm,
"signed_scaled_disequilibrium": scaled_disequilibrium,
"scaled_disequilibrium": abs(scaled_disequilibrium),
# "pair": pair,
}
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_b_,
"side": side_b,
"action": "OPEN",
"price": px_b,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": abs(scaled_disequilibrium),
"signed_scaled_disequilibrium": scaled_disequilibrium,
# "pair": pair,
}
return df
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
trades = None
@ -328,4 +346,3 @@ class PtLiveStrategy(NamedObject):
del pair.user_data_["open_side_a"]
del pair.user_data_["open_side_b"]
return df

View File

@ -0,0 +1,90 @@
from __future__ import annotations
from functools import partial
from typing import Any, Dict, List, Optional
import pandas as pd
from cvttpy_base.tools.app import App
from cvttpy_base.tools.config import Config
from cvttpy_base.settings.cvtt_types import JsonDictT
from cvttpy_base.tools.base import NamedObject
from cvttpy_base.tools.logger import Log
from cvtt_client.mkt_data import (
CvttPricerWebSockClient,
CvttPricesSubscription,
MessageTypeT,
SubscriptionIdT,
)
from pt_strategy.live.live_strategy import PtLiveStrategy
from pt_strategy.trading_pair import TradingPair
"""
--config=pair.cfg
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
"""
class PtMktDataClient(NamedObject):
config_: Config
live_strategy_: PtLiveStrategy
pricer_client_: CvttPricerWebSockClient
subscriptions_: List[CvttPricesSubscription]
def __init__(self, live_strategy: PtLiveStrategy, pricer_config: Config):
self.config_ = pricer_config
self.live_strategy_ = live_strategy
App.instance().add_call(App.Stage.Start, self._on_start())
App.instance().add_call(App.Stage.Run, self.run())
async def _on_start(self) -> None:
pricer_url = self.config_.get_value("pricer_url")
assert pricer_url is not None, "pricer_url is not found in config"
self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)
async def _subscribe(self) -> None:
history_depth_sec = self.config_.get_value("history_depth_sec", 86400)
interval_sec = self.config_.get_value("interval_sec", 60)
pair: TradingPair = self.live_strategy_.trading_pair_
subscriptions = [CvttPricesSubscription(
exchange_config_name=instrument["exchange_config_name"],
instrument_id=instrument["instrument_id"],
interval_sec=interval_sec,
history_depth_sec=history_depth_sec,
callback=partial(
self.on_message, instrument_id=instrument["instrument_id"]
),
) for instrument in pair.instruments_]
for subscription in subscriptions:
Log.info(f"{self.fname()} Subscribing to {subscription}")
await self.pricer_client_.subscribe(subscription)
async def on_message(
self,
message_type: MessageTypeT,
subscr_id: SubscriptionIdT,
message: Dict,
instrument_id: str,
) -> None:
Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
aggr: JsonDictT
if message_type == "md_aggregate":
aggr = message.get("md_aggregate", {})
await self.live_strategy_.on_mkt_data_update(aggr)
elif message_type == "historical_md_aggregate":
aggr = message.get("historical_data", {})
await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
else:
Log.info(f"Unknown message type: {message_type}")
async def run(self) -> None:
if not await CvttPricerWebSockClient.check_connection(self.pricer_client_.ws_url_):
Log.error(f"Unable to connect to {self.pricer_client_.ws_url_}")
raise Exception(f"Unable to connect to {self.pricer_client_.ws_url_}")
await self._subscribe()
await self.pricer_client_.run()

View File

@ -0,0 +1,85 @@
from enum import Enum
from typing import Dict, Any, Tuple
import time
# import aiohttp
from cvttpy_base.tools.app import App
from cvttpy_base.tools.base import NamedObject
from cvttpy_base.tools.config import Config
from cvttpy_base.tools.logger import Log
from cvttpy_base.tools.web.rest_client import REST_RequestProcessor
from cvttpy_base.tools.timeutils import NanoPerSec
from cvttpy_base.tools.timer import Timer
class TradingInstructionsSender(NamedObject):
class TradingInstType(str, Enum):
TARGET_POSITION = "TARGET_POSITION"
DIRECT_ORDER = "DIRECT_ORDER"
MARKET_MAKING = "MARKET_MAKING"
NONE = "NONE"
config_: Config
ti_method_: str
ti_url_: str
health_check_method_: str
health_check_url_: str
def __init__(self, config: Config):
self.config_ = config
base_url = config.get_value("url", "ws://localhost:12346/ws")
self.book_id_ = config.get_value("book_id", "")
assert self.book_id_, "book_id is required"
self.strategy_id_ = config.get_value("strategy_id", "")
assert self.strategy_id_, "strategy_id is required"
endpoint_uri = config.get_value("ti_endpoint/url", "/trading_instructions")
endpoint_method = config.get_value("ti_endpoint/method", "POST")
health_check_uri = config.get_value("health_check_endpoint/url", "/ping")
health_check_method = config.get_value("health_check_endpoint/method", "GET")
self.ti_method_ = endpoint_method
self.ti_url_ = f"{base_url}{endpoint_uri}"
self.health_check_method_ = health_check_method
self.health_check_url_ = f"{base_url}{health_check_uri}"
App.instance().add_call(App.Stage.Start, self._set_health_check_timer(), can_run_now=True)
async def _set_health_check_timer(self) -> None:
# TODO: configurable interval
self.health_check_timer_ = Timer(is_periodic=True, period_interval=15, start_in_sec=0, func=self._health_check)
Log.info(f"{self.fname()} Health check timer set to 15 seconds")
async def _health_check(self) -> None:
rqst = REST_RequestProcessor(method=self.health_check_method_, url=self.health_check_url_)
async with rqst as (status, msg, headers):
if status != 200:
Log.error(f"{self.fname()} CVTT Service is not responding")
async def send_tgt_positions(self, strength: float, base_asset: str, quote_asset: str) -> Tuple[int, str]:
instr = {
"type": self.TradingInstType.TARGET_POSITION.value,
"book_id": self.book_id_,
"strategy_id": self.strategy_id_,
"issued_ts_ns": int(time.time() * NanoPerSec),
"data": {
"strength": strength,
"base_asset": base_asset,
"quote_asset": quote_asset,
"user_data": {},
},
}
rqst = REST_RequestProcessor(method=self.ti_method_, url=self.ti_url_, params=instr)
async with rqst as (status, msg, headers):
if status != 200:
raise ConnectionError(f"Failed to send trading instructions: {msg}")
return (status, msg)

View File

@ -18,6 +18,20 @@ class PairState(Enum):
CLOSE_STOP_LOSS = 5
CLOSE_STOP_PROFIT = 6
def get_symbol(instrument: Dict[str, str]) -> str:
if "symbol" in instrument:
return instrument["symbol"]
elif "instrument_id" in instrument:
instrument_id = instrument["instrument_id"]
instrument_pfx = instrument_id[:instrument_id.find("-") + 1]
symbol = instrument_id[len(instrument_pfx):]
instrument["symbol"] = symbol
instrument["instrument_id_pfx"] = instrument_pfx
return symbol
else:
raise ValueError(f"Invalid instrument: {instrument}, missing symbol or instrument_id")
class TradingPair:
config_: Dict[str, Any]
market_data_: pd.DataFrame
@ -42,14 +56,32 @@ class TradingPair:
self.config_ = config
self.instruments_ = instruments
self.symbol_a_ = instruments[0]["symbol"]
self.symbol_b_ = instruments[1]["symbol"]
self.symbol_a_ = get_symbol(instruments[0])
self.symbol_b_ = get_symbol(instruments[1])
self.model_ = PairsTradingModel.create(config)
self.stat_model_price_ = config["stat_model_price"]
self.user_data_ = {
"state": PairState.INITIAL,
}
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}:"
f" symbol_a={self.symbol_a_},"
f" symbol_b={self.symbol_b_},"
f" model={self.model_.__class__.__name__}"
)
def is_closed(self) -> bool:
return self.user_data_["state"] in [
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]
def is_open(self) -> bool:
return self.user_data_["state"] == PairState.OPEN
def colnames(self) -> List[str]:
return [
f"{self.stat_model_price_}_{self.symbol_a_}",

View File

@ -78,6 +78,7 @@ scipy<1.13.0
seaborn>=0.13.2
SecretStorage>=3.3.1
setproctitle>=1.2.2
simpleeval>=1.0.3
six>=1.16.0
soupsieve>=2.3.1
ssh-import-id>=5.11