This commit is contained in:
Oleg Sheynin 2025-06-20 18:06:04 -04:00
parent 95b25eddd7
commit 6cd82b3621
6 changed files with 114 additions and 65 deletions

View File

@ -2,15 +2,7 @@
"security_type": "EQUITY", "security_type": "EQUITY",
"data_directory": "./data/equity", "data_directory": "./data/equity",
"datafiles": [ "datafiles": [
# "20250508.alpaca_sim_md.db", "202505*.alpaca_sim_md.db",
"20250509.alpaca_sim_md.db",
# "20250512.alpaca_sim_md.db",
# "20250513.alpaca_sim_md.db",
# "20250514.alpaca_sim_md.db",
# "20250515.alpaca_sim_md.db",
# "20250516.alpaca_sim_md.db",
# "20250519.alpaca_sim_md.db",
# "20250520.alpaca_sim_md.db"
], ],
"db_table_name": "md_1min_bars", "db_table_name": "md_1min_bars",
"exchange_id": "ALPACA", "exchange_id": "ALPACA",
@ -29,4 +21,6 @@
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
# "strategy_class": "strategies.StaticFitStrategy" # "strategy_class": "strategies.StaticFitStrategy"
"strategy_class": "strategies.SlidingFitStrategy" "strategy_class": "strategies.SlidingFitStrategy"
"exclude_instruments": ["CAN"]
} }

View File

@ -19,6 +19,7 @@ eyeD3>=0.8.10
filelock>=3.6.0 filelock>=3.6.0
frozenlist>=1.3.3 frozenlist>=1.3.3
grpcio>=1.30.2 grpcio>=1.30.2
hjson>=3.0.2
html5lib>=1.1 html5lib>=1.1
httplib2>=0.20.2 httplib2>=0.20.2
idna>=3.3 idna>=3.3
@ -39,7 +40,7 @@ multidict>=6.0.4
mypy>=0.942 mypy>=0.942
mypy-extensions>=0.4.3 mypy-extensions>=0.4.3
netaddr>=0.8.0 netaddr>=0.8.0
netifaces>=0.11.0 ######### netifaces>=0.11.0
oauthlib>=3.2.0 oauthlib>=3.2.0
packaging>=23.1 packaging>=23.1
pathspec>=0.11.1 pathspec>=0.11.1
@ -72,7 +73,7 @@ statsmodels>=0.14.4
texttable>=1.6.4 texttable>=1.6.4
tldextract>=3.1.2 tldextract>=3.1.2
tomli>=1.2.2 tomli>=1.2.2
typed-ast>=1.4.3 ######## typed-ast>=1.4.3
types-aiofiles>=0.1 types-aiofiles>=0.1
types-annoy>=1.17 types-annoy>=1.17
types-appdirs>=1.4 types-appdirs>=1.4

View File

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
from tools.data_loader import load_market_data from tools.data_loader import get_available_instruments_from_db, load_market_data
from tools.trading_pair import TradingPair from tools.trading_pair import TradingPair
from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database
@ -21,40 +21,40 @@ def load_config(config_path: str) -> Dict:
return config return config
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]: # def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
""" # """
Auto-detect available instruments from the database by querying distinct instrument_id values. # Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix. # Returns instruments without the configured prefix.
""" # """
try: # try:
conn = sqlite3.connect(datafile) # conn = sqlite3.connect(datafile)
# Query to get distinct instrument_ids # # Query to get distinct instrument_ids
query = f""" # query = f"""
SELECT DISTINCT instrument_id # SELECT DISTINCT instrument_id
FROM {config['db_table_name']} # FROM {config['db_table_name']}
WHERE exchange_id = ? # WHERE exchange_id = ?
""" # """
cursor = conn.execute(query, (config["exchange_id"],)) # cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()] # instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close() # conn.close()
# Remove the configured prefix to get instrument symbols # # Remove the configured prefix to get instrument symbols
prefix = config.get("instrument_id_pfx", "") # prefix = config.get("instrument_id_pfx", "")
instruments = [] # instruments = []
for instrument_id in instrument_ids: # for instrument_id in instrument_ids:
if instrument_id.startswith(prefix): # if instrument_id.startswith(prefix):
symbol = instrument_id[len(prefix) :] # symbol = instrument_id[len(prefix) :]
instruments.append(symbol) # instruments.append(symbol)
else: # else:
instruments.append(instrument_id) # instruments.append(instrument_id)
return sorted(instruments) # return sorted(instruments)
except Exception as e: # except Exception as e:
print(f"Error auto-detecting instruments from {datafile}: {str(e)}") # print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
return [] # return []
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
@ -100,13 +100,13 @@ def run_backtest(
config: Dict, config: Dict,
datafile: str, datafile: str,
price_column: str, price_column: str,
bt_result: BacktestResult,
strategy, strategy,
instruments: List[str], instruments: List[str],
) -> None: ) -> BacktestResult:
""" """
Run backtest for all pairs using the specified instruments. Run backtest for all pairs using the specified instruments.
""" """
bt_result: BacktestResult = BacktestResult(config=config)
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]: def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile nonlocal datafile
@ -141,13 +141,14 @@ def run_backtest(
# Check if result_list has any data before concatenating # Check if result_list has any data before concatenating
if len(pairs_trades) == 0: if len(pairs_trades) == 0:
print("No trading signals found for any pairs") print("No trading signals found for any pairs")
return None return bt_result
result = pd.concat(pairs_trades, ignore_index=True) result = pd.concat(pairs_trades, ignore_index=True)
result["time"] = pd.to_datetime(result["time"]) result["time"] = pd.to_datetime(result["time"])
result = result.set_index("time").sort_index() result = result.set_index("time").sort_index()
bt_result.collect_single_day_results(result) bt_result.collect_single_day_results(result)
return bt_result
def main() -> None: def main() -> None:
@ -201,7 +202,6 @@ def main() -> None:
# Initialize a dictionary to store all trade results # Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {} all_results: Dict[str, Dict[str, Any]] = {}
bt_results = BacktestResult(config=config)
# Store configuration in database for reference # Store configuration in database for reference
if args.result_db.upper() != "NONE": if args.result_db.upper() != "NONE":
@ -232,9 +232,6 @@ def main() -> None:
for datafile in datafiles: for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======") print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Clear the trades for the new file
bt_results.clear_trades()
# Determine instruments to use # Determine instruments to use
if args.instruments: if args.instruments:
# Use CLI-specified instruments # Use CLI-specified instruments
@ -251,11 +248,12 @@ def main() -> None:
# Process data for this file # Process data for this file
try: try:
run_backtest( strategy.reset()
bt_results = run_backtest(
config=config, config=config,
datafile=datafile, datafile=datafile,
price_column=price_column, price_column=price_column,
bt_result=bt_results,
strategy=strategy, strategy=strategy,
instruments=instruments, instruments=instruments,
) )
@ -270,17 +268,18 @@ def main() -> None:
print(f"Successfully processed {filename}") print(f"Successfully processed {filename}")
except Exception as e: except Exception as err:
print(f"Error processing {datafile}: {str(e)}") print(f"Error processing {datafile}: {str(err)}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# Calculate and print results # Calculate and print results using a new BacktestResult instance for aggregation
if all_results: if all_results:
bt_results.calculate_returns(all_results) aggregate_bt_results = BacktestResult(config=config)
bt_results.print_grand_totals() aggregate_bt_results.calculate_returns(all_results)
bt_results.print_outstanding_positions() aggregate_bt_results.print_grand_totals()
aggregate_bt_results.print_outstanding_positions()
if args.result_db.upper() != "NONE": if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}") print(f"\nResults stored in database: {args.result_db}")

View File

@ -596,6 +596,9 @@ class BacktestResult:
open_px_a, open_px_b: Opening prices for symbols A and B open_px_a, open_px_b: Opening prices for symbols A and B
open_tstamp: Opening timestamp open_tstamp: Opening timestamp
""" """
if pair_result_df is None or pair_result_df.empty:
return 0, 0, 0
last_row = pair_result_df.loc[last_row_index] last_row = pair_result_df.loc[last_row_index]
last_tstamp = last_row["tstamp"] last_tstamp = last_row["tstamp"]
colname_a, colname_b = pair.colnames() colname_a, colname_b = pair.colnames()

View File

@ -1,10 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
import sys
from typing import Dict, Optional from typing import Dict, Optional, cast
import pandas as pd # type: ignore import pandas as pd # type: ignore[import]
from tools.trading_pair import TradingPair from tools.trading_pair import TradingPair
from results import BacktestResult from results import BacktestResult
@ -25,6 +24,10 @@ class PairsTradingStrategy(ABC):
def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]:
... ...
@abstractmethod
def reset(self):
...
class StaticFitStrategy(PairsTradingStrategy): class StaticFitStrategy(PairsTradingStrategy):
def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: # abstractmethod def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: # abstractmethod
@ -197,9 +200,12 @@ class StaticFitStrategy(PairsTradingStrategy):
# Add tuples to data frame # Add tuples to data frame
return pd.DataFrame( return pd.DataFrame(
trd_signal_tuples, trd_signal_tuples,
columns=self.TRADES_COLUMNS, columns=self.TRADES_COLUMNS, # type: ignore
) )
def reset(self):
pass
class PairState(Enum): class PairState(Enum):
INITIAL = 1 INITIAL = 1
OPEN = 2 OPEN = 2
@ -214,7 +220,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
print(f"***{pair}*** STARTING....") print(f"***{pair}*** STARTING....")
pair.user_data_['state'] = PairState.INITIAL pair.user_data_['state'] = PairState.INITIAL
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS) pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS) # type: ignore
pair.user_data_["is_cointegrated"] = False pair.user_data_["is_cointegrated"] = False
open_threshold = config["dis-equilibrium_open_trshld"] open_threshold = config["dis-equilibrium_open_trshld"]
@ -357,7 +363,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
] ]
return pd.DataFrame( return pd.DataFrame(
trd_signal_tuples, trd_signal_tuples,
columns=self.TRADES_COLUMNS, columns=self.TRADES_COLUMNS, # type: ignore
) )
def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]: def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]:
@ -404,9 +410,11 @@ class SlidingFitStrategy(PairsTradingStrategy):
# Add tuples to data frame # Add tuples to data frame
return pd.DataFrame( return pd.DataFrame(
trd_signal_tuples, trd_signal_tuples,
columns=self.TRADES_COLUMNS, columns=self.TRADES_COLUMNS, # type: ignore
) )
def reset(self):
self.curr_training_start_idx_ = 0

View File

@ -1,5 +1,5 @@
import sqlite3 import sqlite3
from typing import Dict from typing import Dict, List, cast
import pandas as pd import pandas as pd
@ -83,9 +83,53 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)] df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
df["tstamp"] = pd.to_datetime(df["tstamp"]) df["tstamp"] = pd.to_datetime(df["tstamp"])
return df return cast(pd.DataFrame, df)
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
"""
Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix.
"""
try:
conn = sqlite3.connect(datafile)
# Build exclusion list with full instrument_ids
exclude_instruments = config.get("exclude_instruments", [])
prefix = config.get("instrument_id_pfx", "")
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# Query to get distinct instrument_ids
query = f"""
SELECT DISTINCT instrument_id
FROM {config['db_table_name']}
WHERE exchange_id = ?
"""
# Add exclusion clause if there are instruments to exclude
if exclude_instrument_ids:
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
query += f" AND instrument_id NOT IN ({placeholders})"
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
else:
cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close()
# Remove the configured prefix to get instrument symbols
instruments = []
for instrument_id in instrument_ids:
if instrument_id.startswith(prefix):
symbol = instrument_id[len(prefix) :]
instruments.append(symbol)
else:
instruments.append(instrument_id)
return sorted(instruments)
except Exception as e:
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
return []
# if __name__ == "__main__": # if __name__ == "__main__":