Compare commits
4 Commits
a86cdb2c8f
...
c3526bb9f6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3526bb9f6 | ||
|
|
4b1b542430 | ||
|
|
9d6641a5f2 | ||
|
|
30f2fd2d1f |
711
gru_sac_predictor/src/pipeline_stages/data_processing.py
Normal file
711
gru_sac_predictor/src/pipeline_stages/data_processing.py
Normal file
@ -0,0 +1,711 @@
|
||||
# Stage functions for loading, initial preprocessing, feature engineering, label generation, and splitting
|
||||
import logging
|
||||
import sys # Added for sys.exit
|
||||
from datetime import datetime, timezone # Added for datetime
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Tuple, Optional, Any, List, Dict # Added List and Dict
|
||||
import matplotlib.pyplot as plt # Added for plotting
|
||||
import seaborn as sns # Added for plotting
|
||||
|
||||
# --- Component Imports --- #
|
||||
# Assuming DataLoader is in the parent directory's src
|
||||
# This might need adjustment based on actual project structure
|
||||
# Using relative import assuming pipeline_stages is sibling to other src modules
|
||||
from ..data_loader import DataLoader, fill_missing_bars
|
||||
from ..feature_engineer import FeatureEngineer # Added FeatureEngineer import
|
||||
from ..io_manager import IOManager # Added IOManager import
|
||||
from ..metrics import calculate_sharpe_ratio # For potential baseline comparison
|
||||
|
||||
# --- Local Imports --- #
|
||||
# Import the label generation function we moved here
|
||||
# Removed duplicate import: from .data_processing import generate_direction_labels
|
||||
|
||||
# Assuming tensorflow is installed and available
|
||||
try:
|
||||
from tensorflow.keras.utils import to_categorical
|
||||
except ImportError:
|
||||
logging.warning("TensorFlow/Keras not found. Ternary label one-hot encoding will fail.")
|
||||
# Define a placeholder if keras is not available
|
||||
def to_categorical(*args, **kwargs):
|
||||
raise NotImplementedError("Keras 'to_categorical' is unavailable.")
|
||||
|
||||
logger = logging.getLogger(__name__) # Use module-level logger
|
||||
|
||||
# --- Refactored Label Generation Logic (Moved from trading_pipeline.py) --- #
|
||||
def generate_direction_labels(df: pd.DataFrame, config: dict) -> Tuple[pd.DataFrame, str, pd.Series, Optional[pd.Series]]:
|
||||
"""
|
||||
Calculates forward returns and generates binary, soft binary, or ternary direction labels.
|
||||
Also returns the raw forward returns and the epsilon series used for ternary flat definition.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame containing at least a 'close' column and DatetimeIndex.
|
||||
config (dict): Pipeline configuration dictionary, expecting keys under 'gru' and 'data'.
|
||||
|
||||
Returns:
|
||||
tuple[pd.DataFrame, str, pd.Series, Optional[pd.Series]]:
|
||||
- DataFrame with added forward return and direction label columns (and NaNs dropped based on labels).
|
||||
- Name of the generated direction label column.
|
||||
- Series containing the calculated forward log returns (`fwd_log_ret`).
|
||||
- Series containing the calculated epsilon (`eps`) threshold if ternary, else None.
|
||||
"""
|
||||
if 'close' not in df.columns:
|
||||
raise ValueError("'close' column missing in input DataFrame for label generation.")
|
||||
|
||||
gru_cfg = config.get('gru', {})
|
||||
data_cfg = config.get('data', {})
|
||||
horizon = gru_cfg.get('prediction_horizon', 5)
|
||||
use_ternary = gru_cfg.get('use_ternary', False) # Check if ternary flag is set
|
||||
|
||||
target_ret_col = f'fwd_log_ret_{horizon}'
|
||||
eps_series: Optional[pd.Series] = None # Initialize eps
|
||||
|
||||
# --- Calculate Forward Log Return --- #
|
||||
shifted_close = df['close'].shift(-horizon)
|
||||
fwd_returns = np.log(shifted_close / df['close'])
|
||||
df[target_ret_col] = fwd_returns
|
||||
|
||||
# --- Generate Direction Label (Binary/Soft or Ternary) --- #
|
||||
if use_ternary:
|
||||
k = gru_cfg.get('flat_sigma_multiplier', 0.25)
|
||||
target_dir_col = f'direction_label3_{horizon}'
|
||||
logger.info(f"Generating ternary labels ({target_dir_col}) with k={k}...")
|
||||
|
||||
sigma_n = fwd_returns.rolling(window=horizon, min_periods=max(1, horizon//2)).std()
|
||||
eps = k * sigma_n
|
||||
eps_series = eps # Store the calculated eps series
|
||||
|
||||
conditions = [fwd_returns > eps, fwd_returns < -eps]
|
||||
choices = [2, 0] # 2=up, 0=down
|
||||
ordinal_labels = np.select(conditions, choices, default=1).astype(int) # 1=flat
|
||||
|
||||
# --- Log Distribution & Check Balance --- #
|
||||
df['_ordinal_label_temp'] = ordinal_labels
|
||||
valid_mask_for_dist = ~np.isnan(eps) & ~np.isnan(fwd_returns)
|
||||
ordinal_labels_valid = df.loc[valid_mask_for_dist, '_ordinal_label_temp']
|
||||
|
||||
if not ordinal_labels_valid.empty:
|
||||
counts = np.bincount(ordinal_labels_valid, minlength=3)
|
||||
total_valid = len(ordinal_labels_valid)
|
||||
if total_valid > 0: # Avoid division by zero
|
||||
dist_pct = counts / total_valid * 100
|
||||
log_msg = (f"Label dist (n={total_valid}): "
|
||||
f"Down(0)={dist_pct[0]:.1f}%, Flat(1)={dist_pct[1]:.1f}%, Up(2)={dist_pct[2]:.1f}%")
|
||||
logger.info(log_msg)
|
||||
|
||||
min_pct_threshold = 10.0 # As per implementation
|
||||
if any(p < min_pct_threshold for p in dist_pct):
|
||||
error_msg = f"Label imbalance detected! Min class percentage is {np.min(dist_pct):.1f}% (Threshold: {min_pct_threshold}%). Check data or flat_sigma_multiplier (k={k})."
|
||||
logger.error(error_msg)
|
||||
print(f"ERROR: {error_msg}") # Also print for visibility
|
||||
else:
|
||||
logger.warning("Label distribution check skipped: total valid labels is zero.")
|
||||
else:
|
||||
logger.warning("Could not calculate label distribution (no valid sigma or returns).")
|
||||
# --- End Distribution Check --- #
|
||||
|
||||
# --- One-hot encode --- #
|
||||
try:
|
||||
y_cat_full = np.full((len(df), 3), np.nan, dtype=np.float32)
|
||||
if not ordinal_labels_valid.empty:
|
||||
y_cat_valid = to_categorical(ordinal_labels_valid, num_classes=3)
|
||||
y_cat_full[valid_mask_for_dist] = y_cat_valid.astype(np.float32)
|
||||
else:
|
||||
logger.warning("No valid ordinal labels to one-hot encode.")
|
||||
|
||||
# Assign the list of arrays (or NaNs) - using list avoids mixed type issues later
|
||||
df[target_dir_col] = [list(row) if not np.all(np.isnan(row)) else np.nan for row in y_cat_full]
|
||||
|
||||
except NotImplementedError as nie:
|
||||
logger.error(f"Ternary label generation failed: {nie}. Keras 'to_categorical' is unavailable. Please install tensorflow.", exc_info=True)
|
||||
raise # Re-raise exception to halt pipeline
|
||||
except Exception as e:
|
||||
logger.error(f"Error during one-hot encoding: {e}", exc_info=True)
|
||||
raise # Re-raise exception to halt pipeline if encoding fails
|
||||
finally:
|
||||
if '_ordinal_label_temp' in df.columns:
|
||||
df.drop(columns=['_ordinal_label_temp'], inplace=True)
|
||||
# --- End One-hot Encoding --- #
|
||||
|
||||
else: # Binary / Soft Binary
|
||||
target_dir_col = f'direction_label_{horizon}'
|
||||
label_smoothing = data_cfg.get('label_smoothing', 0.0)
|
||||
if not (0.0 <= label_smoothing < 1.0):
|
||||
logger.warning(f"Invalid label_smoothing value ({label_smoothing}). Must be in [0.0, 1.0). Disabling smoothing.")
|
||||
label_smoothing = 0.0
|
||||
|
||||
if label_smoothing > 0.0:
|
||||
high_label = 1.0 - label_smoothing / 2.0
|
||||
low_label = label_smoothing / 2.0
|
||||
logger.info(f"Applying label smoothing: {label_smoothing:.2f} -> labels [{low_label:.2f}, {high_label:.2f}] for {target_dir_col}")
|
||||
df[target_dir_col] = np.where(fwd_returns > 0, high_label, low_label).astype(np.float32)
|
||||
else:
|
||||
logger.info(f"Using hard binary labels (0.0 / 1.0) for {target_dir_col}")
|
||||
df[target_dir_col] = (fwd_returns > 0).astype(np.float32)
|
||||
|
||||
# --- Drop Rows with NaN Targets --- #
|
||||
initial_rows = len(df)
|
||||
|
||||
# Create mask for NaNs in the direction column
|
||||
if use_ternary:
|
||||
# Check if elements are np.nan (since we assign np.nan for rows with no valid labels)
|
||||
nan_mask_dir = df[target_dir_col].isna()
|
||||
else:
|
||||
nan_mask_dir = df[target_dir_col].isna()
|
||||
|
||||
nan_mask_combined = df[target_ret_col].isna() | nan_mask_dir
|
||||
|
||||
df_clean = df[~nan_mask_combined].copy()
|
||||
|
||||
final_rows = len(df_clean)
|
||||
if final_rows < initial_rows:
|
||||
logger.info(f"Dropped {initial_rows - final_rows} rows due to NaN targets (horizon={horizon}).")
|
||||
|
||||
if df_clean.empty:
|
||||
logger.error("DataFrame is empty after defining labels and dropping NaNs. Exiting.")
|
||||
# Returning empty DataFrame, caller should handle exit
|
||||
return pd.DataFrame(), target_dir_col, pd.Series(dtype=float), None # Return empty series/None on failure
|
||||
|
||||
# Return the cleaned df, target col name, and the *original* full fwd_returns and eps series
|
||||
# Need to return the original series aligned with the original df index *before* cleaning
|
||||
# So the caller can align them with the features *after* cleaning df_clean
|
||||
return df_clean, target_dir_col, fwd_returns, eps_series
|
||||
# --- End Label Generation --- #
|
||||
|
||||
# --- Stage 1: Load and Preprocess Data (Moved from TradingPipeline.load_and_preprocess_data) --- #
|
||||
def load_and_preprocess(
|
||||
data_loader: DataLoader,
|
||||
io: Optional[IOManager],
|
||||
run_id: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Loads the full raw dataset using DataLoader and performs initial checks.
|
||||
|
||||
Args:
|
||||
data_loader: Initialized DataLoader instance.
|
||||
io: IOManager instance (optional).
|
||||
run_id: Current run ID.
|
||||
config: Pipeline configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- DataFrame with raw loaded data, or None on failure.
|
||||
- Dictionary summarizing the loading process, or None on failure.
|
||||
"""
|
||||
logger.info("--- Stage: Loading and Preprocessing Data ---")
|
||||
data_cfg = config.get('data', {})
|
||||
|
||||
# --- Extract necessary parameters from config --- #
|
||||
ticker = data_cfg.get('ticker')
|
||||
exchange = data_cfg.get('exchange')
|
||||
start_date = data_cfg.get('start_date')
|
||||
end_date = data_cfg.get('end_date')
|
||||
interval = data_cfg.get('interval', '1min') # Default to 1min
|
||||
vol_sampling = data_cfg.get('volatility_sampling', {}).get('enabled', False)
|
||||
vol_window = data_cfg.get('volatility_sampling', {}).get('window', 30)
|
||||
vol_quantile = data_cfg.get('volatility_sampling', {}).get('quantile', 0.5)
|
||||
|
||||
# Validate required parameters
|
||||
if not all([ticker, exchange, start_date, end_date]):
|
||||
logger.error("Missing required data parameters in config: ticker, exchange, start_date, end_date")
|
||||
return None, None
|
||||
# --- End Parameter Extraction --- #
|
||||
|
||||
load_summary = {
|
||||
'ticker': ticker,
|
||||
'exchange': exchange,
|
||||
'start_date_req': start_date,
|
||||
'end_date_req': end_date,
|
||||
'interval_req': interval,
|
||||
'vol_sampling_enabled': vol_sampling,
|
||||
'vol_window': vol_window,
|
||||
'vol_quantile': vol_quantile,
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Loading data for {ticker} ({exchange}) from {start_date} to {end_date}, interval {interval}")
|
||||
# --- Pass extracted parameters to load_data --- #
|
||||
df_raw = data_loader.load_data(
|
||||
ticker=ticker,
|
||||
exchange=exchange,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
interval=interval,
|
||||
vol_sampling=vol_sampling,
|
||||
vol_window=vol_window,
|
||||
vol_quantile=vol_quantile
|
||||
)
|
||||
# --- End Pass Parameters --- #
|
||||
|
||||
if df_raw is None or df_raw.empty:
|
||||
logger.error("Data loading returned empty DataFrame or failed.")
|
||||
return None, load_summary
|
||||
|
||||
# --- Fill Missing Bars (Step 2.5 from prompts/missing_data.txt) --- #
|
||||
if io is None:
|
||||
logger.error("IOManager is required for fill_missing_bars reporting. Cannot proceed.")
|
||||
return None, load_summary
|
||||
try:
|
||||
df_filled = fill_missing_bars(df_raw, config, io, logger)
|
||||
if df_filled is None or df_filled.empty:
|
||||
logger.error("fill_missing_bars returned empty DataFrame or failed.")
|
||||
return None, load_summary
|
||||
df_raw = df_filled # Replace df_raw with the filled version
|
||||
logger.info("Missing bars handled successfully.")
|
||||
except ValueError as e:
|
||||
logger.error(f"Error during missing bar handling: {e}. Halting processing.")
|
||||
return None, load_summary
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during missing bar handling: {e}. Halting processing.", exc_info=True)
|
||||
return None, load_summary
|
||||
# --- End Fill Missing Bars --- #
|
||||
|
||||
# Calculate memory usage and log info
|
||||
mem_usage = df_raw.memory_usage(deep=True).sum() / (1024**2)
|
||||
if load_summary:
|
||||
logger.info(f"Data loading summary: {load_summary}")
|
||||
else:
|
||||
logger.warning("No load summary returned by DataLoader.")
|
||||
logger.info(f"Loaded data: {df_raw.shape[0]} rows, {df_raw.shape[1]} columns. Memory: {mem_usage:.2f} MB")
|
||||
logger.info(f"Time range: {df_raw.index.min()} to {df_raw.index.max()}")
|
||||
|
||||
# --- V3 Output Contract: Stage 1 Artifacts --- #
|
||||
if io:
|
||||
if load_summary:
|
||||
save_summary = load_summary.copy() # Don't modify original
|
||||
save_summary['run_id'] = run_id
|
||||
save_summary['timestamp_utc'] = datetime.now(timezone.utc).isoformat()
|
||||
# TODO: Finalize summary content (add counts, NaN info etc.)
|
||||
logger.info("Saving preprocess summary...")
|
||||
io.save_json(save_summary, "preprocess_summary", use_txt=True) # Spec wants .txt
|
||||
|
||||
# Save head of preprocessed data
|
||||
if df_raw is not None and not df_raw.empty:
|
||||
logger.info("Saving head of preprocessed data (first 20 rows)...")
|
||||
io.save_df(df_raw.head(20), "head_preprocessed")
|
||||
else:
|
||||
logger.warning("Skipping saving head_preprocessed: DataFrame is empty or None.")
|
||||
|
||||
else:
|
||||
logger.warning("IOManager not available, skipping saving of Stage 1 artifacts (preprocess_summary, head_preprocessed).")
|
||||
# --- End V3 Output Contract ---
|
||||
|
||||
# --- V3 Output Contract: Stage 2 Artifact (Label Histogram) --- #
|
||||
# TODO: Move this plotting logic to evaluation stage or after split, needs y_train.
|
||||
# if io and config.get('control', {}).get('generate_plots', True):
|
||||
# logger.info("Generating training label distribution histogram... [SKIPPED IN CURRENT STAGE]")
|
||||
# ... (Original plotting code removed from here)
|
||||
# --- End V3 Output Contract ---
|
||||
|
||||
return df_raw, load_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during data loading: {e}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
# --- Stage 2: Engineer Features (Moved from TradingPipeline.engineer_features) --- #
|
||||
def engineer_features_for_fold(
|
||||
df: pd.DataFrame,
|
||||
feature_engineer: FeatureEngineer,
|
||||
io: Optional[IOManager], # Added IOManager for saving figure
|
||||
config: Dict[str, Any], # Added config for plot settings
|
||||
target_col: Optional[str] = None # Added target column name for sorting correlation
|
||||
) -> pd.DataFrame:
|
||||
"""Adds features using FeatureEngineer, handles NaNs, and saves correlation heatmap for a fold.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Input DataFrame for the fold (typically raw data).
|
||||
feature_engineer (FeatureEngineer): Initialized FeatureEngineer instance.
|
||||
io (Optional[IOManager]): IOManager instance for saving artifacts.
|
||||
config (Dict[str, Any]): Pipeline configuration dictionary.
|
||||
target_col (Optional[str]): Name of the target column to sort correlations by (e.g., 'fwd_log_ret_5').
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: DataFrame with engineered features, NaNs dropped.
|
||||
Returns an empty DataFrame if input is empty or result is empty.
|
||||
"""
|
||||
logger.info("--- Stage: Engineering Features --- ")
|
||||
if df is None or df.empty:
|
||||
logger.error("Input DataFrame is empty. Cannot engineer features.")
|
||||
return pd.DataFrame() # Return empty DataFrame to indicate failure
|
||||
|
||||
if feature_engineer is None:
|
||||
logger.error("FeatureEngineer not initialized. Cannot engineer features.")
|
||||
# Or raise an error? For now return empty
|
||||
return pd.DataFrame()
|
||||
|
||||
# Add base features (cyclical, imbalance, TA)
|
||||
df_engineered = feature_engineer.add_base_features(df.copy())
|
||||
|
||||
# --- V3 Output Contract: Feature Correlation Heatmap --- #
|
||||
# Generate heatmap *before* dropping NaNs to capture full feature set correlations
|
||||
# if io and config.get('control', {}).get('generate_plots', True): # Check if plotting is enabled
|
||||
if io: # Assume generate_plots is implicitly true if io is provided
|
||||
try:
|
||||
logger.info("Generating feature correlation heatmap...")
|
||||
numeric_cols = df_engineered.select_dtypes(include=np.number).columns
|
||||
if len(numeric_cols) < 2:
|
||||
logger.warning("Skipping correlation heatmap: Less than 2 numeric columns found.")
|
||||
else:
|
||||
corr_matrix = df_engineered[numeric_cols].corr(method='pearson')
|
||||
|
||||
# Get plot settings from config
|
||||
output_cfg = config.get('output', {})
|
||||
fig_size = output_cfg.get('figure_size', [16, 9])
|
||||
plot_style = output_cfg.get('plot_style', 'seaborn-v0_8-darkgrid')
|
||||
annot_threshold = output_cfg.get('corr_annot_threshold', 0.5)
|
||||
plot_footer = output_cfg.get('plot_footer', "© GRU-SAC v3")
|
||||
|
||||
plt.style.use(plot_style)
|
||||
fig, ax = plt.subplots(figsize=fig_size)
|
||||
|
||||
sort_features = False
|
||||
if target_col and target_col in corr_matrix.columns:
|
||||
# Sort by absolute correlation with the target
|
||||
target_corr = corr_matrix[target_col].abs().sort_values(ascending=False)
|
||||
sorted_cols = target_corr.index.tolist()
|
||||
corr_matrix_sorted = corr_matrix.loc[sorted_cols, sorted_cols]
|
||||
sort_features = True
|
||||
else:
|
||||
if target_col:
|
||||
logger.warning(f"Target column '{target_col}' not found in correlation matrix. Heatmap will not be sorted by target correlation.")
|
||||
corr_matrix_sorted = corr_matrix # Use original matrix if no target or not found
|
||||
|
||||
sns.heatmap(
|
||||
corr_matrix_sorted,
|
||||
annot=False, # Annotations can be messy; spec only requires > threshold
|
||||
cmap='coolwarm', # Diverging palette centered at 0
|
||||
center=0,
|
||||
linewidths=0.5,
|
||||
cbar=True,
|
||||
square=True, # Ensure square cells
|
||||
ax=ax
|
||||
)
|
||||
|
||||
# Annotate cells where absolute correlation > threshold (from config)
|
||||
for i in range(corr_matrix_sorted.shape[0]):
|
||||
for j in range(corr_matrix_sorted.shape[1]):
|
||||
if abs(corr_matrix_sorted.iloc[i, j]) > annot_threshold and i != j:
|
||||
ax.text(j + 0.5, i + 0.5, f'{corr_matrix_sorted.iloc[i, j]:.2f}',
|
||||
ha='center', va='center', color='black', fontsize=8)
|
||||
|
||||
title = "Feature Correlation Heatmap (Pearson)"
|
||||
if sort_features:
|
||||
title += f" - Sorted by |ρ| vs '{target_col}'"
|
||||
ax.set_title(title, fontsize=14)
|
||||
plt.xticks(rotation=90, fontsize=8)
|
||||
plt.yticks(rotation=0, fontsize=8)
|
||||
|
||||
# Add footer (from config)
|
||||
if plot_footer: # Only add if footer is not empty
|
||||
plt.figtext(0.99, 0.01, plot_footer, ha="right", va="bottom", fontsize=8, color='gray')
|
||||
|
||||
# Save figure using IOManager
|
||||
io.save_figure(fig, "feature_corr_heatmap", section='figures') # Saved to results/<run_id>/figures/
|
||||
plt.close(fig) # Close figure after saving
|
||||
logger.info("Saved feature correlation heatmap.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate or save feature correlation heatmap: {e}", exc_info=True)
|
||||
else:
|
||||
logger.warning("IOManager not provided or plotting disabled, skipping feature correlation heatmap.")
|
||||
# --- End V3 Output Contract --- #
|
||||
|
||||
# --- REMOVE Aggressive DropNA --- #
|
||||
# Dropping all rows with any NaN here is too aggressive, especially with long lookback features.
|
||||
# NaN handling should occur within feature calculation methods (bfill/ffill/fillna(0))
|
||||
# and critically during label definition (where rows without valid labels are dropped).
|
||||
# initial_rows = len(df_engineered)
|
||||
# df_engineered.dropna(inplace=True)
|
||||
# rows_dropped = initial_rows - len(df_engineered)
|
||||
# if rows_dropped > 0:
|
||||
# logger.warning(f"Dropped {rows_dropped} rows with NaN values after feature engineering.")
|
||||
# --- End REMOVE --- #
|
||||
|
||||
# Check if dataframe became empty *after feature calculation and internal NaN handling*
|
||||
# (Though ideally internal handling should prevent this)
|
||||
if df_engineered.empty:
|
||||
logger.error("DataFrame is empty after feature engineering (check internal NaN handling in FeatureEngineer)." )
|
||||
return pd.DataFrame() # Return empty DataFrame
|
||||
|
||||
logger.info(f"Feature engineering complete. Shape: {df_engineered.shape}")
|
||||
return df_engineered
|
||||
|
||||
# --- Stage 3: Define Labels and Align (Moved from TradingPipeline.define_labels_and_align) --- #
|
||||
def define_labels_and_align_fold(
|
||||
df_engineered: pd.DataFrame,
|
||||
config: dict
|
||||
) -> Tuple[pd.DataFrame, str, List[str], pd.Series, Optional[pd.Series]]:
|
||||
"""Defines prediction labels, aligns with features, and separates targets for a fold.
|
||||
Also returns the raw forward returns and epsilon series used for filtering baselines.
|
||||
|
||||
Args:
|
||||
df_engineered (pd.DataFrame): DataFrame with engineered features for the fold.
|
||||
config (dict): Pipeline configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Tuple[pd.DataFrame, str, List[str], pd.Series, Optional[pd.Series]]:
|
||||
- df_labeled_aligned: DataFrame with labels generated and features/targets aligned (NaNs dropped).
|
||||
- target_dir_col: Name of the direction label column.
|
||||
- target_cols: List containing names of all target columns (ret + dir).
|
||||
- fwd_returns_aligned: Series of forward returns aligned with df_labeled_aligned.
|
||||
- eps_aligned: Series of epsilon threshold aligned with df_labeled_aligned (or None).
|
||||
Returns (pd.DataFrame(), "", [], pd.Series(), None) on failure or empty input.
|
||||
"""
|
||||
logger.info("--- Stage: Defining Labels and Aligning --- ")
|
||||
if df_engineered is None or df_engineered.empty:
|
||||
logger.error("Engineered data (DataFrame) is empty. Cannot define labels.")
|
||||
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||
|
||||
# --- Call the label generation function (already in this module) --- #
|
||||
try:
|
||||
# generate_direction_labels modifies the DataFrame in place and returns it
|
||||
# It also returns the original fwd_returns and eps series (aligned with df_engineered)
|
||||
df_clean, target_dir_col, fwd_returns_orig, eps_orig = generate_direction_labels(
|
||||
df_engineered.copy(), # Pass a copy to avoid modifying original outside this scope if needed
|
||||
config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Label generation failed: {e}.", exc_info=True)
|
||||
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||
|
||||
if df_clean.empty:
|
||||
logger.error("Label generation resulted in an empty DataFrame.")
|
||||
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||
# --- End Label Generation Call --- #
|
||||
|
||||
# --- Determine Target Columns --- #
|
||||
horizon = config.get('gru', {}).get('prediction_horizon', 5)
|
||||
target_ret_col = f'fwd_log_ret_{horizon}'
|
||||
# target_dir_col is returned by generate_direction_labels
|
||||
target_cols = [target_ret_col, target_dir_col]
|
||||
|
||||
# Ensure the columns actually exist after generation and cleaning
|
||||
if not all(col in df_clean.columns for col in target_cols):
|
||||
# Log which columns are actually present for debugging
|
||||
present_cols = df_clean.columns.tolist()
|
||||
logger.error(f"Generated label/return columns ({target_cols}) not found in DataFrame after label generation. Present columns: {present_cols}")
|
||||
return pd.DataFrame(), "", [], pd.Series(dtype=float), None
|
||||
# --- End Determine Target Columns --- #
|
||||
|
||||
# --- Align fwd_returns_orig and eps_orig with the cleaned DataFrame --- #
|
||||
fwd_returns_aligned = fwd_returns_orig.loc[df_clean.index]
|
||||
eps_aligned = eps_orig.loc[df_clean.index] if eps_orig is not None else None
|
||||
# --- End Alignment --- #
|
||||
|
||||
# Note: Separation of X and y happens in the splitting function now
|
||||
# We just need to return the fully labeled/aligned DataFrame and target column names.
|
||||
logger.info(f"Labels defined and aligned. Shape: {df_clean.shape}")
|
||||
|
||||
# Return the aligned DataFrame and the aligned supplementary series
|
||||
return df_clean, target_dir_col, target_cols, fwd_returns_aligned, eps_aligned
|
||||
|
||||
# --- Stage 4: Split Data (Moved from TradingPipeline.split_data) --- #
|
||||
def split_data_fold(
|
||||
df_labeled_aligned: pd.DataFrame,
|
||||
fwd_returns_aligned: pd.Series,
|
||||
eps_aligned: Optional[pd.Series],
|
||||
config: dict,
|
||||
target_columns: List[str],
|
||||
target_dir_col: str,
|
||||
fold_dates: Optional[Tuple] = None,
|
||||
current_fold: Optional[int] = None # For logging
|
||||
) -> Tuple[
|
||||
# Features
|
||||
pd.DataFrame, pd.DataFrame, pd.DataFrame,
|
||||
# Original Targets
|
||||
pd.DataFrame, pd.DataFrame, pd.DataFrame,
|
||||
# Original Full DataFrames
|
||||
pd.DataFrame, pd.DataFrame, pd.DataFrame,
|
||||
# Ordinal Direction Target (Train only)
|
||||
pd.Series,
|
||||
# Forward Returns (Train/Val)
|
||||
pd.Series, Optional[pd.Series],
|
||||
# Epsilon (Train/Val)
|
||||
Optional[pd.Series], Optional[pd.Series],
|
||||
# Ordinal Direction Labels (Val)
|
||||
Optional[pd.Series]
|
||||
]:
|
||||
"""Splits features, targets, fwd returns, and epsilon for a given fold.
|
||||
|
||||
Args:
|
||||
df_labeled_aligned (pd.DataFrame): Labeled and aligned data for the entire fold period.
|
||||
fwd_returns_aligned (pd.Series): Forward returns aligned with df_labeled_aligned.
|
||||
eps_aligned (Optional[pd.Series]): Epsilon threshold aligned with df_labeled_aligned.
|
||||
config (dict): Pipeline configuration.
|
||||
target_columns (List[str]): Names of all target columns (e.g., ['fwd_log_ret_5', 'direction_label_5']).
|
||||
target_dir_col (str): Name of the specific direction target column.
|
||||
fold_dates (Optional[Tuple]): Tuple of (train_start, train_end, val_start, val_end, test_start, test_end) for WF.
|
||||
current_fold (Optional[int]): Fold number for logging.
|
||||
|
||||
Returns:
|
||||
Tuple containing the split dataframes/series:
|
||||
(X_train_raw, X_val_raw, X_test_raw, # Features
|
||||
y_train, y_val, y_test, # Targets
|
||||
df_train_original, df_val_original, df_test_original, # Original DFs
|
||||
y_dir_train_ordinal, # Ordinal Direction Labels (Train)
|
||||
fwd_ret_train, fwd_ret_val, # Forward Returns (Train/Val)
|
||||
eps_train, eps_val, # Epsilon (Train/Val, Optional)
|
||||
y_dir_val_ordinal) # Ordinal Direction Labels (Val, Optional)
|
||||
Returns tuple of Nones if splitting fails.
|
||||
"""
|
||||
fold_label = f"Fold {current_fold}" if current_fold is not None else "Split"
|
||||
logger.info(f"--- {fold_label}: Stage: Splitting Data --- ")
|
||||
|
||||
if df_labeled_aligned is None or df_labeled_aligned.empty:
|
||||
logger.error(f"Fold {fold_label}: Input data for splitting is empty.")
|
||||
# Return Nones to indicate failure (update count based on new returns)
|
||||
return (None,) * 14
|
||||
|
||||
# --- Temporarily add fwd_ret and eps to DataFrame for easier splitting --- #
|
||||
temp_fwd_ret_col = '__temp_fwd_ret__'
|
||||
temp_eps_col = '__temp_eps__'
|
||||
df_split_input = df_labeled_aligned.copy()
|
||||
df_split_input[temp_fwd_ret_col] = fwd_returns_aligned
|
||||
if eps_aligned is not None:
|
||||
df_split_input[temp_eps_col] = eps_aligned
|
||||
# --- End Temp Add --- #
|
||||
|
||||
if not isinstance(df_split_input.index, pd.DatetimeIndex):
|
||||
logger.error(f"{fold_label}: Data index must be DatetimeIndex for splitting. Aborting.")
|
||||
raise SystemExit(f"{fold_label}: Index is not DatetimeIndex in split_data.")
|
||||
|
||||
if not target_columns:
|
||||
logger.error(f"{fold_label}: Target columns list is empty. Aborting.")
|
||||
raise SystemExit(f"{fold_label}: Target columns missing in split_data.")
|
||||
if not target_dir_col:
|
||||
logger.error(f"{fold_label}: Target direction column name is empty. Aborting.")
|
||||
raise SystemExit(f"{fold_label}: Target direction column missing in split_data.")
|
||||
|
||||
# Ensure target columns exist before trying to drop/select them
|
||||
cols_to_drop = [col for col in target_columns if col in df_split_input.columns]
|
||||
if len(cols_to_drop) != len(target_columns):
|
||||
missing_targets = set(target_columns) - set(cols_to_drop)
|
||||
logger.error(f"{fold_label}: Expected target columns {missing_targets} not found in input DataFrame. Aborting.")
|
||||
raise SystemExit(f"{fold_label}: Missing target columns in split_data input.")
|
||||
|
||||
# Exclude temporary columns from feature_cols
|
||||
feature_cols = df_split_input.columns.difference(cols_to_drop + [temp_fwd_ret_col, temp_eps_col])
|
||||
if feature_cols.empty:
|
||||
logger.error(f"{fold_label}: No feature columns remain after excluding targets and temp cols. Aborting.")
|
||||
raise SystemExit(f"{fold_label}: No feature columns found in split_data.")
|
||||
|
||||
# --- Determine if ternary mode is active --- #
|
||||
use_ternary = config.get('gru', {}).get('use_ternary', False)
|
||||
# --- End Determine Ternary --- #
|
||||
|
||||
# Initialize split results
|
||||
X_train_raw, X_val_raw, X_test_raw = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
||||
y_train, y_val, y_test = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
||||
df_train_original, df_val_original, df_test_original = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
|
||||
fwd_ret_train, fwd_ret_val = pd.Series(dtype=float), pd.Series(dtype=float)
|
||||
eps_train, eps_val = None, None
|
||||
y_dir_train_raw_format, y_dir_val_raw_format = pd.Series(dtype=object), pd.Series(dtype=object) # Store raw labels before converting
|
||||
|
||||
# Split based on Walk-Forward dates or ratios
|
||||
if fold_dates and len(fold_dates) == 6 and all(fold_dates): # Check for valid WF tuple
|
||||
train_start, train_end, val_start, val_end, test_start, test_end = fold_dates
|
||||
logger.info(f" Splitting using Walk-Forward dates: Train=[{train_start}, {train_end}), Val=[{val_start}, {val_end}), Test=[{test_start}, {test_end})")
|
||||
|
||||
# Slicing logic
|
||||
df_train_original = df_split_input.loc[train_start:train_end]
|
||||
df_val_original = df_split_input.loc[val_start:val_end]
|
||||
df_test_original = df_split_input.loc[test_start:test_end] if test_start else pd.DataFrame()
|
||||
|
||||
else: # Single split using ratios
|
||||
split_cfg = config.get('split_ratios', {})
|
||||
train_ratio = split_cfg.get('train', 0.7)
|
||||
val_ratio = split_cfg.get('validation', 0.15)
|
||||
test_ratio = round(1.0 - train_ratio - val_ratio, 2)
|
||||
logger.info(f" Splitting using ratios: Train={train_ratio:.2f}, Val={val_ratio:.2f}, Test={test_ratio:.2f}")
|
||||
|
||||
total_len = len(df_split_input)
|
||||
train_end_idx = int(total_len * train_ratio)
|
||||
val_end_idx = int(total_len * (train_ratio + val_ratio))
|
||||
|
||||
df_train_original = df_split_input.iloc[:train_end_idx]
|
||||
df_val_original = df_split_input.iloc[train_end_idx:val_end_idx]
|
||||
df_test_original = df_split_input.iloc[val_end_idx:]
|
||||
|
||||
# --- Extract components from split DataFrames --- #
|
||||
if not df_train_original.empty:
|
||||
X_train_raw = df_train_original[feature_cols]
|
||||
y_train = df_train_original[target_columns]
|
||||
y_dir_train_raw_format = df_train_original[target_dir_col]
|
||||
fwd_ret_train = df_train_original[temp_fwd_ret_col]
|
||||
if temp_eps_col in df_train_original:
|
||||
eps_train = df_train_original[temp_eps_col]
|
||||
|
||||
if not df_val_original.empty:
|
||||
X_val_raw = df_val_original[feature_cols]
|
||||
y_val = df_val_original[target_columns]
|
||||
fwd_ret_val = df_val_original[temp_fwd_ret_col]
|
||||
if temp_eps_col in df_val_original:
|
||||
eps_val = df_val_original[temp_eps_col]
|
||||
|
||||
if not df_test_original.empty:
|
||||
X_test_raw = df_test_original[feature_cols]
|
||||
y_test = df_test_original[target_columns]
|
||||
# --- End Extraction --- #
|
||||
|
||||
# --- Extract Ordinal Labels if Ternary --- #
|
||||
y_dir_train_ordinal = None
|
||||
if not y_train.empty: # Check if training data exists
|
||||
if use_ternary:
|
||||
valid_mask = y_dir_train_raw_format.notna() & y_dir_train_raw_format.apply(lambda x: isinstance(x, list) and len(x) == 3)
|
||||
if valid_mask.any():
|
||||
ordinal_values = y_dir_train_raw_format[valid_mask].apply(np.argmax)
|
||||
y_dir_train_ordinal = pd.Series(np.nan, index=y_dir_train_raw_format.index)
|
||||
y_dir_train_ordinal[valid_mask] = ordinal_values
|
||||
logger.info(f"{fold_label}: Extracted ordinal labels (0, 1, 2) for feature selection. Count: {valid_mask.sum()}")
|
||||
else:
|
||||
logger.warning(f"{fold_label}: No valid list-based ternary labels found in y_dir_train_raw_format to convert to ordinal.")
|
||||
y_dir_train_ordinal = pd.Series(dtype=np.float64) # Return empty series
|
||||
else:
|
||||
y_dir_train_ordinal = y_dir_train_raw_format.astype(int) # Ensure integer type
|
||||
else:
|
||||
y_dir_train_ordinal = pd.Series(dtype=int) # Empty series if no train data
|
||||
# --- End Extract Ordinal Labels --- #
|
||||
|
||||
# --- Extract Ordinal Validation Labels if Ternary --- #
|
||||
y_dir_val_ordinal = None
|
||||
if not y_val.empty: # Check if validation data exists
|
||||
if use_ternary:
|
||||
# Use y_dir_val_raw_format which holds the lists/NaNs
|
||||
valid_mask_val = y_dir_val_raw_format.notna() & y_dir_val_raw_format.apply(lambda x: isinstance(x, list) and len(x) == 3)
|
||||
if valid_mask_val.any():
|
||||
ordinal_values_val = y_dir_val_raw_format[valid_mask_val].apply(np.argmax)
|
||||
y_dir_val_ordinal = pd.Series(np.nan, index=y_dir_val_raw_format.index)
|
||||
y_dir_val_ordinal[valid_mask_val] = ordinal_values_val
|
||||
logger.info(f"{fold_label}: Extracted ordinal validation labels. Count: {valid_mask_val.sum()}")
|
||||
else:
|
||||
logger.warning(f"{fold_label}: No valid ternary labels found in y_dir_val_raw_format.")
|
||||
y_dir_val_ordinal = pd.Series(dtype=np.float64)
|
||||
else:
|
||||
# Use y_dir_val_raw_format which holds 0.0/1.0
|
||||
y_dir_val_ordinal = y_dir_val_raw_format.astype(int)
|
||||
else:
|
||||
y_dir_val_ordinal = pd.Series(dtype=int) # Empty series if no validation data
|
||||
# --- End Extract Ordinal Validation Labels --- #
|
||||
|
||||
# Log split shapes and check for empty splits
|
||||
logger.info(f"Data split complete for {fold_label}:")
|
||||
logger.info(f" Train: X={X_train_raw.shape}, y={y_train.shape}, fwd_ret={fwd_ret_train.shape}, eps={eps_train.shape if eps_train is not None else 'None'} ({X_train_raw.index.min()} to {X_train_raw.index.max()})" if not X_train_raw.empty else " Train: EMPTY")
|
||||
logger.info(f" Val: X={X_val_raw.shape}, y={y_val.shape}, fwd_ret={fwd_ret_val.shape}, eps={eps_val.shape if eps_val is not None else 'None'} ({X_val_raw.index.min()} to {X_val_raw.index.max()})" if not X_val_raw.empty else " Val: EMPTY")
|
||||
logger.info(f" Test: X=({X_test_raw.shape if X_test_raw is not None else 'None'}), y=({y_test.shape if y_test is not None else 'None'}) ({df_test_original.index.min() if df_test_original is not None and not df_test_original.empty else 'N/A'} to {df_test_original.index.max() if df_test_original is not None and not df_test_original.empty else 'N/A'})" )
|
||||
|
||||
# Check required splits are non-empty
|
||||
if X_train_raw.empty or X_val_raw.empty:
|
||||
logger.error(f"Fold {current_fold}: Data splitting resulted in empty train or validation set. Aborting fold.")
|
||||
raise SystemExit(f"Fold {current_fold}: Empty train or validation split detected.")
|
||||
|
||||
return (
|
||||
X_train_raw, X_val_raw, X_test_raw, # Features
|
||||
y_train, y_val, y_test, # Targets
|
||||
df_train_original, df_val_original, df_test_original, # Original DFs
|
||||
y_dir_train_ordinal, # Ordinal Direction Labels (Train)
|
||||
fwd_ret_train, fwd_ret_val, # Forward Returns (Train/Val)
|
||||
eps_train, eps_val, # Epsilon (Train/Val, Optional)
|
||||
y_dir_val_ordinal # Ordinal Direction Labels (Val, Optional)
|
||||
)
|
||||
182
gru_sac_predictor/src/pipeline_stages/sequence_creation.py
Normal file
182
gru_sac_predictor/src/pipeline_stages/sequence_creation.py
Normal file
@ -0,0 +1,182 @@
|
||||
# Stage functions for creating GRU input sequences
|
||||
import logging
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Tuple, Dict, Optional, List
|
||||
import json # Added for saving artefact
|
||||
|
||||
# Assuming IOManager is importable from parent src directory
|
||||
from ..io_manager import IOManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_sequences_fold(
|
||||
X_data: pd.DataFrame,
|
||||
y_data: pd.DataFrame,
|
||||
target_names: List[str], # e.g., ['mu', 'dir3'] or ['mu', 'dir']
|
||||
lookback: int,
|
||||
name: str, # e.g., "Train", "Validation", "Test"
|
||||
config: dict, # For gru.drop_imputed_sequences
|
||||
io: Optional[IOManager] # For saving artefact
|
||||
) -> Tuple[Optional[np.ndarray], Optional[Dict], Optional[pd.Index], int]:
|
||||
"""
|
||||
Transforms pruned, scaled feature DataFrame into 3D sequences for GRU input
|
||||
and extracts corresponding targets for a specific data split (Train/Val/Test).
|
||||
Handles dropping sequences containing imputed bars based on config.
|
||||
|
||||
Args:
|
||||
X_data (pd.DataFrame): Pruned, scaled features for the split.
|
||||
y_data (pd.DataFrame): Targets for the split.
|
||||
target_names (List[str]): List of target column names in y_data (e.g., ['mu', 'dir3']).
|
||||
lookback (int): Sequence length.
|
||||
name (str): Name of the split (e.g., "Train", "Validation", "Test") for logging.
|
||||
config (dict): Pipeline configuration dictionary.
|
||||
io (Optional[IOManager]): IOManager instance for saving artefacts.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- X_seq (np.ndarray or None): 3D feature sequences.
|
||||
- y_seq_dict (Dict or None): Dictionary of target sequences.
|
||||
- target_indices (pd.Index or None): Timestamps corresponding to the targets.
|
||||
- dropped_count (int): Number of sequences dropped due to imputed bars.
|
||||
Returns (None, None, None, 0) if sequence creation fails or data is insufficient.
|
||||
Raises SystemExit on critical errors (e.g., misalignment).
|
||||
"""
|
||||
logger.info(f"--- Creating {name} Sequences ---")
|
||||
use_ternary = config.get('gru', {}).get('use_ternary', False)
|
||||
drop_imputed = config.get('gru', {}).get('drop_imputed_sequences', False)
|
||||
imputed_col_name = 'bar_imputed' # Assuming this is the column name
|
||||
|
||||
# --- Input Validation --- #
|
||||
if X_data is None or y_data is None or X_data.empty or y_data.empty:
|
||||
logger.error(f"{name}: Missing or empty features/targets for sequence creation.")
|
||||
return None, None, None, 0
|
||||
|
||||
# Check for bar_imputed column
|
||||
if imputed_col_name not in X_data.columns:
|
||||
logger.error(f"{name}: Required column '{imputed_col_name}' not found in features. Cannot handle imputed sequences.")
|
||||
# Decide whether to proceed without it or raise error - raising for now
|
||||
raise SystemExit(f"{name}: '{imputed_col_name}' column missing. Sequence creation halted.")
|
||||
|
||||
# Strict Anti-Leakage Check
|
||||
try:
|
||||
assert X_data.index.equals(y_data.index), \
|
||||
f"{name}: Features and targets indices misaligned!"
|
||||
except AssertionError as e:
|
||||
logger.error(f"Data alignment check failed: {e}. Potential data leakage. Aborting.")
|
||||
raise SystemExit(f"{name}: {e}")
|
||||
|
||||
# Check target columns exist
|
||||
if not all(col in y_data.columns for col in target_names):
|
||||
missing_targets = set(target_names) - set(y_data.columns)
|
||||
logger.error(f"{name}: Target columns {missing_targets} not found in y_data. Aborting.")
|
||||
raise SystemExit(f"{name}: Missing target columns for sequencing.")
|
||||
# --- End Input Validation --- #
|
||||
|
||||
# Convert DataFrames to numpy for potential speedup, keep index access
|
||||
features_np = X_data.values
|
||||
imputed_flag_np = X_data[imputed_col_name].values.astype(bool) # Ensure boolean type
|
||||
# Extract targets based on target_names
|
||||
targets_dict_np = {name: y_data[name].values for name in target_names}
|
||||
|
||||
X_seq_list, y_seq_dict_list = [], {name: [] for name in target_names}
|
||||
mask_seq_list = [] # To store the imputed flag sequences
|
||||
target_indices = []
|
||||
|
||||
if len(X_data) <= lookback:
|
||||
logger.warning(f"{name}: DataFrame length ({len(X_data)}) is not greater than lookback ({lookback}). Cannot create sequences.")
|
||||
return None, None, None, 0
|
||||
|
||||
for i in range(lookback, len(features_np)):
|
||||
# Feature window: [i-lookback, i)
|
||||
X_seq_list.append(features_np[i - lookback : i])
|
||||
mask_seq_list.append(imputed_flag_np[i - lookback : i])
|
||||
|
||||
# Targets correspond to index i
|
||||
for t_name in target_names:
|
||||
target_val = targets_dict_np[t_name][i]
|
||||
# Special handling for potential list/array type in ternary labels
|
||||
if use_ternary and 'dir' in t_name and isinstance(target_val, list):
|
||||
target_val = np.array(target_val, dtype=np.float32)
|
||||
y_seq_dict_list[t_name].append(target_val)
|
||||
|
||||
target_indices.append(y_data.index[i]) # Get index corresponding to target
|
||||
|
||||
if not X_seq_list: # Check if any sequences were created
|
||||
logger.warning(f"{name}: No sequences were generated (length <= lookback?).")
|
||||
return None, None, None, 0
|
||||
|
||||
# Convert lists to numpy arrays
|
||||
X_seq = np.array(X_seq_list, dtype=np.float32)
|
||||
mask_seq = np.array(mask_seq_list, dtype=bool)
|
||||
target_indices_pd = pd.Index(target_indices)
|
||||
y_seq_dict_np = {}
|
||||
for t_name in target_names:
|
||||
try:
|
||||
# Attempt to stack; requires consistent shapes
|
||||
if use_ternary and 'dir' in t_name:
|
||||
y_seq_dict_np[t_name] = np.stack(y_seq_dict_list[t_name]).astype(np.float32)
|
||||
else: # Assuming other targets are scalar
|
||||
y_seq_dict_np[t_name] = np.array(y_seq_dict_list[t_name], dtype=np.float32)
|
||||
except ValueError as e:
|
||||
logger.error(f"{name}: Error stacking target '{t_name}': {e}. Check target consistency (especially ternary).", exc_info=True)
|
||||
shapes = [getattr(item, 'shape', type(item)) for item in y_seq_dict_list[t_name]]
|
||||
from collections import Counter
|
||||
logger.error(f"Target shapes/types found: {Counter(shapes)}")
|
||||
raise SystemExit(f"{name}: Inconsistent target shapes for '{t_name}' during sequence creation.") from e
|
||||
|
||||
orig_n = X_seq.shape[0]
|
||||
dropped_count = 0
|
||||
|
||||
# Conditionally drop sequences containing imputed bars
|
||||
if drop_imputed:
|
||||
logger.info(f"{name}: Dropping sequences containing imputed bars (drop_imputed_sequences=True)...")
|
||||
valid_mask = ~mask_seq.any(axis=1)
|
||||
X_seq = X_seq[valid_mask]
|
||||
mask_seq = mask_seq[valid_mask] # Keep mask aligned, though not explicitly used later
|
||||
for t_name in target_names:
|
||||
y_seq_dict_np[t_name] = y_seq_dict_np[t_name][valid_mask]
|
||||
target_indices_pd = target_indices_pd[valid_mask]
|
||||
|
||||
dropped_count = orig_n - X_seq.shape[0]
|
||||
logger.info(f"{name}: Generated {orig_n} sequences, dropped {dropped_count} containing imputed bars. Remaining: {X_seq.shape[0]}")
|
||||
|
||||
# Save summary artifact
|
||||
if io:
|
||||
summary_data = {
|
||||
"split_name": name,
|
||||
"total_sequences_generated": orig_n,
|
||||
"sequences_dropped_imputed": dropped_count,
|
||||
"sequences_remaining": X_seq.shape[0],
|
||||
"drop_imputed_sequences_config": drop_imputed
|
||||
}
|
||||
try:
|
||||
filename = f"imputed_sequence_summary_{name.lower()}.json"
|
||||
io.save_json(summary_data, filename, section='results', indent=4)
|
||||
logger.info(f"Saved imputed sequence summary to results/{filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save imputed sequence summary for {name}: {e}")
|
||||
else:
|
||||
logger.warning(f"IOManager not available, cannot save imputed sequence summary for {name}.")
|
||||
|
||||
else:
|
||||
logger.info(f"{name}: Generated {orig_n} sequences. Keeping sequences with imputed bars (drop_imputed_sequences=False).")
|
||||
|
||||
# Final checks
|
||||
if X_seq.shape[0] == 0:
|
||||
logger.error(f"{name}: No valid sequences remaining after potential filtering. Aborting.")
|
||||
return None, None, None, dropped_count # Return 0 count if no sequences left
|
||||
|
||||
# --- REMOVE: Final Dictionary Mapping (Let GRU handler manage this) --- #
|
||||
# final_y_seq_dict = {
|
||||
# 'mu': y_seq_dict_np['ret'], # Map 'ret' to 'mu'
|
||||
# 'dir3': y_seq_dict_np['dir3'] # Keep 'dir3' as is
|
||||
# }
|
||||
# --- END REMOVE --- #
|
||||
|
||||
# Log final shapes
|
||||
logger.info(f"Sequence shapes created for {name}:")
|
||||
logger.info(f" X={X_seq.shape}, y_keys={list(y_seq_dict_np.keys())}, indices={len(target_indices_pd)}")
|
||||
|
||||
return X_seq, y_seq_dict_np, target_indices_pd, dropped_count
|
||||
@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns.
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
import gymnasium as gym
|
||||
from omegaconf import DictConfig # Added for config typing
|
||||
|
||||
env_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,6 +17,8 @@ class TradingEnv:
|
||||
sigma_predictions: np.ndarray,
|
||||
p_cal_predictions: np.ndarray,
|
||||
actual_returns: np.ndarray,
|
||||
bar_imputed_flags: np.ndarray, # Added imputed flags
|
||||
config: DictConfig, # Added config
|
||||
initial_capital: float = 10000.0,
|
||||
transaction_cost: float = 0.0005,
|
||||
reward_scale: float = 100.0,
|
||||
@ -27,18 +31,22 @@ class TradingEnv:
|
||||
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
|
||||
p_cal_predictions: Calibrated probability of price increase (p_cal).
|
||||
actual_returns: Actual log returns (y_ret).
|
||||
bar_imputed_flags: Boolean array indicating if a bar was imputed.
|
||||
config: OmegaConf configuration object.
|
||||
initial_capital: Starting capital for simulation (used notionally in reward).
|
||||
transaction_cost: Fractional cost per trade.
|
||||
reward_scale: Multiplier for the reward signal.
|
||||
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
|
||||
"""
|
||||
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \
|
||||
"All input arrays must have the same length"
|
||||
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \
|
||||
"All input arrays (predictions, returns, imputed_flags) must have the same length"
|
||||
|
||||
self.mu = mu_predictions
|
||||
self.sigma = sigma_predictions
|
||||
self.p_cal = p_cal_predictions
|
||||
self.actual_returns = actual_returns
|
||||
self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags
|
||||
self.config = config # Store config
|
||||
|
||||
self.initial_capital = initial_capital
|
||||
self.transaction_cost = transaction_cost
|
||||
@ -65,20 +73,36 @@ class TradingEnv:
|
||||
self.state_dim = 5
|
||||
self.action_dim = 1
|
||||
|
||||
# --- Define Gym Spaces ---
|
||||
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32)
|
||||
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
|
||||
# --- End Define Gym Spaces ---
|
||||
|
||||
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
|
||||
|
||||
def _get_state(self) -> np.ndarray:
|
||||
"""Construct the state vector for the current step."""
|
||||
if self.current_step >= self.n_steps:
|
||||
# Handle episode end - return a dummy state or zeros
|
||||
return np.zeros(self.state_dim, dtype=np.float32)
|
||||
|
||||
mu_t = self.mu[self.current_step]
|
||||
sigma_t = self.sigma[self.current_step]
|
||||
p_cal_t = self.p_cal[self.current_step]
|
||||
|
||||
edge_t = 2 * p_cal_t - 1
|
||||
z_score_t = np.abs(mu_t) / (sigma_t + 1e-9)
|
||||
# Calculate edge based on p_cal shape (binary vs ternary)
|
||||
if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3:
|
||||
# Ternary: edge = max(P(up), P(down)) - P(flat)
|
||||
# Assuming order [Down, Flat, Up] for p_cal_t
|
||||
edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1]
|
||||
elif isinstance(p_cal_t, (float, np.number)):
|
||||
# Binary: edge = 2 * P(up) - 1
|
||||
edge_t = 2 * p_cal_t - 1
|
||||
else:
|
||||
env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.")
|
||||
edge_t = 0.0
|
||||
|
||||
_EPS = 1e-9 # Define epsilon locally
|
||||
z_score_t = np.abs(mu_t) / (sigma_t + _EPS)
|
||||
|
||||
# State uses position *before* the action for this step is taken
|
||||
state = np.array([
|
||||
@ -108,11 +132,48 @@ class TradingEnv:
|
||||
Returns:
|
||||
tuple: (next_state, reward, done, info_dict)
|
||||
"""
|
||||
info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False}
|
||||
|
||||
if self.current_step >= self.n_steps:
|
||||
# Should not happen if 'done' is handled correctly, but as safeguard
|
||||
env_logger.warning("Step called after environment finished.")
|
||||
return self._get_state(), 0.0, True, {}
|
||||
return self._get_state(), 0.0, True, info
|
||||
|
||||
# --- Handle Imputed Bar --- #
|
||||
imputed = self.bar_imputed[self.current_step]
|
||||
if imputed:
|
||||
mode = self.config.sac.imputed_handling
|
||||
env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}")
|
||||
if mode == "skip":
|
||||
self.current_step += 1
|
||||
next_state = self._get_state() # Get state for the *next* actual step
|
||||
# Return 0 reward, not done, but indicate skip for buffer handling
|
||||
info['is_imputed_step_skipped'] = True
|
||||
return next_state, 0.0, False, info
|
||||
elif mode == "hold":
|
||||
# Action is forced to maintain current position
|
||||
action = self.current_position
|
||||
elif mode == "penalty":
|
||||
# Calculate reward penalty based on config
|
||||
target_position_penalty = np.clip(action, -1.0, 1.0)
|
||||
reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2
|
||||
# Update position based on agent's intended action (clipped)
|
||||
self.current_position = target_position_penalty
|
||||
# Update capital notionally (no actual return, only cost if implemented)
|
||||
# Cost is implicitly 0 here as there's no trade size if pos doesn't change
|
||||
# If penalty mode allowed position change, cost would apply.
|
||||
# For simplicity, we don't add cost here for the penalty step.
|
||||
self.current_step += 1
|
||||
next_state = self._get_state()
|
||||
scaled_reward = reward * self.reward_scale # Scale the penalty
|
||||
done = self.current_step >= self.n_steps
|
||||
info['capital'] = self.current_capital
|
||||
info['position'] = self.current_position
|
||||
return next_state, scaled_reward, done, info
|
||||
# else: default behavior (treat as normal bar) - implicitly handled by falling through
|
||||
# --- End Handle Imputed Bar --- #
|
||||
|
||||
# --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- #
|
||||
# Action is the TARGET position for the *end* of this step
|
||||
target_position = np.clip(action, -1.0, 1.0)
|
||||
trade_size = target_position - self.current_position
|
||||
@ -150,7 +211,9 @@ class TradingEnv:
|
||||
done = self.current_step >= self.n_steps or self.current_capital <= 0
|
||||
|
||||
next_state = self._get_state()
|
||||
info = {'capital': self.current_capital, 'position': self.current_position}
|
||||
# Update info dict (capital/position might have changed in normal step)
|
||||
info['capital'] = self.current_capital
|
||||
info['position'] = self.current_position
|
||||
|
||||
# Log step details periodically
|
||||
# if self.current_step % 1000 == 0:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
186
gru_sac_predictor/tests/test_sequence_creation.py
Normal file
186
gru_sac_predictor/tests/test_sequence_creation.py
Normal file
@ -0,0 +1,186 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from unittest.mock import MagicMock
|
||||
import os
|
||||
import tempfile
|
||||
import json
|
||||
|
||||
# Adjust the import path based on your project structure
|
||||
from gru_sac_predictor.src.pipeline_stages.sequence_creation import create_sequences_fold
|
||||
from gru_sac_predictor.src.io_manager import IOManager # Adjust path if needed
|
||||
|
||||
# --- Test Fixtures ---
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_with_imputed():
|
||||
"""Creates sample X and y dataframes with a 'bar_imputed' column."""
|
||||
dates = pd.to_datetime(pd.date_range('2023-01-01', periods=20, freq='T'))
|
||||
lookback = 5
|
||||
n_features_orig = 3
|
||||
n_samples = len(dates)
|
||||
|
||||
# Features (including bar_imputed)
|
||||
X_data = pd.DataFrame(
|
||||
np.random.randn(n_samples, n_features_orig),
|
||||
index=dates,
|
||||
columns=[f'feat_{i}' for i in range(n_features_orig)]
|
||||
)
|
||||
# Add bar_imputed column - mark some bars as imputed
|
||||
imputed_flags = np.zeros(n_samples, dtype=bool)
|
||||
imputed_flags[2] = True # Imputed within first potential sequence
|
||||
imputed_flags[8] = True # Imputed within a later potential sequence
|
||||
imputed_flags[15] = True # Imputed near the end
|
||||
X_data['bar_imputed'] = imputed_flags
|
||||
|
||||
# Targets (mu and dir3)
|
||||
y_data = pd.DataFrame({
|
||||
'mu': np.random.randn(n_samples),
|
||||
'dir3': [list(row) for row in np.eye(3)[np.random.randint(0, 3, n_samples)]] # Example one-hot
|
||||
}, index=dates)
|
||||
|
||||
return X_data, y_data
|
||||
|
||||
@pytest.fixture
|
||||
def base_config():
|
||||
"""Creates a base OmegaConf config for testing sequence creation."""
|
||||
conf = OmegaConf.create({
|
||||
'gru': {
|
||||
'lookback': 5,
|
||||
'use_ternary': True, # Matches sample_data_with_imputed
|
||||
'drop_imputed_sequences': True # Default to True for testing dropping
|
||||
},
|
||||
# Add other necessary sections if needed
|
||||
})
|
||||
return conf
|
||||
|
||||
@pytest.fixture
|
||||
def mock_io_manager():
|
||||
"""Creates a mock IOManager for testing artefact saving."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
mock_io = MagicMock(spec=IOManager)
|
||||
mock_io.results_dir = tmpdir
|
||||
saved_jsons = {}
|
||||
def mock_save_json(data, filename, **kwargs):
|
||||
filepath = os.path.join(tmpdir, filename)
|
||||
saved_jsons[filename] = data
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, **kwargs)
|
||||
mock_io.save_json.side_effect = mock_save_json
|
||||
mock_io.get_artifact_path.side_effect = lambda filename: os.path.join(tmpdir, filename)
|
||||
mock_io._saved_jsons = saved_jsons
|
||||
yield mock_io
|
||||
|
||||
# --- Test Functions ---
|
||||
|
||||
def test_sequence_creation_shapes(sample_data_with_imputed, base_config, mock_io_manager):
|
||||
X_data, y_data = sample_data_with_imputed
|
||||
lookback = base_config.gru.lookback
|
||||
n_features = X_data.shape[1]
|
||||
n_samples = len(X_data)
|
||||
expected_n_seq = n_samples - lookback
|
||||
|
||||
# Test without dropping imputed
|
||||
cfg_no_drop = base_config.copy()
|
||||
cfg_no_drop.gru.drop_imputed_sequences = False
|
||||
|
||||
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
|
||||
X_data=X_data, y_data=y_data, target_names=['mu', 'dir3'],
|
||||
lookback=lookback, name="TestSplit", config=cfg_no_drop, io=mock_io_manager
|
||||
)
|
||||
|
||||
assert dropped_count == 0
|
||||
assert X_seq is not None
|
||||
assert y_seq_dict is not None
|
||||
assert indices is not None
|
||||
assert X_seq.shape == (expected_n_seq, lookback, n_features)
|
||||
assert 'mu' in y_seq_dict and y_seq_dict['mu'].shape == (expected_n_seq,)
|
||||
assert 'dir3' in y_seq_dict and y_seq_dict['dir3'].shape == (expected_n_seq, 3)
|
||||
assert len(indices) == expected_n_seq
|
||||
# Check first target index corresponds to lookback-th original index
|
||||
assert indices[0] == X_data.index[lookback]
|
||||
# Check last target index corresponds to last original index
|
||||
assert indices[-1] == X_data.index[-1]
|
||||
|
||||
def test_sequence_dropping_imputed(sample_data_with_imputed, base_config, mock_io_manager):
|
||||
X_data, y_data = sample_data_with_imputed
|
||||
lookback = base_config.gru.lookback
|
||||
n_samples = len(X_data)
|
||||
expected_n_seq_orig = n_samples - lookback
|
||||
|
||||
# Config with dropping enabled (default in fixture)
|
||||
cfg_drop = base_config
|
||||
|
||||
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
|
||||
X_data=X_data.copy(), y_data=y_data.copy(), target_names=['mu', 'dir3'],
|
||||
lookback=lookback, name="TestDrop", config=cfg_drop, io=mock_io_manager
|
||||
)
|
||||
|
||||
assert X_seq is not None
|
||||
assert y_seq_dict is not None
|
||||
assert indices is not None
|
||||
|
||||
# Determine which original sequences should have been dropped
|
||||
# A sequence starting at index i uses data from [i, i+lookback-1]
|
||||
# The target corresponds to index i+lookback
|
||||
# We need to check the imputed flag in the range [i, i+lookback-1] for each potential sequence target index i+lookback
|
||||
|
||||
# Original target indices range from index `lookback` to `n_samples - 1`
|
||||
should_be_dropped_mask = np.zeros(expected_n_seq_orig, dtype=bool)
|
||||
imputed_flags_np = X_data['bar_imputed'].values
|
||||
for seq_idx in range(expected_n_seq_orig):
|
||||
# The features for this sequence are from original indices [seq_idx, seq_idx + lookback - 1]
|
||||
feature_indices_range = slice(seq_idx, seq_idx + lookback)
|
||||
if np.any(imputed_flags_np[feature_indices_range]):
|
||||
should_be_dropped_mask[seq_idx] = True
|
||||
|
||||
expected_dropped_count = np.sum(should_be_dropped_mask)
|
||||
expected_remaining_count = expected_n_seq_orig - expected_dropped_count
|
||||
|
||||
assert dropped_count == expected_dropped_count
|
||||
assert X_seq.shape[0] == expected_remaining_count
|
||||
assert y_seq_dict['mu'].shape[0] == expected_remaining_count
|
||||
assert y_seq_dict['dir3'].shape[0] == expected_remaining_count
|
||||
assert len(indices) == expected_remaining_count
|
||||
|
||||
# Check that the remaining indices are correct (weren't marked for dropping)
|
||||
original_indices = X_data.index[lookback:]
|
||||
expected_remaining_indices = original_indices[~should_be_dropped_mask]
|
||||
pd.testing.assert_index_equal(indices, expected_remaining_indices)
|
||||
|
||||
# Check artifact saving
|
||||
assert 'imputed_sequence_summary_testdrop.json' in mock_io_manager._saved_jsons
|
||||
report_data = mock_io_manager._saved_jsons['imputed_sequence_summary_testdrop.json']
|
||||
assert report_data['total_sequences_generated'] == expected_n_seq_orig
|
||||
assert report_data['sequences_dropped_imputed'] == expected_dropped_count
|
||||
assert report_data['sequences_remaining'] == expected_remaining_count
|
||||
|
||||
def test_sequence_creation_no_imputed_col(sample_data_with_imputed, base_config, mock_io_manager):
|
||||
X_data, y_data = sample_data_with_imputed
|
||||
X_data_no_imputed = X_data.drop(columns=['bar_imputed'])
|
||||
lookback = base_config.gru.lookback
|
||||
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
create_sequences_fold(
|
||||
X_data=X_data_no_imputed, y_data=y_data, target_names=['mu', 'dir3'],
|
||||
lookback=lookback, name="TestNoImputedCol", config=base_config, io=mock_io_manager
|
||||
)
|
||||
assert "'bar_imputed' column missing" in str(excinfo.value)
|
||||
|
||||
def test_sequence_creation_insufficient_data(sample_data_with_imputed, base_config, mock_io_manager):
|
||||
X_data, y_data = sample_data_with_imputed
|
||||
lookback = base_config.gru.lookback
|
||||
# Create data shorter than lookback
|
||||
X_short = X_data.iloc[:lookback-1]
|
||||
y_short = y_data.iloc[:lookback-1]
|
||||
|
||||
X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold(
|
||||
X_data=X_short, y_data=y_short, target_names=['mu', 'dir3'],
|
||||
lookback=lookback, name="TestShort", config=base_config, io=mock_io_manager
|
||||
)
|
||||
|
||||
assert X_seq is None
|
||||
assert y_seq_dict is None
|
||||
assert indices is None
|
||||
assert dropped_count == 0
|
||||
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
@ -0,0 +1,166 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Adjust import path based on structure
|
||||
from gru_sac_predictor.src.trading_env import TradingEnv
|
||||
|
||||
# --- Test Fixtures ---
|
||||
|
||||
@pytest.fixture
|
||||
def sample_env_data():
|
||||
"""Provides sample data for initializing the TradingEnv."""
|
||||
n_steps = 10
|
||||
data = {
|
||||
'mu_predictions': np.random.randn(n_steps) * 0.001,
|
||||
'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005),
|
||||
'p_cal_predictions': np.random.rand(n_steps),
|
||||
'actual_returns': np.random.randn(n_steps) * 0.0015,
|
||||
'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool)
|
||||
}
|
||||
return data
|
||||
|
||||
@pytest.fixture
|
||||
def base_env_config():
|
||||
"""Base configuration for the environment."""
|
||||
return OmegaConf.create({
|
||||
'sac': {
|
||||
'imputed_handling': 'skip', # Default test mode
|
||||
'action_penalty': 0.05
|
||||
},
|
||||
'environment': {
|
||||
'initial_capital': 10000.0,
|
||||
'transaction_cost': 0.0005,
|
||||
'reward_scale': 100.0,
|
||||
'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc
|
||||
}
|
||||
})
|
||||
|
||||
@pytest.fixture
|
||||
def trading_env_instance(sample_env_data, base_env_config):
|
||||
"""Creates a TradingEnv instance with default 'skip' mode."""
|
||||
return TradingEnv(**sample_env_data, config=base_env_config)
|
||||
|
||||
# --- Test Functions ---
|
||||
|
||||
def test_env_initialization(trading_env_instance, sample_env_data):
|
||||
assert trading_env_instance.n_steps == len(sample_env_data['actual_returns'])
|
||||
assert trading_env_instance.current_step == 0
|
||||
assert trading_env_instance.current_position == 0.0
|
||||
assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags'])
|
||||
|
||||
def test_env_reset(trading_env_instance):
|
||||
# Take a few steps
|
||||
trading_env_instance.step(0.5)
|
||||
trading_env_instance.step(-0.2)
|
||||
assert trading_env_instance.current_step > 0
|
||||
# Reset
|
||||
initial_state = trading_env_instance.reset()
|
||||
assert trading_env_instance.current_step == 0
|
||||
assert trading_env_instance.current_position == 0.0
|
||||
assert initial_state is not None
|
||||
assert initial_state.shape == (trading_env_instance.state_dim,)
|
||||
|
||||
def test_env_step_normal(trading_env_instance):
|
||||
# Test a normal step (step 0 is not imputed)
|
||||
initial_pos = trading_env_instance.current_position
|
||||
action = 0.7
|
||||
next_state, reward, done, info = trading_env_instance.step(action)
|
||||
|
||||
assert trading_env_instance.current_step == 1
|
||||
assert trading_env_instance.current_position == action # Position updates to action
|
||||
assert not info['is_imputed_step_skipped']
|
||||
assert not done
|
||||
assert next_state is not None
|
||||
# Reward calculation is complex, just check type/sign if needed
|
||||
assert isinstance(reward, float)
|
||||
|
||||
def test_env_step_imputed_skip(trading_env_instance, sample_env_data):
|
||||
# Step 2 is imputed in sample_env_data
|
||||
trading_env_instance.step(0.5) # Step 0
|
||||
trading_env_instance.step(0.6) # Step 1
|
||||
assert trading_env_instance.current_step == 2
|
||||
initial_pos_before_imputed = trading_env_instance.current_position
|
||||
|
||||
# Action for the imputed step (should be ignored by 'skip')
|
||||
action_imputed = 0.9
|
||||
next_state, reward, done, info = trading_env_instance.step(action_imputed)
|
||||
|
||||
# Should skip step 2 and now be at step 3
|
||||
assert trading_env_instance.current_step == 3
|
||||
# Position should NOT have changed from step 1
|
||||
assert trading_env_instance.current_position == initial_pos_before_imputed
|
||||
assert reward == 0.0 # Skip gives 0 reward
|
||||
assert not done
|
||||
assert info['is_imputed_step_skipped'] == True # Crucial check for buffer
|
||||
# Check that the returned state is for step 3
|
||||
expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3
|
||||
np.testing.assert_array_almost_equal(next_state, expected_state_step_3)
|
||||
|
||||
def test_env_step_imputed_hold(sample_env_data, base_env_config):
|
||||
cfg = base_env_config.copy()
|
||||
cfg.sac.imputed_handling = 'hold'
|
||||
env = TradingEnv(**sample_env_data, config=cfg)
|
||||
|
||||
# Step 2 is imputed
|
||||
env.step(0.5) # Step 0
|
||||
env.step(0.6) # Step 1
|
||||
assert env.current_step == 2
|
||||
position_before_imputed = env.current_position
|
||||
|
||||
# Action for the imputed step (should be overridden by 'hold')
|
||||
action_imputed = -0.5
|
||||
next_state, reward, done, info = env.step(action_imputed)
|
||||
|
||||
# Should process step 2 and move to step 3
|
||||
assert env.current_step == 3
|
||||
# Position should be the same as before the step
|
||||
assert env.current_position == position_before_imputed
|
||||
assert not info['is_imputed_step_skipped']
|
||||
assert not done
|
||||
# Reward should be calculated based on holding the position
|
||||
expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1)
|
||||
expected_cost = 0 # No trade size if holding
|
||||
expected_penalty = 0 # No penalty in hold mode
|
||||
expected_raw_reward = expected_pnl - expected_cost - expected_penalty
|
||||
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||
assert np.isclose(reward, expected_scaled_reward)
|
||||
|
||||
def test_env_step_imputed_penalty(sample_env_data, base_env_config):
|
||||
cfg = base_env_config.copy()
|
||||
cfg.sac.imputed_handling = 'penalty'
|
||||
cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing
|
||||
env = TradingEnv(**sample_env_data, config=cfg)
|
||||
|
||||
# Step 2 is imputed
|
||||
env.step(0.5) # Step 0
|
||||
env.step(0.6) # Step 1
|
||||
assert env.current_step == 2
|
||||
position_before_imputed = env.current_position # Should be 0.6
|
||||
|
||||
# Action for the imputed step
|
||||
action_imputed = -0.2
|
||||
next_state, reward, done, info = env.step(action_imputed)
|
||||
|
||||
# Should process step 2 and move to step 3
|
||||
assert env.current_step == 3
|
||||
# Position should update to the *agent's* action
|
||||
assert env.current_position == np.clip(action_imputed, -1.0, 1.0)
|
||||
assert not info['is_imputed_step_skipped']
|
||||
assert not done
|
||||
|
||||
# Reward calculation is ONLY the penalty
|
||||
expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2
|
||||
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||
assert np.isclose(reward, expected_scaled_reward)
|
||||
|
||||
def test_env_done_condition(trading_env_instance, sample_env_data):
|
||||
n_steps = len(sample_env_data['actual_returns'])
|
||||
# Step through the environment
|
||||
done = False
|
||||
for i in range(n_steps):
|
||||
_, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1))
|
||||
if i < n_steps - 1:
|
||||
assert not done
|
||||
else:
|
||||
assert done # Should be done on the last step
|
||||
Loading…
x
Reference in New Issue
Block a user