diff --git a/.gitignore b/.gitignore index ce10c9f..66aaf0a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ __OLD__/ .history/ .cursorindexingignore data -.vscode/ cvttpy # SpecStory explanation file .specstory/.what-is-this.md diff --git a/__DELETE__/.vscode/launch.json b/__DELETE__/.vscode/launch.json new file mode 100644 index 0000000..04fddcf --- /dev/null +++ b/__DELETE__/.vscode/launch.json @@ -0,0 +1,181 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + + + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, + { + "name": "-------- Z-Score (OLS) --------", + }, + { + "name": "CRYPTO z-score", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "research/pt_backtest.py", + "args": [ + "--config=${workspaceFolder}/configuration/zscore.cfg", + "--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT", + "--date_pattern=20250605", + "--result_db=${workspaceFolder}/research/results/crypto/%T.z-score.ADA-SOL.20250602.crypto_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "EQUITY z-score", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "research/pt_backtest.py", + "args": [ + "--config=${workspaceFolder}/configuration/zscore.cfg", + "--instruments=COIN:EQUITY:ALPACA,MSTR:EQUITY:ALPACA", + "--date_pattern=2025060*", + "--result_db=${workspaceFolder}/research/results/equity/%T.z-score.COIN-MSTR.20250602.equity_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "EQUITY-CRYPTO z-score", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "research/pt_backtest.py", + "args": [ + "--config=${workspaceFolder}/configuration/zscore.cfg", + "--instruments=COIN:EQUITY:ALPACA,BTC-USDT:CRYPTO:BNBSPOT", + "--date_pattern=2025060*", + "--result_db=${workspaceFolder}/research/results/intermarket/%T.z-score.COIN-BTC.20250601.equity_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "-------- VECM --------", + }, + { + "name": "CRYPTO vecm", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "research/pt_backtest.py", + "args": [ + "--config=${workspaceFolder}/configuration/vecm.cfg", + "--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT", + "--date_pattern=2025060*", + "--result_db=${workspaceFolder}/research/results/crypto/%T.vecm.ADA-SOL.20250602.crypto_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "EQUITY vecm", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "research/pt_backtest.py", + "args": [ + "--config=${workspaceFolder}/configuration/vecm.cfg", + "--instruments=COIN:EQUITY:ALPACA,MSTR:EQUITY:ALPACA", + "--date_pattern=2025060*", + "--result_db=${workspaceFolder}/research/results/equity/%T.vecm.COIN-MSTR.20250602.equity_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "EQUITY-CRYPTO vecm", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "research/pt_backtest.py", + "args": [ + "--config=${workspaceFolder}/configuration/vecm.cfg", + "--instruments=COIN:EQUITY:ALPACA,BTC-USDT:CRYPTO:BNBSPOT", + "--date_pattern=2025060*", + "--result_db=${workspaceFolder}/research/results/intermarket/%T.vecm.COIN-BTC.20250601.equity_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "-------- New ZSCORE --------", + }, + { + "name": "New CRYPTO z-score", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "${workspaceFolder}/research/backtest_new.py", + "args": [ + "--config=${workspaceFolder}/configuration/new_zscore.cfg", + "--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT", + "--date_pattern=2025060*", + "--result_db=${workspaceFolder}/research/results/crypto/%T.new_zscore.ADA-SOL.2025060-.crypto_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "New CRYPTO vecm", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "${workspaceFolder}/research/backtest_new.py", + "args": [ + "--config=${workspaceFolder}/configuration/new_vecm.cfg", + "--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT", + "--date_pattern=20250605", + "--result_db=${workspaceFolder}/research/results/crypto/%T.vecm.ADA-SOL.20250605.crypto_results.db", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + }, + { + "name": "-------- Viz Test --------", + }, + { + "name": "Viz Test", + "type": "debugpy", + "request": "launch", + "python": "/home/oleg/.pyenv/python3.12-venv/bin/python", + "program": "${workspaceFolder}/research/viz_test.py", + "args": [ + "--config=${workspaceFolder}/configuration/new_zscore.cfg", + "--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT", + "--date_pattern=20250605", + ], + "env": { + "PYTHONPATH": "${workspaceFolder}/lib" + }, + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/__DELETE__/strategy/pair_strategy.py b/__DELETE__/strategy/pair_strategy.py new file mode 100644 index 0000000..fbed69f --- /dev/null +++ b/__DELETE__/strategy/pair_strategy.py @@ -0,0 +1,101 @@ +import argparse +import asyncio +import glob +import importlib +import os +from datetime import date, datetime +from typing import Any, Dict, List, Optional + +import hjson +import pandas as pd + +from tools.data_loader import get_available_instruments_from_db, load_market_data +from pt_trading.results import ( + BacktestResult, + create_result_database, + store_config_in_database, + store_results_in_database, +) +from pt_trading.fit_methods import PairsTradingFitMethod +from pt_trading.trading_pair import TradingPair + + +def run_strategy( + config: Dict, + datafile: str, + fit_method: PairsTradingFitMethod, + instruments: List[str], +) -> BacktestResult: + """ + Run backtest for all pairs using the specified instruments. + """ + bt_result: BacktestResult = BacktestResult(config=config) + + def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]: + nonlocal datafile + all_indexes = range(len(instruments)) + unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j] + pairs = [] + + # Update config to use the specified instruments + config_copy = config.copy() + config_copy["instruments"] = instruments + + market_data_df = load_market_data( + datafile=datafile, + exchange_id=config_copy["exchange_id"], + instruments=config_copy["instruments"], + instrument_id_pfx=config_copy["instrument_id_pfx"], + db_table_name=config_copy["db_table_name"], + trading_hours=config_copy["trading_hours"], + ) + + for a_index, b_index in unique_index_pairs: + pair = fit_method.create_trading_pair( + market_data=market_data_df, + symbol_a=instruments[a_index], + symbol_b=instruments[b_index], + ) + pairs.append(pair) + return pairs + + pairs_trades = [] + for pair in _create_pairs(config, instruments): + single_pair_trades = fit_method.run_pair( + pair=pair, config=config, bt_result=bt_result + ) + if single_pair_trades is not None and len(single_pair_trades) > 0: + pairs_trades.append(single_pair_trades) + + # Check if result_list has any data before concatenating + if len(pairs_trades) == 0: + print("No trading signals found for any pairs") + return bt_result + + result = pd.concat(pairs_trades, ignore_index=True) + result["time"] = pd.to_datetime(result["time"]) + result = result.set_index("time").sort_index() + + bt_result.collect_single_day_results(result) + return bt_result + + +def main() -> None: + # Load config + # Subscribe to CVTT market data + # On snapshot (with historical data) - create trading strategy with market data dateframe + + async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: + print(f"{message_type=} {subscr_id=} {instrument_id}") + if message_type == "md_aggregate": + aggr = message.get("md_aggregate", []) + print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}") + elif message_type == "historical_md_aggregate": + for aggr in message.get("historical_data", []): + print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}") + else: + print(f"Unknown message type: {message_type}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/bin/trade_pair.py b/bin/trade_pair.py new file mode 100644 index 0000000..89ae114 --- /dev/null +++ b/bin/trade_pair.py @@ -0,0 +1,69 @@ + +from functools import partial +from typing import Dict + +from cvtt_client.mkt_data import (CvttPricerWebSockClient, + CvttPricesSubscription, MessageTypeT, + SubscriptionIdT) +from cvttpy_base.tools.app import App +from cvttpy_base.tools.base import NamedObject +from pt_strategy.live_strategy import PtLiveStrategy + + +class PairTradingRunner(NamedObject): + def __init__(self) -> None: + super().__init__() + App.instance().add_call(App.Stage.Config, self._on_config()) + App.instance().add_call(App.Stage.Run, self.run()) + + async def _on_config(self) -> None: + pass + + async def run(self) -> None: + pass + +# async def main() -> None: +# live_strategy = PtLiveStrategy( +# config={}, +# instruments=[ +# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-BTC-USD"}, +# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-ETH-USD"}, +# ] +# ) +# async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: +# print(f"{message_type=} {subscr_id=} {instrument_id}") +# if message_type == "md_aggregate": +# aggr = message.get("md_aggregate", []) +# print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}") +# elif message_type == "historical_md_aggregate": +# for aggr in message.get("historical_data", []): +# print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}") +# else: +# print(f"Unknown message type: {message_type}") + +# pricer_client = CvttPricerWebSockClient( +# "ws://localhost:12346/ws" +# ) +# await pricer_client.subscribe(CvttPricesSubscription( +# exchange_config_name="COINBASE_AT", +# instrument_id="PAIR-BTC-USD", +# interval_sec=60, +# history_depth_sec=60*60*24, +# callback=partial(on_message, instrument_id="PAIR-BTC-USD") +# )) +# await pricer_client.subscribe(CvttPricesSubscription( +# exchange_config_name="COINBASE_AT", +# instrument_id="PAIR-ETH-USD", +# interval_sec=60, +# history_depth_sec=60*60*24, +# callback=partial(on_message, instrument_id="PAIR-ETH-USD") +# )) + +# await pricer_client.run() + + + +if __name__ == "__main__": + App() + + App.instance().run() diff --git a/configuration/ols-opt.cfg b/configuration/ols-opt.cfg index 567f80a..e40a8a3 100644 --- a/configuration/ols-opt.cfg +++ b/configuration/ols-opt.cfg @@ -24,8 +24,7 @@ "dis-equilibrium_close_trshld": 0.9, "model_class": "pt_strategy.models.OLSModel", - # "training_size": 120, - # "model_data_policy_class": "pt_strategy.model_data_policy.RollingWindowDataPolicy", + # "model_data_policy_class": "pt_strategy.model_data_policy.EGOptimizedWndDataPolicy", # "model_data_policy_class": "pt_strategy.model_data_policy.ADFOptimizedWndDataPolicy", "model_data_policy_class": "pt_strategy.model_data_policy.JohansenOptdWndDataPolicy", "min_training_size": 60, diff --git a/lib/cvtt/mkt_data.py b/lib/cvtt_client/mkt_data.py similarity index 100% rename from lib/cvtt/mkt_data.py rename to lib/cvtt_client/mkt_data.py diff --git a/lib/pt_strategy/live_strategy.py b/lib/pt_strategy/live_strategy.py new file mode 100644 index 0000000..68e23a4 --- /dev/null +++ b/lib/pt_strategy/live_strategy.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +from functools import partial +from typing import Any, Dict, List, Optional + +import pandas as pd +from cvttpy_base.settings.cvtt_types import JsonDictT +from cvttpy_base.tools.base import NamedObject +from cvttpy_base.tools.logger import Log + +from cvtt_client.mkt_data import CvttPricerWebSockClient, CvttPricesSubscription, MessageTypeT, SubscriptionIdT +from pt_strategy.model_data_policy import ModelDataPolicy +from pt_strategy.pt_market_data import PtMarketData, RealTimeMarketData +from pt_strategy.pt_model import Prediction +from pt_strategy.trading_pair import PairState, TradingPair + +''' + --config=pair.cfg + --pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT +''' + + +class PtMktDataClient(NamedObject): + live_strategy_: PtLiveStrategy + pricer_client_: CvttPricerWebSockClient + subscriptions_: List[CvttPricesSubscription] + + def __init__(self, live_strategy: PtLiveStrategy): + self.live_strategy_ = live_strategy + + async def start(self, subscription: CvttPricesSubscription) -> None: + pricer_url = self.live_strategy_.config_.get("pricer_url", None) #, "ws://localhost:12346/ws") + assert pricer_url is not None, "pricer_url is not found in config" + self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url) + + await self._subscribe() + + async def _subscribe(self) -> None: + pair: TradingPair = self.live_strategy_.trading_pair_ + for instrument in pair.instruments_: + await self.pricer_client_.subscribe(CvttPricesSubscription( + exchange_config_name=instrument["exchange_config_name"], + instrument_id=instrument["instrument_id"], + interval_sec=60, + history_depth_sec=60*60*24, + callback=partial(self.on_message, instrument_id=instrument["instrument_id"]) + )) + + async def on_message(self, message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None: + Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}") + aggr: JsonDictT + if message_type == "md_aggregate": + aggr = message.get("md_aggregate", {}) + await self.live_strategy_.on_mkt_data_update(aggr) + # print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}") + elif message_type == "historical_md_aggregate": + aggr = message.get("historical_data", {}) + await self.live_strategy_.on_mkt_data_hist_snapshot(aggr) + # print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}") + else: + Log.info(f"Unknown message type: {message_type}") + + async def run(self) -> None: + await self.pricer_client_.run() + + + +class PtLiveStrategy(NamedObject): + config_: Dict[str, Any] + trading_pair_: TradingPair + model_data_policy_: ModelDataPolicy + pt_mkt_data_: RealTimeMarketData + pt_mkt_data_client_: PtMktDataClient + + # for presentation: history of prediction values and trading signals + predictions_: pd.DataFrame + trading_signals_: pd.DataFrame + + def __init__( + self, + config: Dict[str, Any], + instruments: List[Dict[str, str]], + ): + + self.config_ = config + self.trading_pair_ = TradingPair(config=config, instruments=instruments) + self.predictions_ = pd.DataFrame() + self.trading_signals_ = pd.DataFrame() + + import copy + + # modified config must be passed to PtMarketData + config_copy = copy.deepcopy(config) + config_copy["instruments"] = instruments + self.pt_mkt_data_ = RealTimeMarketData(config=config_copy) + self.model_data_policy_ = ModelDataPolicy.create( + config, is_real_time=True,pair=self.trading_pair_ + ) + + async def on_mkt_data_hist_snapshot(self, aggr: JsonDictT) -> None: + Log.info(f"on_mkt_data_hist_snapshot: {aggr}") + await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr) + pass + + async def on_mkt_data_update(self, aggr: JsonDictT) -> None: + market_data_df = await self.pt_mkt_data_.on_mkt_data_update(update=aggr) + if market_data_df is not None: + self.trading_pair_.market_data_ = market_data_df + self.model_data_policy_.advance() + prediction = self.trading_pair_.run(market_data_df, self.model_data_policy_.advance()) + self.predictions_ = pd.concat([self.predictions_, prediction.to_df()], ignore_index=True) + trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1]) + # URGENT implement this + pass + + + async def run(self) -> None: + await self.pt_mkt_data_client_.run() + + def _create_trades( + self, prediction: Prediction, last_row: pd.Series + ) -> Optional[pd.DataFrame]: + pair = self.trading_pair_ + trades = None + + open_threshold = self.config_["dis-equilibrium_open_trshld"] + close_threshold = self.config_["dis-equilibrium_close_trshld"] + scaled_disequilibrium = prediction.scaled_disequilibrium_ + abs_scaled_disequilibrium = abs(scaled_disequilibrium) + + if pair.user_data_["state"] in [ + PairState.INITIAL, + PairState.CLOSE, + PairState.CLOSE_POSITION, + PairState.CLOSE_STOP_LOSS, + PairState.CLOSE_STOP_PROFIT, + ]: + if abs_scaled_disequilibrium >= open_threshold: + trades = self._create_open_trades( + pair, row=last_row, prediction=prediction + ) + if trades is not None: + trades["status"] = PairState.OPEN.name + print(f"OPEN TRADES:\n{trades}") + pair.user_data_["state"] = PairState.OPEN + pair.on_open_trades(trades) + + elif pair.user_data_["state"] == PairState.OPEN: + if abs_scaled_disequilibrium <= close_threshold: + trades = self._create_close_trades( + pair, row=last_row, prediction=prediction + ) + if trades is not None: + trades["status"] = PairState.CLOSE.name + print(f"CLOSE TRADES:\n{trades}") + pair.user_data_["state"] = PairState.CLOSE + pair.on_close_trades(trades) + elif pair.to_stop_close_conditions(predicted_row=last_row): + trades = self._create_close_trades(pair, row=last_row) + if trades is not None: + trades["status"] = pair.user_data_["stop_close_state"].name + print(f"STOP CLOSE TRADES:\n{trades}") + pair.user_data_["state"] = pair.user_data_["stop_close_state"] + pair.on_close_trades(trades) + + return trades + + def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]: + trades = None + pair = self.trading_pair_ + + # Outstanding positions + if pair.user_data_["state"] == PairState.OPEN: + print(f"{pair}: *** Position is NOT CLOSED. ***") + # outstanding positions + if self.config_["close_outstanding_positions"]: + close_position_row = pd.Series(pair.market_data_.iloc[-2]) + # close_position_row["disequilibrium"] = 0.0 + # close_position_row["scaled_disequilibrium"] = 0.0 + # close_position_row["signed_scaled_disequilibrium"] = 0.0 + + trades = self._create_close_trades( + pair=pair, row=close_position_row, prediction=None + ) + if trades is not None: + trades["status"] = PairState.CLOSE_POSITION.name + print(f"CLOSE_POSITION TRADES:\n{trades}") + pair.user_data_["state"] = PairState.CLOSE_POSITION + pair.on_close_trades(trades) + else: + pair.add_outstanding_position( + symbol=pair.symbol_a_, + open_side=pair.user_data_["open_side_a"], + open_px=pair.user_data_["open_px_a"], + open_tstamp=pair.user_data_["open_tstamp"], + last_mkt_data_row=pair.market_data_.iloc[-1], + ) + pair.add_outstanding_position( + symbol=pair.symbol_b_, + open_side=pair.user_data_["open_side_b"], + open_px=pair.user_data_["open_px_b"], + open_tstamp=pair.user_data_["open_tstamp"], + last_mkt_data_row=pair.market_data_.iloc[-1], + ) + return trades + + def _trades_df(self) -> pd.DataFrame: + types = { + "time": "datetime64[ns]", + "action": "string", + "symbol": "string", + "side": "string", + "price": "float64", + "disequilibrium": "float64", + "scaled_disequilibrium": "float64", + "signed_scaled_disequilibrium": "float64", + # "pair": "object", + } + columns = list(types.keys()) + return pd.DataFrame(columns=columns).astype(types) + + def _create_open_trades( + self, pair: TradingPair, row: pd.Series, prediction: Prediction + ) -> Optional[pd.DataFrame]: + colname_a, colname_b = pair.exec_prices_colnames() + + tstamp = row["tstamp"] + diseqlbrm = prediction.disequilibrium_ + scaled_disequilibrium = prediction.scaled_disequilibrium_ + px_a = row[f"{colname_a}"] + px_b = row[f"{colname_b}"] + + # creating the trades + df = self._trades_df() + + print(f"OPEN_TRADES: {row["tstamp"]} {scaled_disequilibrium=}") + if diseqlbrm > 0: + side_a = "SELL" + side_b = "BUY" + else: + side_a = "BUY" + side_b = "SELL" + + # save closing sides + pair.user_data_["open_side_a"] = side_a # used in oustanding positions + pair.user_data_["open_side_b"] = side_b + pair.user_data_["open_px_a"] = px_a + pair.user_data_["open_px_b"] = px_b + pair.user_data_["open_tstamp"] = tstamp + + pair.user_data_["close_side_a"] = side_b # used for closing trades + pair.user_data_["close_side_b"] = side_a + + # create opening trades + df.loc[len(df)] = { + "time": tstamp, + "symbol": pair.symbol_a_, + "side": side_a, + "action": "OPEN", + "price": px_a, + "disequilibrium": diseqlbrm, + "signed_scaled_disequilibrium": scaled_disequilibrium, + "scaled_disequilibrium": abs(scaled_disequilibrium), + # "pair": pair, + } + df.loc[len(df)] = { + "time": tstamp, + "symbol": pair.symbol_b_, + "side": side_b, + "action": "OPEN", + "price": px_b, + "disequilibrium": diseqlbrm, + "scaled_disequilibrium": abs(scaled_disequilibrium), + "signed_scaled_disequilibrium": scaled_disequilibrium, + # "pair": pair, + } + return df + + def _create_close_trades( + self, pair: TradingPair, row: pd.Series, prediction: Optional[Prediction] = None + ) -> Optional[pd.DataFrame]: + colname_a, colname_b = pair.exec_prices_colnames() + + tstamp = row["tstamp"] + if prediction is not None: + diseqlbrm = prediction.disequilibrium_ + signed_scaled_disequilibrium = prediction.scaled_disequilibrium_ + scaled_disequilibrium = abs(prediction.scaled_disequilibrium_) + else: + diseqlbrm = 0.0 + signed_scaled_disequilibrium = 0.0 + scaled_disequilibrium = 0.0 + px_a = row[f"{colname_a}"] + px_b = row[f"{colname_b}"] + + # creating the trades + df = self._trades_df() + + # create opening trades + df.loc[len(df)] = { + "time": tstamp, + "symbol": pair.symbol_a_, + "side": pair.user_data_["close_side_a"], + "action": "CLOSE", + "price": px_a, + "disequilibrium": diseqlbrm, + "scaled_disequilibrium": scaled_disequilibrium, + "signed_scaled_disequilibrium": signed_scaled_disequilibrium, + # "pair": pair, + } + df.loc[len(df)] = { + "time": tstamp, + "symbol": pair.symbol_b_, + "side": pair.user_data_["close_side_b"], + "action": "CLOSE", + "price": px_b, + "disequilibrium": diseqlbrm, + "scaled_disequilibrium": scaled_disequilibrium, + "signed_scaled_disequilibrium": signed_scaled_disequilibrium, + # "pair": pair, + } + del pair.user_data_["close_side_a"] + del pair.user_data_["close_side_b"] + + del pair.user_data_["open_tstamp"] + del pair.user_data_["open_px_a"] + del pair.user_data_["open_px_b"] + del pair.user_data_["open_side_a"] + del pair.user_data_["open_side_b"] + return df + diff --git a/lib/pt_strategy/model_data_policy.py b/lib/pt_strategy/model_data_policy.py index fd0f39b..8c1923e 100644 --- a/lib/pt_strategy/model_data_policy.py +++ b/lib/pt_strategy/model_data_policy.py @@ -3,7 +3,7 @@ from __future__ import annotations import copy from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, cast +from typing import Any, Dict, Optional, cast import numpy as np import pandas as pd @@ -19,17 +19,26 @@ class ModelDataPolicy(ABC): config_: Dict[str, Any] current_data_params_: DataWindowParams count_: int + is_real_time_: bool - def __init__(self, config: Dict[str, Any]): + def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): self.config_ = config + training_size = config.get("training_size", 120) + training_start_index = 0 + if kwargs.get("is_real_time", False): + training_size = 120 + training_start_index = 0 + else: + training_size = config.get("training_size", 120) self.current_data_params_ = DataWindowParams( training_size=config.get("training_size", 120), training_start_index=0, ) self.count_ = 0 + self.is_real_time_ = kwargs.get("is_real_time", False) @abstractmethod - def advance(self) -> DataWindowParams: + def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams: self.count_ += 1 print(self.count_, end="\r") return self.current_data_params_ @@ -50,22 +59,15 @@ class ModelDataPolicy(ABC): class RollingWindowDataPolicy(ModelDataPolicy): def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): - super().__init__(config) + super().__init__(config, *args, **kwargs) self.count_ = 1 - def advance(self) -> DataWindowParams: - super().advance() - self.current_data_params_.training_start_index += 1 - return self.current_data_params_ - - -class ExpandingWindowDataPolicy(ModelDataPolicy): - def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): - super().__init__(config) - - def advance(self) -> DataWindowParams: - super().advance() - self.current_data_params_.training_size += 1 + def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams: + super().advance(mkt_data_df) + if self.is_real_time_: + self.current_data_params_.training_start_index = -self.current_data_params_.training_size + else: + self.current_data_params_.training_start_index += 1 return self.current_data_params_ @@ -79,34 +81,47 @@ class OptimizedWndDataPolicy(ModelDataPolicy, ABC): prices_b_: np.ndarray def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): - super().__init__(config) + super().__init__(config, *args, **kwargs) assert ( - kwargs.get("mkt_data") is not None and kwargs.get("pair") is not None - ), "mkt_data and/or pair must be provided" + kwargs.get("pair") is not None + ), "pair must be provided" assert ( "min_training_size" in config and "max_training_size" in config ), "min_training_size and max_training_size must be provided" self.min_training_size_ = cast(int, config.get("min_training_size")) self.max_training_size_ = cast(int, config.get("max_training_size")) - assert self.min_training_size_ < self.max_training_size_ from pt_strategy.trading_pair import TradingPair + self.pair_ = cast(TradingPair, kwargs.get("pair")) + + if "mkt_data" in kwargs: + self.mkt_data_df_ = cast(pd.DataFrame, kwargs.get("mkt_data")) + col_a, col_b = self.pair_.colnames() + self.prices_a_ = np.array(self.mkt_data_df_[col_a]) + self.prices_b_ = np.array(self.mkt_data_df_[col_b]) + assert self.min_training_size_ < self.max_training_size_ - self.mkt_data_df_ = cast(pd.DataFrame, kwargs.get("mkt_data")) - self.pair_ = cast(TradingPair, kwargs.get("pair")) - self.end_index_ = ( - self.current_data_params_.training_start_index + self.max_training_size_ - ) + def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams: + super().advance(mkt_data_df) + if mkt_data_df is not None: + self.mkt_data_df_ = mkt_data_df + + if self.is_real_time_: + self.end_index_ = len(self.mkt_data_df_) - 1 + else: + self.end_index_ = self.current_data_params_.training_start_index + self.max_training_size_ + if self.end_index_ > len(self.mkt_data_df_) - 1: + self.end_index_ = len(self.mkt_data_df_) - 1 + self.current_data_params_.training_start_index = self.end_index_ - self.max_training_size_ + if self.current_data_params_.training_start_index < 0: + self.current_data_params_.training_start_index = 0 + col_a, col_b = self.pair_.colnames() self.prices_a_ = np.array(self.mkt_data_df_[col_a]) self.prices_b_ = np.array(self.mkt_data_df_[col_b]) - - def advance(self) -> DataWindowParams: - super().advance() self.current_data_params_ = self.optimize_window_size() - self.end_index_ += 1 return self.current_data_params_ @abstractmethod @@ -126,6 +141,9 @@ class EGOptimizedWndDataPolicy(OptimizedWndDataPolicy): last_pvalue = 1.0 result = copy.copy(self.current_data_params_) for trn_size in range(self.min_training_size_, self.max_training_size_): + if self.end_index_ - trn_size < 0: + break + from statsmodels.tsa.stattools import coint # type: ignore start_index = self.end_index_ - trn_size @@ -155,6 +173,8 @@ class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy): last_pvalue = 1.0 result = copy.copy(self.current_data_params_) for trn_size in range(self.min_training_size_, self.max_training_size_): + if self.end_index_ - trn_size < 0: + break start_index = self.end_index_ - trn_size y = self.prices_a_[start_index : self.end_index_] x = self.prices_b_[start_index : self.end_index_] @@ -201,6 +221,8 @@ class JohansenOptdWndDataPolicy(OptimizedWndDataPolicy): result = copy.copy(self.current_data_params_) for trn_size in range(self.min_training_size_, self.max_training_size_): + if self.end_index_ - trn_size < 0: + break start_index = self.end_index_ - trn_size series_a = self.prices_a_[start_index:self.end_index_] series_b = self.prices_b_[start_index:self.end_index_] diff --git a/lib/pt_strategy/prediction.py b/lib/pt_strategy/prediction.py new file mode 100644 index 0000000..8ae838f --- /dev/null +++ b/lib/pt_strategy/prediction.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any, Dict + +import pandas as pd + + +class Prediction: + tstamp_: pd.Timestamp + disequilibrium_: float + scaled_disequilibrium_: float + + def __init__(self, tstamp: pd.Timestamp, disequilibrium: float, scaled_disequilibrium: float): + self.tstamp_ = tstamp + self.disequilibrium_ = disequilibrium + self.scaled_disequilibrium_ = scaled_disequilibrium + + def to_dict(self) -> Dict[str, Any]: + return { + "tstamp": self.tstamp_, + "disequilibrium": self.disequilibrium_, + "signed_scaled_disequilibrium": self.scaled_disequilibrium_, + "scaled_disequilibrium": abs(self.scaled_disequilibrium_), + # "pair": self.pair_, + } + def to_df(self) -> pd.DataFrame: + return pd.DataFrame([self.to_dict()]) + \ No newline at end of file diff --git a/lib/pt_strategy/pt_market_data.py b/lib/pt_strategy/pt_market_data.py index 1709f1b..ba41597 100644 --- a/lib/pt_strategy/pt_market_data.py +++ b/lib/pt_strategy/pt_market_data.py @@ -1,14 +1,14 @@ from __future__ import annotations -from abc import ABC, abstractmethod from typing import Any, Dict, List, Type - import pandas as pd +from cvttpy_base.settings.cvtt_types import JsonDictT + from tools.data_loader import load_market_data +from pt_strategy.trading_pair import TradingPair - -class PtMarketData(ABC): +class PtMarketData(): config_: Dict[str, Any] origin_mkt_data_df_: pd.DataFrame market_data_df_: pd.DataFrame @@ -16,27 +16,10 @@ class PtMarketData(ABC): def __init__(self, config: Dict[str, Any]): self.config_ = config self.origin_mkt_data_df_ = pd.DataFrame() + self.market_data_df_ = pd.DataFrame() - @abstractmethod - def load(self) -> None: - ... - - - @abstractmethod - def has_next(self) -> bool: - ... - - @abstractmethod - def get_next(self) -> pd.Series: - ... - - - @staticmethod - def create(config: Dict[str, Any], md_class: Type[PtMarketData]) -> PtMarketData: - return md_class(config) class ResearchMarketData(PtMarketData): - config_: Dict[str, Any] current_index_: int is_execution_price_: bool @@ -185,3 +168,25 @@ class ResearchMarketData(PtMarketData): f"exec_price_{self.symbol_b_}", ] +class RealTimeMarketData(PtMarketData): + + def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): + super().__init__(config, *args, **kwargs) + + async def on_mkt_data_hist_snapshot(self, snapshot: JsonDictT) -> None: + # URGENT + # create origin_mkt_data_df_ from snapshot + # transform it to market_data_df_ tstamp, close_symbolA, close_symbolB + pass + + async def on_mkt_data_update(self, update: JsonDictT) -> Optional[pd.DataFrame]: + # URGENT + # make sure update has both instruments + # create DataFrame tmp1 from update + # transform tmp1 into temp. datframe tmp2 + # add tmp1 to origin_mkt_data_df_ + # add tmp2 to market_data_df_ + # return market_data_df_ + + + return pd.DataFrame() \ No newline at end of file diff --git a/lib/pt_strategy/pt_model.py b/lib/pt_strategy/pt_model.py index 1f817f9..daf193d 100644 --- a/lib/pt_strategy/pt_model.py +++ b/lib/pt_strategy/pt_model.py @@ -1,39 +1,15 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, Optional, cast, Generator, List +from typing import Any, Dict, cast -import pandas as pd +from pt_strategy.prediction import Prediction -from pt_strategy.trading_pair import TradingPair - -class Prediction: - tstamp_: pd.Timestamp - disequilibrium_: float - scaled_disequilibrium_: float - - def __init__(self, tstamp: pd.Timestamp, disequilibrium: float, scaled_disequilibrium: float): - self.tstamp_ = tstamp - self.disequilibrium_ = disequilibrium - self.scaled_disequilibrium_ = scaled_disequilibrium - - def to_dict(self) -> Dict[str, Any]: - return { - "tstamp": self.tstamp_, - "disequilibrium": self.disequilibrium_, - "signed_scaled_disequilibrium": self.scaled_disequilibrium_, - "scaled_disequilibrium": abs(self.scaled_disequilibrium_), - # "pair": self.pair_, - } - def to_df(self) -> pd.DataFrame: - return pd.DataFrame([self.to_dict()]) class PairsTradingModel(ABC): @abstractmethod - def predict(self, pair: TradingPair) -> Prediction: + def predict(self, pair: TradingPair) -> Prediction: # type: ignore[assignment] ... @staticmethod diff --git a/lib/pt_strategy/trading_strategy.py b/lib/pt_strategy/research_strategy.py similarity index 97% rename from lib/pt_strategy/trading_strategy.py rename to lib/pt_strategy/research_strategy.py index fa0a8e8..4f4fafc 100644 --- a/lib/pt_strategy/trading_strategy.py +++ b/lib/pt_strategy/research_strategy.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional import pandas as pd from pt_strategy.model_data_policy import ModelDataPolicy -from pt_strategy.pt_market_data import PtMarketData +from pt_strategy.pt_market_data import ResearchMarketData from pt_strategy.pt_model import Prediction from pt_strategy.trading_pair import PairState, TradingPair @@ -13,7 +13,7 @@ class PtResearchStrategy: config_: Dict[str, Any] trading_pair_: TradingPair model_data_policy_: ModelDataPolicy - pt_mkt_data_: PtMarketData + pt_mkt_data_: ResearchMarketData trades_: List[pd.DataFrame] predictions_: pd.DataFrame @@ -25,7 +25,6 @@ class PtResearchStrategy: instruments: List[Dict[str, str]], ): from pt_strategy.model_data_policy import ModelDataPolicy - from pt_strategy.pt_market_data import PtMarketData, ResearchMarketData from pt_strategy.trading_pair import TradingPair self.config_ = config @@ -39,9 +38,7 @@ class PtResearchStrategy: config_copy = copy.deepcopy(config) config_copy["instruments"] = instruments config_copy["datafiles"] = datafiles - self.pt_mkt_data_ = PtMarketData.create( - config=config_copy, md_class=ResearchMarketData - ) + self.pt_mkt_data_ = ResearchMarketData(config=config_copy) self.pt_mkt_data_.load() self.model_data_policy_ = ModelDataPolicy.create( config, mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_ @@ -73,7 +70,7 @@ class PtResearchStrategy: market_data_df = pd.concat([market_data_df, new_row], ignore_index=True) prediction = self.trading_pair_.run( - market_data_df, self.model_data_policy_.advance() + market_data_df, self.model_data_policy_.advance(mkt_data_df=market_data_df) ) self.predictions_ = pd.concat( [self.predictions_, prediction.to_df()], ignore_index=True diff --git a/lib/pt_strategy/trading_pair.py b/lib/pt_strategy/trading_pair.py index 89da29b..9b96adb 100644 --- a/lib/pt_strategy/trading_pair.py +++ b/lib/pt_strategy/trading_pair.py @@ -1,12 +1,13 @@ from __future__ import annotations -from abc import ABC, abstractmethod from datetime import datetime from enum import Enum -from typing import Any, Dict, Generator, List, Optional, Type, cast +from typing import Any, Dict, List import pandas as pd + from pt_strategy.model_data_policy import DataWindowParams +from pt_strategy.prediction import Prediction class PairState(Enum): @@ -20,11 +21,12 @@ class PairState(Enum): class TradingPair: config_: Dict[str, Any] market_data_: pd.DataFrame + instruments_: List[Dict[str, str]] symbol_a_: str symbol_b_: str stat_model_price_: str - model_: PairsTradingModel # type: ignore[assignment] + model_: PairsTradingModel # type: ignore[assignment] user_data_: Dict[str, Any] @@ -34,11 +36,12 @@ class TradingPair: instruments: List[Dict[str, str]], ): - from pt_strategy.model_data_policy import ModelDataPolicy from pt_strategy.pt_model import PairsTradingModel + assert len(instruments) == 2, "Trading pair must have exactly 2 instruments" self.config_ = config + self.instruments_ = instruments self.symbol_a_ = instruments[0]["symbol"] self.symbol_b_ = instruments[1]["symbol"] self.model_ = PairsTradingModel.create(config) diff --git a/lib/tools/viz/viz_prices.py b/lib/tools/viz/viz_prices.py index 2575176..418426e 100644 --- a/lib/tools/viz/viz_prices.py +++ b/lib/tools/viz/viz_prices.py @@ -1,4 +1,4 @@ -from pt_strategy.trading_strategy import PtResearchStrategy +from pt_strategy.research_strategy import PtResearchStrategy def visualize_prices(strategy: PtResearchStrategy, trading_date: str) -> None: diff --git a/lib/tools/viz/viz_trades.py b/lib/tools/viz/viz_trades.py index 209bb39..274fbd1 100644 --- a/lib/tools/viz/viz_trades.py +++ b/lib/tools/viz/viz_trades.py @@ -5,7 +5,7 @@ from typing import Any, Dict from pt_strategy.results import (PairResearchResult, create_result_database, store_config_in_database) -from pt_strategy.trading_strategy import PtResearchStrategy +from pt_strategy.research_strategy import PtResearchStrategy from tools.filetools import resolve_datafiles from tools.instruments import get_instruments diff --git a/pyrightconfig.json b/pyrightconfig.json index 9bf3138..c2d10fd 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -16,7 +16,8 @@ "autoImportCompletions": true, "autoSearchPaths": true, "extraPaths": [ - "lib" + "lib", + ".." ], "stubPath": "./typings", "venvPath": ".", diff --git a/research/backtest.py b/research/backtest.py index daad2fb..4e4f202 100644 --- a/research/backtest.py +++ b/research/backtest.py @@ -8,7 +8,7 @@ from pt_strategy.results import ( create_result_database, store_config_in_database, ) -from pt_strategy.trading_strategy import PtResearchStrategy +from pt_strategy.research_strategy import PtResearchStrategy from tools.filetools import resolve_datafiles from tools.instruments import get_instruments diff --git a/research/notebooks/pair_trading_test.ipynb b/research/notebooks/pair_trading_test.ipynb index 04e81c1..c08187b 100644 --- a/research/notebooks/pair_trading_test.ipynb +++ b/research/notebooks/pair_trading_test.ipynb @@ -288,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -302,7 +302,7 @@ "\n", " \n", " from pt_strategy.trading_pair import TradingPair\n", - " from pt_strategy.trading_strategy import PtResearchStrategy\n", + " from pt_strategy.research_strategy import PtResearchStrategy\n", " from pt_strategy.results import PairResearchResult\n", "\n", " # Create trading pair\n", diff --git a/research/research_tools.py b/research/tools/research_tools.py similarity index 100% rename from research/research_tools.py rename to research/tools/research_tools.py diff --git a/strategy/pair_strategy.py b/strategy/pair_strategy.py deleted file mode 100644 index a7f7604..0000000 --- a/strategy/pair_strategy.py +++ /dev/null @@ -1,221 +0,0 @@ -import argparse -import asyncio -import glob -import importlib -import os -from datetime import date, datetime -from typing import Any, Dict, List, Optional - -import hjson -import pandas as pd - -from tools.data_loader import get_available_instruments_from_db, load_market_data -from pt_trading.results import ( - BacktestResult, - create_result_database, - store_config_in_database, - store_results_in_database, -) -from pt_trading.fit_methods import PairsTradingFitMethod -from pt_trading.trading_pair import TradingPair - - -def run_strategy( - config: Dict, - datafile: str, - fit_method: PairsTradingFitMethod, - instruments: List[str], -) -> BacktestResult: - """ - Run backtest for all pairs using the specified instruments. - """ - bt_result: BacktestResult = BacktestResult(config=config) - - def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]: - nonlocal datafile - all_indexes = range(len(instruments)) - unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j] - pairs = [] - - # Update config to use the specified instruments - config_copy = config.copy() - config_copy["instruments"] = instruments - - market_data_df = load_market_data( - datafile=datafile, - exchange_id=config_copy["exchange_id"], - instruments=config_copy["instruments"], - instrument_id_pfx=config_copy["instrument_id_pfx"], - db_table_name=config_copy["db_table_name"], - trading_hours=config_copy["trading_hours"], - ) - - for a_index, b_index in unique_index_pairs: - pair = fit_method.create_trading_pair( - market_data=market_data_df, - symbol_a=instruments[a_index], - symbol_b=instruments[b_index], - ) - pairs.append(pair) - return pairs - - pairs_trades = [] - for pair in _create_pairs(config, instruments): - single_pair_trades = fit_method.run_pair( - pair=pair, config=config, bt_result=bt_result - ) - if single_pair_trades is not None and len(single_pair_trades) > 0: - pairs_trades.append(single_pair_trades) - - # Check if result_list has any data before concatenating - if len(pairs_trades) == 0: - print("No trading signals found for any pairs") - return bt_result - - result = pd.concat(pairs_trades, ignore_index=True) - result["time"] = pd.to_datetime(result["time"]) - result = result.set_index("time").sort_index() - - bt_result.collect_single_day_results(result) - return bt_result - - -def main() -> None: - parser = argparse.ArgumentParser(description="Run pairs trading backtest.") - parser.add_argument( - "--config", type=str, required=True, help="Path to the configuration file." - ) - parser.add_argument( - "--datafiles", - type=str, - required=False, - help="Comma-separated list of data files (overrides config). No wildcards supported.", - ) - parser.add_argument( - "--instruments", - type=str, - required=False, - help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.", - ) - parser.add_argument( - "--result_db", - type=str, - required=True, - help="Path to SQLite database for storing results. Use 'NONE' to disable database output.", - ) - - args = parser.parse_args() - - config: Dict = load_config(args.config) - - # Dynamically instantiate fit method class - fit_method_class_name = config.get("fit_method_class", None) - assert fit_method_class_name is not None - module_name, class_name = fit_method_class_name.rsplit(".", 1) - module = importlib.import_module(module_name) - fit_method = getattr(module, class_name)() - - # Resolve data files (CLI takes priority over config) - datafiles = resolve_datafiles(config, args.datafiles) - - if not datafiles: - print("No data files found to process.") - return - - print(f"Found {len(datafiles)} data files to process:") - for df in datafiles: - print(f" - {df}") - - # Create result database if needed - if args.result_db.upper() != "NONE": - create_result_database(args.result_db) - - # Initialize a dictionary to store all trade results - all_results: Dict[str, Dict[str, Any]] = {} - - # Store configuration in database for reference - if args.result_db.upper() != "NONE": - # Get list of all instruments for storage - all_instruments = [] - for datafile in datafiles: - if args.instruments: - file_instruments = [ - inst.strip() for inst in args.instruments.split(",") - ] - else: - file_instruments = get_available_instruments_from_db(datafile, config) - all_instruments.extend(file_instruments) - - # Remove duplicates while preserving order - unique_instruments = list(dict.fromkeys(all_instruments)) - - store_config_in_database( - db_path=args.result_db, - config_file_path=args.config, - config=config, - fit_method_class=fit_method_class_name, - datafiles=datafiles, - instruments=unique_instruments, - ) - - # Process each data file - - for datafile in datafiles: - print(f"\n====== Processing {os.path.basename(datafile)} ======") - - # Determine instruments to use - if args.instruments: - # Use CLI-specified instruments - instruments = [inst.strip() for inst in args.instruments.split(",")] - print(f"Using CLI-specified instruments: {instruments}") - else: - # Auto-detect instruments from database - instruments = get_available_instruments_from_db(datafile, config) - print(f"Auto-detected instruments: {instruments}") - - if not instruments: - print(f"No instruments found for {datafile}, skipping...") - continue - - # Process data for this file - try: - fit_method.reset() - - bt_results = run_strategy( - config=config, - datafile=datafile, - fit_method=fit_method, - instruments=instruments, - ) - - # Store results with file name as key - filename = os.path.basename(datafile) - all_results[filename] = {"trades": bt_results.trades.copy()} - - # Store results in database - if args.result_db.upper() != "NONE": - store_results_in_database(args.result_db, datafile, bt_results) - - print(f"Successfully processed {filename}") - - except Exception as err: - print(f"Error processing {datafile}: {str(err)}") - import traceback - - traceback.print_exc() - - # Calculate and print results using a new BacktestResult instance for aggregation - if all_results: - aggregate_bt_results = BacktestResult(config=config) - aggregate_bt_results.calculate_returns(all_results) - aggregate_bt_results.print_grand_totals() - aggregate_bt_results.print_outstanding_positions() - - if args.result_db.upper() != "NONE": - print(f"\nResults stored in database: {args.result_db}") - else: - print("No results to display.") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/tests/viz_test.py b/tests/viz_test.py index e10d31b..e35f5b7 100644 --- a/tests/viz_test.py +++ b/tests/viz_test.py @@ -5,7 +5,7 @@ from typing import Any, Dict from pt_strategy.results import (PairResearchResult, create_result_database, store_config_in_database) -from pt_strategy.trading_strategy import PtResearchStrategy +from pt_strategy.research_strategy import PtResearchStrategy from tools.filetools import resolve_datafiles from tools.instruments import get_instruments from tools.viz.viz_trades import visualize_trades