fixes to notebook

This commit is contained in:
Oleg Sheynin 2025-07-24 22:45:21 +00:00
parent af0a6f62a9
commit 24f1f82d1f
4 changed files with 5494 additions and 80 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Dict, Optional, cast from typing import Dict, Optional, cast
@ -21,6 +23,15 @@ class PairsTradingFitMethod(ABC):
"signed_scaled_disequilibrium", "signed_scaled_disequilibrium",
"pair", "pair",
] ]
@staticmethod
def create(config: Dict) -> PairsTradingFitMethod:
import importlib
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
return cast(PairsTradingFitMethod, fit_method)
@abstractmethod @abstractmethod
def run_pair( def run_pair(

File diff suppressed because one or more lines are too long

View File

@ -99,7 +99,6 @@ def run_backtest(
bt_result.collect_single_day_results(pairs_trades) bt_result.collect_single_day_results(pairs_trades)
return bt_result return bt_result
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.") parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument( parser.add_argument(
@ -129,11 +128,7 @@ def main() -> None:
config: Dict = load_config(args.config) config: Dict = load_config(args.config)
# Dynamically instantiate fit method class # Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None) fit_method = PairsTradingFitMethod.create(config)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config) # Resolve data files (CLI takes priority over config)
instruments = get_instruments(args, config) instruments = get_instruments(args, config)
@ -166,7 +161,7 @@ def main() -> None:
db_path=args.result_db, db_path=args.result_db,
config_file_path=args.config, config_file_path=args.config,
config=config, config=config,
fit_method_class=fit_method_class_name, fit_method_class=config["fit_method_class"],
datafiles=datafiles, datafiles=datafiles,
instruments=instruments, instruments=instruments,
) )

View File

@ -1,8 +1,8 @@
import glob import glob
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
import pandas as pd
import pandas as pd
from pt_trading.fit_method import PairsTradingFitMethod from pt_trading.fit_method import PairsTradingFitMethod
@ -52,8 +52,8 @@ def create_pairs(
config: Dict, config: Dict,
instruments: List[Dict[str, str]], instruments: List[Dict[str, str]],
) -> List: ) -> List:
from tools.data_loader import load_market_data
from pt_trading.trading_pair import TradingPair from pt_trading.trading_pair import TradingPair
from tools.data_loader import load_market_data
all_indexes = range(len(instruments)) all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j] unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
@ -74,8 +74,6 @@ def create_pairs(
market_data_df = pd.concat([market_data_df, md_df]) market_data_df = pd.concat([market_data_df, md_df])
for a_index, b_index in unique_index_pairs: for a_index, b_index in unique_index_pairs:
from research.pt_backtest import TradingPair
pair = fit_method.create_trading_pair( pair = fit_method.create_trading_pair(
config=config_copy, config=config_copy,
market_data=market_data_df, market_data=market_data_df,