fixes to notebook
This commit is contained in:
parent
af0a6f62a9
commit
24f1f82d1f
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, cast
|
from typing import Dict, Optional, cast
|
||||||
@ -21,6 +23,15 @@ class PairsTradingFitMethod(ABC):
|
|||||||
"signed_scaled_disequilibrium",
|
"signed_scaled_disequilibrium",
|
||||||
"pair",
|
"pair",
|
||||||
]
|
]
|
||||||
|
@staticmethod
|
||||||
|
def create(config: Dict) -> PairsTradingFitMethod:
|
||||||
|
import importlib
|
||||||
|
fit_method_class_name = config.get("fit_method_class", None)
|
||||||
|
assert fit_method_class_name is not None
|
||||||
|
module_name, class_name = fit_method_class_name.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
fit_method = getattr(module, class_name)()
|
||||||
|
return cast(PairsTradingFitMethod, fit_method)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run_pair(
|
def run_pair(
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@ -99,7 +99,6 @@ def run_backtest(
|
|||||||
bt_result.collect_single_day_results(pairs_trades)
|
bt_result.collect_single_day_results(pairs_trades)
|
||||||
return bt_result
|
return bt_result
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -129,11 +128,7 @@ def main() -> None:
|
|||||||
config: Dict = load_config(args.config)
|
config: Dict = load_config(args.config)
|
||||||
|
|
||||||
# Dynamically instantiate fit method class
|
# Dynamically instantiate fit method class
|
||||||
fit_method_class_name = config.get("fit_method_class", None)
|
fit_method = PairsTradingFitMethod.create(config)
|
||||||
assert fit_method_class_name is not None
|
|
||||||
module_name, class_name = fit_method_class_name.rsplit(".", 1)
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
fit_method = getattr(module, class_name)()
|
|
||||||
|
|
||||||
# Resolve data files (CLI takes priority over config)
|
# Resolve data files (CLI takes priority over config)
|
||||||
instruments = get_instruments(args, config)
|
instruments = get_instruments(args, config)
|
||||||
@ -166,7 +161,7 @@ def main() -> None:
|
|||||||
db_path=args.result_db,
|
db_path=args.result_db,
|
||||||
config_file_path=args.config,
|
config_file_path=args.config,
|
||||||
config=config,
|
config=config,
|
||||||
fit_method_class=fit_method_class_name,
|
fit_method_class=config["fit_method_class"],
|
||||||
datafiles=datafiles,
|
datafiles=datafiles,
|
||||||
instruments=instruments,
|
instruments=instruments,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
from pt_trading.fit_method import PairsTradingFitMethod
|
from pt_trading.fit_method import PairsTradingFitMethod
|
||||||
|
|
||||||
|
|
||||||
@ -52,8 +52,8 @@ def create_pairs(
|
|||||||
config: Dict,
|
config: Dict,
|
||||||
instruments: List[Dict[str, str]],
|
instruments: List[Dict[str, str]],
|
||||||
) -> List:
|
) -> List:
|
||||||
from tools.data_loader import load_market_data
|
|
||||||
from pt_trading.trading_pair import TradingPair
|
from pt_trading.trading_pair import TradingPair
|
||||||
|
from tools.data_loader import load_market_data
|
||||||
|
|
||||||
all_indexes = range(len(instruments))
|
all_indexes = range(len(instruments))
|
||||||
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
|
||||||
@ -74,8 +74,6 @@ def create_pairs(
|
|||||||
market_data_df = pd.concat([market_data_df, md_df])
|
market_data_df = pd.concat([market_data_df, md_df])
|
||||||
|
|
||||||
for a_index, b_index in unique_index_pairs:
|
for a_index, b_index in unique_index_pairs:
|
||||||
from research.pt_backtest import TradingPair
|
|
||||||
|
|
||||||
pair = fit_method.create_trading_pair(
|
pair = fit_method.create_trading_pair(
|
||||||
config=config_copy,
|
config=config_copy,
|
||||||
market_data=market_data_df,
|
market_data=market_data_df,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user