diff --git a/configuration/equity.cfg b/configuration/equity.cfg index e53c8e9..a018f27 100644 --- a/configuration/equity.cfg +++ b/configuration/equity.cfg @@ -2,15 +2,7 @@ "security_type": "EQUITY", "data_directory": "./data/equity", "datafiles": [ - # "20250508.alpaca_sim_md.db", - "20250509.alpaca_sim_md.db", - # "20250512.alpaca_sim_md.db", - # "20250513.alpaca_sim_md.db", - # "20250514.alpaca_sim_md.db", - # "20250515.alpaca_sim_md.db", - # "20250516.alpaca_sim_md.db", - # "20250519.alpaca_sim_md.db", - # "20250520.alpaca_sim_md.db" + "202505*.alpaca_sim_md.db", ], "db_table_name": "md_1min_bars", "exchange_id": "ALPACA", @@ -29,4 +21,6 @@ "funding_per_pair": 2000.0, # "strategy_class": "strategies.StaticFitStrategy" "strategy_class": "strategies.SlidingFitStrategy" + "exclude_instruments": ["CAN"] + } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0537291..eb1e396 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ eyeD3>=0.8.10 filelock>=3.6.0 frozenlist>=1.3.3 grpcio>=1.30.2 +hjson>=3.0.2 html5lib>=1.1 httplib2>=0.20.2 idna>=3.3 @@ -39,7 +40,7 @@ multidict>=6.0.4 mypy>=0.942 mypy-extensions>=0.4.3 netaddr>=0.8.0 -netifaces>=0.11.0 +######### netifaces>=0.11.0 oauthlib>=3.2.0 packaging>=23.1 pathspec>=0.11.1 @@ -72,7 +73,7 @@ statsmodels>=0.14.4 texttable>=1.6.4 tldextract>=3.1.2 tomli>=1.2.2 -typed-ast>=1.4.3 +######## typed-ast>=1.4.3 types-aiofiles>=0.1 types-annoy>=1.17 types-appdirs>=1.4 diff --git a/src/pt_backtest.py b/src/pt_backtest.py index ee2b380..7139d44 100644 --- a/src/pt_backtest.py +++ b/src/pt_backtest.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional import pandas as pd -from tools.data_loader import load_market_data +from tools.data_loader import get_available_instruments_from_db, load_market_data from tools.trading_pair import TradingPair from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database @@ -21,40 +21,40 @@ def load_config(config_path: str) -> Dict: return config -def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]: - """ - Auto-detect available instruments from the database by querying distinct instrument_id values. - Returns instruments without the configured prefix. - """ - try: - conn = sqlite3.connect(datafile) +# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]: +# """ +# Auto-detect available instruments from the database by querying distinct instrument_id values. +# Returns instruments without the configured prefix. +# """ +# try: +# conn = sqlite3.connect(datafile) - # Query to get distinct instrument_ids - query = f""" - SELECT DISTINCT instrument_id - FROM {config['db_table_name']} - WHERE exchange_id = ? - """ +# # Query to get distinct instrument_ids +# query = f""" +# SELECT DISTINCT instrument_id +# FROM {config['db_table_name']} +# WHERE exchange_id = ? +# """ - cursor = conn.execute(query, (config["exchange_id"],)) - instrument_ids = [row[0] for row in cursor.fetchall()] - conn.close() +# cursor = conn.execute(query, (config["exchange_id"],)) +# instrument_ids = [row[0] for row in cursor.fetchall()] +# conn.close() - # Remove the configured prefix to get instrument symbols - prefix = config.get("instrument_id_pfx", "") - instruments = [] - for instrument_id in instrument_ids: - if instrument_id.startswith(prefix): - symbol = instrument_id[len(prefix) :] - instruments.append(symbol) - else: - instruments.append(instrument_id) +# # Remove the configured prefix to get instrument symbols +# prefix = config.get("instrument_id_pfx", "") +# instruments = [] +# for instrument_id in instrument_ids: +# if instrument_id.startswith(prefix): +# symbol = instrument_id[len(prefix) :] +# instruments.append(symbol) +# else: +# instruments.append(instrument_id) - return sorted(instruments) +# return sorted(instruments) - except Exception as e: - print(f"Error auto-detecting instruments from {datafile}: {str(e)}") - return [] +# except Exception as e: +# print(f"Error auto-detecting instruments from {datafile}: {str(e)}") +# return [] def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: @@ -100,13 +100,13 @@ def run_backtest( config: Dict, datafile: str, price_column: str, - bt_result: BacktestResult, strategy, instruments: List[str], -) -> None: +) -> BacktestResult: """ Run backtest for all pairs using the specified instruments. """ + bt_result: BacktestResult = BacktestResult(config=config) def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]: nonlocal datafile @@ -141,13 +141,14 @@ def run_backtest( # Check if result_list has any data before concatenating if len(pairs_trades) == 0: print("No trading signals found for any pairs") - return None + return bt_result result = pd.concat(pairs_trades, ignore_index=True) result["time"] = pd.to_datetime(result["time"]) result = result.set_index("time").sort_index() bt_result.collect_single_day_results(result) + return bt_result def main() -> None: @@ -201,7 +202,6 @@ def main() -> None: # Initialize a dictionary to store all trade results all_results: Dict[str, Dict[str, Any]] = {} - bt_results = BacktestResult(config=config) # Store configuration in database for reference if args.result_db.upper() != "NONE": @@ -232,9 +232,6 @@ def main() -> None: for datafile in datafiles: print(f"\n====== Processing {os.path.basename(datafile)} ======") - # Clear the trades for the new file - bt_results.clear_trades() - # Determine instruments to use if args.instruments: # Use CLI-specified instruments @@ -251,11 +248,12 @@ def main() -> None: # Process data for this file try: - run_backtest( + strategy.reset() + + bt_results = run_backtest( config=config, datafile=datafile, price_column=price_column, - bt_result=bt_results, strategy=strategy, instruments=instruments, ) @@ -270,17 +268,18 @@ def main() -> None: print(f"Successfully processed {filename}") - except Exception as e: - print(f"Error processing {datafile}: {str(e)}") + except Exception as err: + print(f"Error processing {datafile}: {str(err)}") import traceback traceback.print_exc() - # Calculate and print results + # Calculate and print results using a new BacktestResult instance for aggregation if all_results: - bt_results.calculate_returns(all_results) - bt_results.print_grand_totals() - bt_results.print_outstanding_positions() + aggregate_bt_results = BacktestResult(config=config) + aggregate_bt_results.calculate_returns(all_results) + aggregate_bt_results.print_grand_totals() + aggregate_bt_results.print_outstanding_positions() if args.result_db.upper() != "NONE": print(f"\nResults stored in database: {args.result_db}") diff --git a/src/results.py b/src/results.py index 67480e5..56e34d9 100644 --- a/src/results.py +++ b/src/results.py @@ -596,6 +596,9 @@ class BacktestResult: open_px_a, open_px_b: Opening prices for symbols A and B open_tstamp: Opening timestamp """ + if pair_result_df is None or pair_result_df.empty: + return 0, 0, 0 + last_row = pair_result_df.loc[last_row_index] last_tstamp = last_row["tstamp"] colname_a, colname_b = pair.colnames() diff --git a/src/strategies.py b/src/strategies.py index 1683cd7..ad85526 100644 --- a/src/strategies.py +++ b/src/strategies.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod from enum import Enum -import sys -from typing import Dict, Optional +from typing import Dict, Optional, cast -import pandas as pd # type: ignore +import pandas as pd # type: ignore[import] from tools.trading_pair import TradingPair from results import BacktestResult @@ -24,6 +23,10 @@ class PairsTradingStrategy(ABC): @abstractmethod def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: ... + + @abstractmethod + def reset(self): + ... class StaticFitStrategy(PairsTradingStrategy): @@ -197,8 +200,11 @@ class StaticFitStrategy(PairsTradingStrategy): # Add tuples to data frame return pd.DataFrame( trd_signal_tuples, - columns=self.TRADES_COLUMNS, + columns=self.TRADES_COLUMNS, # type: ignore ) + + def reset(self): + pass class PairState(Enum): INITIAL = 1 @@ -214,7 +220,7 @@ class SlidingFitStrategy(PairsTradingStrategy): print(f"***{pair}*** STARTING....") pair.user_data_['state'] = PairState.INITIAL - pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS) + pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS) # type: ignore pair.user_data_["is_cointegrated"] = False open_threshold = config["dis-equilibrium_open_trshld"] @@ -357,7 +363,7 @@ class SlidingFitStrategy(PairsTradingStrategy): ] return pd.DataFrame( trd_signal_tuples, - columns=self.TRADES_COLUMNS, + columns=self.TRADES_COLUMNS, # type: ignore ) def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]: @@ -404,9 +410,11 @@ class SlidingFitStrategy(PairsTradingStrategy): # Add tuples to data frame return pd.DataFrame( trd_signal_tuples, - columns=self.TRADES_COLUMNS, + columns=self.TRADES_COLUMNS, # type: ignore ) + def reset(self): + self.curr_training_start_idx_ = 0 diff --git a/src/tools/data_loader.py b/src/tools/data_loader.py index b6f7fcb..08a5f28 100644 --- a/src/tools/data_loader.py +++ b/src/tools/data_loader.py @@ -1,5 +1,5 @@ import sqlite3 -from typing import Dict +from typing import Dict, List, cast import pandas as pd @@ -83,9 +83,53 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame: df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)] df["tstamp"] = pd.to_datetime(df["tstamp"]) - return df + return cast(pd.DataFrame, df) +def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]: + """ + Auto-detect available instruments from the database by querying distinct instrument_id values. + Returns instruments without the configured prefix. + """ + try: + conn = sqlite3.connect(datafile) + + # Build exclusion list with full instrument_ids + exclude_instruments = config.get("exclude_instruments", []) + prefix = config.get("instrument_id_pfx", "") + exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments] + + # Query to get distinct instrument_ids + query = f""" + SELECT DISTINCT instrument_id + FROM {config['db_table_name']} + WHERE exchange_id = ? + """ + + # Add exclusion clause if there are instruments to exclude + if exclude_instrument_ids: + placeholders = ','.join(['?' for _ in exclude_instrument_ids]) + query += f" AND instrument_id NOT IN ({placeholders})" + cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids)) + else: + cursor = conn.execute(query, (config["exchange_id"],)) + instrument_ids = [row[0] for row in cursor.fetchall()] + conn.close() + + # Remove the configured prefix to get instrument symbols + instruments = [] + for instrument_id in instrument_ids: + if instrument_id.startswith(prefix): + symbol = instrument_id[len(prefix) :] + instruments.append(symbol) + else: + instruments.append(instrument_id) + + return sorted(instruments) + + except Exception as e: + print(f"Error auto-detecting instruments from {datafile}: {str(e)}") + return [] # if __name__ == "__main__":