Compare commits

..

26 Commits

Author SHA1 Message Date
Cryptoval Trading Technologies
6f845d32c6 . 2025-07-25 22:13:49 +00:00
Oleg Sheynin
71822c64b0 progress 2025-07-25 20:39:59 +00:00
Oleg Sheynin
c2f701e3a2 progress 2025-07-25 20:20:23 +00:00
Oleg Sheynin
21a473a4c2 fix close position trades 2025-07-25 18:21:52 +00:00
Oleg Sheynin
98a15d301a bug fix - multiple dates 2025-07-25 07:04:44 +00:00
Oleg Sheynin
bcf4447cb6 bug fixes 2025-07-25 06:39:17 +00:00
Oleg Sheynin
1af35000ab cleaning 2025-07-25 01:28:59 +00:00
Oleg Sheynin
2c08b6f1a9 intermarket fix for weekends 2025-07-25 00:47:19 +00:00
Oleg Sheynin
24f1f82d1f fixes to notebook 2025-07-24 22:45:21 +00:00
Oleg Sheynin
af0a6f62a9 progress 2025-07-24 21:09:13 +00:00
Oleg Sheynin
a7b4777f76 bug fix 2025-07-24 07:44:33 +00:00
Oleg Sheynin
e30b0df4db progress and result.py fixes 2025-07-24 06:51:46 +00:00
Oleg Sheynin
577fb5c109 notebook progress 2025-07-23 03:32:43 +00:00
Oleg Sheynin
e0138907be progress 2025-07-23 02:56:00 +00:00
Oleg Sheynin
b7292c11f3 notebook cleaning 2025-07-23 02:11:02 +00:00
Oleg Sheynin
aac8b9dc50 fixes 2025-07-22 18:04:23 +00:00
Oleg Sheynin
9bb36dddd7 notebook fixes 2025-07-22 17:42:14 +00:00
Oleg Sheynin
31eb9f800c bug fix 2025-07-22 17:25:16 +00:00
Oleg Sheynin
0e83142d0a progress: added zscore fit 2025-07-22 00:20:14 +00:00
Oleg Sheynin
b87b40a6ed progress 2025-07-21 05:15:33 +00:00
Oleg Sheynin
28386cdf12 fix trading pair, loading scripts 2025-07-20 18:11:45 +00:00
Oleg Sheynin
fb3dc68a1d minor: rename 2025-07-19 01:49:46 +00:00
Oleg Sheynin
c776c95d69 progress: stop signals 2025-07-19 01:04:09 +00:00
Oleg Sheynin
ca9fff8d88 progress 2025-07-18 23:13:11 +00:00
Oleg Sheynin
705330a9f7 progress 2025-07-18 22:51:29 +00:00
Oleg Sheynin
2272a31765 cointegration test initial 2025-07-17 00:19:49 +00:00
26 changed files with 8419 additions and 6549 deletions

View File

@ -11,6 +11,7 @@ The enhanced `pt_backtest.py` script now supports multi-day and multi-instrument
- Support for wildcard patterns in configuration files - Support for wildcard patterns in configuration files
- CLI override for data file specification - CLI override for data file specification
### 2. Dynamic Instrument Selection ### 2. Dynamic Instrument Selection
- Auto-detection of instruments from database - Auto-detection of instruments from database
- CLI override for instrument specification - CLI override for instrument specification

View File

@ -38,15 +38,12 @@ CONFIG = EQT_CONFIG # For equity data
``` ```
Each configuration dictionary specifies: Each configuration dictionary specifies:
- `security_type`: "CRYPTO" or "EQUITY".
- `data_directory`: Path to the data files. - `data_directory`: Path to the data files.
- `datafiles`: A list of database files to process. You can comment/uncomment specific files to include/exclude them from the backtest. - `datafiles`: A list of database files to process. You can comment/uncomment specific files to include/exclude them from the backtest.
- `db_table_name`: The name of the table within the SQLite database. - `db_table_name`: The name of the table within the SQLite database.
- `instruments`: A list of symbols to consider for forming trading pairs. - `instruments`: A list of symbols to consider for forming trading pairs.
- `trading_hours`: Defines the session start and end times, crucial for equity markets. - `trading_hours`: Defines the session start and end times, crucial for equity markets.
- `price_column`: The column in the data to be used as the price (e.g., "close"). - `stat_model_price`: The column in the data to be used as the price (e.g., "close").
- `min_required_points`: Minimum data points needed for statistical calculations.
- `zero_threshold`: A small value to handle potential division by zero.
- `dis-equilibrium_open_trshld`: The threshold (in standard deviations) of the dis-equilibrium for opening a trade. - `dis-equilibrium_open_trshld`: The threshold (in standard deviations) of the dis-equilibrium for opening a trade.
- `dis-equilibrium_close_trshld`: The threshold (in standard deviations) of the dis-equilibrium for closing an open trade. - `dis-equilibrium_close_trshld`: The threshold (in standard deviations) of the dis-equilibrium for closing an open trade.
- `training_minutes`: The length of the rolling window (in minutes) used to train the model (e.g., calculate cointegration, mean, and standard deviation of the dis-equilibrium). - `training_minutes`: The length of the rolling window (in minutes) used to train the model (e.g., calculate cointegration, mean, and standard deviation of the dis-equilibrium).

View File

@ -1,31 +0,0 @@
{
"security_type": "CRYPTO",
"data_directory": "./data/crypto",
"datafiles": [
"2025*.mktdata.ohlcv.db"
],
"db_table_name": "md_1min_bars",
"exchange_id": "BNBSPOT",
"instrument_id_pfx": "PAIR-",
"trading_hours": {
"begin_session": "00:00:00",
"end_session": "23:59:00",
"timezone": "UTC"
},
"price_column": "close",
"min_required_points": 30,
"zero_threshold": 1e-10,
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120,
"funding_per_pair": 2000.0,
"fit_method_class": "pt_trading.sliding_fit.SlidingFit",
# "fit_method_class": "pt_trading.static_fit.StaticFit",
"close_outstanding_positions": true,
"trading_hours": {
"begin_session": "06:00:00",
"end_session": "16:00:00",
"timezone": "America/New_York"
}
}

View File

@ -1,27 +0,0 @@
{
"security_type": "EQUITY",
"data_directory": "./data/equity",
# "datafiles": [
# "20250604.mktdata.ohlcv.db",
# ],
"db_table_name": "md_1min_bars",
"exchange_id": "ALPACA",
"instrument_id_pfx": "STOCK-",
"trading_hours": {
"begin_session": "9:30:00",
"end_session": "16:00:00",
"timezone": "America/New_York"
},
"price_column": "close",
"min_required_points": 30,
"zero_threshold": 1e-10,
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120,
"funding_per_pair": 2000.0,
"fit_method_class": "pt_trading.sliding_fit.SlidingFit",
# "fit_method_class": "pt_trading.static_fit.StaticFit",
"exclude_instruments": ["CAN"],
"close_outstanding_positions": false
}

43
configuration/vecm.cfg Normal file
View File

@ -0,0 +1,43 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"stat_model_price": "close", # "vwap"
"execution_price": {
"column": "vwap",
"shift": 1,
},
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 1.0,
"training_minutes": 120,
"fit_method_class": "pt_trading.vecm_rolling_fit.VECMRollingFit",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
"close_outstanding_positions": true,
# "close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00",
"end_session": "18:30:00",
}
}

42
configuration/zscore.cfg Normal file
View File

@ -0,0 +1,42 @@
{
"market_data_loading": {
"CRYPTO": {
"data_directory": "./data/crypto",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "PAIR-",
},
"EQUITY": {
"data_directory": "./data/equity",
"db_table_name": "md_1min_bars",
"instrument_id_pfx": "STOCK-",
}
},
# ====== Funding ======
"funding_per_pair": 2000.0,
# ====== Trading Parameters ======
"stat_model_price": "close",
"execution_price": {
"column": "vwap",
"shift": 1,
},
"dis-equilibrium_open_trshld": 2.0,
"dis-equilibrium_close_trshld": 0.5,
"training_minutes": 120,
"fit_method_class": "pt_trading.z-score_rolling_fit.ZScoreRollingFit",
# ====== Stop Conditions ======
"stop_close_conditions": {
"profit": 2.0,
"loss": -0.5
}
# ====== End of Session Closeout ======
"close_outstanding_positions": true,
# "close_outstanding_positions": false,
"trading_hours": {
"timezone": "America/New_York",
"begin_session": "9:30:00",
"end_session": "18:30:00",
}
}

View File

@ -1,8 +1,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Dict, Optional, cast from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import] import pandas as pd
from pt_trading.results import BacktestResult from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair from pt_trading.trading_pair import TradingPair
@ -12,13 +14,24 @@ NanoPerMin = 1e9
class PairsTradingFitMethod(ABC): class PairsTradingFitMethod(ABC):
TRADES_COLUMNS = [ TRADES_COLUMNS = [
"time", "time",
"action",
"symbol", "symbol",
"side",
"action",
"price", "price",
"disequilibrium", "disequilibrium",
"scaled_disequilibrium", "scaled_disequilibrium",
"signed_scaled_disequilibrium",
"pair", "pair",
] ]
@staticmethod
def create(config: Dict) -> PairsTradingFitMethod:
import importlib
fit_method_class_name = config.get("fit_method_class", None)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
return cast(PairsTradingFitMethod, fit_method)
@abstractmethod @abstractmethod
def run_pair( def run_pair(
@ -28,9 +41,12 @@ class PairsTradingFitMethod(ABC):
@abstractmethod @abstractmethod
def reset(self) -> None: ... def reset(self) -> None: ...
@abstractmethod
def create_trading_pair(
self,
config: Dict,
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
) -> TradingPair: ...
class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSED = 3
CLOSED_POSITIONS = 4

View File

@ -46,7 +46,7 @@ def create_result_database(db_path: str) -> None:
if db_dir and not os.path.exists(db_dir): if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True) os.makedirs(db_dir, exist_ok=True)
print(f"Created directory: {db_dir}") print(f"Created directory: {db_dir}")
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
@ -68,7 +68,8 @@ def create_result_database(db_path: str) -> None:
close_quantity INTEGER, close_quantity INTEGER,
close_disequilibrium REAL, close_disequilibrium REAL,
symbol_return REAL, symbol_return REAL,
pair_return REAL pair_return REAL,
close_condition TEXT
) )
""" """
) )
@ -120,8 +121,8 @@ def store_config_in_database(
config_file_path: str, config_file_path: str,
config: Dict, config: Dict,
fit_method_class: str, fit_method_class: str,
datafiles: List[str], datafiles: List[Tuple[str, str]],
instruments: List[str], instruments: List[Dict[str, str]],
) -> None: ) -> None:
""" """
Store configuration information in the database for reference. Store configuration information in the database for reference.
@ -139,8 +140,13 @@ def store_config_in_database(
config_json = json.dumps(config, indent=2, default=str) config_json = json.dumps(config, indent=2, default=str)
# Convert lists to comma-separated strings for storage # Convert lists to comma-separated strings for storage
datafiles_str = ", ".join(datafiles) datafiles_str = ", ".join([f"{datafile}" for _, datafile in datafiles])
instruments_str = ", ".join(instruments) instruments_str = ", ".join(
[
f"{inst['symbol']}:{inst['instrument_type']}:{inst['exchange_id']}"
for inst in instruments
]
)
# Insert configuration record # Insert configuration record
cursor.execute( cursor.execute(
@ -171,251 +177,23 @@ def store_config_in_database(
traceback.print_exc() traceback.print_exc()
def store_results_in_database( def convert_timestamp(timestamp: Any) -> Optional[datetime]:
db_path: str, datafile: str, bt_result: "BacktestResult" """Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
) -> None: if timestamp is None:
""" return None
Store backtest results in the SQLite database. if isinstance(timestamp, pd.Timestamp):
""" return timestamp.to_pydatetime()
if db_path.upper() == "NONE": elif isinstance(timestamp, datetime):
return
def convert_timestamp(timestamp: Any) -> Optional[datetime]:
"""Convert pandas Timestamp to Python datetime object for SQLite compatibility."""
if timestamp is None:
return None
if hasattr(timestamp, "to_pydatetime"):
return timestamp.to_pydatetime()
return timestamp return timestamp
elif isinstance(timestamp, date):
return datetime.combine(timestamp, datetime.min.time())
elif isinstance(timestamp, str):
return datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
elif isinstance(timestamp, int):
return datetime.fromtimestamp(timestamp)
else:
raise ValueError(f"Unsupported timestamp type: {type(timestamp)}")
try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
filename = os.path.basename(datafile)
date_str = filename.split(".")[0] # Extract date part
# Convert to proper date format
try:
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
# If date parsing fails, use current date
date_obj = datetime.now().date()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Process each trade from bt_result
trades = bt_result.get_trades()
for pair_name, symbols in trades.items():
# Calculate pair return for this pair
pair_return = 0.0
pair_trades = []
# First pass: collect all trades and calculate returns
for symbol, symbol_trades in symbols.items():
if len(symbol_trades) == 0: # No trades for this symbol
print(
f"Warning: No trades found for symbol {symbol} in pair {pair_name}"
)
continue
elif len(symbol_trades) >= 2: # Completed trades (entry + exit)
# Handle both old and new tuple formats
if len(symbol_trades[0]) == 2: # Old format: (action, price)
entry_action, entry_price = symbol_trades[0]
exit_action, exit_price = symbol_trades[1]
open_disequilibrium = 0.0 # Fallback for old format
open_scaled_disequilibrium = 0.0
close_disequilibrium = 0.0
close_scaled_disequilibrium = 0.0
open_time = datetime.now()
close_time = datetime.now()
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
(
entry_action,
entry_price,
open_disequilibrium,
open_scaled_disequilibrium,
open_time,
) = symbol_trades[0]
(
exit_action,
exit_price,
close_disequilibrium,
close_scaled_disequilibrium,
close_time,
) = symbol_trades[1]
# Handle None values
open_disequilibrium = (
open_disequilibrium
if open_disequilibrium is not None
else 0.0
)
open_scaled_disequilibrium = (
open_scaled_disequilibrium
if open_scaled_disequilibrium is not None
else 0.0
)
close_disequilibrium = (
close_disequilibrium
if close_disequilibrium is not None
else 0.0
)
close_scaled_disequilibrium = (
close_scaled_disequilibrium
if close_scaled_disequilibrium is not None
else 0.0
)
# Convert pandas Timestamps to Python datetime objects
open_time = convert_timestamp(open_time) or datetime.now()
close_time = convert_timestamp(close_time) or datetime.now()
# Calculate actual share quantities based on funding per pair
# Split funding equally between the two positions
funding_per_position = bt_result.config["funding_per_pair"] / 2
shares = funding_per_position / entry_price
# Calculate symbol return
symbol_return = 0.0
if entry_action == "BUY" and exit_action == "SELL":
symbol_return = (exit_price - entry_price) / entry_price * 100
elif entry_action == "SELL" and exit_action == "BUY":
symbol_return = (entry_price - exit_price) / entry_price * 100
pair_return += symbol_return
pair_trades.append(
{
"symbol": symbol,
"entry_action": entry_action,
"entry_price": entry_price,
"exit_action": exit_action,
"exit_price": exit_price,
"symbol_return": symbol_return,
"open_disequilibrium": open_disequilibrium,
"open_scaled_disequilibrium": open_scaled_disequilibrium,
"close_disequilibrium": close_disequilibrium,
"close_scaled_disequilibrium": close_scaled_disequilibrium,
"open_time": open_time,
"close_time": close_time,
"shares": shares,
"is_completed": True,
}
)
# Skip one-sided trades - they will be handled by outstanding_positions table
elif len(symbol_trades) == 1:
print(
f"Skipping one-sided trade for {symbol} in pair {pair_name} - will be stored in outstanding_positions table"
)
continue
else:
# This should not happen, but handle unexpected cases
print(
f"Warning: Unexpected number of trades ({len(symbol_trades)}) for symbol {symbol} in pair {pair_name}"
)
continue
# Second pass: insert completed trade records into database
for trade in pair_trades:
# Only store completed trades in pt_bt_results table
cursor.execute(
"""
INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pair_name,
trade["symbol"],
trade["open_time"],
trade["entry_action"],
trade["entry_price"],
trade["shares"],
trade["open_scaled_disequilibrium"],
trade["close_time"],
trade["exit_action"],
trade["exit_price"],
trade["shares"],
trade["close_scaled_disequilibrium"],
trade["symbol_return"],
pair_return,
),
)
# Store outstanding positions in separate table
outstanding_positions = bt_result.get_outstanding_positions()
for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions)
position_qty_a = (
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = (
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_a"],
position_qty_a,
pos["current_px_a"],
unrealized_return_a,
pos["open_px_a"],
pos["side_a"],
),
)
# Store outstanding position for symbol B
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_b"],
position_qty_b,
pos["current_px_b"],
unrealized_return_b,
pos["open_px_b"],
pos["side_b"],
),
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error storing results in database: {str(e)}")
import traceback
traceback.print_exc()
class BacktestResult: class BacktestResult:
@ -428,16 +206,19 @@ class BacktestResult:
self.trades: Dict[str, Dict[str, Any]] = {} self.trades: Dict[str, Dict[str, Any]] = {}
self.total_realized_pnl = 0.0 self.total_realized_pnl = 0.0
self.outstanding_positions: List[Dict[str, Any]] = [] self.outstanding_positions: List[Dict[str, Any]] = []
self.pairs_trades_: Dict[str, List[Dict[str, Any]]] = {}
def add_trade( def add_trade(
self, self,
pair_nm: str, pair_nm: str,
symbol: str, symbol: str,
side: str,
action: str, action: str,
price: Any, price: Any,
disequilibrium: Optional[float] = None, disequilibrium: Optional[float] = None,
scaled_disequilibrium: Optional[float] = None, scaled_disequilibrium: Optional[float] = None,
timestamp: Optional[datetime] = None, timestamp: Optional[datetime] = None,
status: Optional[str] = None,
) -> None: ) -> None:
"""Add a trade to the results tracking.""" """Add a trade to the results tracking."""
pair_nm = str(pair_nm) pair_nm = str(pair_nm)
@ -447,7 +228,16 @@ class BacktestResult:
if symbol not in self.trades[pair_nm]: if symbol not in self.trades[pair_nm]:
self.trades[pair_nm][symbol] = [] self.trades[pair_nm][symbol] = []
self.trades[pair_nm][symbol].append( self.trades[pair_nm][symbol].append(
(action, price, disequilibrium, scaled_disequilibrium, timestamp) {
"symbol": symbol,
"side": side,
"action": action,
"price": price,
"disequilibrium": disequilibrium,
"scaled_disequilibrium": scaled_disequilibrium,
"timestamp": timestamp,
"status": status,
}
) )
def add_outstanding_position(self, position: Dict[str, Any]) -> None: def add_outstanding_position(self, position: Dict[str, Any]) -> None:
@ -484,20 +274,27 @@ class BacktestResult:
print(result) print(result)
for row in result.itertuples(): for row in result.itertuples():
side = row.side
action = row.action action = row.action
symbol = row.symbol symbol = row.symbol
price = row.price price = row.price
disequilibrium = getattr(row, "disequilibrium", None) disequilibrium = getattr(row, "disequilibrium", None)
scaled_disequilibrium = getattr(row, "scaled_disequilibrium", None) scaled_disequilibrium = getattr(row, "scaled_disequilibrium", None)
timestamp = getattr(row, "time", None) if hasattr(row, "time"):
timestamp = getattr(row, "time")
else:
timestamp = convert_timestamp(row.Index)
status = row.status
self.add_trade( self.add_trade(
pair_nm=str(row.pair), pair_nm=str(row.pair),
action=str(action),
symbol=str(symbol), symbol=str(symbol),
side=str(side),
action=str(action),
price=float(str(price)), price=float(str(price)),
disequilibrium=disequilibrium, disequilibrium=disequilibrium,
scaled_disequilibrium=scaled_disequilibrium, scaled_disequilibrium=scaled_disequilibrium,
timestamp=timestamp, timestamp=timestamp,
status=str(status) if status is not None else "?",
) )
def print_single_day_results(self) -> None: def print_single_day_results(self) -> None:
@ -523,105 +320,126 @@ class BacktestResult:
def calculate_returns(self, all_results: Dict[str, Dict[str, Any]]) -> None: def calculate_returns(self, all_results: Dict[str, Dict[str, Any]]) -> None:
"""Calculate and print returns by day and pair.""" """Calculate and print returns by day and pair."""
def _symbol_return(trade1_side: str, trade1_px: float, trade2_side: str, trade2_px: float) -> float:
if trade1_side == "BUY" and trade2_side == "SELL":
return (trade2_px - trade1_px) / trade1_px * 100
elif trade1_side == "SELL" and trade2_side == "BUY":
return (trade1_px - trade2_px) / trade1_px * 100
else:
return 0
print("\n====== Returns By Day and Pair ======") print("\n====== Returns By Day and Pair ======")
trades = []
for filename, data in all_results.items(): for filename, data in all_results.items():
day_return = 0 pairs = list(data["trades"].keys())
for pair in pairs:
self.pairs_trades_[pair] = []
trades_dict = data["trades"][pair]
for symbol in trades_dict.keys():
trades.extend(trades_dict[symbol])
trades = sorted(trades, key=lambda x: (x["timestamp"], x["symbol"]))
print(f"\n--- {filename} ---") print(f"\n--- {filename} ---")
self.outstanding_positions = data["outstanding_positions"] self.outstanding_positions = data["outstanding_positions"]
day_return = 0.0
for idx in range(0, len(trades), 4):
symbol_a = trades[idx]["symbol"]
trade_a_1 = trades[idx]
trade_a_2 = trades[idx + 2]
# Process each pair symbol_b = trades[idx + 1]["symbol"]
for pair, symbols in data["trades"].items(): trade_b_1 = trades[idx + 1]
pair_return = 0 trade_b_2 = trades[idx + 3]
pair_trades = []
# Calculate individual symbol returns in the pair symbol_return = 0
for symbol, trades in symbols.items(): assert (
if len(trades) == 0: trade_a_1["timestamp"] < trade_a_2["timestamp"]
continue ), f"Trade 1: {trade_a_1['timestamp']} is not less than Trade 2: {trade_a_2['timestamp']}"
assert (
symbol_return = 0 trade_a_1["action"] == "OPEN" and trade_a_2["action"] == "CLOSE"
symbol_trades = [] ), f"Trade 1: {trade_a_1['action']} and Trade 2: {trade_a_2['action']} are the same"
# Process all trades sequentially for this symbol
for i, trade in enumerate(trades):
# Handle both old and new tuple formats
if len(trade) == 2: # Old format: (action, price)
action, price = trade
disequilibrium = None
scaled_disequilibrium = None
timestamp = None
else: # New format: (action, price, disequilibrium, scaled_disequilibrium, timestamp)
action, price = trade[:2]
disequilibrium = trade[2] if len(trade) > 2 else None
scaled_disequilibrium = trade[3] if len(trade) > 3 else None
timestamp = trade[4] if len(trade) > 4 else None
symbol_trades.append((action, price, disequilibrium, scaled_disequilibrium, timestamp))
# Calculate returns for all trade combinations
for i in range(len(symbol_trades) - 1):
trade1 = symbol_trades[i]
trade2 = symbol_trades[i + 1]
action1, price1, diseq1, scaled_diseq1, ts1 = trade1
action2, price2, diseq2, scaled_diseq2, ts2 = trade2
# Calculate return based on action combination
trade_return = 0
if action1 == "BUY" and action2 == "SELL":
# Long position
trade_return = (price2 - price1) / price1 * 100
elif action1 == "SELL" and action2 == "BUY":
# Short position
trade_return = (price1 - price2) / price1 * 100
symbol_return += trade_return
# Store trade details for reporting
pair_trades.append(
(
symbol,
action1,
price1,
action2,
price2,
trade_return,
scaled_diseq1,
scaled_diseq2,
i + 1, # Trade sequence number
)
)
pair_return += symbol_return
# Print pair returns with disequilibrium information # Calculate return based on action combination
if pair_trades: trade_return = 0
print(f" {pair}:") symbol_a_return = _symbol_return(trade_a_1["side"], trade_a_1["price"], trade_a_2["side"], trade_a_2["price"])
for ( symbol_b_return = _symbol_return(trade_b_1["side"], trade_b_1["price"], trade_b_2["side"], trade_b_2["price"])
symbol,
action1,
price1,
action2,
price2,
trade_return,
scaled_diseq1,
scaled_diseq2,
trade_num,
) in pair_trades:
disequil_info = ""
if (
scaled_diseq1 is not None
and scaled_diseq2 is not None
):
disequil_info = f" | Open Dis-eq: {scaled_diseq1:.2f}, Close Dis-eq: {scaled_diseq2:.2f}"
print( pair_return = symbol_a_return + symbol_b_return
f" {symbol} (Trade #{trade_num}): {action1} @ ${price1:.2f}, {action2} @ ${price2:.2f}, Return: {trade_return:.2f}%{disequil_info}"
) self.pairs_trades_[pair].append(
print(f" Pair Total Return: {pair_return:.2f}%") {
day_return += pair_return "symbol": symbol_a,
"open_side": trade_a_1["side"],
"open_action": trade_a_1["action"],
"open_price": trade_a_1["price"],
"close_side": trade_a_2["side"],
"close_action": trade_a_2["action"],
"close_price": trade_a_2["price"],
"symbol_return": symbol_a_return,
"open_disequilibrium": trade_a_1["disequilibrium"],
"open_scaled_disequilibrium": trade_a_1["scaled_disequilibrium"],
"close_disequilibrium": trade_a_2["disequilibrium"],
"close_scaled_disequilibrium": trade_a_2["scaled_disequilibrium"],
"open_time": trade_a_1["timestamp"],
"close_time": trade_a_2["timestamp"],
"shares": self.config["funding_per_pair"] / 2 / trade_a_1["price"],
"is_completed": True,
"close_condition": trade_a_2["status"],
"pair_return": pair_return
}
)
self.pairs_trades_[pair].append(
{
"symbol": symbol_b,
"open_side": trade_b_1["side"],
"open_action": trade_b_1["action"],
"open_price": trade_b_1["price"],
"close_side": trade_b_2["side"],
"close_action": trade_b_2["action"],
"close_price": trade_b_2["price"],
"symbol_return": symbol_b_return,
"open_disequilibrium": trade_b_1["disequilibrium"],
"open_scaled_disequilibrium": trade_b_1["scaled_disequilibrium"],
"close_disequilibrium": trade_b_2["disequilibrium"],
"close_scaled_disequilibrium": trade_b_2["scaled_disequilibrium"],
"open_time": trade_b_1["timestamp"],
"close_time": trade_b_2["timestamp"],
"shares": self.config["funding_per_pair"] / 2 / trade_b_1["price"],
"is_completed": True,
"close_condition": trade_b_2["status"],
"pair_return": pair_return
}
)
# Print pair returns with disequilibrium information
day_return = 0.0
if pair in self.pairs_trades_:
print(f"{pair}:")
pair_return = 0.0
for trd in self.pairs_trades_[pair]:
disequil_info = ""
if (
trd["open_scaled_disequilibrium"] is not None
and trd["open_scaled_disequilibrium"] is not None
):
disequil_info = f" | Open Dis-eq: {trd['open_scaled_disequilibrium']:.2f},"
f" Close Dis-eq: {trd['open_scaled_disequilibrium']:.2f}"
print(
f" {trd['open_time'].time()}-{trd['close_time'].time()} {trd['symbol']}: "
f" {trd['open_side']} @ ${trd['open_price']:.2f},"
f" {trd["close_side"]} @ ${trd["close_price"]:.2f},"
f" Return: {trd['symbol_return']:.2f}%{disequil_info}"
)
pair_return += trd["symbol_return"]
print(f" Pair Total Return: {pair_return:.2f}%")
day_return += pair_return
# Print day total return and add to global realized PnL # Print day total return and add to global realized PnL
if day_return != 0: if day_return != 0:
@ -698,7 +516,7 @@ class BacktestResult:
print("-" * 100) print("-" * 100)
total_value += pos["total_current_value"] total_value += pos["total_current_value"]
print(f"{'TOTAL OUTSTANDING VALUE':<80} ${total_value:<12.2f}") print(f"{'TOTAL OUTSTANDING VALUE':<80} ${total_value:<12.2f}")
@ -734,7 +552,7 @@ class BacktestResult:
last_row = pair_result_df.loc[last_row_index] last_row = pair_result_df.loc[last_row_index]
last_tstamp = last_row["tstamp"] last_tstamp = last_row["tstamp"]
colname_a, colname_b = pair.colnames() colname_a, colname_b = pair.exec_prices_colnames()
last_px_a = last_row[colname_a] last_px_a = last_row[colname_a]
last_px_b = last_row[colname_b] last_px_b = last_row[colname_b]
@ -793,3 +611,131 @@ class BacktestResult:
) )
return current_value_a, current_value_b, total_current_value return current_value_a, current_value_b, total_current_value
def store_results_in_database(
self, db_path: str, day: str
) -> None:
"""
Store backtest results in the SQLite database.
"""
if db_path.upper() == "NONE":
return
try:
# Extract date from datafile name (assuming format like 20250528.mktdata.ohlcv.db)
date_str = day
# Convert to proper date format
try:
date_obj = datetime.strptime(date_str, "%Y%m%d").date()
except ValueError:
# If date parsing fails, use current date
date_obj = datetime.now().date()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Process each trade from bt_result
trades = self.get_trades()
for pair_name, _ in trades.items():
# Second pass: insert completed trade records into database
for trade_pair in sorted(self.pairs_trades_[pair_name], key=lambda x: x["open_time"]):
# Only store completed trades in pt_bt_results table
cursor.execute(
"""
INSERT INTO pt_bt_results (
date, pair, symbol, open_time, open_side, open_price,
open_quantity, open_disequilibrium, close_time, close_side,
close_price, close_quantity, close_disequilibrium,
symbol_return, pair_return, close_condition
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pair_name,
trade_pair["symbol"],
trade_pair["open_time"],
trade_pair["open_side"],
trade_pair["open_price"],
trade_pair["shares"],
trade_pair["open_scaled_disequilibrium"],
trade_pair["close_time"],
trade_pair["close_side"],
trade_pair["close_price"],
trade_pair["shares"],
trade_pair["close_scaled_disequilibrium"],
trade_pair["symbol_return"],
trade_pair["pair_return"],
trade_pair["close_condition"]
),
)
# Store outstanding positions in separate table
outstanding_positions = self.get_outstanding_positions()
for pos in outstanding_positions:
# Calculate position quantity (negative for SELL positions)
position_qty_a = (
pos["shares_a"] if pos["side_a"] == "BUY" else -pos["shares_a"]
)
position_qty_b = (
pos["shares_b"] if pos["side_b"] == "BUY" else -pos["shares_b"]
)
# Calculate unrealized returns
# For symbol A: (current_price - open_price) / open_price * 100 * position_direction
unrealized_return_a = (
(pos["current_px_a"] - pos["open_px_a"]) / pos["open_px_a"] * 100
) * (1 if pos["side_a"] == "BUY" else -1)
unrealized_return_b = (
(pos["current_px_b"] - pos["open_px_b"]) / pos["open_px_b"] * 100
) * (1 if pos["side_b"] == "BUY" else -1)
# Store outstanding position for symbol A
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_a"],
position_qty_a,
pos["current_px_a"],
unrealized_return_a,
pos["open_px_a"],
pos["side_a"],
),
)
# Store outstanding position for symbol B
cursor.execute(
"""
INSERT INTO outstanding_positions (
date, pair, symbol, position_quantity, last_price, unrealized_return, open_price, open_side
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
date_obj,
pos["pair"],
pos["symbol_b"],
position_qty_b,
pos["current_px_b"],
unrealized_return_b,
pos["open_px_b"],
pos["side_b"],
),
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error storing results in database: {str(e)}")
import traceback
traceback.print_exc()

View File

@ -0,0 +1,319 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import PairState, TradingPair
from statsmodels.tsa.vector_ar.vecm import VECM, VECMResults
NanoPerMin = 1e9
class RollingFit(PairsTradingFitMethod):
"""
N O T E:
=========
- This class remains to be abstract
- The following methods are to be implemented in the subclass:
- create_trading_pair()
=========
"""
def __init__(self) -> None:
super().__init__()
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]:
print(f"***{pair}*** STARTING....")
config = pair.config_
curr_training_start_idx = pair.get_begin_index()
end_index = pair.get_end_index()
pair.user_data_["state"] = PairState.INITIAL
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS).astype(
{
"time": "datetime64[ns]",
"symbol": "string",
"side": "string",
"action": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object",
}
)
training_minutes = config["training_minutes"]
curr_predicted_row_idx = 0
while True:
print(curr_training_start_idx, end="\r")
pair.get_datasets(
training_minutes=training_minutes,
training_start_index=curr_training_start_idx,
testing_size=1,
)
if len(pair.training_df_) < training_minutes:
print(
f"{pair}: current offset={curr_training_start_idx}"
f" * Training data length={len(pair.training_df_)} < {training_minutes}"
" * Not enough training data. Completing the job."
)
break
try:
# ================================ PREDICTION ================================
self.pair_predict_result_ = pair.predict()
except Exception as e:
raise RuntimeError(
f"{pair}: TrainingPrediction failed: {str(e)}"
) from e
# break
curr_training_start_idx += 1
if curr_training_start_idx > end_index:
break
curr_predicted_row_idx += 1
self._create_trading_signals(pair, config, bt_result)
print(f"***{pair}*** FINISHED *** Num Trades:{len(pair.user_data_['trades'])}")
return pair.get_trades()
def _create_trading_signals(
self, pair: TradingPair, config: Dict, bt_result: BacktestResult
) -> None:
predicted_df = self.pair_predict_result_
assert predicted_df is not None
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
for curr_predicted_row_idx in range(len(predicted_df)):
pred_row = predicted_df.iloc[curr_predicted_row_idx]
scaled_disequilibrium = pred_row["scaled_disequilibrium"]
if pair.user_data_["state"] in [
PairState.INITIAL,
PairState.CLOSE,
PairState.CLOSE_POSITION,
PairState.CLOSE_STOP_LOSS,
PairState.CLOSE_STOP_PROFIT,
]:
if scaled_disequilibrium >= open_threshold:
open_trades = self._get_open_trades(
pair, row=pred_row, open_threshold=open_threshold
)
if open_trades is not None:
open_trades["status"] = PairState.OPEN.name
print(f"OPEN TRADES:\n{open_trades}")
pair.add_trades(open_trades)
pair.user_data_["state"] = PairState.OPEN
pair.on_open_trades(open_trades)
elif pair.user_data_["state"] == PairState.OPEN:
if scaled_disequilibrium <= close_threshold:
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = PairState.CLOSE.name
print(f"CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = PairState.CLOSE
pair.on_close_trades(close_trades)
elif pair.to_stop_close_conditions(predicted_row=pred_row):
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = pair.user_data_[
"stop_close_state"
].name
print(f"STOP CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = pair.user_data_["stop_close_state"]
pair.on_close_trades(close_trades)
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(f"{pair}: *** Position is NOT CLOSED. ***")
# outstanding positions
if config["close_outstanding_positions"]:
close_position_row = pd.Series(pair.market_data_.iloc[-2])
close_position_row["disequilibrium"] = 0.0
close_position_row["scaled_disequilibrium"] = 0.0
close_position_row["signed_scaled_disequilibrium"] = 0.0
close_position_trades = self._get_close_trades(
pair=pair, row=close_position_row, close_threshold=close_threshold
)
if close_position_trades is not None:
close_position_trades["status"] = PairState.CLOSE_POSITION.name
print(f"CLOSE_POSITION TRADES:\n{close_position_trades}")
pair.add_trades(close_position_trades)
pair.user_data_["state"] = PairState.CLOSE_POSITION
pair.on_close_trades(close_position_trades)
else:
if predicted_df is not None:
bt_result.handle_outstanding_position(
pair=pair,
pair_result_df=predicted_df,
last_row_index=0,
open_side_a=pair.user_data_["open_side_a"],
open_side_b=pair.user_data_["open_side_b"],
open_px_a=pair.user_data_["open_px_a"],
open_px_b=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
)
def _get_open_trades(
self, pair: TradingPair, row: pd.Series, open_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
open_row = row
open_tstamp = open_row["tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
signed_scaled_disequilibrium = open_row["signed_scaled_disequilibrium"]
open_px_a = open_row[f"{colname_a}"]
open_px_b = open_row[f"{colname_b}"]
# creating the trades
print(f"OPEN_TRADES: {row["tstamp"]} {open_scaled_disequilibrium=}")
if open_disequilibrium > 0:
open_side_a = "SELL"
open_side_b = "BUY"
close_side_a = "BUY"
close_side_b = "SELL"
else:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
# save closing sides
pair.user_data_["open_side_a"] = open_side_a
pair.user_data_["open_side_b"] = open_side_b
pair.user_data_["open_px_a"] = open_px_a
pair.user_data_["open_px_b"] = open_px_b
pair.user_data_["open_tstamp"] = open_tstamp
pair.user_data_["close_side_a"] = close_side_a
pair.user_data_["close_side_b"] = close_side_b
# create opening trades
trd_signal_tuples = [
(
open_tstamp,
pair.symbol_a_,
open_side_a,
"OPEN",
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
(
open_tstamp,
pair.symbol_b_,
open_side_b,
"OPEN",
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
]
# Create DataFrame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype(
{
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
"pair": "object",
}
)
def _get_close_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.exec_prices_colnames()
close_row = row
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
signed_scaled_disequilibrium = close_row["signed_scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
close_side_a = pair.user_data_["close_side_a"]
close_side_b = pair.user_data_["close_side_b"]
trd_signal_tuples = [
(
close_tstamp,
pair.symbol_a_,
close_side_a,
"CLOSE",
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
(
close_tstamp,
pair.symbol_b_,
close_side_b,
"CLOSE",
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
signed_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype(
{
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"signed_scaled_disequilibrium": "float64",
"pair": "object",
}
)
def reset(self) -> None:
curr_training_start_idx = 0

View File

@ -1,362 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.fit_method import PairState, PairsTradingFitMethod
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
NanoPerMin = 1e9
class SlidingFit(PairsTradingFitMethod):
def __init__(self) -> None:
super().__init__()
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]:
print(f"***{pair}*** STARTING....")
config = pair.config_
curr_training_start_idx = pair.get_begin_index()
end_index = pair.get_end_index()
pair.user_data_["state"] = PairState.INITIAL
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
pair.user_data_["trades"] = pd.DataFrame(columns=self.TRADES_COLUMNS).astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
pair.user_data_["is_cointegrated"] = False
training_minutes = config["training_minutes"]
curr_predicted_row_idx = 0
while True:
print(curr_training_start_idx, end="\r")
pair.get_datasets(
training_minutes=training_minutes,
training_start_index=curr_training_start_idx,
testing_size=1,
)
if len(pair.training_df_) < training_minutes:
print(
f"{pair}: current offset={curr_training_start_idx}"
f" * Training data length={len(pair.training_df_)} < {training_minutes}"
" * Not enough training data. Completing the job."
)
break
try:
# ================================ TRAINING ================================
is_cointegrated = pair.train_pair()
except Exception as e:
raise RuntimeError(f"{pair}: Training failed: {str(e)}") from e
if pair.user_data_["is_cointegrated"] != is_cointegrated:
pair.user_data_["is_cointegrated"] = is_cointegrated
if not is_cointegrated:
if pair.user_data_["state"] == PairState.OPEN:
print(
f"{pair} {curr_training_start_idx} LOST COINTEGRATION. Consider closing positions..."
)
else:
print(
f"{pair} {curr_training_start_idx} IS NOT COINTEGRATED. Moving on"
)
else:
print("*" * 80)
print(
f"Pair {pair} ({curr_training_start_idx}) IS COINTEGRATED"
)
print("*" * 80)
if not is_cointegrated:
curr_training_start_idx += 1
continue
try:
# ================================ PREDICTION ================================
pair.predict()
except Exception as e:
raise RuntimeError(f"{pair}: Prediction failed: {str(e)}") from e
# break
curr_training_start_idx += 1
if curr_training_start_idx > end_index:
break
curr_predicted_row_idx += 1
self._create_trading_signals(pair, config, bt_result)
print(f"***{pair}*** FINISHED ... {len(pair.user_data_['trades'])}")
return pair.get_trades()
def _create_trading_signals(
self, pair: TradingPair, config: Dict, bt_result: BacktestResult
) -> None:
if pair.predicted_df_ is None:
print(f"{pair.market_data_.iloc[0]['tstamp']} {pair}: No predicted data")
return
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
for curr_predicted_row_idx in range(len(pair.predicted_df_)):
pred_row = pair.predicted_df_.iloc[curr_predicted_row_idx]
if pair.user_data_["state"] in [PairState.INITIAL, PairState.CLOSED, PairState.CLOSED_POSITIONS]:
open_trades = self._get_open_trades(
pair, row=pred_row, open_threshold=open_threshold
)
if open_trades is not None:
open_trades["status"] = "OPEN"
print(f"OPEN TRADES:\n{open_trades}")
pair.add_trades(open_trades)
pair.user_data_["state"] = PairState.OPEN
elif pair.user_data_["state"] == PairState.OPEN:
close_trades = self._get_close_trades(
pair, row=pred_row, close_threshold=close_threshold
)
if close_trades is not None:
close_trades["status"] = "CLOSE"
print(f"CLOSE TRADES:\n{close_trades}")
pair.add_trades(close_trades)
pair.user_data_["state"] = PairState.CLOSED
# Outstanding positions
if pair.user_data_["state"] == PairState.OPEN:
print(
f"{pair}: *** Position is NOT CLOSED. ***"
)
# outstanding positions
if config["close_outstanding_positions"]:
close_position_trades = self._get_close_position_trades(
pair=pair,
row=pred_row,
close_threshold=close_threshold,
)
if close_position_trades is not None:
close_position_trades["status"] = "CLOSE_POSITION"
print(f"CLOSE_POSITION TRADES:\n{close_position_trades}")
pair.add_trades(close_position_trades)
pair.user_data_["state"] = PairState.CLOSED_POSITIONS
else:
if pair.predicted_df_ is not None:
bt_result.handle_outstanding_position(
pair=pair,
pair_result_df=pair.predicted_df_,
last_row_index=0,
open_side_a=pair.user_data_["open_side_a"],
open_side_b=pair.user_data_["open_side_b"],
open_px_a=pair.user_data_["open_px_a"],
open_px_b=pair.user_data_["open_px_b"],
open_tstamp=pair.user_data_["open_tstamp"],
)
def _get_open_trades(
self, pair: TradingPair, row: pd.Series, open_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames()
assert pair.predicted_df_ is not None
predicted_df = pair.predicted_df_
# Check if we have any data to work with
if len(predicted_df) == 0:
return None
open_row = row
open_tstamp = open_row["tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
open_px_a = open_row[f"{colname_a}"]
open_px_b = open_row[f"{colname_b}"]
if open_scaled_disequilibrium < open_threshold:
return None
# creating the trades
print(f"OPEN_TRADES: {row["tstamp"]} {open_scaled_disequilibrium=}")
if open_disequilibrium > 0:
open_side_a = "SELL"
open_side_b = "BUY"
close_side_a = "BUY"
close_side_b = "SELL"
else:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
# save closing sides
pair.user_data_["open_side_a"] = open_side_a
pair.user_data_["open_side_b"] = open_side_b
pair.user_data_["open_px_a"] = open_px_a
pair.user_data_["open_px_b"] = open_px_b
pair.user_data_["open_tstamp"] = open_tstamp
pair.user_data_["close_side_a"] = close_side_a
pair.user_data_["close_side_b"] = close_side_b
# create opening trades
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
]
# Create DataFrame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def _get_close_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames()
assert pair.predicted_df_ is not None
if len(pair.predicted_df_) == 0:
return None
close_row = row
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
close_side_a = pair.user_data_["close_side_a"]
close_side_b = pair.user_data_["close_side_b"]
if close_scaled_disequilibrium > close_threshold:
return None
trd_signal_tuples = [
(
close_tstamp,
close_side_a,
pair.symbol_a_,
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_b,
pair.symbol_b_,
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def _get_close_position_trades(
self, pair: TradingPair, row: pd.Series, close_threshold: float
) -> Optional[pd.DataFrame]:
colname_a, colname_b = pair.colnames()
assert pair.predicted_df_ is not None
if len(pair.predicted_df_) == 0:
return None
close_position_row = row
close_position_tstamp = close_position_row["tstamp"]
close_position_disequilibrium = close_position_row["disequilibrium"]
close_position_scaled_disequilibrium = close_position_row["scaled_disequilibrium"]
close_position_px_a = close_position_row[f"{colname_a}"]
close_position_px_b = close_position_row[f"{colname_b}"]
close_position_side_a = pair.user_data_["close_side_a"]
close_position_side_b = pair.user_data_["close_side_b"]
trd_signal_tuples = [
(
close_position_tstamp,
close_position_side_a,
pair.symbol_a_,
close_position_px_a,
close_position_disequilibrium,
close_position_scaled_disequilibrium,
pair,
),
(
close_position_tstamp,
close_position_side_b,
pair.symbol_b_,
close_position_px_b,
close_position_disequilibrium,
close_position_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def reset(self) -> None:
curr_training_start_idx = 0

View File

@ -1,220 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast
import pandas as pd # type: ignore[import]
from pt_trading.results import BacktestResult
from pt_trading.trading_pair import TradingPair
from pt_trading.fit_method import PairsTradingFitMethod
NanoPerMin = 1e9
class StaticFit(PairsTradingFitMethod):
def run_pair(
self, pair: TradingPair, bt_result: BacktestResult
) -> Optional[pd.DataFrame]: # abstractmethod
config = pair.config_
pair.get_datasets(training_minutes=config["training_minutes"])
try:
is_cointegrated = pair.train_pair()
if not is_cointegrated:
print(f"{pair} IS NOT COINTEGRATED")
return None
except Exception as e:
print(f"{pair}: Training failed: {str(e)}")
return None
try:
pair.predict()
except Exception as e:
print(f"{pair}: Prediction failed: {str(e)}")
return None
pair_trades = self.create_trading_signals(
pair=pair, config=config, result=bt_result
)
return pair_trades
def create_trading_signals(
self, pair: TradingPair, config: Dict, result: BacktestResult
) -> pd.DataFrame:
beta = pair.vecm_fit_.beta # type: ignore
colname_a, colname_b = pair.colnames()
predicted_df = pair.predicted_df_
if predicted_df is None:
# Return empty DataFrame with correct columns and dtypes
return pd.DataFrame(columns=self.TRADES_COLUMNS).astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
open_threshold = config["dis-equilibrium_open_trshld"]
close_threshold = config["dis-equilibrium_close_trshld"]
# Iterate through the testing dataset to find the first trading opportunity
open_row_index = None
for row_idx in range(len(predicted_df)):
curr_disequilibrium = predicted_df["scaled_disequilibrium"][row_idx]
# Check if current row has sufficient disequilibrium (not near-zero)
if curr_disequilibrium >= open_threshold:
open_row_index = row_idx
break
# If no row with sufficient disequilibrium found, skip this pair
if open_row_index is None:
print(f"{pair}: Insufficient disequilibrium in testing dataset. Skipping.")
return pd.DataFrame()
# Look for close signal starting from the open position
trading_signals_df = (
predicted_df["scaled_disequilibrium"][open_row_index:] < close_threshold
)
# Adjust indices to account for the offset from open_row_index
close_row_index = None
for idx, value in trading_signals_df.items():
if value:
close_row_index = idx
break
open_row = predicted_df.loc[open_row_index]
open_px_a = predicted_df.at[open_row_index, f"{colname_a}"]
open_px_b = predicted_df.at[open_row_index, f"{colname_b}"]
open_tstamp = predicted_df.at[open_row_index, "tstamp"]
open_disequilibrium = open_row["disequilibrium"]
open_scaled_disequilibrium = open_row["scaled_disequilibrium"]
abs_beta = abs(beta[1])
pred_px_b = predicted_df.loc[open_row_index][f"{colname_b}_pred"]
pred_px_a = predicted_df.loc[open_row_index][f"{colname_a}_pred"]
if pred_px_b * abs_beta - pred_px_a > 0:
open_side_a = "BUY"
open_side_b = "SELL"
close_side_a = "SELL"
close_side_b = "BUY"
else:
open_side_b = "BUY"
open_side_a = "SELL"
close_side_b = "SELL"
close_side_a = "BUY"
# If no close signal found, print position and unrealized PnL
if close_row_index is None:
last_row_index = len(predicted_df) - 1
# Use the new method from BacktestResult to handle outstanding positions
result.handle_outstanding_position(
pair=pair,
pair_result_df=predicted_df,
last_row_index=last_row_index,
open_side_a=open_side_a,
open_side_b=open_side_b,
open_px_a=float(open_px_a),
open_px_b=float(open_px_b),
open_tstamp=pd.Timestamp(open_tstamp),
)
# Return only open trades (no close trades)
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
]
else:
# Close signal found - create complete trade
close_row = predicted_df.loc[close_row_index]
close_tstamp = close_row["tstamp"]
close_disequilibrium = close_row["disequilibrium"]
close_scaled_disequilibrium = close_row["scaled_disequilibrium"]
close_px_a = close_row[f"{colname_a}"]
close_px_b = close_row[f"{colname_b}"]
print(f"{pair}: Close signal found at index {close_row_index}")
trd_signal_tuples = [
(
open_tstamp,
open_side_a,
pair.symbol_a_,
open_px_a,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
open_tstamp,
open_side_b,
pair.symbol_b_,
open_px_b,
open_disequilibrium,
open_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_a,
pair.symbol_a_,
close_px_a,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
(
close_tstamp,
close_side_b,
pair.symbol_b_,
close_px_b,
close_disequilibrium,
close_scaled_disequilibrium,
pair,
),
]
# Add tuples to data frame with explicit dtypes to avoid concatenation warnings
df = pd.DataFrame(
trd_signal_tuples,
columns=self.TRADES_COLUMNS,
)
# Ensure consistent dtypes
return df.astype({
"time": "datetime64[ns]",
"action": "string",
"symbol": "string",
"price": "float64",
"disequilibrium": "float64",
"scaled_disequilibrium": "float64",
"pair": "object"
})
def reset(self) -> None:
pass

View File

@ -1,14 +1,79 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import pandas as pd # type:ignore import pandas as pd # type:ignore
from statsmodels.tsa.vector_ar.vecm import VECM, VECMResults # type:ignore
class TradingPair: class PairState(Enum):
INITIAL = 1
OPEN = 2
CLOSE = 3
CLOSE_POSITION = 4
CLOSE_STOP_LOSS = 5
CLOSE_STOP_PROFIT = 6
class CointegrationData:
EG_PVALUE_THRESHOLD = 0.05
tstamp_: pd.Timestamp
pair_: str
eg_pvalue_: float
johansen_lr1_: float
johansen_cvt_: float
eg_is_cointegrated_: bool
johansen_is_cointegrated_: bool
def __init__(self, pair: TradingPair):
training_df = pair.training_df_
assert training_df is not None
from statsmodels.tsa.vector_ar.vecm import coint_johansen
df = training_df[pair.colnames()].reset_index(drop=True)
# Run Johansen cointegration test
result = coint_johansen(df, det_order=0, k_ar_diff=1)
self.johansen_lr1_ = result.lr1[0]
self.johansen_cvt_ = result.cvt[0, 1]
self.johansen_is_cointegrated_ = self.johansen_lr1_ > self.johansen_cvt_
# Run Engle-Granger cointegration test
from statsmodels.tsa.stattools import coint # type: ignore
col1, col2 = pair.colnames()
assert training_df is not None
series1 = training_df[col1].reset_index(drop=True)
series2 = training_df[col2].reset_index(drop=True)
self.eg_pvalue_ = float(coint(series1, series2)[1])
self.eg_is_cointegrated_ = bool(self.eg_pvalue_ < self.EG_PVALUE_THRESHOLD)
self.tstamp_ = training_df.index[-1]
self.pair_ = pair.name()
def to_dict(self) -> Dict[str, Any]:
return {
"tstamp": self.tstamp_,
"pair": self.pair_,
"eg_pvalue": self.eg_pvalue_,
"johansen_lr1": self.johansen_lr1_,
"johansen_cvt": self.johansen_cvt_,
"eg_is_cointegrated": self.eg_is_cointegrated_,
"johansen_is_cointegrated": self.johansen_is_cointegrated_,
}
def __repr__(self) -> str:
return f"CointegrationData(tstamp={self.tstamp_}, pair={self.pair_}, eg_pvalue={self.eg_pvalue_}, johansen_lr1={self.johansen_lr1_}, johansen_cvt={self.johansen_cvt_}, eg_is_cointegrated={self.eg_is_cointegrated_}, johansen_is_cointegrated={self.johansen_is_cointegrated_})"
class TradingPair(ABC):
market_data_: pd.DataFrame market_data_: pd.DataFrame
symbol_a_: str symbol_a_: str
symbol_b_: str symbol_b_: str
price_column_: str stat_model_price_: str
training_mu_: float training_mu_: float
training_std_: float training_std_: float
@ -16,54 +81,81 @@ class TradingPair:
training_df_: pd.DataFrame training_df_: pd.DataFrame
testing_df_: pd.DataFrame testing_df_: pd.DataFrame
vecm_fit_: VECMResults
user_data_: Dict[str, Any] user_data_: Dict[str, Any]
predicted_df_: Optional[pd.DataFrame] # predicted_df_: Optional[pd.DataFrame]
def __init__( def __init__(
self, config: Dict[str, Any], market_data: pd.DataFrame, symbol_a: str, symbol_b: str, price_column: str self,
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
): ):
self.symbol_a_ = symbol_a self.symbol_a_ = symbol_a
self.symbol_b_ = symbol_b self.symbol_b_ = symbol_b
self.price_column_ = price_column self.stat_model_price_ = config["stat_model_price"]
self.set_market_data(market_data)
self.user_data_ = {} self.user_data_ = {}
self.predicted_df_ = None self.predicted_df_ = None
self.config_ = config self.config_ = config
def set_market_data(self, market_data: pd.DataFrame) -> None: self._set_market_data(market_data)
def _set_market_data(self, market_data: pd.DataFrame) -> None:
self.market_data_ = pd.DataFrame( self.market_data_ = pd.DataFrame(
self._transform_dataframe(market_data)[["tstamp"] + self.colnames()] self._transform_dataframe(market_data)[["tstamp"] + self.colnames()]
) )
self.market_data_ = self.market_data_.dropna().reset_index(drop=True) self.market_data_ = self.market_data_.dropna().reset_index(drop=True)
self.market_data_['tstamp'] = pd.to_datetime(self.market_data_['tstamp']) self.market_data_["tstamp"] = pd.to_datetime(self.market_data_["tstamp"])
self.market_data_ = self.market_data_.sort_values('tstamp') self.market_data_ = self.market_data_.sort_values("tstamp")
self._set_execution_price_data()
pass
def _set_execution_price_data(self) -> None:
if "execution_price" not in self.config_:
self.market_data_[f"exec_price_{self.symbol_a_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_a_}"]
self.market_data_[f"exec_price_{self.symbol_b_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_b_}"]
return
execution_price_column = self.config_["execution_price"]["column"]
execution_price_shift = self.config_["execution_price"]["shift"]
self.market_data_[f"exec_price_{self.symbol_a_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_a_}"].shift(-execution_price_shift)
self.market_data_[f"exec_price_{self.symbol_b_}"] = self.market_data_[f"{self.stat_model_price_}_{self.symbol_b_}"].shift(-execution_price_shift)
self.market_data_ = self.market_data_.dropna().reset_index(drop=True)
def get_begin_index(self) -> int: def get_begin_index(self) -> int:
if "trading_hours" not in self.config_: if "trading_hours" not in self.config_:
return 0 return 0
assert "timezone" in self.config_["trading_hours"] assert "timezone" in self.config_["trading_hours"]
assert "begin_session" in self.config_["trading_hours"] assert "begin_session" in self.config_["trading_hours"]
start_time = pd.to_datetime(self.config_["trading_hours"]["begin_session"]).tz_localize(self.config_["trading_hours"]["timezone"]).time() start_time = (
mask = self.market_data_['tstamp'].dt.time >= start_time pd.to_datetime(self.config_["trading_hours"]["begin_session"])
.tz_localize(self.config_["trading_hours"]["timezone"])
.time()
)
mask = self.market_data_["tstamp"].dt.time >= start_time
return int(self.market_data_.index[mask].min()) return int(self.market_data_.index[mask].min())
def get_end_index(self) -> int: def get_end_index(self) -> int:
if "trading_hours" not in self.config_: if "trading_hours" not in self.config_:
return 0 return 0
assert "timezone" in self.config_["trading_hours"] assert "timezone" in self.config_["trading_hours"]
assert "end_session" in self.config_["trading_hours"] assert "end_session" in self.config_["trading_hours"]
end_time = pd.to_datetime(self.config_["trading_hours"]["end_session"]).tz_localize(self.config_["trading_hours"]["timezone"]).time() end_time = (
mask = self.market_data_['tstamp'].dt.time <= end_time pd.to_datetime(self.config_["trading_hours"]["end_session"])
.tz_localize(self.config_["trading_hours"]["timezone"])
.time()
)
mask = self.market_data_["tstamp"].dt.time <= end_time
return int(self.market_data_.index[mask].max()) return int(self.market_data_.index[mask].max())
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
# Select only the columns we need # Select only the columns we need
df_selected: pd.DataFrame = pd.DataFrame( df_selected: pd.DataFrame = pd.DataFrame(
df[["tstamp", "symbol", self.price_column_]] df[["tstamp", "symbol", self.stat_model_price_]]
) )
# Start with unique timestamps # Start with unique timestamps
@ -81,13 +173,13 @@ class TradingPair:
) )
# Create column name like "close-COIN" # Create column name like "close-COIN"
new_price_column = f"{self.price_column_}_{symbol}" new_price_column = f"{self.stat_model_price_}_{symbol}"
# Create temporary dataframe with timestamp and price # Create temporary dataframe with timestamp and price
temp_df = pd.DataFrame( temp_df = pd.DataFrame(
{ {
"tstamp": df_symbol["tstamp"], "tstamp": df_symbol["tstamp"],
new_price_column: df_symbol[self.price_column_], new_price_column: df_symbol[self.stat_model_price_],
} }
) )
@ -108,7 +200,7 @@ class TradingPair:
testing_start_index = training_start_index + training_minutes testing_start_index = training_start_index + training_minutes
self.training_df_ = self.market_data_.iloc[ self.training_df_ = self.market_data_.iloc[
training_start_index:testing_start_index, : training_minutes training_start_index:testing_start_index, :training_minutes
].copy() ].copy()
assert self.training_df_ is not None assert self.training_df_ is not None
self.training_df_ = self.training_df_.dropna().reset_index(drop=True) self.training_df_ = self.training_df_.dropna().reset_index(drop=True)
@ -125,82 +217,15 @@ class TradingPair:
def colnames(self) -> List[str]: def colnames(self) -> List[str]:
return [ return [
f"{self.price_column_}_{self.symbol_a_}", f"{self.stat_model_price_}_{self.symbol_a_}",
f"{self.price_column_}_{self.symbol_b_}", f"{self.stat_model_price_}_{self.symbol_b_}",
] ]
def fit_VECM(self) -> None: def exec_prices_colnames(self) -> List[str]:
assert self.training_df_ is not None return [
vecm_df = self.training_df_[self.colnames()].reset_index(drop=True) f"exec_price_{self.symbol_a_}",
vecm_model = VECM(vecm_df, coint_rank=1) f"exec_price_{self.symbol_b_}",
vecm_fit = vecm_model.fit() ]
assert vecm_fit is not None
# URGENT check beta and alpha
# Check if the model converged properly
if not hasattr(vecm_fit, "beta") or vecm_fit.beta is None:
print(f"{self}: VECM model failed to converge properly")
self.vecm_fit_ = vecm_fit
# print(f"{self}: beta={self.vecm_fit_.beta} alpha={self.vecm_fit_.alpha}" )
# print(f"{self}: {self.vecm_fit_.summary()}")
pass
def check_cointegration_johansen(self) -> bool:
assert self.training_df_ is not None
from statsmodels.tsa.vector_ar.vecm import coint_johansen
df = self.training_df_[self.colnames()].reset_index(drop=True)
result = coint_johansen(df, det_order=0, k_ar_diff=1)
# print(
# f"{self}: lr1={result.lr1[0]} > cvt={result.cvt[0, 1]}? {result.lr1[0] > result.cvt[0, 1]}"
# )
is_cointegrated: bool = bool(result.lr1[0] > result.cvt[0, 1])
return is_cointegrated
def check_cointegration_engle_granger(self) -> bool:
from statsmodels.tsa.stattools import coint
col1, col2 = self.colnames()
assert self.training_df_ is not None
series1 = self.training_df_[col1].reset_index(drop=True)
series2 = self.training_df_[col2].reset_index(drop=True)
# Run Engle-Granger cointegration test
pvalue = coint(series1, series2)[1]
# Define cointegration if p-value < 0.05 (i.e., reject null of no cointegration)
is_cointegrated: bool = bool(pvalue < 0.05)
# print(f"{self}: is_cointegrated={is_cointegrated} pvalue={pvalue}")
return is_cointegrated
def check_cointegration(self) -> bool:
is_cointegrated_johansen = self.check_cointegration_johansen()
is_cointegrated_engle_granger = self.check_cointegration_engle_granger()
result = is_cointegrated_johansen or is_cointegrated_engle_granger
return result or True # TODO: remove this
def train_pair(self) -> bool:
result = self.check_cointegration()
# print('*' * 80 + '\n' + f"**************** {self} IS COINTEGRATED ****************\n" + '*' * 80)
self.fit_VECM()
assert self.training_df_ is not None and self.vecm_fit_ is not None
diseq_series = self.training_df_[self.colnames()] @ self.vecm_fit_.beta
# print(diseq_series.shape)
self.training_mu_ = float(diseq_series[0].mean())
self.training_std_ = float(diseq_series[0].std())
self.training_df_["dis-equilibrium"] = (
self.training_df_[self.colnames()] @ self.vecm_fit_.beta
)
# Normalize the dis-equilibrium
self.training_df_["scaled_dis-equilibrium"] = (
diseq_series - self.training_mu_
) / self.training_std_
return result
def add_trades(self, trades: pd.DataFrame) -> None: def add_trades(self, trades: pd.DataFrame) -> None:
if self.user_data_["trades"] is None or len(self.user_data_["trades"]) == 0: if self.user_data_["trades"] is None or len(self.user_data_["trades"]) == 0:
@ -209,7 +234,7 @@ class TradingPair:
else: else:
# Ensure both DataFrames have the same columns and dtypes before concatenation # Ensure both DataFrames have the same columns and dtypes before concatenation
existing_trades = self.user_data_["trades"] existing_trades = self.user_data_["trades"]
# If existing trades is empty, just assign the new trades # If existing trades is empty, just assign the new trades
if len(existing_trades) == 0: if len(existing_trades) == 0:
self.user_data_["trades"] = trades.copy() self.user_data_["trades"] = trades.copy()
@ -223,68 +248,123 @@ class TradingPair:
trades[col] = pd.Timestamp.now() trades[col] = pd.Timestamp.now()
elif col in ["action", "symbol"]: elif col in ["action", "symbol"]:
trades[col] = "" trades[col] = ""
elif col in ["price", "disequilibrium", "scaled_disequilibrium"]: elif col in [
"price",
"disequilibrium",
"scaled_disequilibrium",
]:
trades[col] = 0.0 trades[col] = 0.0
elif col == "pair": elif col == "pair":
trades[col] = None trades[col] = None
else: else:
trades[col] = None trades[col] = None
# Concatenate with explicit dtypes to avoid warnings # Concatenate with explicit dtypes to avoid warnings
self.user_data_["trades"] = pd.concat( self.user_data_["trades"] = pd.concat(
[existing_trades, trades], [existing_trades, trades], ignore_index=True, copy=False
ignore_index=True,
copy=False
) )
def get_trades(self) -> pd.DataFrame: def get_trades(self) -> pd.DataFrame:
return self.user_data_["trades"] if "trades" in self.user_data_ else pd.DataFrame() return (
self.user_data_["trades"] if "trades" in self.user_data_ else pd.DataFrame()
def predict(self) -> pd.DataFrame:
assert self.testing_df_ is not None
assert self.vecm_fit_ is not None
predicted_prices = self.vecm_fit_.predict(steps=len(self.testing_df_))
# Convert prediction to a DataFrame for readability
predicted_df = pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
) )
def cointegration_check(self) -> Optional[pd.DataFrame]:
print(f"***{self}*** STARTING....")
config = self.config_
predicted_df = pd.merge( curr_training_start_idx = 0
self.testing_df_.reset_index(drop=True),
pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
),
left_index=True,
right_index=True,
suffixes=("", "_pred"),
).dropna()
predicted_df["disequilibrium"] = ( COINTEGRATION_DATA_COLUMNS = {
predicted_df[self.colnames()] @ self.vecm_fit_.beta "tstamp": "datetime64[ns]",
) "pair": "string",
"eg_pvalue": "float64",
"johansen_lr1": "float64",
"johansen_cvt": "float64",
"eg_is_cointegrated": "bool",
"johansen_is_cointegrated": "bool",
}
# Initialize trades DataFrame with proper dtypes to avoid concatenation warnings
result: pd.DataFrame = pd.DataFrame(
columns=[col for col in COINTEGRATION_DATA_COLUMNS.keys()]
) # .astype(COINTEGRATION_DATA_COLUMNS)
predicted_df["scaled_disequilibrium"] = ( training_minutes = config["training_minutes"]
abs(predicted_df["disequilibrium"] - self.training_mu_) while True:
/ self.training_std_ print(curr_training_start_idx, end="\r")
) self.get_datasets(
training_minutes=training_minutes,
# print("*** PREDICTED DF") training_start_index=curr_training_start_idx,
# print(predicted_df) testing_size=1,
# print("*" * 80) )
# print("*** SELF.PREDICTED_DF")
# print(self.predicted_df_)
# print("*" * 80)
predicted_df = predicted_df.reset_index(drop=True) if len(self.training_df_) < training_minutes:
if self.predicted_df_ is None: print(
self.predicted_df_ = predicted_df f"{self}: current offset={curr_training_start_idx}"
else: f" * Training data length={len(self.training_df_)} < {training_minutes}"
self.predicted_df_ = pd.concat([self.predicted_df_, predicted_df], ignore_index=True) " * Not enough training data. Completing the job."
# Reset index to ensure proper indexing )
self.predicted_df_ = self.predicted_df_.reset_index(drop=True) break
return self.predicted_df_ new_row = pd.Series(CointegrationData(self).to_dict())
result.loc[len(result)] = new_row
curr_training_start_idx += 1
return result
def to_stop_close_conditions(self, predicted_row: pd.Series) -> bool:
config = self.config_
if (
"stop_close_conditions" not in config
or config["stop_close_conditions"] is None
):
return False
if "profit" in config["stop_close_conditions"]:
current_return = self._current_return(predicted_row)
#
# print(f"time={predicted_row['tstamp']} current_return={current_return}")
#
if current_return >= config["stop_close_conditions"]["profit"]:
print(f"STOP PROFIT: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_PROFIT
return True
if "loss" in config["stop_close_conditions"]:
if current_return <= config["stop_close_conditions"]["loss"]:
print(f"STOP LOSS: {current_return}")
self.user_data_["stop_close_state"] = PairState.CLOSE_STOP_LOSS
return True
return False
def on_open_trades(self, trades: pd.DataFrame) -> None:
if "close_trades" in self.user_data_:
del self.user_data_["close_trades"]
self.user_data_["open_trades"] = trades
def on_close_trades(self, trades: pd.DataFrame) -> None:
del self.user_data_["open_trades"]
self.user_data_["close_trades"] = trades
def _current_return(self, predicted_row: pd.Series) -> float:
if "open_trades" in self.user_data_:
open_trades = self.user_data_["open_trades"]
if len(open_trades) == 0:
return 0.0
def _single_instrument_return(symbol: str) -> float:
instrument_open_trades = open_trades[open_trades["symbol"] == symbol]
instrument_open_price = instrument_open_trades["price"].iloc[0]
sign = -1 if instrument_open_trades["side"].iloc[0] == "SELL" else 1
instrument_price = predicted_row[f"{self.stat_model_price_}_{symbol}"]
instrument_return = (
sign
* (instrument_price - instrument_open_price)
/ instrument_open_price
)
return float(instrument_return) * 100.0
instrument_a_return = _single_instrument_return(self.symbol_a_)
instrument_b_return = _single_instrument_return(self.symbol_b_)
return instrument_a_return + instrument_b_return
return 0.0
def __repr__(self) -> str: def __repr__(self) -> str:
return self.name() return self.name()
@ -292,3 +372,9 @@ class TradingPair:
def name(self) -> str: def name(self) -> str:
return f"{self.symbol_a_} & {self.symbol_b_}" return f"{self.symbol_a_} & {self.symbol_b_}"
# return f"{self.symbol_a_} & {self.symbol_b_}" # return f"{self.symbol_a_} & {self.symbol_b_}"
@abstractmethod
def predict(self) -> pd.DataFrame: ...
# @abstractmethod
# def predicted_df(self) -> Optional[pd.DataFrame]: ...

View File

@ -0,0 +1,122 @@
from typing import Any, Dict, Optional, cast
import pandas as pd
from pt_trading.results import BacktestResult
from pt_trading.rolling_window_fit import RollingFit
from pt_trading.trading_pair import TradingPair
from statsmodels.tsa.vector_ar.vecm import VECM, VECMResults
NanoPerMin = 1e9
class VECMTradingPair(TradingPair):
vecm_fit_: Optional[VECMResults]
pair_predict_result_: Optional[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
):
super().__init__(config, market_data, symbol_a, symbol_b)
self.vecm_fit_ = None
self.pair_predict_result_ = None
def _train_pair(self) -> None:
self._fit_VECM()
assert self.vecm_fit_ is not None
diseq_series = self.training_df_[self.colnames()] @ self.vecm_fit_.beta
# print(diseq_series.shape)
self.training_mu_ = float(diseq_series[0].mean())
self.training_std_ = float(diseq_series[0].std())
self.training_df_["dis-equilibrium"] = (
self.training_df_[self.colnames()] @ self.vecm_fit_.beta
)
# Normalize the dis-equilibrium
self.training_df_["scaled_dis-equilibrium"] = (
diseq_series - self.training_mu_
) / self.training_std_
def _fit_VECM(self) -> None:
assert self.training_df_ is not None
vecm_df = self.training_df_[self.colnames()].reset_index(drop=True)
vecm_model = VECM(vecm_df, coint_rank=1)
vecm_fit = vecm_model.fit()
assert vecm_fit is not None
# URGENT check beta and alpha
# Check if the model converged properly
if not hasattr(vecm_fit, "beta") or vecm_fit.beta is None:
print(f"{self}: VECM model failed to converge properly")
self.vecm_fit_ = vecm_fit
pass
def predict(self) -> pd.DataFrame:
self._train_pair()
assert self.testing_df_ is not None
assert self.vecm_fit_ is not None
predicted_prices = self.vecm_fit_.predict(steps=len(self.testing_df_))
# Convert prediction to a DataFrame for readability
predicted_df = pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
)
predicted_df = pd.merge(
self.testing_df_.reset_index(drop=True),
pd.DataFrame(
predicted_prices, columns=pd.Index(self.colnames()), dtype=float
),
left_index=True,
right_index=True,
suffixes=("", "_pred"),
).dropna()
predicted_df["disequilibrium"] = (
predicted_df[self.colnames()] @ self.vecm_fit_.beta
)
predicted_df["signed_scaled_disequilibrium"] = (
predicted_df["disequilibrium"] - self.training_mu_
) / self.training_std_
predicted_df["scaled_disequilibrium"] = abs(
predicted_df["signed_scaled_disequilibrium"]
)
predicted_df = predicted_df.reset_index(drop=True)
if self.pair_predict_result_ is None:
self.pair_predict_result_ = predicted_df
else:
self.pair_predict_result_ = pd.concat(
[self.pair_predict_result_, predicted_df], ignore_index=True
)
# Reset index to ensure proper indexing
self.pair_predict_result_ = self.pair_predict_result_.reset_index(drop=True)
return self.pair_predict_result_
class VECMRollingFit(RollingFit):
def __init__(self) -> None:
super().__init__()
def create_trading_pair(
self,
config: Dict,
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
) -> TradingPair:
return VECMTradingPair(
config=config,
market_data=market_data,
symbol_a=symbol_a,
symbol_b=symbol_b,
)

View File

@ -0,0 +1,85 @@
from typing import Any, Dict, Optional, cast
import pandas as pd
from pt_trading.results import BacktestResult
from pt_trading.rolling_window_fit import RollingFit
from pt_trading.trading_pair import TradingPair
import statsmodels.api as sm
NanoPerMin = 1e9
class ZScoreTradingPair(TradingPair):
zscore_model_: Optional[sm.regression.linear_model.RegressionResultsWrapper]
pair_predict_result_: Optional[pd.DataFrame]
zscore_df_: Optional[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
):
super().__init__(config, market_data, symbol_a, symbol_b)
self.zscore_model_ = None
self.pair_predict_result_ = None
self.zscore_df_ = None
def _fit_zscore(self) -> None:
assert self.training_df_ is not None
symbol_a_px_series = self.training_df_[self.colnames()].iloc[:, 0]
symbol_b_px_series = self.training_df_[self.colnames()].iloc[:, 1]
symbol_a_px_series, symbol_b_px_series = symbol_a_px_series.align(
symbol_b_px_series, axis=0
)
X = sm.add_constant(symbol_b_px_series)
self.zscore_model_ = sm.OLS(symbol_a_px_series, X).fit()
assert self.zscore_model_ is not None
hedge_ratio = self.zscore_model_.params.iloc[1]
# Calculate spread and Z-score
spread = symbol_a_px_series - hedge_ratio * symbol_b_px_series
self.zscore_df_ = (spread - spread.mean()) / spread.std()
def predict(self) -> pd.DataFrame:
self._fit_zscore()
assert self.zscore_df_ is not None
self.training_df_["dis-equilibrium"] = self.zscore_df_
self.training_df_["scaled_dis-equilibrium"] = abs(self.zscore_df_)
assert self.testing_df_ is not None
assert self.zscore_df_ is not None
predicted_df = self.testing_df_
predicted_df["disequilibrium"] = self.zscore_df_
predicted_df["signed_scaled_disequilibrium"] = self.zscore_df_
predicted_df["scaled_disequilibrium"] = abs(self.zscore_df_)
predicted_df = predicted_df.reset_index(drop=True)
if self.pair_predict_result_ is None:
self.pair_predict_result_ = predicted_df
else:
self.pair_predict_result_ = pd.concat(
[self.pair_predict_result_, predicted_df], ignore_index=True
)
# Reset index to ensure proper indexing
self.pair_predict_result_ = self.pair_predict_result_.reset_index(drop=True)
return self.pair_predict_result_.dropna()
class ZScoreRollingFit(RollingFit):
def __init__(self) -> None:
super().__init__()
def create_trading_pair(
self, config: Dict, market_data: pd.DataFrame, symbol_a: str, symbol_b: str
) -> TradingPair:
return ZScoreTradingPair(
config=config,
market_data=market_data,
symbol_a=symbol_a,
symbol_b=symbol_b,
)

View File

@ -1,10 +1,17 @@
from __future__ import annotations
import sqlite3 import sqlite3
from typing import Dict, List, cast from typing import Dict, List, cast
import pandas as pd import pandas as pd
def load_sqlite_to_dataframe(db_path:str, query:str) -> pd.DataFrame:
df: pd.DataFrame = pd.DataFrame()
import os
if not os.path.exists(db_path):
print(f"WARNING: database file {db_path} does not exist")
return df
def load_sqlite_to_dataframe(db_path, query):
try: try:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
@ -21,13 +28,14 @@ def load_sqlite_to_dataframe(db_path, query):
conn.close() conn.close()
def convert_time_to_UTC(value: str, timezone: str) -> str: def convert_time_to_UTC(value: str, timezone: str, extra_minutes: int = 0) -> str:
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from datetime import datetime from datetime import datetime, timedelta
# Parse it to naive datetime object # Parse it to naive datetime object
local_dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S") local_dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
local_dt = local_dt + timedelta(minutes=extra_minutes)
zinfo = ZoneInfo(timezone) zinfo = ZoneInfo(timezone)
result: datetime = local_dt.replace(tzinfo=zinfo).astimezone(ZoneInfo("UTC")) result: datetime = local_dt.replace(tzinfo=zinfo).astimezone(ZoneInfo("UTC"))
@ -35,25 +43,28 @@ def convert_time_to_UTC(value: str, timezone: str) -> str:
return result.strftime("%Y-%m-%d %H:%M:%S") return result.strftime("%Y-%m-%d %H:%M:%S")
def load_market_data(datafile: str, config: Dict) -> pd.DataFrame: def load_market_data(
from tools.data_loader import load_sqlite_to_dataframe datafile: str,
instruments: List[Dict[str, str]],
db_table_name: str,
trading_hours: Dict = {},
extra_minutes: int = 0,
) -> pd.DataFrame:
instrument_ids = [ insts = [
'"' + config["instrument_id_pfx"] + instrument + '"' '"' + instrument["instrument_id_pfx"] + instrument["symbol"] + '"'
for instrument in config["instruments"] for instrument in instruments
] ]
security_type = config["security_type"] instrument_ids = list(set(insts))
exchange_id = config["exchange_id"] exchange_ids = list(
set(['"' + instrument["exchange_id"] + '"' for instrument in instruments])
)
query = "select" query = "select"
if security_type == "CRYPTO": query += " tstamp"
query += " strftime('%Y-%m-%d %H:%M:%S', tstamp_ns/1000000000, 'unixepoch') as tstamp" query += ", tstamp_ns as time_ns"
query += ", tstamp as time_ns"
else:
query += " tstamp"
query += ", tstamp_ns as time_ns"
query += f", substr(instrument_id, {len(config['instrument_id_pfx']) + 1}) as symbol" query += f", substr(instrument_id, instr(instrument_id, '-') + 1) as symbol"
query += ", open" query += ", open"
query += ", high" query += ", high"
query += ", low" query += ", low"
@ -62,74 +73,76 @@ def load_market_data(datafile: str, config: Dict) -> pd.DataFrame:
query += ", num_trades" query += ", num_trades"
query += ", vwap" query += ", vwap"
query += f" from {config['db_table_name']}" query += f" from {db_table_name}"
query += f" where exchange_id ='{exchange_id}'" query += f" where exchange_id in ({','.join(exchange_ids)})"
query += f" and instrument_id in ({','.join(instrument_ids)})" query += f" and instrument_id in ({','.join(instrument_ids)})"
df = load_sqlite_to_dataframe(db_path=datafile, query=query) df = load_sqlite_to_dataframe(db_path=datafile, query=query)
# Trading Hours # Trading Hours
date_str = df["tstamp"][0][0:10] if len(df) > 0 and len(trading_hours) > 0:
trading_hours = config["trading_hours"] date_str = df["tstamp"][0][0:10]
start_time = convert_time_to_UTC( start_time = convert_time_to_UTC(
f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"] f"{date_str} {trading_hours['begin_session']}", trading_hours["timezone"]
) )
end_time = convert_time_to_UTC( end_time = convert_time_to_UTC(
f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"] f"{date_str} {trading_hours['end_session']}", trading_hours["timezone"], extra_minutes=extra_minutes # to get execution price
) )
# Perform boolean selection # Perform boolean selection
df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)] df = df[(df["tstamp"] >= start_time) & (df["tstamp"] <= end_time)]
df["tstamp"] = pd.to_datetime(df["tstamp"]) df["tstamp"] = pd.to_datetime(df["tstamp"])
return cast(pd.DataFrame, df) return cast(pd.DataFrame, df)
def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]: # def get_available_instruments_from_db(datafile: str, config: Dict) -> List[str]:
""" # """
Auto-detect available instruments from the database by querying distinct instrument_id values. # Auto-detect available instruments from the database by querying distinct instrument_id values.
Returns instruments without the configured prefix. # Returns instruments without the configured prefix.
""" # """
try: # try:
conn = sqlite3.connect(datafile) # conn = sqlite3.connect(datafile)
# Build exclusion list with full instrument_ids # # Build exclusion list with full instrument_ids
exclude_instruments = config.get("exclude_instruments", []) # exclude_instruments = config.get("exclude_instruments", [])
prefix = config.get("instrument_id_pfx", "") # prefix = config.get("instrument_id_pfx", "")
exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments] # exclude_instrument_ids = [f"{prefix}{inst}" for inst in exclude_instruments]
# Query to get distinct instrument_ids
query = f"""
SELECT DISTINCT instrument_id
FROM {config['db_table_name']}
WHERE exchange_id = ?
"""
# Add exclusion clause if there are instruments to exclude
if exclude_instrument_ids:
placeholders = ','.join(['?' for _ in exclude_instrument_ids])
query += f" AND instrument_id NOT IN ({placeholders})"
cursor = conn.execute(query, (config["exchange_id"],) + tuple(exclude_instrument_ids))
else:
cursor = conn.execute(query, (config["exchange_id"],))
instrument_ids = [row[0] for row in cursor.fetchall()]
conn.close()
# Remove the configured prefix to get instrument symbols # # Query to get distinct instrument_ids
instruments = [] # query = f"""
for instrument_id in instrument_ids: # SELECT DISTINCT instrument_id
if instrument_id.startswith(prefix): # FROM {config['db_table_name']}
symbol = instrument_id[len(prefix) :] # WHERE exchange_id = ?
instruments.append(symbol) # """
else:
instruments.append(instrument_id)
return sorted(instruments) # # Add exclusion clause if there are instruments to exclude
# if exclude_instrument_ids:
# placeholders = ",".join(["?" for _ in exclude_instrument_ids])
# query += f" AND instrument_id NOT IN ({placeholders})"
# cursor = conn.execute(
# query, (config["exchange_id"],) + tuple(exclude_instrument_ids)
# )
# else:
# cursor = conn.execute(query, (config["exchange_id"],))
# instrument_ids = [row[0] for row in cursor.fetchall()]
# conn.close()
except Exception as e: # # Remove the configured prefix to get instrument symbols
print(f"Error auto-detecting instruments from {datafile}: {str(e)}") # instruments = []
return [] # for instrument_id in instrument_ids:
# if instrument_id.startswith(prefix):
# symbol = instrument_id[len(prefix) :]
# instruments.append(symbol)
# else:
# instruments.append(instrument_id)
# return sorted(instruments)
# except Exception as e:
# print(f"Error auto-detecting instruments from {datafile}: {str(e)}")
# return []
# if __name__ == "__main__": # if __name__ == "__main__":

View File

@ -74,6 +74,7 @@ PyYAML>=6.0
reportlab>=3.6.8 reportlab>=3.6.8
requests>=2.25.1 requests>=2.25.1
requests-file>=1.5.1 requests-file>=1.5.1
scipy<1.13.0
seaborn>=0.13.2 seaborn>=0.13.2
SecretStorage>=3.3.1 SecretStorage>=3.3.1
setproctitle>=1.2.2 setproctitle>=1.2.2

View File

@ -0,0 +1,126 @@
import argparse
import glob
import importlib
import os
from datetime import date, datetime
from typing import Any, Dict, List, Optional
import pandas as pd
from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db
from pt_trading.results import (
BacktestResult,
create_result_database,
store_config_in_database,
store_results_in_database,
)
from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair
from research.research_tools import create_pairs, resolve_datafiles
def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file."
)
parser.add_argument(
"--datafile",
type=str,
required=False,
help="Market data file to process.",
)
parser.add_argument(
"--instruments",
type=str,
required=False,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.",
)
args = parser.parse_args()
config: Dict = load_config(args.config)
# Resolve data files (CLI takes priority over config)
datafile = resolve_datafiles(config, args.datafile)[0]
if not datafile:
print("No data files found to process.")
return
print(f"Found {datafile} data files to process:")
# # Create result database if needed
# if args.result_db.upper() != "NONE":
# args.result_db = expand_filename(args.result_db)
# create_result_database(args.result_db)
# # Initialize a dictionary to store all trade results
# all_results: Dict[str, Dict[str, Any]] = {}
# # Store configuration in database for reference
# if args.result_db.upper() != "NONE":
# # Get list of all instruments for storage
# all_instruments = []
# for datafile in datafiles:
# if args.instruments:
# file_instruments = [
# inst.strip() for inst in args.instruments.split(",")
# ]
# else:
# file_instruments = get_available_instruments_from_db(datafile, config)
# all_instruments.extend(file_instruments)
# # Remove duplicates while preserving order
# unique_instruments = list(dict.fromkeys(all_instruments))
# store_config_in_database(
# db_path=args.result_db,
# config_file_path=args.config,
# config=config,
# fit_method_class=fit_method_class_name,
# datafiles=datafiles,
# instruments=unique_instruments,
# )
# Process each data file
stat_model_price = config["stat_model_price"]
print(f"\n====== Processing {os.path.basename(datafile)} ======")
# Determine instruments to use
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found in {datafile}...")
return
# Process data for this file
try:
cointegration_data: pd.DataFrame = pd.DataFrame()
for pair in create_pairs(datafile, stat_model_price, config, instruments):
cointegration_data = pd.concat([cointegration_data, pair.cointegration_check()])
pd.set_option('display.width', 400)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
print(f"cointegration_data:\n{cointegration_data}")
except Exception as err:
print(f"Error processing {datafile}: {str(err)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

View File

@ -1,771 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pairs Trading Visualization Notebook\n",
"\n",
"This notebook allows you to visualize pairs trading strategies on individual instrument pairs.\n",
"You can examine the relationship between two instruments, their dis-equilibrium, and trading signals."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🎯 Key Features:\n",
"\n",
"1. **Interactive Configuration**: \n",
" - Easy switching between CRYPTO and EQUITY configurations\n",
" - Simple parameter adjustment for thresholds and training periods\n",
"\n",
"2. **Single Pair Focus**: \n",
" - Instead of running multiple pairs, focuses on one pair at a time\n",
" - Allows deep analysis of the relationship between two instruments\n",
"\n",
"3. **Step-by-Step Visualization**:\n",
" - **Raw price data**: Individual prices, normalized comparison, and price ratios\n",
" - **Training analysis**: Cointegration testing and VECM model fitting\n",
" - **Dis-equilibrium visualization**: Both raw and scaled dis-equilibrium with threshold lines\n",
" - **Strategy execution**: Trading signal generation and visualization\n",
" - **Prediction analysis**: Actual vs predicted prices with trading signals overlaid\n",
"\n",
"4. **Rich Analytics**:\n",
" - Cointegration status and VECM model details\n",
" - Statistical summaries for all stages\n",
" - Threshold crossing analysis\n",
" - Trading signal breakdown\n",
"\n",
"5. **Interactive Experimentation**:\n",
" - Easy parameter modification\n",
" - Re-run capabilities for different configurations\n",
" - Support for both StaticFitStrategy and SlidingFitStrategy\n",
"\n",
"### 🚀 How to Use:\n",
"\n",
"1. **Start Jupyter**:\n",
" ```bash\n",
" cd src/notebooks\n",
" jupyter notebook pairs_trading_visualization.ipynb\n",
" ```\n",
"\n",
"2. **Customize Your Analysis**:\n",
" - Change `SYMBOL_A` and `SYMBOL_B` to your desired trading pair\n",
" - Switch between `CRYPTO_CONFIG` and `EQT_CONFIG`\n",
" - Only **StaticFitStrategy** is supported. \n",
" - Adjust thresholds and parameters as needed\n",
"\n",
"3. **Run and Visualize**:\n",
" - Execute cells step by step to see the analysis unfold\n",
" - Rich matplotlib visualizations show relationships and signals\n",
" - Comprehensive summary at the end\n",
"\n",
"The notebook provides exactly what you requested - a way to visualize the relationship between two instruments and their scaled dis-equilibrium, with all the stages of your pairs trading strategy clearly displayed and analyzed.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setup complete!\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"sys.path.append('..')\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from typing import Dict, List, Optional\n",
"\n",
"# Import our modules\n",
"from pt_trading.fit_methods import StaticFit, SlidingFit\n",
"from tools.data_loader import load_market_data\n",
"from pt_trading.trading_pair import TradingPair\n",
"from pt_trading.results import BacktestResult\n",
"\n",
"# Set plotting style\n",
"plt.style.use('seaborn-v0_8')\n",
"sns.set_palette(\"husl\")\n",
"plt.rcParams['figure.figsize'] = (12, 8)\n",
"\n",
"print(\"Setup complete!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using EQUITY configuration\n",
"Available instruments: ['COIN', 'GBTC', 'HOOD', 'MSTR', 'PYPL']\n"
]
}
],
"source": [
"# Configuration - Choose between CRYPTO_CONFIG or EQT_CONFIG\n",
"\n",
"CRYPTO_CONFIG = {\n",
" \"security_type\": \"CRYPTO\",\n",
" \"data_directory\": \"../../data/crypto\",\n",
" \"datafiles\": [\n",
" \"20250519.mktdata.ohlcv.db\",\n",
" ],\n",
" \"db_table_name\": \"bnbspot_ohlcv_1min\",\n",
" \"exchange_id\": \"BNBSPOT\",\n",
" \"instrument_id_pfx\": \"PAIR-\",\n",
" \"instruments\": [\n",
" \"BTC-USDT\",\n",
" \"BCH-USDT\",\n",
" \"ETH-USDT\",\n",
" \"LTC-USDT\",\n",
" \"XRP-USDT\",\n",
" \"ADA-USDT\",\n",
" \"SOL-USDT\",\n",
" \"DOT-USDT\",\n",
" ],\n",
" \"trading_hours\": {\n",
" \"begin_session\": \"00:00:00\",\n",
" \"end_session\": \"23:59:00\",\n",
" \"timezone\": \"UTC\",\n",
" },\n",
" \"price_column\": \"close\",\n",
" \"min_required_points\": 30,\n",
" \"zero_threshold\": 1e-10,\n",
" \"dis-equilibrium_open_trshld\": 2.0,\n",
" \"dis-equilibrium_close_trshld\": 0.5,\n",
" \"training_minutes\": 120,\n",
" \"funding_per_pair\": 2000.0,\n",
"}\n",
"\n",
"EQT_CONFIG = {\n",
" \"security_type\": \"EQUITY\",\n",
" \"data_directory\": \"../../data/equity\",\n",
" \"datafiles\": {\n",
" \"0508\": \"20250508.alpaca_sim_md.db\",\n",
" \"0509\": \"20250509.alpaca_sim_md.db\",\n",
" \"0510\": \"20250510.alpaca_sim_md.db\",\n",
" \"0511\": \"20250511.alpaca_sim_md.db\",\n",
" \"0512\": \"20250512.alpaca_sim_md.db\",\n",
" \"0513\": \"20250513.alpaca_sim_md.db\",\n",
" \"0514\": \"20250514.alpaca_sim_md.db\",\n",
" \"0515\": \"20250515.alpaca_sim_md.db\",\n",
" \"0516\": \"20250516.alpaca_sim_md.db\",\n",
" \"0517\": \"20250517.alpaca_sim_md.db\",\n",
" \"0518\": \"20250518.alpaca_sim_md.db\",\n",
" \"0519\": \"20250519.alpaca_sim_md.db\",\n",
" \"0520\": \"20250520.alpaca_sim_md.db\",\n",
" \"0521\": \"20250521.alpaca_sim_md.db\",\n",
" \"0522\": \"20250522.alpaca_sim_md.db\",\n",
" },\n",
" \"db_table_name\": \"md_1min_bars\",\n",
" \"exchange_id\": \"ALPACA\",\n",
" \"instrument_id_pfx\": \"STOCK-\",\n",
" \"instruments\": [\n",
" \"COIN\",\n",
" \"GBTC\",\n",
" \"HOOD\",\n",
" \"MSTR\",\n",
" \"PYPL\",\n",
" ],\n",
" \"trading_hours\": {\n",
" \"begin_session\": \"9:30:00\",\n",
" \"end_session\": \"16:00:00\",\n",
" \"timezone\": \"America/New_York\",\n",
" },\n",
" \"price_column\": \"close\",\n",
" \"min_required_points\": 30,\n",
" \"zero_threshold\": 1e-10,\n",
" \"dis-equilibrium_open_trshld\": 2.0,\n",
" \"dis-equilibrium_close_trshld\": 1.0, #0.5,\n",
" \"training_minutes\": 120,\n",
" \"funding_per_pair\": 2000.0,\n",
"}\n",
"\n",
"# Choose your configuration\n",
"CONFIG = EQT_CONFIG # Change to CRYPTO_CONFIG if you want to use crypto data\n",
"\n",
"print(f\"Using {CONFIG['security_type']} configuration\")\n",
"print(f\"Available instruments: {CONFIG['instruments']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Select Trading Pair and Data File"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected pair: COIN & GBTC\n",
"Data file: 20250509.alpaca_sim_md.db\n",
"Strategy: StaticFitStrategy\n"
]
}
],
"source": [
"# Select your trading pair and strategy\n",
"SYMBOL_A = \"COIN\" # Change these to your desired symbols\n",
"SYMBOL_B = \"GBTC\"\n",
"DATA_FILE = CONFIG[\"datafiles\"][\"0509\"]\n",
"\n",
"# Choose strategy\n",
"FIT_METHOD = StaticFit()\n",
"\n",
"print(f\"Selected pair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Strategy: {type(FIT_METHOD).__name__}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Market Data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current working directory: /home/oleg/devel/pairs_trading/src/notebooks\n",
"Loading data from: ../../data/equity/20250509.alpaca_sim_md.db\n",
"Error: Execution failed on sql 'select tstamp, tstamp_ns as time_ns, substr(instrument_id, 7) as symbol, open, high, low, close, volume, num_trades, vwap from md_1min_bars where exchange_id ='ALPACA' and instrument_id in (\"STOCK-COIN\",\"STOCK-GBTC\",\"STOCK-HOOD\",\"STOCK-MSTR\",\"STOCK-PYPL\")': no such table: md_1min_bars\n"
]
},
{
"ename": "Exception",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mOperationalError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:2664\u001b[39m, in \u001b[36mSQLiteDatabase.execute\u001b[39m\u001b[34m(self, sql, params)\u001b[39m\n\u001b[32m 2663\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2664\u001b[39m \u001b[43mcur\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43msql\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2665\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m cur\n",
"\u001b[31mOperationalError\u001b[39m: no such table: md_1min_bars",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mDatabaseError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m~/devel/pairs_trading/src/notebooks/../tools/data_loader.py:11\u001b[39m, in \u001b[36mload_sqlite_to_dataframe\u001b[39m\u001b[34m(db_path, query)\u001b[39m\n\u001b[32m 9\u001b[39m conn = sqlite3.connect(db_path)\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m df = \u001b[43mpd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mread_sql_query\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m df\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:528\u001b[39m, in \u001b[36mread_sql_query\u001b[39m\u001b[34m(sql, con, index_col, coerce_float, params, parse_dates, chunksize, dtype, dtype_backend)\u001b[39m\n\u001b[32m 527\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m pandasSQL_builder(con) \u001b[38;5;28;01mas\u001b[39;00m pandas_sql:\n\u001b[32m--> \u001b[39m\u001b[32m528\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpandas_sql\u001b[49m\u001b[43m.\u001b[49m\u001b[43mread_query\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 529\u001b[39m \u001b[43m \u001b[49m\u001b[43msql\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 530\u001b[39m \u001b[43m \u001b[49m\u001b[43mindex_col\u001b[49m\u001b[43m=\u001b[49m\u001b[43mindex_col\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 531\u001b[39m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m=\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 532\u001b[39m \u001b[43m \u001b[49m\u001b[43mcoerce_float\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcoerce_float\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 533\u001b[39m \u001b[43m \u001b[49m\u001b[43mparse_dates\u001b[49m\u001b[43m=\u001b[49m\u001b[43mparse_dates\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 534\u001b[39m \u001b[43m \u001b[49m\u001b[43mchunksize\u001b[49m\u001b[43m=\u001b[49m\u001b[43mchunksize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 535\u001b[39m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 536\u001b[39m \u001b[43m \u001b[49m\u001b[43mdtype_backend\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdtype_backend\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 537\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:2728\u001b[39m, in \u001b[36mSQLiteDatabase.read_query\u001b[39m\u001b[34m(self, sql, index_col, coerce_float, parse_dates, params, chunksize, dtype, dtype_backend)\u001b[39m\n\u001b[32m 2717\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mread_query\u001b[39m(\n\u001b[32m 2718\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 2719\u001b[39m sql,\n\u001b[32m (...)\u001b[39m\u001b[32m 2726\u001b[39m dtype_backend: DtypeBackend | Literal[\u001b[33m\"\u001b[39m\u001b[33mnumpy\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[33m\"\u001b[39m\u001b[33mnumpy\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 2727\u001b[39m ) -> DataFrame | Iterator[DataFrame]:\n\u001b[32m-> \u001b[39m\u001b[32m2728\u001b[39m cursor = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43msql\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2729\u001b[39m columns = [col_desc[\u001b[32m0\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m col_desc \u001b[38;5;129;01min\u001b[39;00m cursor.description]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.pyenv/python3.12-venv/lib/python3.12/site-packages/pandas/io/sql.py:2676\u001b[39m, in \u001b[36mSQLiteDatabase.execute\u001b[39m\u001b[34m(self, sql, params)\u001b[39m\n\u001b[32m 2675\u001b[39m ex = DatabaseError(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mExecution failed on sql \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msql\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexc\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m2676\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ex \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mexc\u001b[39;00m\n",
"\u001b[31mDatabaseError\u001b[39m: Execution failed on sql 'select tstamp, tstamp_ns as time_ns, substr(instrument_id, 7) as symbol, open, high, low, close, volume, num_trades, vwap from md_1min_bars where exchange_id ='ALPACA' and instrument_id in (\"STOCK-COIN\",\"STOCK-GBTC\",\"STOCK-HOOD\",\"STOCK-MSTR\",\"STOCK-PYPL\")': no such table: md_1min_bars",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mException\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCurrent working directory: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mos.getcwd()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 4\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mLoading data from: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdatafile_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m market_data_df = \u001b[43mload_market_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdatafile_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m=\u001b[49m\u001b[43mCONFIG\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 8\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mLoaded \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(market_data_df)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m rows of market data\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 9\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mSymbols in data: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmarket_data_df[\u001b[33m'\u001b[39m\u001b[33msymbol\u001b[39m\u001b[33m'\u001b[39m].unique()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/devel/pairs_trading/src/notebooks/../tools/data_loader.py:69\u001b[39m, in \u001b[36mload_market_data\u001b[39m\u001b[34m(datafile, config)\u001b[39m\n\u001b[32m 66\u001b[39m query += \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m where exchange_id =\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexchange_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 67\u001b[39m query += \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m and instrument_id in (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m,\u001b[39m\u001b[33m'\u001b[39m.join(instrument_ids)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m)\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m69\u001b[39m df = \u001b[43mload_sqlite_to_dataframe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdb_path\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdatafile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m=\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 71\u001b[39m \u001b[38;5;66;03m# Trading Hours\u001b[39;00m\n\u001b[32m 72\u001b[39m date_str = df[\u001b[33m\"\u001b[39m\u001b[33mtstamp\u001b[39m\u001b[33m\"\u001b[39m][\u001b[32m0\u001b[39m][\u001b[32m0\u001b[39m:\u001b[32m10\u001b[39m]\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/devel/pairs_trading/src/notebooks/../tools/data_loader.py:18\u001b[39m, in \u001b[36mload_sqlite_to_dataframe\u001b[39m\u001b[34m(db_path, query)\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m excpt:\n\u001b[32m 17\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mError: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexcpt\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m() \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mexcpt\u001b[39;00m\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mconn\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlocals\u001b[39m():\n",
"\u001b[31mException\u001b[39m: "
]
}
],
"source": [
"# Load market data\n",
"datafile_path = f\"{CONFIG['data_directory']}/{DATA_FILE}\"\n",
"print(f\"Current working directory: {os.getcwd()}\")\n",
"print(f\"Loading data from: {datafile_path}\")\n",
"\n",
"market_data_df = load_market_data(datafile_path, config=CONFIG)\n",
"\n",
"print(f\"Loaded {len(market_data_df)} rows of market data\")\n",
"print(f\"Symbols in data: {market_data_df['symbol'].unique()}\")\n",
"print(f\"Time range: {market_data_df['tstamp'].min()} to {market_data_df['tstamp'].max()}\")\n",
"\n",
"# Display first few rows\n",
"market_data_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Trading Pair and Analyze"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create trading pair\n",
"pair = TradingPair(\n",
" market_data=market_data_df,\n",
" symbol_a=SYMBOL_A,\n",
" symbol_b=SYMBOL_B,\n",
" price_column=CONFIG[\"price_column\"]\n",
")\n",
"\n",
"print(f\"Created trading pair: {pair}\")\n",
"print(f\"Market data shape: {pair.market_data_.shape}\")\n",
"print(f\"Column names: {pair.colnames()}\")\n",
"\n",
"# Display first few rows of pair data\n",
"pair.market_data_.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split Data into Training and Testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing datasets\n",
"training_minutes = CONFIG[\"training_minutes\"]\n",
"pair.get_datasets(training_minutes=training_minutes)\n",
"\n",
"print(f\"Training data: {len(pair.training_df_)} rows\")\n",
"print(f\"Testing data: {len(pair.testing_df_)} rows\")\n",
"print(f\"Training period: {pair.training_df_['tstamp'].iloc[0]} to {pair.training_df_['tstamp'].iloc[-1]}\")\n",
"print(f\"Testing period: {pair.testing_df_['tstamp'].iloc[0]} to {pair.testing_df_['tstamp'].iloc[-1]}\")\n",
"\n",
"# Check for any missing data\n",
"print(f\"Training data null values: {pair.training_df_.isnull().sum().sum()}\")\n",
"print(f\"Testing data null values: {pair.testing_df_.isnull().sum().sum()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Raw Price Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Plot raw price data\n",
"fig, axes = plt.subplots(3, 1, figsize=(15, 12))\n",
"\n",
"# Combined price plot\n",
"colname_a, colname_b = pair.colnames()\n",
"all_data = pd.concat([pair.training_df_, pair.testing_df_]).reset_index(drop=True)\n",
"\n",
"# Plot individual prices\n",
"axes[0].plot(all_data['tstamp'], all_data[colname_a], label=f'{SYMBOL_A}', alpha=0.8)\n",
"axes[0].plot(all_data['tstamp'], all_data[colname_b], label=f'{SYMBOL_B}', alpha=0.8)\n",
"axes[0].axvline(x=pair.training_df_['tstamp'].iloc[-1], color='red', linestyle='--', alpha=0.7, label='Train/Test Split')\n",
"axes[0].set_title(f'Price Comparison: {SYMBOL_A} vs {SYMBOL_B}')\n",
"axes[0].set_ylabel('Price')\n",
"axes[0].legend()\n",
"axes[0].grid(True)\n",
"\n",
"# Normalized prices for comparison\n",
"norm_a = all_data[colname_a] / all_data[colname_a].iloc[0]\n",
"norm_b = all_data[colname_b] / all_data[colname_b].iloc[0]\n",
"\n",
"axes[1].plot(all_data['tstamp'], norm_a, label=f'{SYMBOL_A} (normalized)', alpha=0.8)\n",
"axes[1].plot(all_data['tstamp'], norm_b, label=f'{SYMBOL_B} (normalized)', alpha=0.8)\n",
"axes[1].axvline(x=pair.training_df_['tstamp'].iloc[-1], color='red', linestyle='--', alpha=0.7, label='Train/Test Split')\n",
"axes[1].set_title('Normalized Price Comparison')\n",
"axes[1].set_ylabel('Normalized Price')\n",
"axes[1].legend()\n",
"axes[1].grid(True)\n",
"\n",
"# Price ratio\n",
"price_ratio = all_data[colname_a] / all_data[colname_b]\n",
"axes[2].plot(all_data['tstamp'], price_ratio, label=f'{SYMBOL_A}/{SYMBOL_B} Ratio', color='green', alpha=0.8)\n",
"axes[2].axvline(x=pair.training_df_['tstamp'].iloc[-1], color='red', linestyle='--', alpha=0.7, label='Train/Test Split')\n",
"axes[2].set_title('Price Ratio')\n",
"axes[2].set_ylabel('Ratio')\n",
"axes[2].set_xlabel('Time')\n",
"axes[2].legend()\n",
"axes[2].grid(True)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Pair and Check Cointegration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train the pair and check cointegration\n",
"try:\n",
" is_cointegrated = pair.train_pair()\n",
" print(f\"Pair {pair} cointegration status: {is_cointegrated}\")\n",
"\n",
" if is_cointegrated:\n",
" print(f\"VECM Beta coefficients: {pair.vecm_fit_.beta.flatten()}\")\n",
" print(f\"Training dis-equilibrium mean: {pair.training_mu_:.6f}\")\n",
" print(f\"Training dis-equilibrium std: {pair.training_std_:.6f}\")\n",
"\n",
" # Display VECM summary\n",
" print(\"\\nVECM Model Summary:\")\n",
" print(pair.vecm_fit_.summary())\n",
" else:\n",
" print(\"Pair is not cointegrated. Cannot proceed with strategy.\")\n",
"\n",
"except Exception as e:\n",
" print(f\"Training failed: {str(e)}\")\n",
" is_cointegrated = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Training Period Dis-equilibrium"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if is_cointegrated:\n",
" # fig, axes = plt.subplots(, 1, figsize=(15, 10))\n",
"\n",
" # # Raw dis-equilibrium\n",
" # axes[0].plot(pair.training_df_['tstamp'], pair.training_df_['dis-equilibrium'],\n",
" # color='blue', alpha=0.8, label='Raw Dis-equilibrium')\n",
" # axes[0].axhline(y=pair.training_mu_, color='red', linestyle='--', alpha=0.7, label='Mean')\n",
" # axes[0].axhline(y=pair.training_mu_ + pair.training_std_, color='orange', linestyle='--', alpha=0.5, label='+1 Std')\n",
" # axes[0].axhline(y=pair.training_mu_ - pair.training_std_, color='orange', linestyle='--', alpha=0.5, label='-1 Std')\n",
" # axes[0].set_title('Training Period: Raw Dis-equilibrium')\n",
" # axes[0].set_ylabel('Dis-equilibrium')\n",
" # axes[0].legend()\n",
" # axes[0].grid(True)\n",
"\n",
" # Scaled dis-equilibrium\n",
" fig, axes = plt.subplots(1, 1, figsize=(15, 5))\n",
" axes.plot(pair.training_df_['tstamp'], pair.training_df_['scaled_dis-equilibrium'],\n",
" color='green', alpha=0.8, label='Scaled Dis-equilibrium')\n",
" axes.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='Mean (0)')\n",
" axes.axhline(y=1, color='orange', linestyle='--', alpha=0.5, label='+1 Std')\n",
" axes.axhline(y=-1, color='orange', linestyle='--', alpha=0.5, label='-1 Std')\n",
" axes.axhline(y=CONFIG['dis-equilibrium_open_trshld'], color='purple',\n",
" linestyle=':', alpha=0.7, label=f\"Open Threshold ({CONFIG['dis-equilibrium_open_trshld']})\")\n",
" axes.axhline(y=CONFIG['dis-equilibrium_close_trshld'], color='brown',\n",
" linestyle=':', alpha=0.7, label=f\"Close Threshold ({CONFIG['dis-equilibrium_close_trshld']})\")\n",
" axes.set_title('Training Period: Scaled Dis-equilibrium')\n",
" axes.set_ylabel('Scaled Dis-equilibrium')\n",
" axes.set_xlabel('Time')\n",
" axes.legend()\n",
" axes.grid(True)\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" # Print statistics\n",
" print(f\"Training dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.training_df_['dis-equilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.training_df_['dis-equilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.training_df_['dis-equilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.training_df_['dis-equilibrium'].max():.6f}\")\n",
"\n",
" print(f\"\\nScaled dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.training_df_['scaled_dis-equilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.training_df_['scaled_dis-equilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.training_df_['scaled_dis-equilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.training_df_['scaled_dis-equilibrium'].max():.6f}\")\n",
"else:\n",
" print(\"The pair is not cointegrated\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate Predictions and Run Strategy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if is_cointegrated:\n",
" try:\n",
" # Generate predictions\n",
" pair.predict()\n",
" print(f\"Generated predictions for {len(pair.predicted_df_)} rows\")\n",
"\n",
" # Display prediction data structure\n",
" print(f\"Prediction columns: {list(pair.predicted_df_.columns)}\")\n",
" print(f\"Prediction period: {pair.predicted_df_['tstamp'].iloc[0]} to {pair.predicted_df_['tstamp'].iloc[-1]}\")\n",
"\n",
" # Run strategy\n",
" bt_result = BacktestResult(config=CONFIG)\n",
" pair_trades = FIT_METHOD.run_pair(config=CONFIG, pair=pair, bt_result=bt_result)\n",
"\n",
" if pair_trades is not None and len(pair_trades) > 0:\n",
" print(f\"\\nGenerated {len(pair_trades)} trading signals:\")\n",
" print(pair_trades)\n",
" else:\n",
" print(\"\\nNo trading signals generated\")\n",
"\n",
" except Exception as e:\n",
" print(f\"Prediction/Strategy failed: {str(e)}\")\n",
" pair_trades = None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Predictions and Dis-equilibrium"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if is_cointegrated and hasattr(pair, 'predicted_df_'):\n",
" fig, axes = plt.subplots(4, 1, figsize=(16, 16))\n",
"\n",
" # Actual vs Predicted Prices\n",
" colname_a, colname_b = pair.colnames()\n",
"\n",
" axes[0].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[colname_a],\n",
" label=f'{SYMBOL_A} Actual', alpha=0.8)\n",
" axes[0].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[f'{colname_a}_pred'],\n",
" label=f'{SYMBOL_A} Predicted', alpha=0.8, linestyle='--')\n",
" axes[0].set_title('Actual vs Predicted Prices - Symbol A')\n",
" axes[0].set_ylabel('Price')\n",
" axes[0].legend()\n",
" axes[0].grid(True)\n",
"\n",
" axes[1].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[colname_b],\n",
" label=f'{SYMBOL_B} Actual', alpha=0.8)\n",
" axes[1].plot(pair.predicted_df_['tstamp'], pair.predicted_df_[f'{colname_b}_pred'],\n",
" label=f'{SYMBOL_B} Predicted', alpha=0.8, linestyle='--')\n",
" axes[1].set_title('Actual vs Predicted Prices - Symbol B')\n",
" axes[1].set_ylabel('Price')\n",
" axes[1].legend()\n",
" axes[1].grid(True)\n",
"\n",
" # Raw dis-equilibrium\n",
" axes[2].plot(pair.predicted_df_['tstamp'], pair.predicted_df_['disequilibrium'],\n",
" color='blue', alpha=0.8, label='Dis-equilibrium')\n",
" axes[2].axhline(y=pair.training_mu_, color='red', linestyle='--', alpha=0.7, label='Training Mean')\n",
" axes[2].set_title('Testing Period: Raw Dis-equilibrium')\n",
" axes[2].set_ylabel('Dis-equilibrium')\n",
" axes[2].legend()\n",
" axes[2].grid(True)\n",
"\n",
" # Scaled dis-equilibrium with trading signals\n",
" axes[3].plot(pair.predicted_df_['tstamp'], pair.predicted_df_['scaled_disequilibrium'],\n",
" color='green', alpha=0.8, label='Scaled Dis-equilibrium')\n",
"\n",
" # Add threshold lines\n",
" axes[3].axhline(y=CONFIG['dis-equilibrium_open_trshld'], color='purple',\n",
" linestyle=':', alpha=0.7, label=f\"Open Threshold ({CONFIG['dis-equilibrium_open_trshld']})\")\n",
" axes[3].axhline(y=CONFIG['dis-equilibrium_close_trshld'], color='brown',\n",
" linestyle=':', alpha=0.7, label=f\"Close Threshold ({CONFIG['dis-equilibrium_close_trshld']})\")\n",
"\n",
" # Add trading signals if they exist\n",
" if pair_trades is not None and len(pair_trades) > 0:\n",
" for _, trade in pair_trades.iterrows():\n",
" color = 'red' if 'BUY' in trade['action'] else 'blue'\n",
" marker = '^' if 'BUY' in trade['action'] else 'v'\n",
" axes[3].scatter(trade['time'], trade['scaled_disequilibrium'],\n",
" color=color, marker=marker, s=100, alpha=0.8,\n",
" label=f\"{trade['action']} {trade['symbol']}\" if _ < 2 else \"\")\n",
"\n",
" axes[3].set_title('Testing Period: Scaled Dis-equilibrium with Trading Signals')\n",
" axes[3].set_ylabel('Scaled Dis-equilibrium')\n",
" axes[3].set_xlabel('Time')\n",
" axes[3].legend()\n",
" axes[3].grid(True)\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" # Print prediction statistics\n",
" print(f\"\\nTesting dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.predicted_df_['disequilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.predicted_df_['disequilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.predicted_df_['disequilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.predicted_df_['disequilibrium'].max():.6f}\")\n",
"\n",
" print(f\"\\nTesting scaled dis-equilibrium statistics:\")\n",
" print(f\" Mean: {pair.predicted_df_['scaled_disequilibrium'].mean():.6f}\")\n",
" print(f\" Std: {pair.predicted_df_['scaled_disequilibrium'].std():.6f}\")\n",
" print(f\" Min: {pair.predicted_df_['scaled_disequilibrium'].min():.6f}\")\n",
" print(f\" Max: {pair.predicted_df_['scaled_disequilibrium'].max():.6f}\")\n",
"\n",
" # Count threshold crossings\n",
" open_crossings = (pair.predicted_df_['scaled_disequilibrium'] >= CONFIG['dis-equilibrium_open_trshld']).sum()\n",
" close_crossings = (pair.predicted_df_['scaled_disequilibrium'] <= CONFIG['dis-equilibrium_close_trshld']).sum()\n",
" print(f\"\\nThreshold crossings:\")\n",
" print(f\" Open threshold ({CONFIG['dis-equilibrium_open_trshld']}): {open_crossings} times\")\n",
" print(f\" Close threshold ({CONFIG['dis-equilibrium_close_trshld']}): {close_crossings} times\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary and Analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 60)\n",
"print(\"PAIRS TRADING ANALYSIS SUMMARY\")\n",
"print(\"=\" * 60)\n",
"\n",
"print(f\"\\nPair: {SYMBOL_A} & {SYMBOL_B}\")\n",
"print(f\"Strategy: {type(FIT_METHOD).__name__}\")\n",
"print(f\"Data file: {DATA_FILE}\")\n",
"print(f\"Training period: {training_minutes} minutes\")\n",
"\n",
"print(f\"\\nCointegration Status: {'✓ COINTEGRATED' if is_cointegrated else '✗ NOT COINTEGRATED'}\")\n",
"\n",
"if is_cointegrated:\n",
" print(f\"\\nVECM Model:\")\n",
" print(f\" Beta coefficients: {pair.vecm_fit_.beta.flatten()}\")\n",
" print(f\" Training mean: {pair.training_mu_:.6f}\")\n",
" print(f\" Training std: {pair.training_std_:.6f}\")\n",
"\n",
" if pair_trades is not None and len(pair_trades) > 0:\n",
" print(f\"\\nTrading Signals: {len(pair_trades)} generated\")\n",
" unique_times = pair_trades['time'].unique()\n",
" print(f\" Unique trade times: {len(unique_times)}\")\n",
"\n",
" # Group by time to see paired trades\n",
" for trade_time in unique_times:\n",
" trades_at_time = pair_trades[pair_trades['time'] == trade_time]\n",
" print(f\"\\n Trade at {trade_time}:\")\n",
" for _, trade in trades_at_time.iterrows():\n",
" print(f\" {trade['action']} {trade['symbol']} @ ${trade['price']:.2f} (dis-eq: {trade['scaled_disequilibrium']:.2f})\")\n",
" else:\n",
" print(f\"\\nTrading Signals: None generated\")\n",
" print(\" Possible reasons:\")\n",
" print(\" - Dis-equilibrium never exceeded open threshold\")\n",
" print(\" - Insufficient testing data\")\n",
" print(\" - Strategy-specific conditions not met\")\n",
"\n",
"else:\n",
" print(\"\\nCannot proceed with trading strategy - pair is not cointegrated\")\n",
" print(\"Consider:\")\n",
" print(\" - Trying different symbol pairs\")\n",
" print(\" - Adjusting training period length\")\n",
" print(\" - Using different data timeframe\")\n",
"\n",
"print(\"\\n\" + \"=\" * 60)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interactive Analysis (Optional)\n",
"\n",
"You can modify the parameters below and re-run the analysis:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Interactive parameter adjustment\n",
"print(\"Current parameters:\")\n",
"print(f\" Open threshold: {CONFIG['dis-equilibrium_open_trshld']}\")\n",
"print(f\" Close threshold: {CONFIG['dis-equilibrium_close_trshld']}\")\n",
"print(f\" Training minutes: {CONFIG['training_minutes']}\")\n",
"\n",
"# Uncomment and modify these to experiment:\n",
"# CONFIG['dis-equilibrium_open_trshld'] = 1.5\n",
"# CONFIG['dis-equilibrium_close_trshld'] = 0.3\n",
"# CONFIG['training_minutes'] = 180\n",
"\n",
"print(\"\\nTo re-run with different parameters:\")\n",
"print(\"1. Modify the parameters above\")\n",
"print(\"2. Re-run from the 'Split Data into Training and Testing' cell\")\n",
"print(\"3. Or try different symbol pairs by changing SYMBOL_A and SYMBOL_B\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3.12-venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

File diff suppressed because one or more lines are too long

View File

@ -3,104 +3,100 @@ import glob
import importlib import importlib
import os import os
from datetime import date, datetime from datetime import date, datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
import pandas as pd import pandas as pd
from research.research_tools import create_pairs
from tools.config import expand_filename, load_config from tools.config import expand_filename, load_config
from tools.data_loader import get_available_instruments_from_db, load_market_data
from pt_trading.results import ( from pt_trading.results import (
BacktestResult, BacktestResult,
create_result_database, create_result_database,
store_config_in_database, store_config_in_database,
store_results_in_database,
) )
from pt_trading.fit_method import PairsTradingFitMethod from pt_trading.fit_method import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair from pt_trading.trading_pair import TradingPair
DayT = str
DataFileNameT = str
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]: def resolve_datafiles(
""" config: Dict, date_pattern: str, instruments: List[Dict[str, str]]
Resolve the list of data files to process. ) -> List[Tuple[DayT, DataFileNameT]]:
CLI datafiles take priority over config datafiles. resolved_files: List[Tuple[DayT, DataFileNameT]] = []
Supports wildcards in config but not in CLI. for inst in instruments:
""" pattern = date_pattern
if cli_datafiles: inst_type = inst["instrument_type"]
# CLI override - comma-separated list, no wildcards data_dir = config["market_data_loading"][inst_type]["data_directory"]
datafiles = [f.strip() for f in cli_datafiles.split(",")]
# Make paths absolute relative to data directory
data_dir = config.get("data_directory", "./data")
resolved_files = []
for df in datafiles:
if not os.path.isabs(df):
df = os.path.join(data_dir, df)
resolved_files.append(df)
return resolved_files
# Use config datafiles with wildcard support
config_datafiles = config.get("datafiles", [])
data_dir = config.get("data_directory", "./data")
resolved_files = []
for pattern in config_datafiles:
if "*" in pattern or "?" in pattern: if "*" in pattern or "?" in pattern:
# Handle wildcards # Handle wildcards
if not os.path.isabs(pattern): if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern) pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
matched_files = glob.glob(pattern) matched_files = glob.glob(pattern)
resolved_files.extend(matched_files) for matched_file in matched_files:
import re
match = re.search(r"(\d{8})\.mktdata\.ohlcv\.db$", matched_file)
assert match is not None
day = match.group(1)
resolved_files.append((day, matched_file))
else: else:
# Handle explicit file path # Handle explicit file path
if not os.path.isabs(pattern): if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern) pattern = os.path.join(data_dir, f"{pattern}.mktdata.ohlcv.db")
resolved_files.append(pattern) resolved_files.append((date_pattern, pattern))
return sorted(list(set(resolved_files))) # Remove duplicates and sort return sorted(list(set(resolved_files))) # Remove duplicates and sort
def get_instruments(args: argparse.Namespace, config: Dict) -> List[Dict[str, str]]:
instruments = [
{
"symbol": inst.split(":")[0],
"instrument_type": inst.split(":")[1],
"exchange_id": inst.split(":")[2],
"instrument_id_pfx": config["market_data_loading"][inst.split(":")[1]][
"instrument_id_pfx"
],
"db_table_name": config["market_data_loading"][inst.split(":")[1]][
"db_table_name"
],
}
for inst in args.instruments.split(",")
]
return instruments
def run_backtest( def run_backtest(
config: Dict, config: Dict,
datafile: str, datafiles: List[str],
price_column: str,
fit_method: PairsTradingFitMethod, fit_method: PairsTradingFitMethod,
instruments: List[str], instruments: List[Dict[str, str]],
) -> BacktestResult: ) -> BacktestResult:
""" """
Run backtest for all pairs using the specified instruments. Run backtest for all pairs using the specified instruments.
""" """
bt_result: BacktestResult = BacktestResult(config=config) bt_result: BacktestResult = BacktestResult(config=config)
# if len(datafiles) < 2:
# print(f"WARNING: insufficient data files: {datafiles}")
# return bt_result
def _create_pairs(config: Dict, instruments: List[str]) -> List[TradingPair]: if not all([os.path.exists(datafile) for datafile in datafiles]):
nonlocal datafile print(f"WARNING: data file {datafiles} does not exist")
all_indexes = range(len(instruments)) return bt_result
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy)
for a_index, b_index in unique_index_pairs:
pair = TradingPair(
config=config_copy,
market_data=market_data_df,
symbol_a=instruments[a_index],
symbol_b=instruments[b_index],
price_column=price_column,
)
pairs.append(pair)
return pairs
pairs_trades = [] pairs_trades = []
for pair in _create_pairs(config, instruments):
single_pair_trades = fit_method.run_pair( pairs = create_pairs(
pair=pair, bt_result=bt_result datafiles=datafiles,
) fit_method=fit_method,
config=config,
instruments=instruments,
)
for pair in pairs:
single_pair_trades = fit_method.run_pair(pair=pair, bt_result=bt_result)
if single_pair_trades is not None and len(single_pair_trades) > 0: if single_pair_trades is not None and len(single_pair_trades) > 0:
pairs_trades.append(single_pair_trades) pairs_trades.append(single_pair_trades)
print(f"pairs_trades: {pairs_trades}") print(f"pairs_trades:\n{pairs_trades}")
# Check if result_list has any data before concatenating # Check if result_list has any data before concatenating
if len(pairs_trades) == 0: if len(pairs_trades) == 0:
print("No trading signals found for any pairs") print("No trading signals found for any pairs")
@ -109,23 +105,22 @@ def run_backtest(
bt_result.collect_single_day_results(pairs_trades) bt_result.collect_single_day_results(pairs_trades)
return bt_result return bt_result
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description="Run pairs trading backtest.") parser = argparse.ArgumentParser(description="Run pairs trading backtest.")
parser.add_argument( parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file." "--config", type=str, required=True, help="Path to the configuration file."
) )
parser.add_argument( parser.add_argument(
"--datafiles", "--date_pattern",
type=str, type=str,
required=False, required=True,
help="Comma-separated list of data files (overrides config). No wildcards supported.", help="Date YYYYMMDD, allows * and ? wildcards",
) )
parser.add_argument( parser.add_argument(
"--instruments", "--instruments",
type=str, type=str,
required=False, required=True,
help="Comma-separated list of instrument symbols (e.g., COIN,GBTC). If not provided, auto-detects from database.", help="Comma-separated list of instrument symbols (e.g., COIN:EQUITY,GBTC:CRYPTO)",
) )
parser.add_argument( parser.add_argument(
"--result_db", "--result_db",
@ -139,19 +134,13 @@ def main() -> None:
config: Dict = load_config(args.config) config: Dict = load_config(args.config)
# Dynamically instantiate fit method class # Dynamically instantiate fit method class
fit_method_class_name = config.get("fit_method_class", None) fit_method = PairsTradingFitMethod.create(config)
assert fit_method_class_name is not None
module_name, class_name = fit_method_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
fit_method = getattr(module, class_name)()
# Resolve data files (CLI takes priority over config) # Resolve data files (CLI takes priority over config)
datafiles = resolve_datafiles(config, args.datafiles) instruments = get_instruments(args, config)
datafiles = resolve_datafiles(config, args.date_pattern, instruments)
if not datafiles:
print("No data files found to process.")
return
days = list(set([day for day, _ in datafiles]))
print(f"Found {len(datafiles)} data files to process:") print(f"Found {len(datafiles)} data files to process:")
for df in datafiles: for df in datafiles:
print(f" - {df}") print(f" - {df}")
@ -163,51 +152,26 @@ def main() -> None:
# Initialize a dictionary to store all trade results # Initialize a dictionary to store all trade results
all_results: Dict[str, Dict[str, Any]] = {} all_results: Dict[str, Dict[str, Any]] = {}
is_config_stored = False
# Store configuration in database for reference
if args.result_db.upper() != "NONE":
# Get list of all instruments for storage
all_instruments = []
for datafile in datafiles:
if args.instruments:
file_instruments = [
inst.strip() for inst in args.instruments.split(",")
]
else:
file_instruments = get_available_instruments_from_db(datafile, config)
all_instruments.extend(file_instruments)
# Remove duplicates while preserving order
unique_instruments = list(dict.fromkeys(all_instruments))
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=fit_method_class_name,
datafiles=datafiles,
instruments=unique_instruments,
)
# Process each data file # Process each data file
price_column = config["price_column"]
for datafile in datafiles: for day in sorted(days):
print(f"\n====== Processing {os.path.basename(datafile)} ======") md_datafiles = [datafile for md_day, datafile in datafiles if md_day == day]
if not all([os.path.exists(datafile) for datafile in md_datafiles]):
# Determine instruments to use print(f"WARNING: insufficient data files: {md_datafiles}")
if args.instruments:
# Use CLI-specified instruments
instruments = [inst.strip() for inst in args.instruments.split(",")]
print(f"Using CLI-specified instruments: {instruments}")
else:
# Auto-detect instruments from database
instruments = get_available_instruments_from_db(datafile, config)
print(f"Auto-detected instruments: {instruments}")
if not instruments:
print(f"No instruments found for {datafile}, skipping...")
continue continue
print(f"\n====== Processing {day} ======")
if not is_config_stored:
store_config_in_database(
db_path=args.result_db,
config_file_path=args.config,
config=config,
fit_method_class=config["fit_method_class"],
datafiles=datafiles,
instruments=instruments,
)
is_config_stored = True
# Process data for this file # Process data for this file
try: try:
@ -215,14 +179,17 @@ def main() -> None:
bt_results = run_backtest( bt_results = run_backtest(
config=config, config=config,
datafile=datafile, datafiles=md_datafiles,
price_column=price_column,
fit_method=fit_method, fit_method=fit_method,
instruments=instruments, instruments=instruments,
) )
if bt_results.trades is None or len(bt_results.trades) == 0:
print(f"No trades found for {day}")
continue
# Store results with file name as key # Store results with day name as key
filename = os.path.basename(datafile) filename = os.path.basename(day)
all_results[filename] = { all_results[filename] = {
"trades": bt_results.trades.copy(), "trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(), "outstanding_positions": bt_results.outstanding_positions.copy(),
@ -230,12 +197,20 @@ def main() -> None:
# Store results in database # Store results in database
if args.result_db.upper() != "NONE": if args.result_db.upper() != "NONE":
store_results_in_database(args.result_db, datafile, bt_results) bt_results.calculate_returns(
{
filename: {
"trades": bt_results.trades.copy(),
"outstanding_positions": bt_results.outstanding_positions.copy(),
}
}
)
bt_results.store_results_in_database(db_path=args.result_db, day=day)
print(f"Successfully processed {filename}") print(f"Successfully processed {filename}")
except Exception as err: except Exception as err:
print(f"Error processing {datafile}: {str(err)}") print(f"Error processing {day}: {str(err)}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View File

@ -0,0 +1,94 @@
import glob
import os
from typing import Dict, List, Optional
import pandas as pd
from pt_trading.fit_method import PairsTradingFitMethod
def resolve_datafiles(config: Dict, cli_datafiles: Optional[str] = None) -> List[str]:
"""
Resolve the list of data files to process.
CLI datafiles take priority over config datafiles.
Supports wildcards in config but not in CLI.
"""
if cli_datafiles:
# CLI override - comma-separated list, no wildcards
datafiles = [f.strip() for f in cli_datafiles.split(",")]
# Make paths absolute relative to data directory
data_dir = config.get("data_directory", "./data")
resolved_files = []
for df in datafiles:
if not os.path.isabs(df):
df = os.path.join(data_dir, df)
resolved_files.append(df)
return resolved_files
# Use config datafiles with wildcard support
config_datafiles = config.get("datafiles", [])
data_dir = config.get("data_directory", "./data")
resolved_files = []
for pattern in config_datafiles:
if "*" in pattern or "?" in pattern:
# Handle wildcards
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
matched_files = glob.glob(pattern)
resolved_files.extend(matched_files)
else:
# Handle explicit file path
if not os.path.isabs(pattern):
pattern = os.path.join(data_dir, pattern)
resolved_files.append(pattern)
return sorted(list(set(resolved_files))) # Remove duplicates and sort
def create_pairs(
datafiles: List[str],
fit_method: PairsTradingFitMethod,
config: Dict,
instruments: List[Dict[str, str]],
) -> List:
from pt_trading.trading_pair import TradingPair
from tools.data_loader import load_market_data
all_indexes = range(len(instruments))
unique_index_pairs = [(i, j) for i in all_indexes for j in all_indexes if i < j]
pairs = []
# Update config to use the specified instruments
config_copy = config.copy()
config_copy["instruments"] = instruments
market_data_df = pd.DataFrame()
extra_minutes = 0
if "execution_price" in config_copy:
extra_minutes = config_copy["execution_price"]["shift"]
for datafile in datafiles:
md_df = load_market_data(
datafile=datafile,
instruments=instruments,
db_table_name=config_copy["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
trading_hours=config_copy["trading_hours"],
extra_minutes=extra_minutes,
)
market_data_df = pd.concat([market_data_df, md_df])
if len(set(market_data_df["symbol"])) != 2: # both symbols must be present for a pair
print(f"WARNING: insufficient data in files: {datafiles}")
return []
for a_index, b_index in unique_index_pairs:
symbol_a=instruments[a_index]["symbol"]
symbol_b=instruments[b_index]["symbol"]
pair = fit_method.create_trading_pair(
config=config_copy,
market_data=market_data_df,
symbol_a=symbol_a,
symbol_b=symbol_b,
)
pairs.append(pair)
return pairs

View File

@ -16,7 +16,12 @@ cd $(realpath $(dirname $0))/..
mkdir -p ./data/crypto mkdir -p ./data/crypto
pushd ./data/crypto pushd ./data/crypto
Cmd="rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/*.gz ./" Files=$1
if [ -z "$Files" ]; then
Files="*.gz"
fi
Cmd="rsync -ahvv cvtt@hs01.cvtt.vpn:/works/cvtt/md_archive/crypto/sim/${Files} ./"
echo $Cmd echo $Cmd
eval $Cmd eval $Cmd
# ------------------------------------- # -------------------------------------

View File

@ -26,8 +26,12 @@ for srcfname in $(ls *.db.gz); do
tgtfile=${dt}.mktdata.ohlcv.db tgtfile=${dt}.mktdata.ohlcv.db
echo "${srcfname} -> ${tgtfile}" echo "${srcfname} -> ${tgtfile}"
gunzip -c $srcfname > temp.db Cmd="gunzip -c $srcfname > temp.db && rm $srcfname"
rm -f ${tgtfile} && sqlite3 temp.db ".dump md_1min_bars" | sqlite3 ${tgtfile} && rm ${srcfname} echo ${Cmd}
eval ${Cmd}
Cmd="rm -f ${tgtfile} && sqlite3 temp.db '.dump md_1min_bars' | sqlite3 ${tgtfile}"
echo ${Cmd}
eval ${Cmd}
done done
rm temp.db rm temp.db
popd popd

View File

@ -20,12 +20,9 @@ from pt_trading.fit_methods import PairsTradingFitMethod
from pt_trading.trading_pair import TradingPair from pt_trading.trading_pair import TradingPair
def run_strategy( def run_strategy(
config: Dict, config: Dict,
datafile: str, datafile: str,
price_column: str,
fit_method: PairsTradingFitMethod, fit_method: PairsTradingFitMethod,
instruments: List[str], instruments: List[str],
) -> BacktestResult: ) -> BacktestResult:
@ -44,14 +41,20 @@ def run_strategy(
config_copy = config.copy() config_copy = config.copy()
config_copy["instruments"] = instruments config_copy["instruments"] = instruments
market_data_df = load_market_data(datafile, config=config_copy) market_data_df = load_market_data(
datafile=datafile,
exchange_id=config_copy["exchange_id"],
instruments=config_copy["instruments"],
instrument_id_pfx=config_copy["instrument_id_pfx"],
db_table_name=config_copy["db_table_name"],
trading_hours=config_copy["trading_hours"],
)
for a_index, b_index in unique_index_pairs: for a_index, b_index in unique_index_pairs:
pair = TradingPair( pair = fit_method.create_trading_pair(
market_data=market_data_df, market_data=market_data_df,
symbol_a=instruments[a_index], symbol_a=instruments[a_index],
symbol_b=instruments[b_index], symbol_b=instruments[b_index],
price_column=price_column,
) )
pairs.append(pair) pairs.append(pair)
return pairs return pairs
@ -156,7 +159,6 @@ def main() -> None:
) )
# Process each data file # Process each data file
price_column = config["price_column"]
for datafile in datafiles: for datafile in datafiles:
print(f"\n====== Processing {os.path.basename(datafile)} ======") print(f"\n====== Processing {os.path.basename(datafile)} ======")
@ -182,7 +184,6 @@ def main() -> None:
bt_results = run_strategy( bt_results = run_strategy(
config=config, config=config,
datafile=datafile, datafile=datafile,
price_column=price_column,
fit_method=fit_method, fit_method=fit_method,
instruments=instruments, instruments=instruments,
) )