fixes to notebook
This commit is contained in:
parent
af0a6f62a9
commit
24f1f82d1f
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, cast
|
||||
@ -21,6 +23,15 @@ class PairsTradingFitMethod(ABC):
|
||||
"signed_scaled_disequilibrium",
|
||||
"pair",
|
||||
]
|
||||
@staticmethod
|
||||
def create(config: Dict) -> PairsTradingFitMethod:
|
||||
import importlib
|
||||
fit_method_class_name = config.get("fit_method_class", None)
|
||||
assert fit_method_class_name is not None
|
||||
module_name, class_name = fit_method_class_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
fit_method = getattr(module, class_name)()
|
||||
return cast(PairsTradingFitMethod, fit_method)
|
||||
|
||||
@abstractmethod
|
||||
def run_pair(
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -99,7 +99,6 @@ def run_backtest(
|
||||
bt_result.collect_single_day_results(pairs_trades)
|
||||
return bt_result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
||||
parser.add_argument(
|
||||
@ -129,11 +128,7 @@ def main() -> None:
|
||||
config: Dict = load_config(args.config)
|
||||
|
||||
# Dynamically instantiate fit method class
|
||||
fit_method_class_name = config.get("fit_method_class", None)
|
||||
assert fit_method_class_name is not None
|
||||
module_name, class_name = fit_method_class_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
fit_method = getattr(module, class_name)()
|
||||
fit_method = PairsTradingFitMethod.create(config)
|
||||
|
||||
# Resolve data files (CLI takes priority over config)
|
||||
instruments = get_instruments(args, config)
|
||||
@ -166,7 +161,7 @@ def main() -> None:
|
||||
db_path=args.result_db,
|
||||
config_file_path=args.config,
|
||||
config=config,
|
||||
fit_method_class=fit_method_class_name,
|
||||
fit_method_class=config["fit_method_class"],
|
||||
datafiles=datafiles,
|
||||
instruments=instruments,
|
||||
)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
|
||||
import pandas as pd
|
||||
from pt_trading.fit_method import PairsTradingFitMethod
|
||||
|
||||
|
||||
@ -52,8 +52,8 @@ def create_pairs(
|
||||
config: Dict,
|
||||
instruments: List[Dict[str, str]],
|
||||
) -> List:
|
||||
from tools.data_loader import load_market_data
|
||||
from pt_trading.trading_pair import TradingPair
|
||||
from tools.data_loader import load_market_data
|
||||
|
||||
all_indexes = range(len(instruments))
|
||||
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
||||
@ -74,8 +74,6 @@ def create_pairs(
|
||||
market_data_df = pd.concat([market_data_df, md_df])
|
||||
|
||||
for a_index, b_index in unique_index_pairs:
|
||||
from research.pt_backtest import TradingPair
|
||||
|
||||
pair = fit_method.create_trading_pair(
|
||||
config=config_copy,
|
||||
market_data=market_data_df,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user