Compare commits

...

4 Commits

6 changed files with 2295 additions and 2219 deletions

View File

@ -0,0 +1,711 @@
# Stage functions for loading, initial preprocessing, feature engineering, label generation, and splitting
import logging
import sys # Added for sys.exit
from datetime import datetime, timezone # Added for datetime
import pandas as pd
import numpy as np
from typing import Tuple, Optional, Any, List, Dict # Added List and Dict
import matplotlib.pyplot as plt # Added for plotting
import seaborn as sns # Added for plotting
# --- Component Imports --- #
# Assuming DataLoader is in the parent directory's src
# This might need adjustment based on actual project structure
# Using relative import assuming pipeline_stages is sibling to other src modules
from ..data_loader import DataLoader, fill_missing_bars
from ..feature_engineer import FeatureEngineer # Added FeatureEngineer import
from ..io_manager import IOManager # Added IOManager import
from ..metrics import calculate_sharpe_ratio # For potential baseline comparison
# --- Local Imports --- #
# Import the label generation function we moved here
# Removed duplicate import: from .data_processing import generate_direction_labels
# Assuming tensorflow is installed and available
try:
from tensorflow.keras.utils import to_categorical
except ImportError:
logging.warning("TensorFlow/Keras not found. Ternary label one-hot encoding will fail.")
# Define a placeholder if keras is not available
def to_categorical(*args, **kwargs):
raise NotImplementedError("Keras 'to_categorical' is unavailable.")
logger = logging.getLogger(__name__) # Use module-level logger
# --- Refactored Label Generation Logic (Moved from trading_pipeline.py) --- #
def generate_direction_labels(df: pd.DataFrame, config: dict) -> Tuple[pd.DataFrame, str, pd.Series, Optional[pd.Series]]:
"""
Calculates forward returns and generates binary, soft binary, or ternary direction labels.
Also returns the raw forward returns and the epsilon series used for ternary flat definition.
Args:
df (pd.DataFrame): DataFrame containing at least a 'close' column and DatetimeIndex.
config (dict): Pipeline configuration dictionary, expecting keys under 'gru' and 'data'.
Returns:
tuple[pd.DataFrame, str, pd.Series, Optional[pd.Series]]:
- DataFrame with added forward return and direction label columns (and NaNs dropped based on labels).
- Name of the generated direction label column.
- Series containing the calculated forward log returns (`fwd_log_ret`).
- Series containing the calculated epsilon (`eps`) threshold if ternary, else None.
"""
if 'close' not in df.columns:
raise ValueError("'close' column missing in input DataFrame for label generation.")
gru_cfg = config.get('gru', {})
data_cfg = config.get('data', {})
horizon = gru_cfg.get('prediction_horizon', 5)
use_ternary = gru_cfg.get('use_ternary', False) # Check if ternary flag is set
target_ret_col = f'fwd_log_ret_{horizon}'
eps_series: Optional[pd.Series] = None # Initialize eps
# --- Calculate Forward Log Return --- #
shifted_close = df['close'].shift(-horizon)
fwd_returns = np.log(shifted_close / df['close'])
df[target_ret_col] = fwd_returns
# --- Generate Direction Label (Binary/Soft or Ternary) --- #
if use_ternary:
k = gru_cfg.get('flat_sigma_multiplier', 0.25)
target_dir_col = f'direction_label3_{horizon}'
logger.info(f"Generating ternary labels ({target_dir_col}) with k={k}...")
sigma_n = fwd_returns.rolling(window=horizon, min_periods=max(1, horizon//2)).std()
eps = k * sigma_n
eps_series = eps # Store the calculated eps series
conditions = [fwd_returns > eps, fwd_returns < -eps]
choices = [2, 0] # 2=up, 0=down
ordinal_labels = np.select(conditions, choices, default=1).astype(int) # 1=flat
# --- Log Distribution & Check Balance --- #
df['_ordinal_label_temp'] = ordinal_labels
valid_mask_for_dist = ~np.isnan(eps) & ~np.isnan(fwd_returns)
ordinal_labels_valid = df.loc[valid_mask_for_dist, '_ordinal_label_temp']
if not ordinal_labels_valid.empty:
counts = np.bincount(ordinal_labels_valid, minlength=3)
total_valid = len(ordinal_labels_valid)
if total_valid > 0: # Avoid division by zero
dist_pct = counts / total_valid * 100
log_msg = (f"Label dist (n={total_valid}): "
f"Down(0)={dist_pct[0]:.1f}%, Flat(1)={dist_pct[1]:.1f}%, Up(2)={dist_pct[2]:.1f}%")
logger.info(log_msg)
min_pct_threshold = 10.0 # As per implementation
if any(p < min_pct_threshold for p in dist_pct):
error_msg = f"Label imbalance detected! Min class percentage is {np.min(dist_pct):.1f}% (Threshold: {min_pct_threshold}%). Check data or flat_sigma_multiplier (k={k})."
logger.error(error_msg)
print(f"ERROR: {error_msg}") # Also print for visibility
else:
logger.warning("Label distribution check skipped: total valid labels is zero.")
else:
logger.warning("Could not calculate label distribution (no valid sigma or returns).")
# --- End Distribution Check --- #
# --- One-hot encode --- #
try:
y_cat_full = np.full((len(df), 3), np.nan, dtype=np.float32)
if not ordinal_labels_valid.empty:
y_cat_valid = to_categorical(ordinal_labels_valid, num_classes=3)
y_cat_full[valid_mask_for_dist] = y_cat_valid.astype(np.float32)
else:
logger.warning("No valid ordinal labels to one-hot encode.")
# Assign the list of arrays (or NaNs) - using list avoids mixed type issues later
df[target_dir_col] = [list(row) if not np.all(np.isnan(row)) else np.nan for row in y_cat_full]
except NotImplementedError as nie:
logger.error(f"Ternary label generation failed: {nie}. Keras 'to_categorical' is unavailable. Please install tensorflow.", exc_info=True)
raise # Re-raise exception to halt pipeline
except Exception as e:
logger.error(f"Error during one-hot encoding: {e}", exc_info=True)
raise # Re-raise exception to halt pipeline if encoding fails
finally:
if '_ordinal_label_temp' in df.columns:
df.drop(columns=['_ordinal_label_temp'], inplace=True)
# --- End One-hot Encoding --- #
else: # Binary / Soft Binary
target_dir_col = f'direction_label_{horizon}'
label_smoothing = data_cfg.get('label_smoothing', 0.0)
if not (0.0 <= label_smoothing < 1.0):
logger.warning(f"Invalid label_smoothing value ({label_smoothing}). Must be in [0.0, 1.0). Disabling smoothing.")
label_smoothing = 0.0
if label_smoothing > 0.0:
high_label = 1.0 - label_smoothing / 2.0
low_label = label_smoothing / 2.0
logger.info(f"Applying label smoothing: {label_smoothing:.2f} -> labels [{low_label:.2f}, {high_label:.2f}] for {target_dir_col}")
df[target_dir_col] = np.where(fwd_returns > 0, high_label, low_label).astype(np.float32)
else:
logger.info(f"Using hard binary labels (0.0 / 1.0) for {target_dir_col}")
df[target_dir_col] = (fwd_returns > 0).astype(np.float32)
# --- Drop Rows with NaN Targets --- #
initial_rows = len(df)
# Create mask for NaNs in the direction column
if use_ternary:
# Check if elements are np.nan (since we assign np.nan for rows with no valid labels)
nan_mask_dir = df[target_dir_col].isna()
else:
nan_mask_dir = df[target_dir_col].isna()
nan_mask_combined = df[target_ret_col].isna() | nan_mask_dir
df_clean = df[~nan_mask_combined].copy()
final_rows = len(df_clean)
if final_rows < initial_rows:
logger.info(f"Dropped {initial_rows - final_rows} rows due to NaN targets (horizon={horizon}).")
if df_clean.empty:
logger.error("DataFrame is empty after defining labels and dropping NaNs. Exiting.")
# Returning empty DataFrame, caller should handle exit
return pd.DataFrame(), target_dir_col, pd.Series(dtype=float), None # Return empty series/None on failure
# Return the cleaned df, target col name, and the *original* full fwd_returns and eps series
# Need to return the original series aligned with the original df index *before* cleaning
# So the caller can align them with the features *after* cleaning df_clean
return df_clean, target_dir_col, fwd_returns, eps_series
# --- End Label Generation --- #
# --- Stage 1: Load and Preprocess Data (Moved from TradingPipeline.load_and_preprocess_data) --- #
def load_and_preprocess(
data_loader: DataLoader,
io: Optional[IOManager],
run_id: str,
config: Dict[str, Any]
) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]:
"""
Loads the full raw dataset using DataLoader and performs initial checks.
Args:
data_loader: Initialized DataLoader instance.
io: IOManager instance (optional).
run_id: Current run ID.
config: Pipeline configuration dictionary.
Returns:
Tuple containing:
- DataFrame with raw loaded data, or None on failure.
- Dictionary summarizing the loading process, or None on failure.
"""
logger.info("--- Stage: Loading and Preprocessing Data ---")
data_cfg = config.get('data', {})
# --- Extract necessary parameters from config --- #
ticker = data_cfg.get('ticker')
exchange = data_cfg.get('exchange')
start_date = data_cfg.get('start_date')
end_date = data_cfg.get('end_date')
interval = data_cfg.get('interval', '1min') # Default to 1min
vol_sampling = data_cfg.get('volatility_sampling', {}).get('enabled', False)
vol_window = data_cfg.get('volatility_sampling', {}).get('window', 30)
vol_quantile = data_cfg.get('volatility_sampling', {}).get('quantile', 0.5)
# Validate required parameters
if not all([ticker, exchange, start_date, end_date]):
logger.error("Missing required data parameters in config: ticker, exchange, start_date, end_date")
return None, None
# --- End Parameter Extraction --- #
load_summary = {
'ticker': ticker,
'exchange': exchange,
'start_date_req': start_date,
'end_date_req': end_date,
'interval_req': interval,
'vol_sampling_enabled': vol_sampling,
'vol_window': vol_window,
'vol_quantile': vol_quantile,
}
try:
logger.info(f"Loading data for {ticker} ({exchange}) from {start_date} to {end_date}, interval {interval}")
# --- Pass extracted parameters to load_data --- #
df_raw = data_loader.load_data(
ticker=ticker,
exchange=exchange,
start_date=start_date,
end_date=end_date,
interval=interval,
vol_sampling=vol_sampling,
vol_window=vol_window,
vol_quantile=vol_quantile
)
# --- End Pass Parameters --- #
if df_raw is None or df_raw.empty:
logger.error("Data loading returned empty DataFrame or failed.")
return None, load_summary
# --- Fill Missing Bars (Step 2.5 from prompts/missing_data.txt) --- #
if io is None:
logger.error("IOManager is required for fill_missing_bars reporting. Cannot proceed.")
return None, load_summary
try:
df_filled = fill_missing_bars(df_raw, config, io, logger)
if df_filled is None or df_filled.empty:
logger.error("fill_missing_bars returned empty DataFrame or failed.")
return None, load_summary
df_raw = df_filled # Replace df_raw with the filled version
logger.info("Missing bars handled successfully.")
except ValueError as e:
logger.error(f"Error during missing bar handling: {e}. Halting processing.")
return None, load_summary
except Exception as e:
logger.error(f"Unexpected error during missing bar handling: {e}. Halting processing.", exc_info=True)
return None, load_summary
# --- End Fill Missing Bars --- #
# Calculate memory usage and log info
mem_usage = df_raw.memory_usage(deep=True).sum() / (1024**2)
if load_summary:
logger.info(f"Data loading summary: {load_summary}")
else:
logger.warning("No load summary returned by DataLoader.")
logger.info(f"Loaded data: {df_raw.shape[0]} rows, {df_raw.shape[1]} columns. Memory: {mem_usage:.2f} MB")
logger.info(f"Time range: {df_raw.index.min()} to {df_raw.index.max()}")
# --- V3 Output Contract: Stage 1 Artifacts --- #
if io:
if load_summary:
save_summary = load_summary.copy() # Don't modify original
save_summary['run_id'] = run_id
save_summary['timestamp_utc'] = datetime.now(timezone.utc).isoformat()
# TODO: Finalize summary content (add counts, NaN info etc.)
logger.info("Saving preprocess summary...")
io.save_json(save_summary, "preprocess_summary", use_txt=True) # Spec wants .txt
# Save head of preprocessed data
if df_raw is not None and not df_raw.empty:
logger.info("Saving head of preprocessed data (first 20 rows)...")
io.save_df(df_raw.head(20), "head_preprocessed")
else:
logger.warning("Skipping saving head_preprocessed: DataFrame is empty or None.")
else:
logger.warning("IOManager not available, skipping saving of Stage 1 artifacts (preprocess_summary, head_preprocessed).")
# --- End V3 Output Contract ---
# --- V3 Output Contract: Stage 2 Artifact (Label Histogram) --- #
# TODO: Move this plotting logic to evaluation stage or after split, needs y_train.
# if io and config.get('control', {}).get('generate_plots', True):
# logger.info("Generating training label distribution histogram... [SKIPPED IN CURRENT STAGE]")
# ... (Original plotting code removed from here)
# --- End V3 Output Contract ---
return df_raw, load_summary
except Exception as e:
logger.error(f"Error during data loading: {e}", exc_info=True)
return None, None
# --- Stage 2: Engineer Features (Moved from TradingPipeline.engineer_features) --- #
def engineer_features_for_fold(
df: pd.DataFrame,
feature_engineer: FeatureEngineer,
io: Optional[IOManager], # Added IOManager for saving figure
config: Dict[str, Any], # Added config for plot settings
target_col: Optional[str] = None # Added target column name for sorting correlation
) -> pd.DataFrame:
"""Adds features using FeatureEngineer, handles NaNs, and saves correlation heatmap for a fold.
Args:
df (pd.DataFrame): Input DataFrame for the fold (typically raw data).
feature_engineer (FeatureEngineer): Initialized FeatureEngineer instance.
io (Optional[IOManager]): IOManager instance for saving artifacts.
config (Dict[str, Any]): Pipeline configuration dictionary.
target_col (Optional[str]): Name of the target column to sort correlations by (e.g., 'fwd_log_ret_5').
Returns:
pd.DataFrame: DataFrame with engineered features, NaNs dropped.
Returns an empty DataFrame if input is empty or result is empty.
"""
logger.info("--- Stage: Engineering Features --- ")
if df is None or df.empty:
logger.error("Input DataFrame is empty. Cannot engineer features.")
return pd.DataFrame() # Return empty DataFrame to indicate failure
if feature_engineer is None:
logger.error("FeatureEngineer not initialized. Cannot engineer features.")
# Or raise an error? For now return empty
return pd.DataFrame()
# Add base features (cyclical, imbalance, TA)
df_engineered = feature_engineer.add_base_features(df.copy())
# --- V3 Output Contract: Feature Correlation Heatmap --- #
# Generate heatmap *before* dropping NaNs to capture full feature set correlations
# if io and config.get('control', {}).get('generate_plots', True): # Check if plotting is enabled
if io: # Assume generate_plots is implicitly true if io is provided
try:
logger.info("Generating feature correlation heatmap...")
numeric_cols = df_engineered.select_dtypes(include=np.number).columns
if len(numeric_cols) < 2:
logger.warning("Skipping correlation heatmap: Less than 2 numeric columns found.")
else:
corr_matrix = df_engineered[numeric_cols].corr(method='pearson')
# Get plot settings from config
output_cfg = config.get('output', {})
fig_size = output_cfg.get('figure_size', [16, 9])
plot_style = output_cfg.get('plot_style', 'seaborn-v0_8-darkgrid')
annot_threshold = output_cfg.get('corr_annot_threshold', 0.5)
plot_footer = output_cfg.get('plot_footer', "© GRU-SAC v3")
plt.style.use(plot_style)
fig, ax = plt.subplots(figsize=fig_size)
sort_features = False
if target_col and target_col in corr_matrix.columns:
# Sort by absolute correlation with the target
target_corr = corr_matrix[target_col].abs().sort_values(ascending=False)
sorted_cols = target_corr.index.tolist()
corr_matrix_sorted = corr_matrix.loc[sorted_cols, sorted_cols]
sort_features = True
else:
if target_col:
logger.warning(f"Target column '{target_col}' not found in correlation matrix. Heatmap will not be sorted by target correlation.")
corr_matrix_sorted = corr_matrix # Use original matrix if no target or not found
sns.heatmap(
corr_matrix_sorted,
annot=False, # Annotations can be messy; spec only requires > threshold
cmap='coolwarm', # Diverging palette centered at 0
center=0,
linewidths=0.5,
cbar=True,
square=True, # Ensure square cells
ax=ax
)
# Annotate cells where absolute correlation > threshold (from config)
for i in range(corr_matrix_sorted.shape[0]):
for j in range(corr_matrix_sorted.shape[1]):
if abs(corr_matrix_sorted.iloc[i, j]) > annot_threshold and i != j:
ax.text(j + 0.5, i + 0.5, f'{corr_matrix_sorted.iloc[i, j]:.2f}',
ha='center', va='center', color='black', fontsize=8)
title = "Feature Correlation Heatmap (Pearson)"
if sort_features:
title += f" - Sorted by |ρ| vs '{target_col}'"
ax.set_title(title, fontsize=14)
plt.xticks(rotation=90, fontsize=8)
plt.yticks(rotation=0, fontsize=8)
# Add footer (from config)
if plot_footer: # Only add if footer is not empty
plt.figtext(0.99, 0.01, plot_footer, ha="right", va="bottom", fontsize=8, color='gray')
# Save figure using IOManager
io.save_figure(fig, "feature_corr_heatmap", section='figures') # Saved to results/<run_id>/figures/
plt.close(fig) # Close figure after saving
logger.info("Saved feature correlation heatmap.")
except Exception as e:
logger.error(f"Failed to generate or save feature correlation heatmap: {e}", exc_info=True)
else:
logger.warning("IOManager not provided or plotting disabled, skipping feature correlation heatmap.")
# --- End V3 Output Contract --- #
# --- REMOVE Aggressive DropNA --- #
# Dropping all rows with any NaN here is too aggressive, especially with long lookback features.
# NaN handling should occur within feature calculation methods (bfill/ffill/fillna(0))
# and critically during label definition (where rows without valid labels are dropped).
# initial_rows = len(df_engineered)
# df_engineered.dropna(inplace=True)
# rows_dropped = initial_rows - len(df_engineered)
# if rows_dropped > 0:
# logger.warning(f"Dropped {rows_dropped} rows with NaN values after feature engineering.")
# --- End REMOVE --- #
# Check if dataframe became empty *after feature calculation and internal NaN handling*
# (Though ideally internal handling should prevent this)
if df_engineered.empty:
logger.error("DataFrame is empty after feature engineering (check internal NaN handling in FeatureEngineer)." )
return pd.DataFrame() # Return empty DataFrame
logger.info(f"Feature engineering complete. Shape: {df_engineered.shape}")
return df_engineered
# --- Stage 3: Define Labels and Align (Moved from TradingPipeline.define_labels_and_align) --- #
def define_labels_and_align_fold(
df_engineered: pd.DataFrame,
config: dict
) -> Tuple[pd.DataFrame, str, List[str], pd.Series, Optional[pd.Series]]:
"""Defines prediction labels, aligns with features, and separates targets for a fold.
Also returns the raw forward returns and epsilon series used for filtering baselines.
Args:
df_engineered (pd.DataFrame): DataFrame with engineered features for the fold.
config (dict): Pipeline configuration dictionary.
Returns:
Tuple[pd.DataFrame, str, List[str], pd.Series, Optional[pd.Series]]:
- df_labeled_aligned: DataFrame with labels generated and features/targets aligned (NaNs dropped).
- target_dir_col: Name of the direction label column.
- target_cols: List containing names of all target columns (ret + dir).
- fwd_returns_aligned: Series of forward returns aligned with df_labeled_aligned.
- eps_aligned: Series of epsilon threshold aligned with df_labeled_aligned (or None).
Returns (pd.DataFrame(), "", [], pd.Series(), None) on failure or empty input.
"""
logger.info("--- Stage: Defining Labels and Aligning --- ")
if df_engineered is None or df_engineered.empty:
logger.error("Engineered data (DataFrame) is empty. Cannot define labels.")
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
# --- Call the label generation function (already in this module) --- #
try:
# generate_direction_labels modifies the DataFrame in place and returns it
# It also returns the original fwd_returns and eps series (aligned with df_engineered)
df_clean, target_dir_col, fwd_returns_orig, eps_orig = generate_direction_labels(
df_engineered.copy(), # Pass a copy to avoid modifying original outside this scope if needed
config
)
except Exception as e:
logger.error(f"Label generation failed: {e}.", exc_info=True)
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
if df_clean.empty:
logger.error("Label generation resulted in an empty DataFrame.")
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
# --- End Label Generation Call --- #
# --- Determine Target Columns --- #
horizon = config.get('gru', {}).get('prediction_horizon', 5)
target_ret_col = f'fwd_log_ret_{horizon}'
# target_dir_col is returned by generate_direction_labels
target_cols = [target_ret_col, target_dir_col]
# Ensure the columns actually exist after generation and cleaning
if not all(col in df_clean.columns for col in target_cols):
# Log which columns are actually present for debugging
present_cols = df_clean.columns.tolist()
logger.error(f"Generated label/return columns ({target_cols}) not found in DataFrame after label generation. Present columns: {present_cols}")
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
# --- End Determine Target Columns --- #
# --- Align fwd_returns_orig and eps_orig with the cleaned DataFrame --- #
fwd_returns_aligned = fwd_returns_orig.loc[df_clean.index]
eps_aligned = eps_orig.loc[df_clean.index] if eps_orig is not None else None
# --- End Alignment --- #
# Note: Separation of X and y happens in the splitting function now
# We just need to return the fully labeled/aligned DataFrame and target column names.
logger.info(f"Labels defined and aligned. Shape: {df_clean.shape}")
# Return the aligned DataFrame and the aligned supplementary series
return df_clean, target_dir_col, target_cols, fwd_returns_aligned, eps_aligned
# --- Stage 4: Split Data (Moved from TradingPipeline.split_data) --- #
def split_data_fold(
df_labeled_aligned: pd.DataFrame,
fwd_returns_aligned: pd.Series,
eps_aligned: Optional[pd.Series],
config: dict,
target_columns: List[str],
target_dir_col: str,
fold_dates: Optional[Tuple] = None,
current_fold: Optional[int] = None # For logging
) -> Tuple[
# Features
pd.DataFrame, pd.DataFrame, pd.DataFrame,
# Original Targets
pd.DataFrame, pd.DataFrame, pd.DataFrame,
# Original Full DataFrames
pd.DataFrame, pd.DataFrame, pd.DataFrame,
# Ordinal Direction Target (Train only)
pd.Series,
# Forward Returns (Train/Val)
pd.Series, Optional[pd.Series],
# Epsilon (Train/Val)
Optional[pd.Series], Optional[pd.Series],
# Ordinal Direction Labels (Val)
Optional[pd.Series]
]:
"""Splits features, targets, fwd returns, and epsilon for a given fold.
Args:
df_labeled_aligned (pd.DataFrame): Labeled and aligned data for the entire fold period.
fwd_returns_aligned (pd.Series): Forward returns aligned with df_labeled_aligned.
eps_aligned (Optional[pd.Series]): Epsilon threshold aligned with df_labeled_aligned.
config (dict): Pipeline configuration.
target_columns (List[str]): Names of all target columns (e.g., ['fwd_log_ret_5', 'direction_label_5']).
target_dir_col (str): Name of the specific direction target column.
fold_dates (Optional[Tuple]): Tuple of (train_start, train_end, val_start, val_end, test_start, test_end) for WF.
current_fold (Optional[int]): Fold number for logging.
Returns:
Tuple containing the split dataframes/series:
(X_train_raw, X_val_raw, X_test_raw, # Features
y_train, y_val, y_test, # Targets
df_train_original, df_val_original, df_test_original, # Original DFs
y_dir_train_ordinal, # Ordinal Direction Labels (Train)
fwd_ret_train, fwd_ret_val, # Forward Returns (Train/Val)
eps_train, eps_val, # Epsilon (Train/Val, Optional)
y_dir_val_ordinal) # Ordinal Direction Labels (Val, Optional)
Returns tuple of Nones if splitting fails.
"""
fold_label = f"Fold {current_fold}" if current_fold is not None else "Split"
logger.info(f"--- {fold_label}: Stage: Splitting Data --- ")
if df_labeled_aligned is None or df_labeled_aligned.empty:
logger.error(f"Fold {fold_label}: Input data for splitting is empty.")
# Return Nones to indicate failure (update count based on new returns)
return (None,) * 14
# --- Temporarily add fwd_ret and eps to DataFrame for easier splitting --- #
temp_fwd_ret_col = '__temp_fwd_ret__'
temp_eps_col = '__temp_eps__'
df_split_input = df_labeled_aligned.copy()
df_split_input[temp_fwd_ret_col] = fwd_returns_aligned
if eps_aligned is not None:
df_split_input[temp_eps_col] = eps_aligned
# --- End Temp Add --- #
if not isinstance(df_split_input.index, pd.DatetimeIndex):
logger.error(f"{fold_label}: Data index must be DatetimeIndex for splitting. Aborting.")
raise SystemExit(f"{fold_label}: Index is not DatetimeIndex in split_data.")
if not target_columns:
logger.error(f"{fold_label}: Target columns list is empty. Aborting.")
raise SystemExit(f"{fold_label}: Target columns missing in split_data.")
if not target_dir_col:
logger.error(f"{fold_label}: Target direction column name is empty. Aborting.")
raise SystemExit(f"{fold_label}: Target direction column missing in split_data.")
# Ensure target columns exist before trying to drop/select them
cols_to_drop = [col for col in target_columns if col in df_split_input.columns]
if len(cols_to_drop) != len(target_columns):
missing_targets = set(target_columns) - set(cols_to_drop)
logger.error(f"{fold_label}: Expected target columns {missing_targets} not found in input DataFrame. Aborting.")
raise SystemExit(f"{fold_label}: Missing target columns in split_data input.")
# Exclude temporary columns from feature_cols
feature_cols = df_split_input.columns.difference(cols_to_drop + [temp_fwd_ret_col, temp_eps_col])
if feature_cols.empty:
logger.error(f"{fold_label}: No feature columns remain after excluding targets and temp cols. Aborting.")
raise SystemExit(f"{fold_label}: No feature columns found in split_data.")
# --- Determine if ternary mode is active --- #
use_ternary = config.get('gru', {}).get('use_ternary', False)
# --- End Determine Ternary --- #
# Initialize split results
X_train_raw, X_val_raw, X_test_raw = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
y_train, y_val, y_test = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
df_train_original, df_val_original, df_test_original = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
fwd_ret_train, fwd_ret_val = pd.Series(dtype=float), pd.Series(dtype=float)
eps_train, eps_val = None, None
y_dir_train_raw_format, y_dir_val_raw_format = pd.Series(dtype=object), pd.Series(dtype=object) # Store raw labels before converting
# Split based on Walk-Forward dates or ratios
if fold_dates and len(fold_dates) == 6 and all(fold_dates): # Check for valid WF tuple
train_start, train_end, val_start, val_end, test_start, test_end = fold_dates
logger.info(f" Splitting using Walk-Forward dates: Train=[{train_start}, {train_end}), Val=[{val_start}, {val_end}), Test=[{test_start}, {test_end})")
# Slicing logic
df_train_original = df_split_input.loc[train_start:train_end]
df_val_original = df_split_input.loc[val_start:val_end]
df_test_original = df_split_input.loc[test_start:test_end] if test_start else pd.DataFrame()
else: # Single split using ratios
split_cfg = config.get('split_ratios', {})
train_ratio = split_cfg.get('train', 0.7)
val_ratio = split_cfg.get('validation', 0.15)
test_ratio = round(1.0 - train_ratio - val_ratio, 2)
logger.info(f" Splitting using ratios: Train={train_ratio:.2f}, Val={val_ratio:.2f}, Test={test_ratio:.2f}")
total_len = len(df_split_input)
train_end_idx = int(total_len * train_ratio)
val_end_idx = int(total_len * (train_ratio + val_ratio))
df_train_original = df_split_input.iloc[:train_end_idx]
df_val_original = df_split_input.iloc[train_end_idx:val_end_idx]
df_test_original = df_split_input.iloc[val_end_idx:]
# --- Extract components from split DataFrames --- #
if not df_train_original.empty:
X_train_raw = df_train_original[feature_cols]
y_train = df_train_original[target_columns]
y_dir_train_raw_format = df_train_original[target_dir_col]
fwd_ret_train = df_train_original[temp_fwd_ret_col]
if temp_eps_col in df_train_original:
eps_train = df_train_original[temp_eps_col]
if not df_val_original.empty:
X_val_raw = df_val_original[feature_cols]
y_val = df_val_original[target_columns]
fwd_ret_val = df_val_original[temp_fwd_ret_col]
if temp_eps_col in df_val_original:
eps_val = df_val_original[temp_eps_col]
if not df_test_original.empty:
X_test_raw = df_test_original[feature_cols]
y_test = df_test_original[target_columns]
# --- End Extraction --- #
# --- Extract Ordinal Labels if Ternary --- #
y_dir_train_ordinal = None
if not y_train.empty: # Check if training data exists
if use_ternary:
valid_mask = y_dir_train_raw_format.notna() & y_dir_train_raw_format.apply(lambda x: isinstance(x, list) and len(x) == 3)
if valid_mask.any():
ordinal_values = y_dir_train_raw_format[valid_mask].apply(np.argmax)
y_dir_train_ordinal = pd.Series(np.nan, index=y_dir_train_raw_format.index)
y_dir_train_ordinal[valid_mask] = ordinal_values
logger.info(f"{fold_label}: Extracted ordinal labels (0, 1, 2) for feature selection. Count: {valid_mask.sum()}")
else:
logger.warning(f"{fold_label}: No valid list-based ternary labels found in y_dir_train_raw_format to convert to ordinal.")
y_dir_train_ordinal = pd.Series(dtype=np.float64) # Return empty series
else:
y_dir_train_ordinal = y_dir_train_raw_format.astype(int) # Ensure integer type
else:
y_dir_train_ordinal = pd.Series(dtype=int) # Empty series if no train data
# --- End Extract Ordinal Labels --- #
# --- Extract Ordinal Validation Labels if Ternary --- #
y_dir_val_ordinal = None
if not y_val.empty: # Check if validation data exists
if use_ternary:
# Use y_dir_val_raw_format which holds the lists/NaNs
valid_mask_val = y_dir_val_raw_format.notna() & y_dir_val_raw_format.apply(lambda x: isinstance(x, list) and len(x) == 3)
if valid_mask_val.any():
ordinal_values_val = y_dir_val_raw_format[valid_mask_val].apply(np.argmax)
y_dir_val_ordinal = pd.Series(np.nan, index=y_dir_val_raw_format.index)
y_dir_val_ordinal[valid_mask_val] = ordinal_values_val
logger.info(f"{fold_label}: Extracted ordinal validation labels. Count: {valid_mask_val.sum()}")
else:
logger.warning(f"{fold_label}: No valid ternary labels found in y_dir_val_raw_format.")
y_dir_val_ordinal = pd.Series(dtype=np.float64)
else:
# Use y_dir_val_raw_format which holds 0.0/1.0
y_dir_val_ordinal = y_dir_val_raw_format.astype(int)
else:
y_dir_val_ordinal = pd.Series(dtype=int) # Empty series if no validation data
# --- End Extract Ordinal Validation Labels --- #
# Log split shapes and check for empty splits
logger.info(f"Data split complete for {fold_label}:")
logger.info(f" Train: X={X_train_raw.shape}, y={y_train.shape}, fwd_ret={fwd_ret_train.shape}, eps={eps_train.shape if eps_train is not None else 'None'} ({X_train_raw.index.min()} to {X_train_raw.index.max()})" if not X_train_raw.empty else " Train: EMPTY")
logger.info(f" Val: X={X_val_raw.shape}, y={y_val.shape}, fwd_ret={fwd_ret_val.shape}, eps={eps_val.shape if eps_val is not None else 'None'} ({X_val_raw.index.min()} to {X_val_raw.index.max()})" if not X_val_raw.empty else " Val: EMPTY")
logger.info(f" Test: X=({X_test_raw.shape if X_test_raw is not None else 'None'}), y=({y_test.shape if y_test is not None else 'None'}) ({df_test_original.index.min() if df_test_original is not None and not df_test_original.empty else 'N/A'} to {df_test_original.index.max() if df_test_original is not None and not df_test_original.empty else 'N/A'})" )
# Check required splits are non-empty
if X_train_raw.empty or X_val_raw.empty:
logger.error(f"Fold {current_fold}: Data splitting resulted in empty train or validation set. Aborting fold.")
raise SystemExit(f"Fold {current_fold}: Empty train or validation split detected.")
return (
X_train_raw, X_val_raw, X_test_raw, # Features
y_train, y_val, y_test, # Targets
df_train_original, df_val_original, df_test_original, # Original DFs
y_dir_train_ordinal, # Ordinal Direction Labels (Train)
fwd_ret_train, fwd_ret_val, # Forward Returns (Train/Val)
eps_train, eps_val, # Epsilon (Train/Val, Optional)
y_dir_val_ordinal # Ordinal Direction Labels (Val, Optional)
)

View File

@ -0,0 +1,182 @@
# Stage functions for creating GRU input sequences
import logging
import sys
import numpy as np
import pandas as pd
from typing import Tuple, Dict, Optional, List
import json # Added for saving artefact
# Assuming IOManager is importable from parent src directory
from ..io_manager import IOManager
logger = logging.getLogger(__name__)
def create_sequences_fold(
X_data: pd.DataFrame,
y_data: pd.DataFrame,
target_names: List[str], # e.g., ['mu', 'dir3'] or ['mu', 'dir']
lookback: int,
name: str, # e.g., "Train", "Validation", "Test"
config: dict, # For gru.drop_imputed_sequences
io: Optional[IOManager] # For saving artefact
) -> Tuple[Optional[np.ndarray], Optional[Dict], Optional[pd.Index], int]:
"""
Transforms pruned, scaled feature DataFrame into 3D sequences for GRU input
and extracts corresponding targets for a specific data split (Train/Val/Test).
Handles dropping sequences containing imputed bars based on config.
Args:
X_data (pd.DataFrame): Pruned, scaled features for the split.
y_data (pd.DataFrame): Targets for the split.
target_names (List[str]): List of target column names in y_data (e.g., ['mu', 'dir3']).
lookback (int): Sequence length.
name (str): Name of the split (e.g., "Train", "Validation", "Test") for logging.
config (dict): Pipeline configuration dictionary.
io (Optional[IOManager]): IOManager instance for saving artefacts.
Returns:
Tuple containing:
- X_seq (np.ndarray or None): 3D feature sequences.
- y_seq_dict (Dict or None): Dictionary of target sequences.
- target_indices (pd.Index or None): Timestamps corresponding to the targets.
- dropped_count (int): Number of sequences dropped due to imputed bars.
Returns (None, None, None, 0) if sequence creation fails or data is insufficient.
Raises SystemExit on critical errors (e.g., misalignment).
"""
logger.info(f"--- Creating {name} Sequences ---")
use_ternary = config.get('gru', {}).get('use_ternary', False)
drop_imputed = config.get('gru', {}).get('drop_imputed_sequences', False)
imputed_col_name = 'bar_imputed' # Assuming this is the column name
# --- Input Validation --- #
if X_data is None or y_data is None or X_data.empty or y_data.empty:
logger.error(f"{name}: Missing or empty features/targets for sequence creation.")
return None, None, None, 0
# Check for bar_imputed column
if imputed_col_name not in X_data.columns:
logger.error(f"{name}: Required column '{imputed_col_name}' not found in features. Cannot handle imputed sequences.")
# Decide whether to proceed without it or raise error - raising for now
raise SystemExit(f"{name}: '{imputed_col_name}' column missing. Sequence creation halted.")
# Strict Anti-Leakage Check
try:
assert X_data.index.equals(y_data.index), \
f"{name}: Features and targets indices misaligned!"
except AssertionError as e:
logger.error(f"Data alignment check failed: {e}. Potential data leakage. Aborting.")
raise SystemExit(f"{name}: {e}")
# Check target columns exist
if not all(col in y_data.columns for col in target_names):
missing_targets = set(target_names) - set(y_data.columns)
logger.error(f"{name}: Target columns {missing_targets} not found in y_data. Aborting.")
raise SystemExit(f"{name}: Missing target columns for sequencing.")
# --- End Input Validation --- #
# Convert DataFrames to numpy for potential speedup, keep index access
features_np = X_data.values
imputed_flag_np = X_data[imputed_col_name].values.astype(bool) # Ensure boolean type
# Extract targets based on target_names
targets_dict_np = {name: y_data[name].values for name in target_names}
X_seq_list, y_seq_dict_list = [], {name: [] for name in target_names}
mask_seq_list = [] # To store the imputed flag sequences
target_indices = []
if len(X_data) <= lookback:
logger.warning(f"{name}: DataFrame length ({len(X_data)}) is not greater than lookback ({lookback}). Cannot create sequences.")
return None, None, None, 0
for i in range(lookback, len(features_np)):
# Feature window: [i-lookback, i)
X_seq_list.append(features_np[i - lookback : i])
mask_seq_list.append(imputed_flag_np[i - lookback : i])
# Targets correspond to index i
for t_name in target_names:
target_val = targets_dict_np[t_name][i]
# Special handling for potential list/array type in ternary labels
if use_ternary and 'dir' in t_name and isinstance(target_val, list):
target_val = np.array(target_val, dtype=np.float32)
y_seq_dict_list[t_name].append(target_val)
target_indices.append(y_data.index[i]) # Get index corresponding to target
if not X_seq_list: # Check if any sequences were created
logger.warning(f"{name}: No sequences were generated (length <= lookback?).")
return None, None, None, 0
# Convert lists to numpy arrays
X_seq = np.array(X_seq_list, dtype=np.float32)
mask_seq = np.array(mask_seq_list, dtype=bool)
target_indices_pd = pd.Index(target_indices)
y_seq_dict_np = {}
for t_name in target_names:
try:
# Attempt to stack; requires consistent shapes
if use_ternary and 'dir' in t_name:
y_seq_dict_np[t_name] = np.stack(y_seq_dict_list[t_name]).astype(np.float32)
else: # Assuming other targets are scalar
y_seq_dict_np[t_name] = np.array(y_seq_dict_list[t_name], dtype=np.float32)
except ValueError as e:
logger.error(f"{name}: Error stacking target '{t_name}': {e}. Check target consistency (especially ternary).", exc_info=True)
shapes = [getattr(item, 'shape', type(item)) for item in y_seq_dict_list[t_name]]
from collections import Counter
logger.error(f"Target shapes/types found: {Counter(shapes)}")
raise SystemExit(f"{name}: Inconsistent target shapes for '{t_name}' during sequence creation.") from e
orig_n = X_seq.shape[0]
dropped_count = 0
# Conditionally drop sequences containing imputed bars
if drop_imputed:
logger.info(f"{name}: Dropping sequences containing imputed bars (drop_imputed_sequences=True)...")
valid_mask = ~mask_seq.any(axis=1)
X_seq = X_seq[valid_mask]
mask_seq = mask_seq[valid_mask] # Keep mask aligned, though not explicitly used later
for t_name in target_names:
y_seq_dict_np[t_name] = y_seq_dict_np[t_name][valid_mask]
target_indices_pd = target_indices_pd[valid_mask]
dropped_count = orig_n - X_seq.shape[0]
logger.info(f"{name}: Generated {orig_n} sequences, dropped {dropped_count} containing imputed bars. Remaining: {X_seq.shape[0]}")
# Save summary artifact
if io:
summary_data = {
"split_name": name,
"total_sequences_generated": orig_n,
"sequences_dropped_imputed": dropped_count,
"sequences_remaining": X_seq.shape[0],
"drop_imputed_sequences_config": drop_imputed
}
try:
filename = f"imputed_sequence_summary_{name.lower()}.json"
io.save_json(summary_data, filename, section='results', indent=4)
logger.info(f"Saved imputed sequence summary to results/{filename}")
except Exception as e:
logger.error(f"Failed to save imputed sequence summary for {name}: {e}")
else:
logger.warning(f"IOManager not available, cannot save imputed sequence summary for {name}.")
else:
logger.info(f"{name}: Generated {orig_n} sequences. Keeping sequences with imputed bars (drop_imputed_sequences=False).")
# Final checks
if X_seq.shape[0] == 0:
logger.error(f"{name}: No valid sequences remaining after potential filtering. Aborting.")
return None, None, None, dropped_count # Return 0 count if no sequences left
# --- REMOVE: Final Dictionary Mapping (Let GRU handler manage this) --- #
# final_y_seq_dict = {
# 'mu': y_seq_dict_np['ret'], # Map 'ret' to 'mu'
# 'dir3': y_seq_dict_np['dir3'] # Keep 'dir3' as is
# }
# --- END REMOVE --- #
# Log final shapes
logger.info(f"Sequence shapes created for {name}:")
logger.info(f" X={X_seq.shape}, y_keys={list(y_seq_dict_np.keys())}, indices={len(target_indices_pd)}")
return X_seq, y_seq_dict_np, target_indices_pd, dropped_count

View File

@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns.
import numpy as np
import pandas as pd
import logging
import gymnasium as gym
from omegaconf import DictConfig # Added for config typing
env_logger = logging.getLogger(__name__)
@ -15,6 +17,8 @@ class TradingEnv:
sigma_predictions: np.ndarray,
p_cal_predictions: np.ndarray,
actual_returns: np.ndarray,
bar_imputed_flags: np.ndarray, # Added imputed flags
config: DictConfig, # Added config
initial_capital: float = 10000.0,
transaction_cost: float = 0.0005,
reward_scale: float = 100.0,
@ -27,18 +31,22 @@ class TradingEnv:
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
p_cal_predictions: Calibrated probability of price increase (p_cal).
actual_returns: Actual log returns (y_ret).
bar_imputed_flags: Boolean array indicating if a bar was imputed.
config: OmegaConf configuration object.
initial_capital: Starting capital for simulation (used notionally in reward).
transaction_cost: Fractional cost per trade.
reward_scale: Multiplier for the reward signal.
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
"""
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \
"All input arrays must have the same length"
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \
"All input arrays (predictions, returns, imputed_flags) must have the same length"
self.mu = mu_predictions
self.sigma = sigma_predictions
self.p_cal = p_cal_predictions
self.actual_returns = actual_returns
self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags
self.config = config # Store config
self.initial_capital = initial_capital
self.transaction_cost = transaction_cost
@ -65,20 +73,36 @@ class TradingEnv:
self.state_dim = 5
self.action_dim = 1
# --- Define Gym Spaces ---
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32)
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
# --- End Define Gym Spaces ---
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
def _get_state(self) -> np.ndarray:
"""Construct the state vector for the current step."""
if self.current_step >= self.n_steps:
# Handle episode end - return a dummy state or zeros
return np.zeros(self.state_dim, dtype=np.float32)
mu_t = self.mu[self.current_step]
sigma_t = self.sigma[self.current_step]
p_cal_t = self.p_cal[self.current_step]
# Calculate edge based on p_cal shape (binary vs ternary)
if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3:
# Ternary: edge = max(P(up), P(down)) - P(flat)
# Assuming order [Down, Flat, Up] for p_cal_t
edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1]
elif isinstance(p_cal_t, (float, np.number)):
# Binary: edge = 2 * P(up) - 1
edge_t = 2 * p_cal_t - 1
z_score_t = np.abs(mu_t) / (sigma_t + 1e-9)
else:
env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.")
edge_t = 0.0
_EPS = 1e-9 # Define epsilon locally
z_score_t = np.abs(mu_t) / (sigma_t + _EPS)
# State uses position *before* the action for this step is taken
state = np.array([
@ -108,11 +132,48 @@ class TradingEnv:
Returns:
tuple: (next_state, reward, done, info_dict)
"""
info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False}
if self.current_step >= self.n_steps:
# Should not happen if 'done' is handled correctly, but as safeguard
env_logger.warning("Step called after environment finished.")
return self._get_state(), 0.0, True, {}
return self._get_state(), 0.0, True, info
# --- Handle Imputed Bar --- #
imputed = self.bar_imputed[self.current_step]
if imputed:
mode = self.config.sac.imputed_handling
env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}")
if mode == "skip":
self.current_step += 1
next_state = self._get_state() # Get state for the *next* actual step
# Return 0 reward, not done, but indicate skip for buffer handling
info['is_imputed_step_skipped'] = True
return next_state, 0.0, False, info
elif mode == "hold":
# Action is forced to maintain current position
action = self.current_position
elif mode == "penalty":
# Calculate reward penalty based on config
target_position_penalty = np.clip(action, -1.0, 1.0)
reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2
# Update position based on agent's intended action (clipped)
self.current_position = target_position_penalty
# Update capital notionally (no actual return, only cost if implemented)
# Cost is implicitly 0 here as there's no trade size if pos doesn't change
# If penalty mode allowed position change, cost would apply.
# For simplicity, we don't add cost here for the penalty step.
self.current_step += 1
next_state = self._get_state()
scaled_reward = reward * self.reward_scale # Scale the penalty
done = self.current_step >= self.n_steps
info['capital'] = self.current_capital
info['position'] = self.current_position
return next_state, scaled_reward, done, info
# else: default behavior (treat as normal bar) - implicitly handled by falling through
# --- End Handle Imputed Bar --- #
# --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- #
# Action is the TARGET position for the *end* of this step
target_position = np.clip(action, -1.0, 1.0)
trade_size = target_position - self.current_position
@ -150,7 +211,9 @@ class TradingEnv:
done = self.current_step >= self.n_steps or self.current_capital <= 0
next_state = self._get_state()
info = {'capital': self.current_capital, 'position': self.current_position}
# Update info dict (capital/position might have changed in normal step)
info['capital'] = self.current_capital
info['position'] = self.current_position
# Log step details periodically
# if self.current_step % 1000 == 0:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,186 @@
import pytest
import pandas as pd
import numpy as np
from omegaconf import OmegaConf
from unittest.mock import MagicMock
import os
import tempfile
import json
# Adjust the import path based on your project structure
from gru_sac_predictor.src.pipeline_stages.sequence_creation import create_sequences_fold
from gru_sac_predictor.src.io_manager import IOManager # Adjust path if needed
# --- Test Fixtures ---
@pytest.fixture
def sample_data_with_imputed():
"""Creates sample X and y dataframes with a 'bar_imputed' column."""
dates = pd.to_datetime(pd.date_range('2023-01-01', periods=20, freq='T'))
lookback = 5
n_features_orig = 3
n_samples = len(dates)
# Features (including bar_imputed)
X_data = pd.DataFrame(
np.random.randn(n_samples, n_features_orig),
index=dates,
columns=[f'feat_{i}' for i in range(n_features_orig)]
)
# Add bar_imputed column - mark some bars as imputed
imputed_flags = np.zeros(n_samples, dtype=bool)
imputed_flags[2] = True # Imputed within first potential sequence
imputed_flags[8] = True # Imputed within a later potential sequence
imputed_flags[15] = True # Imputed near the end
X_data['bar_imputed'] = imputed_flags
# Targets (mu and dir3)
y_data = pd.DataFrame({
'mu': np.random.randn(n_samples),
'dir3': [list(row) for row in np.eye(3)[np.random.randint(0, 3, n_samples)]] # Example one-hot
}, index=dates)
return X_data, y_data
@pytest.fixture
def base_config():
"""Creates a base OmegaConf config for testing sequence creation."""
conf = OmegaConf.create({
'gru': {
'lookback': 5,
'use_ternary': True, # Matches sample_data_with_imputed
'drop_imputed_sequences': True # Default to True for testing dropping
},
# Add other necessary sections if needed
})
return conf
@pytest.fixture
def mock_io_manager():
"""Creates a mock IOManager for testing artefact saving."""
with tempfile.TemporaryDirectory() as tmpdir:
mock_io = MagicMock(spec=IOManager)
mock_io.results_dir = tmpdir
saved_jsons = {}
def mock_save_json(data, filename, **kwargs):
filepath = os.path.join(tmpdir, filename)
saved_jsons[filename] = data
with open(filepath, 'w') as f:
json.dump(data, f, **kwargs)
mock_io.save_json.side_effect = mock_save_json
mock_io.get_artifact_path.side_effect = lambda filename: os.path.join(tmpdir, filename)
mock_io._saved_jsons = saved_jsons
yield mock_io
# --- Test Functions ---
def test_sequence_creation_shapes(sample_data_with_imputed, base_config, mock_io_manager):
X_data, y_data = sample_data_with_imputed
lookback = base_config.gru.lookback
n_features = X_data.shape[1]
n_samples = len(X_data)
expected_n_seq = n_samples - lookback
# Test without dropping imputed
cfg_no_drop = base_config.copy()
cfg_no_drop.gru.drop_imputed_sequences = False
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
X_data=X_data, y_data=y_data, target_names=['mu', 'dir3'],
lookback=lookback, name="TestSplit", config=cfg_no_drop, io=mock_io_manager
)
assert dropped_count == 0
assert X_seq is not None
assert y_seq_dict is not None
assert indices is not None
assert X_seq.shape == (expected_n_seq, lookback, n_features)
assert 'mu' in y_seq_dict and y_seq_dict['mu'].shape == (expected_n_seq,)
assert 'dir3' in y_seq_dict and y_seq_dict['dir3'].shape == (expected_n_seq, 3)
assert len(indices) == expected_n_seq
# Check first target index corresponds to lookback-th original index
assert indices[0] == X_data.index[lookback]
# Check last target index corresponds to last original index
assert indices[-1] == X_data.index[-1]
def test_sequence_dropping_imputed(sample_data_with_imputed, base_config, mock_io_manager):
X_data, y_data = sample_data_with_imputed
lookback = base_config.gru.lookback
n_samples = len(X_data)
expected_n_seq_orig = n_samples - lookback
# Config with dropping enabled (default in fixture)
cfg_drop = base_config
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
X_data=X_data.copy(), y_data=y_data.copy(), target_names=['mu', 'dir3'],
lookback=lookback, name="TestDrop", config=cfg_drop, io=mock_io_manager
)
assert X_seq is not None
assert y_seq_dict is not None
assert indices is not None
# Determine which original sequences should have been dropped
# A sequence starting at index i uses data from [i, i+lookback-1]
# The target corresponds to index i+lookback
# We need to check the imputed flag in the range [i, i+lookback-1] for each potential sequence target index i+lookback
# Original target indices range from index `lookback` to `n_samples - 1`
should_be_dropped_mask = np.zeros(expected_n_seq_orig, dtype=bool)
imputed_flags_np = X_data['bar_imputed'].values
for seq_idx in range(expected_n_seq_orig):
# The features for this sequence are from original indices [seq_idx, seq_idx + lookback - 1]
feature_indices_range = slice(seq_idx, seq_idx + lookback)
if np.any(imputed_flags_np[feature_indices_range]):
should_be_dropped_mask[seq_idx] = True
expected_dropped_count = np.sum(should_be_dropped_mask)
expected_remaining_count = expected_n_seq_orig - expected_dropped_count
assert dropped_count == expected_dropped_count
assert X_seq.shape[0] == expected_remaining_count
assert y_seq_dict['mu'].shape[0] == expected_remaining_count
assert y_seq_dict['dir3'].shape[0] == expected_remaining_count
assert len(indices) == expected_remaining_count
# Check that the remaining indices are correct (weren't marked for dropping)
original_indices = X_data.index[lookback:]
expected_remaining_indices = original_indices[~should_be_dropped_mask]
pd.testing.assert_index_equal(indices, expected_remaining_indices)
# Check artifact saving
assert 'imputed_sequence_summary_testdrop.json' in mock_io_manager._saved_jsons
report_data = mock_io_manager._saved_jsons['imputed_sequence_summary_testdrop.json']
assert report_data['total_sequences_generated'] == expected_n_seq_orig
assert report_data['sequences_dropped_imputed'] == expected_dropped_count
assert report_data['sequences_remaining'] == expected_remaining_count
def test_sequence_creation_no_imputed_col(sample_data_with_imputed, base_config, mock_io_manager):
X_data, y_data = sample_data_with_imputed
X_data_no_imputed = X_data.drop(columns=['bar_imputed'])
lookback = base_config.gru.lookback
with pytest.raises(SystemExit) as excinfo:
create_sequences_fold(
X_data=X_data_no_imputed, y_data=y_data, target_names=['mu', 'dir3'],
lookback=lookback, name="TestNoImputedCol", config=base_config, io=mock_io_manager
)
assert "'bar_imputed' column missing" in str(excinfo.value)
def test_sequence_creation_insufficient_data(sample_data_with_imputed, base_config, mock_io_manager):
X_data, y_data = sample_data_with_imputed
lookback = base_config.gru.lookback
# Create data shorter than lookback
X_short = X_data.iloc[:lookback-1]
y_short = y_data.iloc[:lookback-1]
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
X_data=X_short, y_data=y_short, target_names=['mu', 'dir3'],
lookback=lookback, name="TestShort", config=base_config, io=mock_io_manager
)
assert X_seq is None
assert y_seq_dict is None
assert indices is None
assert dropped_count == 0

View File

@ -0,0 +1,166 @@
import pytest
import numpy as np
from omegaconf import OmegaConf
# Adjust import path based on structure
from gru_sac_predictor.src.trading_env import TradingEnv
# --- Test Fixtures ---
@pytest.fixture
def sample_env_data():
"""Provides sample data for initializing the TradingEnv."""
n_steps = 10
data = {
'mu_predictions': np.random.randn(n_steps) * 0.001,
'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005),
'p_cal_predictions': np.random.rand(n_steps),
'actual_returns': np.random.randn(n_steps) * 0.0015,
'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool)
}
return data
@pytest.fixture
def base_env_config():
"""Base configuration for the environment."""
return OmegaConf.create({
'sac': {
'imputed_handling': 'skip', # Default test mode
'action_penalty': 0.05
},
'environment': {
'initial_capital': 10000.0,
'transaction_cost': 0.0005,
'reward_scale': 100.0,
'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc
}
})
@pytest.fixture
def trading_env_instance(sample_env_data, base_env_config):
"""Creates a TradingEnv instance with default 'skip' mode."""
return TradingEnv(**sample_env_data, config=base_env_config)
# --- Test Functions ---
def test_env_initialization(trading_env_instance, sample_env_data):
assert trading_env_instance.n_steps == len(sample_env_data['actual_returns'])
assert trading_env_instance.current_step == 0
assert trading_env_instance.current_position == 0.0
assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags'])
def test_env_reset(trading_env_instance):
# Take a few steps
trading_env_instance.step(0.5)
trading_env_instance.step(-0.2)
assert trading_env_instance.current_step > 0
# Reset
initial_state = trading_env_instance.reset()
assert trading_env_instance.current_step == 0
assert trading_env_instance.current_position == 0.0
assert initial_state is not None
assert initial_state.shape == (trading_env_instance.state_dim,)
def test_env_step_normal(trading_env_instance):
# Test a normal step (step 0 is not imputed)
initial_pos = trading_env_instance.current_position
action = 0.7
next_state, reward, done, info = trading_env_instance.step(action)
assert trading_env_instance.current_step == 1
assert trading_env_instance.current_position == action # Position updates to action
assert not info['is_imputed_step_skipped']
assert not done
assert next_state is not None
# Reward calculation is complex, just check type/sign if needed
assert isinstance(reward, float)
def test_env_step_imputed_skip(trading_env_instance, sample_env_data):
# Step 2 is imputed in sample_env_data
trading_env_instance.step(0.5) # Step 0
trading_env_instance.step(0.6) # Step 1
assert trading_env_instance.current_step == 2
initial_pos_before_imputed = trading_env_instance.current_position
# Action for the imputed step (should be ignored by 'skip')
action_imputed = 0.9
next_state, reward, done, info = trading_env_instance.step(action_imputed)
# Should skip step 2 and now be at step 3
assert trading_env_instance.current_step == 3
# Position should NOT have changed from step 1
assert trading_env_instance.current_position == initial_pos_before_imputed
assert reward == 0.0 # Skip gives 0 reward
assert not done
assert info['is_imputed_step_skipped'] == True # Crucial check for buffer
# Check that the returned state is for step 3
expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3
np.testing.assert_array_almost_equal(next_state, expected_state_step_3)
def test_env_step_imputed_hold(sample_env_data, base_env_config):
cfg = base_env_config.copy()
cfg.sac.imputed_handling = 'hold'
env = TradingEnv(**sample_env_data, config=cfg)
# Step 2 is imputed
env.step(0.5) # Step 0
env.step(0.6) # Step 1
assert env.current_step == 2
position_before_imputed = env.current_position
# Action for the imputed step (should be overridden by 'hold')
action_imputed = -0.5
next_state, reward, done, info = env.step(action_imputed)
# Should process step 2 and move to step 3
assert env.current_step == 3
# Position should be the same as before the step
assert env.current_position == position_before_imputed
assert not info['is_imputed_step_skipped']
assert not done
# Reward should be calculated based on holding the position
expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1)
expected_cost = 0 # No trade size if holding
expected_penalty = 0 # No penalty in hold mode
expected_raw_reward = expected_pnl - expected_cost - expected_penalty
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
assert np.isclose(reward, expected_scaled_reward)
def test_env_step_imputed_penalty(sample_env_data, base_env_config):
cfg = base_env_config.copy()
cfg.sac.imputed_handling = 'penalty'
cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing
env = TradingEnv(**sample_env_data, config=cfg)
# Step 2 is imputed
env.step(0.5) # Step 0
env.step(0.6) # Step 1
assert env.current_step == 2
position_before_imputed = env.current_position # Should be 0.6
# Action for the imputed step
action_imputed = -0.2
next_state, reward, done, info = env.step(action_imputed)
# Should process step 2 and move to step 3
assert env.current_step == 3
# Position should update to the *agent's* action
assert env.current_position == np.clip(action_imputed, -1.0, 1.0)
assert not info['is_imputed_step_skipped']
assert not done
# Reward calculation is ONLY the penalty
expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
assert np.isclose(reward, expected_scaled_reward)
def test_env_done_condition(trading_env_instance, sample_env_data):
n_steps = len(sample_env_data['actual_returns'])
# Step through the environment
done = False
for i in range(n_steps):
_, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1))
if i < n_steps - 1:
assert not done
else:
assert done # Should be done on the last step