This commit is contained in:
Oleg Sheynin 2025-07-10 18:14:37 +00:00
parent 46072e03a2
commit 85c9d2ab93
15 changed files with 578 additions and 227 deletions

View File

@ -29,5 +29,5 @@
"dis-equilibrium_close_trshld": 0.5, "dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120, "training_minutes": 120,
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
"strategy_class": "trading.strategies.StaticFitStrategy" "fit_method_class": "pt_trading.fit_methods.StaticFit"
} }

View File

@ -19,8 +19,7 @@
"dis-equilibrium_close_trshld": 1.0, "dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120, "training_minutes": 120,
"funding_per_pair": 2000.0, "funding_per_pair": 2000.0,
# "strategy_class": "strategies.StaticFitStrategy" "fit_method_class": "pt_trading.fit_methods.SlidingFit",
"strategy_class": "trading.strategies.SlidingFitStrategy"
"exclude_instruments": ["CAN"] "exclude_instruments": ["CAN"]
} }

View File

@ -1,16 +1,15 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Dict, Optional, cast from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import] import pandas as pd # type: ignore[import]
from trading.trading_pair import TradingPair from pt_trading.results import BacktestResult
from trading.results import BacktestResult from pt_trading.trading_pair import TradingPair
NanoPerMin = 1e9 NanoPerMin = 1e9
class PairsTradingStrategy(ABC): class PairsTradingFitMethod(ABC):
TRADES_COLUMNS = [ TRADES_COLUMNS = [
"time", "time",
"action", "action",
@ -28,7 +27,7 @@ class PairsTradingStrategy(ABC):
def reset(self): def reset(self):
... ...
class StaticFitStrategy(PairsTradingStrategy): class StaticFit(PairsTradingFitMethod):
def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: # abstractmethod def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: # abstractmethod
pair.get_datasets(training_minutes=config["training_minutes"]) pair.get_datasets(training_minutes=config["training_minutes"])
@ -203,7 +202,7 @@ class StaticFitStrategy(PairsTradingStrategy):
columns=self.TRADES_COLUMNS, # type: ignore columns=self.TRADES_COLUMNS, # type: ignore
) )
def reset(self): def reset(self) -> None:
pass pass
class PairState(Enum): class PairState(Enum):
@ -211,8 +210,8 @@ class PairState(Enum):
OPEN = 2 OPEN = 2
CLOSED = 3 CLOSED = 3
class SlidingFitStrategy(PairsTradingStrategy): class SlidingFit(PairsTradingFitMethod):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.curr_training_start_idx_ = 0 self.curr_training_start_idx_ = 0
@ -235,7 +234,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
testing_size=1 testing_size=1
) )
if len(pair.training_df_) < training_minutes: # type: ignore if len(pair.training_df_) < training_minutes:
print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.") print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.")
if pair.user_data_["state"] == PairState.OPEN: if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.") print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.")

View File

@ -11,18 +11,22 @@ def adapt_date_iso(val):
"""Adapt datetime.date to ISO 8601 date.""" """Adapt datetime.date to ISO 8601 date."""
return val.isoformat() return val.isoformat()
def adapt_datetime_iso(val): def adapt_datetime_iso(val):
"""Adapt datetime.datetime to timezone-naive ISO 8601 date.""" """Adapt datetime.datetime to timezone-naive ISO 8601 date."""
return val.isoformat() return val.isoformat()
def convert_date(val): def convert_date(val):
"""Convert ISO 8601 date to datetime.date object.""" """Convert ISO 8601 date to datetime.date object."""
return datetime.fromisoformat(val.decode()).date() return datetime.fromisoformat(val.decode()).date()
def convert_datetime(val): def convert_datetime(val):
"""Convert ISO 8601 datetime to datetime.datetime object.""" """Convert ISO 8601 datetime to datetime.datetime object."""
return datetime.fromisoformat(val.decode()) return datetime.fromisoformat(val.decode())
# Register the adapters and converters # Register the adapters and converters
sqlite3.register_adapter(date, adapt_date_iso) sqlite3.register_adapter(date, adapt_date_iso)
sqlite3.register_adapter(datetime, adapt_datetime_iso) sqlite3.register_adapter(datetime, adapt_datetime_iso)
@ -37,9 +41,10 @@ def create_result_database(db_path: str) -> None:
try: try:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
# Create the pt_bt_results table for completed trades # Create the pt_bt_results table for completed trades
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pt_bt_results ( CREATE TABLE IF NOT EXISTS pt_bt_results (
date DATE, date DATE,
pair TEXT, pair TEXT,
@ -57,11 +62,13 @@ def create_result_database(db_path: str) -> None:
symbol_return REAL, symbol_return REAL,
pair_return REAL pair_return REAL
) )
''') """
)
cursor.execute("DELETE FROM pt_bt_results;") cursor.execute("DELETE FROM pt_bt_results;")
# Create the outstanding_positions table for open positions # Create the outstanding_positions table for open positions
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS outstanding_positions ( CREATE TABLE IF NOT EXISTS outstanding_positions (
date DATE, date DATE,
pair TEXT, pair TEXT,
@ -72,120 +79,138 @@ def create_result_database(db_path: str) -> None:
open_price REAL, open_price REAL,
open_side TEXT open_side TEXT
) )
''') """
)
cursor.execute("DELETE FROM outstanding_positions;") cursor.execute("DELETE FROM outstanding_positions;")
# Create the config table for storing configuration JSON for reference # Create the config table for storing configuration JSON for reference
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS config ( CREATE TABLE IF NOT EXISTS config (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
run_timestamp DATETIME, run_timestamp DATETIME,
config_file_path TEXT, config_file_path TEXT,
config_json TEXT, config_json TEXT,
strategy_class TEXT, fit_method_class TEXT,
datafiles TEXT, datafiles TEXT,
instruments TEXT instruments TEXT
) )
''') """
)
cursor.execute("DELETE FROM config;") cursor.execute("DELETE FROM config;")
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error creating result database: {str(e)}") print(f"Error creating result database: {str(e)}")
raise raise
def store_config_in_database(db_path: str, config_file_path: str, config: Dict, strategy_class: str, datafiles: List[str], instruments: List[str]) -> None: def store_config_in_database(
db_path: str,
config_file_path: str,
config: Dict,
fit_method_class: str,
datafiles: List[str],
instruments: List[str],
) -> None:
""" """
Store configuration information in the database for reference. Store configuration information in the database for reference.
""" """
import json import json
if db_path.upper() == "NONE": if db_path.upper() == "NONE":
return return
try: try:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
# Convert config to JSON string # Convert config to JSON string
config_json = json.dumps(config, indent=2, default=str) config_json = json.dumps(config, indent=2, default=str)
# Convert lists to comma-separated strings for storage # Convert lists to comma-separated strings for storage
datafiles_str = ', '.join(datafiles) datafiles_str = ", ".join(datafiles)
instruments_str = ', '.join(instruments) instruments_str = ", ".join(instruments)
# Insert configuration record # Insert configuration record
cursor.execute(''' cursor.execute(
"""
INSERT INTO config ( INSERT INTO config (
run_timestamp, config_file_path, config_json, strategy_class, datafiles, instruments run_timestamp, config_file_path, config_json, fit_method_class, datafiles, instruments
) VALUES (?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?)
''', ( """,
datetime.now(), (
config_file_path, datetime.now(),
config_json, config_file_path,
strategy_class, config_json,
datafiles_str, fit_method_class,
instruments_str datafiles_str,
)) instruments_str,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
print(f"Configuration stored in database") print(f"Configuration stored in database")
except Exception as e: except Exception as e:
print(f"Error storing configuration in database: {str(e)}") print(f"Error storing configuration in database: {str(e)}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestResult') -> None: def store_results_in_database(
db_path: str, datafile: str, bt_result: "BacktestResult"
) -> None:
""" """
Store backtest results in the SQLite database. Store backtest results in the SQLite database.
""" """
if db_path.upper() == "NONE": if db_path.upper() == "NONE":
return return
def convert_timestamp(timestamp): def convert_timestamp(timestamp):
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility.""" """Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
if timestamp is None: if timestamp is None:
return None return None
if hasattr(timestamp, 'to_pydatetime'): if hasattr(timestamp, "to_pydatetime"):
return timestamp.to_pydatetime() return timestamp.to_pydatetime()
return timestamp return timestamp
try: try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db) # Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
filename = os.path.basename(datafile) filename = os.path.basename(datafile)
date_str = filename.split('.')[0] # Extract date part date_str = filename.split(".")[0] # Extract date part
# Convert to proper date format # Convert to proper date format
try: try:
date_obj = datetime.strptime(date_str, '%Y%m%d').date() date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError: except ValueError:
# If date parsing fails, use current date # If date parsing fails, use current date
date_obj = datetime.now().date() date_obj = datetime.now().date()
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
# Process each trade from bt_result # Process each trade from bt_result
trades = bt_result.get_trades() trades = bt_result.get_trades()
for pair_name, symbols in trades.items(): for pair_name, symbols in trades.items():
# Calculate pair return for this pair # Calculate pair return for this pair
pair_return = 0.0 pair_return = 0.0
pair_trades = [] pair_trades = []
# First pass: collect all trades and calculate returns # First pass: collect all trades and calculate returns
for symbol, symbol_trades in symbols.items(): for symbol, symbol_trades in symbols.items():
if len(symbol_trades) == 0: # No trades for this symbol if len(symbol_trades) == 0: # No trades for this symbol
print(f"Warning: No trades found for symbol {symbol} in pair {pair_name}") print(
f"Warning: No trades found for symbol {symbol} in pair {pair_name}"
)
continue continue
elif len(symbol_trades) >= 2: # Completed trades (entry + exit) elif len(symbol_trades) >= 2: # Completed trades (entry + exit)
# Handle both old and new tuple formats # Handle both old and new tuple formats
if len(symbol_trades[0]) == 2: # Old format: (action, price) if len(symbol_trades[0]) == 2: # Old format: (action, price)
@ -198,138 +223,190 @@ def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestR
open_time = datetime.now() open_time = datetime.now()
close_time = datetime.now() close_time = datetime.now()
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp) else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
entry_action, entry_price, open_disequilibrium, open_scaled_disequilibrium, open_time = symbol_trades[0] (
exit_action, exit_price, close_disequilibrium, close_scaled_disequilibrium, close_time = symbol_trades[1] entry_action,
entry_price,
open_disequilibrium,
open_scaled_disequilibrium,
open_time,
) = symbol_trades[0]
(
exit_action,
exit_price,
close_disequilibrium,
close_scaled_disequilibrium,
close_time,
) = symbol_trades[1]
# Handle None values # Handle None values
open_disequilibrium = open_disequilibrium if open_disequilibrium is not None else 0.0 open_disequilibrium = (
open_scaled_disequilibrium = open_scaled_disequilibrium if open_scaled_disequilibrium is not None else 0.0 open_disequilibrium
close_disequilibrium = close_disequilibrium if close_disequilibrium is not None else 0.0 if open_disequilibrium is not None
close_scaled_disequilibrium = close_scaled_disequilibrium if close_scaled_disequilibrium is not None else 0.0 else 0.0
)
open_scaled_disequilibrium = (
open_scaled_disequilibrium
if open_scaled_disequilibrium is not None
else 0.0
)
close_disequilibrium = (
close_disequilibrium
if close_disequilibrium is not None
else 0.0
)
close_scaled_disequilibrium = (
close_scaled_disequilibrium
if close_scaled_disequilibrium is not None
else 0.0
)
# Convert pandas Timestamps to Python datetime objects # Convert pandas Timestamps to Python datetime objects
open_time = convert_timestamp(open_time) or datetime.now() open_time = convert_timestamp(open_time) or datetime.now()
close_time = convert_timestamp(close_time) or datetime.now() close_time = convert_timestamp(close_time) or datetime.now()
# Calculate actual share quantities based on funding per pair # Calculate actual share quantities based on funding per pair
# Split funding equally between the two positions # Split funding equally between the two positions
funding_per_position = bt_result.config["funding_per_pair"] / 2 funding_per_position = bt_result.config["funding_per_pair"] / 2
shares = funding_per_position / entry_price shares = funding_per_position / entry_price
# Calculate symbol return # Calculate symbol return
symbol_return = 0.0 symbol_return = 0.0
if entry_action == "BUY" and exit_action == "SELL": if entry_action == "BUY" and exit_action == "SELL":
symbol_return = (exit_price - entry_price) / entry_price * 100 symbol_return = (exit_price - entry_price) / entry_price * 100
elif entry_action == "SELL" and exit_action == "BUY": elif entry_action == "SELL" and exit_action == "BUY":
symbol_return = (entry_price - exit_price) / entry_price * 100 symbol_return = (entry_price - exit_price) / entry_price * 100
pair_return += symbol_return pair_return += symbol_return
pair_trades.append({ pair_trades.append(
'symbol': symbol, {
'entry_action': entry_action, "symbol": symbol,
'entry_price': entry_price, "entry_action": entry_action,
'exit_action': exit_action, "entry_price": entry_price,
'exit_price': exit_price, "exit_action": exit_action,
'symbol_return': symbol_return, "exit_price": exit_price,
'open_disequilibrium': open_disequilibrium, "symbol_return": symbol_return,
'open_scaled_disequilibrium': open_scaled_disequilibrium, "open_disequilibrium": open_disequilibrium,
'close_disequilibrium': close_disequilibrium, "open_scaled_disequilibrium": open_scaled_disequilibrium,
'close_scaled_disequilibrium': close_scaled_disequilibrium, "close_disequilibrium": close_disequilibrium,
'open_time': open_time, "close_scaled_disequilibrium": close_scaled_disequilibrium,
'close_time': close_time, "open_time": open_time,
'shares': shares, "close_time": close_time,
'is_completed': True "shares": shares,
}) "is_completed": True,
}
)
# Skip one-sided trades - they will be handled by outstanding_positions table # Skip one-sided trades - they will be handled by outstanding_positions table
elif len(symbol_trades) == 1: elif len(symbol_trades) == 1:
print(f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table") print(
f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table"
)
continue continue
else: else:
# This should not happen, but handle unexpected cases # This should not happen, but handle unexpected cases
print(f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}") print(
f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}"
)
continue continue
# Second pass: insert completed trade records into database # Second pass: insert completed trade records into database
for trade in pair_trades: for trade in pair_trades:
# Only store completed trades in pt_bt_results table # Only store completed trades in pt_bt_results table
cursor.execute(''' cursor.execute(
"""
INSERT INTO pt_bt_results ( INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price, date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side, open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium, close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return symbol_return, pair_return
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', ( """,
date_obj, (
pair_name, date_obj,
trade['symbol'], pair_name,
trade['open_time'], trade["symbol"],
trade['entry_action'], trade["open_time"],
trade['entry_price'], trade["entry_action"],
trade['shares'], trade["entry_price"],
trade['open_scaled_disequilibrium'], trade["shares"],
trade['close_time'], trade["open_scaled_disequilibrium"],
trade['exit_action'], trade["close_time"],
trade['exit_price'], trade["exit_action"],
trade['shares'], trade["exit_price"],
trade['close_scaled_disequilibrium'], trade["shares"],
trade['symbol_return'], trade["close_scaled_disequilibrium"],
pair_return trade["symbol_return"],
)) pair_return,
),
)
# Store outstanding positions in separate table # Store outstanding positions in separate table
outstanding_positions = bt_result.get_outstanding_positions() outstanding_positions = bt_result.get_outstanding_positions()
for pos in outstanding_positions: for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions) # Calculate position quantity (negative for SELL positions)
position_qty_a = pos['shares_a'] if pos['side_a'] == 'BUY' else -pos['shares_a'] position_qty_a = (
position_qty_b = pos['shares_b'] if pos['side_b'] == 'BUY' else -pos['shares_b'] pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns # Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction # For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = ((pos['current_px_a'] - pos['open_px_a']) / pos['open_px_a'] * 100) * (1 if pos['side_a'] == 'BUY' else -1) unrealized_return_a = (
unrealized_return_b = ((pos['current_px_b'] - pos['open_px_b']) / pos['open_px_b'] * 100) * (1 if pos['side_b'] == 'BUY' else -1) (pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A # Store outstanding position for symbol A
cursor.execute(''' cursor.execute(
"""
INSERT INTO outstanding_positions ( INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', ( """,
date_obj, (
pos['pair'], date_obj,
pos['symbol_a'], pos["pair"],
position_qty_a, pos["symbol_a"],
pos['current_px_a'], position_qty_a,
unrealized_return_a, pos["current_px_a"],
pos['open_px_a'], unrealized_return_a,
pos['side_a'] pos["open_px_a"],
)) pos["side_a"],
),
)
# Store outstanding position for symbol B # Store outstanding position for symbol B
cursor.execute(''' cursor.execute(
"""
INSERT INTO outstanding_positions ( INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', ( """,
date_obj, (
pos['pair'], date_obj,
pos['symbol_b'], pos["pair"],
position_qty_b, pos["symbol_b"],
pos['current_px_b'], position_qty_b,
unrealized_return_b, pos["current_px_b"],
pos['open_px_b'], unrealized_return_b,
pos['side_b'] pos["open_px_b"],
)) pos["side_b"],
),
)
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error storing results in database: {str(e)}") print(f"Error storing results in database: {str(e)}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -344,7 +421,16 @@ class BacktestResult:
self.total_realized_pnl = 0.0 self.total_realized_pnl = 0.0
self.outstanding_positions: List[Dict[str, Any]] = [] self.outstanding_positions: List[Dict[str, Any]] = []
def add_trade(self, pair_nm, symbol, action, price, disequilibrium=None, scaled_disequilibrium=None, timestamp=None): def add_trade(
self,
pair_nm,
symbol,
action,
price,
disequilibrium=None,
scaled_disequilibrium=None,
timestamp=None,
):
"""Add a trade to the results tracking.""" """Add a trade to the results tracking."""
pair_nm = str(pair_nm) pair_nm = str(pair_nm)
@ -352,7 +438,9 @@ class BacktestResult:
self.trades[pair_nm] = {symbol: []} self.trades[pair_nm] = {symbol: []}
if symbol not in self.trades[pair_nm]: if symbol not in self.trades[pair_nm]:
self.trades[pair_nm][symbol] = [] self.trades[pair_nm][symbol] = []
self.trades[pair_nm][symbol].append((action, price, disequilibrium, scaled_disequilibrium, timestamp)) self.trades[pair_nm][symbol].append(
(action, price, disequilibrium, scaled_disequilibrium, timestamp)
)
def add_outstanding_position(self, position: Dict[str, Any]): def add_outstanding_position(self, position: Dict[str, Any]):
"""Add an outstanding position to tracking.""" """Add an outstanding position to tracking."""
@ -390,13 +478,17 @@ class BacktestResult:
action = row.action action = row.action
symbol = row.symbol symbol = row.symbol
price = row.price price = row.price
disequilibrium = getattr(row, 'disequilibrium', None) disequilibrium = getattr(row, "disequilibrium", None)
scaled_disequilibrium = getattr(row, 'scaled_disequilibrium', None) scaled_disequilibrium = getattr(row, "scaled_disequilibrium", None)
timestamp = getattr(row, 'time', None) timestamp = getattr(row, "time", None)
self.add_trade( self.add_trade(
pair_nm=row.pair, action=action, symbol=symbol, price=price, pair_nm=row.pair,
disequilibrium=disequilibrium, scaled_disequilibrium=scaled_disequilibrium, action=action,
timestamp=timestamp symbol=symbol,
price=price,
disequilibrium=disequilibrium,
scaled_disequilibrium=scaled_disequilibrium,
timestamp=timestamp,
) )
def print_single_day_results(self): def print_single_day_results(self):
@ -447,19 +539,31 @@ class BacktestResult:
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp) else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
entry_action, entry_price = trades[0][:2] entry_action, entry_price = trades[0][:2]
exit_action, exit_price = trades[1][:2] exit_action, exit_price = trades[1][:2]
open_disequilibrium = trades[0][2] if len(trades[0]) > 2 else None open_disequilibrium = (
open_scaled_disequilibrium = trades[0][3] if len(trades[0]) > 3 else None trades[0][2] if len(trades[0]) > 2 else None
close_disequilibrium = trades[1][2] if len(trades[1]) > 2 else None )
close_scaled_disequilibrium = trades[1][3] if len(trades[1]) > 3 else None open_scaled_disequilibrium = (
trades[0][3] if len(trades[0]) > 3 else None
)
close_disequilibrium = (
trades[1][2] if len(trades[1]) > 2 else None
)
close_scaled_disequilibrium = (
trades[1][3] if len(trades[1]) > 3 else None
)
# Calculate return based on action # Calculate return based on action
symbol_return = 0 symbol_return = 0
if entry_action == "BUY" and exit_action == "SELL": if entry_action == "BUY" and exit_action == "SELL":
# Long position # Long position
symbol_return = (exit_price - entry_price) / entry_price * 100 symbol_return = (
(exit_price - entry_price) / entry_price * 100
)
elif entry_action == "SELL" and exit_action == "BUY": elif entry_action == "SELL" and exit_action == "BUY":
# Short position # Short position
symbol_return = (entry_price - exit_price) / entry_price * 100 symbol_return = (
(entry_price - exit_price) / entry_price * 100
)
pair_trades.append( pair_trades.append(
( (
@ -489,9 +593,12 @@ class BacktestResult:
close_scaled_disequilibrium, close_scaled_disequilibrium,
) in pair_trades: ) in pair_trades:
disequil_info = "" disequil_info = ""
if open_scaled_disequilibrium is not None and close_scaled_disequilibrium is not None: if (
open_scaled_disequilibrium is not None
and close_scaled_disequilibrium is not None
):
disequil_info = f" | Open Dis-eq: {open_scaled_disequilibrium:.2f}, Close Dis-eq: {close_scaled_disequilibrium:.2f}" disequil_info = f" | Open Dis-eq: {open_scaled_disequilibrium:.2f}, Close Dis-eq: {close_scaled_disequilibrium:.2f}"
print( print(
f" {symbol}: {entry_action} @ ${entry_price:.2f}, {exit_action} @ ${exit_price:.2f}, Return: {symbol_return:.2f}%{disequil_info}" f" {symbol}: {entry_action} @ ${entry_price:.2f}, {exit_action} @ ${exit_price:.2f}, Return: {symbol_return:.2f}%{disequil_info}"
) )
@ -582,9 +689,17 @@ class BacktestResult:
print(f"\n====== GRAND TOTALS ACROSS ALL PAIRS ======") print(f"\n====== GRAND TOTALS ACROSS ALL PAIRS ======")
print(f"Total Realized PnL: {self.get_total_realized_pnl():.2f}%") print(f"Total Realized PnL: {self.get_total_realized_pnl():.2f}%")
def handle_outstanding_position(self, pair, pair_result_df, last_row_index, def handle_outstanding_position(
open_side_a, open_side_b, open_px_a, open_px_b, self,
open_tstamp): pair,
pair_result_df,
last_row_index,
open_side_a,
open_side_b,
open_px_a,
open_px_b,
open_tstamp,
):
""" """
Handle calculation and tracking of outstanding positions when no close signal is found. Handle calculation and tracking of outstanding positions when no close signal is found.
@ -648,9 +763,15 @@ class BacktestResult:
# Print position details # Print position details
print(f"{pair}: NO CLOSE SIGNAL FOUND - Position held until end of session") print(f"{pair}: NO CLOSE SIGNAL FOUND - Position held until end of session")
print(f" Open: {open_tstamp} | Last: {last_tstamp}") print(f" Open: {open_tstamp} | Last: {last_tstamp}")
print(f" {pair.symbol_a_}: {open_side_a} {shares_a:.2f} shares @ ${open_px_a:.2f} -> ${last_px_a:.2f} | Value: ${current_value_a:.2f}") print(
print(f" {pair.symbol_b_}: {open_side_b} {shares_b:.2f} shares @ ${open_px_b:.2f} -> ${last_px_b:.2f} | Value: ${current_value_b:.2f}") f" {pair.symbol_a_}: {open_side_a} {shares_a:.2f} shares @ ${open_px_a:.2f} -> ${last_px_a:.2f} | Value: ${current_value_a:.2f}"
)
print(
f" {pair.symbol_b_}: {open_side_b} {shares_b:.2f} shares @ ${open_px_b:.2f} -> ${last_px_b:.2f} | Value: ${current_value_b:.2f}"
)
print(f" Total Value: ${total_current_value:.2f}") print(f" Total Value: ${total_current_value:.2f}")
print(f" Disequilibrium: {current_disequilibrium:.4f} | Scaled: {current_scaled_disequilibrium:.4f}") print(
f" Disequilibrium: {current_disequilibrium:.4f} | Scaled: {current_scaled_disequilibrium:.4f}"
)
return current_value_a, current_value_b, total_current_value return current_value_a, current_value_b, total_current_value

17
lib/tools/config.py Normal file
View File

@ -0,0 +1,17 @@
import hjson
from typing import Dict
from datetime import datetime
def load_config(config_path: str) -> Dict:
with open(config_path, "r") as f:
config = hjson.load(f)
return dict(config)
def expand_filename(filename: str) -> str:
# expand %T
res = filename.replace("%T", datetime.now().strftime("%Y%m%d_%H%M%S"))
# expand %D
return res.replace("%D", datetime.now().strftime("%Y%m%d"))

View File

@ -25,7 +25,7 @@ def list_tables(db_path: str) -> List[str]:
conn.close() conn.close()
return tables return tables
def view_table_schema(db_path: str, table_name: str): def view_table_schema(db_path: str, table_name: str) -> None:
"""View the schema of a specific table.""" """View the schema of a specific table."""
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
@ -44,13 +44,13 @@ def view_table_schema(db_path: str, table_name: str):
conn.close() conn.close()
def view_config_table(db_path: str, limit: int = 10): def view_config_table(db_path: str, limit: int = 10) -> None:
"""View entries from the config table.""" """View entries from the config table."""
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f""" cursor.execute(f"""
SELECT id, run_timestamp, config_file_path, strategy_class, SELECT id, run_timestamp, config_file_path, fit_method_class,
datafiles, instruments, config_json datafiles, instruments, config_json
FROM config FROM config
ORDER BY run_timestamp DESC ORDER BY run_timestamp DESC
@ -67,17 +67,17 @@ def view_config_table(db_path: str, limit: int = 10):
print("=" * 80) print("=" * 80)
for row in rows: for row in rows:
id, run_timestamp, config_file_path, strategy_class, datafiles, instruments, config_json = row id, run_timestamp, config_file_path, fit_method_class, datafiles, instruments, config_json = row
print(f"ID: {id} | {run_timestamp}") print(f"ID: {id} | {run_timestamp}")
print(f"Config: {config_file_path} | Strategy: {strategy_class}") print(f"Config: {config_file_path} | Strategy: {fit_method_class}")
print(f"Files: {datafiles}") print(f"Files: {datafiles}")
print(f"Instruments: {instruments}") print(f"Instruments: {instruments}")
print("-" * 80) print("-" * 80)
conn.close() conn.close()
def view_results_summary(db_path: str): def view_results_summary(db_path: str) -> None:
"""View summary of trading results.""" """View summary of trading results."""
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
@ -119,7 +119,7 @@ def view_results_summary(db_path: str):
conn.close() conn.close()
def main(): def main() -> None:
if len(sys.argv) < 2: if len(sys.argv) < 2:
print("Usage: python db_inspector.py <database_path> [command]") print("Usage: python db_inspector.py <database_path> [command]")
print("Commands:") print("Commands:")

View File

@ -1,6 +1,6 @@
{ {
"include": [ "include": [
"src" "lib"
], ],
"exclude": [ "exclude": [
"**/node_modules", "**/node_modules",
@ -16,7 +16,7 @@
"autoImportCompletions": true, "autoImportCompletions": true,
"autoSearchPaths": true, "autoSearchPaths": true,
"extraPaths": [ "extraPaths": [
"src" "lib"
], ],
"stubPath": "./typings", "stubPath": "./typings",
"venvPath": ".", "venvPath": ".",

View File

@ -62,7 +62,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -87,10 +87,10 @@
"from IPython.display import clear_output\n", "from IPython.display import clear_output\n",
"\n", "\n",
"# Import our modules\n", "# Import our modules\n",
"from strategies import StaticFitStrategy, SlidingFitStrategy, PairState\n", "from pt_trading.fit_methods import StaticFit, SlidingFit, PairState\n",
"from tools.data_loader import load_market_data\n", "from tools.data_loader import load_market_data\n",
"from trading.trading_pair import TradingPair\n", "from pt_trading.trading_pair import TradingPair\n",
"from trading.results import BacktestResult\n", "from pt_trading.results import BacktestResult\n",
"\n", "\n",
"# Set plotting style\n", "# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n", "plt.style.use('seaborn-v0_8')\n",
@ -113,7 +113,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -149,34 +149,34 @@
" print(f\"Unexpected error loading config from {config_file}: {e}\")\n", " print(f\"Unexpected error loading config from {config_file}: {e}\")\n",
" return None\n", " return None\n",
"\n", "\n",
"def instantiate_strategy_from_config(config: Dict):\n", "def instantiate_fit_method_from_config(config: Dict):\n",
" \"\"\"Dynamically instantiate strategy from config\"\"\"\n", " \"\"\"Dynamically instantiate strategy from config\"\"\"\n",
" strategy_class_name = config.get(\"strategy_class\", \"strategies.StaticFitStrategy\")\n", " fit_method_class_name = config.get(\"fit_method_class\", None)\n",
" \n", " assert fit_method_class_name is not None\n",
" try:\n", " try:\n",
" # Split module and class name\n", " # Split module and class name\n",
" if '.' in strategy_class_name:\n", " if '.' in fit_method_class_name:\n",
" module_name, class_name = strategy_class_name.rsplit('.', 1)\n", " module_name, class_name = fit_method_class_name.rsplit('.', 1)\n",
" else:\n", " else:\n",
" module_name = \"strategies\"\n", " module_name = \"fit_methods\"\n",
" class_name = strategy_class_name\n", " class_name = fit_method_class_name\n",
" \n", " \n",
" # Import module and get class\n", " # Import module and get class\n",
" module = importlib.import_module(module_name)\n", " module = importlib.import_module(module_name)\n",
" strategy_class = getattr(module, class_name)\n", " fit_method_class = getattr(module, class_name)\n",
" \n", " \n",
" # Instantiate strategy\n", " # Instantiate strategy\n",
" return strategy_class()\n", " return fit_method_class()\n",
" \n", " \n",
" except Exception as e:\n", " except Exception as e:\n",
" print(f\"Error instantiating strategy {strategy_class_name}: {e}\")\n", " print(f\"Error instantiating strategy {fit_method_class_name}: {e}\")\n",
" print(\"Falling back to StaticFitStrategy\")\n", " print(\"Falling back to StaticFitStrategy\")\n",
" return StaticFitStrategy()\n" " return StaticFitStrategy()\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -230,7 +230,7 @@
" print(f\" Close threshold: {pt_bt_config['dis-equilibrium_close_trshld']}\")\n", " print(f\" Close threshold: {pt_bt_config['dis-equilibrium_close_trshld']}\")\n",
" \n", " \n",
" # Instantiate strategy from config\n", " # Instantiate strategy from config\n",
" STRATEGY = instantiate_strategy_from_config(pt_bt_config)\n", " STRATEGY = instantiate_fit_method_from_config(pt_bt_config)\n",
" print(f\" Strategy: {type(STRATEGY).__name__}\")\n", " print(f\" Strategy: {type(STRATEGY).__name__}\")\n",
" \n", " \n",
" # Automatically construct data file name based on date and config type\n", " # Automatically construct data file name based on date and config type\n",
@ -576,7 +576,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -831,12 +831,12 @@
" max_demo_iterations = min(200, max_iterations)\n", " max_demo_iterations = min(200, max_iterations)\n",
" print(f\"Processing first {max_demo_iterations} iterations for demonstration...\")\n", " print(f\"Processing first {max_demo_iterations} iterations for demonstration...\")\n",
" \n", " \n",
" # Initialize pair state for sliding strategy\n", " # Initialize pair state for sliding fit method\n",
" pair.user_data_['state'] = PairState.INITIAL\n", " pair.user_data_['state'] = PairState.INITIAL\n",
" pair.user_data_[\"trades\"] = pd.DataFrame(columns=pd.Index(STRATEGY.TRADES_COLUMNS, dtype=str))\n", " pair.user_data_[\"trades\"] = pd.DataFrame(columns=pd.Index(STRATEGY.TRADES_COLUMNS, dtype=str))\n",
" pair.user_data_[\"is_cointegrated\"] = False\n", " pair.user_data_[\"is_cointegrated\"] = False\n",
" \n", " \n",
" # Run the sliding strategy\n", " # Run the sliding fit method\n",
" pair_trades = STRATEGY.run_pair(config=pt_bt_config, pair=pair, bt_result=bt_result)\n", " pair_trades = STRATEGY.run_pair(config=pt_bt_config, pair=pair, bt_result=bt_result)\n",
" \n", " \n",
" if pair_trades is not None and len(pair_trades) > 0:\n", " if pair_trades is not None and len(pair_trades) > 0:\n",

View File

@ -111,10 +111,10 @@
"from IPython.display import clear_output\n", "from IPython.display import clear_output\n",
"\n", "\n",
"# Import our modules\n", "# Import our modules\n",
"from strategies import SlidingFitStrategy, PairState\n", "from pt_trading.fit_methods import SlidingFit, PairState\n",
"from tools.data_loader import load_market_data\n", "from tools.data_loader import load_market_data\n",
"from trading.trading_pair import TradingPair\n", "from pt_trading.trading_pair import TradingPair\n",
"from trading.results import BacktestResult\n", "from pt_trading.results import BacktestResult\n",
"\n", "\n",
"# Set plotting style\n", "# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n", "plt.style.use('seaborn-v0_8')\n",

View File

@ -73,7 +73,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -96,10 +96,10 @@
"from typing import Dict, List, Optional\n", "from typing import Dict, List, Optional\n",
"\n", "\n",
"# Import our modules\n", "# Import our modules\n",
"from strategies import StaticFitStrategy, SlidingFitStrategy\n", "from pt_trading.fit_methods import StaticFit, SlidingFit\n",
"from tools.data_loader import load_market_data\n", "from tools.data_loader import load_market_data\n",
"from trading.trading_pair import TradingPair\n", "from pt_trading.trading_pair import TradingPair\n",
"from trading.results import BacktestResult\n", "from pt_trading.results import BacktestResult\n",
"\n", "\n",
"# Set plotting style\n", "# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n", "plt.style.use('seaborn-v0_8')\n",
@ -226,7 +226,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -246,11 +246,11 @@
"DATA_FILE = CONFIG[\"datafiles\"][\"0509\"]\n", "DATA_FILE = CONFIG[\"datafiles\"][\"0509\"]\n",
"\n", "\n",
"# Choose strategy\n", "# Choose strategy\n",
"STRATEGY = StaticFitStrategy()\n", "FIT_METHOD = StaticFit()\n",
"\n", "\n",
"print(f\"Selected pair: {SYMBOL_A} & {SYMBOL_B}\")\n", "print(f\"Selected pair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n", "print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Strategy: {type(STRATEGY).__name__}\")" "print(f\"Strategy: {type(FIT_METHOD).__name__}\")"
] ]
}, },
{ {
@ -548,7 +548,7 @@
"\n", "\n",
" # Run strategy\n", " # Run strategy\n",
" bt_result = BacktestResult(config=CONFIG)\n", " bt_result = BacktestResult(config=CONFIG)\n",
" pair_trades = STRATEGY.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n", " pair_trades = FIT_METHOD.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n",
"\n", "\n",
" if pair_trades is not None and len(pair_trades) > 0:\n", " if pair_trades is not None and len(pair_trades) > 0:\n",
" print(f\"\\nGenerated {len(pair_trades)} trading signals:\")\n", " print(f\"\\nGenerated {len(pair_trades)} trading signals:\")\n",
@ -674,7 +674,7 @@
"print(\"=\" * 60)\n", "print(\"=\" * 60)\n",
"\n", "\n",
"print(f\"\\nPair: {SYMBOL_A} & {SYMBOL_B}\")\n", "print(f\"\\nPair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Strategy: {type(STRATEGY).__name__}\")\n", "print(f\"Strategy: {type(FIT_METHOD).__name__}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n", "print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Training period: {training_minutes} minutes\")\n", "print(f\"Training period: {training_minutes} minutes\")\n",
"\n", "\n",

View File

@ -1,29 +1,22 @@
import argparse import argparse
import hjson
import importlib
import glob import glob
import importlib
import os import os
from datetime import datetime, date from datetime import date, datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pandas as pd import pandas as pd
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db, load_market_data from tools.data_loader import get_available_instruments_from_db, load_market_data
from trading.strategies import PairsTradingStrategy from pt_trading.results import (
from trading.trading_pair import TradingPair
from trading.results import (
BacktestResult, BacktestResult,
create_result_database, create_result_database,
store_results_in_database,
store_config_in_database, store_config_in_database,
store_results_in_database,
) )
from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def load_config(config_path: str) -> Dict:
with open(config_path, "r") as f:
config = hjson.load(f)
return dict(config)
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
@ -69,7 +62,7 @@ def run_backtest(
config: Dict, config: Dict,
datafile: str, datafile: str,
price_column: str, price_column: str,
strategy: PairsTradingStrategy, fit_method: PairsTradingFitMethod,
instruments: List[str], instruments: List[str],
) -> BacktestResult: ) -> BacktestResult:
""" """
@ -101,7 +94,7 @@ def run_backtest(
pairs_trades = [] pairs_trades = []
for pair in _create_pairs(config, instruments): for pair in _create_pairs(config, instruments):
single_pair_trades = strategy.run_pair( single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result pair=pair, config=config, bt_result=bt_result
) )
if single_pair_trades is not None and len(single_pair_trades) > 0: if single_pair_trades is not None and len(single_pair_trades) > 0:
@ -148,11 +141,12 @@ def main() -> None:
config: Dict = load_config(args.config) config: Dict = load_config(args.config)
# Dynamically instantiate strategy class # Dynamically instantiate fit method class
strategy_class_name = config.get("strategy_class", "strategies.StaticFitStrategy") fit_method_class_name = config.get("fit_method_class", None)
module_name, class_name = strategy_class_name.rsplit(".", 1) assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
strategy = getattr(module, class_name)() fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config) # Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles) datafiles = resolve_datafiles(config, args.datafiles)
@ -167,6 +161,7 @@ def main() -> None:
# Create result database if needed # Create result database if needed
if args.result_db.upper() != "NONE": if args.result_db.upper() != "NONE":
args.result_db = expand_filename(args.result_db)
create_result_database(args.result_db) create_result_database(args.result_db)
# Initialize a dictionary to store all trade results # Initialize a dictionary to store all trade results
@ -192,7 +187,7 @@ def main() -> None:
db_path=args.result_db, db_path=args.result_db,
config_file_path=args.config, config_file_path=args.config,
config=config, config=config,
strategy_class=strategy_class_name, fit_method_class=fit_method_class_name,
datafiles=datafiles, datafiles=datafiles,
instruments=unique_instruments, instruments=unique_instruments,
) )
@ -219,13 +214,13 @@ def main() -> None:
# Process data for this file # Process data for this file
try: try:
strategy.reset() fit_method.reset()
bt_results = run_backtest( bt_results = run_backtest(
config=config, config=config,
datafile=datafile, datafile=datafile,
price_column=price_column, price_column=price_column,
strategy=strategy, fit_method=fit_method,
instruments=instruments, instruments=instruments,
) )

220
strategy/pair_strategy.py Normal file
View File

@ -0,0 +1,220 @@
import argparse
import asyncio
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import hjson
import pandas as pd
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def run_strategy(
config: Dict,
datafile: str,
price_column: str,
fit_method: PairsTradingFitMethod,
instruments: List[str],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy)
for a_index, b_index in unique_index_pairs:
pair = TradingPair(
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
price_column=price_column,
)
pairs.append(pair)
return pairs
pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
return bt_result
result = pd.concat(pairs_trades, ignore_index=True)
result["time"] = pd.to_datetime(result["time"])
result = result.set_index("time").sort_index()
bt_result.collect_single_day_results(result)
return bt_result
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafiles",
type=str,
required=False,
help="Comma-separated list of data files (overrides config). No wildcards supported.",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
parser.add_argument(
"--result_db",
type=str,
required=True,
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles)
if not datafiles:
print("No data files found to process.")
return
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
# Create result database if needed
if args.result_db.upper() != "NONE":
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
)
# Process each data file
price_column = config["price_column"]
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found for {datafile}, skipping...")
continue
# Process data for this file
try:
fit_method.reset()
bt_results = run_strategy(
config=config,
datafile=datafile,
price_column=price_column,
fit_method=fit_method,
instruments=instruments,
)
# Store results with file name as key
filename = os.path.basename(datafile)
all_results[filename] = {"trades": bt_results.trades.copy()}
# Store results in database
if args.result_db.upper() != "NONE":
store_results_in_database(args.result_db, datafile, bt_results)
print(f"Successfully processed {filename}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
# Calculate and print results using a new BacktestResult instance for aggregation
if all_results:
aggregate_bt_results = BacktestResult(config=config)
aggregate_bt_results.calculate_returns(all_results)
aggregate_bt_results.print_grand_totals()
aggregate_bt_results.print_outstanding_positions()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")
else:
print("No results to display.")
if __name__ == "__main__":
asyncio.run(main())