pairs_trading/lib/pt_trading/vecm_rolling_fit.py

193 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# original script moved to vecm_rolling_fit_01.py
# 09.09.25 Added GARCH model - predicting volatility
# Rule of thumb:
# alpha + beta ≈ 1 → strong volatility clustering, persistence.
# If much lower → volatility mean reverts quickly.
# If > 1 → model is unstable / non-stationary (bad).
# the VECM disequilibrium (mean reversion signal) and
# the GARCH volatility forecast (risk measure).
# combine them → e.g., only enter trades when:
# high_volatility = 1 → persistence > 0.95 or volatility > 2 (rule of thumb: unstable / risky regime).
# high_volatility = 0 → stable regime.
# VECM disequilibrium z-score > threshold and
# GARCH-forecasted volatility is not too high (avoid noise-driven signals).
# This creates a volatility-adjusted pairs trading strategy, more robust than plain VECM
# now pair_predict_result_ DataFrame includes:
# disequilibrium, scaled_disequilibrium, z-scores, garch_alpha, garch_beta, garch_persistence (α+β rule-of-thumb)
# garch_vol_forecast (1-step volatility forecast)
# Would you like me to also add a warning flag column
# (e.g., "high_volatility" = 1 if persistence > 0.95 or vol_forecast > threshold)
# so you can easily detect unstable regimes?
# VECM/GARCH
# vecm_rolling_fit.py:
from typing import Any, Dict, Optional, cast
import numpy as np
import pandas as pd
from typing import Any, Dict, Optional
from pt_trading.results import BacktestResult
from pt_trading.rolling_window_fit import RollingFit
from pt_trading.trading_pair import TradingPair
from statsmodels.tsa.vector_ar.vecm import VECM, VECMResults
from arch import arch_model
NanoPerMin = 1e9
class VECMTradingPair(TradingPair):
vecm_fit_: Optional[VECMResults]
pair_predict_result_: Optional[pd.DataFrame]
def __init__(
self,
config: Dict[str, Any],
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
):
super().__init__(config, market_data, symbol_a, symbol_b)
self.vecm_fit_ = None
self.pair_predict_result_ = None
self.garch_fit_ = None
self.sigma_spread_forecast_ = None
self.garch_alpha_ = None
self.garch_beta_ = None
self.garch_persistence_ = None
self.high_volatility_flag_ = None
def _train_pair(self) -> None:
self._fit_VECM()
assert self.vecm_fit_ is not None
diseq_series = self.training_df_[self.colnames()] @ self.vecm_fit_.beta
self.training_mu_ = float(diseq_series[0].mean())
self.training_std_ = float(diseq_series[0].std())
self.training_df_["disequilibrium"] = diseq_series
self.training_df_["scaled_disequilibrium"] = (
diseq_series - self.training_mu_
) / self.training_std_
def _fit_VECM(self) -> None:
assert self.training_df_ is not None
vecm_df = self.training_df_[self.colnames()].reset_index(drop=True)
vecm_model = VECM(vecm_df, coint_rank=1)
vecm_fit = vecm_model.fit()
self.vecm_fit_ = vecm_fit
# Error Correction Term (spread)
ect_series = (vecm_df @ vecm_fit.beta).iloc[:, 0]
# Difference the spread for stationarity
dz = ect_series.diff().dropna()
if len(dz) < 30:
print("Not enough data for GARCH fitting.")
return
# Rescale if variance too small
if dz.std() < 0.1:
dz = dz * 1000
# print("Scale check:", dz.std())
try:
garch = arch_model(dz, vol="GARCH", p=1, q=1, mean="Zero", dist="normal")
garch_fit = garch.fit(disp="off")
self.garch_fit_ = garch_fit
# Extract parameters
params = garch_fit.params
self.garch_alpha_ = params.get("alpha[1]", np.nan)
self.garch_beta_ = params.get("beta[1]", np.nan)
self.garch_persistence_ = self.garch_alpha_ + self.garch_beta_
# print (f"GARCH α: {self.garch_alpha_:.4f}, β: {self.garch_beta_:.4f}, "
# f"α+β (persistence): {self.garch_persistence_:.4f}")
# One-step-ahead volatility forecast
forecast = garch_fit.forecast(horizon=1)
sigma_next = np.sqrt(forecast.variance.iloc[-1, 0])
self.sigma_spread_forecast_ = float(sigma_next)
# print("GARCH sigma forecast:", self.sigma_spread_forecast_)
# Rule of thumb: persistence close to 1 or large volatility forecast
self.high_volatility_flag_ = int(
(self.garch_persistence_ is not None and self.garch_persistence_ > 0.95)
or (self.sigma_spread_forecast_ is not None and self.sigma_spread_forecast_ > 2)
)
except Exception as e:
print(f"GARCH fit failed: {e}")
self.garch_fit_ = None
self.sigma_spread_forecast_ = None
self.high_volatility_flag_ = None
def predict(self) -> pd.DataFrame:
self._train_pair()
assert self.testing_df_ is not None
assert self.vecm_fit_ is not None
# VECM predictions
predicted_prices = self.vecm_fit_.predict(steps=len(self.testing_df_))
predicted_df = pd.merge(
self.testing_df_.reset_index(drop=True),
pd.DataFrame(predicted_prices, columns=pd.Index(self.colnames()), dtype=float),
left_index=True,
right_index=True,
suffixes=("", "_pred"),
).dropna()
# Disequilibrium and z-scores
predicted_df["disequilibrium"] = (
predicted_df[self.colnames()] @ self.vecm_fit_.beta
)
predicted_df["signed_scaled_disequilibrium"] = (
predicted_df["disequilibrium"] - self.training_mu_
) / self.training_std_
predicted_df["scaled_disequilibrium"] = abs(
predicted_df["signed_scaled_disequilibrium"]
)
# Add GARCH parameters + volatility forecast
predicted_df["garch_alpha"] = self.garch_alpha_
predicted_df["garch_beta"] = self.garch_beta_
predicted_df["garch_persistence"] = self.garch_persistence_
predicted_df["garch_vol_forecast"] = self.sigma_spread_forecast_
predicted_df["high_volatility"] = self.high_volatility_flag_
# Save results
if self.pair_predict_result_ is None:
self.pair_predict_result_ = predicted_df
else:
self.pair_predict_result_ = pd.concat(
[self.pair_predict_result_, predicted_df], ignore_index=True
)
return self.pair_predict_result_
class VECMRollingFit(RollingFit):
def __init__(self) -> None:
super().__init__()
def create_trading_pair(
self,
config: Dict,
market_data: pd.DataFrame,
symbol_a: str,
symbol_b: str,
) -> TradingPair:
return VECMTradingPair(
config=config,
market_data=market_data,
symbol_a = symbol_a,
symbol_b = symbol_b,
)