2025-04-20 17:52:49 +00:00

175 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import pandas as pd
import numpy as np
# Restore imports from 'ta' library
from ta.volatility import AverageTrueRange
from ta.momentum import RSIIndicator
from ta.trend import EMAIndicator, MACD
# import talib # Remove talib import
__all__ = [
"add_imbalance_features",
"add_ta_features",
"prune_features",
"minimal_whitelist",
]
_EPS = 1e-6
# --- New Feature Function (Task 2.1) ---
def vola_norm_return(df: pd.DataFrame, k: int) -> pd.Series:
"""
Calculates volatility-normalized returns over k periods.
return_k / rolling_std(return_k, window=k)
"""
if 'close' not in df.columns:
raise ValueError("'close' column required for vola_norm_return")
if k <= 1:
raise ValueError("Window k must be > 1 for rolling std dev")
# Calculate k-period percentage change returns
returns_k = df['close'].pct_change(k)
# Calculate rolling standard deviation of these k-period returns
sigma_k = returns_k.rolling(window=k, min_periods=max(2, k // 2 + 1)).std()
# Normalize returns by volatility, replacing 0 std dev with NaN
vola_normed = returns_k / sigma_k.replace(0, np.nan)
return vola_normed
# --- End New Feature Function ---
def add_imbalance_features(df: pd.DataFrame) -> pd.DataFrame:
"""Add Chaikin AD line, signed volume imbalance, gap imbalance."""
if not {"open", "high", "low", "close", "volume"}.issubset(df.columns):
return df
clv = ((df["close"] - df["low"]) - (df["high"] - df["close"])) / (
df["high"] - df["low"] + _EPS
)
df["chaikin_AD_10"] = (clv * df["volume"]).rolling(10).sum()
signed_vol = np.where(df["close"] >= df["open"], df["volume"], -df["volume"])
df["svi_10"] = pd.Series(signed_vol, index=df.index).rolling(10).sum()
med_vol = df["volume"].rolling(50).median()
gap_up = (df["low"] > df["high"].shift(1)) & (df["volume"] > 2 * med_vol)
gap_dn = (df["high"] < df["low"].shift(1)) & (df["volume"] > 2 * med_vol)
df["gap_imbalance"] = gap_up.astype(int) - gap_dn.astype(int)
df.fillna(0, inplace=True)
return df
# ------------------------------------------------------------------
# Technical analysis features
# ------------------------------------------------------------------
def add_ta_features(df: pd.DataFrame) -> pd.DataFrame:
"""Adds TA features to the dataframe using the ta library."""
# Remove talib checks
# required_cols = {'open': 'open', 'high': 'high', 'low': 'low', 'close': 'close', 'volume': 'volume'}
# if not set(required_cols.keys()).issubset(df.columns):
# print(f"WARN: Missing required columns for TA-Lib in features.py. Need {required_cols.keys()}")
# return df
# Ensure correct dtype for talib (often float64)
# for col in required_cols.keys():
# if df[col].dtype != np.float64:
# try:
# df[col] = df[col].astype(np.float64)
# except Exception as e:
# print(f"WARN: Could not convert column {col} to float64 for TA-Lib: {e}")
# return df # Cannot proceed if conversion fails
df_copy = df.copy()
# Calculate returns first (use bfill + ffill for pct_change compatibility)
# Fill NaNs robustly before pct_change
df_copy["close_filled"] = df_copy["close"].bfill().ffill()
df_copy["return_1m"] = df_copy["close_filled"].pct_change()
df_copy["return_15m"] = df_copy["close_filled"].pct_change(15)
df_copy["return_60m"] = df_copy["close_filled"].pct_change(60)
df_copy.drop(columns=["close_filled"], inplace=True)
# Calculate TA features using ta library
# df_copy["ATR_14"] = talib.ATR(df_copy['high'], df_copy['low'], df_copy['close'], timeperiod=14)
df_copy["ATR_14"] = AverageTrueRange(df_copy['high'], df_copy['low'], df_copy['close'], window=14).average_true_range()
# Daily volatility 14d of returns
df_copy["volatility_14d"] = (
df_copy["return_1m"].rolling(60 * 24 * 14, min_periods=30).std() # rough 14d for 1min bars
)
# EMA 10 / 50 + MACD using ta library
# df_copy["EMA_10"] = talib.EMA(df_copy["close"], timeperiod=10)
# df_copy["EMA_50"] = talib.EMA(df_copy["close"], timeperiod=50)
df_copy["EMA_10"] = EMAIndicator(df_copy["close"], 10).ema_indicator()
df_copy["EMA_50"] = EMAIndicator(df_copy["close"], 50).ema_indicator()
# talib.MACD returns macd, macdsignal, macdhist
# macd, macdsignal, macdhist = talib.MACD(df_copy["close"], fastperiod=12, slowperiod=26, signalperiod=9)
macd = MACD(df_copy["close"], window_slow=26, window_fast=12, window_sign=9)
df_copy["MACD"] = macd.macd()
df_copy["MACD_signal"] = macd.macd_signal()
# RSI 14 using ta library
# df_copy["RSI_14"] = talib.RSI(df_copy["close"], timeperiod=14)
df_copy["RSI_14"] = RSIIndicator(df_copy["close"], window=14).rsi()
# Cyclical hour already recommended to add upstream (data_pipeline).
# Handle potential NaNs introduced by TA calculations
# df.fillna(method="bfill", inplace=True) # Deprecated
df_copy.bfill(inplace=True)
df_copy.ffill(inplace=True) # Add ffill for any remaining NaNs at the beginning
return df_copy
# ------------------------------------------------------------------
# Pruning & whitelist
# ------------------------------------------------------------------
minimal_whitelist = [
# Returns
"return_1m",
"return_15m",
"return_60m",
# Volatility
"ATR_14",
"volatility_14d",
# Vola-Normalized Returns (New)
"vola_norm_return_15",
"vola_norm_return_60",
# Imbalance
"chaikin_AD_10",
"svi_10",
# Trend
"EMA_10",
"EMA_50",
# "MACD", # Removed Task 2.3
# "MACD_signal", # Removed Task 2.3
# Cyclical (Time)
"hour_sin",
"hour_cos",
"week_sin", # Added Task 2.2
"week_cos", # Added Task 2.2
]
def prune_features(df: pd.DataFrame, whitelist: list[str] | None = None) -> pd.DataFrame:
"""Return DataFrame containing only *whitelisted* columns."""
if whitelist is None:
whitelist = minimal_whitelist
# Find columns present in both DataFrame and whitelist
cols_to_keep = [c for c in whitelist if c in df.columns]
# Ensure the set of kept columns exactly matches the intersection
df_pruned = df[cols_to_keep].copy()
assert set(df_pruned.columns) == set(cols_to_keep), \
f"Pruning failed: Output columns {set(df_pruned.columns)} != Expected intersection {set(cols_to_keep)}"
# Optional: Assert against the full whitelist if input is expected to always contain all
# assert set(df_pruned.columns) == set(whitelist), \
# f"Pruning failed: Output columns {set(df_pruned.columns)} != Full whitelist {set(whitelist)}"
return df_pruned