progress
This commit is contained in:
parent
80c3e8d54b
commit
73f36ddcea
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,7 +5,6 @@ __OLD__/
|
||||
.history/
|
||||
.cursorindexingignore
|
||||
data
|
||||
.vscode/
|
||||
cvttpy
|
||||
# SpecStory explanation file
|
||||
.specstory/.what-is-this.md
|
||||
|
||||
181
__DELETE__/.vscode/launch.json
vendored
Normal file
181
__DELETE__/.vscode/launch.json
vendored
Normal file
@ -0,0 +1,181 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
|
||||
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "-------- Z-Score (OLS) --------",
|
||||
},
|
||||
{
|
||||
"name": "CRYPTO z-score",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "research/pt_backtest.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/zscore.cfg",
|
||||
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
|
||||
"--date_pattern=20250605",
|
||||
"--result_db=${workspaceFolder}/research/results/crypto/%T.z-score.ADA-SOL.20250602.crypto_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "EQUITY z-score",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "research/pt_backtest.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/zscore.cfg",
|
||||
"--instruments=COIN:EQUITY:ALPACA,MSTR:EQUITY:ALPACA",
|
||||
"--date_pattern=2025060*",
|
||||
"--result_db=${workspaceFolder}/research/results/equity/%T.z-score.COIN-MSTR.20250602.equity_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "EQUITY-CRYPTO z-score",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "research/pt_backtest.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/zscore.cfg",
|
||||
"--instruments=COIN:EQUITY:ALPACA,BTC-USDT:CRYPTO:BNBSPOT",
|
||||
"--date_pattern=2025060*",
|
||||
"--result_db=${workspaceFolder}/research/results/intermarket/%T.z-score.COIN-BTC.20250601.equity_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "-------- VECM --------",
|
||||
},
|
||||
{
|
||||
"name": "CRYPTO vecm",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "research/pt_backtest.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/vecm.cfg",
|
||||
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
|
||||
"--date_pattern=2025060*",
|
||||
"--result_db=${workspaceFolder}/research/results/crypto/%T.vecm.ADA-SOL.20250602.crypto_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "EQUITY vecm",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "research/pt_backtest.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/vecm.cfg",
|
||||
"--instruments=COIN:EQUITY:ALPACA,MSTR:EQUITY:ALPACA",
|
||||
"--date_pattern=2025060*",
|
||||
"--result_db=${workspaceFolder}/research/results/equity/%T.vecm.COIN-MSTR.20250602.equity_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "EQUITY-CRYPTO vecm",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "research/pt_backtest.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/vecm.cfg",
|
||||
"--instruments=COIN:EQUITY:ALPACA,BTC-USDT:CRYPTO:BNBSPOT",
|
||||
"--date_pattern=2025060*",
|
||||
"--result_db=${workspaceFolder}/research/results/intermarket/%T.vecm.COIN-BTC.20250601.equity_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "-------- New ZSCORE --------",
|
||||
},
|
||||
{
|
||||
"name": "New CRYPTO z-score",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "${workspaceFolder}/research/backtest_new.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/new_zscore.cfg",
|
||||
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
|
||||
"--date_pattern=2025060*",
|
||||
"--result_db=${workspaceFolder}/research/results/crypto/%T.new_zscore.ADA-SOL.2025060-.crypto_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "New CRYPTO vecm",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "${workspaceFolder}/research/backtest_new.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/new_vecm.cfg",
|
||||
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
|
||||
"--date_pattern=20250605",
|
||||
"--result_db=${workspaceFolder}/research/results/crypto/%T.vecm.ADA-SOL.20250605.crypto_results.db",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "-------- Viz Test --------",
|
||||
},
|
||||
{
|
||||
"name": "Viz Test",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "/home/oleg/.pyenv/python3.12-venv/bin/python",
|
||||
"program": "${workspaceFolder}/research/viz_test.py",
|
||||
"args": [
|
||||
"--config=${workspaceFolder}/configuration/new_zscore.cfg",
|
||||
"--instruments=ADA-USDT:CRYPTO:BNBSPOT,SOL-USDT:CRYPTO:BNBSPOT",
|
||||
"--date_pattern=20250605",
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/lib"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
101
__DELETE__/strategy/pair_strategy.py
Normal file
101
__DELETE__/strategy/pair_strategy.py
Normal file
@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import hjson
|
||||
import pandas as pd
|
||||
|
||||
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
||||
from pt_trading.results import (
|
||||
BacktestResult,
|
||||
create_result_database,
|
||||
store_config_in_database,
|
||||
store_results_in_database,
|
||||
)
|
||||
from pt_trading.fit_methods import PairsTradingFitMethod
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
|
||||
|
||||
def run_strategy(
|
||||
config: Dict,
|
||||
datafile: str,
|
||||
fit_method: PairsTradingFitMethod,
|
||||
instruments: List[str],
|
||||
) -> BacktestResult:
|
||||
"""
|
||||
Run backtest for all pairs using the specified instruments.
|
||||
"""
|
||||
bt_result: BacktestResult = BacktestResult(config=config)
|
||||
|
||||
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
|
||||
nonlocal datafile
|
||||
all_indexes = range(len(instruments))
|
||||
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
||||
pairs = []
|
||||
|
||||
# Update config to use the specified instruments
|
||||
config_copy = config.copy()
|
||||
config_copy["instruments"] = instruments
|
||||
|
||||
market_data_df = load_market_data(
|
||||
datafile=datafile,
|
||||
exchange_id=config_copy["exchange_id"],
|
||||
instruments=config_copy["instruments"],
|
||||
instrument_id_pfx=config_copy["instrument_id_pfx"],
|
||||
db_table_name=config_copy["db_table_name"],
|
||||
trading_hours=config_copy["trading_hours"],
|
||||
)
|
||||
|
||||
for a_index, b_index in unique_index_pairs:
|
||||
pair = fit_method.create_trading_pair(
|
||||
market_data=market_data_df,
|
||||
symbol_a=instruments[a_index],
|
||||
symbol_b=instruments[b_index],
|
||||
)
|
||||
pairs.append(pair)
|
||||
return pairs
|
||||
|
||||
pairs_trades = []
|
||||
for pair in _create_pairs(config, instruments):
|
||||
single_pair_trades = fit_method.run_pair(
|
||||
pair=pair, config=config, bt_result=bt_result
|
||||
)
|
||||
if single_pair_trades is not None and len(single_pair_trades) > 0:
|
||||
pairs_trades.append(single_pair_trades)
|
||||
|
||||
# Check if result_list has any data before concatenating
|
||||
if len(pairs_trades) == 0:
|
||||
print("No trading signals found for any pairs")
|
||||
return bt_result
|
||||
|
||||
result = pd.concat(pairs_trades, ignore_index=True)
|
||||
result["time"] = pd.to_datetime(result["time"])
|
||||
result = result.set_index("time").sort_index()
|
||||
|
||||
bt_result.collect_single_day_results(result)
|
||||
return bt_result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Load config
|
||||
# Subscribe to CVTT market data
|
||||
# On snapshot (with historical data) - create trading strategy with market data dateframe
|
||||
|
||||
async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
|
||||
print(f"{message_type=} {subscr_id=} {instrument_id}")
|
||||
if message_type == "md_aggregate":
|
||||
aggr = message.get("md_aggregate", [])
|
||||
print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}")
|
||||
elif message_type == "historical_md_aggregate":
|
||||
for aggr in message.get("historical_data", []):
|
||||
print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}")
|
||||
else:
|
||||
print(f"Unknown message type: {message_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
69
bin/trade_pair.py
Normal file
69
bin/trade_pair.py
Normal file
@ -0,0 +1,69 @@
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
from cvtt_client.mkt_data import (CvttPricerWebSockClient,
|
||||
CvttPricesSubscription, MessageTypeT,
|
||||
SubscriptionIdT)
|
||||
from cvttpy_base.tools.app import App
|
||||
from cvttpy_base.tools.base import NamedObject
|
||||
from pt_strategy.live_strategy import PtLiveStrategy
|
||||
|
||||
|
||||
class PairTradingRunner(NamedObject):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
App.instance().add_call(App.Stage.Config, self._on_config())
|
||||
App.instance().add_call(App.Stage.Run, self.run())
|
||||
|
||||
async def _on_config(self) -> None:
|
||||
pass
|
||||
|
||||
async def run(self) -> None:
|
||||
pass
|
||||
|
||||
# async def main() -> None:
|
||||
# live_strategy = PtLiveStrategy(
|
||||
# config={},
|
||||
# instruments=[
|
||||
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-BTC-USD"},
|
||||
# {"exchange_config_name": "COINBASE_AT", "instrument_id": "PAIR-ETH-USD"},
|
||||
# ]
|
||||
# )
|
||||
# async def on_message(message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
|
||||
# print(f"{message_type=} {subscr_id=} {instrument_id}")
|
||||
# if message_type == "md_aggregate":
|
||||
# aggr = message.get("md_aggregate", [])
|
||||
# print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}")
|
||||
# elif message_type == "historical_md_aggregate":
|
||||
# for aggr in message.get("historical_data", []):
|
||||
# print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}")
|
||||
# else:
|
||||
# print(f"Unknown message type: {message_type}")
|
||||
|
||||
# pricer_client = CvttPricerWebSockClient(
|
||||
# "ws://localhost:12346/ws"
|
||||
# )
|
||||
# await pricer_client.subscribe(CvttPricesSubscription(
|
||||
# exchange_config_name="COINBASE_AT",
|
||||
# instrument_id="PAIR-BTC-USD",
|
||||
# interval_sec=60,
|
||||
# history_depth_sec=60*60*24,
|
||||
# callback=partial(on_message, instrument_id="PAIR-BTC-USD")
|
||||
# ))
|
||||
# await pricer_client.subscribe(CvttPricesSubscription(
|
||||
# exchange_config_name="COINBASE_AT",
|
||||
# instrument_id="PAIR-ETH-USD",
|
||||
# interval_sec=60,
|
||||
# history_depth_sec=60*60*24,
|
||||
# callback=partial(on_message, instrument_id="PAIR-ETH-USD")
|
||||
# ))
|
||||
|
||||
# await pricer_client.run()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
App()
|
||||
|
||||
App.instance().run()
|
||||
@ -24,8 +24,7 @@
|
||||
"dis-equilibrium_close_trshld": 0.9,
|
||||
"model_class": "pt_strategy.models.OLSModel",
|
||||
|
||||
# "training_size": 120,
|
||||
# "model_data_policy_class": "pt_strategy.model_data_policy.RollingWindowDataPolicy",
|
||||
# "model_data_policy_class": "pt_strategy.model_data_policy.EGOptimizedWndDataPolicy",
|
||||
# "model_data_policy_class": "pt_strategy.model_data_policy.ADFOptimizedWndDataPolicy",
|
||||
"model_data_policy_class": "pt_strategy.model_data_policy.JohansenOptdWndDataPolicy",
|
||||
"min_training_size": 60,
|
||||
|
||||
331
lib/pt_strategy/live_strategy.py
Normal file
331
lib/pt_strategy/live_strategy.py
Normal file
@ -0,0 +1,331 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from cvttpy_base.settings.cvtt_types import JsonDictT
|
||||
from cvttpy_base.tools.base import NamedObject
|
||||
from cvttpy_base.tools.logger import Log
|
||||
|
||||
from cvtt_client.mkt_data import CvttPricerWebSockClient, CvttPricesSubscription, MessageTypeT, SubscriptionIdT
|
||||
from pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pt_strategy.pt_market_data import PtMarketData, RealTimeMarketData
|
||||
from pt_strategy.pt_model import Prediction
|
||||
from pt_strategy.trading_pair import PairState, TradingPair
|
||||
|
||||
'''
|
||||
--config=pair.cfg
|
||||
--pair=PAIR-BTC-USDT:COINBASE_AT,PAIR-ETH-USDT:COINBASE_AT
|
||||
'''
|
||||
|
||||
|
||||
class PtMktDataClient(NamedObject):
|
||||
live_strategy_: PtLiveStrategy
|
||||
pricer_client_: CvttPricerWebSockClient
|
||||
subscriptions_: List[CvttPricesSubscription]
|
||||
|
||||
def __init__(self, live_strategy: PtLiveStrategy):
|
||||
self.live_strategy_ = live_strategy
|
||||
|
||||
async def start(self, subscription: CvttPricesSubscription) -> None:
|
||||
pricer_url = self.live_strategy_.config_.get("pricer_url", None) #, "ws://localhost:12346/ws")
|
||||
assert pricer_url is not None, "pricer_url is not found in config"
|
||||
self.pricer_client_ = CvttPricerWebSockClient(url=pricer_url)
|
||||
|
||||
await self._subscribe()
|
||||
|
||||
async def _subscribe(self) -> None:
|
||||
pair: TradingPair = self.live_strategy_.trading_pair_
|
||||
for instrument in pair.instruments_:
|
||||
await self.pricer_client_.subscribe(CvttPricesSubscription(
|
||||
exchange_config_name=instrument["exchange_config_name"],
|
||||
instrument_id=instrument["instrument_id"],
|
||||
interval_sec=60,
|
||||
history_depth_sec=60*60*24,
|
||||
callback=partial(self.on_message, instrument_id=instrument["instrument_id"])
|
||||
))
|
||||
|
||||
async def on_message(self, message_type: MessageTypeT, subscr_id: SubscriptionIdT, message: Dict, instrument_id: str) -> None:
|
||||
Log.info(f"{self.fname()}: {message_type=} {subscr_id=} {instrument_id}")
|
||||
aggr: JsonDictT
|
||||
if message_type == "md_aggregate":
|
||||
aggr = message.get("md_aggregate", {})
|
||||
await self.live_strategy_.on_mkt_data_update(aggr)
|
||||
# print(f"[{aggr['tstmp'][:19]}] *** RLTM *** {message}")
|
||||
elif message_type == "historical_md_aggregate":
|
||||
aggr = message.get("historical_data", {})
|
||||
await self.live_strategy_.on_mkt_data_hist_snapshot(aggr)
|
||||
# print(f"[{aggr['tstmp'][:19]}] *** HIST *** {aggr}")
|
||||
else:
|
||||
Log.info(f"Unknown message type: {message_type}")
|
||||
|
||||
async def run(self) -> None:
|
||||
await self.pricer_client_.run()
|
||||
|
||||
|
||||
|
||||
class PtLiveStrategy(NamedObject):
|
||||
config_: Dict[str, Any]
|
||||
trading_pair_: TradingPair
|
||||
model_data_policy_: ModelDataPolicy
|
||||
pt_mkt_data_: RealTimeMarketData
|
||||
pt_mkt_data_client_: PtMktDataClient
|
||||
|
||||
# for presentation: history of prediction values and trading signals
|
||||
predictions_: pd.DataFrame
|
||||
trading_signals_: pd.DataFrame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
instruments: List[Dict[str, str]],
|
||||
):
|
||||
|
||||
self.config_ = config
|
||||
self.trading_pair_ = TradingPair(config=config, instruments=instruments)
|
||||
self.predictions_ = pd.DataFrame()
|
||||
self.trading_signals_ = pd.DataFrame()
|
||||
|
||||
import copy
|
||||
|
||||
# modified config must be passed to PtMarketData
|
||||
config_copy = copy.deepcopy(config)
|
||||
config_copy["instruments"] = instruments
|
||||
self.pt_mkt_data_ = RealTimeMarketData(config=config_copy)
|
||||
self.model_data_policy_ = ModelDataPolicy.create(
|
||||
config, is_real_time=True,pair=self.trading_pair_
|
||||
)
|
||||
|
||||
async def on_mkt_data_hist_snapshot(self, aggr: JsonDictT) -> None:
|
||||
Log.info(f"on_mkt_data_hist_snapshot: {aggr}")
|
||||
await self.pt_mkt_data_.on_mkt_data_hist_snapshot(snapshot=aggr)
|
||||
pass
|
||||
|
||||
async def on_mkt_data_update(self, aggr: JsonDictT) -> None:
|
||||
market_data_df = await self.pt_mkt_data_.on_mkt_data_update(update=aggr)
|
||||
if market_data_df is not None:
|
||||
self.trading_pair_.market_data_ = market_data_df
|
||||
self.model_data_policy_.advance()
|
||||
prediction = self.trading_pair_.run(market_data_df, self.model_data_policy_.advance())
|
||||
self.predictions_ = pd.concat([self.predictions_, prediction.to_df()], ignore_index=True)
|
||||
trades = self._create_trades(prediction=prediction, last_row=market_data_df.iloc[-1])
|
||||
# URGENT implement this
|
||||
pass
|
||||
|
||||
|
||||
async def run(self) -> None:
|
||||
await self.pt_mkt_data_client_.run()
|
||||
|
||||
def _create_trades(
|
||||
self, prediction: Prediction, last_row: pd.Series
|
||||
) -> Optional[pd.DataFrame]:
|
||||
pair = self.trading_pair_
|
||||
trades = None
|
||||
|
||||
open_threshold = self.config_["dis-equilibrium_open_trshld"]
|
||||
close_threshold = self.config_["dis-equilibrium_close_trshld"]
|
||||
scaled_disequilibrium = prediction.scaled_disequilibrium_
|
||||
abs_scaled_disequilibrium = abs(scaled_disequilibrium)
|
||||
|
||||
if pair.user_data_["state"] in [
|
||||
PairState.INITIAL,
|
||||
PairState.CLOSE,
|
||||
PairState.CLOSE_POSITION,
|
||||
PairState.CLOSE_STOP_LOSS,
|
||||
PairState.CLOSE_STOP_PROFIT,
|
||||
]:
|
||||
if abs_scaled_disequilibrium >= open_threshold:
|
||||
trades = self._create_open_trades(
|
||||
pair, row=last_row, prediction=prediction
|
||||
)
|
||||
if trades is not None:
|
||||
trades["status"] = PairState.OPEN.name
|
||||
print(f"OPEN TRADES:\n{trades}")
|
||||
pair.user_data_["state"] = PairState.OPEN
|
||||
pair.on_open_trades(trades)
|
||||
|
||||
elif pair.user_data_["state"] == PairState.OPEN:
|
||||
if abs_scaled_disequilibrium <= close_threshold:
|
||||
trades = self._create_close_trades(
|
||||
pair, row=last_row, prediction=prediction
|
||||
)
|
||||
if trades is not None:
|
||||
trades["status"] = PairState.CLOSE.name
|
||||
print(f"CLOSE TRADES:\n{trades}")
|
||||
pair.user_data_["state"] = PairState.CLOSE
|
||||
pair.on_close_trades(trades)
|
||||
elif pair.to_stop_close_conditions(predicted_row=last_row):
|
||||
trades = self._create_close_trades(pair, row=last_row)
|
||||
if trades is not None:
|
||||
trades["status"] = pair.user_data_["stop_close_state"].name
|
||||
print(f"STOP CLOSE TRADES:\n{trades}")
|
||||
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
|
||||
pair.on_close_trades(trades)
|
||||
|
||||
return trades
|
||||
|
||||
def _handle_outstanding_positions(self) -> Optional[pd.DataFrame]:
|
||||
trades = None
|
||||
pair = self.trading_pair_
|
||||
|
||||
# Outstanding positions
|
||||
if pair.user_data_["state"] == PairState.OPEN:
|
||||
print(f"{pair}: *** Position is NOT CLOSED. ***")
|
||||
# outstanding positions
|
||||
if self.config_["close_outstanding_positions"]:
|
||||
close_position_row = pd.Series(pair.market_data_.iloc[-2])
|
||||
# close_position_row["disequilibrium"] = 0.0
|
||||
# close_position_row["scaled_disequilibrium"] = 0.0
|
||||
# close_position_row["signed_scaled_disequilibrium"] = 0.0
|
||||
|
||||
trades = self._create_close_trades(
|
||||
pair=pair, row=close_position_row, prediction=None
|
||||
)
|
||||
if trades is not None:
|
||||
trades["status"] = PairState.CLOSE_POSITION.name
|
||||
print(f"CLOSE_POSITION TRADES:\n{trades}")
|
||||
pair.user_data_["state"] = PairState.CLOSE_POSITION
|
||||
pair.on_close_trades(trades)
|
||||
else:
|
||||
pair.add_outstanding_position(
|
||||
symbol=pair.symbol_a_,
|
||||
open_side=pair.user_data_["open_side_a"],
|
||||
open_px=pair.user_data_["open_px_a"],
|
||||
open_tstamp=pair.user_data_["open_tstamp"],
|
||||
last_mkt_data_row=pair.market_data_.iloc[-1],
|
||||
)
|
||||
pair.add_outstanding_position(
|
||||
symbol=pair.symbol_b_,
|
||||
open_side=pair.user_data_["open_side_b"],
|
||||
open_px=pair.user_data_["open_px_b"],
|
||||
open_tstamp=pair.user_data_["open_tstamp"],
|
||||
last_mkt_data_row=pair.market_data_.iloc[-1],
|
||||
)
|
||||
return trades
|
||||
|
||||
def _trades_df(self) -> pd.DataFrame:
|
||||
types = {
|
||||
"time": "datetime64[ns]",
|
||||
"action": "string",
|
||||
"symbol": "string",
|
||||
"side": "string",
|
||||
"price": "float64",
|
||||
"disequilibrium": "float64",
|
||||
"scaled_disequilibrium": "float64",
|
||||
"signed_scaled_disequilibrium": "float64",
|
||||
# "pair": "object",
|
||||
}
|
||||
columns = list(types.keys())
|
||||
return pd.DataFrame(columns=columns).astype(types)
|
||||
|
||||
def _create_open_trades(
|
||||
self, pair: TradingPair, row: pd.Series, prediction: Prediction
|
||||
) -> Optional[pd.DataFrame]:
|
||||
colname_a, colname_b = pair.exec_prices_colnames()
|
||||
|
||||
tstamp = row["tstamp"]
|
||||
diseqlbrm = prediction.disequilibrium_
|
||||
scaled_disequilibrium = prediction.scaled_disequilibrium_
|
||||
px_a = row[f"{colname_a}"]
|
||||
px_b = row[f"{colname_b}"]
|
||||
|
||||
# creating the trades
|
||||
df = self._trades_df()
|
||||
|
||||
print(f"OPEN_TRADES: {row["tstamp"]} {scaled_disequilibrium=}")
|
||||
if diseqlbrm > 0:
|
||||
side_a = "SELL"
|
||||
side_b = "BUY"
|
||||
else:
|
||||
side_a = "BUY"
|
||||
side_b = "SELL"
|
||||
|
||||
# save closing sides
|
||||
pair.user_data_["open_side_a"] = side_a # used in oustanding positions
|
||||
pair.user_data_["open_side_b"] = side_b
|
||||
pair.user_data_["open_px_a"] = px_a
|
||||
pair.user_data_["open_px_b"] = px_b
|
||||
pair.user_data_["open_tstamp"] = tstamp
|
||||
|
||||
pair.user_data_["close_side_a"] = side_b # used for closing trades
|
||||
pair.user_data_["close_side_b"] = side_a
|
||||
|
||||
# create opening trades
|
||||
df.loc[len(df)] = {
|
||||
"time": tstamp,
|
||||
"symbol": pair.symbol_a_,
|
||||
"side": side_a,
|
||||
"action": "OPEN",
|
||||
"price": px_a,
|
||||
"disequilibrium": diseqlbrm,
|
||||
"signed_scaled_disequilibrium": scaled_disequilibrium,
|
||||
"scaled_disequilibrium": abs(scaled_disequilibrium),
|
||||
# "pair": pair,
|
||||
}
|
||||
df.loc[len(df)] = {
|
||||
"time": tstamp,
|
||||
"symbol": pair.symbol_b_,
|
||||
"side": side_b,
|
||||
"action": "OPEN",
|
||||
"price": px_b,
|
||||
"disequilibrium": diseqlbrm,
|
||||
"scaled_disequilibrium": abs(scaled_disequilibrium),
|
||||
"signed_scaled_disequilibrium": scaled_disequilibrium,
|
||||
# "pair": pair,
|
||||
}
|
||||
return df
|
||||
|
||||
def _create_close_trades(
|
||||
self, pair: TradingPair, row: pd.Series, prediction: Optional[Prediction] = None
|
||||
) -> Optional[pd.DataFrame]:
|
||||
colname_a, colname_b = pair.exec_prices_colnames()
|
||||
|
||||
tstamp = row["tstamp"]
|
||||
if prediction is not None:
|
||||
diseqlbrm = prediction.disequilibrium_
|
||||
signed_scaled_disequilibrium = prediction.scaled_disequilibrium_
|
||||
scaled_disequilibrium = abs(prediction.scaled_disequilibrium_)
|
||||
else:
|
||||
diseqlbrm = 0.0
|
||||
signed_scaled_disequilibrium = 0.0
|
||||
scaled_disequilibrium = 0.0
|
||||
px_a = row[f"{colname_a}"]
|
||||
px_b = row[f"{colname_b}"]
|
||||
|
||||
# creating the trades
|
||||
df = self._trades_df()
|
||||
|
||||
# create opening trades
|
||||
df.loc[len(df)] = {
|
||||
"time": tstamp,
|
||||
"symbol": pair.symbol_a_,
|
||||
"side": pair.user_data_["close_side_a"],
|
||||
"action": "CLOSE",
|
||||
"price": px_a,
|
||||
"disequilibrium": diseqlbrm,
|
||||
"scaled_disequilibrium": scaled_disequilibrium,
|
||||
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
|
||||
# "pair": pair,
|
||||
}
|
||||
df.loc[len(df)] = {
|
||||
"time": tstamp,
|
||||
"symbol": pair.symbol_b_,
|
||||
"side": pair.user_data_["close_side_b"],
|
||||
"action": "CLOSE",
|
||||
"price": px_b,
|
||||
"disequilibrium": diseqlbrm,
|
||||
"scaled_disequilibrium": scaled_disequilibrium,
|
||||
"signed_scaled_disequilibrium": signed_scaled_disequilibrium,
|
||||
# "pair": pair,
|
||||
}
|
||||
del pair.user_data_["close_side_a"]
|
||||
del pair.user_data_["close_side_b"]
|
||||
|
||||
del pair.user_data_["open_tstamp"]
|
||||
del pair.user_data_["open_px_a"]
|
||||
del pair.user_data_["open_px_b"]
|
||||
del pair.user_data_["open_side_a"]
|
||||
del pair.user_data_["open_side_b"]
|
||||
return df
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, cast
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -19,17 +19,26 @@ class ModelDataPolicy(ABC):
|
||||
config_: Dict[str, Any]
|
||||
current_data_params_: DataWindowParams
|
||||
count_: int
|
||||
is_real_time_: bool
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
||||
self.config_ = config
|
||||
training_size = config.get("training_size", 120)
|
||||
training_start_index = 0
|
||||
if kwargs.get("is_real_time", False):
|
||||
training_size = 120
|
||||
training_start_index = 0
|
||||
else:
|
||||
training_size = config.get("training_size", 120)
|
||||
self.current_data_params_ = DataWindowParams(
|
||||
training_size=config.get("training_size", 120),
|
||||
training_start_index=0,
|
||||
)
|
||||
self.count_ = 0
|
||||
self.is_real_time_ = kwargs.get("is_real_time", False)
|
||||
|
||||
@abstractmethod
|
||||
def advance(self) -> DataWindowParams:
|
||||
def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams:
|
||||
self.count_ += 1
|
||||
print(self.count_, end="\r")
|
||||
return self.current_data_params_
|
||||
@ -50,22 +59,15 @@ class ModelDataPolicy(ABC):
|
||||
|
||||
class RollingWindowDataPolicy(ModelDataPolicy):
|
||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
||||
super().__init__(config)
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.count_ = 1
|
||||
|
||||
def advance(self) -> DataWindowParams:
|
||||
super().advance()
|
||||
self.current_data_params_.training_start_index += 1
|
||||
return self.current_data_params_
|
||||
|
||||
|
||||
class ExpandingWindowDataPolicy(ModelDataPolicy):
|
||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
||||
super().__init__(config)
|
||||
|
||||
def advance(self) -> DataWindowParams:
|
||||
super().advance()
|
||||
self.current_data_params_.training_size += 1
|
||||
def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams:
|
||||
super().advance(mkt_data_df)
|
||||
if self.is_real_time_:
|
||||
self.current_data_params_.training_start_index = -self.current_data_params_.training_size
|
||||
else:
|
||||
self.current_data_params_.training_start_index += 1
|
||||
return self.current_data_params_
|
||||
|
||||
|
||||
@ -79,34 +81,47 @@ class OptimizedWndDataPolicy(ModelDataPolicy, ABC):
|
||||
prices_b_: np.ndarray
|
||||
|
||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
||||
super().__init__(config)
|
||||
super().__init__(config, *args, **kwargs)
|
||||
assert (
|
||||
kwargs.get("mkt_data") is not None and kwargs.get("pair") is not None
|
||||
), "mkt_data and/or pair must be provided"
|
||||
kwargs.get("pair") is not None
|
||||
), "pair must be provided"
|
||||
assert (
|
||||
"min_training_size" in config and "max_training_size" in config
|
||||
), "min_training_size and max_training_size must be provided"
|
||||
self.min_training_size_ = cast(int, config.get("min_training_size"))
|
||||
self.max_training_size_ = cast(int, config.get("max_training_size"))
|
||||
assert self.min_training_size_ < self.max_training_size_
|
||||
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
|
||||
self.mkt_data_df_ = cast(pd.DataFrame, kwargs.get("mkt_data"))
|
||||
self.pair_ = cast(TradingPair, kwargs.get("pair"))
|
||||
|
||||
self.end_index_ = (
|
||||
self.current_data_params_.training_start_index + self.max_training_size_
|
||||
)
|
||||
if "mkt_data" in kwargs:
|
||||
self.mkt_data_df_ = cast(pd.DataFrame, kwargs.get("mkt_data"))
|
||||
col_a, col_b = self.pair_.colnames()
|
||||
self.prices_a_ = np.array(self.mkt_data_df_[col_a])
|
||||
self.prices_b_ = np.array(self.mkt_data_df_[col_b])
|
||||
assert self.min_training_size_ < self.max_training_size_
|
||||
|
||||
|
||||
def advance(self, mkt_data_df: Optional[pd.DataFrame] = None) -> DataWindowParams:
|
||||
super().advance(mkt_data_df)
|
||||
if mkt_data_df is not None:
|
||||
self.mkt_data_df_ = mkt_data_df
|
||||
|
||||
if self.is_real_time_:
|
||||
self.end_index_ = len(self.mkt_data_df_) - 1
|
||||
else:
|
||||
self.end_index_ = self.current_data_params_.training_start_index + self.max_training_size_
|
||||
if self.end_index_ > len(self.mkt_data_df_) - 1:
|
||||
self.end_index_ = len(self.mkt_data_df_) - 1
|
||||
self.current_data_params_.training_start_index = self.end_index_ - self.max_training_size_
|
||||
if self.current_data_params_.training_start_index < 0:
|
||||
self.current_data_params_.training_start_index = 0
|
||||
|
||||
col_a, col_b = self.pair_.colnames()
|
||||
self.prices_a_ = np.array(self.mkt_data_df_[col_a])
|
||||
self.prices_b_ = np.array(self.mkt_data_df_[col_b])
|
||||
|
||||
|
||||
def advance(self) -> DataWindowParams:
|
||||
super().advance()
|
||||
self.current_data_params_ = self.optimize_window_size()
|
||||
self.end_index_ += 1
|
||||
return self.current_data_params_
|
||||
|
||||
@abstractmethod
|
||||
@ -126,6 +141,9 @@ class EGOptimizedWndDataPolicy(OptimizedWndDataPolicy):
|
||||
last_pvalue = 1.0
|
||||
result = copy.copy(self.current_data_params_)
|
||||
for trn_size in range(self.min_training_size_, self.max_training_size_):
|
||||
if self.end_index_ - trn_size < 0:
|
||||
break
|
||||
|
||||
from statsmodels.tsa.stattools import coint # type: ignore
|
||||
|
||||
start_index = self.end_index_ - trn_size
|
||||
@ -155,6 +173,8 @@ class ADFOptimizedWndDataPolicy(OptimizedWndDataPolicy):
|
||||
last_pvalue = 1.0
|
||||
result = copy.copy(self.current_data_params_)
|
||||
for trn_size in range(self.min_training_size_, self.max_training_size_):
|
||||
if self.end_index_ - trn_size < 0:
|
||||
break
|
||||
start_index = self.end_index_ - trn_size
|
||||
y = self.prices_a_[start_index : self.end_index_]
|
||||
x = self.prices_b_[start_index : self.end_index_]
|
||||
@ -201,6 +221,8 @@ class JohansenOptdWndDataPolicy(OptimizedWndDataPolicy):
|
||||
|
||||
result = copy.copy(self.current_data_params_)
|
||||
for trn_size in range(self.min_training_size_, self.max_training_size_):
|
||||
if self.end_index_ - trn_size < 0:
|
||||
break
|
||||
start_index = self.end_index_ - trn_size
|
||||
series_a = self.prices_a_[start_index:self.end_index_]
|
||||
series_b = self.prices_b_[start_index:self.end_index_]
|
||||
|
||||
28
lib/pt_strategy/prediction.py
Normal file
28
lib/pt_strategy/prediction.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class Prediction:
|
||||
tstamp_: pd.Timestamp
|
||||
disequilibrium_: float
|
||||
scaled_disequilibrium_: float
|
||||
|
||||
def __init__(self, tstamp: pd.Timestamp, disequilibrium: float, scaled_disequilibrium: float):
|
||||
self.tstamp_ = tstamp
|
||||
self.disequilibrium_ = disequilibrium
|
||||
self.scaled_disequilibrium_ = scaled_disequilibrium
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"tstamp": self.tstamp_,
|
||||
"disequilibrium": self.disequilibrium_,
|
||||
"signed_scaled_disequilibrium": self.scaled_disequilibrium_,
|
||||
"scaled_disequilibrium": abs(self.scaled_disequilibrium_),
|
||||
# "pair": self.pair_,
|
||||
}
|
||||
def to_df(self) -> pd.DataFrame:
|
||||
return pd.DataFrame([self.to_dict()])
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from cvttpy_base.settings.cvtt_types import JsonDictT
|
||||
|
||||
from tools.data_loader import load_market_data
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
|
||||
|
||||
class PtMarketData(ABC):
|
||||
class PtMarketData():
|
||||
config_: Dict[str, Any]
|
||||
origin_mkt_data_df_: pd.DataFrame
|
||||
market_data_df_: pd.DataFrame
|
||||
@ -16,27 +16,10 @@ class PtMarketData(ABC):
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config_ = config
|
||||
self.origin_mkt_data_df_ = pd.DataFrame()
|
||||
self.market_data_df_ = pd.DataFrame()
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def has_next(self) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_next(self) -> pd.Series:
|
||||
...
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create(config: Dict[str, Any], md_class: Type[PtMarketData]) -> PtMarketData:
|
||||
return md_class(config)
|
||||
|
||||
class ResearchMarketData(PtMarketData):
|
||||
config_: Dict[str, Any]
|
||||
current_index_: int
|
||||
|
||||
is_execution_price_: bool
|
||||
@ -185,3 +168,25 @@ class ResearchMarketData(PtMarketData):
|
||||
f"exec_price_{self.symbol_b_}",
|
||||
]
|
||||
|
||||
class RealTimeMarketData(PtMarketData):
|
||||
|
||||
def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
|
||||
async def on_mkt_data_hist_snapshot(self, snapshot: JsonDictT) -> None:
|
||||
# URGENT
|
||||
# create origin_mkt_data_df_ from snapshot
|
||||
# transform it to market_data_df_ tstamp, close_symbolA, close_symbolB
|
||||
pass
|
||||
|
||||
async def on_mkt_data_update(self, update: JsonDictT) -> Optional[pd.DataFrame]:
|
||||
# URGENT
|
||||
# make sure update has both instruments
|
||||
# create DataFrame tmp1 from update
|
||||
# transform tmp1 into temp. datframe tmp2
|
||||
# add tmp1 to origin_mkt_data_df_
|
||||
# add tmp2 to market_data_df_
|
||||
# return market_data_df_
|
||||
|
||||
|
||||
return pd.DataFrame()
|
||||
@ -1,39 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, cast, Generator, List
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pandas as pd
|
||||
from pt_strategy.prediction import Prediction
|
||||
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
|
||||
class Prediction:
|
||||
tstamp_: pd.Timestamp
|
||||
disequilibrium_: float
|
||||
scaled_disequilibrium_: float
|
||||
|
||||
def __init__(self, tstamp: pd.Timestamp, disequilibrium: float, scaled_disequilibrium: float):
|
||||
self.tstamp_ = tstamp
|
||||
self.disequilibrium_ = disequilibrium
|
||||
self.scaled_disequilibrium_ = scaled_disequilibrium
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"tstamp": self.tstamp_,
|
||||
"disequilibrium": self.disequilibrium_,
|
||||
"signed_scaled_disequilibrium": self.scaled_disequilibrium_,
|
||||
"scaled_disequilibrium": abs(self.scaled_disequilibrium_),
|
||||
# "pair": self.pair_,
|
||||
}
|
||||
def to_df(self) -> pd.DataFrame:
|
||||
return pd.DataFrame([self.to_dict()])
|
||||
|
||||
class PairsTradingModel(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, pair: TradingPair) -> Prediction:
|
||||
def predict(self, pair: TradingPair) -> Prediction: # type: ignore[assignment]
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pt_strategy.pt_market_data import PtMarketData
|
||||
from pt_strategy.pt_market_data import ResearchMarketData
|
||||
from pt_strategy.pt_model import Prediction
|
||||
from pt_strategy.trading_pair import PairState, TradingPair
|
||||
|
||||
@ -13,7 +13,7 @@ class PtResearchStrategy:
|
||||
config_: Dict[str, Any]
|
||||
trading_pair_: TradingPair
|
||||
model_data_policy_: ModelDataPolicy
|
||||
pt_mkt_data_: PtMarketData
|
||||
pt_mkt_data_: ResearchMarketData
|
||||
|
||||
trades_: List[pd.DataFrame]
|
||||
predictions_: pd.DataFrame
|
||||
@ -25,7 +25,6 @@ class PtResearchStrategy:
|
||||
instruments: List[Dict[str, str]],
|
||||
):
|
||||
from pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pt_strategy.pt_market_data import PtMarketData, ResearchMarketData
|
||||
from pt_strategy.trading_pair import TradingPair
|
||||
|
||||
self.config_ = config
|
||||
@ -39,9 +38,7 @@ class PtResearchStrategy:
|
||||
config_copy = copy.deepcopy(config)
|
||||
config_copy["instruments"] = instruments
|
||||
config_copy["datafiles"] = datafiles
|
||||
self.pt_mkt_data_ = PtMarketData.create(
|
||||
config=config_copy, md_class=ResearchMarketData
|
||||
)
|
||||
self.pt_mkt_data_ = ResearchMarketData(config=config_copy)
|
||||
self.pt_mkt_data_.load()
|
||||
self.model_data_policy_ = ModelDataPolicy.create(
|
||||
config, mkt_data=self.pt_mkt_data_.market_data_df_, pair=self.trading_pair_
|
||||
@ -73,7 +70,7 @@ class PtResearchStrategy:
|
||||
market_data_df = pd.concat([market_data_df, new_row], ignore_index=True)
|
||||
|
||||
prediction = self.trading_pair_.run(
|
||||
market_data_df, self.model_data_policy_.advance()
|
||||
market_data_df, self.model_data_policy_.advance(mkt_data_df=market_data_df)
|
||||
)
|
||||
self.predictions_ = pd.concat(
|
||||
[self.predictions_, prediction.to_df()], ignore_index=True
|
||||
@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, cast
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pt_strategy.model_data_policy import DataWindowParams
|
||||
from pt_strategy.prediction import Prediction
|
||||
|
||||
|
||||
class PairState(Enum):
|
||||
@ -20,11 +21,12 @@ class PairState(Enum):
|
||||
class TradingPair:
|
||||
config_: Dict[str, Any]
|
||||
market_data_: pd.DataFrame
|
||||
instruments_: List[Dict[str, str]]
|
||||
symbol_a_: str
|
||||
symbol_b_: str
|
||||
|
||||
stat_model_price_: str
|
||||
model_: PairsTradingModel # type: ignore[assignment]
|
||||
model_: PairsTradingModel # type: ignore[assignment]
|
||||
|
||||
user_data_: Dict[str, Any]
|
||||
|
||||
@ -34,11 +36,12 @@ class TradingPair:
|
||||
instruments: List[Dict[str, str]],
|
||||
):
|
||||
|
||||
from pt_strategy.model_data_policy import ModelDataPolicy
|
||||
from pt_strategy.pt_model import PairsTradingModel
|
||||
|
||||
assert len(instruments) == 2, "Trading pair must have exactly 2 instruments"
|
||||
|
||||
self.config_ = config
|
||||
self.instruments_ = instruments
|
||||
self.symbol_a_ = instruments[0]["symbol"]
|
||||
self.symbol_b_ = instruments[1]["symbol"]
|
||||
self.model_ = PairsTradingModel.create(config)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pt_strategy.trading_strategy import PtResearchStrategy
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
|
||||
|
||||
def visualize_prices(strategy: PtResearchStrategy, trading_date: str) -> None:
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any, Dict
|
||||
|
||||
from pt_strategy.results import (PairResearchResult, create_result_database,
|
||||
store_config_in_database)
|
||||
from pt_strategy.trading_strategy import PtResearchStrategy
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
|
||||
|
||||
@ -16,7 +16,8 @@
|
||||
"autoImportCompletions": true,
|
||||
"autoSearchPaths": true,
|
||||
"extraPaths": [
|
||||
"lib"
|
||||
"lib",
|
||||
".."
|
||||
],
|
||||
"stubPath": "./typings",
|
||||
"venvPath": ".",
|
||||
|
||||
@ -8,7 +8,7 @@ from pt_strategy.results import (
|
||||
create_result_database,
|
||||
store_config_in_database,
|
||||
)
|
||||
from pt_strategy.trading_strategy import PtResearchStrategy
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
|
||||
|
||||
@ -288,7 +288,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -302,7 +302,7 @@
|
||||
"\n",
|
||||
" \n",
|
||||
" from pt_strategy.trading_pair import TradingPair\n",
|
||||
" from pt_strategy.trading_strategy import PtResearchStrategy\n",
|
||||
" from pt_strategy.research_strategy import PtResearchStrategy\n",
|
||||
" from pt_strategy.results import PairResearchResult\n",
|
||||
"\n",
|
||||
" # Create trading pair\n",
|
||||
|
||||
@ -1,221 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import hjson
|
||||
import pandas as pd
|
||||
|
||||
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
||||
from pt_trading.results import (
|
||||
BacktestResult,
|
||||
create_result_database,
|
||||
store_config_in_database,
|
||||
store_results_in_database,
|
||||
)
|
||||
from pt_trading.fit_methods import PairsTradingFitMethod
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
|
||||
|
||||
def run_strategy(
|
||||
config: Dict,
|
||||
datafile: str,
|
||||
fit_method: PairsTradingFitMethod,
|
||||
instruments: List[str],
|
||||
) -> BacktestResult:
|
||||
"""
|
||||
Run backtest for all pairs using the specified instruments.
|
||||
"""
|
||||
bt_result: BacktestResult = BacktestResult(config=config)
|
||||
|
||||
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
|
||||
nonlocal datafile
|
||||
all_indexes = range(len(instruments))
|
||||
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
||||
pairs = []
|
||||
|
||||
# Update config to use the specified instruments
|
||||
config_copy = config.copy()
|
||||
config_copy["instruments"] = instruments
|
||||
|
||||
market_data_df = load_market_data(
|
||||
datafile=datafile,
|
||||
exchange_id=config_copy["exchange_id"],
|
||||
instruments=config_copy["instruments"],
|
||||
instrument_id_pfx=config_copy["instrument_id_pfx"],
|
||||
db_table_name=config_copy["db_table_name"],
|
||||
trading_hours=config_copy["trading_hours"],
|
||||
)
|
||||
|
||||
for a_index, b_index in unique_index_pairs:
|
||||
pair = fit_method.create_trading_pair(
|
||||
market_data=market_data_df,
|
||||
symbol_a=instruments[a_index],
|
||||
symbol_b=instruments[b_index],
|
||||
)
|
||||
pairs.append(pair)
|
||||
return pairs
|
||||
|
||||
pairs_trades = []
|
||||
for pair in _create_pairs(config, instruments):
|
||||
single_pair_trades = fit_method.run_pair(
|
||||
pair=pair, config=config, bt_result=bt_result
|
||||
)
|
||||
if single_pair_trades is not None and len(single_pair_trades) > 0:
|
||||
pairs_trades.append(single_pair_trades)
|
||||
|
||||
# Check if result_list has any data before concatenating
|
||||
if len(pairs_trades) == 0:
|
||||
print("No trading signals found for any pairs")
|
||||
return bt_result
|
||||
|
||||
result = pd.concat(pairs_trades, ignore_index=True)
|
||||
result["time"] = pd.to_datetime(result["time"])
|
||||
result = result.set_index("time").sort_index()
|
||||
|
||||
bt_result.collect_single_day_results(result)
|
||||
return bt_result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="Path to the configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datafiles",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Comma-separated list of data files (overrides config). No wildcards supported.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruments",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result_db",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config: Dict = load_config(args.config)
|
||||
|
||||
# Dynamically instantiate fit method class
|
||||
fit_method_class_name = config.get("fit_method_class", None)
|
||||
assert fit_method_class_name is not None
|
||||
module_name, class_name = fit_method_class_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
fit_method = getattr(module, class_name)()
|
||||
|
||||
# Resolve data files (CLI takes priority over config)
|
||||
datafiles = resolve_datafiles(config, args.datafiles)
|
||||
|
||||
if not datafiles:
|
||||
print("No data files found to process.")
|
||||
return
|
||||
|
||||
print(f"Found {len(datafiles)} data files to process:")
|
||||
for df in datafiles:
|
||||
print(f" - {df}")
|
||||
|
||||
# Create result database if needed
|
||||
if args.result_db.upper() != "NONE":
|
||||
create_result_database(args.result_db)
|
||||
|
||||
# Initialize a dictionary to store all trade results
|
||||
all_results: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Store configuration in database for reference
|
||||
if args.result_db.upper() != "NONE":
|
||||
# Get list of all instruments for storage
|
||||
all_instruments = []
|
||||
for datafile in datafiles:
|
||||
if args.instruments:
|
||||
file_instruments = [
|
||||
inst.strip() for inst in args.instruments.split(",")
|
||||
]
|
||||
else:
|
||||
file_instruments = get_available_instruments_from_db(datafile, config)
|
||||
all_instruments.extend(file_instruments)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_instruments = list(dict.fromkeys(all_instruments))
|
||||
|
||||
store_config_in_database(
|
||||
db_path=args.result_db,
|
||||
config_file_path=args.config,
|
||||
config=config,
|
||||
fit_method_class=fit_method_class_name,
|
||||
datafiles=datafiles,
|
||||
instruments=unique_instruments,
|
||||
)
|
||||
|
||||
# Process each data file
|
||||
|
||||
for datafile in datafiles:
|
||||
print(f"\n====== Processing {os.path.basename(datafile)} ======")
|
||||
|
||||
# Determine instruments to use
|
||||
if args.instruments:
|
||||
# Use CLI-specified instruments
|
||||
instruments = [inst.strip() for inst in args.instruments.split(",")]
|
||||
print(f"Using CLI-specified instruments: {instruments}")
|
||||
else:
|
||||
# Auto-detect instruments from database
|
||||
instruments = get_available_instruments_from_db(datafile, config)
|
||||
print(f"Auto-detected instruments: {instruments}")
|
||||
|
||||
if not instruments:
|
||||
print(f"No instruments found for {datafile}, skipping...")
|
||||
continue
|
||||
|
||||
# Process data for this file
|
||||
try:
|
||||
fit_method.reset()
|
||||
|
||||
bt_results = run_strategy(
|
||||
config=config,
|
||||
datafile=datafile,
|
||||
fit_method=fit_method,
|
||||
instruments=instruments,
|
||||
)
|
||||
|
||||
# Store results with file name as key
|
||||
filename = os.path.basename(datafile)
|
||||
all_results[filename] = {"trades": bt_results.trades.copy()}
|
||||
|
||||
# Store results in database
|
||||
if args.result_db.upper() != "NONE":
|
||||
store_results_in_database(args.result_db, datafile, bt_results)
|
||||
|
||||
print(f"Successfully processed {filename}")
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error processing {datafile}: {str(err)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Calculate and print results using a new BacktestResult instance for aggregation
|
||||
if all_results:
|
||||
aggregate_bt_results = BacktestResult(config=config)
|
||||
aggregate_bt_results.calculate_returns(all_results)
|
||||
aggregate_bt_results.print_grand_totals()
|
||||
aggregate_bt_results.print_outstanding_positions()
|
||||
|
||||
if args.result_db.upper() != "NONE":
|
||||
print(f"\nResults stored in database: {args.result_db}")
|
||||
else:
|
||||
print("No results to display.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -5,7 +5,7 @@ from typing import Any, Dict
|
||||
|
||||
from pt_strategy.results import (PairResearchResult, create_result_database,
|
||||
store_config_in_database)
|
||||
from pt_strategy.trading_strategy import PtResearchStrategy
|
||||
from pt_strategy.research_strategy import PtResearchStrategy
|
||||
from tools.filetools import resolve_datafiles
|
||||
from tools.instruments import get_instruments
|
||||
from tools.viz.viz_trades import visualize_trades
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user