This commit is contained in:
Oleg Sheynin 2025-07-07 22:00:57 +00:00
parent 191feb341d
commit 352f7df269
14 changed files with 311 additions and 26 deletions

View File

@ -29,5 +29,5 @@
"dis-equilibrium_close_trshld": 0.5, "dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120, "training_minutes": 120,
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
"strategy_class": "strategies.StaticFitStrategy" "strategy_class": "trading.strategies.StaticFitStrategy"
} }

View File

@ -20,7 +20,7 @@
"training_minutes": 120, "training_minutes": 120,
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
# "strategy_class": "strategies.StaticFitStrategy" # "strategy_class": "strategies.StaticFitStrategy"
"strategy_class": "strategies.SlidingFitStrategy" "strategy_class": "trading.strategies.SlidingFitStrategy"
"exclude_instruments": ["CAN"] "exclude_instruments": ["CAN"]
} }

66
pyproject.toml Normal file
View File

@ -0,0 +1,66 @@
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "pairs-trading"
version = "0.1.0"
description = "Pairs Trading Backtesting Framework"
requires-python = ">=3.8"
[tool.black]
line-length = 88
target-version = ['py38']
include = '\.pyi?$'
extend-exclude = '''
/(
# directories
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| build
| dist
)/
'''
[tool.flake8]
max-line-length = 88
extend-ignore = ["E203", "W503"]
exclude = [
".git",
"__pycache__",
"build",
"dist",
".venv",
".mypy_cache",
".tox"
]
[tool.mypy]
python_version = "3.8"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true
strict_equality = true
[[tool.mypy.overrides]]
module = [
"numpy.*",
"pandas.*",
"matplotlib.*",
"seaborn.*",
"scipy.*",
"sklearn.*"
]
ignore_missing_imports = true

24
pyrightconfig.json Normal file
View File

@ -0,0 +1,24 @@
{
"include": [
"src"
],
"exclude": [
"**/node_modules",
"**/__pycache__",
"**/.*",
"results",
"data"
],
"ignore": [],
"defineConstant": {},
"typeCheckingMode": "basic",
"useLibraryCodeForTypes": true,
"autoImportCompletions": true,
"autoSearchPaths": true,
"extraPaths": [
"src"
],
"stubPath": "./typings",
"venvPath": ".",
"venv": "python3.12-venv"
}

View File

@ -4,6 +4,7 @@ async-timeout>=4.0.2
attrs>=21.2.0 attrs>=21.2.0
beautifulsoup4>=4.10.0 beautifulsoup4>=4.10.0
black>=23.3.0 black>=23.3.0
flake8>=6.0.0
certifi>=2020.6.20 certifi>=2020.6.20
chardet>=4.0.0 chardet>=4.0.0
charset-normalizer>=3.1.0 charset-normalizer>=3.1.0

View File

@ -12,11 +12,13 @@
# ------------------------------------- # -------------------------------------
# --- Current month - all files # --- Current month - all files
# ------------------------------------- # -------------------------------------
cd $(realpath $(dirname $0)) cd $(realpath $(dirname $0))/..
mkdir -p ./data/crypto mkdir -p ./data/crypto
pushd ./data/crypto pushd ./data/crypto
rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./ Cmd="rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./"
echo $Cmd
eval $Cmd
# ------------------------------------- # -------------------------------------
for srcfname in $(ls *.db.gz); do for srcfname in $(ls *.db.gz); do
@ -24,8 +26,12 @@ for srcfname in $(ls *.db.gz); do
tgtfile=${dt}.mktdata.ohlcv.db tgtfile=${dt}.mktdata.ohlcv.db
echo "${srcfname} -> ${tgtfile}" echo "${srcfname} -> ${tgtfile}"
gunzip -c $srcfname > temp.db Cmd="gunzip -c $srcfname > temp.db"
rm -f ${tgtfile} && sqlite3 temp.db ".dump md_1min_bars" | sqlite3 ${tgtfile} && rm ${srcfname} echo $Cmd
eval $Cmd
Cmd="rm -f ${tgtfile} && sqlite3 temp.db \".dump md_1min_bars\" | sqlite3 ${tgtfile} && rm ${srcfname}"
echo $Cmd
eval $Cmd
done done
rm temp.db rm temp.db
popd popd

182
src/cvtt/mkt_data.py Normal file
View File

@ -0,0 +1,182 @@
#!/usr/bin/env python3
import argparse
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Optional
from numpy.strings import str_len
import websockets
from websockets.asyncio.client import ClientConnection
SubscriptionIdT = str
UrlT = str
CallbackT = Callable[[Dict], Coroutine[None, str, None]]
@dataclass
class CvttPricesSubscription:
id_: str
exchange_config_name_: str
instrument_id_: str
interval_sec_: int
history_depth_sec_: int
is_subscribed_: bool
is_historical_: bool
callback_: CallbackT
def __init__(
self,
exchange_config_name: str,
instrument_id: str,
interval_sec: int,
history_depth_sec: int,
callback: CallbackT,
):
self.exchange_config_name_ = exchange_config_name
self.instrument_id_ = instrument_id
self.interval_sec_ = interval_sec
self.history_depth_sec_ = history_depth_sec
self.callback_ = callback
self.id_ = str(uuid.uuid4())
self.is_subscribed_ = False
self.is_historical_ = history_depth_sec > 0
class CvttPricerWebSockClient:
# Class members with type hints
ws_url_: UrlT
websocket_: Optional[ClientConnection]
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
is_connected_: bool
logger_: logging.Logger
def __init__(self, url: str):
self.ws_url_ = url
self.websocket_ = None
self.is_connected_ = False
self.subscriptions_ = {}
self.logger_ = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
async def subscribe(
self, subscription: CvttPricesSubscription
) -> str: # returns subscription id
if not self.is_connected_:
self.logger_.info(f"Connecting to {self.ws_url_}")
self.websocket_ = await websockets.connect(self.ws_url_)
self.is_connected_ = True
else:
raise Exception(f"Unable to connect to {self.ws_url_}")
subscr_msg = {
"type": "subscr",
"id": subscription.id_,
"subscr_type": "MD_AGGREGATE",
"exchange_config_name": subscription.exchange_config_name_,
"instrument_id": subscription.instrument_id_,
"interval_sec": subscription.interval_sec_,
}
if subscription.is_historical_:
subscr_msg["history_depth_sec"] = subscription.history_depth_sec_
await self.websocket_.send(json.dumps(subscr_msg))
response = await self.websocket_.recv()
response_data = json.loads(response)
if not await self.handle_subscription_response(subscription, response_data):
await self.websocket_.close()
self.is_connected_ = False
raise Exception(f"Subscription failed: {str(response)}")
self.subscriptions_[subscription.id_] = subscription
return subscription.id_
async def handle_subscription_response(
self, subscription: CvttPricesSubscription, response: dict
) -> bool:
if response.get("type") != "subscr" or response.get("id") != subscription.id_:
return False
if response.get("status") == "success":
self.logger_.info(f"Subscription successful: {json.dumps(response)}")
return True
elif response.get("status") == "error":
self.logger_.error(f"Subscription failed: {response.get('reason')}")
return False
return False
async def connect_and_subscribe(self) -> None:
assert self.websocket_
try:
while self.is_connected_:
try:
message = await self.websocket_.recv()
message_str = (
message.decode("utf-8")
if isinstance(message, bytes)
else message
)
await self.process_message(json.loads(message_str))
except websockets.ConnectionClosed:
self.logger_.warning("Connection closed")
self.is_connected_ = False
break
except Exception as e:
self.logger_.error(f"Error occurred: {str(e)}")
self.is_connected_ = False
await asyncio.sleep(5) # Wait before reconnecting
async def process_message(self, message: Dict) -> None:
if message.get("type") in ["md_aggregate", "historical_md_aggregate"]:
subscription_id = message.get("id")
if subscription_id not in self.subscriptions_:
self.logger_.warning(f"Unknown subscription id: {subscription_id}")
return
subscription = self.subscriptions_[subscription_id]
await subscription.callback_(message)
else:
self.logger_.warning(f"Unknown message type: {message.get('type')}")
async def main() -> None:
pass
# parser = argparse.ArgumentParser(description="WebSocket API Testing Tool")
# parser.add_argument("--url", required=True, help="WebSocket API URL")
# parser.add_argument(
# "--exchange_config_name", required=True, help="Exchange config name"
# )
# parser.add_argument(
# "--instrument_ids", required=True, help="Comma separated Instrument IDs"
# )
# parser.add_argument(
# "--interval_sec", type=int, required=True, help="Interval in seconds"
# )
# parser.add_argument(
# "--history_depth_sec",
# default=0,
# type=int,
# required=False,
# help="History depth in seconds",
# )
# args = parser.parse_args()
# config = PricerClientConfig(
# url_=args.url,
# exchange_config_name_=args.exchange_config_name,
# instrument_ids_=args.instrument_ids.split(","),
# interval_sec_=args.interval_sec,
# history_depth_sec_=args.history_depth_sec,
# )
# client = CvttPricerWebSockClient(config)
# await client.connect_and_subscribe()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -89,8 +89,8 @@
"# Import our modules\n", "# Import our modules\n",
"from strategies import StaticFitStrategy, SlidingFitStrategy, PairState\n", "from strategies import StaticFitStrategy, SlidingFitStrategy, PairState\n",
"from tools.data_loader import load_market_data\n", "from tools.data_loader import load_market_data\n",
"from tools.trading_pair import TradingPair\n", "from trading.trading_pair import TradingPair\n",
"from results import BacktestResult\n", "from trading.results import BacktestResult\n",
"\n", "\n",
"# Set plotting style\n", "# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n", "plt.style.use('seaborn-v0_8')\n",

View File

@ -113,8 +113,8 @@
"# Import our modules\n", "# Import our modules\n",
"from strategies import SlidingFitStrategy, PairState\n", "from strategies import SlidingFitStrategy, PairState\n",
"from tools.data_loader import load_market_data\n", "from tools.data_loader import load_market_data\n",
"from tools.trading_pair import TradingPair\n", "from trading.trading_pair import TradingPair\n",
"from results import BacktestResult\n", "from trading.results import BacktestResult\n",
"\n", "\n",
"# Set plotting style\n", "# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n", "plt.style.use('seaborn-v0_8')\n",

View File

@ -98,8 +98,8 @@
"# Import our modules\n", "# Import our modules\n",
"from strategies import StaticFitStrategy, SlidingFitStrategy\n", "from strategies import StaticFitStrategy, SlidingFitStrategy\n",
"from tools.data_loader import load_market_data\n", "from tools.data_loader import load_market_data\n",
"from tools.trading_pair import TradingPair\n", "from trading.trading_pair import TradingPair\n",
"from results import BacktestResult\n", "from trading.results import BacktestResult\n",
"\n", "\n",
"# Set plotting style\n", "# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n", "plt.style.use('seaborn-v0_8')\n",

View File

@ -3,7 +3,6 @@ import hjson
import importlib import importlib
import glob import glob
import os import os
import sqlite3
from datetime import datetime, date from datetime import datetime, date
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -11,15 +10,20 @@ from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
from tools.data_loader import get_available_instruments_from_db, load_market_data from tools.data_loader import get_available_instruments_from_db, load_market_data
from tools.trading_pair import TradingPair from trading.strategies import PairsTradingStrategy
from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database from trading.trading_pair import TradingPair
from trading.results import (
BacktestResult,
create_result_database,
store_results_in_database,
store_config_in_database,
)
def load_config(config_path: str) -> Dict: def load_config(config_path: str) -> Dict:
with open(config_path, "r") as f: with open(config_path, "r") as f:
config = hjson.load(f) config = hjson.load(f)
return config return dict(config)
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
@ -65,7 +69,7 @@ def run_backtest(
config: Dict, config: Dict,
datafile: str, datafile: str,
price_column: str, price_column: str,
strategy, strategy: PairsTradingStrategy,
instruments: List[str], instruments: List[str],
) -> BacktestResult: ) -> BacktestResult:
""" """
@ -167,28 +171,30 @@ def main() -> None:
# Initialize a dictionary to store all trade results # Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {} all_results: Dict[str, Dict[str, Any]] = {}
# Store configuration in database for reference # Store configuration in database for reference
if args.result_db.upper() != "NONE": if args.result_db.upper() != "NONE":
# Get list of all instruments for storage # Get list of all instruments for storage
all_instruments = [] all_instruments = []
for datafile in datafiles: for datafile in datafiles:
if args.instruments: if args.instruments:
file_instruments = [inst.strip() for inst in args.instruments.split(",")] file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else: else:
file_instruments = get_available_instruments_from_db(datafile, config) file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments) all_instruments.extend(file_instruments)
# Remove duplicates while preserving order # Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments)) unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database( store_config_in_database(
db_path=args.result_db, db_path=args.result_db,
config_file_path=args.config, config_file_path=args.config,
config=config, config=config,
strategy_class=strategy_class_name, strategy_class=strategy_class_name,
datafiles=datafiles, datafiles=datafiles,
instruments=unique_instruments instruments=unique_instruments,
) )
# Process each data file # Process each data file
@ -214,7 +220,7 @@ def main() -> None:
# Process data for this file # Process data for this file
try: try:
strategy.reset() strategy.reset()
bt_results = run_backtest( bt_results = run_backtest(
config=config, config=config,
datafile=datafile, datafile=datafile,

View File

@ -5,8 +5,8 @@ from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import] import pandas as pd # type: ignore[import]
from tools.trading_pair import TradingPair from trading.trading_pair import TradingPair
from results import BacktestResult from trading.results import BacktestResult
NanoPerMin = 1e9 NanoPerMin = 1e9