From 352f7df2692daade504587055d171b2703199636 Mon Sep 17 00:00:00 2001 From: Oleg Sheynin Date: Mon, 7 Jul 2025 22:00:57 +0000 Subject: [PATCH] progress --- configuration/crypto.cfg | 2 +- configuration/equity.cfg | 2 +- pyproject.toml | 66 +++++++++ pyrightconfig.json | 24 ++++ requirements.txt | 1 + scripts/load_crypto_1min.sh | 14 +- src/cvtt/mkt_data.py | 182 +++++++++++++++++++++++++ src/notebooks/pt_pair_backtest.ipynb | 4 +- src/notebooks/pt_sliding.ipynb | 4 +- src/notebooks/pt_static.ipynb | 4 +- src/pt_backtest.py | 30 ++-- src/{ => trading}/results.py | 0 src/{ => trading}/strategies.py | 4 +- src/{tools => trading}/trading_pair.py | 0 14 files changed, 311 insertions(+), 26 deletions(-) create mode 100644 pyproject.toml create mode 100644 pyrightconfig.json create mode 100644 src/cvtt/mkt_data.py rename src/{ => trading}/results.py (100%) rename src/{ => trading}/strategies.py (99%) rename src/{tools => trading}/trading_pair.py (100%) diff --git a/configuration/crypto.cfg b/configuration/crypto.cfg index b782692..5adcead 100644 --- a/configuration/crypto.cfg +++ b/configuration/crypto.cfg @@ -29,5 +29,5 @@ "dis-equilibrium_close_trshld": 0.5, "training_minutes": 120, "funding_per_pair": 2000.0, - "strategy_class": "strategies.StaticFitStrategy" + "strategy_class": "trading.strategies.StaticFitStrategy" } \ No newline at end of file diff --git a/configuration/equity.cfg b/configuration/equity.cfg index c17476c..dffd077 100644 --- a/configuration/equity.cfg +++ b/configuration/equity.cfg @@ -20,7 +20,7 @@ "training_minutes": 120, "funding_per_pair": 2000.0, # "strategy_class": "strategies.StaticFitStrategy" - "strategy_class": "strategies.SlidingFitStrategy" + "strategy_class": "trading.strategies.SlidingFitStrategy" "exclude_instruments": ["CAN"] } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dfec06d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,66 @@ +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "pairs-trading" +version = "0.1.0" +description = "Pairs Trading Backtesting Framework" +requires-python = ">=3.8" + +[tool.black] +line-length = 88 +target-version = ['py38'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + +[tool.flake8] +max-line-length = 88 +extend-ignore = ["E203", "W503"] +exclude = [ + ".git", + "__pycache__", + "build", + "dist", + ".venv", + ".mypy_cache", + ".tox" +] + +[tool.mypy] +python_version = "3.8" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true + +[[tool.mypy.overrides]] +module = [ + "numpy.*", + "pandas.*", + "matplotlib.*", + "seaborn.*", + "scipy.*", + "sklearn.*" +] +ignore_missing_imports = true \ No newline at end of file diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..d22e8a7 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,24 @@ +{ + "include": [ + "src" + ], + "exclude": [ + "**/node_modules", + "**/__pycache__", + "**/.*", + "results", + "data" + ], + "ignore": [], + "defineConstant": {}, + "typeCheckingMode": "basic", + "useLibraryCodeForTypes": true, + "autoImportCompletions": true, + "autoSearchPaths": true, + "extraPaths": [ + "src" + ], + "stubPath": "./typings", + "venvPath": ".", + "venv": "python3.12-venv" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3988cf6..9d5302d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ async-timeout>=4.0.2 attrs>=21.2.0 beautifulsoup4>=4.10.0 black>=23.3.0 +flake8>=6.0.0 certifi>=2020.6.20 chardet>=4.0.0 charset-normalizer>=3.1.0 diff --git a/scripts/load_crypto_1min.sh b/scripts/load_crypto_1min.sh index 0a29451..44ef3a2 100755 --- a/scripts/load_crypto_1min.sh +++ b/scripts/load_crypto_1min.sh @@ -12,11 +12,13 @@ # ------------------------------------- # --- Current month - all files # ------------------------------------- -cd $(realpath $(dirname $0)) +cd $(realpath $(dirname $0))/.. mkdir -p ./data/crypto pushd ./data/crypto -rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./ +Cmd="rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./" +echo $Cmd +eval $Cmd # ------------------------------------- for srcfname in $(ls *.db.gz); do @@ -24,8 +26,12 @@ for srcfname in $(ls *.db.gz); do tgtfile=${dt}.mktdata.ohlcv.db echo "${srcfname} -> ${tgtfile}" - gunzip -c $srcfname > temp.db - rm -f ${tgtfile} && sqlite3 temp.db ".dump md_1min_bars" | sqlite3 ${tgtfile} && rm ${srcfname} + Cmd="gunzip -c $srcfname > temp.db" + echo $Cmd + eval $Cmd + Cmd="rm -f ${tgtfile} && sqlite3 temp.db \".dump md_1min_bars\" | sqlite3 ${tgtfile} && rm ${srcfname}" + echo $Cmd + eval $Cmd done rm temp.db popd diff --git a/src/cvtt/mkt_data.py b/src/cvtt/mkt_data.py new file mode 100644 index 0000000..9d2c653 --- /dev/null +++ b/src/cvtt/mkt_data.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 + +import argparse +import asyncio +import json +import logging +import uuid +from dataclasses import dataclass +from typing import Callable, Coroutine, Dict, List, Optional + +from numpy.strings import str_len +import websockets +from websockets.asyncio.client import ClientConnection + +SubscriptionIdT = str +UrlT = str +CallbackT = Callable[[Dict], Coroutine[None, str, None]] + +@dataclass +class CvttPricesSubscription: + id_: str + exchange_config_name_: str + instrument_id_: str + interval_sec_: int + history_depth_sec_: int + is_subscribed_: bool + is_historical_: bool + callback_: CallbackT + + def __init__( + self, + exchange_config_name: str, + instrument_id: str, + interval_sec: int, + history_depth_sec: int, + callback: CallbackT, + ): + self.exchange_config_name_ = exchange_config_name + self.instrument_id_ = instrument_id + self.interval_sec_ = interval_sec + self.history_depth_sec_ = history_depth_sec + self.callback_ = callback + self.id_ = str(uuid.uuid4()) + self.is_subscribed_ = False + self.is_historical_ = history_depth_sec > 0 + + +class CvttPricerWebSockClient: + # Class members with type hints + ws_url_: UrlT + websocket_: Optional[ClientConnection] + subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription] + is_connected_: bool + logger_: logging.Logger + + def __init__(self, url: str): + self.ws_url_ = url + self.websocket_ = None + self.is_connected_ = False + self.subscriptions_ = {} + self.logger_ = logging.getLogger(__name__) + logging.basicConfig(level=logging.INFO) + + async def subscribe( + self, subscription: CvttPricesSubscription + ) -> str: # returns subscription id + + if not self.is_connected_: + self.logger_.info(f"Connecting to {self.ws_url_}") + self.websocket_ = await websockets.connect(self.ws_url_) + self.is_connected_ = True + else: + raise Exception(f"Unable to connect to {self.ws_url_}") + + subscr_msg = { + "type": "subscr", + "id": subscription.id_, + "subscr_type": "MD_AGGREGATE", + "exchange_config_name": subscription.exchange_config_name_, + "instrument_id": subscription.instrument_id_, + "interval_sec": subscription.interval_sec_, + } + if subscription.is_historical_: + subscr_msg["history_depth_sec"] = subscription.history_depth_sec_ + + await self.websocket_.send(json.dumps(subscr_msg)) + + response = await self.websocket_.recv() + response_data = json.loads(response) + if not await self.handle_subscription_response(subscription, response_data): + await self.websocket_.close() + self.is_connected_ = False + raise Exception(f"Subscription failed: {str(response)}") + + self.subscriptions_[subscription.id_] = subscription + return subscription.id_ + + async def handle_subscription_response( + self, subscription: CvttPricesSubscription, response: dict + ) -> bool: + if response.get("type") != "subscr" or response.get("id") != subscription.id_: + return False + + if response.get("status") == "success": + self.logger_.info(f"Subscription successful: {json.dumps(response)}") + return True + elif response.get("status") == "error": + self.logger_.error(f"Subscription failed: {response.get('reason')}") + return False + return False + + async def connect_and_subscribe(self) -> None: + assert self.websocket_ + try: + while self.is_connected_: + try: + message = await self.websocket_.recv() + message_str = ( + message.decode("utf-8") + if isinstance(message, bytes) + else message + ) + await self.process_message(json.loads(message_str)) + except websockets.ConnectionClosed: + self.logger_.warning("Connection closed") + self.is_connected_ = False + break + except Exception as e: + self.logger_.error(f"Error occurred: {str(e)}") + self.is_connected_ = False + await asyncio.sleep(5) # Wait before reconnecting + + async def process_message(self, message: Dict) -> None: + if message.get("type") in ["md_aggregate", "historical_md_aggregate"]: + subscription_id = message.get("id") + if subscription_id not in self.subscriptions_: + self.logger_.warning(f"Unknown subscription id: {subscription_id}") + return + + subscription = self.subscriptions_[subscription_id] + await subscription.callback_(message) + else: + self.logger_.warning(f"Unknown message type: {message.get('type')}") + + +async def main() -> None: + pass + # parser = argparse.ArgumentParser(description="WebSocket API Testing Tool") + # parser.add_argument("--url", required=True, help="WebSocket API URL") + # parser.add_argument( + # "--exchange_config_name", required=True, help="Exchange config name" + # ) + # parser.add_argument( + # "--instrument_ids", required=True, help="Comma separated Instrument IDs" + # ) + # parser.add_argument( + # "--interval_sec", type=int, required=True, help="Interval in seconds" + # ) + # parser.add_argument( + # "--history_depth_sec", + # default=0, + # type=int, + # required=False, + # help="History depth in seconds", + # ) + + # args = parser.parse_args() + + # config = PricerClientConfig( + # url_=args.url, + # exchange_config_name_=args.exchange_config_name, + # instrument_ids_=args.instrument_ids.split(","), + # interval_sec_=args.interval_sec, + # history_depth_sec_=args.history_depth_sec, + # ) + + # client = CvttPricerWebSockClient(config) + # await client.connect_and_subscribe() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/notebooks/pt_pair_backtest.ipynb b/src/notebooks/pt_pair_backtest.ipynb index 0185c84..f732329 100644 --- a/src/notebooks/pt_pair_backtest.ipynb +++ b/src/notebooks/pt_pair_backtest.ipynb @@ -89,8 +89,8 @@ "# Import our modules\n", "from strategies import StaticFitStrategy, SlidingFitStrategy, PairState\n", "from tools.data_loader import load_market_data\n", - "from tools.trading_pair import TradingPair\n", - "from results import BacktestResult\n", + "from trading.trading_pair import TradingPair\n", + "from trading.results import BacktestResult\n", "\n", "# Set plotting style\n", "plt.style.use('seaborn-v0_8')\n", diff --git a/src/notebooks/pt_sliding.ipynb b/src/notebooks/pt_sliding.ipynb index 89326f1..4e539a1 100644 --- a/src/notebooks/pt_sliding.ipynb +++ b/src/notebooks/pt_sliding.ipynb @@ -113,8 +113,8 @@ "# Import our modules\n", "from strategies import SlidingFitStrategy, PairState\n", "from tools.data_loader import load_market_data\n", - "from tools.trading_pair import TradingPair\n", - "from results import BacktestResult\n", + "from trading.trading_pair import TradingPair\n", + "from trading.results import BacktestResult\n", "\n", "# Set plotting style\n", "plt.style.use('seaborn-v0_8')\n", diff --git a/src/notebooks/pt_static.ipynb b/src/notebooks/pt_static.ipynb index c09961b..201152d 100644 --- a/src/notebooks/pt_static.ipynb +++ b/src/notebooks/pt_static.ipynb @@ -98,8 +98,8 @@ "# Import our modules\n", "from strategies import StaticFitStrategy, SlidingFitStrategy\n", "from tools.data_loader import load_market_data\n", - "from tools.trading_pair import TradingPair\n", - "from results import BacktestResult\n", + "from trading.trading_pair import TradingPair\n", + "from trading.results import BacktestResult\n", "\n", "# Set plotting style\n", "plt.style.use('seaborn-v0_8')\n", diff --git a/src/pt_backtest.py b/src/pt_backtest.py index 80adc7d..9e734c6 100644 --- a/src/pt_backtest.py +++ b/src/pt_backtest.py @@ -3,7 +3,6 @@ import hjson import importlib import glob import os -import sqlite3 from datetime import datetime, date from typing import Any, Dict, List, Optional @@ -11,15 +10,20 @@ from typing import Any, Dict, List, Optional import pandas as pd from tools.data_loader import get_available_instruments_from_db, load_market_data -from tools.trading_pair import TradingPair -from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database +from trading.strategies import PairsTradingStrategy +from trading.trading_pair import TradingPair +from trading.results import ( + BacktestResult, + create_result_database, + store_results_in_database, + store_config_in_database, +) def load_config(config_path: str) -> Dict: with open(config_path, "r") as f: config = hjson.load(f) - return config - + return dict(config) def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: @@ -65,7 +69,7 @@ def run_backtest( config: Dict, datafile: str, price_column: str, - strategy, + strategy: PairsTradingStrategy, instruments: List[str], ) -> BacktestResult: """ @@ -167,28 +171,30 @@ def main() -> None: # Initialize a dictionary to store all trade results all_results: Dict[str, Dict[str, Any]] = {} - + # Store configuration in database for reference if args.result_db.upper() != "NONE": # Get list of all instruments for storage all_instruments = [] for datafile in datafiles: if args.instruments: - file_instruments = [inst.strip() for inst in args.instruments.split(",")] + file_instruments = [ + inst.strip() for inst in args.instruments.split(",") + ] else: file_instruments = get_available_instruments_from_db(datafile, config) all_instruments.extend(file_instruments) - + # Remove duplicates while preserving order unique_instruments = list(dict.fromkeys(all_instruments)) - + store_config_in_database( db_path=args.result_db, config_file_path=args.config, config=config, strategy_class=strategy_class_name, datafiles=datafiles, - instruments=unique_instruments + instruments=unique_instruments, ) # Process each data file @@ -214,7 +220,7 @@ def main() -> None: # Process data for this file try: strategy.reset() - + bt_results = run_backtest( config=config, datafile=datafile, diff --git a/src/results.py b/src/trading/results.py similarity index 100% rename from src/results.py rename to src/trading/results.py diff --git a/src/strategies.py b/src/trading/strategies.py similarity index 99% rename from src/strategies.py rename to src/trading/strategies.py index ad85526..a327fd5 100644 --- a/src/strategies.py +++ b/src/trading/strategies.py @@ -5,8 +5,8 @@ from typing import Dict, Optional, cast import pandas as pd # type: ignore[import] -from tools.trading_pair import TradingPair -from results import BacktestResult +from trading.trading_pair import TradingPair +from trading.results import BacktestResult NanoPerMin = 1e9 diff --git a/src/tools/trading_pair.py b/src/trading/trading_pair.py similarity index 100% rename from src/tools/trading_pair.py rename to src/trading/trading_pair.py