Compare commits

..

No commits in common. "2c08b6f1a98da44aac4727ef02743ce59e807e83" and "a7b4777f76821527e7a3322ef40386197cb29f0b" have entirely different histories.

10 changed files with 1808 additions and 737 deletions

View File

@ -1,38 +0,0 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"price_column": "close",
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120,
"fit_method_class": "pt_trading.vecm_rolling_fit.VECMRollingFit",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
"close_outstanding_positions": true,
# "close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00",
"end_session": "18:30:00",
}
}

View File

@ -1,38 +0,0 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"price_column": "close",
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120,
"fit_method_class": "pt_trading.z-score_rolling_fit.ZScoreRollingFit",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
"close_outstanding_positions": true,
# "close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00",
"end_session": "15:30:00",
}
}

View File

@ -13,7 +13,7 @@
}, },
# ====== Funding ====== # ====== Funding ======
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
# ====== Trading Parameters ====== # ====== Trading Parameters ======
"price_column": "close", "price_column": "close",
"dis-equilibrium_open_trshld": 2.0, "dis-equilibrium_open_trshld": 2.0,
@ -31,8 +31,8 @@
"close_outstanding_positions": true, "close_outstanding_positions": true,
# "close_outstanding_positions": false, # "close_outstanding_positions": false,
"trading_hours": { "trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00", "begin_session": "9:30:00",
"end_session": "18:30:00", "end_session": "22:30:00",
"timezone": "America/New_York"
} }
} }

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Dict, Optional, cast from typing import Dict, Optional, cast
@ -23,15 +21,6 @@ class PairsTradingFitMethod(ABC):
"signed_scaled_disequilibrium", "signed_scaled_disequilibrium",
"pair", "pair",
] ]
@staticmethod
def create(config: Dict) -> PairsTradingFitMethod:
import importlib
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
return cast(PairsTradingFitMethod, fit_method)
@abstractmethod @abstractmethod
def run_pair( def run_pair(

View File

@ -121,7 +121,7 @@ def store_config_in_database(
config_file_path: str, config_file_path: str,
config: Dict, config: Dict,
fit_method_class: str, fit_method_class: str,
datafiles: List[Tuple[str, str]], datafiles: List[str],
instruments: List[Dict[str, str]], instruments: List[Dict[str, str]],
) -> None: ) -> None:
""" """
@ -140,7 +140,7 @@ def store_config_in_database(
config_json = json.dumps(config, indent=2, default=str) config_json = json.dumps(config, indent=2, default=str)
# Convert lists to comma-separated strings for storage # Convert lists to comma-separated strings for storage
datafiles_str = ", ".join([f"{datafile}" for _, datafile in datafiles]) datafiles_str = ", ".join(datafiles)
instruments_str = ", ".join( instruments_str = ", ".join(
[ [
f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}" f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}"
@ -613,7 +613,7 @@ class BacktestResult:
return current_value_a, current_value_b, total_current_value return current_value_a, current_value_b, total_current_value
def store_results_in_database( def store_results_in_database(
self, db_path: str, day: str self, db_path: str, datafile: str
) -> None: ) -> None:
""" """
Store backtest results in the SQLite database. Store backtest results in the SQLite database.
@ -623,7 +623,8 @@ class BacktestResult:
try: try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db) # Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
date_str = day filename = os.path.basename(datafile)
date_str = filename.split(".")[0] # Extract date part
# Convert to proper date format # Convert to proper date format
try: try:

View File

@ -17,7 +17,7 @@ class PairState(Enum):
class CointegrationData: class CointegrationData:
EG_PVALUE_THRESHOLD = 0.05 EG_PVALUE_THRESHOLD = 0.05
tstamp_: pd.Timestamp tstamp_: pd.Timestamp
pair_: str pair_: str
eg_pvalue_: float eg_pvalue_: float
@ -63,7 +63,7 @@ class CointegrationData:
"johansen_cvt": self.johansen_cvt_, "johansen_cvt": self.johansen_cvt_,
"eg_is_cointegrated": self.eg_is_cointegrated_, "eg_is_cointegrated": self.eg_is_cointegrated_,
"johansen_is_cointegrated": self.johansen_is_cointegrated_, "johansen_is_cointegrated": self.johansen_is_cointegrated_,
} }
def __repr__(self) -> str: def __repr__(self) -> str:
return f"CointegrationData(tstamp={self.tstamp_}, pair={self.pair_}, eg_pvalue={self.eg_pvalue_}, johansen_lr1={self.johansen_lr1_}, johansen_cvt={self.johansen_cvt_}, eg_is_cointegrated={self.eg_is_cointegrated_}, johansen_is_cointegrated={self.johansen_is_cointegrated_})" return f"CointegrationData(tstamp={self.tstamp_}, pair={self.pair_}, eg_pvalue={self.eg_pvalue_}, johansen_lr1={self.johansen_lr1_}, johansen_cvt={self.johansen_cvt_}, eg_is_cointegrated={self.eg_is_cointegrated_}, johansen_is_cointegrated={self.johansen_is_cointegrated_})"
@ -86,12 +86,7 @@ class TradingPair(ABC):
# predicted_df_: Optional[pd.DataFrame] # predicted_df_: Optional[pd.DataFrame]
def __init__( def __init__(
self, self, config: Dict[str, Any], market_data: pd.DataFrame, symbol_a: str, symbol_b: str, price_column: str
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
price_column: str,
): ):
self.symbol_a_ = symbol_a self.symbol_a_ = symbol_a
self.symbol_b_ = symbol_b self.symbol_b_ = symbol_b
@ -107,33 +102,25 @@ class TradingPair(ABC):
) )
self.market_data_ = self.market_data_.dropna().reset_index(drop=True) self.market_data_ = self.market_data_.dropna().reset_index(drop=True)
self.market_data_["tstamp"] = pd.to_datetime(self.market_data_["tstamp"]) self.market_data_['tstamp'] = pd.to_datetime(self.market_data_['tstamp'])
self.market_data_ = self.market_data_.sort_values("tstamp") self.market_data_ = self.market_data_.sort_values('tstamp')
def get_begin_index(self) -> int: def get_begin_index(self) -> int:
if "trading_hours" not in self.config_: if "trading_hours" not in self.config_:
return 0 return 0
assert "timezone" in self.config_["trading_hours"] assert "timezone" in self.config_["trading_hours"]
assert "begin_session" in self.config_["trading_hours"] assert "begin_session" in self.config_["trading_hours"]
start_time = ( start_time = pd.to_datetime(self.config_["trading_hours"]["begin_session"]).tz_localize(self.config_["trading_hours"]["timezone"]).time()
pd.to_datetime(self.config_["trading_hours"]["begin_session"]) mask = self.market_data_['tstamp'].dt.time >= start_time
.tz_localize(self.config_["trading_hours"]["timezone"])
.time()
)
mask = self.market_data_["tstamp"].dt.time >= start_time
return int(self.market_data_.index[mask].min()) return int(self.market_data_.index[mask].min())
def get_end_index(self) -> int: def get_end_index(self) -> int:
if "trading_hours" not in self.config_: if "trading_hours" not in self.config_:
return 0 return 0
assert "timezone" in self.config_["trading_hours"] assert "timezone" in self.config_["trading_hours"]
assert "end_session" in self.config_["trading_hours"] assert "end_session" in self.config_["trading_hours"]
end_time = ( end_time = pd.to_datetime(self.config_["trading_hours"]["end_session"]).tz_localize(self.config_["trading_hours"]["timezone"]).time()
pd.to_datetime(self.config_["trading_hours"]["end_session"]) mask = self.market_data_['tstamp'].dt.time <= end_time
.tz_localize(self.config_["trading_hours"]["timezone"])
.time()
)
mask = self.market_data_["tstamp"].dt.time <= end_time
return int(self.market_data_.index[mask].max()) return int(self.market_data_.index[mask].max())
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
@ -184,7 +171,7 @@ class TradingPair(ABC):
testing_start_index = training_start_index + training_minutes testing_start_index = training_start_index + training_minutes
self.training_df_ = self.market_data_.iloc[ self.training_df_ = self.market_data_.iloc[
training_start_index:testing_start_index, :training_minutes training_start_index:testing_start_index, : training_minutes
].copy() ].copy()
assert self.training_df_ is not None assert self.training_df_ is not None
self.training_df_ = self.training_df_.dropna().reset_index(drop=True) self.training_df_ = self.training_df_.dropna().reset_index(drop=True)
@ -212,7 +199,7 @@ class TradingPair(ABC):
else: else:
# Ensure both DataFrames have the same columns and dtypes before concatenation # Ensure both DataFrames have the same columns and dtypes before concatenation
existing_trades = self.user_data_["trades"] existing_trades = self.user_data_["trades"]
# If existing trades is empty, just assign the new trades # If existing trades is empty, just assign the new trades
if len(existing_trades) == 0: if len(existing_trades) == 0:
self.user_data_["trades"] = trades.copy() self.user_data_["trades"] = trades.copy()
@ -226,26 +213,22 @@ class TradingPair(ABC):
trades[col] = pd.Timestamp.now() trades[col] = pd.Timestamp.now()
elif col in ["action", "symbol"]: elif col in ["action", "symbol"]:
trades[col] = "" trades[col] = ""
elif col in [ elif col in ["price", "disequilibrium", "scaled_disequilibrium"]:
"price",
"disequilibrium",
"scaled_disequilibrium",
]:
trades[col] = 0.0 trades[col] = 0.0
elif col == "pair": elif col == "pair":
trades[col] = None trades[col] = None
else: else:
trades[col] = None trades[col] = None
# Concatenate with explicit dtypes to avoid warnings # Concatenate with explicit dtypes to avoid warnings
self.user_data_["trades"] = pd.concat( self.user_data_["trades"] = pd.concat(
[existing_trades, trades], ignore_index=True, copy=False [existing_trades, trades],
ignore_index=True,
copy=False
) )
def get_trades(self) -> pd.DataFrame: def get_trades(self) -> pd.DataFrame:
return ( return self.user_data_["trades"] if "trades" in self.user_data_ else pd.DataFrame()
self.user_data_["trades"] if "trades" in self.user_data_ else pd.DataFrame()
)
def cointegration_check(self) -> Optional[pd.DataFrame]: def cointegration_check(self) -> Optional[pd.DataFrame]:
print(f"***{self}*** STARTING....") print(f"***{self}*** STARTING....")
@ -254,19 +237,17 @@ class TradingPair(ABC):
curr_training_start_idx = 0 curr_training_start_idx = 0
COINTEGRATION_DATA_COLUMNS = { COINTEGRATION_DATA_COLUMNS = {
"tstamp": "datetime64[ns]", "tstamp" : "datetime64[ns]",
"pair": "string", "pair" : "string",
"eg_pvalue": "float64", "eg_pvalue" : "float64",
"johansen_lr1": "float64", "johansen_lr1" : "float64",
"johansen_cvt": "float64", "johansen_cvt" : "float64",
"eg_is_cointegrated": "bool", "eg_is_cointegrated" : "bool",
"johansen_is_cointegrated": "bool", "johansen_is_cointegrated" : "bool",
} }
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings # Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
result: pd.DataFrame = pd.DataFrame( result: pd.DataFrame = pd.DataFrame(columns=[col for col in COINTEGRATION_DATA_COLUMNS.keys()]) #.astype(COINTEGRATION_DATA_COLUMNS)
columns=[col for col in COINTEGRATION_DATA_COLUMNS.keys()]
) # .astype(COINTEGRATION_DATA_COLUMNS)
training_minutes = config["training_minutes"] training_minutes = config["training_minutes"]
while True: while True:
print(curr_training_start_idx, end="\r") print(curr_training_start_idx, end="\r")
@ -290,16 +271,13 @@ class TradingPair(ABC):
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool: def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
config = self.config_ config = self.config_
if ( if ("stop_close_conditions" not in config or config["stop_close_conditions"] is None) :
"stop_close_conditions" not in config
or config["stop_close_conditions"] is None
):
return False return False
if "profit" in config["stop_close_conditions"]: if "profit" in config["stop_close_conditions"]:
current_return = self._current_return(predicted_row) current_return = self._current_return(predicted_row)
# #
# print(f"time={predicted_row['tstamp']} current_return={current_return}") # print(f"time={predicted_row['tstamp']} current_return={current_return}")
# #
if current_return >= config["stop_close_conditions"]["profit"]: if current_return >= config["stop_close_conditions"]["profit"]:
print(f"STOP PROFIT: {current_return}") print(f"STOP PROFIT: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
@ -310,10 +288,9 @@ class TradingPair(ABC):
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
return True return True
return False return False
def on_open_trades(self, trades: pd.DataFrame) -> None: def on_open_trades(self, trades: pd.DataFrame) -> None:
if "close_trades" in self.user_data_: if "close_trades" in self.user_data_: del self.user_data_["close_trades"]
del self.user_data_["close_trades"]
self.user_data_["open_trades"] = trades self.user_data_["open_trades"] = trades
def on_close_trades(self, trades: pd.DataFrame) -> None: def on_close_trades(self, trades: pd.DataFrame) -> None:
@ -325,25 +302,20 @@ class TradingPair(ABC):
open_trades = self.user_data_["open_trades"] open_trades = self.user_data_["open_trades"]
if len(open_trades) == 0: if len(open_trades) == 0:
return 0.0 return 0.0
def _single_instrument_return(symbol: str) -> float: def _single_instrument_return(symbol: str) -> float:
instrument_open_trades = open_trades[open_trades["symbol"] == symbol] instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
instrument_open_price = instrument_open_trades["price"].iloc[0] instrument_open_price = instrument_open_trades["price"].iloc[0]
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1 sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
instrument_price = predicted_row[f"{self.price_column_}_{symbol}"] instrument_price = predicted_row[f"{self.price_column_}_{symbol}"]
instrument_return = ( instrument_return = sign * (instrument_price - instrument_open_price) / instrument_open_price
sign
* (instrument_price - instrument_open_price)
/ instrument_open_price
)
return float(instrument_return) * 100.0 return float(instrument_return) * 100.0
instrument_a_return = _single_instrument_return(self.symbol_a_) instrument_a_return = _single_instrument_return(self.symbol_a_)
instrument_b_return = _single_instrument_return(self.symbol_b_) instrument_b_return = _single_instrument_return(self.symbol_b_)
return instrument_a_return + instrument_b_return return (instrument_a_return + instrument_b_return)
return 0.0 return 0.0
def __repr__(self) -> str: def __repr__(self) -> str:
return self.name() return self.name()
@ -356,3 +328,4 @@ class TradingPair(ABC):
# @abstractmethod # @abstractmethod
# def predicted_df(self) -> Optional[pd.DataFrame]: ... # def predicted_df(self) -> Optional[pd.DataFrame]: ...

View File

@ -5,7 +5,7 @@ from typing import Dict, List, cast
import pandas as pd import pandas as pd
def load_sqlite_to_dataframe(db_path:str, query:str) -> pd.DataFrame: def load_sqlite_to_dataframe(db_path, query):
try: try:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)

File diff suppressed because one or more lines are too long

View File

@ -3,7 +3,7 @@ import glob
import importlib import importlib
import os import os
from datetime import date, datetime from datetime import date, datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
@ -17,13 +17,11 @@ from pt_trading.results import (
from pt_trading.fit_method import PairsTradingFitMethod from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair from pt_trading.trading_pair import TradingPair
DayT = str
DataFileNameT = str
def resolve_datafiles( def resolve_datafiles(
config: Dict, date_pattern: str, instruments: List[Dict[str, str]] config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
) -> List[Tuple[DayT, DataFileNameT]]: ) -> List[str]:
resolved_files: List[Tuple[DayT, DataFileNameT]] = [] resolved_files = []
for inst in instruments: for inst in instruments:
pattern = date_pattern pattern = date_pattern
inst_type = inst["instrument_type"] inst_type = inst["instrument_type"]
@ -33,17 +31,12 @@ def resolve_datafiles(
if not os.path.isabs(pattern): if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db") pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
matched_files = glob.glob(pattern) matched_files = glob.glob(pattern)
for matched_file in matched_files: resolved_files.extend(matched_files)
import re
match = re.search(r"(\d{8})\.mktdata\.ohlcv\.db$", matched_file)
assert match is not None
day = match.group(1)
resolved_files.append((day, matched_file))
else: else:
# Handle explicit file path # Handle explicit file path
if not os.path.isabs(pattern): if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db") pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
resolved_files.append((date_pattern, pattern)) resolved_files.append(pattern)
return sorted(list(set(resolved_files))) # Remove duplicates and sort return sorted(list(set(resolved_files))) # Remove duplicates and sort
@ -68,7 +61,7 @@ def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, st
def run_backtest( def run_backtest(
config: Dict, config: Dict,
datafiles: List[str], datafile: str,
price_column: str, price_column: str,
fit_method: PairsTradingFitMethod, fit_method: PairsTradingFitMethod,
instruments: List[Dict[str, str]], instruments: List[Dict[str, str]],
@ -77,14 +70,10 @@ def run_backtest(
Run backtest for all pairs using the specified instruments. Run backtest for all pairs using the specified instruments.
""" """
bt_result: BacktestResult = BacktestResult(config=config) bt_result: BacktestResult = BacktestResult(config=config)
if len(datafiles) < 2:
print(f"WARNING: insufficient data files: {datafiles}")
return bt_result
pairs_trades = [] pairs_trades = []
pairs = create_pairs( pairs = create_pairs(
datafiles=datafiles, datafile=datafile,
fit_method=fit_method, fit_method=fit_method,
price_column=price_column, price_column=price_column,
config=config, config=config,
@ -103,6 +92,7 @@ def run_backtest(
bt_result.collect_single_day_results(pairs_trades) bt_result.collect_single_day_results(pairs_trades)
return bt_result return bt_result
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.") parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument( parser.add_argument(
@ -132,7 +122,11 @@ def main() -> None:
config: Dict = load_config(args.config) config: Dict = load_config(args.config)
# Dynamically instantiate fit method class # Dynamically instantiate fit method class
fit_method = PairsTradingFitMethod.create(config) fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config) # Resolve data files (CLI takes priority over config)
instruments = get_instruments(args, config) instruments = get_instruments(args, config)
@ -142,7 +136,6 @@ def main() -> None:
print("No data files found to process.") print("No data files found to process.")
return return
days = list(set([day for day, _ in datafiles]))
print(f"Found {len(datafiles)} data files to process:") print(f"Found {len(datafiles)} data files to process:")
for df in datafiles: for df in datafiles:
print(f" - {df}") print(f" - {df}")
@ -165,7 +158,7 @@ def main() -> None:
db_path=args.result_db, db_path=args.result_db,
config_file_path=args.config, config_file_path=args.config,
config=config, config=config,
fit_method_class=config["fit_method_class"], fit_method_class=fit_method_class_name,
datafiles=datafiles, datafiles=datafiles,
instruments=instruments, instruments=instruments,
) )
@ -173,9 +166,8 @@ def main() -> None:
# Process each data file # Process each data file
price_column = config["price_column"] price_column = config["price_column"]
for day in sorted(days): for datafile in datafiles:
md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day] print(f"\n====== Processing {os.path.basename(datafile)} ======")
print(f"\n====== Processing {day} ======")
# Process data for this file # Process data for this file
try: try:
@ -183,14 +175,14 @@ def main() -> None:
bt_results = run_backtest( bt_results = run_backtest(
config=config, config=config,
datafiles=md_datafiles, datafile=datafile,
price_column=price_column, price_column=price_column,
fit_method=fit_method, fit_method=fit_method,
instruments=instruments, instruments=instruments,
) )
# Store results with day name as key # Store results with file name as key
filename = os.path.basename(day) filename = os.path.basename(datafile)
all_results[filename] = { all_results[filename] = {
"trades": bt_results.trades.copy(), "trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(), "outstanding_positions": bt_results.outstanding_positions.copy(),
@ -206,12 +198,12 @@ def main() -> None:
} }
} }
) )
bt_results.store_results_in_database(db_path=args.result_db, day=day) bt_results.store_results_in_database(args.result_db, datafile)
print(f"Successfully processed {filename}") print(f"Successfully processed {filename}")
except Exception as err: except Exception as err:
print(f"Error processing {day}: {str(err)}") print(f"Error processing {datafile}: {str(err)}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View File

@ -2,7 +2,6 @@ import glob
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
import pandas as pd
from pt_trading.fit_method import PairsTradingFitMethod from pt_trading.fit_method import PairsTradingFitMethod
@ -46,14 +45,14 @@ def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List
def create_pairs( def create_pairs(
datafiles: List[str], datafile: str,
fit_method: PairsTradingFitMethod, fit_method: PairsTradingFitMethod,
price_column: str, price_column: str,
config: Dict, config: Dict,
instruments: List[Dict[str, str]], instruments: List[Dict[str, str]],
) -> List: ) -> List:
from pt_trading.trading_pair import TradingPair
from tools.data_loader import load_market_data from tools.data_loader import load_market_data
from pt_trading.trading_pair import TradingPair
all_indexes = range(len(instruments)) all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j] unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
@ -62,18 +61,17 @@ def create_pairs(
# Update config to use the specified instruments # Update config to use the specified instruments
config_copy = config.copy() config_copy = config.copy()
config_copy["instruments"] = instruments config_copy["instruments"] = instruments
market_data_df = pd.DataFrame() market_data_df = load_market_data(
for datafile in datafiles: datafile=datafile,
md_df = load_market_data( instruments=instruments,
datafile=datafile, db_table_name=config_copy["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
instruments=instruments, trading_hours=config_copy["trading_hours"],
db_table_name=config_copy["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"], )
trading_hours=config_copy["trading_hours"],
)
market_data_df = pd.concat([market_data_df, md_df])
for a_index, b_index in unique_index_pairs: for a_index, b_index in unique_index_pairs:
from research.pt_backtest import TradingPair
pair = fit_method.create_trading_pair( pair = fit_method.create_trading_pair(
config=config_copy, config=config_copy,
market_data=market_data_df, market_data=market_data_df,