progress. not ready, changing live_strategy.py to use ExchangeInstrument and new TradingInstruction class
This commit is contained in:
parent
7ab09669b4
commit
0423a7d34f
24
.vscode/launch.json
vendored
24
.vscode/launch.json
vendored
@ -10,8 +10,30 @@
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
"console": "integratedTerminal",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib:${workspaceFolder}/.."
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "-------- Live Pair Trading --------",
|
||||
},
|
||||
{
|
||||
"name": "PAIRS TRADER",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "${workspaceFolder}/bin/pairs_trader.py",
|
||||
"console": "integratedTerminal",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib:${workspaceFolder}/.."
|
||||
},
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/pairs_trader.cfg",
|
||||
"--pair=PAIR-ADA-USDT:BNBSPOT,PAIR-SOL-USDT:BNBSPOT",
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "-------- OLS --------",
|
||||
|
||||
3
.vscode/pairs_trading.code-workspace
vendored
3
.vscode/pairs_trading.code-workspace
vendored
@ -2,6 +2,9 @@
|
||||
"folders": [
|
||||
{
|
||||
"path": ".."
|
||||
},
|
||||
{
|
||||
"path": "../../cvttpy_base"
|
||||
}
|
||||
],
|
||||
"settings": {
|
||||
|
||||
104
bin/pairs_trader.py
Normal file
104
bin/pairs_trader.py
Normal file
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
# import sys
|
||||
# print("PYTHONPATH directories:")
|
||||
# for path in sys.path:
|
||||
# print(path)
|
||||
|
||||
from cvttpy_base.tools.app import App
|
||||
from cvttpy_base.tools.base import NamedObject
|
||||
from cvttpy_base.tools.logger import Log
|
||||
from cvttpy_base.tools.config import CvttAppConfig
|
||||
from cvttpy_base.settings.cvtt_types import JsonDictT
|
||||
|
||||
from pt_strategy.live.live_strategy import PtLiveStrategy
|
||||
from pt_strategy.live.pricer_md_client import PtMktDataClient
|
||||
from pt_strategy.live.ti_sender import TradingInstructionsSender
|
||||
|
||||
# from cvtt_client.mkt_data import (CvttPricerWebSockClient,
|
||||
# CvttPricesSubscription, MessageTypeT,
|
||||
# SubscriptionIdT)
|
||||
|
||||
class PairTradingRunner(NamedObject):
|
||||
config_: CvttAppConfig
|
||||
instruments_: List[JsonDictT]
|
||||
|
||||
live_strategy_: PtLiveStrategy
|
||||
pricer_client_: PtMktDataClient
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.instruments_ = []
|
||||
|
||||
App.instance().add_cmdline_arg(
|
||||
"--pair",
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Comma-separated pair of instrument symbols"
|
||||
" with exchange config name"
|
||||
" (e.g., PAIR-BTC-USD:BNBSPOT,PAIR-ETH-USD:BNBSPOT)"
|
||||
),
|
||||
)
|
||||
|
||||
App.instance().add_call(App.Stage.Config, self._on_config())
|
||||
App.instance().add_call(App.Stage.Run, self.run())
|
||||
|
||||
async def _on_config(self) -> None:
|
||||
self.config_ = CvttAppConfig.instance()
|
||||
|
||||
# ------- PARSE INSTRUMENTS -------
|
||||
instr_str = App.instance().get_argument("pair", "")
|
||||
if not instr_str:
|
||||
raise ValueError("Pair is required")
|
||||
instr_list = instr_str.split(",")
|
||||
for instr in instr_list:
|
||||
instr_parts = instr.split(":")
|
||||
if len(instr_parts) != 2:
|
||||
raise ValueError(f"Invalid pair format: {instr}")
|
||||
instrument_id = instr_parts[0]
|
||||
exchange_config_name = instr_parts[1]
|
||||
self.instruments_.append({
|
||||
"exchange_config_name": exchange_config_name,
|
||||
"instrument_id": instrument_id
|
||||
})
|
||||
|
||||
assert len(self.instruments_) == 2, "Only two instruments are supported"
|
||||
Log.info(f"{self.fname()} Instruments: {self.instruments_}")
|
||||
|
||||
# ------- CREATE TI (trading instructions) CLIENT -------
|
||||
ti_config = self.config_.get_subconfig("ti_config", {})
|
||||
self.ti_sender_ = TradingInstructionsSender(config=ti_config)
|
||||
Log.info(f"{self.fname()} TI client created: {self.ti_sender_}")
|
||||
|
||||
# ------- CREATE STRATEGY -------
|
||||
strategy_config = self.config_.get_value("strategy_config", {})
|
||||
self.live_strategy_ = PtLiveStrategy(
|
||||
config=strategy_config,
|
||||
instruments=self.instruments_,
|
||||
ti_sender=self.ti_sender_
|
||||
)
|
||||
Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}")
|
||||
|
||||
# ------- CREATE PRICER CLIENT -------
|
||||
pricer_config = self.config_.get_subconfig("pricer_config", {})
|
||||
self.pricer_client_ = PtMktDataClient(
|
||||
live_strategy=self.live_strategy_,
|
||||
pricer_config=pricer_config
|
||||
)
|
||||
Log.info(f"{self.fname()} CVTT Pricer client created: {self.pricer_client_}")
|
||||
|
||||
async def run(self) -> None:
|
||||
Log.info(f"{self.fname()} ...")
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
App()
|
||||
CvttAppConfig()
|
||||
PairTradingRunner()
|
||||
App.instance().run()
|
||||
@ -1,69 +0,0 @@
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
from cvtt_client.mkt_data import (CvttPricerWebSockClient,
|
||||
CvttPricesSubscription, MessageTypeT,
|
||||
SubscriptionIdT)
|
||||
from cvttpy_base.tools.app import App
|
||||
from cvttpy_base.tools.base import NamedObject
|
||||
from pt_strategy.live_strategy import PtLiveStrategy
|
||||
|
||||
|
||||
class PairTradingRunner(NamedObject):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
App.instance().add_call(App.Stage.Config, self._on_config())
|
||||
App.instance().add_call(App.Stage.Run, self.run())
|
||||
|
||||
async def _on_config(self) -> None:
|
||||
pass
|
||||
|
||||
async def run(self) -> None:
|
||||
pass
|
||||
|
||||
# async def main() -> None:
|
||||
# live_strategy = PtLiveStrategy(
|
||||
# config={},
|
||||
# instruments=[
|
||||
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-BTC-USD"},
|
||||
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-ETH-USD"},
|
||||
# ]
|
||||
# )
|
||||
# async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
|
||||
# print(f"{message_type=} {subscr_id=} {instrument_id}")
|
||||
# if message_type == "md_aggregate":
|
||||
# aggr = message.get("md_aggregate", [])
|
||||
# print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}")
|
||||
# elif message_type == "historical_md_aggregate":
|
||||
# for aggr in message.get("historical_data", []):
|
||||
# print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}")
|
||||
# else:
|
||||
# print(f"Unknown message type: {message_type}")
|
||||
|
||||
# pricer_client = CvttPricerWebSockClient(
|
||||
# "ws://localhost:12346/ws"
|
||||
# )
|
||||
# await pricer_client.subscribe(CvttPricesSubscription(
|
||||
# exchange_config_name="COINBASE_AT",
|
||||
# instrument_id="PAIR-BTC-USD",
|
||||
# interval_sec=60,
|
||||
# history_depth_sec=60*60*24,
|
||||
# callback=partial(on_message, instrument_id="PAIR-BTC-USD")
|
||||
# ))
|
||||
# await pricer_client.subscribe(CvttPricesSubscription(
|
||||
# exchange_config_name="COINBASE_AT",
|
||||
# instrument_id="PAIR-ETH-USD",
|
||||
# interval_sec=60,
|
||||
# history_depth_sec=60*60*24,
|
||||
# callback=partial(on_message, instrument_id="PAIR-ETH-USD")
|
||||
# ))
|
||||
|
||||
# await pricer_client.run()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
App()
|
||||
|
||||
App.instance().run()
|
||||
21
configuration/pairs_trader.cfg
Normal file
21
configuration/pairs_trader.cfg
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"strategy_config": @inc=file:///home/oleg/develop/pairs_trading/configuration/ols.cfg
|
||||
"pricer_config": {
|
||||
"pricer_url": "ws://localhost:12346/ws",
|
||||
"history_depth_sec": 86400 #"60*60*24", # use simpleeval
|
||||
"interval_sec": 60
|
||||
},
|
||||
"ti_config": {
|
||||
"cvtt_base_url": "http://localhost:23456"
|
||||
"book_id": "XXXXXXXXX",
|
||||
"strategy_id": "XXXXXXXXX",
|
||||
"ti_endpoint": {
|
||||
"method": "POST",
|
||||
"url": "/trading_instructions"
|
||||
},
|
||||
"health_check_endpoint": {
|
||||
"method": "GET",
|
||||
"url": "/ping"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -10,10 +10,12 @@ import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Coroutine, Dict, List, Optional
|
||||
|
||||
from numpy.strings import str_len
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
from cvttpy_base.settings.cvtt_types import JsonDictT
|
||||
from cvttpy_base.tools.logger import Log
|
||||
|
||||
MessageTypeT = str
|
||||
SubscriptionIdT = str
|
||||
MessageT = Dict
|
||||
@ -48,22 +50,56 @@ class CvttPricesSubscription:
|
||||
self.is_subscribed_ = False
|
||||
self.is_historical_ = history_depth_sec > 0
|
||||
|
||||
|
||||
class CvttPricerWebSockClient:
|
||||
# Class members with type hints
|
||||
class CvttWebSockClient:
|
||||
ws_url_: UrlT
|
||||
websocket_: Optional[ClientConnection]
|
||||
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
|
||||
is_connected_: bool
|
||||
logger_: logging.Logger
|
||||
|
||||
def __init__(self, url: str):
|
||||
self.ws_url_ = url
|
||||
self.websocket_ = None
|
||||
self.is_connected_ = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.websocket_ = await websockets.connect(self.ws_url_)
|
||||
self.is_connected_ = True
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.websocket_ is not None:
|
||||
await self.websocket_.close()
|
||||
self.is_connected_ = False
|
||||
|
||||
async def receive_message(self) -> JsonDictT:
|
||||
assert self.websocket_ is not None
|
||||
assert self.is_connected_
|
||||
message = await self.websocket_.recv()
|
||||
message_str = (
|
||||
message.decode("utf-8")
|
||||
if isinstance(message, bytes)
|
||||
else message
|
||||
)
|
||||
res = json.loads(message_str)
|
||||
assert res is not None
|
||||
assert isinstance(res, dict)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
async def check_connection(cls, url: str) -> bool:
|
||||
try:
|
||||
async with websockets.connect(url) as websocket:
|
||||
result = True
|
||||
except Exception as e:
|
||||
Log.error(f"Unable to connect to {url}: {str(e)}")
|
||||
result = False
|
||||
return result
|
||||
|
||||
class CvttPricerWebSockClient(CvttWebSockClient):
|
||||
# Class members with type hints
|
||||
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
|
||||
|
||||
def __init__(self, url: str):
|
||||
super().__init__(url)
|
||||
self.subscriptions_ = {}
|
||||
self.logger_ = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
async def subscribe(
|
||||
self, subscription: CvttPricesSubscription
|
||||
@ -71,11 +107,10 @@ class CvttPricerWebSockClient:
|
||||
|
||||
if not self.is_connected_:
|
||||
try:
|
||||
self.logger_.info(f"Connecting to {self.ws_url_}")
|
||||
self.websocket_ = await websockets.connect(self.ws_url_)
|
||||
self.is_connected_ = True
|
||||
Log.info(f"Connecting to {self.ws_url_}")
|
||||
await self.connect()
|
||||
except Exception as e:
|
||||
self.logger_.error(f"Unable to connect to {self.ws_url_}: {str(e)}")
|
||||
Log.error(f"Unable to connect to {self.ws_url_}: {str(e)}")
|
||||
raise e
|
||||
|
||||
subscr_msg = {
|
||||
@ -109,10 +144,10 @@ class CvttPricerWebSockClient:
|
||||
return False
|
||||
|
||||
if response.get("status") == "success":
|
||||
self.logger_.info(f"Subscription successful: {json.dumps(response)}")
|
||||
Log.info(f"Subscription successful: {json.dumps(response)}")
|
||||
return True
|
||||
elif response.get("status") == "error":
|
||||
self.logger_.error(f"Subscription failed: {response.get('reason')}")
|
||||
Log.error(f"Subscription failed: {response.get('reason')}")
|
||||
return False
|
||||
return False
|
||||
|
||||
@ -121,19 +156,20 @@ class CvttPricerWebSockClient:
|
||||
try:
|
||||
while self.is_connected_:
|
||||
try:
|
||||
message = await self.websocket_.recv()
|
||||
message_str = (
|
||||
message.decode("utf-8")
|
||||
if isinstance(message, bytes)
|
||||
else message
|
||||
)
|
||||
await self.process_message(json.loads(message_str))
|
||||
msg_dict: JsonDictT = await self.receive_message()
|
||||
except websockets.ConnectionClosed:
|
||||
self.logger_.warning("Connection closed")
|
||||
Log.warning("Connection closed")
|
||||
self.is_connected_ = False
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger_.error(f"Error occurred: {str(e)}")
|
||||
Log.error(f"Error occurred: {str(e)}")
|
||||
self.is_connected_ = False
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
|
||||
await self.process_message(msg_dict)
|
||||
|
||||
except Exception as e:
|
||||
Log.error(f"Error occurred: {str(e)}")
|
||||
self.is_connected_ = False
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
|
||||
@ -142,13 +178,13 @@ class CvttPricerWebSockClient:
|
||||
if message_type in ["md_aggregate", "historical_md_aggregate"]:
|
||||
subscription_id = message.get("subscr_id")
|
||||
if subscription_id not in self.subscriptions_:
|
||||
self.logger_.warning(f"Unknown subscription id: {subscription_id}")
|
||||
Log.warning(f"Unknown subscription id: {subscription_id}")
|
||||
return
|
||||
|
||||
subscription = self.subscriptions_[subscription_id]
|
||||
await subscription.callback_(message_type, subscription_id, message)
|
||||
else:
|
||||
self.logger_.warning(f"Unknown message type: {message.get('type')}")
|
||||
Log.warning(f"Unknown message type: {message.get('type')}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@ -8,61 +9,26 @@ from cvttpy_base.settings.cvtt_types import JsonDictT
|
||||
from cvttpy_base.tools.base import NamedObject
|
||||
from cvttpy_base.tools.logger import Log
|
||||
|
||||
from cvtt_client.mkt_data import CvttPricerWebSockClient, CvttPricesSubscription, MessageTypeT, SubscriptionIdT
|
||||
from pt_strategy.live.ti_sender import TradingInstructionsSender
|
||||
from pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pt_strategy.pt_market_data import PtMarketData, RealTimeMarketData
|
||||
from pt_strategy.pt_market_data import RealTimeMarketData
|
||||
from pt_strategy.pt_model import Prediction
|
||||
from pt_strategy.trading_pair import PairState, TradingPair
|
||||
|
||||
'''
|
||||
"""
|
||||
--config=pair.cfg
|
||||
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class PtMktDataClient(NamedObject):
|
||||
live_strategy_: PtLiveStrategy
|
||||
pricer_client_: CvttPricerWebSockClient
|
||||
subscriptions_: List[CvttPricesSubscription]
|
||||
|
||||
def __init__(self, live_strategy: PtLiveStrategy):
|
||||
self.live_strategy_ = live_strategy
|
||||
|
||||
async def start(self, subscription: CvttPricesSubscription) -> None:
|
||||
pricer_url = self.live_strategy_.config_.get("pricer_url", None) #, "ws://localhost:12346/ws")
|
||||
assert pricer_url is not None, "pricer_url is not found in config"
|
||||
self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)
|
||||
|
||||
await self._subscribe()
|
||||
|
||||
async def _subscribe(self) -> None:
|
||||
pair: TradingPair = self.live_strategy_.trading_pair_
|
||||
for instrument in pair.instruments_:
|
||||
await self.pricer_client_.subscribe(CvttPricesSubscription(
|
||||
exchange_config_name=instrument["exchange_config_name"],
|
||||
instrument_id=instrument["instrument_id"],
|
||||
interval_sec=60,
|
||||
history_depth_sec=60*60*24,
|
||||
callback=partial(self.on_message, instrument_id=instrument["instrument_id"])
|
||||
))
|
||||
|
||||
async def on_message(self, message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
|
||||
Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
|
||||
aggr: JsonDictT
|
||||
if message_type == "md_aggregate":
|
||||
aggr = message.get("md_aggregate", {})
|
||||
await self.live_strategy_.on_mkt_data_update(aggr)
|
||||
# print(f"[{aggr['tstamp'][:19]}] *** RLTM *** {message}")
|
||||
elif message_type == "historical_md_aggregate":
|
||||
aggr = message.get("historical_data", {})
|
||||
await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
|
||||
# print(f"[{aggr['tstamp'][:19]}] *** HIST *** {aggr}")
|
||||
else:
|
||||
Log.info(f"Unknown message type: {message_type}")
|
||||
|
||||
async def run(self) -> None:
|
||||
await self.pricer_client_.run()
|
||||
class TradingInstructionType(Enum):
|
||||
TARGET_POSITION = "TARGET_POSITION"
|
||||
|
||||
@dataclass
|
||||
class TradingInstruction(NamedObject):
|
||||
type_: TradingInstructionType
|
||||
exch_instr_: ExchangeInstrument
|
||||
specifics_: Dict[str, Any]
|
||||
|
||||
|
||||
class PtLiveStrategy(NamedObject):
|
||||
@ -70,7 +36,7 @@ class PtLiveStrategy(NamedObject):
|
||||
trading_pair_: TradingPair
|
||||
model_data_policy_: ModelDataPolicy
|
||||
pt_mkt_data_: RealTimeMarketData
|
||||
pt_mkt_data_client_: PtMktDataClient
|
||||
ti_sender_: TradingInstructionsSender
|
||||
|
||||
# for presentation: history of prediction values and trading signals
|
||||
predictions_: pd.DataFrame
|
||||
@ -80,12 +46,14 @@ class PtLiveStrategy(NamedObject):
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
instruments: List[Dict[str, str]],
|
||||
ti_sender: TradingInstructionsSender,
|
||||
):
|
||||
|
||||
self.config_ = config
|
||||
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
|
||||
self.predictions_ = pd.DataFrame()
|
||||
self.trading_signals_ = pd.DataFrame()
|
||||
self.ti_sender_ = ti_sender
|
||||
|
||||
import copy
|
||||
|
||||
@ -94,8 +62,15 @@ class PtLiveStrategy(NamedObject):
|
||||
config_copy["instruments"] = instruments
|
||||
self.pt_mkt_data_ = RealTimeMarketData(config=config_copy)
|
||||
self.model_data_policy_ = ModelDataPolicy.create(
|
||||
config, is_real_time=True,pair=self.trading_pair_
|
||||
config, is_real_time=True, pair=self.trading_pair_
|
||||
)
|
||||
self.open_threshold_ = self.config_.get("dis-equilibrium_open_trshld", 0.0)
|
||||
assert self.open_threshold_ > 0, "open_threshold must be greater than 0"
|
||||
self.close_threshold_ = self.config_.get("dis-equilibrium_close_trshld", 0.0)
|
||||
assert self.close_threshold_ > 0, "close_threshold must be greater than 0"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, "
|
||||
|
||||
async def on_mkt_data_hist_snapshot(self, aggr: JsonDictT) -> None:
|
||||
Log.info(f"on_mkt_data_hist_snapshot: {aggr}")
|
||||
@ -107,63 +82,106 @@ class PtLiveStrategy(NamedObject):
|
||||
if market_data_df is not None:
|
||||
self.trading_pair_.market_data_ = market_data_df
|
||||
self.model_data_policy_.advance()
|
||||
prediction = self.trading_pair_.run(market_data_df, self.model_data_policy_.advance())
|
||||
self.predictions_ = pd.concat([self.predictions_, prediction.to_df()], ignore_index=True)
|
||||
trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
|
||||
prediction = self.trading_pair_.run(
|
||||
market_data_df, self.model_data_policy_.advance()
|
||||
)
|
||||
self.predictions_ = pd.concat(
|
||||
[self.predictions_, prediction.to_df()], ignore_index=True
|
||||
)
|
||||
|
||||
trading_instructions: List[TradingInstruction] = (
|
||||
self._create_trading_instructions(
|
||||
prediction=prediction, last_row=market_data_df.iloc[-1]
|
||||
)
|
||||
)
|
||||
if len(trading_instructions) > 0:
|
||||
await self._send_trading_instructions(trading_instructions)
|
||||
# trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
|
||||
# URGENT implement this
|
||||
pass
|
||||
|
||||
async def _send_trading_instructions(
|
||||
self, trading_instructions: pd.DataFrame
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def run(self) -> None:
|
||||
await self.pt_mkt_data_client_.run()
|
||||
|
||||
def _create_trades(
|
||||
def _create_trading_instructions(
|
||||
self, prediction: Prediction, last_row: pd.Series
|
||||
) -> Optional[pd.DataFrame]:
|
||||
) -> List[TradingInstruction]:
|
||||
pair = self.trading_pair_
|
||||
trades = None
|
||||
trd_instructions: List[TradingInstruction] = []
|
||||
|
||||
open_threshold = self.config_["dis-equilibrium_open_trshld"]
|
||||
close_threshold = self.config_["dis-equilibrium_close_trshld"]
|
||||
scaled_disequilibrium = prediction.scaled_disequilibrium_
|
||||
abs_scaled_disequilibrium = abs(scaled_disequilibrium)
|
||||
|
||||
if pair.user_data_["state"] in [
|
||||
PairState.INITIAL,
|
||||
PairState.CLOSE,
|
||||
PairState.CLOSE_POSITION,
|
||||
PairState.CLOSE_STOP_LOSS,
|
||||
PairState.CLOSE_STOP_PROFIT,
|
||||
]:
|
||||
if abs_scaled_disequilibrium >= open_threshold:
|
||||
trades = self._create_open_trades(
|
||||
if pair.is_closed():
|
||||
if abs_scaled_disequilibrium >= self.open_threshold_:
|
||||
trd_instructions = self._create_open_trade_instructions(
|
||||
pair, row=last_row, prediction=prediction
|
||||
)
|
||||
if trades is not None:
|
||||
trades["status"] = PairState.OPEN.name
|
||||
print(f"OPEN TRADES:\n{trades}")
|
||||
pair.user_data_["state"] = PairState.OPEN
|
||||
pair.on_open_trades(trades)
|
||||
|
||||
elif pair.user_data_["state"] == PairState.OPEN:
|
||||
if abs_scaled_disequilibrium <= close_threshold:
|
||||
trades = self._create_close_trades(
|
||||
elif pair.is_open():
|
||||
if abs_scaled_disequilibrium <= self.close_threshold_:
|
||||
trd_instructions = self._create_close_trade_instructions(
|
||||
pair, row=last_row, prediction=prediction
|
||||
)
|
||||
if trades is not None:
|
||||
trades["status"] = PairState.CLOSE.name
|
||||
print(f"CLOSE TRADES:\n{trades}")
|
||||
pair.user_data_["state"] = PairState.CLOSE
|
||||
pair.on_close_trades(trades)
|
||||
elif pair.to_stop_close_conditions(predicted_row=last_row):
|
||||
trades = self._create_close_trades(pair, row=last_row)
|
||||
if trades is not None:
|
||||
trades["status"] = pair.user_data_["stop_close_state"].name
|
||||
print(f"STOP CLOSE TRADES:\n{trades}")
|
||||
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
|
||||
pair.on_close_trades(trades)
|
||||
trd_instructions = self._create_close_trade_instructions(
|
||||
pair, row=last_row
|
||||
)
|
||||
|
||||
return trades
|
||||
return trd_instructions
|
||||
|
||||
def _create_open_trade_instructions(
|
||||
self, pair: TradingPair, row: pd.Series, prediction: Prediction
|
||||
) -> List[TradingInstruction]:
|
||||
scaled_disequilibrium = prediction.scaled_disequilibrium_
|
||||
|
||||
if scaled_disequilibrium > 0:
|
||||
side_a = "SELL"
|
||||
trd_inst_a = TradingInstruction(
|
||||
type=TradingInstructionType.TARGET_POSITION,
|
||||
exch_instr=pair.get_instrument_a(),
|
||||
specifics={"side": "SELL", "strength": -1},
|
||||
)
|
||||
side_b = "BUY"
|
||||
else:
|
||||
side_a = "BUY"
|
||||
side_b = "SELL"
|
||||
|
||||
# save closing sides
|
||||
pair.user_data_["open_side_a"] = side_a # used in oustanding positions
|
||||
pair.user_data_["open_side_b"] = side_b
|
||||
pair.user_data_["open_px_a"] = px_a
|
||||
pair.user_data_["open_px_b"] = px_b
|
||||
pair.user_data_["open_tstamp"] = tstamp
|
||||
|
||||
pair.user_data_["close_side_a"] = side_b # used for closing trades
|
||||
pair.user_data_["close_side_b"] = side_a
|
||||
|
||||
# create opening trades
|
||||
df.loc[len(df)] = {
|
||||
"time": tstamp,
|
||||
"symbol": pair.symbol_a_,
|
||||
"side": side_a,
|
||||
"action": "OPEN",
|
||||
"price": px_a,
|
||||
"disequilibrium": diseqlbrm,
|
||||
"signed_scaled_disequilibrium": scaled_disequilibrium,
|
||||
"scaled_disequilibrium": abs(scaled_disequilibrium),
|
||||
# "pair": pair,
|
||||
}
|
||||
df.loc[len(df)] = {
|
||||
"time": tstamp,
|
||||
"symbol": pair.symbol_b_,
|
||||
"side": side_b,
|
||||
"action": "OPEN",
|
||||
"price": px_b,
|
||||
"disequilibrium": diseqlbrm,
|
||||
"scaled_disequilibrium": abs(scaled_disequilibrium),
|
||||
"signed_scaled_disequilibrium": scaled_disequilibrium,
|
||||
# "pair": pair,
|
||||
}
|
||||
return df
|
||||
|
||||
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
|
||||
trades = None
|
||||
@ -328,4 +346,3 @@ class PtLiveStrategy(NamedObject):
|
||||
del pair.user_data_["open_side_a"]
|
||||
del pair.user_data_["open_side_b"]
|
||||
return df
|
||||
|
||||
90
lib/pt_strategy/live/pricer_md_client.py
Normal file
90
lib/pt_strategy/live/pricer_md_client.py
Normal file
@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from cvttpy_base.tools.app import App
|
||||
from cvttpy_base.tools.config import Config
|
||||
from cvttpy_base.settings.cvtt_types import JsonDictT
|
||||
from cvttpy_base.tools.base import NamedObject
|
||||
from cvttpy_base.tools.logger import Log
|
||||
|
||||
from cvtt_client.mkt_data import (
|
||||
CvttPricerWebSockClient,
|
||||
CvttPricesSubscription,
|
||||
MessageTypeT,
|
||||
SubscriptionIdT,
|
||||
)
|
||||
from pt_strategy.live.live_strategy import PtLiveStrategy
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
|
||||
"""
|
||||
--config=pair.cfg
|
||||
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
|
||||
"""
|
||||
|
||||
|
||||
class PtMktDataClient(NamedObject):
|
||||
config_: Config
|
||||
live_strategy_: PtLiveStrategy
|
||||
pricer_client_: CvttPricerWebSockClient
|
||||
subscriptions_: List[CvttPricesSubscription]
|
||||
|
||||
def __init__(self, live_strategy: PtLiveStrategy, pricer_config: Config):
|
||||
self.config_ = pricer_config
|
||||
self.live_strategy_ = live_strategy
|
||||
|
||||
App.instance().add_call(App.Stage.Start, self._on_start())
|
||||
App.instance().add_call(App.Stage.Run, self.run())
|
||||
|
||||
async def _on_start(self) -> None:
|
||||
pricer_url = self.config_.get_value("pricer_url")
|
||||
assert pricer_url is not None, "pricer_url is not found in config"
|
||||
self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)
|
||||
|
||||
|
||||
async def _subscribe(self) -> None:
|
||||
history_depth_sec = self.config_.get_value("history_depth_sec", 86400)
|
||||
interval_sec = self.config_.get_value("interval_sec", 60)
|
||||
|
||||
pair: TradingPair = self.live_strategy_.trading_pair_
|
||||
subscriptions = [CvttPricesSubscription(
|
||||
exchange_config_name=instrument["exchange_config_name"],
|
||||
instrument_id=instrument["instrument_id"],
|
||||
interval_sec=interval_sec,
|
||||
history_depth_sec=history_depth_sec,
|
||||
callback=partial(
|
||||
self.on_message, instrument_id=instrument["instrument_id"]
|
||||
),
|
||||
) for instrument in pair.instruments_]
|
||||
|
||||
for subscription in subscriptions:
|
||||
Log.info(f"{self.fname()} Subscribing to {subscription}")
|
||||
await self.pricer_client_.subscribe(subscription)
|
||||
|
||||
async def on_message(
|
||||
self,
|
||||
message_type: MessageTypeT,
|
||||
subscr_id: SubscriptionIdT,
|
||||
message: Dict,
|
||||
instrument_id: str,
|
||||
) -> None:
|
||||
Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
|
||||
aggr: JsonDictT
|
||||
if message_type == "md_aggregate":
|
||||
aggr = message.get("md_aggregate", {})
|
||||
await self.live_strategy_.on_mkt_data_update(aggr)
|
||||
elif message_type == "historical_md_aggregate":
|
||||
aggr = message.get("historical_data", {})
|
||||
await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
|
||||
else:
|
||||
Log.info(f"Unknown message type: {message_type}")
|
||||
|
||||
async def run(self) -> None:
|
||||
if not await CvttPricerWebSockClient.check_connection(self.pricer_client_.ws_url_):
|
||||
Log.error(f"Unable to connect to {self.pricer_client_.ws_url_}")
|
||||
raise Exception(f"Unable to connect to {self.pricer_client_.ws_url_}")
|
||||
await self._subscribe()
|
||||
|
||||
await self.pricer_client_.run()
|
||||
85
lib/pt_strategy/live/ti_sender.py
Normal file
85
lib/pt_strategy/live/ti_sender.py
Normal file
@ -0,0 +1,85 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Tuple
|
||||
import time
|
||||
# import aiohttp
|
||||
from cvttpy_base.tools.app import App
|
||||
from cvttpy_base.tools.base import NamedObject
|
||||
from cvttpy_base.tools.config import Config
|
||||
from cvttpy_base.tools.logger import Log
|
||||
from cvttpy_base.tools.web.rest_client import REST_RequestProcessor
|
||||
from cvttpy_base.tools.timeutils import NanoPerSec
|
||||
from cvttpy_base.tools.timer import Timer
|
||||
|
||||
|
||||
class TradingInstructionsSender(NamedObject):
|
||||
|
||||
class TradingInstType(str, Enum):
|
||||
TARGET_POSITION = "TARGET_POSITION"
|
||||
DIRECT_ORDER = "DIRECT_ORDER"
|
||||
MARKET_MAKING = "MARKET_MAKING"
|
||||
NONE = "NONE"
|
||||
|
||||
config_: Config
|
||||
ti_method_: str
|
||||
ti_url_: str
|
||||
health_check_method_: str
|
||||
health_check_url_: str
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config_ = config
|
||||
base_url = config.get_value("url", "ws://localhost:12346/ws")
|
||||
|
||||
self.book_id_ = config.get_value("book_id", "")
|
||||
assert self.book_id_, "book_id is required"
|
||||
|
||||
self.strategy_id_ = config.get_value("strategy_id", "")
|
||||
assert self.strategy_id_, "strategy_id is required"
|
||||
|
||||
endpoint_uri = config.get_value("ti_endpoint/url", "/trading_instructions")
|
||||
endpoint_method = config.get_value("ti_endpoint/method", "POST")
|
||||
|
||||
health_check_uri = config.get_value("health_check_endpoint/url", "/ping")
|
||||
health_check_method = config.get_value("health_check_endpoint/method", "GET")
|
||||
|
||||
|
||||
|
||||
self.ti_method_ = endpoint_method
|
||||
self.ti_url_ = f"{base_url}{endpoint_uri}"
|
||||
|
||||
self.health_check_method_ = health_check_method
|
||||
self.health_check_url_ = f"{base_url}{health_check_uri}"
|
||||
|
||||
App.instance().add_call(App.Stage.Start, self._set_health_check_timer(), can_run_now=True)
|
||||
|
||||
async def _set_health_check_timer(self) -> None:
|
||||
# TODO: configurable interval
|
||||
self.health_check_timer_ = Timer(is_periodic=True, period_interval=15, start_in_sec=0, func=self._health_check)
|
||||
Log.info(f"{self.fname()} Health check timer set to 15 seconds")
|
||||
|
||||
async def _health_check(self) -> None:
|
||||
rqst = REST_RequestProcessor(method=self.health_check_method_, url=self.health_check_url_)
|
||||
async with rqst as (status, msg, headers):
|
||||
if status != 200:
|
||||
Log.error(f"{self.fname()} CVTT Service is not responding")
|
||||
|
||||
async def send_tgt_positions(self, strength: float, base_asset: str, quote_asset: str) -> Tuple[int, str]:
|
||||
instr = {
|
||||
"type": self.TradingInstType.TARGET_POSITION.value,
|
||||
"book_id": self.book_id_,
|
||||
"strategy_id": self.strategy_id_,
|
||||
"issued_ts_ns": int(time.time() * NanoPerSec),
|
||||
"data": {
|
||||
"strength": strength,
|
||||
"base_asset": base_asset,
|
||||
"quote_asset": quote_asset,
|
||||
"user_data": {},
|
||||
},
|
||||
}
|
||||
|
||||
rqst = REST_RequestProcessor(method=self.ti_method_, url=self.ti_url_, params=instr)
|
||||
async with rqst as (status, msg, headers):
|
||||
if status != 200:
|
||||
raise ConnectionError(f"Failed to send trading instructions: {msg}")
|
||||
return (status, msg)
|
||||
|
||||
|
||||
@ -18,6 +18,20 @@ class PairState(Enum):
|
||||
CLOSE_STOP_LOSS = 5
|
||||
CLOSE_STOP_PROFIT = 6
|
||||
|
||||
|
||||
def get_symbol(instrument: Dict[str, str]) -> str:
|
||||
if "symbol" in instrument:
|
||||
return instrument["symbol"]
|
||||
elif "instrument_id" in instrument:
|
||||
instrument_id = instrument["instrument_id"]
|
||||
instrument_pfx = instrument_id[:instrument_id.find("-") + 1]
|
||||
symbol = instrument_id[len(instrument_pfx):]
|
||||
instrument["symbol"] = symbol
|
||||
instrument["instrument_id_pfx"] = instrument_pfx
|
||||
return symbol
|
||||
else:
|
||||
raise ValueError(f"Invalid instrument: {instrument}, missing symbol or instrument_id")
|
||||
|
||||
class TradingPair:
|
||||
config_: Dict[str, Any]
|
||||
market_data_: pd.DataFrame
|
||||
@ -42,14 +56,32 @@ class TradingPair:
|
||||
|
||||
self.config_ = config
|
||||
self.instruments_ = instruments
|
||||
self.symbol_a_ = instruments[0]["symbol"]
|
||||
self.symbol_b_ = instruments[1]["symbol"]
|
||||
self.symbol_a_ = get_symbol(instruments[0])
|
||||
self.symbol_b_ = get_symbol(instruments[1])
|
||||
self.model_ = PairsTradingModel.create(config)
|
||||
self.stat_model_price_ = config["stat_model_price"]
|
||||
self.user_data_ = {
|
||||
"state": PairState.INITIAL,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}:"
|
||||
f" symbol_a={self.symbol_a_},"
|
||||
f" symbol_b={self.symbol_b_},"
|
||||
f" model={self.model_.__class__.__name__}"
|
||||
)
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self.user_data_["state"] in [
|
||||
PairState.CLOSE,
|
||||
PairState.CLOSE_POSITION,
|
||||
PairState.CLOSE_STOP_LOSS,
|
||||
PairState.CLOSE_STOP_PROFIT,
|
||||
]
|
||||
def is_open(self) -> bool:
|
||||
return self.user_data_["state"] == PairState.OPEN
|
||||
|
||||
def colnames(self) -> List[str]:
|
||||
return [
|
||||
f"{self.stat_model_price_}_{self.symbol_a_}",
|
||||
|
||||
@ -78,6 +78,7 @@ scipy<1.13.0
|
||||
seaborn>=0.13.2
|
||||
SecretStorage>=3.3.1
|
||||
setproctitle>=1.2.2
|
||||
simpleeval>=1.0.3
|
||||
six>=1.16.0
|
||||
soupsieve>=2.3.1
|
||||
ssh-import-id>=5.11
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user