dev progress
This commit is contained in:
parent
121c85def0
commit
69a0b19e9f
@ -13,6 +13,7 @@ from cvttpy_tools.logger import Log
|
|||||||
from cvttpy_trading.trading.instrument import ExchangeInstrument
|
from cvttpy_trading.trading.instrument import ExchangeInstrument
|
||||||
from cvttpy_trading.trading.active_instruments import Instruments
|
from cvttpy_trading.trading.active_instruments import Instruments
|
||||||
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
|
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
|
||||||
|
from cvttpy_trading.trading.exchange_config import ExchangeAccounts
|
||||||
# ---
|
# ---
|
||||||
from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy
|
from pairs_trading.lib.pt_strategy.live.live_strategy import PtLiveStrategy
|
||||||
from pairs_trading.lib.live.mkt_data_client import CvttRestMktDataClient, MdSummary
|
from pairs_trading.lib.live.mkt_data_client import CvttRestMktDataClient, MdSummary
|
||||||
@ -42,7 +43,7 @@ UpdateMdCbT = Callable[[MdTradesAggregate], Coroutine]
|
|||||||
|
|
||||||
class PairsTrader(NamedObject):
|
class PairsTrader(NamedObject):
|
||||||
config_: CvttAppConfig
|
config_: CvttAppConfig
|
||||||
instruments_: List[JsonDictT]
|
instruments_: List[ExchangeInstrument]
|
||||||
|
|
||||||
live_strategy_: PtLiveStrategy
|
live_strategy_: PtLiveStrategy
|
||||||
pricer_client_: CvttRestMktDataClient
|
pricer_client_: CvttRestMktDataClient
|
||||||
@ -72,20 +73,22 @@ class PairsTrader(NamedObject):
|
|||||||
if not instr_str:
|
if not instr_str:
|
||||||
raise ValueError("Pair is required")
|
raise ValueError("Pair is required")
|
||||||
instr_list = instr_str.split(",")
|
instr_list = instr_str.split(",")
|
||||||
|
|
||||||
|
assert len(instr_list) == 2, "Only two instruments are supported"
|
||||||
|
|
||||||
for instr in instr_list:
|
for instr in instr_list:
|
||||||
instr_parts = instr.split(":")
|
instr_parts = instr.split(":")
|
||||||
if len(instr_parts) != 2:
|
if len(instr_parts) != 2:
|
||||||
raise ValueError(f"Invalid pair format: {instr}")
|
raise ValueError(f"Invalid pair format: {instr}")
|
||||||
instrument_id = instr_parts[0]
|
instrument_id = instr_parts[0]
|
||||||
exch_acct = instr_parts[1]
|
exch_acct = instr_parts[1]
|
||||||
exch_inst = Instruments.instance()
|
exch_inst = ExchangeAccounts.instance().get_exchange_instrument(exch_acct=exch_acct, instrument_id=instrument_id)
|
||||||
self.instruments_.append({
|
|
||||||
"exch_acct": exch_acct,
|
|
||||||
"instrument_id": instrument_id
|
|
||||||
})
|
|
||||||
|
|
||||||
assert len(self.instruments_) == 2, "Only two instruments are supported"
|
assert exch_inst is not None, f"No ExchangeInstrument for {instr}"
|
||||||
Log.info(f"{self.fname()} Instruments: {self.instruments_}")
|
exch_inst.user_data_["exch_acct"] = exch_acct
|
||||||
|
self.instruments_.append(exch_inst)
|
||||||
|
|
||||||
|
Log.info(f"{self.fname()} Instruments: {self.instruments_[0].details_short()} <==> {self.instruments_[1].details_short()}")
|
||||||
|
|
||||||
|
|
||||||
# ------- CREATE CVTT CLIENT -------
|
# ------- CREATE CVTT CLIENT -------
|
||||||
@ -95,7 +98,7 @@ class PairsTrader(NamedObject):
|
|||||||
|
|
||||||
|
|
||||||
# ------- CREATE STRATEGY -------
|
# ------- CREATE STRATEGY -------
|
||||||
strategy_config = self.config_.get_value("strategy_config", {})
|
strategy_config = self.config_.get_subconfig("strategy_config", Config({}))
|
||||||
self.live_strategy_ = PtLiveStrategy(
|
self.live_strategy_ = PtLiveStrategy(
|
||||||
config=strategy_config,
|
config=strategy_config,
|
||||||
instruments=self.instruments_,
|
instruments=self.instruments_,
|
||||||
@ -104,7 +107,7 @@ class PairsTrader(NamedObject):
|
|||||||
Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}")
|
Log.info(f"{self.fname()} Strategy created: {self.live_strategy_}")
|
||||||
|
|
||||||
# # ------- CREATE PRICER CLIENT -------
|
# # ------- CREATE PRICER CLIENT -------
|
||||||
self.pricer_client_ = CvttRestMktDataClient(self.config_)
|
self.pricer_client_ = CvttRestMktDataClient(config=self.config_)
|
||||||
|
|
||||||
# ------- CREATE TRADER CLIENT -------
|
# ------- CREATE TRADER CLIENT -------
|
||||||
# URGENT CREATE TRADER CLIENT
|
# URGENT CREATE TRADER CLIENT
|
||||||
@ -118,9 +121,10 @@ class PairsTrader(NamedObject):
|
|||||||
# URGENT CREATE REST SERVER for dashboard communications
|
# URGENT CREATE REST SERVER for dashboard communications
|
||||||
|
|
||||||
async def subscribe_md(self) -> None:
|
async def subscribe_md(self) -> None:
|
||||||
for inst in self.instruments_:
|
for exch_inst in self.instruments_:
|
||||||
exch_acct = inst.get("exch_acct", "?exch_acct?")
|
exch_acct = exch_inst.user_data_.get("exch_acct", "?exch_acct?")
|
||||||
instrument_id = inst.get("instrument_id", "?instrument_id?")
|
instrument_id = exch_inst.instrument_id()
|
||||||
|
|
||||||
await self.pricer_client_.add_subscription(
|
await self.pricer_client_.add_subscription(
|
||||||
exch_acct=exch_acct,
|
exch_acct=exch_acct,
|
||||||
instrument_id=instrument_id,
|
instrument_id=instrument_id,
|
||||||
@ -129,7 +133,9 @@ class PairsTrader(NamedObject):
|
|||||||
callback=self._on_md_summary
|
callback=self._on_md_summary
|
||||||
)
|
)
|
||||||
|
|
||||||
def _on_md_summary(self, history: List[MdSummary]) -> None:
|
async def _on_md_summary(self, history: List[MdSummary]) -> None:
|
||||||
|
# depth = len(history)
|
||||||
|
# if depth < 2:
|
||||||
pass # URGENT
|
pass # URGENT
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Callable, Dict, Any, List, Optional, Set
|
from typing import Callable, Coroutine, Dict, Any, List, Optional, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class MdSummary(HistMdBar):
|
|||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
MdSummaryCallbackT = Callable[[List[MdSummary]], None]
|
MdSummaryCallbackT = Callable[[List[MdSummary]], Coroutine]
|
||||||
|
|
||||||
class MdSummaryCollector(NamedObject):
|
class MdSummaryCollector(NamedObject):
|
||||||
sender_: RESTSender
|
sender_: RESTSender
|
||||||
@ -134,7 +134,7 @@ class MdSummaryCollector(NamedObject):
|
|||||||
Log.error(f"{self.fname()}: Timer is already started")
|
Log.error(f"{self.fname()}: Timer is already started")
|
||||||
return
|
return
|
||||||
self.history_ = self.get_history()
|
self.history_ = self.get_history()
|
||||||
self.run_callbacks()
|
await self.run_callbacks()
|
||||||
self.set_timer()
|
self.set_timer()
|
||||||
|
|
||||||
def set_timer(self):
|
def set_timer(self):
|
||||||
@ -158,11 +158,11 @@ class MdSummaryCollector(NamedObject):
|
|||||||
Log.info(f"{self.fname()}: Received {last}. Already Have: {self.history_[-1]}")
|
Log.info(f"{self.fname()}: Received {last}. Already Have: {self.history_[-1]}")
|
||||||
else:
|
else:
|
||||||
self.history_.append(last)
|
self.history_.append(last)
|
||||||
self.run_callbacks()
|
await self.run_callbacks()
|
||||||
self.set_timer()
|
self.set_timer()
|
||||||
|
|
||||||
def run_callbacks(self) -> None:
|
async def run_callbacks(self) -> None:
|
||||||
[cb(self.history_) for cb in self.callbacks_]
|
[await cb(self.history_) for cb in self.callbacks_]
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if self.timer_:
|
if self.timer_:
|
||||||
@ -203,7 +203,7 @@ if __name__ == "__main__":
|
|||||||
config = Config(json_src={"cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456"})
|
config = Config(json_src={"cvtt_base_url": "http://cvtt-tester-01.cvtt.vpn:23456"})
|
||||||
# config = Config(json_src={"cvtt_base_url": "http://dev-server-02.cvtt.vpn:23456"})
|
# config = Config(json_src={"cvtt_base_url": "http://dev-server-02.cvtt.vpn:23456"})
|
||||||
|
|
||||||
def _calback(history: List[MdSummary]) -> None:
|
async def _calback(history: List[MdSummary]) -> None:
|
||||||
Log.info(f"MdSummary Hist Length is {len(history)}. Last summary: {history[-1] if len(history) > 0 else '[]'}")
|
Log.info(f"MdSummary Hist Length is {len(history)}. Last summary: {history[-1] if len(history) > 0 else '[]'}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +1,28 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
from cvttpy_tools.base import NamedObject
|
from cvttpy_tools.base import NamedObject
|
||||||
from cvttpy_tools.app import App
|
from cvttpy_tools.app import App
|
||||||
|
from cvttpy_tools.config import Config
|
||||||
from cvttpy_tools.settings.cvtt_types import IntervalSecT
|
from cvttpy_tools.settings.cvtt_types import IntervalSecT
|
||||||
|
from cvttpy_tools.timeutils import SecPerHour
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
from cvttpy_trading.trading.instrument import ExchangeInstrument
|
from cvttpy_trading.trading.instrument import ExchangeInstrument
|
||||||
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
|
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
|
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
|
||||||
from pairs_trading.lib.pt_strategy.pt_model import Prediction
|
from pairs_trading.lib.pt_strategy.pt_model import Prediction
|
||||||
from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair
|
from pairs_trading.lib.pt_strategy.trading_pair import PairState, TradingPair
|
||||||
from pairs_trading.apps.pairs_trader import PairsTrader
|
from pairs_trading.apps.pairs_trader import PairsTrader
|
||||||
|
|
||||||
"""
|
"""
|
||||||
--config=pair.cfg
|
--config=pair.cfg
|
||||||
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
|
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
|
||||||
@ -26,6 +32,7 @@ from pairs_trading.apps.pairs_trader import PairsTrader
|
|||||||
class TradingInstructionType(Enum):
|
class TradingInstructionType(Enum):
|
||||||
TARGET_POSITION = "TARGET_POSITION"
|
TARGET_POSITION = "TARGET_POSITION"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TradingInstruction(NamedObject):
|
class TradingInstruction(NamedObject):
|
||||||
type_: TradingInstructionType
|
type_: TradingInstructionType
|
||||||
@ -34,7 +41,14 @@ class TradingInstruction(NamedObject):
|
|||||||
|
|
||||||
|
|
||||||
class PtLiveStrategy(NamedObject):
|
class PtLiveStrategy(NamedObject):
|
||||||
config_: Dict[str, Any]
|
config_: Config
|
||||||
|
instruments_: List[ExchangeInstrument]
|
||||||
|
|
||||||
|
interval_sec_: IntervalSecT
|
||||||
|
history_depth_sec_: IntervalSecT
|
||||||
|
open_threshold_: float
|
||||||
|
close_threshold_: float
|
||||||
|
|
||||||
trading_pair_: TradingPair
|
trading_pair_: TradingPair
|
||||||
model_data_policy_: ModelDataPolicy
|
model_data_policy_: ModelDataPolicy
|
||||||
pairs_trader_: PairsTrader
|
pairs_trader_: PairsTrader
|
||||||
@ -42,49 +56,69 @@ class PtLiveStrategy(NamedObject):
|
|||||||
# ti_sender_: TradingInstructionsSender
|
# ti_sender_: TradingInstructionsSender
|
||||||
|
|
||||||
# for presentation: history of prediction values and trading signals
|
# for presentation: history of prediction values and trading signals
|
||||||
predictions_: pd.DataFrame
|
predictions_df_: pd.DataFrame
|
||||||
trading_signals_: pd.DataFrame
|
trading_signals_df_: pd.DataFrame
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Dict[str, Any],
|
config: Config,
|
||||||
instruments: List[Dict[str, str]],
|
instruments: List[ExchangeInstrument],
|
||||||
pairs_trader: PairsTrader,
|
pairs_trader: PairsTrader,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.config_ = config
|
self.trading_pair_ = TradingPair(
|
||||||
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
|
config=cast(Dict[str, Any], config.data()),
|
||||||
self.predictions_ = pd.DataFrame()
|
instruments=[{"instrument_id": ei.instrument_id()} for ei in instruments],
|
||||||
self.trading_signals_ = pd.DataFrame()
|
)
|
||||||
|
self.predictions_df_ = pd.DataFrame()
|
||||||
|
self.trading_signals_df_ = pd.DataFrame()
|
||||||
self.pairs_trader_ = pairs_trader
|
self.pairs_trader_ = pairs_trader
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
# modified config must be passed to PtMarketData
|
# modified config must be passed to PtMarketData
|
||||||
config_copy = copy.deepcopy(config)
|
self.config_ = Config(json_src=copy.deepcopy(config.data()))
|
||||||
config_copy["instruments"] = instruments
|
|
||||||
self.config_ = config_copy
|
|
||||||
|
|
||||||
App.instance().add_call(stage=App.Stage.Config, func=self._on_config(), can_run_now=True)
|
self.instruments_ = instruments
|
||||||
|
|
||||||
|
App.instance().add_call(
|
||||||
|
stage=App.Stage.Config, func=self._on_config(), can_run_now=True
|
||||||
|
)
|
||||||
|
|
||||||
async def _on_config(self) -> None:
|
async def _on_config(self) -> None:
|
||||||
|
self.interval_sec_ = self.config_.get_value("interval_sec", 0)
|
||||||
|
self.history_depth_sec_ = (
|
||||||
|
self.config_.get_value("history_depth_hours", 0) * SecPerHour
|
||||||
|
)
|
||||||
|
|
||||||
await self.pairs_trader_.subscribe_md()
|
await self.pairs_trader_.subscribe_md()
|
||||||
|
|
||||||
self.model_data_policy_ = ModelDataPolicy.create(
|
self.model_data_policy_ = ModelDataPolicy.create(
|
||||||
self.config_, is_real_time=True, pair=self.trading_pair_
|
self.config_, is_real_time=True, pair=self.trading_pair_
|
||||||
)
|
)
|
||||||
self.open_threshold_ = self.config_.get("dis-equilibrium_open_trshld", 0.0)
|
self.open_threshold_ = self.config_.get_value(
|
||||||
assert self.open_threshold_ > 0, "open_threshold must be greater than 0"
|
"dis-equilibrium_open_trshld", 0.0
|
||||||
self.close_threshold_ = self.config_.get("dis-equilibrium_close_trshld", 0.0)
|
)
|
||||||
assert self.close_threshold_ > 0, "close_threshold must be greater than 0"
|
self.close_threshold_ = self.config_.get_value(
|
||||||
|
"dis-equilibrium_close_trshld", 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.open_threshold_ > 0
|
||||||
|
), "dis-equilibrium_open_trshld must be greater than 0"
|
||||||
|
assert (
|
||||||
|
self.close_threshold_ > 0
|
||||||
|
), "dis-equilibrium_close_trshld must be greater than 0"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, "
|
return f"{self.classname()}: trading_pair={self.trading_pair_}, mdp={self.model_data_policy_.__class__.__name__}, "
|
||||||
|
|
||||||
async def on_mkt_data_hist_snapshot(self, hist_aggr: List[MdTradesAggregate]) -> None:
|
async def on_mkt_data_hist_snapshot(
|
||||||
|
self, hist_aggr: List[MdTradesAggregate]
|
||||||
|
) -> None:
|
||||||
# Log.info(f"on_mkt_data_hist_snapshot: {aggr}")
|
# Log.info(f"on_mkt_data_hist_snapshot: {aggr}")
|
||||||
# await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr)
|
# await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr)
|
||||||
pass # URGENT PtiveStrategy.on_mkt_data_hist_snapshot()
|
pass # URGENT PtiveStrategy.on_mkt_data_hist_snapshot()
|
||||||
|
|
||||||
async def on_mkt_data_update(self, aggr: MdTradesAggregate) -> None:
|
async def on_mkt_data_update(self, aggr: MdTradesAggregate) -> None:
|
||||||
# if market_data_df is not None:
|
# if market_data_df is not None:
|
||||||
@ -105,18 +139,18 @@ class PtLiveStrategy(NamedObject):
|
|||||||
# if len(trading_instructions) > 0:
|
# if len(trading_instructions) > 0:
|
||||||
# await self._send_trading_instructions(trading_instructions)
|
# await self._send_trading_instructions(trading_instructions)
|
||||||
# # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
|
# # trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
|
||||||
pass # URGENT
|
pass # URGENT
|
||||||
|
|
||||||
def interval_sec(self) -> IntervalSecT:
|
def interval_sec(self) -> IntervalSecT:
|
||||||
return 60 # URGENT use config
|
return self.interval_sec_
|
||||||
|
|
||||||
def history_depth_sec(self) -> IntervalSecT:
|
def history_depth_sec(self) -> IntervalSecT:
|
||||||
return 3600 * 60 * 2 # URGENT use config
|
return self.history_depth_sec_
|
||||||
|
|
||||||
async def _send_trading_instructions(
|
async def _send_trading_instructions(
|
||||||
self, trading_instructions: List[TradingInstruction]
|
self, trading_instructions: List[TradingInstruction]
|
||||||
) -> None:
|
) -> None:
|
||||||
pass # URGENT implement _send_trading_instructions
|
pass # URGENT implement _send_trading_instructions
|
||||||
|
|
||||||
def _create_trading_instructions(
|
def _create_trading_instructions(
|
||||||
self, prediction: Prediction, last_row: pd.Series
|
self, prediction: Prediction, last_row: pd.Series
|
||||||
@ -135,7 +169,7 @@ class PtLiveStrategy(NamedObject):
|
|||||||
elif pair.is_open():
|
elif pair.is_open():
|
||||||
if abs_scaled_disequilibrium <= self.close_threshold_:
|
if abs_scaled_disequilibrium <= self.close_threshold_:
|
||||||
trd_instructions = self._create_close_trade_instructions(
|
trd_instructions = self._create_close_trade_instructions(
|
||||||
pair, row=last_row #, prediction=prediction
|
pair, row=last_row # , prediction=prediction
|
||||||
)
|
)
|
||||||
elif pair.to_stop_close_conditions(predicted_row=last_row):
|
elif pair.to_stop_close_conditions(predicted_row=last_row):
|
||||||
trd_instructions = self._create_close_trade_instructions(
|
trd_instructions = self._create_close_trade_instructions(
|
||||||
@ -204,16 +238,15 @@ class PtLiveStrategy(NamedObject):
|
|||||||
"signed_scaled_disequilibrium": scaled_disequilibrium,
|
"signed_scaled_disequilibrium": scaled_disequilibrium,
|
||||||
# "pair": pair,
|
# "pair": pair,
|
||||||
}
|
}
|
||||||
ti: List[TradingInstruction] =self._create_trading_instructions(
|
ti: List[TradingInstruction] = self._create_trading_instructions(
|
||||||
prediction=prediction, last_row=row
|
prediction=prediction, last_row=row
|
||||||
)
|
)
|
||||||
return ti
|
return ti
|
||||||
|
|
||||||
|
|
||||||
def _create_close_trade_instructions(
|
def _create_close_trade_instructions(
|
||||||
self, pair: TradingPair, row: pd.Series #, prediction: Prediction
|
self, pair: TradingPair, row: pd.Series # , prediction: Prediction
|
||||||
) -> List[TradingInstruction]:
|
) -> List[TradingInstruction]:
|
||||||
return [] # URGENT implement _create_close_trade_instructions
|
return [] # URGENT implement _create_close_trade_instructions
|
||||||
|
|
||||||
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
|
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
|
||||||
trades = None
|
trades = None
|
||||||
@ -223,7 +256,7 @@ class PtLiveStrategy(NamedObject):
|
|||||||
if pair.user_data_["state"] == PairState.OPEN:
|
if pair.user_data_["state"] == PairState.OPEN:
|
||||||
print(f"{pair}: *** Position is NOT CLOSED. ***")
|
print(f"{pair}: *** Position is NOT CLOSED. ***")
|
||||||
# outstanding positions
|
# outstanding positions
|
||||||
if self.config_["close_outstanding_positions"]:
|
if self.config_.key_exists("close_outstanding_positions"):
|
||||||
close_position_row = pd.Series(pair.market_data_.iloc[-2])
|
close_position_row = pd.Series(pair.market_data_.iloc[-2])
|
||||||
# close_position_row["disequilibrium"] = 0.0
|
# close_position_row["disequilibrium"] = 0.0
|
||||||
# close_position_row["scaled_disequilibrium"] = 0.0
|
# close_position_row["scaled_disequilibrium"] = 0.0
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, cast
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from cvttpy_tools.config import Config
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataWindowParams:
|
class DataWindowParams:
|
||||||
@ -16,22 +17,22 @@ class DataWindowParams:
|
|||||||
|
|
||||||
|
|
||||||
class ModelDataPolicy(ABC):
|
class ModelDataPolicy(ABC):
|
||||||
config_: Dict[str, Any]
|
config_: Config
|
||||||
current_data_params_: DataWindowParams
|
current_data_params_: DataWindowParams
|
||||||
count_: int
|
count_: int
|
||||||
is_real_time_: bool
|
is_real_time_: bool
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
def __init__(self, config: Config, *args: Any, **kwargs: Any):
|
||||||
self.config_ = config
|
self.config_ = config
|
||||||
training_size = config.get("training_size", 120)
|
training_size = config.get_value("training_size", 120)
|
||||||
training_start_index = 0
|
training_start_index = 0
|
||||||
if kwargs.get("is_real_time", False):
|
if kwargs.get("is_real_time", False):
|
||||||
training_size = 120
|
training_size = 120
|
||||||
training_start_index = 0
|
training_start_index = 0
|
||||||
else:
|
else:
|
||||||
training_size = config.get("training_size", 120)
|
training_size = config.get_value("training_size", 120)
|
||||||
self.current_data_params_ = DataWindowParams(
|
self.current_data_params_ = DataWindowParams(
|
||||||
training_size=config.get("training_size", 120),
|
training_size=config.get_value("training_size", 120),
|
||||||
training_start_index=0,
|
training_start_index=0,
|
||||||
)
|
)
|
||||||
self.count_ = 0
|
self.count_ = 0
|
||||||
@ -44,10 +45,10 @@ class ModelDataPolicy(ABC):
|
|||||||
return self.current_data_params_
|
return self.current_data_params_
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(config: Dict[str, Any], *args: Any, **kwargs: Any) -> ModelDataPolicy:
|
def create(config: Config, *args: Any, **kwargs: Any) -> ModelDataPolicy:
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
model_data_policy_class_name = config.get("model_data_policy_class", None)
|
model_data_policy_class_name = config.get_value("model_data_policy_class", None)
|
||||||
assert model_data_policy_class_name is not None
|
assert model_data_policy_class_name is not None
|
||||||
module_name, class_name = model_data_policy_class_name.rsplit(".", 1)
|
module_name, class_name = model_data_policy_class_name.rsplit(".", 1)
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
@ -58,7 +59,7 @@ class ModelDataPolicy(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class RollingWindowDataPolicy(ModelDataPolicy):
|
class RollingWindowDataPolicy(ModelDataPolicy):
|
||||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
def __init__(self, config: Config, *args: Any, **kwargs: Any):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
self.count_ = 1
|
self.count_ = 1
|
||||||
|
|
||||||
@ -80,16 +81,16 @@ class OptimizedWndDataPolicy(ModelDataPolicy, ABC):
|
|||||||
prices_a_: np.ndarray
|
prices_a_: np.ndarray
|
||||||
prices_b_: np.ndarray
|
prices_b_: np.ndarray
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
def __init__(self, config: Config, *args: Any, **kwargs: Any):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
assert (
|
assert (
|
||||||
kwargs.get("pair") is not None
|
kwargs.get("pair") is not None
|
||||||
), "pair must be provided"
|
), "pair must be provided"
|
||||||
assert (
|
assert (
|
||||||
"min_training_size" in config and "max_training_size" in config
|
"min_training_size" in config.data() and "max_training_size" in config.data()
|
||||||
), "min_training_size and max_training_size must be provided"
|
), "min_training_size and max_training_size must be provided"
|
||||||
self.min_training_size_ = cast(int, config.get("min_training_size"))
|
self.min_training_size_ = cast(int, config.get_value("min_training_size"))
|
||||||
self.max_training_size_ = cast(int, config.get("max_training_size"))
|
self.max_training_size_ = cast(int, config.get_value("max_training_size"))
|
||||||
|
|
||||||
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
|
from pairs_trading.lib.pt_strategy.trading_pair import TradingPair
|
||||||
self.pair_ = cast(TradingPair, kwargs.get("pair"))
|
self.pair_ = cast(TradingPair, kwargs.get("pair"))
|
||||||
@ -133,7 +134,7 @@ class EGOptimizedWndDataPolicy(OptimizedWndDataPolicy):
|
|||||||
# Engle-Granger cointegration test
|
# Engle-Granger cointegration test
|
||||||
*** VERY SLOW ***
|
*** VERY SLOW ***
|
||||||
'''
|
'''
|
||||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
def __init__(self, config: Config, *args: Any, **kwargs: Any):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
|
|
||||||
def optimize_window_size(self) -> DataWindowParams:
|
def optimize_window_size(self) -> DataWindowParams:
|
||||||
@ -162,7 +163,7 @@ class EGOptimizedWndDataPolicy(OptimizedWndDataPolicy):
|
|||||||
|
|
||||||
class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy):
|
class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy):
|
||||||
# Augmented Dickey-Fuller test
|
# Augmented Dickey-Fuller test
|
||||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
def __init__(self, config: Config, *args: Any, **kwargs: Any):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
|
|
||||||
def optimize_window_size(self) -> DataWindowParams:
|
def optimize_window_size(self) -> DataWindowParams:
|
||||||
@ -208,7 +209,7 @@ class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy):
|
|||||||
|
|
||||||
class JohansenOptdWndDataPolicy(OptimizedWndDataPolicy):
|
class JohansenOptdWndDataPolicy(OptimizedWndDataPolicy):
|
||||||
# Johansen test
|
# Johansen test
|
||||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
def __init__(self, config: Config, *args: Any, **kwargs: Any):
|
||||||
super().__init__(config, *args, **kwargs)
|
super().__init__(config, *args, **kwargs)
|
||||||
|
|
||||||
def optimize_window_size(self) -> DataWindowParams:
|
def optimize_window_size(self) -> DataWindowParams:
|
||||||
|
|||||||
@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from cvttpy_tools.config import Config
|
||||||
|
|
||||||
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
|
from pairs_trading.lib.pt_strategy.model_data_policy import ModelDataPolicy
|
||||||
from pairs_trading.lib.pt_strategy.pt_market_data import ResearchMarketData
|
from pairs_trading.lib.pt_strategy.pt_market_data import ResearchMarketData
|
||||||
from pairs_trading.lib.pt_strategy.pt_model import Prediction
|
from pairs_trading.lib.pt_strategy.pt_model import Prediction
|
||||||
@ -41,7 +43,7 @@ class PtResearchStrategy:
|
|||||||
self.pt_mkt_data_ = ResearchMarketData(config=config_copy)
|
self.pt_mkt_data_ = ResearchMarketData(config=config_copy)
|
||||||
self.pt_mkt_data_.load()
|
self.pt_mkt_data_.load()
|
||||||
self.model_data_policy_ = ModelDataPolicy.create(
|
self.model_data_policy_ = ModelDataPolicy.create(
|
||||||
config, mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_
|
Config(config_copy), mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_
|
||||||
)
|
)
|
||||||
|
|
||||||
def outstanding_positions(self) -> List[Dict[str, Any]]:
|
def outstanding_positions(self) -> List[Dict[str, Any]]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user