Compare commits
4 Commits
a86cdb2c8f
...
c3526bb9f6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3526bb9f6 | ||
|
|
4b1b542430 | ||
|
|
9d6641a5f2 | ||
|
|
30f2fd2d1f |
711
gru_sac_predictor/src/pipeline_stages/data_processing.py
Normal file
711
gru_sac_predictor/src/pipeline_stages/data_processing.py
Normal file
@ -0,0 +1,711 @@
|
|||||||
|
# Stage functions for loading, initial preprocessing, feature engineering, label generation, and splitting
|
||||||
|
import logging
|
||||||
|
import sys # Added for sys.exit
|
||||||
|
from datetime import datetime, timezone # Added for datetime
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple, Optional, Any, List, Dict # Added List and Dict
|
||||||
|
import matplotlib.pyplot as plt # Added for plotting
|
||||||
|
import seaborn as sns # Added for plotting
|
||||||
|
|
||||||
|
# --- Component Imports --- #
|
||||||
|
# Assuming DataLoader is in the parent directory's src
|
||||||
|
# This might need adjustment based on actual project structure
|
||||||
|
# Using relative import assuming pipeline_stages is sibling to other src modules
|
||||||
|
from ..data_loader import DataLoader, fill_missing_bars
|
||||||
|
from ..feature_engineer import FeatureEngineer # Added FeatureEngineer import
|
||||||
|
from ..io_manager import IOManager # Added IOManager import
|
||||||
|
from ..metrics import calculate_sharpe_ratio # For potential baseline comparison
|
||||||
|
|
||||||
|
# --- Local Imports --- #
|
||||||
|
# Import the label generation function we moved here
|
||||||
|
# Removed duplicate import: from .data_processing import generate_direction_labels
|
||||||
|
|
||||||
|
# Assuming tensorflow is installed and available
|
||||||
|
try:
|
||||||
|
from tensorflow.keras.utils import to_categorical
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("TensorFlow/Keras not found. Ternary label one-hot encoding will fail.")
|
||||||
|
# Define a placeholder if keras is not available
|
||||||
|
def to_categorical(*args, **kwargs):
|
||||||
|
raise NotImplementedError("Keras 'to_categorical' is unavailable.")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Use module-level logger
|
||||||
|
|
||||||
|
# --- Refactored Label Generation Logic (Moved from trading_pipeline.py) --- #
|
||||||
|
def generate_direction_labels(df: pd.DataFrame, config: dict) -> Tuple[pd.DataFrame, str, pd.Series, Optional[pd.Series]]:
|
||||||
|
"""
|
||||||
|
Calculates forward returns and generates binary, soft binary, or ternary direction labels.
|
||||||
|
Also returns the raw forward returns and the epsilon series used for ternary flat definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): DataFrame containing at least a 'close' column and DatetimeIndex.
|
||||||
|
config (dict): Pipeline configuration dictionary, expecting keys under 'gru' and 'data'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[pd.DataFrame, str, pd.Series, Optional[pd.Series]]:
|
||||||
|
- DataFrame with added forward return and direction label columns (and NaNs dropped based on labels).
|
||||||
|
- Name of the generated direction label column.
|
||||||
|
- Series containing the calculated forward log returns (`fwd_log_ret`).
|
||||||
|
- Series containing the calculated epsilon (`eps`) threshold if ternary, else None.
|
||||||
|
"""
|
||||||
|
if 'close' not in df.columns:
|
||||||
|
raise ValueError("'close' column missing in input DataFrame for label generation.")
|
||||||
|
|
||||||
|
gru_cfg = config.get('gru', {})
|
||||||
|
data_cfg = config.get('data', {})
|
||||||
|
horizon = gru_cfg.get('prediction_horizon', 5)
|
||||||
|
use_ternary = gru_cfg.get('use_ternary', False) # Check if ternary flag is set
|
||||||
|
|
||||||
|
target_ret_col = f'fwd_log_ret_{horizon}'
|
||||||
|
eps_series: Optional[pd.Series] = None # Initialize eps
|
||||||
|
|
||||||
|
# --- Calculate Forward Log Return --- #
|
||||||
|
shifted_close = df['close'].shift(-horizon)
|
||||||
|
fwd_returns = np.log(shifted_close / df['close'])
|
||||||
|
df[target_ret_col] = fwd_returns
|
||||||
|
|
||||||
|
# --- Generate Direction Label (Binary/Soft or Ternary) --- #
|
||||||
|
if use_ternary:
|
||||||
|
k = gru_cfg.get('flat_sigma_multiplier', 0.25)
|
||||||
|
target_dir_col = f'direction_label3_{horizon}'
|
||||||
|
logger.info(f"Generating ternary labels ({target_dir_col}) with k={k}...")
|
||||||
|
|
||||||
|
sigma_n = fwd_returns.rolling(window=horizon, min_periods=max(1, horizon//2)).std()
|
||||||
|
eps = k * sigma_n
|
||||||
|
eps_series = eps # Store the calculated eps series
|
||||||
|
|
||||||
|
conditions = [fwd_returns > eps, fwd_returns < -eps]
|
||||||
|
choices = [2, 0] # 2=up, 0=down
|
||||||
|
ordinal_labels = np.select(conditions, choices, default=1).astype(int) # 1=flat
|
||||||
|
|
||||||
|
# --- Log Distribution & Check Balance --- #
|
||||||
|
df['_ordinal_label_temp'] = ordinal_labels
|
||||||
|
valid_mask_for_dist = ~np.isnan(eps) & ~np.isnan(fwd_returns)
|
||||||
|
ordinal_labels_valid = df.loc[valid_mask_for_dist, '_ordinal_label_temp']
|
||||||
|
|
||||||
|
if not ordinal_labels_valid.empty:
|
||||||
|
counts = np.bincount(ordinal_labels_valid, minlength=3)
|
||||||
|
total_valid = len(ordinal_labels_valid)
|
||||||
|
if total_valid > 0: # Avoid division by zero
|
||||||
|
dist_pct = counts / total_valid * 100
|
||||||
|
log_msg = (f"Label dist (n={total_valid}): "
|
||||||
|
f"Down(0)={dist_pct[0]:.1f}%, Flat(1)={dist_pct[1]:.1f}%, Up(2)={dist_pct[2]:.1f}%")
|
||||||
|
logger.info(log_msg)
|
||||||
|
|
||||||
|
min_pct_threshold = 10.0 # As per implementation
|
||||||
|
if any(p < min_pct_threshold for p in dist_pct):
|
||||||
|
error_msg = f"Label imbalance detected! Min class percentage is {np.min(dist_pct):.1f}% (Threshold: {min_pct_threshold}%). Check data or flat_sigma_multiplier (k={k})."
|
||||||
|
logger.error(error_msg)
|
||||||
|
print(f"ERROR: {error_msg}") # Also print for visibility
|
||||||
|
else:
|
||||||
|
logger.warning("Label distribution check skipped: total valid labels is zero.")
|
||||||
|
else:
|
||||||
|
logger.warning("Could not calculate label distribution (no valid sigma or returns).")
|
||||||
|
# --- End Distribution Check --- #
|
||||||
|
|
||||||
|
# --- One-hot encode --- #
|
||||||
|
try:
|
||||||
|
y_cat_full = np.full((len(df), 3), np.nan, dtype=np.float32)
|
||||||
|
if not ordinal_labels_valid.empty:
|
||||||
|
y_cat_valid = to_categorical(ordinal_labels_valid, num_classes=3)
|
||||||
|
y_cat_full[valid_mask_for_dist] = y_cat_valid.astype(np.float32)
|
||||||
|
else:
|
||||||
|
logger.warning("No valid ordinal labels to one-hot encode.")
|
||||||
|
|
||||||
|
# Assign the list of arrays (or NaNs) - using list avoids mixed type issues later
|
||||||
|
df[target_dir_col] = [list(row) if not np.all(np.isnan(row)) else np.nan for row in y_cat_full]
|
||||||
|
|
||||||
|
except NotImplementedError as nie:
|
||||||
|
logger.error(f"Ternary label generation failed: {nie}. Keras 'to_categorical' is unavailable. Please install tensorflow.", exc_info=True)
|
||||||
|
raise # Re-raise exception to halt pipeline
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during one-hot encoding: {e}", exc_info=True)
|
||||||
|
raise # Re-raise exception to halt pipeline if encoding fails
|
||||||
|
finally:
|
||||||
|
if '_ordinal_label_temp' in df.columns:
|
||||||
|
df.drop(columns=['_ordinal_label_temp'], inplace=True)
|
||||||
|
# --- End One-hot Encoding --- #
|
||||||
|
|
||||||
|
else: # Binary / Soft Binary
|
||||||
|
target_dir_col = f'direction_label_{horizon}'
|
||||||
|
label_smoothing = data_cfg.get('label_smoothing', 0.0)
|
||||||
|
if not (0.0 <= label_smoothing < 1.0):
|
||||||
|
logger.warning(f"Invalid label_smoothing value ({label_smoothing}). Must be in [0.0, 1.0). Disabling smoothing.")
|
||||||
|
label_smoothing = 0.0
|
||||||
|
|
||||||
|
if label_smoothing > 0.0:
|
||||||
|
high_label = 1.0 - label_smoothing / 2.0
|
||||||
|
low_label = label_smoothing / 2.0
|
||||||
|
logger.info(f"Applying label smoothing: {label_smoothing:.2f} -> labels [{low_label:.2f}, {high_label:.2f}] for {target_dir_col}")
|
||||||
|
df[target_dir_col] = np.where(fwd_returns > 0, high_label, low_label).astype(np.float32)
|
||||||
|
else:
|
||||||
|
logger.info(f"Using hard binary labels (0.0 / 1.0) for {target_dir_col}")
|
||||||
|
df[target_dir_col] = (fwd_returns > 0).astype(np.float32)
|
||||||
|
|
||||||
|
# --- Drop Rows with NaN Targets --- #
|
||||||
|
initial_rows = len(df)
|
||||||
|
|
||||||
|
# Create mask for NaNs in the direction column
|
||||||
|
if use_ternary:
|
||||||
|
# Check if elements are np.nan (since we assign np.nan for rows with no valid labels)
|
||||||
|
nan_mask_dir = df[target_dir_col].isna()
|
||||||
|
else:
|
||||||
|
nan_mask_dir = df[target_dir_col].isna()
|
||||||
|
|
||||||
|
nan_mask_combined = df[target_ret_col].isna() | nan_mask_dir
|
||||||
|
|
||||||
|
df_clean = df[~nan_mask_combined].copy()
|
||||||
|
|
||||||
|
final_rows = len(df_clean)
|
||||||
|
if final_rows < initial_rows:
|
||||||
|
logger.info(f"Dropped {initial_rows - final_rows} rows due to NaN targets (horizon={horizon}).")
|
||||||
|
|
||||||
|
if df_clean.empty:
|
||||||
|
logger.error("DataFrame is empty after defining labels and dropping NaNs. Exiting.")
|
||||||
|
# Returning empty DataFrame, caller should handle exit
|
||||||
|
return pd.DataFrame(), target_dir_col, pd.Series(dtype=float), None # Return empty series/None on failure
|
||||||
|
|
||||||
|
# Return the cleaned df, target col name, and the *original* full fwd_returns and eps series
|
||||||
|
# Need to return the original series aligned with the original df index *before* cleaning
|
||||||
|
# So the caller can align them with the features *after* cleaning df_clean
|
||||||
|
return df_clean, target_dir_col, fwd_returns, eps_series
|
||||||
|
# --- End Label Generation --- #
|
||||||
|
|
||||||
|
# --- Stage 1: Load and Preprocess Data (Moved from TradingPipeline.load_and_preprocess_data) --- #
|
||||||
|
def load_and_preprocess(
|
||||||
|
data_loader: DataLoader,
|
||||||
|
io: Optional[IOManager],
|
||||||
|
run_id: str,
|
||||||
|
config: Dict[str, Any]
|
||||||
|
) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Loads the full raw dataset using DataLoader and performs initial checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_loader: Initialized DataLoader instance.
|
||||||
|
io: IOManager instance (optional).
|
||||||
|
run_id: Current run ID.
|
||||||
|
config: Pipeline configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- DataFrame with raw loaded data, or None on failure.
|
||||||
|
- Dictionary summarizing the loading process, or None on failure.
|
||||||
|
"""
|
||||||
|
logger.info("--- Stage: Loading and Preprocessing Data ---")
|
||||||
|
data_cfg = config.get('data', {})
|
||||||
|
|
||||||
|
# --- Extract necessary parameters from config --- #
|
||||||
|
ticker = data_cfg.get('ticker')
|
||||||
|
exchange = data_cfg.get('exchange')
|
||||||
|
start_date = data_cfg.get('start_date')
|
||||||
|
end_date = data_cfg.get('end_date')
|
||||||
|
interval = data_cfg.get('interval', '1min') # Default to 1min
|
||||||
|
vol_sampling = data_cfg.get('volatility_sampling', {}).get('enabled', False)
|
||||||
|
vol_window = data_cfg.get('volatility_sampling', {}).get('window', 30)
|
||||||
|
vol_quantile = data_cfg.get('volatility_sampling', {}).get('quantile', 0.5)
|
||||||
|
|
||||||
|
# Validate required parameters
|
||||||
|
if not all([ticker, exchange, start_date, end_date]):
|
||||||
|
logger.error("Missing required data parameters in config: ticker, exchange, start_date, end_date")
|
||||||
|
return None, None
|
||||||
|
# --- End Parameter Extraction --- #
|
||||||
|
|
||||||
|
load_summary = {
|
||||||
|
'ticker': ticker,
|
||||||
|
'exchange': exchange,
|
||||||
|
'start_date_req': start_date,
|
||||||
|
'end_date_req': end_date,
|
||||||
|
'interval_req': interval,
|
||||||
|
'vol_sampling_enabled': vol_sampling,
|
||||||
|
'vol_window': vol_window,
|
||||||
|
'vol_quantile': vol_quantile,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading data for {ticker} ({exchange}) from {start_date} to {end_date}, interval {interval}")
|
||||||
|
# --- Pass extracted parameters to load_data --- #
|
||||||
|
df_raw = data_loader.load_data(
|
||||||
|
ticker=ticker,
|
||||||
|
exchange=exchange,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
interval=interval,
|
||||||
|
vol_sampling=vol_sampling,
|
||||||
|
vol_window=vol_window,
|
||||||
|
vol_quantile=vol_quantile
|
||||||
|
)
|
||||||
|
# --- End Pass Parameters --- #
|
||||||
|
|
||||||
|
if df_raw is None or df_raw.empty:
|
||||||
|
logger.error("Data loading returned empty DataFrame or failed.")
|
||||||
|
return None, load_summary
|
||||||
|
|
||||||
|
# --- Fill Missing Bars (Step 2.5 from prompts/missing_data.txt) --- #
|
||||||
|
if io is None:
|
||||||
|
logger.error("IOManager is required for fill_missing_bars reporting. Cannot proceed.")
|
||||||
|
return None, load_summary
|
||||||
|
try:
|
||||||
|
df_filled = fill_missing_bars(df_raw, config, io, logger)
|
||||||
|
if df_filled is None or df_filled.empty:
|
||||||
|
logger.error("fill_missing_bars returned empty DataFrame or failed.")
|
||||||
|
return None, load_summary
|
||||||
|
df_raw = df_filled # Replace df_raw with the filled version
|
||||||
|
logger.info("Missing bars handled successfully.")
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error during missing bar handling: {e}. Halting processing.")
|
||||||
|
return None, load_summary
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during missing bar handling: {e}. Halting processing.", exc_info=True)
|
||||||
|
return None, load_summary
|
||||||
|
# --- End Fill Missing Bars --- #
|
||||||
|
|
||||||
|
# Calculate memory usage and log info
|
||||||
|
mem_usage = df_raw.memory_usage(deep=True).sum() / (1024**2)
|
||||||
|
if load_summary:
|
||||||
|
logger.info(f"Data loading summary: {load_summary}")
|
||||||
|
else:
|
||||||
|
logger.warning("No load summary returned by DataLoader.")
|
||||||
|
logger.info(f"Loaded data: {df_raw.shape[0]} rows, {df_raw.shape[1]} columns. Memory: {mem_usage:.2f} MB")
|
||||||
|
logger.info(f"Time range: {df_raw.index.min()} to {df_raw.index.max()}")
|
||||||
|
|
||||||
|
# --- V3 Output Contract: Stage 1 Artifacts --- #
|
||||||
|
if io:
|
||||||
|
if load_summary:
|
||||||
|
save_summary = load_summary.copy() # Don't modify original
|
||||||
|
save_summary['run_id'] = run_id
|
||||||
|
save_summary['timestamp_utc'] = datetime.now(timezone.utc).isoformat()
|
||||||
|
# TODO: Finalize summary content (add counts, NaN info etc.)
|
||||||
|
logger.info("Saving preprocess summary...")
|
||||||
|
io.save_json(save_summary, "preprocess_summary", use_txt=True) # Spec wants .txt
|
||||||
|
|
||||||
|
# Save head of preprocessed data
|
||||||
|
if df_raw is not None and not df_raw.empty:
|
||||||
|
logger.info("Saving head of preprocessed data (first 20 rows)...")
|
||||||
|
io.save_df(df_raw.head(20), "head_preprocessed")
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping saving head_preprocessed: DataFrame is empty or None.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning("IOManager not available, skipping saving of Stage 1 artifacts (preprocess_summary, head_preprocessed).")
|
||||||
|
# --- End V3 Output Contract ---
|
||||||
|
|
||||||
|
# --- V3 Output Contract: Stage 2 Artifact (Label Histogram) --- #
|
||||||
|
# TODO: Move this plotting logic to evaluation stage or after split, needs y_train.
|
||||||
|
# if io and config.get('control', {}).get('generate_plots', True):
|
||||||
|
# logger.info("Generating training label distribution histogram... [SKIPPED IN CURRENT STAGE]")
|
||||||
|
# ... (Original plotting code removed from here)
|
||||||
|
# --- End V3 Output Contract ---
|
||||||
|
|
||||||
|
return df_raw, load_summary
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during data loading: {e}", exc_info=True)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# --- Stage 2: Engineer Features (Moved from TradingPipeline.engineer_features) --- #
|
||||||
|
def engineer_features_for_fold(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
feature_engineer: FeatureEngineer,
|
||||||
|
io: Optional[IOManager], # Added IOManager for saving figure
|
||||||
|
config: Dict[str, Any], # Added config for plot settings
|
||||||
|
target_col: Optional[str] = None # Added target column name for sorting correlation
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Adds features using FeatureEngineer, handles NaNs, and saves correlation heatmap for a fold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): Input DataFrame for the fold (typically raw data).
|
||||||
|
feature_engineer (FeatureEngineer): Initialized FeatureEngineer instance.
|
||||||
|
io (Optional[IOManager]): IOManager instance for saving artifacts.
|
||||||
|
config (Dict[str, Any]): Pipeline configuration dictionary.
|
||||||
|
target_col (Optional[str]): Name of the target column to sort correlations by (e.g., 'fwd_log_ret_5').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: DataFrame with engineered features, NaNs dropped.
|
||||||
|
Returns an empty DataFrame if input is empty or result is empty.
|
||||||
|
"""
|
||||||
|
logger.info("--- Stage: Engineering Features --- ")
|
||||||
|
if df is None or df.empty:
|
||||||
|
logger.error("Input DataFrame is empty. Cannot engineer features.")
|
||||||
|
return pd.DataFrame() # Return empty DataFrame to indicate failure
|
||||||
|
|
||||||
|
if feature_engineer is None:
|
||||||
|
logger.error("FeatureEngineer not initialized. Cannot engineer features.")
|
||||||
|
# Or raise an error? For now return empty
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# Add base features (cyclical, imbalance, TA)
|
||||||
|
df_engineered = feature_engineer.add_base_features(df.copy())
|
||||||
|
|
||||||
|
# --- V3 Output Contract: Feature Correlation Heatmap --- #
|
||||||
|
# Generate heatmap *before* dropping NaNs to capture full feature set correlations
|
||||||
|
# if io and config.get('control', {}).get('generate_plots', True): # Check if plotting is enabled
|
||||||
|
if io: # Assume generate_plots is implicitly true if io is provided
|
||||||
|
try:
|
||||||
|
logger.info("Generating feature correlation heatmap...")
|
||||||
|
numeric_cols = df_engineered.select_dtypes(include=np.number).columns
|
||||||
|
if len(numeric_cols) < 2:
|
||||||
|
logger.warning("Skipping correlation heatmap: Less than 2 numeric columns found.")
|
||||||
|
else:
|
||||||
|
corr_matrix = df_engineered[numeric_cols].corr(method='pearson')
|
||||||
|
|
||||||
|
# Get plot settings from config
|
||||||
|
output_cfg = config.get('output', {})
|
||||||
|
fig_size = output_cfg.get('figure_size', [16, 9])
|
||||||
|
plot_style = output_cfg.get('plot_style', 'seaborn-v0_8-darkgrid')
|
||||||
|
annot_threshold = output_cfg.get('corr_annot_threshold', 0.5)
|
||||||
|
plot_footer = output_cfg.get('plot_footer', "© GRU-SAC v3")
|
||||||
|
|
||||||
|
plt.style.use(plot_style)
|
||||||
|
fig, ax = plt.subplots(figsize=fig_size)
|
||||||
|
|
||||||
|
sort_features = False
|
||||||
|
if target_col and target_col in corr_matrix.columns:
|
||||||
|
# Sort by absolute correlation with the target
|
||||||
|
target_corr = corr_matrix[target_col].abs().sort_values(ascending=False)
|
||||||
|
sorted_cols = target_corr.index.tolist()
|
||||||
|
corr_matrix_sorted = corr_matrix.loc[sorted_cols, sorted_cols]
|
||||||
|
sort_features = True
|
||||||
|
else:
|
||||||
|
if target_col:
|
||||||
|
logger.warning(f"Target column '{target_col}' not found in correlation matrix. Heatmap will not be sorted by target correlation.")
|
||||||
|
corr_matrix_sorted = corr_matrix # Use original matrix if no target or not found
|
||||||
|
|
||||||
|
sns.heatmap(
|
||||||
|
corr_matrix_sorted,
|
||||||
|
annot=False, # Annotations can be messy; spec only requires > threshold
|
||||||
|
cmap='coolwarm', # Diverging palette centered at 0
|
||||||
|
center=0,
|
||||||
|
linewidths=0.5,
|
||||||
|
cbar=True,
|
||||||
|
square=True, # Ensure square cells
|
||||||
|
ax=ax
|
||||||
|
)
|
||||||
|
|
||||||
|
# Annotate cells where absolute correlation > threshold (from config)
|
||||||
|
for i in range(corr_matrix_sorted.shape[0]):
|
||||||
|
for j in range(corr_matrix_sorted.shape[1]):
|
||||||
|
if abs(corr_matrix_sorted.iloc[i, j]) > annot_threshold and i != j:
|
||||||
|
ax.text(j + 0.5, i + 0.5, f'{corr_matrix_sorted.iloc[i, j]:.2f}',
|
||||||
|
ha='center', va='center', color='black', fontsize=8)
|
||||||
|
|
||||||
|
title = "Feature Correlation Heatmap (Pearson)"
|
||||||
|
if sort_features:
|
||||||
|
title += f" - Sorted by |ρ| vs '{target_col}'"
|
||||||
|
ax.set_title(title, fontsize=14)
|
||||||
|
plt.xticks(rotation=90, fontsize=8)
|
||||||
|
plt.yticks(rotation=0, fontsize=8)
|
||||||
|
|
||||||
|
# Add footer (from config)
|
||||||
|
if plot_footer: # Only add if footer is not empty
|
||||||
|
plt.figtext(0.99, 0.01, plot_footer, ha="right", va="bottom", fontsize=8, color='gray')
|
||||||
|
|
||||||
|
# Save figure using IOManager
|
||||||
|
io.save_figure(fig, "feature_corr_heatmap", section='figures') # Saved to results/<run_id>/figures/
|
||||||
|
plt.close(fig) # Close figure after saving
|
||||||
|
logger.info("Saved feature correlation heatmap.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate or save feature correlation heatmap: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logger.warning("IOManager not provided or plotting disabled, skipping feature correlation heatmap.")
|
||||||
|
# --- End V3 Output Contract --- #
|
||||||
|
|
||||||
|
# --- REMOVE Aggressive DropNA --- #
|
||||||
|
# Dropping all rows with any NaN here is too aggressive, especially with long lookback features.
|
||||||
|
# NaN handling should occur within feature calculation methods (bfill/ffill/fillna(0))
|
||||||
|
# and critically during label definition (where rows without valid labels are dropped).
|
||||||
|
# initial_rows = len(df_engineered)
|
||||||
|
# df_engineered.dropna(inplace=True)
|
||||||
|
# rows_dropped = initial_rows - len(df_engineered)
|
||||||
|
# if rows_dropped > 0:
|
||||||
|
# logger.warning(f"Dropped {rows_dropped} rows with NaN values after feature engineering.")
|
||||||
|
# --- End REMOVE --- #
|
||||||
|
|
||||||
|
# Check if dataframe became empty *after feature calculation and internal NaN handling*
|
||||||
|
# (Though ideally internal handling should prevent this)
|
||||||
|
if df_engineered.empty:
|
||||||
|
logger.error("DataFrame is empty after feature engineering (check internal NaN handling in FeatureEngineer)." )
|
||||||
|
return pd.DataFrame() # Return empty DataFrame
|
||||||
|
|
||||||
|
logger.info(f"Feature engineering complete. Shape: {df_engineered.shape}")
|
||||||
|
return df_engineered
|
||||||
|
|
||||||
|
# --- Stage 3: Define Labels and Align (Moved from TradingPipeline.define_labels_and_align) --- #
|
||||||
|
def define_labels_and_align_fold(
|
||||||
|
df_engineered: pd.DataFrame,
|
||||||
|
config: dict
|
||||||
|
) -> Tuple[pd.DataFrame, str, List[str], pd.Series, Optional[pd.Series]]:
|
||||||
|
"""Defines prediction labels, aligns with features, and separates targets for a fold.
|
||||||
|
Also returns the raw forward returns and epsilon series used for filtering baselines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df_engineered (pd.DataFrame): DataFrame with engineered features for the fold.
|
||||||
|
config (dict): Pipeline configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[pd.DataFrame, str, List[str], pd.Series, Optional[pd.Series]]:
|
||||||
|
- df_labeled_aligned: DataFrame with labels generated and features/targets aligned (NaNs dropped).
|
||||||
|
- target_dir_col: Name of the direction label column.
|
||||||
|
- target_cols: List containing names of all target columns (ret + dir).
|
||||||
|
- fwd_returns_aligned: Series of forward returns aligned with df_labeled_aligned.
|
||||||
|
- eps_aligned: Series of epsilon threshold aligned with df_labeled_aligned (or None).
|
||||||
|
Returns (pd.DataFrame(), "", [], pd.Series(), None) on failure or empty input.
|
||||||
|
"""
|
||||||
|
logger.info("--- Stage: Defining Labels and Aligning --- ")
|
||||||
|
if df_engineered is None or df_engineered.empty:
|
||||||
|
logger.error("Engineered data (DataFrame) is empty. Cannot define labels.")
|
||||||
|
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||||
|
|
||||||
|
# --- Call the label generation function (already in this module) --- #
|
||||||
|
try:
|
||||||
|
# generate_direction_labels modifies the DataFrame in place and returns it
|
||||||
|
# It also returns the original fwd_returns and eps series (aligned with df_engineered)
|
||||||
|
df_clean, target_dir_col, fwd_returns_orig, eps_orig = generate_direction_labels(
|
||||||
|
df_engineered.copy(), # Pass a copy to avoid modifying original outside this scope if needed
|
||||||
|
config
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Label generation failed: {e}.", exc_info=True)
|
||||||
|
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||||
|
|
||||||
|
if df_clean.empty:
|
||||||
|
logger.error("Label generation resulted in an empty DataFrame.")
|
||||||
|
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||||
|
# --- End Label Generation Call --- #
|
||||||
|
|
||||||
|
# --- Determine Target Columns --- #
|
||||||
|
horizon = config.get('gru', {}).get('prediction_horizon', 5)
|
||||||
|
target_ret_col = f'fwd_log_ret_{horizon}'
|
||||||
|
# target_dir_col is returned by generate_direction_labels
|
||||||
|
target_cols = [target_ret_col, target_dir_col]
|
||||||
|
|
||||||
|
# Ensure the columns actually exist after generation and cleaning
|
||||||
|
if not all(col in df_clean.columns for col in target_cols):
|
||||||
|
# Log which columns are actually present for debugging
|
||||||
|
present_cols = df_clean.columns.tolist()
|
||||||
|
logger.error(f"Generated label/return columns ({target_cols}) not found in DataFrame after label generation. Present columns: {present_cols}")
|
||||||
|
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||||
|
# --- End Determine Target Columns --- #
|
||||||
|
|
||||||
|
# --- Align fwd_returns_orig and eps_orig with the cleaned DataFrame --- #
|
||||||
|
fwd_returns_aligned = fwd_returns_orig.loc[df_clean.index]
|
||||||
|
eps_aligned = eps_orig.loc[df_clean.index] if eps_orig is not None else None
|
||||||
|
# --- End Alignment --- #
|
||||||
|
|
||||||
|
# Note: Separation of X and y happens in the splitting function now
|
||||||
|
# We just need to return the fully labeled/aligned DataFrame and target column names.
|
||||||
|
logger.info(f"Labels defined and aligned. Shape: {df_clean.shape}")
|
||||||
|
|
||||||
|
# Return the aligned DataFrame and the aligned supplementary series
|
||||||
|
return df_clean, target_dir_col, target_cols, fwd_returns_aligned, eps_aligned
|
||||||
|
|
||||||
|
# --- Stage 4: Split Data (Moved from TradingPipeline.split_data) --- #
|
||||||
|
def split_data_fold(
|
||||||
|
df_labeled_aligned: pd.DataFrame,
|
||||||
|
fwd_returns_aligned: pd.Series,
|
||||||
|
eps_aligned: Optional[pd.Series],
|
||||||
|
config: dict,
|
||||||
|
target_columns: List[str],
|
||||||
|
target_dir_col: str,
|
||||||
|
fold_dates: Optional[Tuple] = None,
|
||||||
|
current_fold: Optional[int] = None # For logging
|
||||||
|
) -> Tuple[
|
||||||
|
# Features
|
||||||
|
pd.DataFrame, pd.DataFrame, pd.DataFrame,
|
||||||
|
# Original Targets
|
||||||
|
pd.DataFrame, pd.DataFrame, pd.DataFrame,
|
||||||
|
# Original Full DataFrames
|
||||||
|
pd.DataFrame, pd.DataFrame, pd.DataFrame,
|
||||||
|
# Ordinal Direction Target (Train only)
|
||||||
|
pd.Series,
|
||||||
|
# Forward Returns (Train/Val)
|
||||||
|
pd.Series, Optional[pd.Series],
|
||||||
|
# Epsilon (Train/Val)
|
||||||
|
Optional[pd.Series], Optional[pd.Series],
|
||||||
|
# Ordinal Direction Labels (Val)
|
||||||
|
Optional[pd.Series]
|
||||||
|
]:
|
||||||
|
"""Splits features, targets, fwd returns, and epsilon for a given fold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df_labeled_aligned (pd.DataFrame): Labeled and aligned data for the entire fold period.
|
||||||
|
fwd_returns_aligned (pd.Series): Forward returns aligned with df_labeled_aligned.
|
||||||
|
eps_aligned (Optional[pd.Series]): Epsilon threshold aligned with df_labeled_aligned.
|
||||||
|
config (dict): Pipeline configuration.
|
||||||
|
target_columns (List[str]): Names of all target columns (e.g., ['fwd_log_ret_5', 'direction_label_5']).
|
||||||
|
target_dir_col (str): Name of the specific direction target column.
|
||||||
|
fold_dates (Optional[Tuple]): Tuple of (train_start, train_end, val_start, val_end, test_start, test_end) for WF.
|
||||||
|
current_fold (Optional[int]): Fold number for logging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing the split dataframes/series:
|
||||||
|
(X_train_raw, X_val_raw, X_test_raw, # Features
|
||||||
|
y_train, y_val, y_test, # Targets
|
||||||
|
df_train_original, df_val_original, df_test_original, # Original DFs
|
||||||
|
y_dir_train_ordinal, # Ordinal Direction Labels (Train)
|
||||||
|
fwd_ret_train, fwd_ret_val, # Forward Returns (Train/Val)
|
||||||
|
eps_train, eps_val, # Epsilon (Train/Val, Optional)
|
||||||
|
y_dir_val_ordinal) # Ordinal Direction Labels (Val, Optional)
|
||||||
|
Returns tuple of Nones if splitting fails.
|
||||||
|
"""
|
||||||
|
fold_label = f"Fold {current_fold}" if current_fold is not None else "Split"
|
||||||
|
logger.info(f"--- {fold_label}: Stage: Splitting Data --- ")
|
||||||
|
|
||||||
|
if df_labeled_aligned is None or df_labeled_aligned.empty:
|
||||||
|
logger.error(f"Fold {fold_label}: Input data for splitting is empty.")
|
||||||
|
# Return Nones to indicate failure (update count based on new returns)
|
||||||
|
return (None,) * 14
|
||||||
|
|
||||||
|
# --- Temporarily add fwd_ret and eps to DataFrame for easier splitting --- #
|
||||||
|
temp_fwd_ret_col = '__temp_fwd_ret__'
|
||||||
|
temp_eps_col = '__temp_eps__'
|
||||||
|
df_split_input = df_labeled_aligned.copy()
|
||||||
|
df_split_input[temp_fwd_ret_col] = fwd_returns_aligned
|
||||||
|
if eps_aligned is not None:
|
||||||
|
df_split_input[temp_eps_col] = eps_aligned
|
||||||
|
# --- End Temp Add --- #
|
||||||
|
|
||||||
|
if not isinstance(df_split_input.index, pd.DatetimeIndex):
|
||||||
|
logger.error(f"{fold_label}: Data index must be DatetimeIndex for splitting. Aborting.")
|
||||||
|
raise SystemExit(f"{fold_label}: Index is not DatetimeIndex in split_data.")
|
||||||
|
|
||||||
|
if not target_columns:
|
||||||
|
logger.error(f"{fold_label}: Target columns list is empty. Aborting.")
|
||||||
|
raise SystemExit(f"{fold_label}: Target columns missing in split_data.")
|
||||||
|
if not target_dir_col:
|
||||||
|
logger.error(f"{fold_label}: Target direction column name is empty. Aborting.")
|
||||||
|
raise SystemExit(f"{fold_label}: Target direction column missing in split_data.")
|
||||||
|
|
||||||
|
# Ensure target columns exist before trying to drop/select them
|
||||||
|
cols_to_drop = [col for col in target_columns if col in df_split_input.columns]
|
||||||
|
if len(cols_to_drop) != len(target_columns):
|
||||||
|
missing_targets = set(target_columns) - set(cols_to_drop)
|
||||||
|
logger.error(f"{fold_label}: Expected target columns {missing_targets} not found in input DataFrame. Aborting.")
|
||||||
|
raise SystemExit(f"{fold_label}: Missing target columns in split_data input.")
|
||||||
|
|
||||||
|
# Exclude temporary columns from feature_cols
|
||||||
|
feature_cols = df_split_input.columns.difference(cols_to_drop + [temp_fwd_ret_col, temp_eps_col])
|
||||||
|
if feature_cols.empty:
|
||||||
|
logger.error(f"{fold_label}: No feature columns remain after excluding targets and temp cols. Aborting.")
|
||||||
|
raise SystemExit(f"{fold_label}: No feature columns found in split_data.")
|
||||||
|
|
||||||
|
# --- Determine if ternary mode is active --- #
|
||||||
|
use_ternary = config.get('gru', {}).get('use_ternary', False)
|
||||||
|
# --- End Determine Ternary --- #
|
||||||
|
|
||||||
|
# Initialize split results
|
||||||
|
X_train_raw, X_val_raw, X_test_raw = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
||||||
|
y_train, y_val, y_test = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
||||||
|
df_train_original, df_val_original, df_test_original = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
||||||
|
fwd_ret_train, fwd_ret_val = pd.Series(dtype=float), pd.Series(dtype=float)
|
||||||
|
eps_train, eps_val = None, None
|
||||||
|
y_dir_train_raw_format, y_dir_val_raw_format = pd.Series(dtype=object), pd.Series(dtype=object) # Store raw labels before converting
|
||||||
|
|
||||||
|
# Split based on Walk-Forward dates or ratios
|
||||||
|
if fold_dates and len(fold_dates) == 6 and all(fold_dates): # Check for valid WF tuple
|
||||||
|
train_start, train_end, val_start, val_end, test_start, test_end = fold_dates
|
||||||
|
logger.info(f" Splitting using Walk-Forward dates: Train=[{train_start}, {train_end}), Val=[{val_start}, {val_end}), Test=[{test_start}, {test_end})")
|
||||||
|
|
||||||
|
# Slicing logic
|
||||||
|
df_train_original = df_split_input.loc[train_start:train_end]
|
||||||
|
df_val_original = df_split_input.loc[val_start:val_end]
|
||||||
|
df_test_original = df_split_input.loc[test_start:test_end] if test_start else pd.DataFrame()
|
||||||
|
|
||||||
|
else: # Single split using ratios
|
||||||
|
split_cfg = config.get('split_ratios', {})
|
||||||
|
train_ratio = split_cfg.get('train', 0.7)
|
||||||
|
val_ratio = split_cfg.get('validation', 0.15)
|
||||||
|
test_ratio = round(1.0 - train_ratio - val_ratio, 2)
|
||||||
|
logger.info(f" Splitting using ratios: Train={train_ratio:.2f}, Val={val_ratio:.2f}, Test={test_ratio:.2f}")
|
||||||
|
|
||||||
|
total_len = len(df_split_input)
|
||||||
|
train_end_idx = int(total_len * train_ratio)
|
||||||
|
val_end_idx = int(total_len * (train_ratio + val_ratio))
|
||||||
|
|
||||||
|
df_train_original = df_split_input.iloc[:train_end_idx]
|
||||||
|
df_val_original = df_split_input.iloc[train_end_idx:val_end_idx]
|
||||||
|
df_test_original = df_split_input.iloc[val_end_idx:]
|
||||||
|
|
||||||
|
# --- Extract components from split DataFrames --- #
|
||||||
|
if not df_train_original.empty:
|
||||||
|
X_train_raw = df_train_original[feature_cols]
|
||||||
|
y_train = df_train_original[target_columns]
|
||||||
|
y_dir_train_raw_format = df_train_original[target_dir_col]
|
||||||
|
fwd_ret_train = df_train_original[temp_fwd_ret_col]
|
||||||
|
if temp_eps_col in df_train_original:
|
||||||
|
eps_train = df_train_original[temp_eps_col]
|
||||||
|
|
||||||
|
if not df_val_original.empty:
|
||||||
|
X_val_raw = df_val_original[feature_cols]
|
||||||
|
y_val = df_val_original[target_columns]
|
||||||
|
fwd_ret_val = df_val_original[temp_fwd_ret_col]
|
||||||
|
if temp_eps_col in df_val_original:
|
||||||
|
eps_val = df_val_original[temp_eps_col]
|
||||||
|
|
||||||
|
if not df_test_original.empty:
|
||||||
|
X_test_raw = df_test_original[feature_cols]
|
||||||
|
y_test = df_test_original[target_columns]
|
||||||
|
# --- End Extraction --- #
|
||||||
|
|
||||||
|
# --- Extract Ordinal Labels if Ternary --- #
|
||||||
|
y_dir_train_ordinal = None
|
||||||
|
if not y_train.empty: # Check if training data exists
|
||||||
|
if use_ternary:
|
||||||
|
valid_mask = y_dir_train_raw_format.notna() & y_dir_train_raw_format.apply(lambda x: isinstance(x, list) and len(x) == 3)
|
||||||
|
if valid_mask.any():
|
||||||
|
ordinal_values = y_dir_train_raw_format[valid_mask].apply(np.argmax)
|
||||||
|
y_dir_train_ordinal = pd.Series(np.nan, index=y_dir_train_raw_format.index)
|
||||||
|
y_dir_train_ordinal[valid_mask] = ordinal_values
|
||||||
|
logger.info(f"{fold_label}: Extracted ordinal labels (0, 1, 2) for feature selection. Count: {valid_mask.sum()}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"{fold_label}: No valid list-based ternary labels found in y_dir_train_raw_format to convert to ordinal.")
|
||||||
|
y_dir_train_ordinal = pd.Series(dtype=np.float64) # Return empty series
|
||||||
|
else:
|
||||||
|
y_dir_train_ordinal = y_dir_train_raw_format.astype(int) # Ensure integer type
|
||||||
|
else:
|
||||||
|
y_dir_train_ordinal = pd.Series(dtype=int) # Empty series if no train data
|
||||||
|
# --- End Extract Ordinal Labels --- #
|
||||||
|
|
||||||
|
# --- Extract Ordinal Validation Labels if Ternary --- #
|
||||||
|
y_dir_val_ordinal = None
|
||||||
|
if not y_val.empty: # Check if validation data exists
|
||||||
|
if use_ternary:
|
||||||
|
# Use y_dir_val_raw_format which holds the lists/NaNs
|
||||||
|
valid_mask_val = y_dir_val_raw_format.notna() & y_dir_val_raw_format.apply(lambda x: isinstance(x, list) and len(x) == 3)
|
||||||
|
if valid_mask_val.any():
|
||||||
|
ordinal_values_val = y_dir_val_raw_format[valid_mask_val].apply(np.argmax)
|
||||||
|
y_dir_val_ordinal = pd.Series(np.nan, index=y_dir_val_raw_format.index)
|
||||||
|
y_dir_val_ordinal[valid_mask_val] = ordinal_values_val
|
||||||
|
logger.info(f"{fold_label}: Extracted ordinal validation labels. Count: {valid_mask_val.sum()}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"{fold_label}: No valid ternary labels found in y_dir_val_raw_format.")
|
||||||
|
y_dir_val_ordinal = pd.Series(dtype=np.float64)
|
||||||
|
else:
|
||||||
|
# Use y_dir_val_raw_format which holds 0.0/1.0
|
||||||
|
y_dir_val_ordinal = y_dir_val_raw_format.astype(int)
|
||||||
|
else:
|
||||||
|
y_dir_val_ordinal = pd.Series(dtype=int) # Empty series if no validation data
|
||||||
|
# --- End Extract Ordinal Validation Labels --- #
|
||||||
|
|
||||||
|
# Log split shapes and check for empty splits
|
||||||
|
logger.info(f"Data split complete for {fold_label}:")
|
||||||
|
logger.info(f" Train: X={X_train_raw.shape}, y={y_train.shape}, fwd_ret={fwd_ret_train.shape}, eps={eps_train.shape if eps_train is not None else 'None'} ({X_train_raw.index.min()} to {X_train_raw.index.max()})" if not X_train_raw.empty else " Train: EMPTY")
|
||||||
|
logger.info(f" Val: X={X_val_raw.shape}, y={y_val.shape}, fwd_ret={fwd_ret_val.shape}, eps={eps_val.shape if eps_val is not None else 'None'} ({X_val_raw.index.min()} to {X_val_raw.index.max()})" if not X_val_raw.empty else " Val: EMPTY")
|
||||||
|
logger.info(f" Test: X=({X_test_raw.shape if X_test_raw is not None else 'None'}), y=({y_test.shape if y_test is not None else 'None'}) ({df_test_original.index.min() if df_test_original is not None and not df_test_original.empty else 'N/A'} to {df_test_original.index.max() if df_test_original is not None and not df_test_original.empty else 'N/A'})" )
|
||||||
|
|
||||||
|
# Check required splits are non-empty
|
||||||
|
if X_train_raw.empty or X_val_raw.empty:
|
||||||
|
logger.error(f"Fold {current_fold}: Data splitting resulted in empty train or validation set. Aborting fold.")
|
||||||
|
raise SystemExit(f"Fold {current_fold}: Empty train or validation split detected.")
|
||||||
|
|
||||||
|
return (
|
||||||
|
X_train_raw, X_val_raw, X_test_raw, # Features
|
||||||
|
y_train, y_val, y_test, # Targets
|
||||||
|
df_train_original, df_val_original, df_test_original, # Original DFs
|
||||||
|
y_dir_train_ordinal, # Ordinal Direction Labels (Train)
|
||||||
|
fwd_ret_train, fwd_ret_val, # Forward Returns (Train/Val)
|
||||||
|
eps_train, eps_val, # Epsilon (Train/Val, Optional)
|
||||||
|
y_dir_val_ordinal # Ordinal Direction Labels (Val, Optional)
|
||||||
|
)
|
||||||
182
gru_sac_predictor/src/pipeline_stages/sequence_creation.py
Normal file
182
gru_sac_predictor/src/pipeline_stages/sequence_creation.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
# Stage functions for creating GRU input sequences
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Tuple, Dict, Optional, List
|
||||||
|
import json # Added for saving artefact
|
||||||
|
|
||||||
|
# Assuming IOManager is importable from parent src directory
|
||||||
|
from ..io_manager import IOManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def create_sequences_fold(
|
||||||
|
X_data: pd.DataFrame,
|
||||||
|
y_data: pd.DataFrame,
|
||||||
|
target_names: List[str], # e.g., ['mu', 'dir3'] or ['mu', 'dir']
|
||||||
|
lookback: int,
|
||||||
|
name: str, # e.g., "Train", "Validation", "Test"
|
||||||
|
config: dict, # For gru.drop_imputed_sequences
|
||||||
|
io: Optional[IOManager] # For saving artefact
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[Dict], Optional[pd.Index], int]:
|
||||||
|
"""
|
||||||
|
Transforms pruned, scaled feature DataFrame into 3D sequences for GRU input
|
||||||
|
and extracts corresponding targets for a specific data split (Train/Val/Test).
|
||||||
|
Handles dropping sequences containing imputed bars based on config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_data (pd.DataFrame): Pruned, scaled features for the split.
|
||||||
|
y_data (pd.DataFrame): Targets for the split.
|
||||||
|
target_names (List[str]): List of target column names in y_data (e.g., ['mu', 'dir3']).
|
||||||
|
lookback (int): Sequence length.
|
||||||
|
name (str): Name of the split (e.g., "Train", "Validation", "Test") for logging.
|
||||||
|
config (dict): Pipeline configuration dictionary.
|
||||||
|
io (Optional[IOManager]): IOManager instance for saving artefacts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- X_seq (np.ndarray or None): 3D feature sequences.
|
||||||
|
- y_seq_dict (Dict or None): Dictionary of target sequences.
|
||||||
|
- target_indices (pd.Index or None): Timestamps corresponding to the targets.
|
||||||
|
- dropped_count (int): Number of sequences dropped due to imputed bars.
|
||||||
|
Returns (None, None, None, 0) if sequence creation fails or data is insufficient.
|
||||||
|
Raises SystemExit on critical errors (e.g., misalignment).
|
||||||
|
"""
|
||||||
|
logger.info(f"--- Creating {name} Sequences ---")
|
||||||
|
use_ternary = config.get('gru', {}).get('use_ternary', False)
|
||||||
|
drop_imputed = config.get('gru', {}).get('drop_imputed_sequences', False)
|
||||||
|
imputed_col_name = 'bar_imputed' # Assuming this is the column name
|
||||||
|
|
||||||
|
# --- Input Validation --- #
|
||||||
|
if X_data is None or y_data is None or X_data.empty or y_data.empty:
|
||||||
|
logger.error(f"{name}: Missing or empty features/targets for sequence creation.")
|
||||||
|
return None, None, None, 0
|
||||||
|
|
||||||
|
# Check for bar_imputed column
|
||||||
|
if imputed_col_name not in X_data.columns:
|
||||||
|
logger.error(f"{name}: Required column '{imputed_col_name}' not found in features. Cannot handle imputed sequences.")
|
||||||
|
# Decide whether to proceed without it or raise error - raising for now
|
||||||
|
raise SystemExit(f"{name}: '{imputed_col_name}' column missing. Sequence creation halted.")
|
||||||
|
|
||||||
|
# Strict Anti-Leakage Check
|
||||||
|
try:
|
||||||
|
assert X_data.index.equals(y_data.index), \
|
||||||
|
f"{name}: Features and targets indices misaligned!"
|
||||||
|
except AssertionError as e:
|
||||||
|
logger.error(f"Data alignment check failed: {e}. Potential data leakage. Aborting.")
|
||||||
|
raise SystemExit(f"{name}: {e}")
|
||||||
|
|
||||||
|
# Check target columns exist
|
||||||
|
if not all(col in y_data.columns for col in target_names):
|
||||||
|
missing_targets = set(target_names) - set(y_data.columns)
|
||||||
|
logger.error(f"{name}: Target columns {missing_targets} not found in y_data. Aborting.")
|
||||||
|
raise SystemExit(f"{name}: Missing target columns for sequencing.")
|
||||||
|
# --- End Input Validation --- #
|
||||||
|
|
||||||
|
# Convert DataFrames to numpy for potential speedup, keep index access
|
||||||
|
features_np = X_data.values
|
||||||
|
imputed_flag_np = X_data[imputed_col_name].values.astype(bool) # Ensure boolean type
|
||||||
|
# Extract targets based on target_names
|
||||||
|
targets_dict_np = {name: y_data[name].values for name in target_names}
|
||||||
|
|
||||||
|
X_seq_list, y_seq_dict_list = [], {name: [] for name in target_names}
|
||||||
|
mask_seq_list = [] # To store the imputed flag sequences
|
||||||
|
target_indices = []
|
||||||
|
|
||||||
|
if len(X_data) <= lookback:
|
||||||
|
logger.warning(f"{name}: DataFrame length ({len(X_data)}) is not greater than lookback ({lookback}). Cannot create sequences.")
|
||||||
|
return None, None, None, 0
|
||||||
|
|
||||||
|
for i in range(lookback, len(features_np)):
|
||||||
|
# Feature window: [i-lookback, i)
|
||||||
|
X_seq_list.append(features_np[i - lookback : i])
|
||||||
|
mask_seq_list.append(imputed_flag_np[i - lookback : i])
|
||||||
|
|
||||||
|
# Targets correspond to index i
|
||||||
|
for t_name in target_names:
|
||||||
|
target_val = targets_dict_np[t_name][i]
|
||||||
|
# Special handling for potential list/array type in ternary labels
|
||||||
|
if use_ternary and 'dir' in t_name and isinstance(target_val, list):
|
||||||
|
target_val = np.array(target_val, dtype=np.float32)
|
||||||
|
y_seq_dict_list[t_name].append(target_val)
|
||||||
|
|
||||||
|
target_indices.append(y_data.index[i]) # Get index corresponding to target
|
||||||
|
|
||||||
|
if not X_seq_list: # Check if any sequences were created
|
||||||
|
logger.warning(f"{name}: No sequences were generated (length <= lookback?).")
|
||||||
|
return None, None, None, 0
|
||||||
|
|
||||||
|
# Convert lists to numpy arrays
|
||||||
|
X_seq = np.array(X_seq_list, dtype=np.float32)
|
||||||
|
mask_seq = np.array(mask_seq_list, dtype=bool)
|
||||||
|
target_indices_pd = pd.Index(target_indices)
|
||||||
|
y_seq_dict_np = {}
|
||||||
|
for t_name in target_names:
|
||||||
|
try:
|
||||||
|
# Attempt to stack; requires consistent shapes
|
||||||
|
if use_ternary and 'dir' in t_name:
|
||||||
|
y_seq_dict_np[t_name] = np.stack(y_seq_dict_list[t_name]).astype(np.float32)
|
||||||
|
else: # Assuming other targets are scalar
|
||||||
|
y_seq_dict_np[t_name] = np.array(y_seq_dict_list[t_name], dtype=np.float32)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"{name}: Error stacking target '{t_name}': {e}. Check target consistency (especially ternary).", exc_info=True)
|
||||||
|
shapes = [getattr(item, 'shape', type(item)) for item in y_seq_dict_list[t_name]]
|
||||||
|
from collections import Counter
|
||||||
|
logger.error(f"Target shapes/types found: {Counter(shapes)}")
|
||||||
|
raise SystemExit(f"{name}: Inconsistent target shapes for '{t_name}' during sequence creation.") from e
|
||||||
|
|
||||||
|
orig_n = X_seq.shape[0]
|
||||||
|
dropped_count = 0
|
||||||
|
|
||||||
|
# Conditionally drop sequences containing imputed bars
|
||||||
|
if drop_imputed:
|
||||||
|
logger.info(f"{name}: Dropping sequences containing imputed bars (drop_imputed_sequences=True)...")
|
||||||
|
valid_mask = ~mask_seq.any(axis=1)
|
||||||
|
X_seq = X_seq[valid_mask]
|
||||||
|
mask_seq = mask_seq[valid_mask] # Keep mask aligned, though not explicitly used later
|
||||||
|
for t_name in target_names:
|
||||||
|
y_seq_dict_np[t_name] = y_seq_dict_np[t_name][valid_mask]
|
||||||
|
target_indices_pd = target_indices_pd[valid_mask]
|
||||||
|
|
||||||
|
dropped_count = orig_n - X_seq.shape[0]
|
||||||
|
logger.info(f"{name}: Generated {orig_n} sequences, dropped {dropped_count} containing imputed bars. Remaining: {X_seq.shape[0]}")
|
||||||
|
|
||||||
|
# Save summary artifact
|
||||||
|
if io:
|
||||||
|
summary_data = {
|
||||||
|
"split_name": name,
|
||||||
|
"total_sequences_generated": orig_n,
|
||||||
|
"sequences_dropped_imputed": dropped_count,
|
||||||
|
"sequences_remaining": X_seq.shape[0],
|
||||||
|
"drop_imputed_sequences_config": drop_imputed
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
filename = f"imputed_sequence_summary_{name.lower()}.json"
|
||||||
|
io.save_json(summary_data, filename, section='results', indent=4)
|
||||||
|
logger.info(f"Saved imputed sequence summary to results/{filename}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save imputed sequence summary for {name}: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"IOManager not available, cannot save imputed sequence summary for {name}.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.info(f"{name}: Generated {orig_n} sequences. Keeping sequences with imputed bars (drop_imputed_sequences=False).")
|
||||||
|
|
||||||
|
# Final checks
|
||||||
|
if X_seq.shape[0] == 0:
|
||||||
|
logger.error(f"{name}: No valid sequences remaining after potential filtering. Aborting.")
|
||||||
|
return None, None, None, dropped_count # Return 0 count if no sequences left
|
||||||
|
|
||||||
|
# --- REMOVE: Final Dictionary Mapping (Let GRU handler manage this) --- #
|
||||||
|
# final_y_seq_dict = {
|
||||||
|
# 'mu': y_seq_dict_np['ret'], # Map 'ret' to 'mu'
|
||||||
|
# 'dir3': y_seq_dict_np['dir3'] # Keep 'dir3' as is
|
||||||
|
# }
|
||||||
|
# --- END REMOVE --- #
|
||||||
|
|
||||||
|
# Log final shapes
|
||||||
|
logger.info(f"Sequence shapes created for {name}:")
|
||||||
|
logger.info(f" X={X_seq.shape}, y_keys={list(y_seq_dict_np.keys())}, indices={len(target_indices_pd)}")
|
||||||
|
|
||||||
|
return X_seq, y_seq_dict_np, target_indices_pd, dropped_count
|
||||||
@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns.
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import logging
|
import logging
|
||||||
|
import gymnasium as gym
|
||||||
|
from omegaconf import DictConfig # Added for config typing
|
||||||
|
|
||||||
env_logger = logging.getLogger(__name__)
|
env_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -15,6 +17,8 @@ class TradingEnv:
|
|||||||
sigma_predictions: np.ndarray,
|
sigma_predictions: np.ndarray,
|
||||||
p_cal_predictions: np.ndarray,
|
p_cal_predictions: np.ndarray,
|
||||||
actual_returns: np.ndarray,
|
actual_returns: np.ndarray,
|
||||||
|
bar_imputed_flags: np.ndarray, # Added imputed flags
|
||||||
|
config: DictConfig, # Added config
|
||||||
initial_capital: float = 10000.0,
|
initial_capital: float = 10000.0,
|
||||||
transaction_cost: float = 0.0005,
|
transaction_cost: float = 0.0005,
|
||||||
reward_scale: float = 100.0,
|
reward_scale: float = 100.0,
|
||||||
@ -27,18 +31,22 @@ class TradingEnv:
|
|||||||
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
|
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
|
||||||
p_cal_predictions: Calibrated probability of price increase (p_cal).
|
p_cal_predictions: Calibrated probability of price increase (p_cal).
|
||||||
actual_returns: Actual log returns (y_ret).
|
actual_returns: Actual log returns (y_ret).
|
||||||
|
bar_imputed_flags: Boolean array indicating if a bar was imputed.
|
||||||
|
config: OmegaConf configuration object.
|
||||||
initial_capital: Starting capital for simulation (used notionally in reward).
|
initial_capital: Starting capital for simulation (used notionally in reward).
|
||||||
transaction_cost: Fractional cost per trade.
|
transaction_cost: Fractional cost per trade.
|
||||||
reward_scale: Multiplier for the reward signal.
|
reward_scale: Multiplier for the reward signal.
|
||||||
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
|
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
|
||||||
"""
|
"""
|
||||||
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \
|
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \
|
||||||
"All input arrays must have the same length"
|
"All input arrays (predictions, returns, imputed_flags) must have the same length"
|
||||||
|
|
||||||
self.mu = mu_predictions
|
self.mu = mu_predictions
|
||||||
self.sigma = sigma_predictions
|
self.sigma = sigma_predictions
|
||||||
self.p_cal = p_cal_predictions
|
self.p_cal = p_cal_predictions
|
||||||
self.actual_returns = actual_returns
|
self.actual_returns = actual_returns
|
||||||
|
self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags
|
||||||
|
self.config = config # Store config
|
||||||
|
|
||||||
self.initial_capital = initial_capital
|
self.initial_capital = initial_capital
|
||||||
self.transaction_cost = transaction_cost
|
self.transaction_cost = transaction_cost
|
||||||
@ -65,20 +73,36 @@ class TradingEnv:
|
|||||||
self.state_dim = 5
|
self.state_dim = 5
|
||||||
self.action_dim = 1
|
self.action_dim = 1
|
||||||
|
|
||||||
|
# --- Define Gym Spaces ---
|
||||||
|
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32)
|
||||||
|
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
|
||||||
|
# --- End Define Gym Spaces ---
|
||||||
|
|
||||||
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
|
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
|
||||||
|
|
||||||
def _get_state(self) -> np.ndarray:
|
def _get_state(self) -> np.ndarray:
|
||||||
"""Construct the state vector for the current step."""
|
"""Construct the state vector for the current step."""
|
||||||
if self.current_step >= self.n_steps:
|
if self.current_step >= self.n_steps:
|
||||||
# Handle episode end - return a dummy state or zeros
|
|
||||||
return np.zeros(self.state_dim, dtype=np.float32)
|
return np.zeros(self.state_dim, dtype=np.float32)
|
||||||
|
|
||||||
mu_t = self.mu[self.current_step]
|
mu_t = self.mu[self.current_step]
|
||||||
sigma_t = self.sigma[self.current_step]
|
sigma_t = self.sigma[self.current_step]
|
||||||
p_cal_t = self.p_cal[self.current_step]
|
p_cal_t = self.p_cal[self.current_step]
|
||||||
|
|
||||||
edge_t = 2 * p_cal_t - 1
|
# Calculate edge based on p_cal shape (binary vs ternary)
|
||||||
z_score_t = np.abs(mu_t) / (sigma_t + 1e-9)
|
if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3:
|
||||||
|
# Ternary: edge = max(P(up), P(down)) - P(flat)
|
||||||
|
# Assuming order [Down, Flat, Up] for p_cal_t
|
||||||
|
edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1]
|
||||||
|
elif isinstance(p_cal_t, (float, np.number)):
|
||||||
|
# Binary: edge = 2 * P(up) - 1
|
||||||
|
edge_t = 2 * p_cal_t - 1
|
||||||
|
else:
|
||||||
|
env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.")
|
||||||
|
edge_t = 0.0
|
||||||
|
|
||||||
|
_EPS = 1e-9 # Define epsilon locally
|
||||||
|
z_score_t = np.abs(mu_t) / (sigma_t + _EPS)
|
||||||
|
|
||||||
# State uses position *before* the action for this step is taken
|
# State uses position *before* the action for this step is taken
|
||||||
state = np.array([
|
state = np.array([
|
||||||
@ -108,11 +132,48 @@ class TradingEnv:
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: (next_state, reward, done, info_dict)
|
tuple: (next_state, reward, done, info_dict)
|
||||||
"""
|
"""
|
||||||
|
info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False}
|
||||||
|
|
||||||
if self.current_step >= self.n_steps:
|
if self.current_step >= self.n_steps:
|
||||||
# Should not happen if 'done' is handled correctly, but as safeguard
|
# Should not happen if 'done' is handled correctly, but as safeguard
|
||||||
env_logger.warning("Step called after environment finished.")
|
env_logger.warning("Step called after environment finished.")
|
||||||
return self._get_state(), 0.0, True, {}
|
return self._get_state(), 0.0, True, info
|
||||||
|
|
||||||
|
# --- Handle Imputed Bar --- #
|
||||||
|
imputed = self.bar_imputed[self.current_step]
|
||||||
|
if imputed:
|
||||||
|
mode = self.config.sac.imputed_handling
|
||||||
|
env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}")
|
||||||
|
if mode == "skip":
|
||||||
|
self.current_step += 1
|
||||||
|
next_state = self._get_state() # Get state for the *next* actual step
|
||||||
|
# Return 0 reward, not done, but indicate skip for buffer handling
|
||||||
|
info['is_imputed_step_skipped'] = True
|
||||||
|
return next_state, 0.0, False, info
|
||||||
|
elif mode == "hold":
|
||||||
|
# Action is forced to maintain current position
|
||||||
|
action = self.current_position
|
||||||
|
elif mode == "penalty":
|
||||||
|
# Calculate reward penalty based on config
|
||||||
|
target_position_penalty = np.clip(action, -1.0, 1.0)
|
||||||
|
reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2
|
||||||
|
# Update position based on agent's intended action (clipped)
|
||||||
|
self.current_position = target_position_penalty
|
||||||
|
# Update capital notionally (no actual return, only cost if implemented)
|
||||||
|
# Cost is implicitly 0 here as there's no trade size if pos doesn't change
|
||||||
|
# If penalty mode allowed position change, cost would apply.
|
||||||
|
# For simplicity, we don't add cost here for the penalty step.
|
||||||
|
self.current_step += 1
|
||||||
|
next_state = self._get_state()
|
||||||
|
scaled_reward = reward * self.reward_scale # Scale the penalty
|
||||||
|
done = self.current_step >= self.n_steps
|
||||||
|
info['capital'] = self.current_capital
|
||||||
|
info['position'] = self.current_position
|
||||||
|
return next_state, scaled_reward, done, info
|
||||||
|
# else: default behavior (treat as normal bar) - implicitly handled by falling through
|
||||||
|
# --- End Handle Imputed Bar --- #
|
||||||
|
|
||||||
|
# --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- #
|
||||||
# Action is the TARGET position for the *end* of this step
|
# Action is the TARGET position for the *end* of this step
|
||||||
target_position = np.clip(action, -1.0, 1.0)
|
target_position = np.clip(action, -1.0, 1.0)
|
||||||
trade_size = target_position - self.current_position
|
trade_size = target_position - self.current_position
|
||||||
@ -150,7 +211,9 @@ class TradingEnv:
|
|||||||
done = self.current_step >= self.n_steps or self.current_capital <= 0
|
done = self.current_step >= self.n_steps or self.current_capital <= 0
|
||||||
|
|
||||||
next_state = self._get_state()
|
next_state = self._get_state()
|
||||||
info = {'capital': self.current_capital, 'position': self.current_position}
|
# Update info dict (capital/position might have changed in normal step)
|
||||||
|
info['capital'] = self.current_capital
|
||||||
|
info['position'] = self.current_position
|
||||||
|
|
||||||
# Log step details periodically
|
# Log step details periodically
|
||||||
# if self.current_step % 1000 == 0:
|
# if self.current_step % 1000 == 0:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
186
gru_sac_predictor/tests/test_sequence_creation.py
Normal file
186
gru_sac_predictor/tests/test_sequence_creation.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Adjust the import path based on your project structure
|
||||||
|
from gru_sac_predictor.src.pipeline_stages.sequence_creation import create_sequences_fold
|
||||||
|
from gru_sac_predictor.src.io_manager import IOManager # Adjust path if needed
|
||||||
|
|
||||||
|
# --- Test Fixtures ---
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_data_with_imputed():
|
||||||
|
"""Creates sample X and y dataframes with a 'bar_imputed' column."""
|
||||||
|
dates = pd.to_datetime(pd.date_range('2023-01-01', periods=20, freq='T'))
|
||||||
|
lookback = 5
|
||||||
|
n_features_orig = 3
|
||||||
|
n_samples = len(dates)
|
||||||
|
|
||||||
|
# Features (including bar_imputed)
|
||||||
|
X_data = pd.DataFrame(
|
||||||
|
np.random.randn(n_samples, n_features_orig),
|
||||||
|
index=dates,
|
||||||
|
columns=[f'feat_{i}' for i in range(n_features_orig)]
|
||||||
|
)
|
||||||
|
# Add bar_imputed column - mark some bars as imputed
|
||||||
|
imputed_flags = np.zeros(n_samples, dtype=bool)
|
||||||
|
imputed_flags[2] = True # Imputed within first potential sequence
|
||||||
|
imputed_flags[8] = True # Imputed within a later potential sequence
|
||||||
|
imputed_flags[15] = True # Imputed near the end
|
||||||
|
X_data['bar_imputed'] = imputed_flags
|
||||||
|
|
||||||
|
# Targets (mu and dir3)
|
||||||
|
y_data = pd.DataFrame({
|
||||||
|
'mu': np.random.randn(n_samples),
|
||||||
|
'dir3': [list(row) for row in np.eye(3)[np.random.randint(0, 3, n_samples)]] # Example one-hot
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
return X_data, y_data
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_config():
|
||||||
|
"""Creates a base OmegaConf config for testing sequence creation."""
|
||||||
|
conf = OmegaConf.create({
|
||||||
|
'gru': {
|
||||||
|
'lookback': 5,
|
||||||
|
'use_ternary': True, # Matches sample_data_with_imputed
|
||||||
|
'drop_imputed_sequences': True # Default to True for testing dropping
|
||||||
|
},
|
||||||
|
# Add other necessary sections if needed
|
||||||
|
})
|
||||||
|
return conf
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_io_manager():
|
||||||
|
"""Creates a mock IOManager for testing artefact saving."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
mock_io = MagicMock(spec=IOManager)
|
||||||
|
mock_io.results_dir = tmpdir
|
||||||
|
saved_jsons = {}
|
||||||
|
def mock_save_json(data, filename, **kwargs):
|
||||||
|
filepath = os.path.join(tmpdir, filename)
|
||||||
|
saved_jsons[filename] = data
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
json.dump(data, f, **kwargs)
|
||||||
|
mock_io.save_json.side_effect = mock_save_json
|
||||||
|
mock_io.get_artifact_path.side_effect = lambda filename: os.path.join(tmpdir, filename)
|
||||||
|
mock_io._saved_jsons = saved_jsons
|
||||||
|
yield mock_io
|
||||||
|
|
||||||
|
# --- Test Functions ---
|
||||||
|
|
||||||
|
def test_sequence_creation_shapes(sample_data_with_imputed, base_config, mock_io_manager):
|
||||||
|
X_data, y_data = sample_data_with_imputed
|
||||||
|
lookback = base_config.gru.lookback
|
||||||
|
n_features = X_data.shape[1]
|
||||||
|
n_samples = len(X_data)
|
||||||
|
expected_n_seq = n_samples - lookback
|
||||||
|
|
||||||
|
# Test without dropping imputed
|
||||||
|
cfg_no_drop = base_config.copy()
|
||||||
|
cfg_no_drop.gru.drop_imputed_sequences = False
|
||||||
|
|
||||||
|
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
|
||||||
|
X_data=X_data, y_data=y_data, target_names=['mu', 'dir3'],
|
||||||
|
lookback=lookback, name="TestSplit", config=cfg_no_drop, io=mock_io_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dropped_count == 0
|
||||||
|
assert X_seq is not None
|
||||||
|
assert y_seq_dict is not None
|
||||||
|
assert indices is not None
|
||||||
|
assert X_seq.shape == (expected_n_seq, lookback, n_features)
|
||||||
|
assert 'mu' in y_seq_dict and y_seq_dict['mu'].shape == (expected_n_seq,)
|
||||||
|
assert 'dir3' in y_seq_dict and y_seq_dict['dir3'].shape == (expected_n_seq, 3)
|
||||||
|
assert len(indices) == expected_n_seq
|
||||||
|
# Check first target index corresponds to lookback-th original index
|
||||||
|
assert indices[0] == X_data.index[lookback]
|
||||||
|
# Check last target index corresponds to last original index
|
||||||
|
assert indices[-1] == X_data.index[-1]
|
||||||
|
|
||||||
|
def test_sequence_dropping_imputed(sample_data_with_imputed, base_config, mock_io_manager):
|
||||||
|
X_data, y_data = sample_data_with_imputed
|
||||||
|
lookback = base_config.gru.lookback
|
||||||
|
n_samples = len(X_data)
|
||||||
|
expected_n_seq_orig = n_samples - lookback
|
||||||
|
|
||||||
|
# Config with dropping enabled (default in fixture)
|
||||||
|
cfg_drop = base_config
|
||||||
|
|
||||||
|
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
|
||||||
|
X_data=X_data.copy(), y_data=y_data.copy(), target_names=['mu', 'dir3'],
|
||||||
|
lookback=lookback, name="TestDrop", config=cfg_drop, io=mock_io_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
assert X_seq is not None
|
||||||
|
assert y_seq_dict is not None
|
||||||
|
assert indices is not None
|
||||||
|
|
||||||
|
# Determine which original sequences should have been dropped
|
||||||
|
# A sequence starting at index i uses data from [i, i+lookback-1]
|
||||||
|
# The target corresponds to index i+lookback
|
||||||
|
# We need to check the imputed flag in the range [i, i+lookback-1] for each potential sequence target index i+lookback
|
||||||
|
|
||||||
|
# Original target indices range from index `lookback` to `n_samples - 1`
|
||||||
|
should_be_dropped_mask = np.zeros(expected_n_seq_orig, dtype=bool)
|
||||||
|
imputed_flags_np = X_data['bar_imputed'].values
|
||||||
|
for seq_idx in range(expected_n_seq_orig):
|
||||||
|
# The features for this sequence are from original indices [seq_idx, seq_idx + lookback - 1]
|
||||||
|
feature_indices_range = slice(seq_idx, seq_idx + lookback)
|
||||||
|
if np.any(imputed_flags_np[feature_indices_range]):
|
||||||
|
should_be_dropped_mask[seq_idx] = True
|
||||||
|
|
||||||
|
expected_dropped_count = np.sum(should_be_dropped_mask)
|
||||||
|
expected_remaining_count = expected_n_seq_orig - expected_dropped_count
|
||||||
|
|
||||||
|
assert dropped_count == expected_dropped_count
|
||||||
|
assert X_seq.shape[0] == expected_remaining_count
|
||||||
|
assert y_seq_dict['mu'].shape[0] == expected_remaining_count
|
||||||
|
assert y_seq_dict['dir3'].shape[0] == expected_remaining_count
|
||||||
|
assert len(indices) == expected_remaining_count
|
||||||
|
|
||||||
|
# Check that the remaining indices are correct (weren't marked for dropping)
|
||||||
|
original_indices = X_data.index[lookback:]
|
||||||
|
expected_remaining_indices = original_indices[~should_be_dropped_mask]
|
||||||
|
pd.testing.assert_index_equal(indices, expected_remaining_indices)
|
||||||
|
|
||||||
|
# Check artifact saving
|
||||||
|
assert 'imputed_sequence_summary_testdrop.json' in mock_io_manager._saved_jsons
|
||||||
|
report_data = mock_io_manager._saved_jsons['imputed_sequence_summary_testdrop.json']
|
||||||
|
assert report_data['total_sequences_generated'] == expected_n_seq_orig
|
||||||
|
assert report_data['sequences_dropped_imputed'] == expected_dropped_count
|
||||||
|
assert report_data['sequences_remaining'] == expected_remaining_count
|
||||||
|
|
||||||
|
def test_sequence_creation_no_imputed_col(sample_data_with_imputed, base_config, mock_io_manager):
|
||||||
|
X_data, y_data = sample_data_with_imputed
|
||||||
|
X_data_no_imputed = X_data.drop(columns=['bar_imputed'])
|
||||||
|
lookback = base_config.gru.lookback
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
|
create_sequences_fold(
|
||||||
|
X_data=X_data_no_imputed, y_data=y_data, target_names=['mu', 'dir3'],
|
||||||
|
lookback=lookback, name="TestNoImputedCol", config=base_config, io=mock_io_manager
|
||||||
|
)
|
||||||
|
assert "'bar_imputed' column missing" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_sequence_creation_insufficient_data(sample_data_with_imputed, base_config, mock_io_manager):
|
||||||
|
X_data, y_data = sample_data_with_imputed
|
||||||
|
lookback = base_config.gru.lookback
|
||||||
|
# Create data shorter than lookback
|
||||||
|
X_short = X_data.iloc[:lookback-1]
|
||||||
|
y_short = y_data.iloc[:lookback-1]
|
||||||
|
|
||||||
|
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
|
||||||
|
X_data=X_short, y_data=y_short, target_names=['mu', 'dir3'],
|
||||||
|
lookback=lookback, name="TestShort", config=base_config, io=mock_io_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
assert X_seq is None
|
||||||
|
assert y_seq_dict is None
|
||||||
|
assert indices is None
|
||||||
|
assert dropped_count == 0
|
||||||
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
# Adjust import path based on structure
|
||||||
|
from gru_sac_predictor.src.trading_env import TradingEnv
|
||||||
|
|
||||||
|
# --- Test Fixtures ---
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_env_data():
|
||||||
|
"""Provides sample data for initializing the TradingEnv."""
|
||||||
|
n_steps = 10
|
||||||
|
data = {
|
||||||
|
'mu_predictions': np.random.randn(n_steps) * 0.001,
|
||||||
|
'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005),
|
||||||
|
'p_cal_predictions': np.random.rand(n_steps),
|
||||||
|
'actual_returns': np.random.randn(n_steps) * 0.0015,
|
||||||
|
'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_env_config():
|
||||||
|
"""Base configuration for the environment."""
|
||||||
|
return OmegaConf.create({
|
||||||
|
'sac': {
|
||||||
|
'imputed_handling': 'skip', # Default test mode
|
||||||
|
'action_penalty': 0.05
|
||||||
|
},
|
||||||
|
'environment': {
|
||||||
|
'initial_capital': 10000.0,
|
||||||
|
'transaction_cost': 0.0005,
|
||||||
|
'reward_scale': 100.0,
|
||||||
|
'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def trading_env_instance(sample_env_data, base_env_config):
|
||||||
|
"""Creates a TradingEnv instance with default 'skip' mode."""
|
||||||
|
return TradingEnv(**sample_env_data, config=base_env_config)
|
||||||
|
|
||||||
|
# --- Test Functions ---
|
||||||
|
|
||||||
|
def test_env_initialization(trading_env_instance, sample_env_data):
|
||||||
|
assert trading_env_instance.n_steps == len(sample_env_data['actual_returns'])
|
||||||
|
assert trading_env_instance.current_step == 0
|
||||||
|
assert trading_env_instance.current_position == 0.0
|
||||||
|
assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags'])
|
||||||
|
|
||||||
|
def test_env_reset(trading_env_instance):
|
||||||
|
# Take a few steps
|
||||||
|
trading_env_instance.step(0.5)
|
||||||
|
trading_env_instance.step(-0.2)
|
||||||
|
assert trading_env_instance.current_step > 0
|
||||||
|
# Reset
|
||||||
|
initial_state = trading_env_instance.reset()
|
||||||
|
assert trading_env_instance.current_step == 0
|
||||||
|
assert trading_env_instance.current_position == 0.0
|
||||||
|
assert initial_state is not None
|
||||||
|
assert initial_state.shape == (trading_env_instance.state_dim,)
|
||||||
|
|
||||||
|
def test_env_step_normal(trading_env_instance):
|
||||||
|
# Test a normal step (step 0 is not imputed)
|
||||||
|
initial_pos = trading_env_instance.current_position
|
||||||
|
action = 0.7
|
||||||
|
next_state, reward, done, info = trading_env_instance.step(action)
|
||||||
|
|
||||||
|
assert trading_env_instance.current_step == 1
|
||||||
|
assert trading_env_instance.current_position == action # Position updates to action
|
||||||
|
assert not info['is_imputed_step_skipped']
|
||||||
|
assert not done
|
||||||
|
assert next_state is not None
|
||||||
|
# Reward calculation is complex, just check type/sign if needed
|
||||||
|
assert isinstance(reward, float)
|
||||||
|
|
||||||
|
def test_env_step_imputed_skip(trading_env_instance, sample_env_data):
|
||||||
|
# Step 2 is imputed in sample_env_data
|
||||||
|
trading_env_instance.step(0.5) # Step 0
|
||||||
|
trading_env_instance.step(0.6) # Step 1
|
||||||
|
assert trading_env_instance.current_step == 2
|
||||||
|
initial_pos_before_imputed = trading_env_instance.current_position
|
||||||
|
|
||||||
|
# Action for the imputed step (should be ignored by 'skip')
|
||||||
|
action_imputed = 0.9
|
||||||
|
next_state, reward, done, info = trading_env_instance.step(action_imputed)
|
||||||
|
|
||||||
|
# Should skip step 2 and now be at step 3
|
||||||
|
assert trading_env_instance.current_step == 3
|
||||||
|
# Position should NOT have changed from step 1
|
||||||
|
assert trading_env_instance.current_position == initial_pos_before_imputed
|
||||||
|
assert reward == 0.0 # Skip gives 0 reward
|
||||||
|
assert not done
|
||||||
|
assert info['is_imputed_step_skipped'] == True # Crucial check for buffer
|
||||||
|
# Check that the returned state is for step 3
|
||||||
|
expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3
|
||||||
|
np.testing.assert_array_almost_equal(next_state, expected_state_step_3)
|
||||||
|
|
||||||
|
def test_env_step_imputed_hold(sample_env_data, base_env_config):
|
||||||
|
cfg = base_env_config.copy()
|
||||||
|
cfg.sac.imputed_handling = 'hold'
|
||||||
|
env = TradingEnv(**sample_env_data, config=cfg)
|
||||||
|
|
||||||
|
# Step 2 is imputed
|
||||||
|
env.step(0.5) # Step 0
|
||||||
|
env.step(0.6) # Step 1
|
||||||
|
assert env.current_step == 2
|
||||||
|
position_before_imputed = env.current_position
|
||||||
|
|
||||||
|
# Action for the imputed step (should be overridden by 'hold')
|
||||||
|
action_imputed = -0.5
|
||||||
|
next_state, reward, done, info = env.step(action_imputed)
|
||||||
|
|
||||||
|
# Should process step 2 and move to step 3
|
||||||
|
assert env.current_step == 3
|
||||||
|
# Position should be the same as before the step
|
||||||
|
assert env.current_position == position_before_imputed
|
||||||
|
assert not info['is_imputed_step_skipped']
|
||||||
|
assert not done
|
||||||
|
# Reward should be calculated based on holding the position
|
||||||
|
expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1)
|
||||||
|
expected_cost = 0 # No trade size if holding
|
||||||
|
expected_penalty = 0 # No penalty in hold mode
|
||||||
|
expected_raw_reward = expected_pnl - expected_cost - expected_penalty
|
||||||
|
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||||
|
assert np.isclose(reward, expected_scaled_reward)
|
||||||
|
|
||||||
|
def test_env_step_imputed_penalty(sample_env_data, base_env_config):
|
||||||
|
cfg = base_env_config.copy()
|
||||||
|
cfg.sac.imputed_handling = 'penalty'
|
||||||
|
cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing
|
||||||
|
env = TradingEnv(**sample_env_data, config=cfg)
|
||||||
|
|
||||||
|
# Step 2 is imputed
|
||||||
|
env.step(0.5) # Step 0
|
||||||
|
env.step(0.6) # Step 1
|
||||||
|
assert env.current_step == 2
|
||||||
|
position_before_imputed = env.current_position # Should be 0.6
|
||||||
|
|
||||||
|
# Action for the imputed step
|
||||||
|
action_imputed = -0.2
|
||||||
|
next_state, reward, done, info = env.step(action_imputed)
|
||||||
|
|
||||||
|
# Should process step 2 and move to step 3
|
||||||
|
assert env.current_step == 3
|
||||||
|
# Position should update to the *agent's* action
|
||||||
|
assert env.current_position == np.clip(action_imputed, -1.0, 1.0)
|
||||||
|
assert not info['is_imputed_step_skipped']
|
||||||
|
assert not done
|
||||||
|
|
||||||
|
# Reward calculation is ONLY the penalty
|
||||||
|
expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2
|
||||||
|
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||||
|
assert np.isclose(reward, expected_scaled_reward)
|
||||||
|
|
||||||
|
def test_env_done_condition(trading_env_instance, sample_env_data):
|
||||||
|
n_steps = len(sample_env_data['actual_returns'])
|
||||||
|
# Step through the environment
|
||||||
|
done = False
|
||||||
|
for i in range(n_steps):
|
||||||
|
_, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1))
|
||||||
|
if i < n_steps - 1:
|
||||||
|
assert not done
|
||||||
|
else:
|
||||||
|
assert done # Should be done on the last step
|
||||||
Loading…
x
Reference in New Issue
Block a user