This commit is contained in:
Oleg Sheynin 2025-08-02 00:12:31 +00:00
parent 80c3e8d54b
commit 73f36ddcea
21 changed files with 812 additions and 321 deletions

1
.gitignore vendored
View File

@ -5,7 +5,6 @@ __OLD__/
.history/ .history/
.cursorindexingignore .cursorindexingignore
data data
.vscode/
cvttpy cvttpy
# SpecStory explanation file # SpecStory explanation file
.specstory/.what-is-this.md .specstory/.what-is-this.md

181
__DELETE__/.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,181 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
},
{
"name": "-------- Z-Score (OLS) --------",
},
{
"name": "CRYPTO z-score",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "research/pt_backtest.py",
"args": [
"--config=${workspaceFolder}/configuration/zscore.cfg",
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
"--date_pattern=20250605",
"--result_db=${workspaceFolder}/research/results/crypto/%T.z-score.ADA-SOL.20250602.crypto_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "EQUITY z-score",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "research/pt_backtest.py",
"args": [
"--config=${workspaceFolder}/configuration/zscore.cfg",
"--instruments=COIN:EQUITY:ALPACA,MSTR:EQUITY:ALPACA",
"--date_pattern=2025060*",
"--result_db=${workspaceFolder}/research/results/equity/%T.z-score.COIN-MSTR.20250602.equity_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "EQUITY-CRYPTO z-score",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "research/pt_backtest.py",
"args": [
"--config=${workspaceFolder}/configuration/zscore.cfg",
"--instruments=COIN:EQUITY:ALPACA,BTC-USDT:CRYPTO:BNBSPOT",
"--date_pattern=2025060*",
"--result_db=${workspaceFolder}/research/results/intermarket/%T.z-score.COIN-BTC.20250601.equity_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "-------- VECM --------",
},
{
"name": "CRYPTO vecm",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "research/pt_backtest.py",
"args": [
"--config=${workspaceFolder}/configuration/vecm.cfg",
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
"--date_pattern=2025060*",
"--result_db=${workspaceFolder}/research/results/crypto/%T.vecm.ADA-SOL.20250602.crypto_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "EQUITY vecm",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "research/pt_backtest.py",
"args": [
"--config=${workspaceFolder}/configuration/vecm.cfg",
"--instruments=COIN:EQUITY:ALPACA,MSTR:EQUITY:ALPACA",
"--date_pattern=2025060*",
"--result_db=${workspaceFolder}/research/results/equity/%T.vecm.COIN-MSTR.20250602.equity_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "EQUITY-CRYPTO vecm",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "research/pt_backtest.py",
"args": [
"--config=${workspaceFolder}/configuration/vecm.cfg",
"--instruments=COIN:EQUITY:ALPACA,BTC-USDT:CRYPTO:BNBSPOT",
"--date_pattern=2025060*",
"--result_db=${workspaceFolder}/research/results/intermarket/%T.vecm.COIN-BTC.20250601.equity_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "-------- New ZSCORE --------",
},
{
"name": "New CRYPTO z-score",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "${workspaceFolder}/research/backtest_new.py",
"args": [
"--config=${workspaceFolder}/configuration/new_zscore.cfg",
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
"--date_pattern=2025060*",
"--result_db=${workspaceFolder}/research/results/crypto/%T.new_zscore.ADA-SOL.2025060-.crypto_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "New CRYPTO vecm",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "${workspaceFolder}/research/backtest_new.py",
"args": [
"--config=${workspaceFolder}/configuration/new_vecm.cfg",
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
"--date_pattern=20250605",
"--result_db=${workspaceFolder}/research/results/crypto/%T.vecm.ADA-SOL.20250605.crypto_results.db",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
},
{
"name": "-------- Viz Test --------",
},
{
"name": "Viz Test",
"type": "debugpy",
"request": "launch",
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
"program": "${workspaceFolder}/research/viz_test.py",
"args": [
"--config=${workspaceFolder}/configuration/new_zscore.cfg",
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
"--date_pattern=20250605",
],
"env": {
"PYTHONPATH": "${workspaceFolder}/lib"
},
"console": "integratedTerminal"
}
]
}

View File

@ -0,0 +1,101 @@
import argparse
import asyncio
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import hjson
import pandas as pd
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def run_strategy(
config: Dict,
datafile: str,
fit_method: PairsTradingFitMethod,
instruments: List[str],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(
datafile=datafile,
exchange_id=config_copy["exchange_id"],
instruments=config_copy["instruments"],
instrument_id_pfx=config_copy["instrument_id_pfx"],
db_table_name=config_copy["db_table_name"],
trading_hours=config_copy["trading_hours"],
)
for a_index, b_index in unique_index_pairs:
pair = fit_method.create_trading_pair(
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
)
pairs.append(pair)
return pairs
pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
return bt_result
result = pd.concat(pairs_trades, ignore_index=True)
result["time"] = pd.to_datetime(result["time"])
result = result.set_index("time").sort_index()
bt_result.collect_single_day_results(result)
return bt_result
def main() -> None:
# Load config
# Subscribe to CVTT market data
# On snapshot (with historical data) - create trading strategy with market data dateframe
async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
print(f"{message_type=} {subscr_id=} {instrument_id}")
if message_type == "md_aggregate":
aggr = message.get("md_aggregate", [])
print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}")
elif message_type == "historical_md_aggregate":
for aggr in message.get("historical_data", []):
print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}")
else:
print(f"Unknown message type: {message_type}")
if __name__ == "__main__":
asyncio.run(main())

69
bin/trade_pair.py Normal file
View File

@ -0,0 +1,69 @@
from functools import partial
from typing import Dict
from cvtt_client.mkt_data import (CvttPricerWebSockClient,
CvttPricesSubscription, MessageTypeT,
SubscriptionIdT)
from cvttpy_base.tools.app import App
from cvttpy_base.tools.base import NamedObject
from pt_strategy.live_strategy import PtLiveStrategy
class PairTradingRunner(NamedObject):
def __init__(self) -> None:
super().__init__()
App.instance().add_call(App.Stage.Config, self._on_config())
App.instance().add_call(App.Stage.Run, self.run())
async def _on_config(self) -> None:
pass
async def run(self) -> None:
pass
# async def main() -> None:
# live_strategy = PtLiveStrategy(
# config={},
# instruments=[
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-BTC-USD"},
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-ETH-USD"},
# ]
# )
# async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
# print(f"{message_type=} {subscr_id=} {instrument_id}")
# if message_type == "md_aggregate":
# aggr = message.get("md_aggregate", [])
# print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}")
# elif message_type == "historical_md_aggregate":
# for aggr in message.get("historical_data", []):
# print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}")
# else:
# print(f"Unknown message type: {message_type}")
# pricer_client = CvttPricerWebSockClient(
# "ws://localhost:12346/ws"
# )
# await pricer_client.subscribe(CvttPricesSubscription(
# exchange_config_name="COINBASE_AT",
# instrument_id="PAIR-BTC-USD",
# interval_sec=60,
# history_depth_sec=60*60*24,
# callback=partial(on_message, instrument_id="PAIR-BTC-USD")
# ))
# await pricer_client.subscribe(CvttPricesSubscription(
# exchange_config_name="COINBASE_AT",
# instrument_id="PAIR-ETH-USD",
# interval_sec=60,
# history_depth_sec=60*60*24,
# callback=partial(on_message, instrument_id="PAIR-ETH-USD")
# ))
# await pricer_client.run()
if __name__ == "__main__":
App()
App.instance().run()

View File

@ -24,8 +24,7 @@
"dis-equilibrium_close_trshld": 0.9, "dis-equilibrium_close_trshld": 0.9,
"model_class": "pt_strategy.models.OLSModel", "model_class": "pt_strategy.models.OLSModel",
# "training_size": 120, # "model_data_policy_class": "pt_strategy.model_data_policy.EGOptimizedWndDataPolicy",
# "model_data_policy_class": "pt_strategy.model_data_policy.RollingWindowDataPolicy",
# "model_data_policy_class": "pt_strategy.model_data_policy.ADFOptimizedWndDataPolicy", # "model_data_policy_class": "pt_strategy.model_data_policy.ADFOptimizedWndDataPolicy",
"model_data_policy_class": "pt_strategy.model_data_policy.JohansenOptdWndDataPolicy", "model_data_policy_class": "pt_strategy.model_data_policy.JohansenOptdWndDataPolicy",
"min_training_size": 60, "min_training_size": 60,

View File

@ -0,0 +1,331 @@
from __future__ import annotations
from functools import partial
from typing import Any, Dict, List, Optional
import pandas as pd
from cvttpy_base.settings.cvtt_types import JsonDictT
from cvttpy_base.tools.base import NamedObject
from cvttpy_base.tools.logger import Log
from cvtt_client.mkt_data import CvttPricerWebSockClient, CvttPricesSubscription, MessageTypeT, SubscriptionIdT
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData, RealTimeMarketData
from pt_strategy.pt_model import Prediction
from pt_strategy.trading_pair import PairState, TradingPair
'''
--config=pair.cfg
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
'''
class PtMktDataClient(NamedObject):
live_strategy_: PtLiveStrategy
pricer_client_: CvttPricerWebSockClient
subscriptions_: List[CvttPricesSubscription]
def __init__(self, live_strategy: PtLiveStrategy):
self.live_strategy_ = live_strategy
async def start(self, subscription: CvttPricesSubscription) -> None:
pricer_url = self.live_strategy_.config_.get("pricer_url", None) #, "ws://localhost:12346/ws")
assert pricer_url is not None, "pricer_url is not found in config"
self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)
await self._subscribe()
async def _subscribe(self) -> None:
pair: TradingPair = self.live_strategy_.trading_pair_
for instrument in pair.instruments_:
await self.pricer_client_.subscribe(CvttPricesSubscription(
exchange_config_name=instrument["exchange_config_name"],
instrument_id=instrument["instrument_id"],
interval_sec=60,
history_depth_sec=60*60*24,
callback=partial(self.on_message, instrument_id=instrument["instrument_id"])
))
async def on_message(self, message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
aggr: JsonDictT
if message_type == "md_aggregate":
aggr = message.get("md_aggregate", {})
await self.live_strategy_.on_mkt_data_update(aggr)
# print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}")
elif message_type == "historical_md_aggregate":
aggr = message.get("historical_data", {})
await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
# print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}")
else:
Log.info(f"Unknown message type: {message_type}")
async def run(self) -> None:
await self.pricer_client_.run()
class PtLiveStrategy(NamedObject):
config_: Dict[str, Any]
trading_pair_: TradingPair
model_data_policy_: ModelDataPolicy
pt_mkt_data_: RealTimeMarketData
pt_mkt_data_client_: PtMktDataClient
# for presentation: history of prediction values and trading signals
predictions_: pd.DataFrame
trading_signals_: pd.DataFrame
def __init__(
self,
config: Dict[str, Any],
instruments: List[Dict[str, str]],
):
self.config_ = config
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
self.predictions_ = pd.DataFrame()
self.trading_signals_ = pd.DataFrame()
import copy
# modified config must be passed to PtMarketData
config_copy = copy.deepcopy(config)
config_copy["instruments"] = instruments
self.pt_mkt_data_ = RealTimeMarketData(config=config_copy)
self.model_data_policy_ = ModelDataPolicy.create(
config, is_real_time=True,pair=self.trading_pair_
)
async def on_mkt_data_hist_snapshot(self, aggr: JsonDictT) -> None:
Log.info(f"on_mkt_data_hist_snapshot: {aggr}")
await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr)
pass
async def on_mkt_data_update(self, aggr: JsonDictT) -> None:
market_data_df = await self.pt_mkt_data_.on_mkt_data_update(update=aggr)
if market_data_df is not None:
self.trading_pair_.market_data_ = market_data_df
self.model_data_policy_.advance()
prediction = self.trading_pair_.run(market_data_df, self.model_data_policy_.advance())
self.predictions_ = pd.concat([self.predictions_, prediction.to_df()], ignore_index=True)
trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
# URGENT implement this
pass
async def run(self) -> None:
await self.pt_mkt_data_client_.run()
def _create_trades(
self, prediction: Prediction, last_row: pd.Series
) -> Optional[pd.DataFrame]:
pair = self.trading_pair_
trades = None
open_threshold = self.config_["dis-equilibrium_open_trshld"]
close_threshold = self.config_["dis-equilibrium_close_trshld"]
scaled_disequilibrium = prediction.scaled_disequilibrium_
abs_scaled_disequilibrium = abs(scaled_disequilibrium)
if pair.user_data_["state"] in [
PairState.INITIAL,
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]:
if abs_scaled_disequilibrium >= open_threshold:
trades = self._create_open_trades(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.OPEN.name
print(f"OPEN TRADES:\n{trades}")
pair.user_data_["state"] = PairState.OPEN
pair.on_open_trades(trades)
elif pair.user_data_["state"] == PairState.OPEN:
if abs_scaled_disequilibrium <= close_threshold:
trades = self._create_close_trades(
pair, row=last_row, prediction=prediction
)
if trades is not None:
trades["status"] = PairState.CLOSE.name
print(f"CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = PairState.CLOSE
pair.on_close_trades(trades)
elif pair.to_stop_close_conditions(predicted_row=last_row):
trades = self._create_close_trades(pair, row=last_row)
if trades is not None:
trades["status"] = pair.user_data_["stop_close_state"].name
print(f"STOP CLOSE TRADES:\n{trades}")
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
pair.on_close_trades(trades)
return trades
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
trades = None
pair = self.trading_pair_
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: *** Position is NOT CLOSED. ***")
# outstanding positions
if self.config_["close_outstanding_positions"]:
close_position_row = pd.Series(pair.market_data_.iloc[-2])
# close_position_row["disequilibrium"] = 0.0
# close_position_row["scaled_disequilibrium"] = 0.0
# close_position_row["signed_scaled_disequilibrium"] = 0.0
trades = self._create_close_trades(
pair=pair, row=close_position_row, prediction=None
)
if trades is not None:
trades["status"] = PairState.CLOSE_POSITION.name
print(f"CLOSE_POSITION TRADES:\n{trades}")
pair.user_data_["state"] = PairState.CLOSE_POSITION
pair.on_close_trades(trades)
else:
pair.add_outstanding_position(
symbol=pair.symbol_a_,
open_side=pair.user_data_["open_side_a"],
open_px=pair.user_data_["open_px_a"],
open_tstamp=pair.user_data_["open_tstamp"],
last_mkt_data_row=pair.market_data_.iloc[-1],
)
pair.add_outstanding_position(
symbol=pair.symbol_b_,
open_side=pair.user_data_["open_side_b"],
open_px=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
last_mkt_data_row=pair.market_data_.iloc[-1],
)
return trades
def _trades_df(self) -> pd.DataFrame:
types = {
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"side": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
# "pair": "object",
}
columns = list(types.keys())
return pd.DataFrame(columns=columns).astype(types)
def _create_open_trades(
self, pair: TradingPair, row: pd.Series, prediction: Prediction
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
tstamp = row["tstamp"]
diseqlbrm = prediction.disequilibrium_
scaled_disequilibrium = prediction.scaled_disequilibrium_
px_a = row[f"{colname_a}"]
px_b = row[f"{colname_b}"]
# creating the trades
df = self._trades_df()
print(f"OPEN_TRADES: {row["tstamp"]} {scaled_disequilibrium=}")
if diseqlbrm > 0:
side_a = "SELL"
side_b = "BUY"
else:
side_a = "BUY"
side_b = "SELL"
# save closing sides
pair.user_data_["open_side_a"] = side_a # used in oustanding positions
pair.user_data_["open_side_b"] = side_b
pair.user_data_["open_px_a"] = px_a
pair.user_data_["open_px_b"] = px_b
pair.user_data_["open_tstamp"] = tstamp
pair.user_data_["close_side_a"] = side_b # used for closing trades
pair.user_data_["close_side_b"] = side_a
# create opening trades
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_a_,
"side": side_a,
"action": "OPEN",
"price": px_a,
"disequilibrium": diseqlbrm,
"signed_scaled_disequilibrium": scaled_disequilibrium,
"scaled_disequilibrium": abs(scaled_disequilibrium),
# "pair": pair,
}
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_b_,
"side": side_b,
"action": "OPEN",
"price": px_b,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": abs(scaled_disequilibrium),
"signed_scaled_disequilibrium": scaled_disequilibrium,
# "pair": pair,
}
return df
def _create_close_trades(
self, pair: TradingPair, row: pd.Series, prediction: Optional[Prediction] = None
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
tstamp = row["tstamp"]
if prediction is not None:
diseqlbrm = prediction.disequilibrium_
signed_scaled_disequilibrium = prediction.scaled_disequilibrium_
scaled_disequilibrium = abs(prediction.scaled_disequilibrium_)
else:
diseqlbrm = 0.0
signed_scaled_disequilibrium = 0.0
scaled_disequilibrium = 0.0
px_a = row[f"{colname_a}"]
px_b = row[f"{colname_b}"]
# creating the trades
df = self._trades_df()
# create opening trades
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_a_,
"side": pair.user_data_["close_side_a"],
"action": "CLOSE",
"price": px_a,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": scaled_disequilibrium,
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
# "pair": pair,
}
df.loc[len(df)] = {
"time": tstamp,
"symbol": pair.symbol_b_,
"side": pair.user_data_["close_side_b"],
"action": "CLOSE",
"price": px_b,
"disequilibrium": diseqlbrm,
"scaled_disequilibrium": scaled_disequilibrium,
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
# "pair": pair,
}
del pair.user_data_["close_side_a"]
del pair.user_data_["close_side_b"]
del pair.user_data_["open_tstamp"]
del pair.user_data_["open_px_a"]
del pair.user_data_["open_px_b"]
del pair.user_data_["open_side_a"]
del pair.user_data_["open_side_b"]
return df

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import copy import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, cast from typing import Any, Dict, Optional, cast
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -19,17 +19,26 @@ class ModelDataPolicy(ABC):
config_: Dict[str, Any] config_: Dict[str, Any]
current_data_params_: DataWindowParams current_data_params_: DataWindowParams
count_: int count_: int
is_real_time_: bool
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
self.config_ = config self.config_ = config
training_size = config.get("training_size", 120)
training_start_index = 0
if kwargs.get("is_real_time", False):
training_size = 120
training_start_index = 0
else:
training_size = config.get("training_size", 120)
self.current_data_params_ = DataWindowParams( self.current_data_params_ = DataWindowParams(
training_size=config.get("training_size", 120), training_size=config.get("training_size", 120),
training_start_index=0, training_start_index=0,
) )
self.count_ = 0 self.count_ = 0
self.is_real_time_ = kwargs.get("is_real_time", False)
@abstractmethod @abstractmethod
def advance(self) -> DataWindowParams: def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams:
self.count_ += 1 self.count_ += 1
print(self.count_, end="\r") print(self.count_, end="\r")
return self.current_data_params_ return self.current_data_params_
@ -50,25 +59,18 @@ class ModelDataPolicy(ABC):
class RollingWindowDataPolicy(ModelDataPolicy): class RollingWindowDataPolicy(ModelDataPolicy):
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
super().__init__(config) super().__init__(config, *args, **kwargs)
self.count_ = 1 self.count_ = 1
def advance(self) -> DataWindowParams: def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams:
super().advance() super().advance(mkt_data_df)
if self.is_real_time_:
self.current_data_params_.training_start_index = -self.current_data_params_.training_size
else:
self.current_data_params_.training_start_index += 1 self.current_data_params_.training_start_index += 1
return self.current_data_params_ return self.current_data_params_
class ExpandingWindowDataPolicy(ModelDataPolicy):
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
super().__init__(config)
def advance(self) -> DataWindowParams:
super().advance()
self.current_data_params_.training_size += 1
return self.current_data_params_
class OptimizedWndDataPolicy(ModelDataPolicy, ABC): class OptimizedWndDataPolicy(ModelDataPolicy, ABC):
mkt_data_df_: pd.DataFrame mkt_data_df_: pd.DataFrame
pair_: TradingPair # type: ignore pair_: TradingPair # type: ignore
@ -79,34 +81,47 @@ class OptimizedWndDataPolicy(ModelDataPolicy, ABC):
prices_b_: np.ndarray prices_b_: np.ndarray
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any): def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
super().__init__(config) super().__init__(config, *args, **kwargs)
assert ( assert (
kwargs.get("mkt_data") is not None and kwargs.get("pair") is not None kwargs.get("pair") is not None
), "mkt_data and/or pair must be provided" ), "pair must be provided"
assert ( assert (
"min_training_size" in config and "max_training_size" in config "min_training_size" in config and "max_training_size" in config
), "min_training_size and max_training_size must be provided" ), "min_training_size and max_training_size must be provided"
self.min_training_size_ = cast(int, config.get("min_training_size")) self.min_training_size_ = cast(int, config.get("min_training_size"))
self.max_training_size_ = cast(int, config.get("max_training_size")) self.max_training_size_ = cast(int, config.get("max_training_size"))
assert self.min_training_size_ < self.max_training_size_
from pt_strategy.trading_pair import TradingPair from pt_strategy.trading_pair import TradingPair
self.mkt_data_df_ = cast(pd.DataFrame, kwargs.get("mkt_data"))
self.pair_ = cast(TradingPair, kwargs.get("pair")) self.pair_ = cast(TradingPair, kwargs.get("pair"))
self.end_index_ = ( if "mkt_data" in kwargs:
self.current_data_params_.training_start_index + self.max_training_size_ self.mkt_data_df_ = cast(pd.DataFrame, kwargs.get("mkt_data"))
) col_a, col_b = self.pair_.colnames()
self.prices_a_ = np.array(self.mkt_data_df_[col_a])
self.prices_b_ = np.array(self.mkt_data_df_[col_b])
assert self.min_training_size_ < self.max_training_size_
def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams:
super().advance(mkt_data_df)
if mkt_data_df is not None:
self.mkt_data_df_ = mkt_data_df
if self.is_real_time_:
self.end_index_ = len(self.mkt_data_df_) - 1
else:
self.end_index_ = self.current_data_params_.training_start_index + self.max_training_size_
if self.end_index_ > len(self.mkt_data_df_) - 1:
self.end_index_ = len(self.mkt_data_df_) - 1
self.current_data_params_.training_start_index = self.end_index_ - self.max_training_size_
if self.current_data_params_.training_start_index < 0:
self.current_data_params_.training_start_index = 0
col_a, col_b = self.pair_.colnames() col_a, col_b = self.pair_.colnames()
self.prices_a_ = np.array(self.mkt_data_df_[col_a]) self.prices_a_ = np.array(self.mkt_data_df_[col_a])
self.prices_b_ = np.array(self.mkt_data_df_[col_b]) self.prices_b_ = np.array(self.mkt_data_df_[col_b])
def advance(self) -> DataWindowParams:
super().advance()
self.current_data_params_ = self.optimize_window_size() self.current_data_params_ = self.optimize_window_size()
self.end_index_ += 1
return self.current_data_params_ return self.current_data_params_
@abstractmethod @abstractmethod
@ -126,6 +141,9 @@ class EGOptimizedWndDataPolicy(OptimizedWndDataPolicy):
last_pvalue = 1.0 last_pvalue = 1.0
result = copy.copy(self.current_data_params_) result = copy.copy(self.current_data_params_)
for trn_size in range(self.min_training_size_, self.max_training_size_): for trn_size in range(self.min_training_size_, self.max_training_size_):
if self.end_index_ - trn_size < 0:
break
from statsmodels.tsa.stattools import coint # type: ignore from statsmodels.tsa.stattools import coint # type: ignore
start_index = self.end_index_ - trn_size start_index = self.end_index_ - trn_size
@ -155,6 +173,8 @@ class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy):
last_pvalue = 1.0 last_pvalue = 1.0
result = copy.copy(self.current_data_params_) result = copy.copy(self.current_data_params_)
for trn_size in range(self.min_training_size_, self.max_training_size_): for trn_size in range(self.min_training_size_, self.max_training_size_):
if self.end_index_ - trn_size < 0:
break
start_index = self.end_index_ - trn_size start_index = self.end_index_ - trn_size
y = self.prices_a_[start_index : self.end_index_] y = self.prices_a_[start_index : self.end_index_]
x = self.prices_b_[start_index : self.end_index_] x = self.prices_b_[start_index : self.end_index_]
@ -201,6 +221,8 @@ class JohansenOptdWndDataPolicy(OptimizedWndDataPolicy):
result = copy.copy(self.current_data_params_) result = copy.copy(self.current_data_params_)
for trn_size in range(self.min_training_size_, self.max_training_size_): for trn_size in range(self.min_training_size_, self.max_training_size_):
if self.end_index_ - trn_size < 0:
break
start_index = self.end_index_ - trn_size start_index = self.end_index_ - trn_size
series_a = self.prices_a_[start_index:self.end_index_] series_a = self.prices_a_[start_index:self.end_index_]
series_b = self.prices_b_[start_index:self.end_index_] series_b = self.prices_b_[start_index:self.end_index_]

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from typing import Any, Dict
import pandas as pd
class Prediction:
tstamp_: pd.Timestamp
disequilibrium_: float
scaled_disequilibrium_: float
def __init__(self, tstamp: pd.Timestamp, disequilibrium: float, scaled_disequilibrium: float):
self.tstamp_ = tstamp
self.disequilibrium_ = disequilibrium
self.scaled_disequilibrium_ = scaled_disequilibrium
def to_dict(self) -> Dict[str, Any]:
return {
"tstamp": self.tstamp_,
"disequilibrium": self.disequilibrium_,
"signed_scaled_disequilibrium": self.scaled_disequilibrium_,
"scaled_disequilibrium": abs(self.scaled_disequilibrium_),
# "pair": self.pair_,
}
def to_df(self) -> pd.DataFrame:
return pd.DataFrame([self.to_dict()])

View File

@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type from typing import Any, Dict, List, Type
import pandas as pd import pandas as pd
from cvttpy_base.settings.cvtt_types import JsonDictT
from tools.data_loader import load_market_data from tools.data_loader import load_market_data
from pt_strategy.trading_pair import TradingPair
class PtMarketData():
class PtMarketData(ABC):
config_: Dict[str, Any] config_: Dict[str, Any]
origin_mkt_data_df_: pd.DataFrame origin_mkt_data_df_: pd.DataFrame
market_data_df_: pd.DataFrame market_data_df_: pd.DataFrame
@ -16,27 +16,10 @@ class PtMarketData(ABC):
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
self.config_ = config self.config_ = config
self.origin_mkt_data_df_ = pd.DataFrame() self.origin_mkt_data_df_ = pd.DataFrame()
self.market_data_df_ = pd.DataFrame()
@abstractmethod
def load(self) -> None:
...
@abstractmethod
def has_next(self) -> bool:
...
@abstractmethod
def get_next(self) -> pd.Series:
...
@staticmethod
def create(config: Dict[str, Any], md_class: Type[PtMarketData]) -> PtMarketData:
return md_class(config)
class ResearchMarketData(PtMarketData): class ResearchMarketData(PtMarketData):
config_: Dict[str, Any]
current_index_: int current_index_: int
is_execution_price_: bool is_execution_price_: bool
@ -185,3 +168,25 @@ class ResearchMarketData(PtMarketData):
f"exec_price_{self.symbol_b_}", f"exec_price_{self.symbol_b_}",
] ]
class RealTimeMarketData(PtMarketData):
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
super().__init__(config, *args, **kwargs)
async def on_mkt_data_hist_snapshot(self, snapshot: JsonDictT) -> None:
# URGENT
# create origin_mkt_data_df_ from snapshot
# transform it to market_data_df_ tstamp, close_symbolA, close_symbolB
pass
async def on_mkt_data_update(self, update: JsonDictT) -> Optional[pd.DataFrame]:
# URGENT
# make sure update has both instruments
# create DataFrame tmp1 from update
# transform tmp1 into temp. datframe tmp2
# add tmp1 to origin_mkt_data_df_
# add tmp2 to market_data_df_
# return market_data_df_
return pd.DataFrame()

View File

@ -1,39 +1,15 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from typing import Any, Dict, cast
from enum import Enum
from typing import Any, Dict, Optional, cast, Generator, List
import pandas as pd from pt_strategy.prediction import Prediction
from pt_strategy.trading_pair import TradingPair
class Prediction:
tstamp_: pd.Timestamp
disequilibrium_: float
scaled_disequilibrium_: float
def __init__(self, tstamp: pd.Timestamp, disequilibrium: float, scaled_disequilibrium: float):
self.tstamp_ = tstamp
self.disequilibrium_ = disequilibrium
self.scaled_disequilibrium_ = scaled_disequilibrium
def to_dict(self) -> Dict[str, Any]:
return {
"tstamp": self.tstamp_,
"disequilibrium": self.disequilibrium_,
"signed_scaled_disequilibrium": self.scaled_disequilibrium_,
"scaled_disequilibrium": abs(self.scaled_disequilibrium_),
# "pair": self.pair_,
}
def to_df(self) -> pd.DataFrame:
return pd.DataFrame([self.to_dict()])
class PairsTradingModel(ABC): class PairsTradingModel(ABC):
@abstractmethod @abstractmethod
def predict(self, pair: TradingPair) -> Prediction: def predict(self, pair: TradingPair) -> Prediction: # type: ignore[assignment]
... ...
@staticmethod @staticmethod

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
from pt_strategy.model_data_policy import ModelDataPolicy from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData from pt_strategy.pt_market_data import ResearchMarketData
from pt_strategy.pt_model import Prediction from pt_strategy.pt_model import Prediction
from pt_strategy.trading_pair import PairState, TradingPair from pt_strategy.trading_pair import PairState, TradingPair
@ -13,7 +13,7 @@ class PtResearchStrategy:
config_: Dict[str, Any] config_: Dict[str, Any]
trading_pair_: TradingPair trading_pair_: TradingPair
model_data_policy_: ModelDataPolicy model_data_policy_: ModelDataPolicy
pt_mkt_data_: PtMarketData pt_mkt_data_: ResearchMarketData
trades_: List[pd.DataFrame] trades_: List[pd.DataFrame]
predictions_: pd.DataFrame predictions_: pd.DataFrame
@ -25,7 +25,6 @@ class PtResearchStrategy:
instruments: List[Dict[str, str]], instruments: List[Dict[str, str]],
): ):
from pt_strategy.model_data_policy import ModelDataPolicy from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_market_data import PtMarketData, ResearchMarketData
from pt_strategy.trading_pair import TradingPair from pt_strategy.trading_pair import TradingPair
self.config_ = config self.config_ = config
@ -39,9 +38,7 @@ class PtResearchStrategy:
config_copy = copy.deepcopy(config) config_copy = copy.deepcopy(config)
config_copy["instruments"] = instruments config_copy["instruments"] = instruments
config_copy["datafiles"] = datafiles config_copy["datafiles"] = datafiles
self.pt_mkt_data_ = PtMarketData.create( self.pt_mkt_data_ = ResearchMarketData(config=config_copy)
config=config_copy, md_class=ResearchMarketData
)
self.pt_mkt_data_.load() self.pt_mkt_data_.load()
self.model_data_policy_ = ModelDataPolicy.create( self.model_data_policy_ = ModelDataPolicy.create(
config, mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_ config, mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_
@ -73,7 +70,7 @@ class PtResearchStrategy:
market_data_df = pd.concat([market_data_df, new_row], ignore_index=True) market_data_df = pd.concat([market_data_df, new_row], ignore_index=True)
prediction = self.trading_pair_.run( prediction = self.trading_pair_.run(
market_data_df, self.model_data_policy_.advance() market_data_df, self.model_data_policy_.advance(mkt_data_df=market_data_df)
) )
self.predictions_ = pd.concat( self.predictions_ = pd.concat(
[self.predictions_, prediction.to_df()], ignore_index=True [self.predictions_, prediction.to_df()], ignore_index=True

View File

@ -1,12 +1,13 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Type, cast from typing import Any, Dict, List
import pandas as pd import pandas as pd
from pt_strategy.model_data_policy import DataWindowParams from pt_strategy.model_data_policy import DataWindowParams
from pt_strategy.prediction import Prediction
class PairState(Enum): class PairState(Enum):
@ -20,6 +21,7 @@ class PairState(Enum):
class TradingPair: class TradingPair:
config_: Dict[str, Any] config_: Dict[str, Any]
market_data_: pd.DataFrame market_data_: pd.DataFrame
instruments_: List[Dict[str, str]]
symbol_a_: str symbol_a_: str
symbol_b_: str symbol_b_: str
@ -34,11 +36,12 @@ class TradingPair:
instruments: List[Dict[str, str]], instruments: List[Dict[str, str]],
): ):
from pt_strategy.model_data_policy import ModelDataPolicy
from pt_strategy.pt_model import PairsTradingModel from pt_strategy.pt_model import PairsTradingModel
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments" assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
self.config_ = config self.config_ = config
self.instruments_ = instruments
self.symbol_a_ = instruments[0]["symbol"] self.symbol_a_ = instruments[0]["symbol"]
self.symbol_b_ = instruments[1]["symbol"] self.symbol_b_ = instruments[1]["symbol"]
self.model_ = PairsTradingModel.create(config) self.model_ = PairsTradingModel.create(config)

View File

@ -1,4 +1,4 @@
from pt_strategy.trading_strategy import PtResearchStrategy from pt_strategy.research_strategy import PtResearchStrategy
def visualize_prices(strategy: PtResearchStrategy, trading_date: str) -> None: def visualize_prices(strategy: PtResearchStrategy, trading_date: str) -> None:

View File

@ -5,7 +5,7 @@ from typing import Any, Dict
from pt_strategy.results import (PairResearchResult, create_result_database, from pt_strategy.results import (PairResearchResult, create_result_database,
store_config_in_database) store_config_in_database)
from pt_strategy.trading_strategy import PtResearchStrategy from pt_strategy.research_strategy import PtResearchStrategy
from tools.filetools import resolve_datafiles from tools.filetools import resolve_datafiles
from tools.instruments import get_instruments from tools.instruments import get_instruments

View File

@ -16,7 +16,8 @@
"autoImportCompletions": true, "autoImportCompletions": true,
"autoSearchPaths": true, "autoSearchPaths": true,
"extraPaths": [ "extraPaths": [
"lib" "lib",
".."
], ],
"stubPath": "./typings", "stubPath": "./typings",
"venvPath": ".", "venvPath": ".",

View File

@ -8,7 +8,7 @@ from pt_strategy.results import (
create_result_database, create_result_database,
store_config_in_database, store_config_in_database,
) )
from pt_strategy.trading_strategy import PtResearchStrategy from pt_strategy.research_strategy import PtResearchStrategy
from tools.filetools import resolve_datafiles from tools.filetools import resolve_datafiles
from tools.instruments import get_instruments from tools.instruments import get_instruments

View File

@ -288,7 +288,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -302,7 +302,7 @@
"\n", "\n",
" \n", " \n",
" from pt_strategy.trading_pair import TradingPair\n", " from pt_strategy.trading_pair import TradingPair\n",
" from pt_strategy.trading_strategy import PtResearchStrategy\n", " from pt_strategy.research_strategy import PtResearchStrategy\n",
" from pt_strategy.results import PairResearchResult\n", " from pt_strategy.results import PairResearchResult\n",
"\n", "\n",
" # Create trading pair\n", " # Create trading pair\n",

View File

@ -1,221 +0,0 @@
import argparse
import asyncio
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import hjson
import pandas as pd
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def run_strategy(
config: Dict,
datafile: str,
fit_method: PairsTradingFitMethod,
instruments: List[str],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(
datafile=datafile,
exchange_id=config_copy["exchange_id"],
instruments=config_copy["instruments"],
instrument_id_pfx=config_copy["instrument_id_pfx"],
db_table_name=config_copy["db_table_name"],
trading_hours=config_copy["trading_hours"],
)
for a_index, b_index in unique_index_pairs:
pair = fit_method.create_trading_pair(
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
)
pairs.append(pair)
return pairs
pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
return bt_result
result = pd.concat(pairs_trades, ignore_index=True)
result["time"] = pd.to_datetime(result["time"])
result = result.set_index("time").sort_index()
bt_result.collect_single_day_results(result)
return bt_result
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafiles",
type=str,
required=False,
help="Comma-separated list of data files (overrides config). No wildcards supported.",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
parser.add_argument(
"--result_db",
type=str,
required=True,
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles)
if not datafiles:
print("No data files found to process.")
return
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
# Create result database if needed
if args.result_db.upper() != "NONE":
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
)
# Process each data file
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found for {datafile}, skipping...")
continue
# Process data for this file
try:
fit_method.reset()
bt_results = run_strategy(
config=config,
datafile=datafile,
fit_method=fit_method,
instruments=instruments,
)
# Store results with file name as key
filename = os.path.basename(datafile)
all_results[filename] = {"trades": bt_results.trades.copy()}
# Store results in database
if args.result_db.upper() != "NONE":
store_results_in_database(args.result_db, datafile, bt_results)
print(f"Successfully processed {filename}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
# Calculate and print results using a new BacktestResult instance for aggregation
if all_results:
aggregate_bt_results = BacktestResult(config=config)
aggregate_bt_results.calculate_returns(all_results)
aggregate_bt_results.print_grand_totals()
aggregate_bt_results.print_outstanding_positions()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")
else:
print("No results to display.")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -5,7 +5,7 @@ from typing import Any, Dict
from pt_strategy.results import (PairResearchResult, create_result_database, from pt_strategy.results import (PairResearchResult, create_result_database,
store_config_in_database) store_config_in_database)
from pt_strategy.trading_strategy import PtResearchStrategy from pt_strategy.research_strategy import PtResearchStrategy
from tools.filetools import resolve_datafiles from tools.filetools import resolve_datafiles
from tools.instruments import get_instruments from tools.instruments import get_instruments
from tools.viz.viz_trades import visualize_trades from tools.viz.viz_trades import visualize_trades