dev progress

This commit is contained in:
Oleg Sheynin 2025-12-31 08:03:26 +00:00
parent 121c85def0
commit 69a0b19e9f
5 changed files with 117 additions and 75 deletions

View File

@ -13,6 +13,7 @@ from cvttpy_tools.logger import Log
from cvttpy_trading.trading.instrument import ExchangeInstrument from cvttpy_trading.trading.instrument import ExchangeInstrument
from cvttpy_trading.trading.active_instruments import Instruments from cvttpy_trading.trading.active_instruments import Instruments
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
from cvttpy_trading.trading.exchange_config import ExchangeAccounts
# --- # ---
from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy
from pairs_trading.lib.live.mkt_data_client import CvttRestMktDataClient, MdSummary from pairs_trading.lib.live.mkt_data_client import CvttRestMktDataClient, MdSummary
@ -42,7 +43,7 @@ UpdateMdCbT = Callable[[MdTradesAggregate], Coroutine]
class PairsTrader(NamedObject): class PairsTrader(NamedObject):
config_: CvttAppConfig config_: CvttAppConfig
instruments_: List[JsonDictT] instruments_: List[ExchangeInstrument]
live_strategy_: PtLiveStrategy live_strategy_: PtLiveStrategy
pricer_client_: CvttRestMktDataClient pricer_client_: CvttRestMktDataClient
@ -72,20 +73,22 @@ class PairsTrader(NamedObject):
if not instr_str: if not instr_str:
raise ValueError("Pair is required") raise ValueError("Pair is required")
instr_list = instr_str.split(",") instr_list = instr_str.split(",")
assert len(instr_list) == 2, "Only two instruments are supported"
for instr in instr_list: for instr in instr_list:
instr_parts = instr.split(":") instr_parts = instr.split(":")
if len(instr_parts) != 2: if len(instr_parts) != 2:
raise ValueError(f"Invalid pair format: {instr}") raise ValueError(f"Invalid pair format: {instr}")
instrument_id = instr_parts[0] instrument_id = instr_parts[0]
exch_acct = instr_parts[1] exch_acct = instr_parts[1]
exch_inst = Instruments.instance() exch_inst = ExchangeAccounts.instance().get_exchange_instrument(exch_acct=exch_acct, instrument_id=instrument_id)
self.instruments_.append({
"exch_acct": exch_acct,
"instrument_id": instrument_id
})
assert len(self.instruments_) == 2, "Only two instruments are supported" assert exch_inst is not None, f"No ExchangeInstrument for {instr}"
Log.info(f"{self.fname()} Instruments: {self.instruments_}") exch_inst.user_data_["exch_acct"] = exch_acct
self.instruments_.append(exch_inst)
Log.info(f"{self.fname()} Instruments: {self.instruments_[0].details_short()} <==> {self.instruments_[1].details_short()}")
# ------- CREATE CVTT CLIENT ------- # ------- CREATE CVTT CLIENT -------
@ -95,7 +98,7 @@ class PairsTrader(NamedObject):
# ------- CREATE STRATEGY ------- # ------- CREATE STRATEGY -------
strategy_config = self.config_.get_value("strategy_config", {}) strategy_config = self.config_.get_subconfig("strategy_config", Config({}))
self.live_strategy_ = PtLiveStrategy( self.live_strategy_ = PtLiveStrategy(
config=strategy_config, config=strategy_config,
instruments=self.instruments_, instruments=self.instruments_,
@ -104,7 +107,7 @@ class PairsTrader(NamedObject):
Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}") Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}")
# # ------- CREATE PRICER CLIENT ------- # # ------- CREATE PRICER CLIENT -------
self.pricer_client_ = CvttRestMktDataClient(self.config_) self.pricer_client_ = CvttRestMktDataClient(config=self.config_)
# ------- CREATE TRADER CLIENT ------- # ------- CREATE TRADER CLIENT -------
# URGENT CREATE TRADER CLIENT # URGENT CREATE TRADER CLIENT
@ -118,9 +121,10 @@ class PairsTrader(NamedObject):
# URGENT CREATE REST SERVER for dashboard communications # URGENT CREATE REST SERVER for dashboard communications
async def subscribe_md(self) -> None: async def subscribe_md(self) -> None:
for inst in self.instruments_: for exch_inst in self.instruments_:
exch_acct = inst.get("exch_acct", "?exch_acct?") exch_acct = exch_inst.user_data_.get("exch_acct", "?exch_acct?")
instrument_id = inst.get("instrument_id", "?instrument_id?") instrument_id = exch_inst.instrument_id()
await self.pricer_client_.add_subscription( await self.pricer_client_.add_subscription(
exch_acct=exch_acct, exch_acct=exch_acct,
instrument_id=instrument_id, instrument_id=instrument_id,
@ -129,8 +133,10 @@ class PairsTrader(NamedObject):
callback=self._on_md_summary callback=self._on_md_summary
) )
def _on_md_summary(self, history: List[MdSummary]) -> None: async def _on_md_summary(self, history: List[MdSummary]) -> None:
pass # URGENT # depth = len(history)
# if depth < 2:
pass # URGENT
async def run(self) -> None: async def run(self) -> None:
Log.info(f"{self.fname()} ...") Log.info(f"{self.fname()} ...")

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import Callable, Dict, Any, List, Optional, Set from typing import Callable, Coroutine, Dict, Any, List, Optional, Set
import requests import requests
@ -60,7 +60,7 @@ class MdSummary(HistMdBar):
) )
return res return res
MdSummaryCallbackT = Callable[[List[MdSummary]], None] MdSummaryCallbackT = Callable[[List[MdSummary]], Coroutine]
class MdSummaryCollector(NamedObject): class MdSummaryCollector(NamedObject):
sender_: RESTSender sender_: RESTSender
@ -134,7 +134,7 @@ class MdSummaryCollector(NamedObject):
Log.error(f"{self.fname()}: Timer is already started") Log.error(f"{self.fname()}: Timer is already started")
return return
self.history_ = self.get_history() self.history_ = self.get_history()
self.run_callbacks() await self.run_callbacks()
self.set_timer() self.set_timer()
def set_timer(self): def set_timer(self):
@ -158,11 +158,11 @@ class MdSummaryCollector(NamedObject):
Log.info(f"{self.fname()}: Received {last}. Already Have: {self.history_[-1]}") Log.info(f"{self.fname()}: Received {last}. Already Have: {self.history_[-1]}")
else: else:
self.history_.append(last) self.history_.append(last)
self.run_callbacks() await self.run_callbacks()
self.set_timer() self.set_timer()
def run_callbacks(self) -> None: async def run_callbacks(self) -> None:
[cb(self.history_) for cb in self.callbacks_] [await cb(self.history_) for cb in self.callbacks_]
def stop(self) -> None: def stop(self) -> None:
if self.timer_: if self.timer_:
@ -203,7 +203,7 @@ if __name__ == "__main__":
config = Config(json_src={"cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456"}) config = Config(json_src={"cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456"})
# config = Config(json_src={"cvtt_base_url": "http://dev-server-02.cvtt.vpn:23456"}) # config = Config(json_src={"cvtt_base_url": "http://dev-server-02.cvtt.vpn:23456"})
def _calback(history: List[MdSummary]) -> None: async def _calback(history: List[MdSummary]) -> None:
Log.info(f"MdSummary Hist Length is {len(history)}. Last summary: {history[-1] if len(history) > 0 else '[]'}") Log.info(f"MdSummary Hist Length is {len(history)}. Last summary: {history[-1] if len(history) > 0 else '[]'}")

View File

@ -1,22 +1,28 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
from enum import Enum from enum import Enum
import pandas as pd import pandas as pd
# --- # ---
from cvttpy_tools.base import NamedObject from cvttpy_tools.base import NamedObject
from cvttpy_tools.app import App from cvttpy_tools.app import App
from cvttpy_tools.config import Config
from cvttpy_tools.settings.cvtt_types import IntervalSecT from cvttpy_tools.settings.cvtt_types import IntervalSecT
from cvttpy_tools.timeutils import SecPerHour
# --- # ---
from cvttpy_trading.trading.instrument import ExchangeInstrument from cvttpy_trading.trading.instrument import ExchangeInstrument
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
# --- # ---
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
from pairs_trading.lib.pt_strategy.pt_model import Prediction from pairs_trading.lib.pt_strategy.pt_model import Prediction
from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair
from pairs_trading.apps.pairs_trader import PairsTrader from pairs_trading.apps.pairs_trader import PairsTrader
""" """
--config=pair.cfg --config=pair.cfg
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT --pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
@ -26,6 +32,7 @@ from pairs_trading.apps.pairs_trader import PairsTrader
class TradingInstructionType(Enum): class TradingInstructionType(Enum):
TARGET_POSITION = "TARGET_POSITION" TARGET_POSITION = "TARGET_POSITION"
@dataclass @dataclass
class TradingInstruction(NamedObject): class TradingInstruction(NamedObject):
type_: TradingInstructionType type_: TradingInstructionType
@ -34,57 +41,84 @@ class TradingInstruction(NamedObject):
class PtLiveStrategy(NamedObject): class PtLiveStrategy(NamedObject):
config_: Dict[str, Any] config_: Config
instruments_: List[ExchangeInstrument]
interval_sec_: IntervalSecT
history_depth_sec_: IntervalSecT
open_threshold_: float
close_threshold_: float
trading_pair_: TradingPair trading_pair_: TradingPair
model_data_policy_: ModelDataPolicy model_data_policy_: ModelDataPolicy
pairs_trader_: PairsTrader pairs_trader_: PairsTrader
# ti_sender_: TradingInstructionsSender # ti_sender_: TradingInstructionsSender
# for presentation: history of prediction values and trading signals # for presentation: history of prediction values and trading signals
predictions_: pd.DataFrame predictions_df_: pd.DataFrame
trading_signals_: pd.DataFrame trading_signals_df_: pd.DataFrame
def __init__( def __init__(
self, self,
config: Dict[str, Any], config: Config,
instruments: List[Dict[str, str]], instruments: List[ExchangeInstrument],
pairs_trader: PairsTrader, pairs_trader: PairsTrader,
): ):
self.config_ = config self.trading_pair_ = TradingPair(
self.trading_pair_ = TradingPair(config=config, instruments=instruments) config=cast(Dict[str, Any], config.data()),
self.predictions_ = pd.DataFrame() instruments=[{"instrument_id": ei.instrument_id()} for ei in instruments],
self.trading_signals_ = pd.DataFrame() )
self.predictions_df_ = pd.DataFrame()
self.trading_signals_df_ = pd.DataFrame()
self.pairs_trader_ = pairs_trader self.pairs_trader_ = pairs_trader
import copy import copy
# modified config must be passed to PtMarketData # modified config must be passed to PtMarketData
config_copy = copy.deepcopy(config) self.config_ = Config(json_src=copy.deepcopy(config.data()))
config_copy["instruments"] = instruments
self.config_ = config_copy self.instruments_ = instruments
App.instance().add_call(stage=App.Stage.Config, func=self._on_config(), can_run_now=True) App.instance().add_call(
stage=App.Stage.Config, func=self._on_config(), can_run_now=True
)
async def _on_config(self) -> None: async def _on_config(self) -> None:
self.interval_sec_ = self.config_.get_value("interval_sec", 0)
self.history_depth_sec_ = (
self.config_.get_value("history_depth_hours", 0) * SecPerHour
)
await self.pairs_trader_.subscribe_md() await self.pairs_trader_.subscribe_md()
self.model_data_policy_ = ModelDataPolicy.create( self.model_data_policy_ = ModelDataPolicy.create(
self.config_, is_real_time=True, pair=self.trading_pair_ self.config_, is_real_time=True, pair=self.trading_pair_
) )
self.open_threshold_ = self.config_.get("dis-equilibrium_open_trshld", 0.0) self.open_threshold_ = self.config_.get_value(
assert self.open_threshold_ > 0, "open_threshold must be greater than 0" "dis-equilibrium_open_trshld", 0.0
self.close_threshold_ = self.config_.get("dis-equilibrium_close_trshld", 0.0) )
assert self.close_threshold_ > 0, "close_threshold must be greater than 0" self.close_threshold_ = self.config_.get_value(
"dis-equilibrium_close_trshld", 0.0
)
assert (
self.open_threshold_ > 0
), "dis-equilibrium_open_trshld must be greater than 0"
assert (
self.close_threshold_ > 0
), "dis-equilibrium_close_trshld must be greater than 0"
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, " return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, "
async def on_mkt_data_hist_snapshot(self, hist_aggr: List[MdTradesAggregate]) -> None: async def on_mkt_data_hist_snapshot(
self, hist_aggr: List[MdTradesAggregate]
) -> None:
# Log.info(f"on_mkt_data_hist_snapshot: {aggr}") # Log.info(f"on_mkt_data_hist_snapshot: {aggr}")
# await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr) # await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr)
pass # URGENT PtiveStrategy.on_mkt_data_hist_snapshot() pass # URGENT PtiveStrategy.on_mkt_data_hist_snapshot()
async def on_mkt_data_update(self, aggr: MdTradesAggregate) -> None: async def on_mkt_data_update(self, aggr: MdTradesAggregate) -> None:
# if market_data_df is not None: # if market_data_df is not None:
@ -105,18 +139,18 @@ class PtLiveStrategy(NamedObject):
# if len(trading_instructions) > 0: # if len(trading_instructions) > 0:
# await self._send_trading_instructions(trading_instructions) # await self._send_trading_instructions(trading_instructions)
# # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1]) # # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
pass # URGENT pass # URGENT
def interval_sec(self) -> IntervalSecT: def interval_sec(self) -> IntervalSecT:
return 60 # URGENT use config return self.interval_sec_
def history_depth_sec(self) -> IntervalSecT: def history_depth_sec(self) -> IntervalSecT:
return 3600 * 60 * 2 # URGENT use config return self.history_depth_sec_
async def _send_trading_instructions( async def _send_trading_instructions(
self, trading_instructions: List[TradingInstruction] self, trading_instructions: List[TradingInstruction]
) -> None: ) -> None:
pass # URGENT implement _send_trading_instructions pass # URGENT implement _send_trading_instructions
def _create_trading_instructions( def _create_trading_instructions(
self, prediction: Prediction, last_row: pd.Series self, prediction: Prediction, last_row: pd.Series
@ -135,7 +169,7 @@ class PtLiveStrategy(NamedObject):
elif pair.is_open(): elif pair.is_open():
if abs_scaled_disequilibrium <= self.close_threshold_: if abs_scaled_disequilibrium <= self.close_threshold_:
trd_instructions = self._create_close_trade_instructions( trd_instructions = self._create_close_trade_instructions(
pair, row=last_row #, prediction=prediction pair, row=last_row # , prediction=prediction
) )
elif pair.to_stop_close_conditions(predicted_row=last_row): elif pair.to_stop_close_conditions(predicted_row=last_row):
trd_instructions = self._create_close_trade_instructions( trd_instructions = self._create_close_trade_instructions(
@ -204,16 +238,15 @@ class PtLiveStrategy(NamedObject):
"signed_scaled_disequilibrium": scaled_disequilibrium, "signed_scaled_disequilibrium": scaled_disequilibrium,
# "pair": pair, # "pair": pair,
} }
ti: List[TradingInstruction] =self._create_trading_instructions( ti: List[TradingInstruction] = self._create_trading_instructions(
prediction=prediction, last_row=row prediction=prediction, last_row=row
) )
return ti return ti
def _create_close_trade_instructions( def _create_close_trade_instructions(
self, pair: TradingPair, row: pd.Series #, prediction: Prediction self, pair: TradingPair, row: pd.Series # , prediction: Prediction
) -> List[TradingInstruction]: ) -> List[TradingInstruction]:
return [] # URGENT implement _create_close_trade_instructions return [] # URGENT implement _create_close_trade_instructions
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]: def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
trades = None trades = None
@ -223,7 +256,7 @@ class PtLiveStrategy(NamedObject):
if pair.user_data_["state"] == PairState.OPEN: if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: *** Position is NOT CLOSED. ***") print(f"{pair}: *** Position is NOT CLOSED. ***")
# outstanding positions # outstanding positions
if self.config_["close_outstanding_positions"]: if self.config_.key_exists("close_outstanding_positions"):
close_position_row = pd.Series(pair.market_data_.iloc[-2]) close_position_row = pd.Series(pair.market_data_.iloc[-2])
# close_position_row["disequilibrium"] = 0.0 # close_position_row["disequilibrium"] = 0.0
# close_position_row["scaled_disequilibrium"] = 0.0 # close_position_row["scaled_disequilibrium"] = 0.0

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, cast
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from cvttpy_tools.config import Config
@dataclass @dataclass
class DataWindowParams: class DataWindowParams:
@ -16,22 +17,22 @@ class DataWindowParams:
class ModelDataPolicy(ABC): class ModelDataPolicy(ABC):
config_: Dict[str, Any] config_: Config
current_data_params_: DataWindowParams current_data_params_: DataWindowParams
count_: int count_: int
is_real_time_: bool is_real_time_: bool
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Config, *args: Any, **kwargs: Any):
self.config_ = config self.config_ = config
training_size = config.get("training_size", 120) training_size = config.get_value("training_size", 120)
training_start_index = 0 training_start_index = 0
if kwargs.get("is_real_time", False): if kwargs.get("is_real_time", False):
training_size = 120 training_size = 120
training_start_index = 0 training_start_index = 0
else: else:
training_size = config.get("training_size", 120) training_size = config.get_value("training_size", 120)
self.current_data_params_ = DataWindowParams( self.current_data_params_ = DataWindowParams(
training_size=config.get("training_size", 120), training_size=config.get_value("training_size", 120),
training_start_index=0, training_start_index=0,
) )
self.count_ = 0 self.count_ = 0
@ -44,10 +45,10 @@ class ModelDataPolicy(ABC):
return self.current_data_params_ return self.current_data_params_
@staticmethod @staticmethod
def create(config: Dict[str, Any], *args: Any, **kwargs: Any) -> ModelDataPolicy: def create(config: Config, *args: Any, **kwargs: Any) -> ModelDataPolicy:
import importlib import importlib
model_data_policy_class_name = config.get("model_data_policy_class", None) model_data_policy_class_name = config.get_value("model_data_policy_class", None)
assert model_data_policy_class_name is not None assert model_data_policy_class_name is not None
module_name, class_name = model_data_policy_class_name.rsplit(".", 1) module_name, class_name = model_data_policy_class_name.rsplit(".", 1)
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
@ -58,7 +59,7 @@ class ModelDataPolicy(ABC):
class RollingWindowDataPolicy(ModelDataPolicy): class RollingWindowDataPolicy(ModelDataPolicy):
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Config, *args: Any, **kwargs: Any):
super().__init__(config, *args, **kwargs) super().__init__(config, *args, **kwargs)
self.count_ = 1 self.count_ = 1
@ -80,16 +81,16 @@ class OptimizedWndDataPolicy(ModelDataPolicy, ABC):
prices_a_: np.ndarray prices_a_: np.ndarray
prices_b_: np.ndarray prices_b_: np.ndarray
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Config, *args: Any, **kwargs: Any):
super().__init__(config, *args, **kwargs) super().__init__(config, *args, **kwargs)
assert ( assert (
kwargs.get("pair") is not None kwargs.get("pair") is not None
), "pair must be provided" ), "pair must be provided"
assert ( assert (
"min_training_size" in config and "max_training_size" in config "min_training_size" in config.data() and "max_training_size" in config.data()
), "min_training_size and max_training_size must be provided" ), "min_training_size and max_training_size must be provided"
self.min_training_size_ = cast(int, config.get("min_training_size")) self.min_training_size_ = cast(int, config.get_value("min_training_size"))
self.max_training_size_ = cast(int, config.get("max_training_size")) self.max_training_size_ = cast(int, config.get_value("max_training_size"))
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
self.pair_ = cast(TradingPair, kwargs.get("pair")) self.pair_ = cast(TradingPair, kwargs.get("pair"))
@ -133,7 +134,7 @@ class EGOptimizedWndDataPolicy(OptimizedWndDataPolicy):
# Engle-Granger cointegration test # Engle-Granger cointegration test
*** VERY SLOW *** *** VERY SLOW ***
''' '''
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Config, *args: Any, **kwargs: Any):
super().__init__(config, *args, **kwargs) super().__init__(config, *args, **kwargs)
def optimize_window_size(self) -> DataWindowParams: def optimize_window_size(self) -> DataWindowParams:
@ -162,7 +163,7 @@ class EGOptimizedWndDataPolicy(OptimizedWndDataPolicy):
class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy): class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy):
# Augmented Dickey-Fuller test # Augmented Dickey-Fuller test
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Config, *args: Any, **kwargs: Any):
super().__init__(config, *args, **kwargs) super().__init__(config, *args, **kwargs)
def optimize_window_size(self) -> DataWindowParams: def optimize_window_size(self) -> DataWindowParams:
@ -208,7 +209,7 @@ class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy):
class JohansenOptdWndDataPolicy(OptimizedWndDataPolicy): class JohansenOptdWndDataPolicy(OptimizedWndDataPolicy):
# Johansen test # Johansen test
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Config, *args: Any, **kwargs: Any):
super().__init__(config, *args, **kwargs) super().__init__(config, *args, **kwargs)
def optimize_window_size(self) -> DataWindowParams: def optimize_window_size(self) -> DataWindowParams:

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
from cvttpy_tools.config import Config
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
from pairs_trading.lib.pt_strategy.pt_market_data import ResearchMarketData from pairs_trading.lib.pt_strategy.pt_market_data import ResearchMarketData
from pairs_trading.lib.pt_strategy.pt_model import Prediction from pairs_trading.lib.pt_strategy.pt_model import Prediction
@ -41,7 +43,7 @@ class PtResearchStrategy:
self.pt_mkt_data_ = ResearchMarketData(config=config_copy) self.pt_mkt_data_ = ResearchMarketData(config=config_copy)
self.pt_mkt_data_.load() self.pt_mkt_data_.load()
self.model_data_policy_ = ModelDataPolicy.create( self.model_data_policy_ = ModelDataPolicy.create(
config, mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_ Config(config_copy), mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_
) )
def outstanding_positions(self) -> List[Dict[str, Any]]: def outstanding_positions(self) -> List[Dict[str, Any]]: