This commit is contained in:
Oleg Sheynin 2025-06-20 18:06:04 -04:00
parent 95b25eddd7
commit 6cd82b3621
6 changed files with 114 additions and 65 deletions

View File

@ -2,15 +2,7 @@
"security_type": "EQUITY",
"data_directory": "./data/equity",
"datafiles": [
# "20250508.alpaca_sim_md.db",
"20250509.alpaca_sim_md.db",
# "20250512.alpaca_sim_md.db",
# "20250513.alpaca_sim_md.db",
# "20250514.alpaca_sim_md.db",
# "20250515.alpaca_sim_md.db",
# "20250516.alpaca_sim_md.db",
# "20250519.alpaca_sim_md.db",
# "20250520.alpaca_sim_md.db"
"202505*.alpaca_sim_md.db",
],
"db_table_name": "md_1min_bars",
"exchange_id": "ALPACA",
@ -29,4 +21,6 @@
"funding_per_pair": 2000.0,
# "strategy_class": "strategies.StaticFitStrategy"
"strategy_class": "strategies.SlidingFitStrategy"
"exclude_instruments": ["CAN"]
}

View File

@ -19,6 +19,7 @@ eyeD3>=0.8.10
filelock>=3.6.0
frozenlist>=1.3.3
grpcio>=1.30.2
hjson>=3.0.2
html5lib>=1.1
httplib2>=0.20.2
idna>=3.3
@ -39,7 +40,7 @@ multidict>=6.0.4
mypy>=0.942
mypy-extensions>=0.4.3
netaddr>=0.8.0
netifaces>=0.11.0
######### netifaces>=0.11.0
oauthlib>=3.2.0
packaging>=23.1
pathspec>=0.11.1
@ -72,7 +73,7 @@ statsmodels>=0.14.4
texttable>=1.6.4
tldextract>=3.1.2
tomli>=1.2.2
typed-ast>=1.4.3
######## typed-ast>=1.4.3
types-aiofiles>=0.1
types-annoy>=1.17
types-appdirs>=1.4

View File

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional
import pandas as pd
from tools.data_loader import load_market_data
from tools.data_loader import get_available_instruments_from_db, load_market_data
from tools.trading_pair import TradingPair
from results import BacktestResult, create_result_database, store_results_in_database, store_config_in_database
@ -21,40 +21,40 @@ def load_config(config_path: str) -> Dict:
return config
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
"""
Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix.
"""
try:
conn = sqlite3.connect(datafile)
# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
# """
# Auto-detect available instruments from the database by querying distinct instrument_id values.
# Returns instruments without the configured prefix.
# """
# try:
# conn = sqlite3.connect(datafile)
# Query to get distinct instrument_ids
query = f"""
SELECT DISTINCT instrument_id
FROM {config['db_table_name']}
WHERE exchange_id = ?
"""
# # Query to get distinct instrument_ids
# query = f"""
# SELECT DISTINCT instrument_id
# FROM {config['db_table_name']}
# WHERE exchange_id = ?
# """
cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close()
# cursor = conn.execute(query, (config["exchange_id"],))
# instrument_ids = [row[0] for row in cursor.fetchall()]
# conn.close()
# Remove the configured prefix to get instrument symbols
prefix = config.get("instrument_id_pfx", "")
instruments = []
for instrument_id in instrument_ids:
if instrument_id.startswith(prefix):
symbol = instrument_id[len(prefix) :]
instruments.append(symbol)
else:
instruments.append(instrument_id)
# # Remove the configured prefix to get instrument symbols
# prefix = config.get("instrument_id_pfx", "")
# instruments = []
# for instrument_id in instrument_ids:
# if instrument_id.startswith(prefix):
# symbol = instrument_id[len(prefix) :]
# instruments.append(symbol)
# else:
# instruments.append(instrument_id)
return sorted(instruments)
# return sorted(instruments)
except Exception as e:
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
return []
# except Exception as e:
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
# return []
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
@ -100,13 +100,13 @@ def run_backtest(
config: Dict,
datafile: str,
price_column: str,
bt_result: BacktestResult,
strategy,
instruments: List[str],
) -> None:
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile
@ -141,13 +141,14 @@ def run_backtest(
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
return None
return bt_result
result = pd.concat(pairs_trades, ignore_index=True)
result["time"] = pd.to_datetime(result["time"])
result = result.set_index("time").sort_index()
bt_result.collect_single_day_results(result)
return bt_result
def main() -> None:
@ -201,7 +202,6 @@ def main() -> None:
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
bt_results = BacktestResult(config=config)
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
@ -232,9 +232,6 @@ def main() -> None:
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Clear the trades for the new file
bt_results.clear_trades()
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
@ -251,11 +248,12 @@ def main() -> None:
# Process data for this file
try:
run_backtest(
strategy.reset()
bt_results = run_backtest(
config=config,
datafile=datafile,
price_column=price_column,
bt_result=bt_results,
strategy=strategy,
instruments=instruments,
)
@ -270,17 +268,18 @@ def main() -> None:
print(f"Successfully processed {filename}")
except Exception as e:
print(f"Error processing {datafile}: {str(e)}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
# Calculate and print results
# Calculate and print results using a new BacktestResult instance for aggregation
if all_results:
bt_results.calculate_returns(all_results)
bt_results.print_grand_totals()
bt_results.print_outstanding_positions()
aggregate_bt_results = BacktestResult(config=config)
aggregate_bt_results.calculate_returns(all_results)
aggregate_bt_results.print_grand_totals()
aggregate_bt_results.print_outstanding_positions()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")

View File

@ -596,6 +596,9 @@ class BacktestResult:
open_px_a, open_px_b: Opening prices for symbols A and B
open_tstamp: Opening timestamp
"""
if pair_result_df is None or pair_result_df.empty:
return 0, 0, 0
last_row = pair_result_df.loc[last_row_index]
last_tstamp = last_row["tstamp"]
colname_a, colname_b = pair.colnames()

View File

@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from enum import Enum
import sys
from typing import Dict, Optional
from typing import Dict, Optional, cast
import pandas as pd # type: ignore
import pandas as pd # type: ignore[import]
from tools.trading_pair import TradingPair
from results import BacktestResult
@ -25,6 +24,10 @@ class PairsTradingStrategy(ABC):
def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]:
...
@abstractmethod
def reset(self):
...
class StaticFitStrategy(PairsTradingStrategy):
def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: # abstractmethod
@ -197,9 +200,12 @@ class StaticFitStrategy(PairsTradingStrategy):
# Add tuples to data frame
return pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
columns=self.TRADES_COLUMNS, # type: ignore
)
def reset(self):
pass
class PairState(Enum):
INITIAL = 1
OPEN = 2
@ -214,7 +220,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
print(f"***{pair}*** STARTING....")
pair.user_data_['state'] = PairState.INITIAL
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS)
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS) # type: ignore
pair.user_data_["is_cointegrated"] = False
open_threshold = config["dis-equilibrium_open_trshld"]
@ -357,7 +363,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
]
return pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
columns=self.TRADES_COLUMNS, # type: ignore
)
def _get_close_trades(self, pair: TradingPair, close_threshold: float) -> Optional[pd.DataFrame]:
@ -404,9 +410,11 @@ class SlidingFitStrategy(PairsTradingStrategy):
# Add tuples to data frame
return pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
columns=self.TRADES_COLUMNS, # type: ignore
)
def reset(self):
self.curr_training_start_idx_ = 0

View File

@ -1,5 +1,5 @@
import sqlite3
from typing import Dict
from typing import Dict, List, cast
import pandas as pd
@ -83,9 +83,53 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
df["tstamp"] = pd.to_datetime(df["tstamp"])
return df
return cast(pd.DataFrame, df)
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
"""
Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix.
"""
try:
conn = sqlite3.connect(datafile)
# Build exclusion list with full instrument_ids
exclude_instruments = config.get("exclude_instruments", [])
prefix = config.get("instrument_id_pfx", "")
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# Query to get distinct instrument_ids
query = f"""
SELECT DISTINCT instrument_id
FROM {config['db_table_name']}
WHERE exchange_id = ?
"""
# Add exclusion clause if there are instruments to exclude
if exclude_instrument_ids:
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
query += f" AND instrument_id NOT IN ({placeholders})"
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
else:
cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close()
# Remove the configured prefix to get instrument symbols
instruments = []
for instrument_id in instrument_ids:
if instrument_id.startswith(prefix):
symbol = instrument_id[len(prefix) :]
instruments.append(symbol)
else:
instruments.append(instrument_id)
return sorted(instruments)
except Exception as e:
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
return []
# if __name__ == "__main__":