This commit is contained in:
Oleg Sheynin 2025-07-10 18:14:37 +00:00
parent 46072e03a2
commit 85c9d2ab93
15 changed files with 578 additions and 227 deletions

View File

@ -29,5 +29,5 @@
"dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120,
"funding_per_pair": 2000.0,
"strategy_class": "trading.strategies.StaticFitStrategy"
"fit_method_class": "pt_trading.fit_methods.StaticFit"
}

View File

@ -19,8 +19,7 @@
"dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120,
"funding_per_pair": 2000.0,
# "strategy_class": "strategies.StaticFitStrategy"
"strategy_class": "trading.strategies.SlidingFitStrategy"
"fit_method_class": "pt_trading.fit_methods.SlidingFit",
"exclude_instruments": ["CAN"]
}

View File

@ -1,16 +1,15 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import]
from trading.trading_pair import TradingPair
from trading.results import BacktestResult
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
NanoPerMin = 1e9
class PairsTradingStrategy(ABC):
class PairsTradingFitMethod(ABC):
TRADES_COLUMNS = [
"time",
"action",
@ -28,7 +27,7 @@ class PairsTradingStrategy(ABC):
def reset(self):
...
class StaticFitStrategy(PairsTradingStrategy):
class StaticFit(PairsTradingFitMethod):
def run_pair(self, config: Dict, pair: TradingPair, bt_result: BacktestResult) -> Optional[pd.DataFrame]: # abstractmethod
pair.get_datasets(training_minutes=config["training_minutes"])
@ -203,7 +202,7 @@ class StaticFitStrategy(PairsTradingStrategy):
columns=self.TRADES_COLUMNS, # type: ignore
)
def reset(self):
def reset(self) -> None:
pass
class PairState(Enum):
@ -211,8 +210,8 @@ class PairState(Enum):
OPEN = 2
CLOSED = 3
class SlidingFitStrategy(PairsTradingStrategy):
def __init__(self):
class SlidingFit(PairsTradingFitMethod):
def __init__(self) -> None:
super().__init__()
self.curr_training_start_idx_ = 0
@ -235,7 +234,7 @@ class SlidingFitStrategy(PairsTradingStrategy):
testing_size=1
)
if len(pair.training_df_) < training_minutes: # type: ignore
if len(pair.training_df_) < training_minutes:
print(f"{pair}: {self.curr_training_start_idx_} Not enough training data. Completing the job.")
if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: {self.curr_training_start_idx_} Position is not closed.")

View File

@ -11,18 +11,22 @@ def adapt_date_iso(val):
"""Adapt datetime.date to ISO 8601 date."""
return val.isoformat()
def adapt_datetime_iso(val):
"""Adapt datetime.datetime to timezone-naive ISO 8601 date."""
return val.isoformat()
def convert_date(val):
"""Convert ISO 8601 date to datetime.date object."""
return datetime.fromisoformat(val.decode()).date()
def convert_datetime(val):
"""Convert ISO 8601 datetime to datetime.datetime object."""
return datetime.fromisoformat(val.decode())
# Register the adapters and converters
sqlite3.register_adapter(date, adapt_date_iso)
sqlite3.register_adapter(datetime, adapt_datetime_iso)
@ -39,7 +43,8 @@ def create_result_database(db_path: str) -> None:
cursor = conn.cursor()
# Create the pt_bt_results table for completed trades
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pt_bt_results (
date DATE,
pair TEXT,
@ -57,11 +62,13 @@ def create_result_database(db_path: str) -> None:
symbol_return REAL,
pair_return REAL
)
''')
"""
)
cursor.execute("DELETE FROM pt_bt_results;")
# Create the outstanding_positions table for open positions
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS outstanding_positions (
date DATE,
pair TEXT,
@ -72,21 +79,24 @@ def create_result_database(db_path: str) -> None:
open_price REAL,
open_side TEXT
)
''')
"""
)
cursor.execute("DELETE FROM outstanding_positions;")
# Create the config table for storing configuration JSON for reference
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS config (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_timestamp DATETIME,
config_file_path TEXT,
config_json TEXT,
strategy_class TEXT,
fit_method_class TEXT,
datafiles TEXT,
instruments TEXT
)
''')
"""
)
cursor.execute("DELETE FROM config;")
conn.commit()
@ -97,7 +107,14 @@ def create_result_database(db_path: str) -> None:
raise
def store_config_in_database(db_path: str, config_file_path: str, config: Dict, strategy_class: str, datafiles: List[str], instruments: List[str]) -> None:
def store_config_in_database(
db_path: str,
config_file_path: str,
config: Dict,
fit_method_class: str,
datafiles: List[str],
instruments: List[str],
) -> None:
"""
Store configuration information in the database for reference.
"""
@ -114,22 +131,25 @@ def store_config_in_database(db_path: str, config_file_path: str, config: Dict,
config_json = json.dumps(config, indent=2, default=str)
# Convert lists to comma-separated strings for storage
datafiles_str = ', '.join(datafiles)
instruments_str = ', '.join(instruments)
datafiles_str = ", ".join(datafiles)
instruments_str = ", ".join(instruments)
# Insert configuration record
cursor.execute('''
cursor.execute(
"""
INSERT INTO config (
run_timestamp, config_file_path, config_json, strategy_class, datafiles, instruments
run_timestamp, config_file_path, config_json, fit_method_class, datafiles, instruments
) VALUES (?, ?, ?, ?, ?, ?)
''', (
datetime.now(),
config_file_path,
config_json,
strategy_class,
datafiles_str,
instruments_str
))
""",
(
datetime.now(),
config_file_path,
config_json,
fit_method_class,
datafiles_str,
instruments_str,
),
)
conn.commit()
conn.close()
@ -139,10 +159,13 @@ def store_config_in_database(db_path: str, config_file_path: str, config: Dict,
except Exception as e:
print(f"Error storing configuration in database: {str(e)}")
import traceback
traceback.print_exc()
def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestResult') -> None:
def store_results_in_database(
db_path: str, datafile: str, bt_result: "BacktestResult"
) -> None:
"""
Store backtest results in the SQLite database.
"""
@ -153,18 +176,18 @@ def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestR
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
if timestamp is None:
return None
if hasattr(timestamp, 'to_pydatetime'):
if hasattr(timestamp, "to_pydatetime"):
return timestamp.to_pydatetime()
return timestamp
try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
filename = os.path.basename(datafile)
date_str = filename.split('.')[0] # Extract date part
date_str = filename.split(".")[0] # Extract date part
# Convert to proper date format
try:
date_obj = datetime.strptime(date_str, '%Y%m%d').date()
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
# If date parsing fails, use current date
date_obj = datetime.now().date()
@ -183,7 +206,9 @@ def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestR
# First pass: collect all trades and calculate returns
for symbol, symbol_trades in symbols.items():
if len(symbol_trades) == 0: # No trades for this symbol
print(f"Warning: No trades found for symbol {symbol} in pair {pair_name}")
print(
f"Warning: No trades found for symbol {symbol} in pair {pair_name}"
)
continue
elif len(symbol_trades) >= 2: # Completed trades (entry + exit)
@ -198,14 +223,42 @@ def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestR
open_time = datetime.now()
close_time = datetime.now()
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
entry_action, entry_price, open_disequilibrium, open_scaled_disequilibrium, open_time = symbol_trades[0]
exit_action, exit_price, close_disequilibrium, close_scaled_disequilibrium, close_time = symbol_trades[1]
(
entry_action,
entry_price,
open_disequilibrium,
open_scaled_disequilibrium,
open_time,
) = symbol_trades[0]
(
exit_action,
exit_price,
close_disequilibrium,
close_scaled_disequilibrium,
close_time,
) = symbol_trades[1]
# Handle None values
open_disequilibrium = open_disequilibrium if open_disequilibrium is not None else 0.0
open_scaled_disequilibrium = open_scaled_disequilibrium if open_scaled_disequilibrium is not None else 0.0
close_disequilibrium = close_disequilibrium if close_disequilibrium is not None else 0.0
close_scaled_disequilibrium = close_scaled_disequilibrium if close_scaled_disequilibrium is not None else 0.0
open_disequilibrium = (
open_disequilibrium
if open_disequilibrium is not None
else 0.0
)
open_scaled_disequilibrium = (
open_scaled_disequilibrium
if open_scaled_disequilibrium is not None
else 0.0
)
close_disequilibrium = (
close_disequilibrium
if close_disequilibrium is not None
else 0.0
)
close_scaled_disequilibrium = (
close_scaled_disequilibrium
if close_scaled_disequilibrium is not None
else 0.0
)
# Convert pandas Timestamps to Python datetime objects
open_time = convert_timestamp(open_time) or datetime.now()
@ -225,104 +278,127 @@ def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestR
pair_return += symbol_return
pair_trades.append({
'symbol': symbol,
'entry_action': entry_action,
'entry_price': entry_price,
'exit_action': exit_action,
'exit_price': exit_price,
'symbol_return': symbol_return,
'open_disequilibrium': open_disequilibrium,
'open_scaled_disequilibrium': open_scaled_disequilibrium,
'close_disequilibrium': close_disequilibrium,
'close_scaled_disequilibrium': close_scaled_disequilibrium,
'open_time': open_time,
'close_time': close_time,
'shares': shares,
'is_completed': True
})
pair_trades.append(
{
"symbol": symbol,
"entry_action": entry_action,
"entry_price": entry_price,
"exit_action": exit_action,
"exit_price": exit_price,
"symbol_return": symbol_return,
"open_disequilibrium": open_disequilibrium,
"open_scaled_disequilibrium": open_scaled_disequilibrium,
"close_disequilibrium": close_disequilibrium,
"close_scaled_disequilibrium": close_scaled_disequilibrium,
"open_time": open_time,
"close_time": close_time,
"shares": shares,
"is_completed": True,
}
)
# Skip one-sided trades - they will be handled by outstanding_positions table
elif len(symbol_trades) == 1:
print(f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table")
print(
f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table"
)
continue
else:
# This should not happen, but handle unexpected cases
print(f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}")
print(
f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}"
)
continue
# Second pass: insert completed trade records into database
for trade in pair_trades:
# Only store completed trades in pt_bt_results table
cursor.execute('''
cursor.execute(
"""
INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
date_obj,
pair_name,
trade['symbol'],
trade['open_time'],
trade['entry_action'],
trade['entry_price'],
trade['shares'],
trade['open_scaled_disequilibrium'],
trade['close_time'],
trade['exit_action'],
trade['exit_price'],
trade['shares'],
trade['close_scaled_disequilibrium'],
trade['symbol_return'],
pair_return
))
""",
(
date_obj,
pair_name,
trade["symbol"],
trade["open_time"],
trade["entry_action"],
trade["entry_price"],
trade["shares"],
trade["open_scaled_disequilibrium"],
trade["close_time"],
trade["exit_action"],
trade["exit_price"],
trade["shares"],
trade["close_scaled_disequilibrium"],
trade["symbol_return"],
pair_return,
),
)
# Store outstanding positions in separate table
outstanding_positions = bt_result.get_outstanding_positions()
for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions)
position_qty_a = pos['shares_a'] if pos['side_a'] == 'BUY' else -pos['shares_a']
position_qty_b = pos['shares_b'] if pos['side_b'] == 'BUY' else -pos['shares_b']
position_qty_a = (
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = ((pos['current_px_a'] - pos['open_px_a']) / pos['open_px_a'] * 100) * (1 if pos['side_a'] == 'BUY' else -1)
unrealized_return_b = ((pos['current_px_b'] - pos['open_px_b']) / pos['open_px_b'] * 100) * (1 if pos['side_b'] == 'BUY' else -1)
unrealized_return_a = (
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A
cursor.execute('''
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
date_obj,
pos['pair'],
pos['symbol_a'],
position_qty_a,
pos['current_px_a'],
unrealized_return_a,
pos['open_px_a'],
pos['side_a']
))
""",
(
date_obj,
pos["pair"],
pos["symbol_a"],
position_qty_a,
pos["current_px_a"],
unrealized_return_a,
pos["open_px_a"],
pos["side_a"],
),
)
# Store outstanding position for symbol B
cursor.execute('''
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
date_obj,
pos['pair'],
pos['symbol_b'],
position_qty_b,
pos['current_px_b'],
unrealized_return_b,
pos['open_px_b'],
pos['side_b']
))
""",
(
date_obj,
pos["pair"],
pos["symbol_b"],
position_qty_b,
pos["current_px_b"],
unrealized_return_b,
pos["open_px_b"],
pos["side_b"],
),
)
conn.commit()
conn.close()
@ -330,6 +406,7 @@ def store_results_in_database(db_path: str, datafile: str, bt_result: 'BacktestR
except Exception as e:
print(f"Error storing results in database: {str(e)}")
import traceback
traceback.print_exc()
@ -344,7 +421,16 @@ class BacktestResult:
self.total_realized_pnl = 0.0
self.outstanding_positions: List[Dict[str, Any]] = []
def add_trade(self, pair_nm, symbol, action, price, disequilibrium=None, scaled_disequilibrium=None, timestamp=None):
def add_trade(
self,
pair_nm,
symbol,
action,
price,
disequilibrium=None,
scaled_disequilibrium=None,
timestamp=None,
):
"""Add a trade to the results tracking."""
pair_nm = str(pair_nm)
@ -352,7 +438,9 @@ class BacktestResult:
self.trades[pair_nm] = {symbol: []}
if symbol not in self.trades[pair_nm]:
self.trades[pair_nm][symbol] = []
self.trades[pair_nm][symbol].append((action, price, disequilibrium, scaled_disequilibrium, timestamp))
self.trades[pair_nm][symbol].append(
(action, price, disequilibrium, scaled_disequilibrium, timestamp)
)
def add_outstanding_position(self, position: Dict[str, Any]):
"""Add an outstanding position to tracking."""
@ -390,13 +478,17 @@ class BacktestResult:
action = row.action
symbol = row.symbol
price = row.price
disequilibrium = getattr(row, 'disequilibrium', None)
scaled_disequilibrium = getattr(row, 'scaled_disequilibrium', None)
timestamp = getattr(row, 'time', None)
disequilibrium = getattr(row, "disequilibrium", None)
scaled_disequilibrium = getattr(row, "scaled_disequilibrium", None)
timestamp = getattr(row, "time", None)
self.add_trade(
pair_nm=row.pair, action=action, symbol=symbol, price=price,
disequilibrium=disequilibrium, scaled_disequilibrium=scaled_disequilibrium,
timestamp=timestamp
pair_nm=row.pair,
action=action,
symbol=symbol,
price=price,
disequilibrium=disequilibrium,
scaled_disequilibrium=scaled_disequilibrium,
timestamp=timestamp,
)
def print_single_day_results(self):
@ -447,19 +539,31 @@ class BacktestResult:
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
entry_action, entry_price = trades[0][:2]
exit_action, exit_price = trades[1][:2]
open_disequilibrium = trades[0][2] if len(trades[0]) > 2 else None
open_scaled_disequilibrium = trades[0][3] if len(trades[0]) > 3 else None
close_disequilibrium = trades[1][2] if len(trades[1]) > 2 else None
close_scaled_disequilibrium = trades[1][3] if len(trades[1]) > 3 else None
open_disequilibrium = (
trades[0][2] if len(trades[0]) > 2 else None
)
open_scaled_disequilibrium = (
trades[0][3] if len(trades[0]) > 3 else None
)
close_disequilibrium = (
trades[1][2] if len(trades[1]) > 2 else None
)
close_scaled_disequilibrium = (
trades[1][3] if len(trades[1]) > 3 else None
)
# Calculate return based on action
symbol_return = 0
if entry_action == "BUY" and exit_action == "SELL":
# Long position
symbol_return = (exit_price - entry_price) / entry_price * 100
symbol_return = (
(exit_price - entry_price) / entry_price * 100
)
elif entry_action == "SELL" and exit_action == "BUY":
# Short position
symbol_return = (entry_price - exit_price) / entry_price * 100
symbol_return = (
(entry_price - exit_price) / entry_price * 100
)
pair_trades.append(
(
@ -489,7 +593,10 @@ class BacktestResult:
close_scaled_disequilibrium,
) in pair_trades:
disequil_info = ""
if open_scaled_disequilibrium is not None and close_scaled_disequilibrium is not None:
if (
open_scaled_disequilibrium is not None
and close_scaled_disequilibrium is not None
):
disequil_info = f" | Open Dis-eq: {open_scaled_disequilibrium:.2f}, Close Dis-eq: {close_scaled_disequilibrium:.2f}"
print(
@ -582,9 +689,17 @@ class BacktestResult:
print(f"\n====== GRAND TOTALS ACROSS ALL PAIRS ======")
print(f"Total Realized PnL: {self.get_total_realized_pnl():.2f}%")
def handle_outstanding_position(self, pair, pair_result_df, last_row_index,
open_side_a, open_side_b, open_px_a, open_px_b,
open_tstamp):
def handle_outstanding_position(
self,
pair,
pair_result_df,
last_row_index,
open_side_a,
open_side_b,
open_px_a,
open_px_b,
open_tstamp,
):
"""
Handle calculation and tracking of outstanding positions when no close signal is found.
@ -648,9 +763,15 @@ class BacktestResult:
# Print position details
print(f"{pair}: NO CLOSE SIGNAL FOUND - Position held until end of session")
print(f" Open: {open_tstamp} | Last: {last_tstamp}")
print(f" {pair.symbol_a_}: {open_side_a} {shares_a:.2f} shares @ ${open_px_a:.2f} -> ${last_px_a:.2f} | Value: ${current_value_a:.2f}")
print(f" {pair.symbol_b_}: {open_side_b} {shares_b:.2f} shares @ ${open_px_b:.2f} -> ${last_px_b:.2f} | Value: ${current_value_b:.2f}")
print(
f" {pair.symbol_a_}: {open_side_a} {shares_a:.2f} shares @ ${open_px_a:.2f} -> ${last_px_a:.2f} | Value: ${current_value_a:.2f}"
)
print(
f" {pair.symbol_b_}: {open_side_b} {shares_b:.2f} shares @ ${open_px_b:.2f} -> ${last_px_b:.2f} | Value: ${current_value_b:.2f}"
)
print(f" Total Value: ${total_current_value:.2f}")
print(f" Disequilibrium: {current_disequilibrium:.4f} | Scaled: {current_scaled_disequilibrium:.4f}")
print(
f" Disequilibrium: {current_disequilibrium:.4f} | Scaled: {current_scaled_disequilibrium:.4f}"
)
return current_value_a, current_value_b, total_current_value

17
lib/tools/config.py Normal file
View File

@ -0,0 +1,17 @@
import hjson
from typing import Dict
from datetime import datetime
def load_config(config_path: str) -> Dict:
with open(config_path, "r") as f:
config = hjson.load(f)
return dict(config)
def expand_filename(filename: str) -> str:
# expand %T
res = filename.replace("%T", datetime.now().strftime("%Y%m%d_%H%M%S"))
# expand %D
return res.replace("%D", datetime.now().strftime("%Y%m%d"))

View File

@ -25,7 +25,7 @@ def list_tables(db_path: str) -> List[str]:
conn.close()
return tables
def view_table_schema(db_path: str, table_name: str):
def view_table_schema(db_path: str, table_name: str) -> None:
"""View the schema of a specific table."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
@ -44,13 +44,13 @@ def view_table_schema(db_path: str, table_name: str):
conn.close()
def view_config_table(db_path: str, limit: int = 10):
def view_config_table(db_path: str, limit: int = 10) -> None:
"""View entries from the config table."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(f"""
SELECT id, run_timestamp, config_file_path, strategy_class,
SELECT id, run_timestamp, config_file_path, fit_method_class,
datafiles, instruments, config_json
FROM config
ORDER BY run_timestamp DESC
@ -67,17 +67,17 @@ def view_config_table(db_path: str, limit: int = 10):
print("=" * 80)
for row in rows:
id, run_timestamp, config_file_path, strategy_class, datafiles, instruments, config_json = row
id, run_timestamp, config_file_path, fit_method_class, datafiles, instruments, config_json = row
print(f"ID: {id} | {run_timestamp}")
print(f"Config: {config_file_path} | Strategy: {strategy_class}")
print(f"Config: {config_file_path} | Strategy: {fit_method_class}")
print(f"Files: {datafiles}")
print(f"Instruments: {instruments}")
print("-" * 80)
conn.close()
def view_results_summary(db_path: str):
def view_results_summary(db_path: str) -> None:
"""View summary of trading results."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
@ -119,7 +119,7 @@ def view_results_summary(db_path: str):
conn.close()
def main():
def main() -> None:
if len(sys.argv) < 2:
print("Usage: python db_inspector.py <database_path> [command]")
print("Commands:")

View File

@ -1,6 +1,6 @@
{
"include": [
"src"
"lib"
],
"exclude": [
"**/node_modules",
@ -16,7 +16,7 @@
"autoImportCompletions": true,
"autoSearchPaths": true,
"extraPaths": [
"src"
"lib"
],
"stubPath": "./typings",
"venvPath": ".",

View File

@ -62,7 +62,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -87,10 +87,10 @@
"from IPython.display import clear_output\n",
"\n",
"# Import our modules\n",
"from strategies import StaticFitStrategy, SlidingFitStrategy, PairState\n",
"from pt_trading.fit_methods import StaticFit, SlidingFit, PairState\n",
"from tools.data_loader import load_market_data\n",
"from trading.trading_pair import TradingPair\n",
"from trading.results import BacktestResult\n",
"from pt_trading.trading_pair import TradingPair\n",
"from pt_trading.results import BacktestResult\n",
"\n",
"# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n",
@ -113,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -149,34 +149,34 @@
" print(f\"Unexpected error loading config from {config_file}: {e}\")\n",
" return None\n",
"\n",
"def instantiate_strategy_from_config(config: Dict):\n",
"def instantiate_fit_method_from_config(config: Dict):\n",
" \"\"\"Dynamically instantiate strategy from config\"\"\"\n",
" strategy_class_name = config.get(\"strategy_class\", \"strategies.StaticFitStrategy\")\n",
" \n",
" fit_method_class_name = config.get(\"fit_method_class\", None)\n",
" assert fit_method_class_name is not None\n",
" try:\n",
" # Split module and class name\n",
" if '.' in strategy_class_name:\n",
" module_name, class_name = strategy_class_name.rsplit('.', 1)\n",
" if '.' in fit_method_class_name:\n",
" module_name, class_name = fit_method_class_name.rsplit('.', 1)\n",
" else:\n",
" module_name = \"strategies\"\n",
" class_name = strategy_class_name\n",
" module_name = \"fit_methods\"\n",
" class_name = fit_method_class_name\n",
" \n",
" # Import module and get class\n",
" module = importlib.import_module(module_name)\n",
" strategy_class = getattr(module, class_name)\n",
" fit_method_class = getattr(module, class_name)\n",
" \n",
" # Instantiate strategy\n",
" return strategy_class()\n",
" return fit_method_class()\n",
" \n",
" except Exception as e:\n",
" print(f\"Error instantiating strategy {strategy_class_name}: {e}\")\n",
" print(f\"Error instantiating strategy {fit_method_class_name}: {e}\")\n",
" print(\"Falling back to StaticFitStrategy\")\n",
" return StaticFitStrategy()\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -230,7 +230,7 @@
" print(f\" Close threshold: {pt_bt_config['dis-equilibrium_close_trshld']}\")\n",
" \n",
" # Instantiate strategy from config\n",
" STRATEGY = instantiate_strategy_from_config(pt_bt_config)\n",
" STRATEGY = instantiate_fit_method_from_config(pt_bt_config)\n",
" print(f\" Strategy: {type(STRATEGY).__name__}\")\n",
" \n",
" # Automatically construct data file name based on date and config type\n",
@ -576,7 +576,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -831,12 +831,12 @@
" max_demo_iterations = min(200, max_iterations)\n",
" print(f\"Processing first {max_demo_iterations} iterations for demonstration...\")\n",
" \n",
" # Initialize pair state for sliding strategy\n",
" # Initialize pair state for sliding fit method\n",
" pair.user_data_['state'] = PairState.INITIAL\n",
" pair.user_data_[\"trades\"] = pd.DataFrame(columns=pd.Index(STRATEGY.TRADES_COLUMNS, dtype=str))\n",
" pair.user_data_[\"is_cointegrated\"] = False\n",
" \n",
" # Run the sliding strategy\n",
" # Run the sliding fit method\n",
" pair_trades = STRATEGY.run_pair(config=pt_bt_config, pair=pair, bt_result=bt_result)\n",
" \n",
" if pair_trades is not None and len(pair_trades) > 0:\n",

View File

@ -111,10 +111,10 @@
"from IPython.display import clear_output\n",
"\n",
"# Import our modules\n",
"from strategies import SlidingFitStrategy, PairState\n",
"from pt_trading.fit_methods import SlidingFit, PairState\n",
"from tools.data_loader import load_market_data\n",
"from trading.trading_pair import TradingPair\n",
"from trading.results import BacktestResult\n",
"from pt_trading.trading_pair import TradingPair\n",
"from pt_trading.results import BacktestResult\n",
"\n",
"# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n",

View File

@ -73,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -96,10 +96,10 @@
"from typing import Dict, List, Optional\n",
"\n",
"# Import our modules\n",
"from strategies import StaticFitStrategy, SlidingFitStrategy\n",
"from pt_trading.fit_methods import StaticFit, SlidingFit\n",
"from tools.data_loader import load_market_data\n",
"from trading.trading_pair import TradingPair\n",
"from trading.results import BacktestResult\n",
"from pt_trading.trading_pair import TradingPair\n",
"from pt_trading.results import BacktestResult\n",
"\n",
"# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n",
@ -226,7 +226,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -246,11 +246,11 @@
"DATA_FILE = CONFIG[\"datafiles\"][\"0509\"]\n",
"\n",
"# Choose strategy\n",
"STRATEGY = StaticFitStrategy()\n",
"FIT_METHOD = StaticFit()\n",
"\n",
"print(f\"Selected pair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Strategy: {type(STRATEGY).__name__}\")"
"print(f\"Strategy: {type(FIT_METHOD).__name__}\")"
]
},
{
@ -548,7 +548,7 @@
"\n",
" # Run strategy\n",
" bt_result = BacktestResult(config=CONFIG)\n",
" pair_trades = STRATEGY.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n",
" pair_trades = FIT_METHOD.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n",
"\n",
" if pair_trades is not None and len(pair_trades) > 0:\n",
" print(f\"\\nGenerated {len(pair_trades)} trading signals:\")\n",
@ -674,7 +674,7 @@
"print(\"=\" * 60)\n",
"\n",
"print(f\"\\nPair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Strategy: {type(STRATEGY).__name__}\")\n",
"print(f\"Strategy: {type(FIT_METHOD).__name__}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Training period: {training_minutes} minutes\")\n",
"\n",

View File

@ -1,29 +1,22 @@
import argparse
import hjson
import importlib
import glob
import importlib
import os
from datetime import datetime, date
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import pandas as pd
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db, load_market_data
from trading.strategies import PairsTradingStrategy
from trading.trading_pair import TradingPair
from trading.results import (
from pt_trading.results import (
BacktestResult,
create_result_database,
store_results_in_database,
store_config_in_database,
store_results_in_database,
)
def load_config(config_path: str) -> Dict:
with open(config_path, "r") as f:
config = hjson.load(f)
return dict(config)
from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
@ -69,7 +62,7 @@ def run_backtest(
config: Dict,
datafile: str,
price_column: str,
strategy: PairsTradingStrategy,
fit_method: PairsTradingFitMethod,
instruments: List[str],
) -> BacktestResult:
"""
@ -101,7 +94,7 @@ def run_backtest(
pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = strategy.run_pair(
single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
@ -148,11 +141,12 @@ def main() -> None:
config: Dict = load_config(args.config)
# Dynamically instantiate strategy class
strategy_class_name = config.get("strategy_class", "strategies.StaticFitStrategy")
module_name, class_name = strategy_class_name.rsplit(".", 1)
# Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
strategy = getattr(module, class_name)()
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles)
@ -167,6 +161,7 @@ def main() -> None:
# Create result database if needed
if args.result_db.upper() != "NONE":
args.result_db = expand_filename(args.result_db)
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
@ -192,7 +187,7 @@ def main() -> None:
db_path=args.result_db,
config_file_path=args.config,
config=config,
strategy_class=strategy_class_name,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
)
@ -219,13 +214,13 @@ def main() -> None:
# Process data for this file
try:
strategy.reset()
fit_method.reset()
bt_results = run_backtest(
config=config,
datafile=datafile,
price_column=price_column,
strategy=strategy,
fit_method=fit_method,
instruments=instruments,
)

220
strategy/pair_strategy.py Normal file
View File

@ -0,0 +1,220 @@
import argparse
import asyncio
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import hjson
import pandas as pd
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def run_strategy(
config: Dict,
datafile: str,
price_column: str,
fit_method: PairsTradingFitMethod,
instruments: List[str],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy)
for a_index, b_index in unique_index_pairs:
pair = TradingPair(
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
price_column=price_column,
)
pairs.append(pair)
return pairs
pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair(
pair=pair, config=config, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
return bt_result
result = pd.concat(pairs_trades, ignore_index=True)
result["time"] = pd.to_datetime(result["time"])
result = result.set_index("time").sort_index()
bt_result.collect_single_day_results(result)
return bt_result
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafiles",
type=str,
required=False,
help="Comma-separated list of data files (overrides config). No wildcards supported.",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
parser.add_argument(
"--result_db",
type=str,
required=True,
help="Path to SQLite database for storing results. Use 'NONE' to disable database output.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles)
if not datafiles:
print("No data files found to process.")
return
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
# Create result database if needed
if args.result_db.upper() != "NONE":
create_result_database(args.result_db)
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
)
# Process each data file
price_column = config["price_column"]
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found for {datafile}, skipping...")
continue
# Process data for this file
try:
fit_method.reset()
bt_results = run_strategy(
config=config,
datafile=datafile,
price_column=price_column,
fit_method=fit_method,
instruments=instruments,
)
# Store results with file name as key
filename = os.path.basename(datafile)
all_results[filename] = {"trades": bt_results.trades.copy()}
# Store results in database
if args.result_db.upper() != "NONE":
store_results_in_database(args.result_db, datafile, bt_results)
print(f"Successfully processed {filename}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
# Calculate and print results using a new BacktestResult instance for aggregation
if all_results:
aggregate_bt_results = BacktestResult(config=config)
aggregate_bt_results.calculate_returns(all_results)
aggregate_bt_results.print_grand_totals()
aggregate_bt_results.print_outstanding_positions()
if args.result_db.upper() != "NONE":
print(f"\nResults stored in database: {args.result_db}")
else:
print("No results to display.")
if __name__ == "__main__":
asyncio.run(main())