progress
This commit is contained in:
parent
191feb341d
commit
352f7df269
@ -29,5 +29,5 @@
|
||||
"dis-equilibrium_close_trshld": 0.5,
|
||||
"training_minutes": 120,
|
||||
"funding_per_pair": 2000.0,
|
||||
"strategy_class": "strategies.StaticFitStrategy"
|
||||
"strategy_class": "trading.strategies.StaticFitStrategy"
|
||||
}
|
||||
@ -20,7 +20,7 @@
|
||||
"training_minutes": 120,
|
||||
"funding_per_pair": 2000.0,
|
||||
# "strategy_class": "strategies.StaticFitStrategy"
|
||||
"strategy_class": "strategies.SlidingFitStrategy"
|
||||
"strategy_class": "trading.strategies.SlidingFitStrategy"
|
||||
"exclude_instruments": ["CAN"]
|
||||
|
||||
}
|
||||
66
pyproject.toml
Normal file
66
pyproject.toml
Normal file
@ -0,0 +1,66 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "pairs-trading"
|
||||
version = "0.1.0"
|
||||
description = "Pairs Trading Backtesting Framework"
|
||||
requires-python = ">=3.8"
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py38']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
# directories
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.flake8]
|
||||
max-line-length = 88
|
||||
extend-ignore = ["E203", "W503"]
|
||||
exclude = [
|
||||
".git",
|
||||
"__pycache__",
|
||||
"build",
|
||||
"dist",
|
||||
".venv",
|
||||
".mypy_cache",
|
||||
".tox"
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.8"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
warn_unreachable = true
|
||||
strict_equality = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"numpy.*",
|
||||
"pandas.*",
|
||||
"matplotlib.*",
|
||||
"seaborn.*",
|
||||
"scipy.*",
|
||||
"sklearn.*"
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
24
pyrightconfig.json
Normal file
24
pyrightconfig.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"include": [
|
||||
"src"
|
||||
],
|
||||
"exclude": [
|
||||
"**/node_modules",
|
||||
"**/__pycache__",
|
||||
"**/.*",
|
||||
"results",
|
||||
"data"
|
||||
],
|
||||
"ignore": [],
|
||||
"defineConstant": {},
|
||||
"typeCheckingMode": "basic",
|
||||
"useLibraryCodeForTypes": true,
|
||||
"autoImportCompletions": true,
|
||||
"autoSearchPaths": true,
|
||||
"extraPaths": [
|
||||
"src"
|
||||
],
|
||||
"stubPath": "./typings",
|
||||
"venvPath": ".",
|
||||
"venv": "python3.12-venv"
|
||||
}
|
||||
@ -4,6 +4,7 @@ async-timeout>=4.0.2
|
||||
attrs>=21.2.0
|
||||
beautifulsoup4>=4.10.0
|
||||
black>=23.3.0
|
||||
flake8>=6.0.0
|
||||
certifi>=2020.6.20
|
||||
chardet>=4.0.0
|
||||
charset-normalizer>=3.1.0
|
||||
|
||||
@ -12,11 +12,13 @@
|
||||
# -------------------------------------
|
||||
# --- Current month - all files
|
||||
# -------------------------------------
|
||||
cd $(realpath $(dirname $0))
|
||||
cd $(realpath $(dirname $0))/..
|
||||
mkdir -p ./data/crypto
|
||||
pushd ./data/crypto
|
||||
|
||||
rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./
|
||||
Cmd="rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./"
|
||||
echo $Cmd
|
||||
eval $Cmd
|
||||
# -------------------------------------
|
||||
|
||||
for srcfname in $(ls *.db.gz); do
|
||||
@ -24,8 +26,12 @@ for srcfname in $(ls *.db.gz); do
|
||||
tgtfile=${dt}.mktdata.ohlcv.db
|
||||
echo "${srcfname} -> ${tgtfile}"
|
||||
|
||||
gunzip -c $srcfname > temp.db
|
||||
rm -f ${tgtfile} && sqlite3 temp.db ".dump md_1min_bars" | sqlite3 ${tgtfile} && rm ${srcfname}
|
||||
Cmd="gunzip -c $srcfname > temp.db"
|
||||
echo $Cmd
|
||||
eval $Cmd
|
||||
Cmd="rm -f ${tgtfile} && sqlite3 temp.db \".dump md_1min_bars\" | sqlite3 ${tgtfile} && rm ${srcfname}"
|
||||
echo $Cmd
|
||||
eval $Cmd
|
||||
done
|
||||
rm temp.db
|
||||
popd
|
||||
|
||||
182
src/cvtt/mkt_data.py
Normal file
182
src/cvtt/mkt_data.py
Normal file
@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Coroutine, Dict, List, Optional
|
||||
|
||||
from numpy.strings import str_len
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
SubscriptionIdT = str
|
||||
UrlT = str
|
||||
CallbackT = Callable[[Dict], Coroutine[None, str, None]]
|
||||
|
||||
@dataclass
|
||||
class CvttPricesSubscription:
|
||||
id_: str
|
||||
exchange_config_name_: str
|
||||
instrument_id_: str
|
||||
interval_sec_: int
|
||||
history_depth_sec_: int
|
||||
is_subscribed_: bool
|
||||
is_historical_: bool
|
||||
callback_: CallbackT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exchange_config_name: str,
|
||||
instrument_id: str,
|
||||
interval_sec: int,
|
||||
history_depth_sec: int,
|
||||
callback: CallbackT,
|
||||
):
|
||||
self.exchange_config_name_ = exchange_config_name
|
||||
self.instrument_id_ = instrument_id
|
||||
self.interval_sec_ = interval_sec
|
||||
self.history_depth_sec_ = history_depth_sec
|
||||
self.callback_ = callback
|
||||
self.id_ = str(uuid.uuid4())
|
||||
self.is_subscribed_ = False
|
||||
self.is_historical_ = history_depth_sec > 0
|
||||
|
||||
|
||||
class CvttPricerWebSockClient:
|
||||
# Class members with type hints
|
||||
ws_url_: UrlT
|
||||
websocket_: Optional[ClientConnection]
|
||||
subscriptions_: Dict[SubscriptionIdT, CvttPricesSubscription]
|
||||
is_connected_: bool
|
||||
logger_: logging.Logger
|
||||
|
||||
def __init__(self, url: str):
|
||||
self.ws_url_ = url
|
||||
self.websocket_ = None
|
||||
self.is_connected_ = False
|
||||
self.subscriptions_ = {}
|
||||
self.logger_ = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
async def subscribe(
|
||||
self, subscription: CvttPricesSubscription
|
||||
) -> str: # returns subscription id
|
||||
|
||||
if not self.is_connected_:
|
||||
self.logger_.info(f"Connecting to {self.ws_url_}")
|
||||
self.websocket_ = await websockets.connect(self.ws_url_)
|
||||
self.is_connected_ = True
|
||||
else:
|
||||
raise Exception(f"Unable to connect to {self.ws_url_}")
|
||||
|
||||
subscr_msg = {
|
||||
"type": "subscr",
|
||||
"id": subscription.id_,
|
||||
"subscr_type": "MD_AGGREGATE",
|
||||
"exchange_config_name": subscription.exchange_config_name_,
|
||||
"instrument_id": subscription.instrument_id_,
|
||||
"interval_sec": subscription.interval_sec_,
|
||||
}
|
||||
if subscription.is_historical_:
|
||||
subscr_msg["history_depth_sec"] = subscription.history_depth_sec_
|
||||
|
||||
await self.websocket_.send(json.dumps(subscr_msg))
|
||||
|
||||
response = await self.websocket_.recv()
|
||||
response_data = json.loads(response)
|
||||
if not await self.handle_subscription_response(subscription, response_data):
|
||||
await self.websocket_.close()
|
||||
self.is_connected_ = False
|
||||
raise Exception(f"Subscription failed: {str(response)}")
|
||||
|
||||
self.subscriptions_[subscription.id_] = subscription
|
||||
return subscription.id_
|
||||
|
||||
async def handle_subscription_response(
|
||||
self, subscription: CvttPricesSubscription, response: dict
|
||||
) -> bool:
|
||||
if response.get("type") != "subscr" or response.get("id") != subscription.id_:
|
||||
return False
|
||||
|
||||
if response.get("status") == "success":
|
||||
self.logger_.info(f"Subscription successful: {json.dumps(response)}")
|
||||
return True
|
||||
elif response.get("status") == "error":
|
||||
self.logger_.error(f"Subscription failed: {response.get('reason')}")
|
||||
return False
|
||||
return False
|
||||
|
||||
async def connect_and_subscribe(self) -> None:
|
||||
assert self.websocket_
|
||||
try:
|
||||
while self.is_connected_:
|
||||
try:
|
||||
message = await self.websocket_.recv()
|
||||
message_str = (
|
||||
message.decode("utf-8")
|
||||
if isinstance(message, bytes)
|
||||
else message
|
||||
)
|
||||
await self.process_message(json.loads(message_str))
|
||||
except websockets.ConnectionClosed:
|
||||
self.logger_.warning("Connection closed")
|
||||
self.is_connected_ = False
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger_.error(f"Error occurred: {str(e)}")
|
||||
self.is_connected_ = False
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
|
||||
async def process_message(self, message: Dict) -> None:
|
||||
if message.get("type") in ["md_aggregate", "historical_md_aggregate"]:
|
||||
subscription_id = message.get("id")
|
||||
if subscription_id not in self.subscriptions_:
|
||||
self.logger_.warning(f"Unknown subscription id: {subscription_id}")
|
||||
return
|
||||
|
||||
subscription = self.subscriptions_[subscription_id]
|
||||
await subscription.callback_(message)
|
||||
else:
|
||||
self.logger_.warning(f"Unknown message type: {message.get('type')}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
pass
|
||||
# parser = argparse.ArgumentParser(description="WebSocket API Testing Tool")
|
||||
# parser.add_argument("--url", required=True, help="WebSocket API URL")
|
||||
# parser.add_argument(
|
||||
# "--exchange_config_name", required=True, help="Exchange config name"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--instrument_ids", required=True, help="Comma separated Instrument IDs"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--interval_sec", type=int, required=True, help="Interval in seconds"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--history_depth_sec",
|
||||
# default=0,
|
||||
# type=int,
|
||||
# required=False,
|
||||
# help="History depth in seconds",
|
||||
# )
|
||||
|
||||
# args = parser.parse_args()
|
||||
|
||||
# config = PricerClientConfig(
|
||||
# url_=args.url,
|
||||
# exchange_config_name_=args.exchange_config_name,
|
||||
# instrument_ids_=args.instrument_ids.split(","),
|
||||
# interval_sec_=args.interval_sec,
|
||||
# history_depth_sec_=args.history_depth_sec,
|
||||
# )
|
||||
|
||||
# client = CvttPricerWebSockClient(config)
|
||||
# await client.connect_and_subscribe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -89,8 +89,8 @@
|
||||
"# Import our modules\n",
|
||||
"from strategies import StaticFitStrategy, SlidingFitStrategy, PairState\n",
|
||||
"from tools.data_loader import load_market_data\n",
|
||||
"from tools.trading_pair import TradingPair\n",
|
||||
"from results import BacktestResult\n",
|
||||
"from trading.trading_pair import TradingPair\n",
|
||||
"from trading.results import BacktestResult\n",
|
||||
"\n",
|
||||
"# Set plotting style\n",
|
||||
"plt.style.use('seaborn-v0_8')\n",
|
||||
|
||||
@ -113,8 +113,8 @@
|
||||
"# Import our modules\n",
|
||||
"from strategies import SlidingFitStrategy, PairState\n",
|
||||
"from tools.data_loader import load_market_data\n",
|
||||
"from tools.trading_pair import TradingPair\n",
|
||||
"from results import BacktestResult\n",
|
||||
"from trading.trading_pair import TradingPair\n",
|
||||
"from trading.results import BacktestResult\n",
|
||||
"\n",
|
||||
"# Set plotting style\n",
|
||||
"plt.style.use('seaborn-v0_8')\n",
|
||||
|
||||
@ -98,8 +98,8 @@
|
||||
"# Import our modules\n",
|
||||
"from strategies import StaticFitStrategy, SlidingFitStrategy\n",
|
||||
"from tools.data_loader import load_market_data\n",
|
||||
"from tools.trading_pair import TradingPair\n",
|
||||
"from results import BacktestResult\n",
|
||||
"from trading.trading_pair import TradingPair\n",
|
||||
"from trading.results import BacktestResult\n",
|
||||
"\n",
|
||||
"# Set plotting style\n",
|
||||
"plt.style.use('seaborn-v0_8')\n",
|
||||
|
||||
@ -3,7 +3,6 @@ import hjson
|
||||
import importlib
|
||||
import glob
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime, date
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
@ -11,15 +10,20 @@ from typing import Any, Dict, List, Optional
|
||||
import pandas as pd
|
||||
|
||||
from tools.data_loader import get_available_instruments_from_db, load_market_data
|
||||
from tools.trading_pair import TradingPair
|
||||
from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database
|
||||
from trading.strategies import PairsTradingStrategy
|
||||
from trading.trading_pair import TradingPair
|
||||
from trading.results import (
|
||||
BacktestResult,
|
||||
create_result_database,
|
||||
store_results_in_database,
|
||||
store_config_in_database,
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Dict:
|
||||
with open(config_path, "r") as f:
|
||||
config = hjson.load(f)
|
||||
return config
|
||||
|
||||
return dict(config)
|
||||
|
||||
|
||||
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
|
||||
@ -65,7 +69,7 @@ def run_backtest(
|
||||
config: Dict,
|
||||
datafile: str,
|
||||
price_column: str,
|
||||
strategy,
|
||||
strategy: PairsTradingStrategy,
|
||||
instruments: List[str],
|
||||
) -> BacktestResult:
|
||||
"""
|
||||
@ -167,28 +171,30 @@ def main() -> None:
|
||||
|
||||
# Initialize a dictionary to store all trade results
|
||||
all_results: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
# Store configuration in database for reference
|
||||
if args.result_db.upper() != "NONE":
|
||||
# Get list of all instruments for storage
|
||||
all_instruments = []
|
||||
for datafile in datafiles:
|
||||
if args.instruments:
|
||||
file_instruments = [inst.strip() for inst in args.instruments.split(",")]
|
||||
file_instruments = [
|
||||
inst.strip() for inst in args.instruments.split(",")
|
||||
]
|
||||
else:
|
||||
file_instruments = get_available_instruments_from_db(datafile, config)
|
||||
all_instruments.extend(file_instruments)
|
||||
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_instruments = list(dict.fromkeys(all_instruments))
|
||||
|
||||
|
||||
store_config_in_database(
|
||||
db_path=args.result_db,
|
||||
config_file_path=args.config,
|
||||
config=config,
|
||||
strategy_class=strategy_class_name,
|
||||
datafiles=datafiles,
|
||||
instruments=unique_instruments
|
||||
instruments=unique_instruments,
|
||||
)
|
||||
|
||||
# Process each data file
|
||||
@ -214,7 +220,7 @@ def main() -> None:
|
||||
# Process data for this file
|
||||
try:
|
||||
strategy.reset()
|
||||
|
||||
|
||||
bt_results = run_backtest(
|
||||
config=config,
|
||||
datafile=datafile,
|
||||
|
||||
@ -5,8 +5,8 @@ from typing import Dict, Optional, cast
|
||||
|
||||
import pandas as pd # type: ignore[import]
|
||||
|
||||
from tools.trading_pair import TradingPair
|
||||
from results import BacktestResult
|
||||
from trading.trading_pair import TradingPair
|
||||
from trading.results import BacktestResult
|
||||
|
||||
NanoPerMin = 1e9
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user