Compare commits

..

No commits in common. "6f845d32c62d7f2480094fcfb9bf3bbfa5a6012f" and "a04e8878fb592a439649ec9082cbad254a92d83e" have entirely different histories.

26 changed files with 6550 additions and 8420 deletions

View File

@ -11,7 +11,6 @@ The enhanced `pt_backtest.py` script now supports multi-day and multi-instrument
- Support for wildcard patterns in configuration files
- CLI override for data file specification
### 2. Dynamic Instrument Selection
- Auto-detection of instruments from database
- CLI override for instrument specification

View File

@ -38,12 +38,15 @@ CONFIG = EQT_CONFIG # For equity data
```
Each configuration dictionary specifies:
- `security_type`: "CRYPTO" or "EQUITY".
- `data_directory`: Path to the data files.
- `datafiles`: A list of database files to process. You can comment/uncomment specific files to include/exclude them from the backtest.
- `db_table_name`: The name of the table within the SQLite database.
- `instruments`: A list of symbols to consider for forming trading pairs.
- `trading_hours`: Defines the session start and end times, crucial for equity markets.
- `stat_model_price`: The column in the data to be used as the price (e.g., "close").
- `price_column`: The column in the data to be used as the price (e.g., "close").
- `min_required_points`: Minimum data points needed for statistical calculations.
- `zero_threshold`: A small value to handle potential division by zero.
- `dis-equilibrium_open_trshld`: The threshold (in standard deviations) of the dis-equilibrium for opening a trade.
- `dis-equilibrium_close_trshld`: The threshold (in standard deviations) of the dis-equilibrium for closing an open trade.
- `training_minutes`: The length of the rolling window (in minutes) used to train the model (e.g., calculate cointegration, mean, and standard deviation of the dis-equilibrium).

31
configuration/crypto.cfg Normal file
View File

@ -0,0 +1,31 @@
{
"security_type": "CRYPTO",
"data_directory": "./data/crypto",
"datafiles": [
"2025*.mktdata.ohlcv.db"
],
"db_table_name": "md_1min_bars",
"exchange_id": "BNBSPOT",
"instrument_id_pfx": "PAIR-",
"trading_hours": {
"begin_session": "00:00:00",
"end_session": "23:59:00",
"timezone": "UTC"
},
"price_column": "close",
"min_required_points": 30,
"zero_threshold": 1e-10,
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120,
"funding_per_pair": 2000.0,
"fit_method_class": "pt_trading.sliding_fit.SlidingFit",
# "fit_method_class": "pt_trading.static_fit.StaticFit",
"close_outstanding_positions": true,
"trading_hours": {
"begin_session": "06:00:00",
"end_session": "16:00:00",
"timezone": "America/New_York"
}
}

View File

@ -0,0 +1,27 @@
{
"security_type": "EQUITY",
"data_directory": "./data/equity",
# "datafiles": [
# "20250604.mktdata.ohlcv.db",
# ],
"db_table_name": "md_1min_bars",
"exchange_id": "ALPACA",
"instrument_id_pfx": "STOCK-",
"trading_hours": {
"begin_session": "9:30:00",
"end_session": "16:00:00",
"timezone": "America/New_York"
},
"price_column": "close",
"min_required_points": 30,
"zero_threshold": 1e-10,
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120,
"funding_per_pair": 2000.0,
"fit_method_class": "pt_trading.sliding_fit.SlidingFit",
# "fit_method_class": "pt_trading.static_fit.StaticFit",
"exclude_instruments": ["CAN"],
"close_outstanding_positions": false
}

View File

@ -1,43 +0,0 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"stat_model_price": "close", # "vwap"
"execution_price": {
"column": "vwap",
"shift": 1,
},
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120,
"fit_method_class": "pt_trading.vecm_rolling_fit.VECMRollingFit",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
"close_outstanding_positions": true,
# "close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00",
"end_session": "18:30:00",
}
}

View File

@ -1,42 +0,0 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"stat_model_price": "close",
"execution_price": {
"column": "vwap",
"shift": 1,
},
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120,
"fit_method_class": "pt_trading.z-score_rolling_fit.ZScoreRollingFit",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
"close_outstanding_positions": true,
# "close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00",
"end_session": "18:30:00",
}
}

View File

@ -1,10 +1,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd
import pandas as pd # type: ignore[import]
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
@ -14,24 +12,13 @@ NanoPerMin = 1e9
class PairsTradingFitMethod(ABC):
TRADES_COLUMNS = [
"time",
"symbol",
"side",
"action",
"symbol",
"price",
"disequilibrium",
"scaled_disequilibrium",
"signed_scaled_disequilibrium",
"pair",
]
@staticmethod
def create(config: Dict) -> PairsTradingFitMethod:
import importlib
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
return cast(PairsTradingFitMethod, fit_method)
@abstractmethod
def run_pair(
@ -41,12 +28,9 @@ class PairsTradingFitMethod(ABC):
@abstractmethod
def reset(self) -> None: ...
@abstractmethod
def create_trading_pair(
self,
config: Dict,
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
) -> TradingPair: ...
class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSED = 3
CLOSED_POSITIONS = 4

View File

@ -68,8 +68,7 @@ def create_result_database(db_path: str) -> None:
close_quantity INTEGER,
close_disequilibrium REAL,
symbol_return REAL,
pair_return REAL,
close_condition TEXT
pair_return REAL
)
"""
)
@ -121,8 +120,8 @@ def store_config_in_database(
config_file_path: str,
config: Dict,
fit_method_class: str,
datafiles: List[Tuple[str, str]],
instruments: List[Dict[str, str]],
datafiles: List[str],
instruments: List[str],
) -> None:
"""
Store configuration information in the database for reference.
@ -140,13 +139,8 @@ def store_config_in_database(
config_json = json.dumps(config, indent=2, default=str)
# Convert lists to comma-separated strings for storage
datafiles_str = ", ".join([f"{datafile}" for _, datafile in datafiles])
instruments_str = ", ".join(
[
f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}"
for inst in instruments
]
)
datafiles_str = ", ".join(datafiles)
instruments_str = ", ".join(instruments)
# Insert configuration record
cursor.execute(
@ -177,23 +171,251 @@ def store_config_in_database(
traceback.print_exc()
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
if timestamp is None:
return None
if isinstance(timestamp, pd.Timestamp):
return timestamp.to_pydatetime()
elif isinstance(timestamp, datetime):
return timestamp
elif isinstance(timestamp, date):
return datetime.combine(timestamp, datetime.min.time())
elif isinstance(timestamp, str):
return datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
elif isinstance(timestamp, int):
return datetime.fromtimestamp(timestamp)
else:
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
def store_results_in_database(
db_path: str, datafile: str, bt_result: "BacktestResult"
) -> None:
"""
Store backtest results in the SQLite database.
"""
if db_path.upper() == "NONE":
return
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
if timestamp is None:
return None
if hasattr(timestamp, "to_pydatetime"):
return timestamp.to_pydatetime()
return timestamp
try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
filename = os.path.basename(datafile)
date_str = filename.split(".")[0] # Extract date part
# Convert to proper date format
try:
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
# If date parsing fails, use current date
date_obj = datetime.now().date()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Process each trade from bt_result
trades = bt_result.get_trades()
for pair_name, symbols in trades.items():
# Calculate pair return for this pair
pair_return = 0.0
pair_trades = []
# First pass: collect all trades and calculate returns
for symbol, symbol_trades in symbols.items():
if len(symbol_trades) == 0: # No trades for this symbol
print(
f"Warning: No trades found for symbol {symbol} in pair {pair_name}"
)
continue
elif len(symbol_trades) >= 2: # Completed trades (entry + exit)
# Handle both old and new tuple formats
if len(symbol_trades[0]) == 2: # Old format: (action, price)
entry_action, entry_price = symbol_trades[0]
exit_action, exit_price = symbol_trades[1]
open_disequilibrium = 0.0 # Fallback for old format
open_scaled_disequilibrium = 0.0
close_disequilibrium = 0.0
close_scaled_disequilibrium = 0.0
open_time = datetime.now()
close_time = datetime.now()
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
(
entry_action,
entry_price,
open_disequilibrium,
open_scaled_disequilibrium,
open_time,
) = symbol_trades[0]
(
exit_action,
exit_price,
close_disequilibrium,
close_scaled_disequilibrium,
close_time,
) = symbol_trades[1]
# Handle None values
open_disequilibrium = (
open_disequilibrium
if open_disequilibrium is not None
else 0.0
)
open_scaled_disequilibrium = (
open_scaled_disequilibrium
if open_scaled_disequilibrium is not None
else 0.0
)
close_disequilibrium = (
close_disequilibrium
if close_disequilibrium is not None
else 0.0
)
close_scaled_disequilibrium = (
close_scaled_disequilibrium
if close_scaled_disequilibrium is not None
else 0.0
)
# Convert pandas Timestamps to Python datetime objects
open_time = convert_timestamp(open_time) or datetime.now()
close_time = convert_timestamp(close_time) or datetime.now()
# Calculate actual share quantities based on funding per pair
# Split funding equally between the two positions
funding_per_position = bt_result.config["funding_per_pair"] / 2
shares = funding_per_position / entry_price
# Calculate symbol return
symbol_return = 0.0
if entry_action == "BUY" and exit_action == "SELL":
symbol_return = (exit_price - entry_price) / entry_price * 100
elif entry_action == "SELL" and exit_action == "BUY":
symbol_return = (entry_price - exit_price) / entry_price * 100
pair_return += symbol_return
pair_trades.append(
{
"symbol": symbol,
"entry_action": entry_action,
"entry_price": entry_price,
"exit_action": exit_action,
"exit_price": exit_price,
"symbol_return": symbol_return,
"open_disequilibrium": open_disequilibrium,
"open_scaled_disequilibrium": open_scaled_disequilibrium,
"close_disequilibrium": close_disequilibrium,
"close_scaled_disequilibrium": close_scaled_disequilibrium,
"open_time": open_time,
"close_time": close_time,
"shares": shares,
"is_completed": True,
}
)
# Skip one-sided trades - they will be handled by outstanding_positions table
elif len(symbol_trades) == 1:
print(
f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table"
)
continue
else:
# This should not happen, but handle unexpected cases
print(
f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}"
)
continue
# Second pass: insert completed trade records into database
for trade in pair_trades:
# Only store completed trades in pt_bt_results table
cursor.execute(
"""
INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pair_name,
trade["symbol"],
trade["open_time"],
trade["entry_action"],
trade["entry_price"],
trade["shares"],
trade["open_scaled_disequilibrium"],
trade["close_time"],
trade["exit_action"],
trade["exit_price"],
trade["shares"],
trade["close_scaled_disequilibrium"],
trade["symbol_return"],
pair_return,
),
)
# Store outstanding positions in separate table
outstanding_positions = bt_result.get_outstanding_positions()
for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions)
position_qty_a = (
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = (
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_a"],
position_qty_a,
pos["current_px_a"],
unrealized_return_a,
pos["open_px_a"],
pos["side_a"],
),
)
# Store outstanding position for symbol B
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_b"],
position_qty_b,
pos["current_px_b"],
unrealized_return_b,
pos["open_px_b"],
pos["side_b"],
),
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error storing results in database: {str(e)}")
import traceback
traceback.print_exc()
class BacktestResult:
@ -206,19 +428,16 @@ class BacktestResult:
self.trades: Dict[str, Dict[str, Any]] = {}
self.total_realized_pnl = 0.0
self.outstanding_positions: List[Dict[str, Any]] = []
self.pairs_trades_: Dict[str, List[Dict[str, Any]]] = {}
def add_trade(
self,
pair_nm: str,
symbol: str,
side: str,
action: str,
price: Any,
disequilibrium: Optional[float] = None,
scaled_disequilibrium: Optional[float] = None,
timestamp: Optional[datetime] = None,
status: Optional[str] = None,
) -> None:
"""Add a trade to the results tracking."""
pair_nm = str(pair_nm)
@ -228,16 +447,7 @@ class BacktestResult:
if symbol not in self.trades[pair_nm]:
self.trades[pair_nm][symbol] = []
self.trades[pair_nm][symbol].append(
{
"symbol": symbol,
"side": side,
"action": action,
"price": price,
"disequilibrium": disequilibrium,
"scaled_disequilibrium": scaled_disequilibrium,
"timestamp": timestamp,
"status": status,
}
(action, price, disequilibrium, scaled_disequilibrium, timestamp)
)
def add_outstanding_position(self, position: Dict[str, Any]) -> None:
@ -274,27 +484,20 @@ class BacktestResult:
print(result)
for row in result.itertuples():
side = row.side
action = row.action
symbol = row.symbol
price = row.price
disequilibrium = getattr(row, "disequilibrium", None)
scaled_disequilibrium = getattr(row, "scaled_disequilibrium", None)
if hasattr(row, "time"):
timestamp = getattr(row, "time")
else:
timestamp = convert_timestamp(row.Index)
status = row.status
timestamp = getattr(row, "time", None)
self.add_trade(
pair_nm=str(row.pair),
symbol=str(symbol),
side=str(side),
action=str(action),
symbol=str(symbol),
price=float(str(price)),
disequilibrium=disequilibrium,
scaled_disequilibrium=scaled_disequilibrium,
timestamp=timestamp,
status=str(status) if status is not None else "?",
)
def print_single_day_results(self) -> None:
@ -320,126 +523,105 @@ class BacktestResult:
def calculate_returns(self, all_results: Dict[str, Dict[str, Any]]) -> None:
"""Calculate and print returns by day and pair."""
def _symbol_return(trade1_side: str, trade1_px: float, trade2_side: str, trade2_px: float) -> float:
if trade1_side == "BUY" and trade2_side == "SELL":
return (trade2_px - trade1_px) / trade1_px * 100
elif trade1_side == "SELL" and trade2_side == "BUY":
return (trade1_px - trade2_px) / trade1_px * 100
else:
return 0
print("\n====== Returns By Day and Pair ======")
trades = []
for filename, data in all_results.items():
pairs = list(data["trades"].keys())
for pair in pairs:
self.pairs_trades_[pair] = []
trades_dict = data["trades"][pair]
for symbol in trades_dict.keys():
trades.extend(trades_dict[symbol])
trades = sorted(trades, key=lambda x: (x["timestamp"], x["symbol"]))
day_return = 0
print(f"\n--- {filename} ---")
self.outstanding_positions = data["outstanding_positions"]
day_return = 0.0
for idx in range(0, len(trades), 4):
symbol_a = trades[idx]["symbol"]
trade_a_1 = trades[idx]
trade_a_2 = trades[idx + 2]
# Process each pair
for pair, symbols in data["trades"].items():
pair_return = 0
pair_trades = []
symbol_b = trades[idx + 1]["symbol"]
trade_b_1 = trades[idx + 1]
trade_b_2 = trades[idx + 3]
# Calculate individual symbol returns in the pair
for symbol, trades in symbols.items():
if len(trades) == 0:
continue
symbol_return = 0
assert (
trade_a_1["timestamp"] < trade_a_2["timestamp"]
), f"Trade 1: {trade_a_1['timestamp']} is not less than Trade 2: {trade_a_2['timestamp']}"
assert (
trade_a_1["action"] == "OPEN" and trade_a_2["action"] == "CLOSE"
), f"Trade 1: {trade_a_1['action']} and Trade 2: {trade_a_2['action']} are the same"
symbol_return = 0
symbol_trades = []
# Calculate return based on action combination
trade_return = 0
symbol_a_return = _symbol_return(trade_a_1["side"], trade_a_1["price"], trade_a_2["side"], trade_a_2["price"])
symbol_b_return = _symbol_return(trade_b_1["side"], trade_b_1["price"], trade_b_2["side"], trade_b_2["price"])
# Process all trades sequentially for this symbol
for i, trade in enumerate(trades):
# Handle both old and new tuple formats
if len(trade) == 2: # Old format: (action, price)
action, price = trade
disequilibrium = None
scaled_disequilibrium = None
timestamp = None
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
action, price = trade[:2]
disequilibrium = trade[2] if len(trade) > 2 else None
scaled_disequilibrium = trade[3] if len(trade) > 3 else None
timestamp = trade[4] if len(trade) > 4 else None
pair_return = symbol_a_return + symbol_b_return
symbol_trades.append((action, price, disequilibrium, scaled_disequilibrium, timestamp))
self.pairs_trades_[pair].append(
{
"symbol": symbol_a,
"open_side": trade_a_1["side"],
"open_action": trade_a_1["action"],
"open_price": trade_a_1["price"],
"close_side": trade_a_2["side"],
"close_action": trade_a_2["action"],
"close_price": trade_a_2["price"],
"symbol_return": symbol_a_return,
"open_disequilibrium": trade_a_1["disequilibrium"],
"open_scaled_disequilibrium": trade_a_1["scaled_disequilibrium"],
"close_disequilibrium": trade_a_2["disequilibrium"],
"close_scaled_disequilibrium": trade_a_2["scaled_disequilibrium"],
"open_time": trade_a_1["timestamp"],
"close_time": trade_a_2["timestamp"],
"shares": self.config["funding_per_pair"] / 2 / trade_a_1["price"],
"is_completed": True,
"close_condition": trade_a_2["status"],
"pair_return": pair_return
}
)
self.pairs_trades_[pair].append(
{
"symbol": symbol_b,
"open_side": trade_b_1["side"],
"open_action": trade_b_1["action"],
"open_price": trade_b_1["price"],
"close_side": trade_b_2["side"],
"close_action": trade_b_2["action"],
"close_price": trade_b_2["price"],
"symbol_return": symbol_b_return,
"open_disequilibrium": trade_b_1["disequilibrium"],
"open_scaled_disequilibrium": trade_b_1["scaled_disequilibrium"],
"close_disequilibrium": trade_b_2["disequilibrium"],
"close_scaled_disequilibrium": trade_b_2["scaled_disequilibrium"],
"open_time": trade_b_1["timestamp"],
"close_time": trade_b_2["timestamp"],
"shares": self.config["funding_per_pair"] / 2 / trade_b_1["price"],
"is_completed": True,
"close_condition": trade_b_2["status"],
"pair_return": pair_return
}
)
# Calculate returns for all trade combinations
for i in range(len(symbol_trades) - 1):
trade1 = symbol_trades[i]
trade2 = symbol_trades[i + 1]
action1, price1, diseq1, scaled_diseq1, ts1 = trade1
action2, price2, diseq2, scaled_diseq2, ts2 = trade2
# Print pair returns with disequilibrium information
day_return = 0.0
if pair in self.pairs_trades_:
# Calculate return based on action combination
trade_return = 0
if action1 == "BUY" and action2 == "SELL":
# Long position
trade_return = (price2 - price1) / price1 * 100
elif action1 == "SELL" and action2 == "BUY":
# Short position
trade_return = (price1 - price2) / price1 * 100
print(f"{pair}:")
pair_return = 0.0
for trd in self.pairs_trades_[pair]:
disequil_info = ""
if (
trd["open_scaled_disequilibrium"] is not None
and trd["open_scaled_disequilibrium"] is not None
):
disequil_info = f" | Open Dis-eq: {trd['open_scaled_disequilibrium']:.2f},"
f" Close Dis-eq: {trd['open_scaled_disequilibrium']:.2f}"
symbol_return += trade_return
print(
f" {trd['open_time'].time()}-{trd['close_time'].time()} {trd['symbol']}: "
f" {trd['open_side']} @ ${trd['open_price']:.2f},"
f" {trd["close_side"]} @ ${trd["close_price"]:.2f},"
f" Return: {trd['symbol_return']:.2f}%{disequil_info}"
)
pair_return += trd["symbol_return"]
# Store trade details for reporting
pair_trades.append(
(
symbol,
action1,
price1,
action2,
price2,
trade_return,
scaled_diseq1,
scaled_diseq2,
i + 1, # Trade sequence number
)
)
print(f" Pair Total Return: {pair_return:.2f}%")
day_return += pair_return
pair_return += symbol_return
# Print pair returns with disequilibrium information
if pair_trades:
print(f" {pair}:")
for (
symbol,
action1,
price1,
action2,
price2,
trade_return,
scaled_diseq1,
scaled_diseq2,
trade_num,
) in pair_trades:
disequil_info = ""
if (
scaled_diseq1 is not None
and scaled_diseq2 is not None
):
disequil_info = f" | Open Dis-eq: {scaled_diseq1:.2f}, Close Dis-eq: {scaled_diseq2:.2f}"
print(
f" {symbol} (Trade #{trade_num}): {action1} @ ${price1:.2f}, {action2} @ ${price2:.2f}, Return: {trade_return:.2f}%{disequil_info}"
)
print(f" Pair Total Return: {pair_return:.2f}%")
day_return += pair_return
# Print day total return and add to global realized PnL
if day_return != 0:
@ -552,7 +734,7 @@ class BacktestResult:
last_row = pair_result_df.loc[last_row_index]
last_tstamp = last_row["tstamp"]
colname_a, colname_b = pair.exec_prices_colnames()
colname_a, colname_b = pair.colnames()
last_px_a = last_row[colname_a]
last_px_b = last_row[colname_b]
@ -611,131 +793,3 @@ class BacktestResult:
)
return current_value_a, current_value_b, total_current_value
def store_results_in_database(
self, db_path: str, day: str
) -> None:
"""
Store backtest results in the SQLite database.
"""
if db_path.upper() == "NONE":
return
try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
date_str = day
# Convert to proper date format
try:
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
# If date parsing fails, use current date
date_obj = datetime.now().date()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Process each trade from bt_result
trades = self.get_trades()
for pair_name, _ in trades.items():
# Second pass: insert completed trade records into database
for trade_pair in sorted(self.pairs_trades_[pair_name], key=lambda x: x["open_time"]):
# Only store completed trades in pt_bt_results table
cursor.execute(
"""
INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return, close_condition
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pair_name,
trade_pair["symbol"],
trade_pair["open_time"],
trade_pair["open_side"],
trade_pair["open_price"],
trade_pair["shares"],
trade_pair["open_scaled_disequilibrium"],
trade_pair["close_time"],
trade_pair["close_side"],
trade_pair["close_price"],
trade_pair["shares"],
trade_pair["close_scaled_disequilibrium"],
trade_pair["symbol_return"],
trade_pair["pair_return"],
trade_pair["close_condition"]
),
)
# Store outstanding positions in separate table
outstanding_positions = self.get_outstanding_positions()
for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions)
position_qty_a = (
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = (
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_a"],
position_qty_a,
pos["current_px_a"],
unrealized_return_a,
pos["open_px_a"],
pos["side_a"],
),
)
# Store outstanding position for symbol B
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_b"],
position_qty_b,
pos["current_px_b"],
unrealized_return_b,
pos["open_px_b"],
pos["side_b"],
),
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error storing results in database: {str(e)}")
import traceback
traceback.print_exc()

View File

@ -1,319 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import PairState, TradingPair
from statsmodels.tsa.vector_ar.vecm import VECM, VECMResults
NanoPerMin = 1e9
class RollingFit(PairsTradingFitMethod):
"""
N O T E:
=========
- This class remains to be abstract
- The following methods are to be implemented in the subclass:
- create_trading_pair()
=========
"""
def __init__(self) -> None:
super().__init__()
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]:
print(f"***{pair}*** STARTING....")
config = pair.config_
curr_training_start_idx = pair.get_begin_index()
end_index = pair.get_end_index()
pair.user_data_["state"] = PairState.INITIAL
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS).astype(
{
"time": "datetime64[ns]",
"symbol": "string",
"side": "string",
"action": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object",
}
)
training_minutes = config["training_minutes"]
curr_predicted_row_idx = 0
while True:
print(curr_training_start_idx, end="\r")
pair.get_datasets(
training_minutes=training_minutes,
training_start_index=curr_training_start_idx,
testing_size=1,
)
if len(pair.training_df_) < training_minutes:
print(
f"{pair}: current offset={curr_training_start_idx}"
f" * Training data length={len(pair.training_df_)} < {training_minutes}"
" * Not enough training data. Completing the job."
)
break
try:
# ================================ PREDICTION ================================
self.pair_predict_result_ = pair.predict()
except Exception as e:
raise RuntimeError(
f"{pair}: TrainingPrediction failed: {str(e)}"
) from e
# break
curr_training_start_idx += 1
if curr_training_start_idx > end_index:
break
curr_predicted_row_idx += 1
self._create_trading_signals(pair, config, bt_result)
print(f"***{pair}*** FINISHED *** Num Trades:{len(pair.user_data_['trades'])}")
return pair.get_trades()
def _create_trading_signals(
self, pair: TradingPair, config: Dict, bt_result: BacktestResult
) -> None:
predicted_df = self.pair_predict_result_
assert predicted_df is not None
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
for curr_predicted_row_idx in range(len(predicted_df)):
pred_row = predicted_df.iloc[curr_predicted_row_idx]
scaled_disequilibrium = pred_row["scaled_disequilibrium"]
if pair.user_data_["state"] in [
PairState.INITIAL,
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]:
if scaled_disequilibrium >= open_threshold:
open_trades = self._get_open_trades(
pair, row=pred_row, open_threshold=open_threshold
)
if open_trades is not None:
open_trades["status"] = PairState.OPEN.name
print(f"OPEN TRADES:\n{open_trades}")
pair.add_trades(open_trades)
pair.user_data_["state"] = PairState.OPEN
pair.on_open_trades(open_trades)
elif pair.user_data_["state"] == PairState.OPEN:
if scaled_disequilibrium <= close_threshold:
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = PairState.CLOSE.name
print(f"CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = PairState.CLOSE
pair.on_close_trades(close_trades)
elif pair.to_stop_close_conditions(predicted_row=pred_row):
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = pair.user_data_[
"stop_close_state"
].name
print(f"STOP CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
pair.on_close_trades(close_trades)
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: *** Position is NOT CLOSED. ***")
# outstanding positions
if config["close_outstanding_positions"]:
close_position_row = pd.Series(pair.market_data_.iloc[-2])
close_position_row["disequilibrium"] = 0.0
close_position_row["scaled_disequilibrium"] = 0.0
close_position_row["signed_scaled_disequilibrium"] = 0.0
close_position_trades = self._get_close_trades(
pair=pair, row=close_position_row, close_threshold=close_threshold
)
if close_position_trades is not None:
close_position_trades["status"] = PairState.CLOSE_POSITION.name
print(f"CLOSE_POSITION TRADES:\n{close_position_trades}")
pair.add_trades(close_position_trades)
pair.user_data_["state"] = PairState.CLOSE_POSITION
pair.on_close_trades(close_position_trades)
else:
if predicted_df is not None:
bt_result.handle_outstanding_position(
pair=pair,
pair_result_df=predicted_df,
last_row_index=0,
open_side_a=pair.user_data_["open_side_a"],
open_side_b=pair.user_data_["open_side_b"],
open_px_a=pair.user_data_["open_px_a"],
open_px_b=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
)
def _get_open_trades(
self, pair: TradingPair, row: pd.Series, open_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
open_row = row
open_tstamp = open_row["tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
signed_scaled_disequilibrium = open_row["signed_scaled_disequilibrium"]
open_px_a = open_row[f"{colname_a}"]
open_px_b = open_row[f"{colname_b}"]
# creating the trades
print(f"OPEN_TRADES: {row["tstamp"]} {open_scaled_disequilibrium=}")
if open_disequilibrium > 0:
open_side_a = "SELL"
open_side_b = "BUY"
close_side_a = "BUY"
close_side_b = "SELL"
else:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
# save closing sides
pair.user_data_["open_side_a"] = open_side_a
pair.user_data_["open_side_b"] = open_side_b
pair.user_data_["open_px_a"] = open_px_a
pair.user_data_["open_px_b"] = open_px_b
pair.user_data_["open_tstamp"] = open_tstamp
pair.user_data_["close_side_a"] = close_side_a
pair.user_data_["close_side_b"] = close_side_b
# create opening trades
trd_signal_tuples = [
(
open_tstamp,
pair.symbol_a_,
open_side_a,
"OPEN",
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
(
open_tstamp,
pair.symbol_b_,
open_side_b,
"OPEN",
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
]
# Create DataFrame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype(
{
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
"pair": "object",
}
)
def _get_close_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
close_row = row
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
signed_scaled_disequilibrium = close_row["signed_scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
close_side_a = pair.user_data_["close_side_a"]
close_side_b = pair.user_data_["close_side_b"]
trd_signal_tuples = [
(
close_tstamp,
pair.symbol_a_,
close_side_a,
"CLOSE",
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
(
close_tstamp,
pair.symbol_b_,
close_side_b,
"CLOSE",
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype(
{
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
"pair": "object",
}
)
def reset(self) -> None:
curr_training_start_idx = 0

View File

@ -0,0 +1,362 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.fit_method import PairState, PairsTradingFitMethod
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
NanoPerMin = 1e9
class SlidingFit(PairsTradingFitMethod):
def __init__(self) -> None:
super().__init__()
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]:
print(f"***{pair}*** STARTING....")
config = pair.config_
curr_training_start_idx = pair.get_begin_index()
end_index = pair.get_end_index()
pair.user_data_["state"] = PairState.INITIAL
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS).astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
pair.user_data_["is_cointegrated"] = False
training_minutes = config["training_minutes"]
curr_predicted_row_idx = 0
while True:
print(curr_training_start_idx, end="\r")
pair.get_datasets(
training_minutes=training_minutes,
training_start_index=curr_training_start_idx,
testing_size=1,
)
if len(pair.training_df_) < training_minutes:
print(
f"{pair}: current offset={curr_training_start_idx}"
f" * Training data length={len(pair.training_df_)} < {training_minutes}"
" * Not enough training data. Completing the job."
)
break
try:
# ================================ TRAINING ================================
is_cointegrated = pair.train_pair()
except Exception as e:
raise RuntimeError(f"{pair}: Training failed: {str(e)}") from e
if pair.user_data_["is_cointegrated"] != is_cointegrated:
pair.user_data_["is_cointegrated"] = is_cointegrated
if not is_cointegrated:
if pair.user_data_["state"] == PairState.OPEN:
print(
f"{pair} {curr_training_start_idx} LOST COINTEGRATION. Consider closing positions..."
)
else:
print(
f"{pair} {curr_training_start_idx} IS NOT COINTEGRATED. Moving on"
)
else:
print("*" * 80)
print(
f"Pair {pair} ({curr_training_start_idx}) IS COINTEGRATED"
)
print("*" * 80)
if not is_cointegrated:
curr_training_start_idx += 1
continue
try:
# ================================ PREDICTION ================================
pair.predict()
except Exception as e:
raise RuntimeError(f"{pair}: Prediction failed: {str(e)}") from e
# break
curr_training_start_idx += 1
if curr_training_start_idx > end_index:
break
curr_predicted_row_idx += 1
self._create_trading_signals(pair, config, bt_result)
print(f"***{pair}*** FINISHED ... {len(pair.user_data_['trades'])}")
return pair.get_trades()
def _create_trading_signals(
self, pair: TradingPair, config: Dict, bt_result: BacktestResult
) -> None:
if pair.predicted_df_ is None:
print(f"{pair.market_data_.iloc[0]['tstamp']} {pair}: No predicted data")
return
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
for curr_predicted_row_idx in range(len(pair.predicted_df_)):
pred_row = pair.predicted_df_.iloc[curr_predicted_row_idx]
if pair.user_data_["state"] in [PairState.INITIAL, PairState.CLOSED, PairState.CLOSED_POSITIONS]:
open_trades = self._get_open_trades(
pair, row=pred_row, open_threshold=open_threshold
)
if open_trades is not None:
open_trades["status"] = "OPEN"
print(f"OPEN TRADES:\n{open_trades}")
pair.add_trades(open_trades)
pair.user_data_["state"] = PairState.OPEN
elif pair.user_data_["state"] == PairState.OPEN:
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = "CLOSE"
print(f"CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = PairState.CLOSED
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(
f"{pair}: *** Position is NOT CLOSED. ***"
)
# outstanding positions
if config["close_outstanding_positions"]:
close_position_trades = self._get_close_position_trades(
pair=pair,
row=pred_row,
close_threshold=close_threshold,
)
if close_position_trades is not None:
close_position_trades["status"] = "CLOSE_POSITION"
print(f"CLOSE_POSITION TRADES:\n{close_position_trades}")
pair.add_trades(close_position_trades)
pair.user_data_["state"] = PairState.CLOSED_POSITIONS
else:
if pair.predicted_df_ is not None:
bt_result.handle_outstanding_position(
pair=pair,
pair_result_df=pair.predicted_df_,
last_row_index=0,
open_side_a=pair.user_data_["open_side_a"],
open_side_b=pair.user_data_["open_side_b"],
open_px_a=pair.user_data_["open_px_a"],
open_px_b=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
)
def _get_open_trades(
self, pair: TradingPair, row: pd.Series, open_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames()
assert pair.predicted_df_ is not None
predicted_df = pair.predicted_df_
# Check if we have any data to work with
if len(predicted_df) == 0:
return None
open_row = row
open_tstamp = open_row["tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
open_px_a = open_row[f"{colname_a}"]
open_px_b = open_row[f"{colname_b}"]
if open_scaled_disequilibrium < open_threshold:
return None
# creating the trades
print(f"OPEN_TRADES: {row["tstamp"]} {open_scaled_disequilibrium=}")
if open_disequilibrium > 0:
open_side_a = "SELL"
open_side_b = "BUY"
close_side_a = "BUY"
close_side_b = "SELL"
else:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
# save closing sides
pair.user_data_["open_side_a"] = open_side_a
pair.user_data_["open_side_b"] = open_side_b
pair.user_data_["open_px_a"] = open_px_a
pair.user_data_["open_px_b"] = open_px_b
pair.user_data_["open_tstamp"] = open_tstamp
pair.user_data_["close_side_a"] = close_side_a
pair.user_data_["close_side_b"] = close_side_b
# create opening trades
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
]
# Create DataFrame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def _get_close_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames()
assert pair.predicted_df_ is not None
if len(pair.predicted_df_) == 0:
return None
close_row = row
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
close_side_a = pair.user_data_["close_side_a"]
close_side_b = pair.user_data_["close_side_b"]
if close_scaled_disequilibrium > close_threshold:
return None
trd_signal_tuples = [
(
close_tstamp,
close_side_a,
pair.symbol_a_,
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_b,
pair.symbol_b_,
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def _get_close_position_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames()
assert pair.predicted_df_ is not None
if len(pair.predicted_df_) == 0:
return None
close_position_row = row
close_position_tstamp = close_position_row["tstamp"]
close_position_disequilibrium = close_position_row["disequilibrium"]
close_position_scaled_disequilibrium = close_position_row["scaled_disequilibrium"]
close_position_px_a = close_position_row[f"{colname_a}"]
close_position_px_b = close_position_row[f"{colname_b}"]
close_position_side_a = pair.user_data_["close_side_a"]
close_position_side_b = pair.user_data_["close_side_b"]
trd_signal_tuples = [
(
close_position_tstamp,
close_position_side_a,
pair.symbol_a_,
close_position_px_a,
close_position_disequilibrium,
close_position_scaled_disequilibrium,
pair,
),
(
close_position_tstamp,
close_position_side_b,
pair.symbol_b_,
close_position_px_b,
close_position_disequilibrium,
close_position_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def reset(self) -> None:
curr_training_start_idx = 0

View File

@ -0,0 +1,220 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
from pt_trading.fit_method import PairsTradingFitMethod
NanoPerMin = 1e9
class StaticFit(PairsTradingFitMethod):
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]: # abstractmethod
config = pair.config_
pair.get_datasets(training_minutes=config["training_minutes"])
try:
is_cointegrated = pair.train_pair()
if not is_cointegrated:
print(f"{pair} IS NOT COINTEGRATED")
return None
except Exception as e:
print(f"{pair}: Training failed: {str(e)}")
return None
try:
pair.predict()
except Exception as e:
print(f"{pair}: Prediction failed: {str(e)}")
return None
pair_trades = self.create_trading_signals(
pair=pair, config=config, result=bt_result
)
return pair_trades
def create_trading_signals(
self, pair: TradingPair, config: Dict, result: BacktestResult
) -> pd.DataFrame:
beta = pair.vecm_fit_.beta # type: ignore
colname_a, colname_b = pair.colnames()
predicted_df = pair.predicted_df_
if predicted_df is None:
# Return empty DataFrame with correct columns and dtypes
return pd.DataFrame(columns=self.TRADES_COLUMNS).astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
# Iterate through the testing dataset to find the first trading opportunity
open_row_index = None
for row_idx in range(len(predicted_df)):
curr_disequilibrium = predicted_df["scaled_disequilibrium"][row_idx]
# Check if current row has sufficient disequilibrium (not near-zero)
if curr_disequilibrium >= open_threshold:
open_row_index = row_idx
break
# If no row with sufficient disequilibrium found, skip this pair
if open_row_index is None:
print(f"{pair}: Insufficient disequilibrium in testing dataset. Skipping.")
return pd.DataFrame()
# Look for close signal starting from the open position
trading_signals_df = (
predicted_df["scaled_disequilibrium"][open_row_index:] < close_threshold
)
# Adjust indices to account for the offset from open_row_index
close_row_index = None
for idx, value in trading_signals_df.items():
if value:
close_row_index = idx
break
open_row = predicted_df.loc[open_row_index]
open_px_a = predicted_df.at[open_row_index, f"{colname_a}"]
open_px_b = predicted_df.at[open_row_index, f"{colname_b}"]
open_tstamp = predicted_df.at[open_row_index, "tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
abs_beta = abs(beta[1])
pred_px_b = predicted_df.loc[open_row_index][f"{colname_b}_pred"]
pred_px_a = predicted_df.loc[open_row_index][f"{colname_a}_pred"]
if pred_px_b * abs_beta - pred_px_a > 0:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
else:
open_side_b = "BUY"
open_side_a = "SELL"
close_side_b = "SELL"
close_side_a = "BUY"
# If no close signal found, print position and unrealized PnL
if close_row_index is None:
last_row_index = len(predicted_df) - 1
# Use the new method from BacktestResult to handle outstanding positions
result.handle_outstanding_position(
pair=pair,
pair_result_df=predicted_df,
last_row_index=last_row_index,
open_side_a=open_side_a,
open_side_b=open_side_b,
open_px_a=float(open_px_a),
open_px_b=float(open_px_b),
open_tstamp=pd.Timestamp(open_tstamp),
)
# Return only open trades (no close trades)
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
]
else:
# Close signal found - create complete trade
close_row = predicted_df.loc[close_row_index]
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
print(f"{pair}: Close signal found at index {close_row_index}")
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_a,
pair.symbol_a_,
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_b,
pair.symbol_b_,
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def reset(self) -> None:
pass

View File

@ -1,79 +1,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional
import pandas as pd # type:ignore
from statsmodels.tsa.vector_ar.vecm import VECM, VECMResults # type:ignore
class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSE = 3
CLOSE_POSITION = 4
CLOSE_STOP_LOSS = 5
CLOSE_STOP_PROFIT = 6
class CointegrationData:
EG_PVALUE_THRESHOLD = 0.05
tstamp_: pd.Timestamp
pair_: str
eg_pvalue_: float
johansen_lr1_: float
johansen_cvt_: float
eg_is_cointegrated_: bool
johansen_is_cointegrated_: bool
def __init__(self, pair: TradingPair):
training_df = pair.training_df_
assert training_df is not None
from statsmodels.tsa.vector_ar.vecm import coint_johansen
df = training_df[pair.colnames()].reset_index(drop=True)
# Run Johansen cointegration test
result = coint_johansen(df, det_order=0, k_ar_diff=1)
self.johansen_lr1_ = result.lr1[0]
self.johansen_cvt_ = result.cvt[0, 1]
self.johansen_is_cointegrated_ = self.johansen_lr1_ > self.johansen_cvt_
# Run Engle-Granger cointegration test
from statsmodels.tsa.stattools import coint # type: ignore
col1, col2 = pair.colnames()
assert training_df is not None
series1 = training_df[col1].reset_index(drop=True)
series2 = training_df[col2].reset_index(drop=True)
self.eg_pvalue_ = float(coint(series1, series2)[1])
self.eg_is_cointegrated_ = bool(self.eg_pvalue_ < self.EG_PVALUE_THRESHOLD)
self.tstamp_ = training_df.index[-1]
self.pair_ = pair.name()
def to_dict(self) -> Dict[str, Any]:
return {
"tstamp": self.tstamp_,
"pair": self.pair_,
"eg_pvalue": self.eg_pvalue_,
"johansen_lr1": self.johansen_lr1_,
"johansen_cvt": self.johansen_cvt_,
"eg_is_cointegrated": self.eg_is_cointegrated_,
"johansen_is_cointegrated": self.johansen_is_cointegrated_,
}
def __repr__(self) -> str:
return f"CointegrationData(tstamp={self.tstamp_}, pair={self.pair_}, eg_pvalue={self.eg_pvalue_}, johansen_lr1={self.johansen_lr1_}, johansen_cvt={self.johansen_cvt_}, eg_is_cointegrated={self.eg_is_cointegrated_}, johansen_is_cointegrated={self.johansen_is_cointegrated_})"
class TradingPair(ABC):
class TradingPair:
market_data_: pd.DataFrame
symbol_a_: str
symbol_b_: str
stat_model_price_: str
price_column_: str
training_mu_: float
training_std_: float
@ -81,62 +16,39 @@ class TradingPair(ABC):
training_df_: pd.DataFrame
testing_df_: pd.DataFrame
vecm_fit_: VECMResults
user_data_: Dict[str, Any]
# predicted_df_: Optional[pd.DataFrame]
predicted_df_: Optional[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
self, config: Dict[str, Any], market_data: pd.DataFrame, symbol_a: str, symbol_b: str, price_column: str
):
self.symbol_a_ = symbol_a
self.symbol_b_ = symbol_b
self.stat_model_price_ = config["stat_model_price"]
self.price_column_ = price_column
self.set_market_data(market_data)
self.user_data_ = {}
self.predicted_df_ = None
self.config_ = config
self._set_market_data(market_data)
def _set_market_data(self, market_data: pd.DataFrame) -> None:
def set_market_data(self, market_data: pd.DataFrame) -> None:
self.market_data_ = pd.DataFrame(
self._transform_dataframe(market_data)[["tstamp"] + self.colnames()]
)
self.market_data_ = self.market_data_.dropna().reset_index(drop=True)
self.market_data_["tstamp"] = pd.to_datetime(self.market_data_["tstamp"])
self.market_data_ = self.market_data_.sort_values("tstamp")
self._set_execution_price_data()
pass
def _set_execution_price_data(self) -> None:
if "execution_price" not in self.config_:
self.market_data_[f"exec_price_{self.symbol_a_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_a_}"]
self.market_data_[f"exec_price_{self.symbol_b_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_b_}"]
return
execution_price_column = self.config_["execution_price"]["column"]
execution_price_shift = self.config_["execution_price"]["shift"]
self.market_data_[f"exec_price_{self.symbol_a_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_a_}"].shift(-execution_price_shift)
self.market_data_[f"exec_price_{self.symbol_b_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_b_}"].shift(-execution_price_shift)
self.market_data_ = self.market_data_.dropna().reset_index(drop=True)
self.market_data_['tstamp'] = pd.to_datetime(self.market_data_['tstamp'])
self.market_data_ = self.market_data_.sort_values('tstamp')
def get_begin_index(self) -> int:
if "trading_hours" not in self.config_:
return 0
assert "timezone" in self.config_["trading_hours"]
assert "begin_session" in self.config_["trading_hours"]
start_time = (
pd.to_datetime(self.config_["trading_hours"]["begin_session"])
.tz_localize(self.config_["trading_hours"]["timezone"])
.time()
)
mask = self.market_data_["tstamp"].dt.time >= start_time
start_time = pd.to_datetime(self.config_["trading_hours"]["begin_session"]).tz_localize(self.config_["trading_hours"]["timezone"]).time()
mask = self.market_data_['tstamp'].dt.time >= start_time
return int(self.market_data_.index[mask].min())
def get_end_index(self) -> int:
@ -144,18 +56,14 @@ class TradingPair(ABC):
return 0
assert "timezone" in self.config_["trading_hours"]
assert "end_session" in self.config_["trading_hours"]
end_time = (
pd.to_datetime(self.config_["trading_hours"]["end_session"])
.tz_localize(self.config_["trading_hours"]["timezone"])
.time()
)
mask = self.market_data_["tstamp"].dt.time <= end_time
end_time = pd.to_datetime(self.config_["trading_hours"]["end_session"]).tz_localize(self.config_["trading_hours"]["timezone"]).time()
mask = self.market_data_['tstamp'].dt.time <= end_time
return int(self.market_data_.index[mask].max())
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
# Select only the columns we need
df_selected: pd.DataFrame = pd.DataFrame(
df[["tstamp", "symbol", self.stat_model_price_]]
df[["tstamp", "symbol", self.price_column_]]
)
# Start with unique timestamps
@ -173,13 +81,13 @@ class TradingPair(ABC):
)
# Create column name like "close-COIN"
new_price_column = f"{self.stat_model_price_}_{symbol}"
new_price_column = f"{self.price_column_}_{symbol}"
# Create temporary dataframe with timestamp and price
temp_df = pd.DataFrame(
{
"tstamp": df_symbol["tstamp"],
new_price_column: df_symbol[self.stat_model_price_],
new_price_column: df_symbol[self.price_column_],
}
)
@ -200,7 +108,7 @@ class TradingPair(ABC):
testing_start_index = training_start_index + training_minutes
self.training_df_ = self.market_data_.iloc[
training_start_index:testing_start_index, :training_minutes
training_start_index:testing_start_index, : training_minutes
].copy()
assert self.training_df_ is not None
self.training_df_ = self.training_df_.dropna().reset_index(drop=True)
@ -217,15 +125,82 @@ class TradingPair(ABC):
def colnames(self) -> List[str]:
return [
f"{self.stat_model_price_}_{self.symbol_a_}",
f"{self.stat_model_price_}_{self.symbol_b_}",
f"{self.price_column_}_{self.symbol_a_}",
f"{self.price_column_}_{self.symbol_b_}",
]
def exec_prices_colnames(self) -> List[str]:
return [
f"exec_price_{self.symbol_a_}",
f"exec_price_{self.symbol_b_}",
]
def fit_VECM(self) -> None:
assert self.training_df_ is not None
vecm_df = self.training_df_[self.colnames()].reset_index(drop=True)
vecm_model = VECM(vecm_df, coint_rank=1)
vecm_fit = vecm_model.fit()
assert vecm_fit is not None
# URGENT check beta and alpha
# Check if the model converged properly
if not hasattr(vecm_fit, "beta") or vecm_fit.beta is None:
print(f"{self}: VECM model failed to converge properly")
self.vecm_fit_ = vecm_fit
# print(f"{self}: beta={self.vecm_fit_.beta} alpha={self.vecm_fit_.alpha}" )
# print(f"{self}: {self.vecm_fit_.summary()}")
pass
def check_cointegration_johansen(self) -> bool:
assert self.training_df_ is not None
from statsmodels.tsa.vector_ar.vecm import coint_johansen
df = self.training_df_[self.colnames()].reset_index(drop=True)
result = coint_johansen(df, det_order=0, k_ar_diff=1)
# print(
# f"{self}: lr1={result.lr1[0]} > cvt={result.cvt[0, 1]}? {result.lr1[0] > result.cvt[0, 1]}"
# )
is_cointegrated: bool = bool(result.lr1[0] > result.cvt[0, 1])
return is_cointegrated
def check_cointegration_engle_granger(self) -> bool:
from statsmodels.tsa.stattools import coint
col1, col2 = self.colnames()
assert self.training_df_ is not None
series1 = self.training_df_[col1].reset_index(drop=True)
series2 = self.training_df_[col2].reset_index(drop=True)
# Run Engle-Granger cointegration test
pvalue = coint(series1, series2)[1]
# Define cointegration if p-value < 0.05 (i.e., reject null of no cointegration)
is_cointegrated: bool = bool(pvalue < 0.05)
# print(f"{self}: is_cointegrated={is_cointegrated} pvalue={pvalue}")
return is_cointegrated
def check_cointegration(self) -> bool:
is_cointegrated_johansen = self.check_cointegration_johansen()
is_cointegrated_engle_granger = self.check_cointegration_engle_granger()
result = is_cointegrated_johansen or is_cointegrated_engle_granger
return result or True # TODO: remove this
def train_pair(self) -> bool:
result = self.check_cointegration()
# print('*' * 80 + '\n' + f"**************** {self} IS COINTEGRATED ****************\n" + '*' * 80)
self.fit_VECM()
assert self.training_df_ is not None and self.vecm_fit_ is not None
diseq_series = self.training_df_[self.colnames()] @ self.vecm_fit_.beta
# print(diseq_series.shape)
self.training_mu_ = float(diseq_series[0].mean())
self.training_std_ = float(diseq_series[0].std())
self.training_df_["dis-equilibrium"] = (
self.training_df_[self.colnames()] @ self.vecm_fit_.beta
)
# Normalize the dis-equilibrium
self.training_df_["scaled_dis-equilibrium"] = (
diseq_series - self.training_mu_
) / self.training_std_
return result
def add_trades(self, trades: pd.DataFrame) -> None:
if self.user_data_["trades"] is None or len(self.user_data_["trades"]) == 0:
@ -248,11 +223,7 @@ class TradingPair(ABC):
trades[col] = pd.Timestamp.now()
elif col in ["action", "symbol"]:
trades[col] = ""
elif col in [
"price",
"disequilibrium",
"scaled_disequilibrium",
]:
elif col in ["price", "disequilibrium", "scaled_disequilibrium"]:
trades[col] = 0.0
elif col == "pair":
trades[col] = None
@ -261,110 +232,59 @@ class TradingPair(ABC):
# Concatenate with explicit dtypes to avoid warnings
self.user_data_["trades"] = pd.concat(
[existing_trades, trades], ignore_index=True, copy=False
[existing_trades, trades],
ignore_index=True,
copy=False
)
def get_trades(self) -> pd.DataFrame:
return (
self.user_data_["trades"] if "trades" in self.user_data_ else pd.DataFrame()
return self.user_data_["trades"] if "trades" in self.user_data_ else pd.DataFrame()
def predict(self) -> pd.DataFrame:
assert self.testing_df_ is not None
assert self.vecm_fit_ is not None
predicted_prices = self.vecm_fit_.predict(steps=len(self.testing_df_))
# Convert prediction to a DataFrame for readability
predicted_df = pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
)
def cointegration_check(self) -> Optional[pd.DataFrame]:
print(f"***{self}*** STARTING....")
config = self.config_
curr_training_start_idx = 0
predicted_df = pd.merge(
self.testing_df_.reset_index(drop=True),
pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
),
left_index=True,
right_index=True,
suffixes=("", "_pred"),
).dropna()
COINTEGRATION_DATA_COLUMNS = {
"tstamp": "datetime64[ns]",
"pair": "string",
"eg_pvalue": "float64",
"johansen_lr1": "float64",
"johansen_cvt": "float64",
"eg_is_cointegrated": "bool",
"johansen_is_cointegrated": "bool",
}
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
result: pd.DataFrame = pd.DataFrame(
columns=[col for col in COINTEGRATION_DATA_COLUMNS.keys()]
) # .astype(COINTEGRATION_DATA_COLUMNS)
predicted_df["disequilibrium"] = (
predicted_df[self.colnames()] @ self.vecm_fit_.beta
)
training_minutes = config["training_minutes"]
while True:
print(curr_training_start_idx, end="\r")
self.get_datasets(
training_minutes=training_minutes,
training_start_index=curr_training_start_idx,
testing_size=1,
)
predicted_df["scaled_disequilibrium"] = (
abs(predicted_df["disequilibrium"] - self.training_mu_)
/ self.training_std_
)
if len(self.training_df_) < training_minutes:
print(
f"{self}: current offset={curr_training_start_idx}"
f" * Training data length={len(self.training_df_)} < {training_minutes}"
" * Not enough training data. Completing the job."
)
break
new_row = pd.Series(CointegrationData(self).to_dict())
result.loc[len(result)] = new_row
curr_training_start_idx += 1
return result
# print("*** PREDICTED DF")
# print(predicted_df)
# print("*" * 80)
# print("*** SELF.PREDICTED_DF")
# print(self.predicted_df_)
# print("*" * 80)
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
config = self.config_
if (
"stop_close_conditions" not in config
or config["stop_close_conditions"] is None
):
return False
if "profit" in config["stop_close_conditions"]:
current_return = self._current_return(predicted_row)
#
# print(f"time={predicted_row['tstamp']} current_return={current_return}")
#
if current_return >= config["stop_close_conditions"]["profit"]:
print(f"STOP PROFIT: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
return True
if "loss" in config["stop_close_conditions"]:
if current_return <= config["stop_close_conditions"]["loss"]:
print(f"STOP LOSS: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
return True
return False
def on_open_trades(self, trades: pd.DataFrame) -> None:
if "close_trades" in self.user_data_:
del self.user_data_["close_trades"]
self.user_data_["open_trades"] = trades
def on_close_trades(self, trades: pd.DataFrame) -> None:
del self.user_data_["open_trades"]
self.user_data_["close_trades"] = trades
def _current_return(self, predicted_row: pd.Series) -> float:
if "open_trades" in self.user_data_:
open_trades = self.user_data_["open_trades"]
if len(open_trades) == 0:
return 0.0
def _single_instrument_return(symbol: str) -> float:
instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
instrument_open_price = instrument_open_trades["price"].iloc[0]
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"]
instrument_return = (
sign
* (instrument_price - instrument_open_price)
/ instrument_open_price
)
return float(instrument_return) * 100.0
instrument_a_return = _single_instrument_return(self.symbol_a_)
instrument_b_return = _single_instrument_return(self.symbol_b_)
return instrument_a_return + instrument_b_return
return 0.0
predicted_df = predicted_df.reset_index(drop=True)
if self.predicted_df_ is None:
self.predicted_df_ = predicted_df
else:
self.predicted_df_ = pd.concat([self.predicted_df_, predicted_df], ignore_index=True)
# Reset index to ensure proper indexing
self.predicted_df_ = self.predicted_df_.reset_index(drop=True)
return self.predicted_df_
def __repr__(self) -> str:
return self.name()
@ -372,9 +292,3 @@ class TradingPair(ABC):
def name(self) -> str:
return f"{self.symbol_a_} & {self.symbol_b_}"
# return f"{self.symbol_a_} & {self.symbol_b_}"
@abstractmethod
def predict(self) -> pd.DataFrame: ...
# @abstractmethod
# def predicted_df(self) -> Optional[pd.DataFrame]: ...

View File

@ -1,122 +0,0 @@
from typing import Any, Dict, Optional, cast
import pandas as pd
from pt_trading.results import BacktestResult
from pt_trading.rolling_window_fit import RollingFit
from pt_trading.trading_pair import TradingPair
from statsmodels.tsa.vector_ar.vecm import VECM, VECMResults
NanoPerMin = 1e9
class VECMTradingPair(TradingPair):
vecm_fit_: Optional[VECMResults]
pair_predict_result_: Optional[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
):
super().__init__(config, market_data, symbol_a, symbol_b)
self.vecm_fit_ = None
self.pair_predict_result_ = None
def _train_pair(self) -> None:
self._fit_VECM()
assert self.vecm_fit_ is not None
diseq_series = self.training_df_[self.colnames()] @ self.vecm_fit_.beta
# print(diseq_series.shape)
self.training_mu_ = float(diseq_series[0].mean())
self.training_std_ = float(diseq_series[0].std())
self.training_df_["dis-equilibrium"] = (
self.training_df_[self.colnames()] @ self.vecm_fit_.beta
)
# Normalize the dis-equilibrium
self.training_df_["scaled_dis-equilibrium"] = (
diseq_series - self.training_mu_
) / self.training_std_
def _fit_VECM(self) -> None:
assert self.training_df_ is not None
vecm_df = self.training_df_[self.colnames()].reset_index(drop=True)
vecm_model = VECM(vecm_df, coint_rank=1)
vecm_fit = vecm_model.fit()
assert vecm_fit is not None
# URGENT check beta and alpha
# Check if the model converged properly
if not hasattr(vecm_fit, "beta") or vecm_fit.beta is None:
print(f"{self}: VECM model failed to converge properly")
self.vecm_fit_ = vecm_fit
pass
def predict(self) -> pd.DataFrame:
self._train_pair()
assert self.testing_df_ is not None
assert self.vecm_fit_ is not None
predicted_prices = self.vecm_fit_.predict(steps=len(self.testing_df_))
# Convert prediction to a DataFrame for readability
predicted_df = pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
)
predicted_df = pd.merge(
self.testing_df_.reset_index(drop=True),
pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
),
left_index=True,
right_index=True,
suffixes=("", "_pred"),
).dropna()
predicted_df["disequilibrium"] = (
predicted_df[self.colnames()] @ self.vecm_fit_.beta
)
predicted_df["signed_scaled_disequilibrium"] = (
predicted_df["disequilibrium"] - self.training_mu_
) / self.training_std_
predicted_df["scaled_disequilibrium"] = abs(
predicted_df["signed_scaled_disequilibrium"]
)
predicted_df = predicted_df.reset_index(drop=True)
if self.pair_predict_result_ is None:
self.pair_predict_result_ = predicted_df
else:
self.pair_predict_result_ = pd.concat(
[self.pair_predict_result_, predicted_df], ignore_index=True
)
# Reset index to ensure proper indexing
self.pair_predict_result_ = self.pair_predict_result_.reset_index(drop=True)
return self.pair_predict_result_
class VECMRollingFit(RollingFit):
def __init__(self) -> None:
super().__init__()
def create_trading_pair(
self,
config: Dict,
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
) -> TradingPair:
return VECMTradingPair(
config=config,
market_data=market_data,
symbol_a=symbol_a,
symbol_b=symbol_b,
)

View File

@ -1,85 +0,0 @@
from typing import Any, Dict, Optional, cast
import pandas as pd
from pt_trading.results import BacktestResult
from pt_trading.rolling_window_fit import RollingFit
from pt_trading.trading_pair import TradingPair
import statsmodels.api as sm
NanoPerMin = 1e9
class ZScoreTradingPair(TradingPair):
zscore_model_: Optional[sm.regression.linear_model.RegressionResultsWrapper]
pair_predict_result_: Optional[pd.DataFrame]
zscore_df_: Optional[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
):
super().__init__(config, market_data, symbol_a, symbol_b)
self.zscore_model_ = None
self.pair_predict_result_ = None
self.zscore_df_ = None
def _fit_zscore(self) -> None:
assert self.training_df_ is not None
symbol_a_px_series = self.training_df_[self.colnames()].iloc[:, 0]
symbol_b_px_series = self.training_df_[self.colnames()].iloc[:, 1]
symbol_a_px_series, symbol_b_px_series = symbol_a_px_series.align(
symbol_b_px_series, axis=0
)
X = sm.add_constant(symbol_b_px_series)
self.zscore_model_ = sm.OLS(symbol_a_px_series, X).fit()
assert self.zscore_model_ is not None
hedge_ratio = self.zscore_model_.params.iloc[1]
# Calculate spread and Z-score
spread = symbol_a_px_series - hedge_ratio * symbol_b_px_series
self.zscore_df_ = (spread - spread.mean()) / spread.std()
def predict(self) -> pd.DataFrame:
self._fit_zscore()
assert self.zscore_df_ is not None
self.training_df_["dis-equilibrium"] = self.zscore_df_
self.training_df_["scaled_dis-equilibrium"] = abs(self.zscore_df_)
assert self.testing_df_ is not None
assert self.zscore_df_ is not None
predicted_df = self.testing_df_
predicted_df["disequilibrium"] = self.zscore_df_
predicted_df["signed_scaled_disequilibrium"] = self.zscore_df_
predicted_df["scaled_disequilibrium"] = abs(self.zscore_df_)
predicted_df = predicted_df.reset_index(drop=True)
if self.pair_predict_result_ is None:
self.pair_predict_result_ = predicted_df
else:
self.pair_predict_result_ = pd.concat(
[self.pair_predict_result_, predicted_df], ignore_index=True
)
# Reset index to ensure proper indexing
self.pair_predict_result_ = self.pair_predict_result_.reset_index(drop=True)
return self.pair_predict_result_.dropna()
class ZScoreRollingFit(RollingFit):
def __init__(self) -> None:
super().__init__()
def create_trading_pair(
self, config: Dict, market_data: pd.DataFrame, symbol_a: str, symbol_b: str
) -> TradingPair:
return ZScoreTradingPair(
config=config,
market_data=market_data,
symbol_a=symbol_a,
symbol_b=symbol_b,
)

View File

@ -1,17 +1,10 @@
from __future__ import annotations
import sqlite3
from typing import Dict, List, cast
import pandas as pd
def load_sqlite_to_dataframe(db_path:str, query:str) -> pd.DataFrame:
df: pd.DataFrame = pd.DataFrame()
import os
if not os.path.exists(db_path):
print(f"WARNING: database file {db_path} does not exist")
return df
def load_sqlite_to_dataframe(db_path, query):
try:
conn = sqlite3.connect(db_path)
@ -28,14 +21,13 @@ def load_sqlite_to_dataframe(db_path:str, query:str) -> pd.DataFrame:
conn.close()
def convert_time_to_UTC(value: str, timezone: str, extra_minutes: int = 0) -> str:
def convert_time_to_UTC(value: str, timezone: str) -> str:
from zoneinfo import ZoneInfo
from datetime import datetime, timedelta
from datetime import datetime
# Parse it to naive datetime object
local_dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
local_dt = local_dt + timedelta(minutes=extra_minutes)
zinfo = ZoneInfo(timezone)
result: datetime = local_dt.replace(tzinfo=zinfo).astimezone(ZoneInfo("UTC"))
@ -43,28 +35,25 @@ def convert_time_to_UTC(value: str, timezone: str, extra_minutes: int = 0) -> st
return result.strftime("%Y-%m-%d %H:%M:%S")
def load_market_data(
datafile: str,
instruments: List[Dict[str, str]],
db_table_name: str,
trading_hours: Dict = {},
extra_minutes: int = 0,
) -> pd.DataFrame:
def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
from tools.data_loader import load_sqlite_to_dataframe
insts = [
'"' + instrument["instrument_id_pfx"] + instrument["symbol"] + '"'
for instrument in instruments
instrument_ids = [
'"' + config["instrument_id_pfx"] + instrument + '"'
for instrument in config["instruments"]
]
instrument_ids = list(set(insts))
exchange_ids = list(
set(['"' + instrument["exchange_id"] + '"' for instrument in instruments])
)
security_type = config["security_type"]
exchange_id = config["exchange_id"]
query = "select"
query += " tstamp"
query += ", tstamp_ns as time_ns"
if security_type == "CRYPTO":
query += " strftime('%Y-%m-%d %H:%M:%S', tstamp_ns/1000000000, 'unixepoch') as tstamp"
query += ", tstamp as time_ns"
else:
query += " tstamp"
query += ", tstamp_ns as time_ns"
query += f", substr(instrument_id, instr(instrument_id, '-') + 1) as symbol"
query += f", substr(instrument_id, {len(config['instrument_id_pfx']) + 1}) as symbol"
query += ", open"
query += ", high"
query += ", low"
@ -73,76 +62,74 @@ def load_market_data(
query += ", num_trades"
query += ", vwap"
query += f" from {db_table_name}"
query += f" where exchange_id in ({','.join(exchange_ids)})"
query += f" from {config['db_table_name']}"
query += f" where exchange_id ='{exchange_id}'"
query += f" and instrument_id in ({','.join(instrument_ids)})"
df = load_sqlite_to_dataframe(db_path=datafile, query=query)
# Trading Hours
if len(df) > 0 and len(trading_hours) > 0:
date_str = df["tstamp"][0][0:10]
date_str = df["tstamp"][0][0:10]
trading_hours = config["trading_hours"]
start_time = convert_time_to_UTC(
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
)
end_time = convert_time_to_UTC(
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"], extra_minutes=extra_minutes # to get execution price
)
start_time = convert_time_to_UTC(
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
)
end_time = convert_time_to_UTC(
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"]
)
# Perform boolean selection
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
df["tstamp"] = pd.to_datetime(df["tstamp"])
# Perform boolean selection
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
df["tstamp"] = pd.to_datetime(df["tstamp"])
return cast(pd.DataFrame, df)
# def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
# """
# Auto-detect available instruments from the database by querying distinct instrument_id values.
# Returns instruments without the configured prefix.
# """
# try:
# conn = sqlite3.connect(datafile)
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
"""
Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix.
"""
try:
conn = sqlite3.connect(datafile)
# # Build exclusion list with full instrument_ids
# exclude_instruments = config.get("exclude_instruments", [])
# prefix = config.get("instrument_id_pfx", "")
# exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# Build exclusion list with full instrument_ids
exclude_instruments = config.get("exclude_instruments", [])
prefix = config.get("instrument_id_pfx", "")
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# # Query to get distinct instrument_ids
# query = f"""
# SELECT DISTINCT instrument_id
# FROM {config['db_table_name']}
# WHERE exchange_id = ?
# """
# Query to get distinct instrument_ids
query = f"""
SELECT DISTINCT instrument_id
FROM {config['db_table_name']}
WHERE exchange_id = ?
"""
# # Add exclusion clause if there are instruments to exclude
# if exclude_instrument_ids:
# placeholders = ",".join(["?" for _ in exclude_instrument_ids])
# query += f" AND instrument_id NOT IN ({placeholders})"
# cursor = conn.execute(
# query, (config["exchange_id"],) + tuple(exclude_instrument_ids)
# )
# else:
# cursor = conn.execute(query, (config["exchange_id"],))
# instrument_ids = [row[0] for row in cursor.fetchall()]
# conn.close()
# Add exclusion clause if there are instruments to exclude
if exclude_instrument_ids:
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
query += f" AND instrument_id NOT IN ({placeholders})"
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
else:
cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close()
# # Remove the configured prefix to get instrument symbols
# instruments = []
# for instrument_id in instrument_ids:
# if instrument_id.startswith(prefix):
# symbol = instrument_id[len(prefix) :]
# instruments.append(symbol)
# else:
# instruments.append(instrument_id)
# Remove the configured prefix to get instrument symbols
instruments = []
for instrument_id in instrument_ids:
if instrument_id.startswith(prefix):
symbol = instrument_id[len(prefix) :]
instruments.append(symbol)
else:
instruments.append(instrument_id)
# return sorted(instruments)
return sorted(instruments)
# except Exception as e:
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
# return []
except Exception as e:
print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
return []
# if __name__ == "__main__":

View File

@ -74,7 +74,6 @@ PyYAML>=6.0
reportlab>=3.6.8
requests>=2.25.1
requests-file>=1.5.1
scipy<1.13.0
seaborn>=0.13.2
SecretStorage>=3.3.1
setproctitle>=1.2.2

View File

@ -1,126 +0,0 @@
import argparse
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import pandas as pd
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
from research.research_tools import create_pairs, resolve_datafiles
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafile",
type=str,
required=False,
help="Market data file to process.",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Resolve data files (CLI takes priority over config)
datafile = resolve_datafiles(config, args.datafile)[0]
if not datafile:
print("No data files found to process.")
return
print(f"Found {datafile} data files to process:")
# # Create result database if needed
# if args.result_db.upper() != "NONE":
# args.result_db = expand_filename(args.result_db)
# create_result_database(args.result_db)
# # Initialize a dictionary to store all trade results
# all_results: Dict[str, Dict[str, Any]] = {}
# # Store configuration in database for reference
# if args.result_db.upper() != "NONE":
# # Get list of all instruments for storage
# all_instruments = []
# for datafile in datafiles:
# if args.instruments:
# file_instruments = [
# inst.strip() for inst in args.instruments.split(",")
# ]
# else:
# file_instruments = get_available_instruments_from_db(datafile, config)
# all_instruments.extend(file_instruments)
# # Remove duplicates while preserving order
# unique_instruments = list(dict.fromkeys(all_instruments))
# store_config_in_database(
# db_path=args.result_db,
# config_file_path=args.config,
# config=config,
# fit_method_class=fit_method_class_name,
# datafiles=datafiles,
# instruments=unique_instruments,
# )
# Process each data file
stat_model_price = config["stat_model_price"]
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found in {datafile}...")
return
# Process data for this file
try:
cointegration_data: pd.DataFrame = pd.DataFrame()
for pair in create_pairs(datafile, stat_model_price, config, instruments):
cointegration_data = pd.concat([cointegration_data, pair.cointegration_check()])
pd.set_option('display.width', 400)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
print(f"cointegration_data:\n{cointegration_data}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,771 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pairs Trading Visualization Notebook\n",
"\n",
"This notebook allows you to visualize pairs trading strategies on individual instrument pairs.\n",
"You can examine the relationship between two instruments, their dis-equilibrium, and trading signals."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🎯 Key Features:\n",
"\n",
"1. **Interactive Configuration**: \n",
" - Easy switching between CRYPTO and EQUITY configurations\n",
" - Simple parameter adjustment for thresholds and training periods\n",
"\n",
"2. **Single Pair Focus**: \n",
" - Instead of running multiple pairs, focuses on one pair at a time\n",
" - Allows deep analysis of the relationship between two instruments\n",
"\n",
"3. **Step-by-Step Visualization**:\n",
" - **Raw price data**: Individual prices, normalized comparison, and price ratios\n",
" - **Training analysis**: Cointegration testing and VECM model fitting\n",
" - **Dis-equilibrium visualization**: Both raw and scaled dis-equilibrium with threshold lines\n",
" - **Strategy execution**: Trading signal generation and visualization\n",
" - **Prediction analysis**: Actual vs predicted prices with trading signals overlaid\n",
"\n",
"4. **Rich Analytics**:\n",
" - Cointegration status and VECM model details\n",
" - Statistical summaries for all stages\n",
" - Threshold crossing analysis\n",
" - Trading signal breakdown\n",
"\n",
"5. **Interactive Experimentation**:\n",
" - Easy parameter modification\n",
" - Re-run capabilities for different configurations\n",
" - Support for both StaticFitStrategy and SlidingFitStrategy\n",
"\n",
"### 🚀 How to Use:\n",
"\n",
"1. **Start Jupyter**:\n",
" ```bash\n",
" cd src/notebooks\n",
" jupyter notebook pairs_trading_visualization.ipynb\n",
" ```\n",
"\n",
"2. **Customize Your Analysis**:\n",
" - Change `SYMBOL_A` and `SYMBOL_B` to your desired trading pair\n",
" - Switch between `CRYPTO_CONFIG` and `EQT_CONFIG`\n",
" - Only **StaticFitStrategy** is supported. \n",
" - Adjust thresholds and parameters as needed\n",
"\n",
"3. **Run and Visualize**:\n",
" - Execute cells step by step to see the analysis unfold\n",
" - Rich matplotlib visualizations show relationships and signals\n",
" - Comprehensive summary at the end\n",
"\n",
"The notebook provides exactly what you requested - a way to visualize the relationship between two instruments and their scaled dis-equilibrium, with all the stages of your pairs trading strategy clearly displayed and analyzed.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setup complete!\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"sys.path.append('..')\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from typing import Dict, List, Optional\n",
"\n",
"# Import our modules\n",
"from pt_trading.fit_methods import StaticFit, SlidingFit\n",
"from tools.data_loader import load_market_data\n",
"from pt_trading.trading_pair import TradingPair\n",
"from pt_trading.results import BacktestResult\n",
"\n",
"# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n",
"sns.set_palette(\"husl\")\n",
"plt.rcParams['figure.figsize'] = (12, 8)\n",
"\n",
"print(\"Setup complete!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using EQUITY configuration\n",
"Available instruments: ['COIN', 'GBTC', 'HOOD', 'MSTR', 'PYPL']\n"
]
}
],
"source": [
"# Configuration - Choose between CRYPTO_CONFIG or EQT_CONFIG\n",
"\n",
"CRYPTO_CONFIG = {\n",
" \"security_type\": \"CRYPTO\",\n",
" \"data_directory\": \"../../data/crypto\",\n",
" \"datafiles\": [\n",
" \"20250519.mktdata.ohlcv.db\",\n",
" ],\n",
" \"db_table_name\": \"bnbspot_ohlcv_1min\",\n",
" \"exchange_id\": \"BNBSPOT\",\n",
" \"instrument_id_pfx\": \"PAIR-\",\n",
" \"instruments\": [\n",
" \"BTC-USDT\",\n",
" \"BCH-USDT\",\n",
" \"ETH-USDT\",\n",
" \"LTC-USDT\",\n",
" \"XRP-USDT\",\n",
" \"ADA-USDT\",\n",
" \"SOL-USDT\",\n",
" \"DOT-USDT\",\n",
" ],\n",
" \"trading_hours\": {\n",
" \"begin_session\": \"00:00:00\",\n",
" \"end_session\": \"23:59:00\",\n",
" \"timezone\": \"UTC\",\n",
" },\n",
" \"price_column\": \"close\",\n",
" \"min_required_points\": 30,\n",
" \"zero_threshold\": 1e-10,\n",
" \"dis-equilibrium_open_trshld\": 2.0,\n",
" \"dis-equilibrium_close_trshld\": 0.5,\n",
" \"training_minutes\": 120,\n",
" \"funding_per_pair\": 2000.0,\n",
"}\n",
"\n",
"EQT_CONFIG = {\n",
" \"security_type\": \"EQUITY\",\n",
" \"data_directory\": \"../../data/equity\",\n",
" \"datafiles\": {\n",
" \"0508\": \"20250508.alpaca_sim_md.db\",\n",
" \"0509\": \"20250509.alpaca_sim_md.db\",\n",
" \"0510\": \"20250510.alpaca_sim_md.db\",\n",
" \"0511\": \"20250511.alpaca_sim_md.db\",\n",
" \"0512\": \"20250512.alpaca_sim_md.db\",\n",
" \"0513\": \"20250513.alpaca_sim_md.db\",\n",
" \"0514\": \"20250514.alpaca_sim_md.db\",\n",
" \"0515\": \"20250515.alpaca_sim_md.db\",\n",
" \"0516\": \"20250516.alpaca_sim_md.db\",\n",
" \"0517\": \"20250517.alpaca_sim_md.db\",\n",
" \"0518\": \"20250518.alpaca_sim_md.db\",\n",
" \"0519\": \"20250519.alpaca_sim_md.db\",\n",
" \"0520\": \"20250520.alpaca_sim_md.db\",\n",
" \"0521\": \"20250521.alpaca_sim_md.db\",\n",
" \"0522\": \"20250522.alpaca_sim_md.db\",\n",
" },\n",
" \"db_table_name\": \"md_1min_bars\",\n",
" \"exchange_id\": \"ALPACA\",\n",
" \"instrument_id_pfx\": \"STOCK-\",\n",
" \"instruments\": [\n",
" \"COIN\",\n",
" \"GBTC\",\n",
" \"HOOD\",\n",
" \"MSTR\",\n",
" \"PYPL\",\n",
" ],\n",
" \"trading_hours\": {\n",
" \"begin_session\": \"9:30:00\",\n",
" \"end_session\": \"16:00:00\",\n",
" \"timezone\": \"America/New_York\",\n",
" },\n",
" \"price_column\": \"close\",\n",
" \"min_required_points\": 30,\n",
" \"zero_threshold\": 1e-10,\n",
" \"dis-equilibrium_open_trshld\": 2.0,\n",
" \"dis-equilibrium_close_trshld\": 1.0, #0.5,\n",
" \"training_minutes\": 120,\n",
" \"funding_per_pair\": 2000.0,\n",
"}\n",
"\n",
"# Choose your configuration\n",
"CONFIG = EQT_CONFIG # Change to CRYPTO_CONFIG if you want to use crypto data\n",
"\n",
"print(f\"Using {CONFIG['security_type']} configuration\")\n",
"print(f\"Available instruments: {CONFIG['instruments']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Select Trading Pair and Data File"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected pair: COIN & GBTC\n",
"Data file: 20250509.alpaca_sim_md.db\n",
"Strategy: StaticFitStrategy\n"
]
}
],
"source": [
"# Select your trading pair and strategy\n",
"SYMBOL_A = \"COIN\" # Change these to your desired symbols\n",
"SYMBOL_B = \"GBTC\"\n",
"DATA_FILE = CONFIG[\"datafiles\"][\"0509\"]\n",
"\n",
"# Choose strategy\n",
"FIT_METHOD = StaticFit()\n",
"\n",
"print(f\"Selected pair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Strategy: {type(FIT_METHOD).__name__}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Market Data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current working directory: /home/oleg/devel/pairs_trading/src/notebooks\n",
"Loading data from: ../../data/equity/20250509.alpaca_sim_md.db\n",
"Error: Execution failed on sql 'select tstamp, tstamp_ns as time_ns, substr(instrument_id, 7) as symbol, open, high, low, close, volume, num_trades, vwap from md_1min_bars where exchange_id ='ALPACA' and instrument_id in (\"STOCK-COIN\",\"STOCK-GBTC\",\"STOCK-HOOD\",\"STOCK-MSTR\",\"STOCK-PYPL\")': no such table: md_1min_bars\n"
]
},
{
"ename": "Exception",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mOperationalError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:2664\u001b[39m, in \u001b[36mSQLiteDatabase.execute\u001b[39m\u001b[34m(self, sql, params)\u001b[39m\n\u001b[32m 2663\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2664\u001b[39m \u001b[43mcur\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43msql\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2665\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m cur\n",
"\u001b[31mOperationalError\u001b[39m: no such table: md_1min_bars",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mDatabaseError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m~/devel/pairs_trading/src/notebooks/../tools/data_loader.py:11\u001b[39m, in \u001b[36mload_sqlite_to_dataframe\u001b[39m\u001b[34m(db_path, query)\u001b[39m\n\u001b[32m 9\u001b[39m conn = sqlite3.connect(db_path)\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m df = \u001b[43mpd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mread_sql_query\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m df\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:528\u001b[39m, in \u001b[36mread_sql_query\u001b[39m\u001b[34m(sql, con, index_col, coerce_float, params, parse_dates, chunksize, dtype, dtype_backend)\u001b[39m\n\u001b[32m 527\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m pandasSQL_builder(con) \u001b[38;5;28;01mas\u001b[39;00m pandas_sql:\n\u001b[32m--> \u001b[39m\u001b[32m528\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpandas_sql\u001b[49m\u001b[43m.\u001b[49m\u001b[43mread_query\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 529\u001b[39m \u001b[43m \u001b[49m\u001b[43msql\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 530\u001b[39m \u001b[43m \u001b[49m\u001b[43mindex_col\u001b[49m\u001b[43m=\u001b[49m\u001b[43mindex_col\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 531\u001b[39m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m=\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 532\u001b[39m \u001b[43m \u001b[49m\u001b[43mcoerce_float\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcoerce_float\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 533\u001b[39m \u001b[43m \u001b[49m\u001b[43mparse_dates\u001b[49m\u001b[43m=\u001b[49m\u001b[43mparse_dates\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 534\u001b[39m \u001b[43m \u001b[49m\u001b[43mchunksize\u001b[49m\u001b[43m=\u001b[49m\u001b[43mchunksize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 535\u001b[39m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 536\u001b[39m \u001b[43m \u001b[49m\u001b[43mdtype_backend\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdtype_backend\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 537\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:2728\u001b[39m, in \u001b[36mSQLiteDatabase.read_query\u001b[39m\u001b[34m(self, sql, index_col, coerce_float, parse_dates, params, chunksize, dtype, dtype_backend)\u001b[39m\n\u001b[32m 2717\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mread_query\u001b[39m(\n\u001b[32m 2718\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 2719\u001b[39m sql,\n\u001b[32m (...)\u001b[39m\u001b[32m 2726\u001b[39m dtype_backend: DtypeBackend | Literal[\u001b[33m\"\u001b[39m\u001b[33mnumpy\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[33m\"\u001b[39m\u001b[33mnumpy\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 2727\u001b[39m ) -> DataFrame | Iterator[DataFrame]:\n\u001b[32m-> \u001b[39m\u001b[32m2728\u001b[39m cursor = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43msql\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2729\u001b[39m columns = [col_desc[\u001b[32m0\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m col_desc \u001b[38;5;129;01min\u001b[39;00m cursor.description]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:2676\u001b[39m, in \u001b[36mSQLiteDatabase.execute\u001b[39m\u001b[34m(self, sql, params)\u001b[39m\n\u001b[32m 2675\u001b[39m ex = DatabaseError(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mExecution failed on sql \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msql\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexc\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m2676\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ex \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mexc\u001b[39;00m\n",
"\u001b[31mDatabaseError\u001b[39m: Execution failed on sql 'select tstamp, tstamp_ns as time_ns, substr(instrument_id, 7) as symbol, open, high, low, close, volume, num_trades, vwap from md_1min_bars where exchange_id ='ALPACA' and instrument_id in (\"STOCK-COIN\",\"STOCK-GBTC\",\"STOCK-HOOD\",\"STOCK-MSTR\",\"STOCK-PYPL\")': no such table: md_1min_bars",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mException\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCurrent working directory: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mos.getcwd()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 4\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mLoading data from: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdatafile_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m market_data_df = \u001b[43mload_market_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdatafile_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m=\u001b[49m\u001b[43mCONFIG\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 8\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mLoaded \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(market_data_df)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m rows of market data\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 9\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mSymbols in data: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmarket_data_df[\u001b[33m'\u001b[39m\u001b[33msymbol\u001b[39m\u001b[33m'\u001b[39m].unique()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/devel/pairs_trading/src/notebooks/../tools/data_loader.py:69\u001b[39m, in \u001b[36mload_market_data\u001b[39m\u001b[34m(datafile, config)\u001b[39m\n\u001b[32m 66\u001b[39m query += \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m where exchange_id =\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexchange_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 67\u001b[39m query += \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m and instrument_id in (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m,\u001b[39m\u001b[33m'\u001b[39m.join(instrument_ids)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m)\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m69\u001b[39m df = \u001b[43mload_sqlite_to_dataframe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdb_path\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdatafile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m=\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 71\u001b[39m \u001b[38;5;66;03m# Trading Hours\u001b[39;00m\n\u001b[32m 72\u001b[39m date_str = df[\u001b[33m\"\u001b[39m\u001b[33mtstamp\u001b[39m\u001b[33m\"\u001b[39m][\u001b[32m0\u001b[39m][\u001b[32m0\u001b[39m:\u001b[32m10\u001b[39m]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/devel/pairs_trading/src/notebooks/../tools/data_loader.py:18\u001b[39m, in \u001b[36mload_sqlite_to_dataframe\u001b[39m\u001b[34m(db_path, query)\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m excpt:\n\u001b[32m 17\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mError: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexcpt\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m() \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mexcpt\u001b[39;00m\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mconn\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlocals\u001b[39m():\n",
"\u001b[31mException\u001b[39m: "
]
}
],
"source": [
"# Load market data\n",
"datafile_path = f\"{CONFIG['data_directory']}/{DATA_FILE}\"\n",
"print(f\"Current working directory: {os.getcwd()}\")\n",
"print(f\"Loading data from: {datafile_path}\")\n",
"\n",
"market_data_df = load_market_data(datafile_path, config=CONFIG)\n",
"\n",
"print(f\"Loaded {len(market_data_df)} rows of market data\")\n",
"print(f\"Symbols in data: {market_data_df['symbol'].unique()}\")\n",
"print(f\"Time range: {market_data_df['tstamp'].min()} to {market_data_df['tstamp'].max()}\")\n",
"\n",
"# Display first few rows\n",
"market_data_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Trading Pair and Analyze"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create trading pair\n",
"pair = TradingPair(\n",
" market_data=market_data_df,\n",
" symbol_a=SYMBOL_A,\n",
" symbol_b=SYMBOL_B,\n",
" price_column=CONFIG[\"price_column\"]\n",
")\n",
"\n",
"print(f\"Created trading pair: {pair}\")\n",
"print(f\"Market data shape: {pair.market_data_.shape}\")\n",
"print(f\"Column names: {pair.colnames()}\")\n",
"\n",
"# Display first few rows of pair data\n",
"pair.market_data_.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split Data into Training and Testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing datasets\n",
"training_minutes = CONFIG[\"training_minutes\"]\n",
"pair.get_datasets(training_minutes=training_minutes)\n",
"\n",
"print(f\"Training data: {len(pair.training_df_)} rows\")\n",
"print(f\"Testing data: {len(pair.testing_df_)} rows\")\n",
"print(f\"Training period: {pair.training_df_['tstamp'].iloc[0]} to {pair.training_df_['tstamp'].iloc[-1]}\")\n",
"print(f\"Testing period: {pair.testing_df_['tstamp'].iloc[0]} to {pair.testing_df_['tstamp'].iloc[-1]}\")\n",
"\n",
"# Check for any missing data\n",
"print(f\"Training data null values: {pair.training_df_.isnull().sum().sum()}\")\n",
"print(f\"Testing data null values: {pair.testing_df_.isnull().sum().sum()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Raw Price Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Plot raw price data\n",
"fig, axes = plt.subplots(3, 1, figsize=(15, 12))\n",
"\n",
"# Combined price plot\n",
"colname_a, colname_b = pair.colnames()\n",
"all_data = pd.concat([pair.training_df_, pair.testing_df_]).reset_index(drop=True)\n",
"\n",
"# Plot individual prices\n",
"axes[0].plot(all_data['tstamp'], all_data[colname_a], label=f'{SYMBOL_A}', alpha=0.8)\n",
"axes[0].plot(all_data['tstamp'], all_data[colname_b], label=f'{SYMBOL_B}', alpha=0.8)\n",
"axes[0].axvline(x=pair.training_df_['tstamp'].iloc[-1], color='red', linestyle='--', alpha=0.7, label='Train/Test Split')\n",
"axes[0].set_title(f'Price Comparison: {SYMBOL_A} vs {SYMBOL_B}')\n",
"axes[0].set_ylabel('Price')\n",
"axes[0].legend()\n",
"axes[0].grid(True)\n",
"\n",
"# Normalized prices for comparison\n",
"norm_a = all_data[colname_a] / all_data[colname_a].iloc[0]\n",
"norm_b = all_data[colname_b] / all_data[colname_b].iloc[0]\n",
"\n",
"axes[1].plot(all_data['tstamp'], norm_a, label=f'{SYMBOL_A} (normalized)', alpha=0.8)\n",
"axes[1].plot(all_data['tstamp'], norm_b, label=f'{SYMBOL_B} (normalized)', alpha=0.8)\n",
"axes[1].axvline(x=pair.training_df_['tstamp'].iloc[-1], color='red', linestyle='--', alpha=0.7, label='Train/Test Split')\n",
"axes[1].set_title('Normalized Price Comparison')\n",
"axes[1].set_ylabel('Normalized Price')\n",
"axes[1].legend()\n",
"axes[1].grid(True)\n",
"\n",
"# Price ratio\n",
"price_ratio = all_data[colname_a] / all_data[colname_b]\n",
"axes[2].plot(all_data['tstamp'], price_ratio, label=f'{SYMBOL_A}/{SYMBOL_B} Ratio', color='green', alpha=0.8)\n",
"axes[2].axvline(x=pair.training_df_['tstamp'].iloc[-1], color='red', linestyle='--', alpha=0.7, label='Train/Test Split')\n",
"axes[2].set_title('Price Ratio')\n",
"axes[2].set_ylabel('Ratio')\n",
"axes[2].set_xlabel('Time')\n",
"axes[2].legend()\n",
"axes[2].grid(True)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Pair and Check Cointegration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train the pair and check cointegration\n",
"try:\n",
" is_cointegrated = pair.train_pair()\n",
" print(f\"Pair {pair} cointegration status: {is_cointegrated}\")\n",
"\n",
" if is_cointegrated:\n",
" print(f\"VECM Beta coefficients: {pair.vecm_fit_.beta.flatten()}\")\n",
" print(f\"Training dis-equilibrium mean: {pair.training_mu_:.6f}\")\n",
" print(f\"Training dis-equilibrium std: {pair.training_std_:.6f}\")\n",
"\n",
" # Display VECM summary\n",
" print(\"\\nVECM Model Summary:\")\n",
" print(pair.vecm_fit_.summary())\n",
" else:\n",
" print(\"Pair is not cointegrated. Cannot proceed with strategy.\")\n",
"\n",
"except Exception as e:\n",
" print(f\"Training failed: {str(e)}\")\n",
" is_cointegrated = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Training Period Dis-equilibrium"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if is_cointegrated:\n",
" # fig, axes = plt.subplots(, 1, figsize=(15, 10))\n",
"\n",
" # # Raw dis-equilibrium\n",
" # axes[0].plot(pair.training_df_['tstamp'], pair.training_df_['dis-equilibrium'],\n",
" # color='blue', alpha=0.8, label='Raw Dis-equilibrium')\n",
" # axes[0].axhline(y=pair.training_mu_, color='red', linestyle='--', alpha=0.7, label='Mean')\n",
" # axes[0].axhline(y=pair.training_mu_ + pair.training_std_, color='orange', linestyle='--', alpha=0.5, label='+1 Std')\n",
" # axes[0].axhline(y=pair.training_mu_ - pair.training_std_, color='orange', linestyle='--', alpha=0.5, label='-1 Std')\n",
" # axes[0].set_title('Training Period: Raw Dis-equilibrium')\n",
" # axes[0].set_ylabel('Dis-equilibrium')\n",
" # axes[0].legend()\n",
" # axes[0].grid(True)\n",
"\n",
" # Scaled dis-equilibrium\n",
" fig, axes = plt.subplots(1, 1, figsize=(15, 5))\n",
" axes.plot(pair.training_df_['tstamp'], pair.training_df_['scaled_dis-equilibrium'],\n",
" color='green', alpha=0.8, label='Scaled Dis-equilibrium')\n",
" axes.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='Mean (0)')\n",
" axes.axhline(y=1, color='orange', linestyle='--', alpha=0.5, label='+1 Std')\n",
" axes.axhline(y=-1, color='orange', linestyle='--', alpha=0.5, label='-1 Std')\n",
" axes.axhline(y=CONFIG['dis-equilibrium_open_trshld'], color='purple',\n",
" linestyle=':', alpha=0.7, label=f\"Open Threshold ({CONFIG['dis-equilibrium_open_trshld']})\")\n",
" axes.axhline(y=CONFIG['dis-equilibrium_close_trshld'], color='brown',\n",
" linestyle=':', alpha=0.7, label=f\"Close Threshold ({CONFIG['dis-equilibrium_close_trshld']})\")\n",
" axes.set_title('Training Period: Scaled Dis-equilibrium')\n",
" axes.set_ylabel('Scaled Dis-equilibrium')\n",
" axes.set_xlabel('Time')\n",
" axes.legend()\n",
" axes.grid(True)\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" # Print statistics\n",
" print(f\"Training dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.training_df_['dis-equilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.training_df_['dis-equilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.training_df_['dis-equilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.training_df_['dis-equilibrium'].max():.6f}\")\n",
"\n",
" print(f\"\\nScaled dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.training_df_['scaled_dis-equilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.training_df_['scaled_dis-equilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.training_df_['scaled_dis-equilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.training_df_['scaled_dis-equilibrium'].max():.6f}\")\n",
"else:\n",
" print(\"The pair is not cointegrated\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate Predictions and Run Strategy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if is_cointegrated:\n",
" try:\n",
" # Generate predictions\n",
" pair.predict()\n",
" print(f\"Generated predictions for {len(pair.predicted_df_)} rows\")\n",
"\n",
" # Display prediction data structure\n",
" print(f\"Prediction columns: {list(pair.predicted_df_.columns)}\")\n",
" print(f\"Prediction period: {pair.predicted_df_['tstamp'].iloc[0]} to {pair.predicted_df_['tstamp'].iloc[-1]}\")\n",
"\n",
" # Run strategy\n",
" bt_result = BacktestResult(config=CONFIG)\n",
" pair_trades = FIT_METHOD.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n",
"\n",
" if pair_trades is not None and len(pair_trades) > 0:\n",
" print(f\"\\nGenerated {len(pair_trades)} trading signals:\")\n",
" print(pair_trades)\n",
" else:\n",
" print(\"\\nNo trading signals generated\")\n",
"\n",
" except Exception as e:\n",
" print(f\"Prediction/Strategy failed: {str(e)}\")\n",
" pair_trades = None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Predictions and Dis-equilibrium"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if is_cointegrated and hasattr(pair, 'predicted_df_'):\n",
" fig, axes = plt.subplots(4, 1, figsize=(16, 16))\n",
"\n",
" # Actual vs Predicted Prices\n",
" colname_a, colname_b = pair.colnames()\n",
"\n",
" axes[0].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[colname_a],\n",
" label=f'{SYMBOL_A} Actual', alpha=0.8)\n",
" axes[0].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[f'{colname_a}_pred'],\n",
" label=f'{SYMBOL_A} Predicted', alpha=0.8, linestyle='--')\n",
" axes[0].set_title('Actual vs Predicted Prices - Symbol A')\n",
" axes[0].set_ylabel('Price')\n",
" axes[0].legend()\n",
" axes[0].grid(True)\n",
"\n",
" axes[1].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[colname_b],\n",
" label=f'{SYMBOL_B} Actual', alpha=0.8)\n",
" axes[1].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[f'{colname_b}_pred'],\n",
" label=f'{SYMBOL_B} Predicted', alpha=0.8, linestyle='--')\n",
" axes[1].set_title('Actual vs Predicted Prices - Symbol B')\n",
" axes[1].set_ylabel('Price')\n",
" axes[1].legend()\n",
" axes[1].grid(True)\n",
"\n",
" # Raw dis-equilibrium\n",
" axes[2].plot(pair.predicted_df_['tstamp'], pair.predicted_df_['disequilibrium'],\n",
" color='blue', alpha=0.8, label='Dis-equilibrium')\n",
" axes[2].axhline(y=pair.training_mu_, color='red', linestyle='--', alpha=0.7, label='Training Mean')\n",
" axes[2].set_title('Testing Period: Raw Dis-equilibrium')\n",
" axes[2].set_ylabel('Dis-equilibrium')\n",
" axes[2].legend()\n",
" axes[2].grid(True)\n",
"\n",
" # Scaled dis-equilibrium with trading signals\n",
" axes[3].plot(pair.predicted_df_['tstamp'], pair.predicted_df_['scaled_disequilibrium'],\n",
" color='green', alpha=0.8, label='Scaled Dis-equilibrium')\n",
"\n",
" # Add threshold lines\n",
" axes[3].axhline(y=CONFIG['dis-equilibrium_open_trshld'], color='purple',\n",
" linestyle=':', alpha=0.7, label=f\"Open Threshold ({CONFIG['dis-equilibrium_open_trshld']})\")\n",
" axes[3].axhline(y=CONFIG['dis-equilibrium_close_trshld'], color='brown',\n",
" linestyle=':', alpha=0.7, label=f\"Close Threshold ({CONFIG['dis-equilibrium_close_trshld']})\")\n",
"\n",
" # Add trading signals if they exist\n",
" if pair_trades is not None and len(pair_trades) > 0:\n",
" for _, trade in pair_trades.iterrows():\n",
" color = 'red' if 'BUY' in trade['action'] else 'blue'\n",
" marker = '^' if 'BUY' in trade['action'] else 'v'\n",
" axes[3].scatter(trade['time'], trade['scaled_disequilibrium'],\n",
" color=color, marker=marker, s=100, alpha=0.8,\n",
" label=f\"{trade['action']} {trade['symbol']}\" if _ < 2 else \"\")\n",
"\n",
" axes[3].set_title('Testing Period: Scaled Dis-equilibrium with Trading Signals')\n",
" axes[3].set_ylabel('Scaled Dis-equilibrium')\n",
" axes[3].set_xlabel('Time')\n",
" axes[3].legend()\n",
" axes[3].grid(True)\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" # Print prediction statistics\n",
" print(f\"\\nTesting dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.predicted_df_['disequilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.predicted_df_['disequilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.predicted_df_['disequilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.predicted_df_['disequilibrium'].max():.6f}\")\n",
"\n",
" print(f\"\\nTesting scaled dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.predicted_df_['scaled_disequilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.predicted_df_['scaled_disequilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.predicted_df_['scaled_disequilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.predicted_df_['scaled_disequilibrium'].max():.6f}\")\n",
"\n",
" # Count threshold crossings\n",
" open_crossings = (pair.predicted_df_['scaled_disequilibrium'] >= CONFIG['dis-equilibrium_open_trshld']).sum()\n",
" close_crossings = (pair.predicted_df_['scaled_disequilibrium'] <= CONFIG['dis-equilibrium_close_trshld']).sum()\n",
" print(f\"\\nThreshold crossings:\")\n",
" print(f\" Open threshold ({CONFIG['dis-equilibrium_open_trshld']}): {open_crossings} times\")\n",
" print(f\" Close threshold ({CONFIG['dis-equilibrium_close_trshld']}): {close_crossings} times\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary and Analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 60)\n",
"print(\"PAIRS TRADING ANALYSIS SUMMARY\")\n",
"print(\"=\" * 60)\n",
"\n",
"print(f\"\\nPair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Strategy: {type(FIT_METHOD).__name__}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Training period: {training_minutes} minutes\")\n",
"\n",
"print(f\"\\nCointegration Status: {'✓ COINTEGRATED' if is_cointegrated else '✗ NOT COINTEGRATED'}\")\n",
"\n",
"if is_cointegrated:\n",
" print(f\"\\nVECM Model:\")\n",
" print(f\" Beta coefficients: {pair.vecm_fit_.beta.flatten()}\")\n",
" print(f\" Training mean: {pair.training_mu_:.6f}\")\n",
" print(f\" Training std: {pair.training_std_:.6f}\")\n",
"\n",
" if pair_trades is not None and len(pair_trades) > 0:\n",
" print(f\"\\nTrading Signals: {len(pair_trades)} generated\")\n",
" unique_times = pair_trades['time'].unique()\n",
" print(f\" Unique trade times: {len(unique_times)}\")\n",
"\n",
" # Group by time to see paired trades\n",
" for trade_time in unique_times:\n",
" trades_at_time = pair_trades[pair_trades['time'] == trade_time]\n",
" print(f\"\\n Trade at {trade_time}:\")\n",
" for _, trade in trades_at_time.iterrows():\n",
" print(f\" {trade['action']} {trade['symbol']} @ ${trade['price']:.2f} (dis-eq: {trade['scaled_disequilibrium']:.2f})\")\n",
" else:\n",
" print(f\"\\nTrading Signals: None generated\")\n",
" print(\" Possible reasons:\")\n",
" print(\" - Dis-equilibrium never exceeded open threshold\")\n",
" print(\" - Insufficient testing data\")\n",
" print(\" - Strategy-specific conditions not met\")\n",
"\n",
"else:\n",
" print(\"\\nCannot proceed with trading strategy - pair is not cointegrated\")\n",
" print(\"Consider:\")\n",
" print(\" - Trying different symbol pairs\")\n",
" print(\" - Adjusting training period length\")\n",
" print(\" - Using different data timeframe\")\n",
"\n",
"print(\"\\n\" + \"=\" * 60)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interactive Analysis (Optional)\n",
"\n",
"You can modify the parameters below and re-run the analysis:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Interactive parameter adjustment\n",
"print(\"Current parameters:\")\n",
"print(f\" Open threshold: {CONFIG['dis-equilibrium_open_trshld']}\")\n",
"print(f\" Close threshold: {CONFIG['dis-equilibrium_close_trshld']}\")\n",
"print(f\" Training minutes: {CONFIG['training_minutes']}\")\n",
"\n",
"# Uncomment and modify these to experiment:\n",
"# CONFIG['dis-equilibrium_open_trshld'] = 1.5\n",
"# CONFIG['dis-equilibrium_close_trshld'] = 0.3\n",
"# CONFIG['training_minutes'] = 180\n",
"\n",
"print(\"\\nTo re-run with different parameters:\")\n",
"print(\"1. Modify the parameters above\")\n",
"print(\"2. Re-run from the 'Split Data into Training and Testing' cell\")\n",
"print(\"3. Or try different symbol pairs by changing SYMBOL_A and SYMBOL_B\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3.12-venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

File diff suppressed because one or more lines are too long

View File

@ -3,100 +3,104 @@ import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import pandas as pd
from research.research_tools import create_pairs
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
DayT = str
DataFileNameT = str
def resolve_datafiles(
config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
) -> List[Tuple[DayT, DataFileNameT]]:
resolved_files: List[Tuple[DayT, DataFileNameT]] = []
for inst in instruments:
pattern = date_pattern
inst_type = inst["instrument_type"]
data_dir = config["market_data_loading"][inst_type]["data_directory"]
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
"""
Resolve the list of data files to process.
CLI datafiles take priority over config datafiles.
Supports wildcards in config but not in CLI.
"""
if cli_datafiles:
# CLI override - comma-separated list, no wildcards
datafiles = [f.strip() for f in cli_datafiles.split(",")]
# Make paths absolute relative to data directory
data_dir = config.get("data_directory", "./data")
resolved_files = []
for df in datafiles:
if not os.path.isabs(df):
df = os.path.join(data_dir, df)
resolved_files.append(df)
return resolved_files
# Use config datafiles with wildcard support
config_datafiles = config.get("datafiles", [])
data_dir = config.get("data_directory", "./data")
resolved_files = []
for pattern in config_datafiles:
if "*" in pattern or "?" in pattern:
# Handle wildcards
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
pattern = os.path.join(data_dir, pattern)
matched_files = glob.glob(pattern)
for matched_file in matched_files:
import re
match = re.search(r"(\d{8})\.mktdata\.ohlcv\.db$", matched_file)
assert match is not None
day = match.group(1)
resolved_files.append((day, matched_file))
resolved_files.extend(matched_files)
else:
# Handle explicit file path
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
resolved_files.append((date_pattern, pattern))
pattern = os.path.join(data_dir, pattern)
resolved_files.append(pattern)
return sorted(list(set(resolved_files))) # Remove duplicates and sort
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
instruments = [
{
"symbol": inst.split(":")[0],
"instrument_type": inst.split(":")[1],
"exchange_id": inst.split(":")[2],
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
"instrument_id_pfx"
],
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
"db_table_name"
],
}
for inst in args.instruments.split(",")
]
return instruments
def run_backtest(
config: Dict,
datafiles: List[str],
datafile: str,
price_column: str,
fit_method: PairsTradingFitMethod,
instruments: List[Dict[str, str]],
instruments: List[str],
) -> BacktestResult:
"""
Run backtest for all pairs using the specified instruments.
"""
bt_result: BacktestResult = BacktestResult(config=config)
# if len(datafiles) < 2:
# print(f"WARNING: insufficient data files: {datafiles}")
# return bt_result
if not all([os.path.exists(datafile) for datafile in datafiles]):
print(f"WARNING: data file {datafiles} does not exist")
return bt_result
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]:
nonlocal datafile
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy)
for a_index, b_index in unique_index_pairs:
pair = TradingPair(
config=config_copy,
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
price_column=price_column,
)
pairs.append(pair)
return pairs
pairs_trades = []
pairs = create_pairs(
datafiles=datafiles,
fit_method=fit_method,
config=config,
instruments=instruments,
)
for pair in pairs:
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair(
pair=pair, bt_result=bt_result
)
if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades)
print(f"pairs_trades:\n{pairs_trades}")
print(f"pairs_trades: {pairs_trades}")
# Check if result_list has any data before concatenating
if len(pairs_trades) == 0:
print("No trading signals found for any pairs")
@ -105,22 +109,23 @@ def run_backtest(
bt_result.collect_single_day_results(pairs_trades)
return bt_result
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--date_pattern",
"--datafiles",
type=str,
required=True,
help="Date YYYYMMDD, allows * and ? wildcards",
required=False,
help="Comma-separated list of data files (overrides config). No wildcards supported.",
)
parser.add_argument(
"--instruments",
type=str,
required=True,
help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
parser.add_argument(
"--result_db",
@ -134,13 +139,19 @@ def main() -> None:
config: Dict = load_config(args.config)
# Dynamically instantiate fit method class
fit_method = PairsTradingFitMethod.create(config)
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config)
instruments = get_instruments(args, config)
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
datafiles = resolve_datafiles(config, args.datafiles)
if not datafiles:
print("No data files found to process.")
return
days = list(set([day for day, _ in datafiles]))
print(f"Found {len(datafiles)} data files to process:")
for df in datafiles:
print(f" - {df}")
@ -152,26 +163,51 @@ def main() -> None:
# Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {}
is_config_stored = False
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
)
# Process each data file
price_column = config["price_column"]
for day in sorted(days):
md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day]
if not all([os.path.exists(datafile) for datafile in md_datafiles]):
print(f"WARNING: insufficient data files: {md_datafiles}")
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found for {datafile}, skipping...")
continue
print(f"\n====== Processing {day} ======")
if not is_config_stored:
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=config["fit_method_class"],
datafiles=datafiles,
instruments=instruments,
)
is_config_stored = True
# Process data for this file
try:
@ -179,17 +215,14 @@ def main() -> None:
bt_results = run_backtest(
config=config,
datafiles=md_datafiles,
datafile=datafile,
price_column=price_column,
fit_method=fit_method,
instruments=instruments,
)
if bt_results.trades is None or len(bt_results.trades) == 0:
print(f"No trades found for {day}")
continue
# Store results with day name as key
filename = os.path.basename(day)
# Store results with file name as key
filename = os.path.basename(datafile)
all_results[filename] = {
"trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(),
@ -197,20 +230,12 @@ def main() -> None:
# Store results in database
if args.result_db.upper() != "NONE":
bt_results.calculate_returns(
{
filename: {
"trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(),
}
}
)
bt_results.store_results_in_database(db_path=args.result_db, day=day)
store_results_in_database(args.result_db, datafile, bt_results)
print(f"Successfully processed {filename}")
except Exception as err:
print(f"Error processing {day}: {str(err)}")
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()

View File

@ -1,94 +0,0 @@
import glob
import os
from typing import Dict, List, Optional
import pandas as pd
from pt_trading.fit_method import PairsTradingFitMethod
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
"""
Resolve the list of data files to process.
CLI datafiles take priority over config datafiles.
Supports wildcards in config but not in CLI.
"""
if cli_datafiles:
# CLI override - comma-separated list, no wildcards
datafiles = [f.strip() for f in cli_datafiles.split(",")]
# Make paths absolute relative to data directory
data_dir = config.get("data_directory", "./data")
resolved_files = []
for df in datafiles:
if not os.path.isabs(df):
df = os.path.join(data_dir, df)
resolved_files.append(df)
return resolved_files
# Use config datafiles with wildcard support
config_datafiles = config.get("datafiles", [])
data_dir = config.get("data_directory", "./data")
resolved_files = []
for pattern in config_datafiles:
if "*" in pattern or "?" in pattern:
# Handle wildcards
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
matched_files = glob.glob(pattern)
resolved_files.extend(matched_files)
else:
# Handle explicit file path
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
resolved_files.append(pattern)
return sorted(list(set(resolved_files))) # Remove duplicates and sort
def create_pairs(
datafiles: List[str],
fit_method: PairsTradingFitMethod,
config: Dict,
instruments: List[Dict[str, str]],
) -> List:
from pt_trading.trading_pair import TradingPair
from tools.data_loader import load_market_data
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = pd.DataFrame()
extra_minutes = 0
if "execution_price" in config_copy:
extra_minutes = config_copy["execution_price"]["shift"]
for datafile in datafiles:
md_df = load_market_data(
datafile=datafile,
instruments=instruments,
db_table_name=config_copy["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
trading_hours=config_copy["trading_hours"],
extra_minutes=extra_minutes,
)
market_data_df = pd.concat([market_data_df, md_df])
if len(set(market_data_df["symbol"])) != 2: # both symbols must be present for a pair
print(f"WARNING: insufficient data in files: {datafiles}")
return []
for a_index, b_index in unique_index_pairs:
symbol_a=instruments[a_index]["symbol"]
symbol_b=instruments[b_index]["symbol"]
pair = fit_method.create_trading_pair(
config=config_copy,
market_data=market_data_df,
symbol_a=symbol_a,
symbol_b=symbol_b,
)
pairs.append(pair)
return pairs

View File

@ -16,12 +16,7 @@ cd $(realpath $(dirname $0))/..
mkdir -p ./data/crypto
pushd ./data/crypto
Files=$1
if [ -z "$Files" ]; then
Files="*.gz"
fi
Cmd="rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/${Files} ./"
Cmd="rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./"
echo $Cmd
eval $Cmd
# -------------------------------------

View File

@ -26,12 +26,8 @@ for srcfname in $(ls *.db.gz); do
tgtfile=${dt}.mktdata.ohlcv.db
echo "${srcfname} -> ${tgtfile}"
Cmd="gunzip -c $srcfname > temp.db && rm $srcfname"
echo ${Cmd}
eval ${Cmd}
Cmd="rm -f ${tgtfile} && sqlite3 temp.db '.dump md_1min_bars' | sqlite3 ${tgtfile}"
echo ${Cmd}
eval ${Cmd}
gunzip -c $srcfname > temp.db
rm -f ${tgtfile} && sqlite3 temp.db ".dump md_1min_bars" | sqlite3 ${tgtfile} && rm ${srcfname}
done
rm temp.db
popd

View File

@ -20,9 +20,12 @@ from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
def run_strategy(
config: Dict,
datafile: str,
price_column: str,
fit_method: PairsTradingFitMethod,
instruments: List[str],
) -> BacktestResult:
@ -41,20 +44,14 @@ def run_strategy(
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(
datafile=datafile,
exchange_id=config_copy["exchange_id"],
instruments=config_copy["instruments"],
instrument_id_pfx=config_copy["instrument_id_pfx"],
db_table_name=config_copy["db_table_name"],
trading_hours=config_copy["trading_hours"],
)
market_data_df = load_market_data(datafile, config=config_copy)
for a_index, b_index in unique_index_pairs:
pair = fit_method.create_trading_pair(
pair = TradingPair(
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
price_column=price_column,
)
pairs.append(pair)
return pairs
@ -159,6 +156,7 @@ def main() -> None:
)
# Process each data file
price_column = config["price_column"]
for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======")
@ -184,6 +182,7 @@ def main() -> None:
bt_results = run_strategy(
config=config,
datafile=datafile,
price_column=price_column,
fit_method=fit_method,
instruments=instruments,
)