fixes to notebook

This commit is contained in:
Oleg Sheynin 2025-07-24 22:45:21 +00:00
parent af0a6f62a9
commit 24f1f82d1f
4 changed files with 5494 additions and 80 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
@ -21,6 +23,15 @@ class PairsTradingFitMethod(ABC):
"signed_scaled_disequilibrium",
"pair",
]
@staticmethod
def create(config: Dict) -> PairsTradingFitMethod:
import importlib
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
return cast(PairsTradingFitMethod, fit_method)
@abstractmethod
def run_pair(

File diff suppressed because one or more lines are too long

View File

@ -99,7 +99,6 @@ def run_backtest(
bt_result.collect_single_day_results(pairs_trades)
return bt_result
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
@ -129,11 +128,7 @@ def main() -> None:
config: Dict = load_config(args.config)
# Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
fit_method = PairsTradingFitMethod.create(config)
# Resolve data files (CLI takes priority over config)
instruments = get_instruments(args, config)
@ -166,7 +161,7 @@ def main() -> None:
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=fit_method_class_name,
fit_method_class=config["fit_method_class"],
datafiles=datafiles,
instruments=instruments,
)

View File

@ -1,8 +1,8 @@
import glob
import os
from typing import Dict, List, Optional
import pandas as pd
import pandas as pd
from pt_trading.fit_method import PairsTradingFitMethod
@ -52,8 +52,8 @@ def create_pairs(
config: Dict,
instruments: List[Dict[str, str]],
) -> List:
from tools.data_loader import load_market_data
from pt_trading.trading_pair import TradingPair
from tools.data_loader import load_market_data
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
@ -74,8 +74,6 @@ def create_pairs(
market_data_df = pd.concat([market_data_df, md_df])
for a_index, b_index in unique_index_pairs:
from research.pt_backtest import TradingPair
pair = fit_method.create_trading_pair(
config=config_copy,
market_data=market_data_df,