diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..3c16940a --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.venv +.git +.git.BAD +logs +results +models diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..67157212 --- /dev/null +++ b/config.yaml @@ -0,0 +1,361 @@ +# GRU-SAC Predictor v3 Configuration File +# This file parameterizes all major components of the pipeline. + +pipeline: + description: "Configuration for the GRU-SAC trading predictor pipeline." + # Define stages to run, primarily for debugging/selective execution. + # stages_to_run: ["data", "features", "gru", "sac", "backtest", "aggregate"] # Example: Run all + +# --- Data Loading and Initial Processing --- +data: + ticker: "BTC-USDT" # Ticker symbol (adjust based on DataLoader capabilities) + exchange: "bnbspot" # Exchange name (adjust based on DataLoader) + interval: "1min" # Data interval (e.g., '1min', '5min', '1h') + start_date: "2024-09-19" # Start date for data loading (YYYY-MM-DD) + end_date: "2024-11-03" # End date for data loading (YYYY-MM-DD) + db_dir: '../data/crypto_market_data' # to database directory (relative to project root) + bar_frequency: "1T" # Added based on instructions + missing: # Added missing data section + strategy: "neutral" # drop | neutral | ffill | interpolate + max_gap: 60 # max consecutive missing bars allowed + interpolate: + method: "linear" + limit: 10 + + volatility_sampling: # Optional volatility-based downsampling + enabled: False + window: 30 # Window for volatility calculation (e.g., 30 minutes) + quantile: 0.5 # Quantile threshold for sampling (0.0 to 1.0) + + # Optional: Add parameters for data cleaning if needed + # e.g., max_nan_fill_gap: 5 + +# --- Feature Engineering --- +features: + # Parameters for FeatureEngineer.add_base_features + atr_window: 14 + rsi_window: 14 + adx_window: 14 + macd_fast: 12 + macd_slow: 26 + macd_signal: 9 + # Add parameters for other indicators (e.g., Chaikin, SVI, Volatility) if configurable + # chaikin_ad_window: 10 + # svi_window: 10 + # volatility_window: 14 # e.g., for a rolling std dev feature + + # Parameters for feature selection (used by FeatureEngineer.select_features) + # These might include method (e.g., 'correlation', 'mutual_info', 'lgbm'), thresholds, etc. + selection_method: "correlation" # Example + correlation_threshold: 0.02 # Example threshold for correlation-based selection + min_features_after_selection: 10 # Minimum number of features to keep + +# --- Data Splitting (Walk-Forward or Single Split) --- +walk_forward: + enabled: True # Set to False for a single train/val/test split based on ratios below + # Ratios are used if enabled=False OR for preparing data for the SAC environment (which uses val split) + split_ratios: + train: 0.6 + validation: 0.2 + # test ratio is inferred (1.0 - train - validation) + + # Settings used only if enabled=True + num_folds: 5 # Number of walk-forward folds. If <= 1, uses rolling window mode. + # --- Rolling Window Specific Settings (used if num_folds <= 1) --- + train_period_days: 25 # Length of the training period per fold + validation_period_days: 10 # Length of the validation period per fold + test_period_days: 10 # Length of the test period per fold + step_days: 7 # How many days to slide the window forward for the next fold (Recommendation: >= test_period_days / 2) + expanding_window: false # If true, train period grows; otherwise, it slides (Rolling Window only) + + # --- General Walk-Forward Settings --- + initial_offset_days: 14 # Optional: Skip days at the beginning before the first fold starts + purge_window_minutes: 0 # Optional: Drop training samples overlapping with val/test lookback (in minutes) + embargo_minutes: 0 # Optional: Skip minutes after train/val period ends before starting next period (in minutes) + final_holdout_period_days: 0 # Optional: Reserve days at the very end, excluded from all folds + min_fold_duration_days: 1 # Optional: Minimum total duration (days) required for a generated fold (train_start to test_end) + # --- Gap and Regime Settings --- + gap_threshold_minutes: 5 # Split data into chunks if gap > this threshold + min_chunk_days: 1 # Minimum duration (days) for a chunk to be considered for fold generation + regime: + enabled: true + indicator: volatility # e.g., 'volatility', 'trend_strength', 'rsi' + indicator_params: + window: 20 # Parameter for the chosen indicator (e.g., rolling window size) + quantiles: [0.33, 0.66] # Quantiles to define regime boundaries (e.g., [0.33, 0.66] for 3 regimes) + min_regime_representation_pct: 10 # Minimum % each regime must occupy in train/val/test periods + + # --- Drift-Triggered Retraining (Informational - Requires external implementation) --- + # drift: + # enable: false + # feature_list: ["close", "volume", "rsi"] # Example features to monitor + # p_threshold: 0.01 # Example drift detection threshold (e.g., for KS test p-value) + +# --- GRU Model Configuration --- +gru: + # Label Definition + train_gru: true + use_ternary: True # Use ternary (Up/Flat/Down) labels? If False, uses binary (Up/Down). + prediction_horizon: 5 # Lookahead period for target returns/labels (in units of 'data.interval') + flat_sigma_multiplier: 0.25 # 'k' factor for ternary flat label threshold (eps = k * rolling_std(fwd_ret)) + label_smoothing: 0.0 # Alpha for binary label smoothing (0.0 disables) + drop_imputed_sequences: true # Added based on instructions + + # Model Architecture (V3) - Used by GRUModelHandler.build_gru_model_v3 + gru_units: 96 # Number of units in GRU layer + attention_units: 16 # Number of units in MultiHeadAttention layer (set to 0 to disable) + dropout_rate: 0.1 # Dropout rate for GRU and Attention layers + learning_rate: 1e-4 # Learning rate for Adam optimizer + l2_reg: 1e-4 # L2 regularization factor for Dense layers + + # Loss Function Parameters (V3) - Used by GRUModelHandler.build_gru_model_v3 + focal_gamma: 2.0 # Gamma parameter for categorical focal loss (if use_ternary=True) + focal_label_smoothing: 0.1 # Label smoothing within focal loss calculation + huber_delta: 1.0 # Delta parameter for Huber loss (mu/return prediction) + loss_weight_mu: 0.3 # Weight for the mu/return prediction loss component + loss_weight_dir3: 1.0 # Weight for the direction prediction loss component + + # Training Parameters - Used by GRUModelHandler.train + lookback: 60 # Sequence length (timesteps) for GRU input + epochs: 25 # Maximum number of training epochs + batch_size: 128 # Training batch size + patience: 5 # Early stopping patience (epochs with no improvement in val_loss) + # early_stopping_monitor: "val_loss" # Monitor for early stopping (hardcoded in handler) + # training_shuffle: False # Whether to shuffle training data each epoch (hardcoded False) + + # Loading Control - Used by pipeline_stages.modelling.train_or_load_gru_fold + load_gru_model: + run_id: null # Set to a specific GRU pipeline run ID to load model/scaler from instead of training + fold_num: null # Optional: Specify fold number (e.g., 1, 2...). If null, handler might load best/last fold based on its internal logic. + +# --- Hyperparameter Tuning (Optuna/W&B) --- +hyperparameter_tuning: + gru: + sweep_enabled: False # Master switch to enable Optuna sweep for GRU + # If enabled=True, define sweep parameters here: + study_name: "gru_optimization" + direction: "minimize" # "minimize" val_loss or "maximize" val_accuracy + n_trials: 50 + inner_cv_splits: 3 # Number of inner folds for nested cross-validation + pruner: "median" # e.g., "median", "hyperband" + sampler: "tpe" # e.g., "tpe", "random" + search_space: + gru_units: { type: "int", low: 32, high: 128, step: 16 } + attention_units: { type: "int", low: 8, high: 64, step: 8 } + dropout_rate: { type: "float", low: 0.05, high: 0.3 } + learning_rate: { type: "loguniform", low: 1e-5, high: 1e-3 } + l2_reg: { type: "loguniform", low: 1e-5, high: 1e-3 } + loss_weight_mu: { type: "float", low: 0.1, high: 0.9 } + batch_size: { type: "categorical", choices: [64, 128, 256] } + +# --- Probability Calibration --- +calibration: + method: vector + optimize_edge_threshold: true + edge_threshold: 0.5 # Initial or fixed threshold if not optimizing + # Rolling calibration settings (if method requires) + rolling_window_size: 250 + rolling_min_samples: 50 + rolling_step: 50 + reliability_plot_bins: 10 # Number of bins for reliability plot + +# --- Soft Actor-Critic (SAC) Agent and Training --- +sac: + imputed_handling: "hold" # Added based on instructions + action_penalty: 0.05 # Added based on instructions + # Agent Hyperparameters - Used by SACTradingAgent.__init__ + gamma: 0.99 # Discount factor + tau: 0.005 # Target network update rate (polyak averaging) + actor_lr: 3e-4 # Learning rate for the actor network + critic_lr: 3e-4 # Learning rate for the critic networks + # Optional: LR Decay for actor/critic (if implemented in agent) + lr_decay_rate: 0.96 + decay_steps: 100000 + # Optional: Ornstein-Uhlenbeck noise parameters (if used) + ou_noise_stddev: 0.2 + alpha: 0.2 # Initial entropy temperature (used if alpha_auto_tune=False) + alpha_auto_tune: True # Enable automatic tuning of entropy temperature alpha + target_entropy: -1.0 # Target entropy for alpha tuning; -action_dim is common default (-1.0 for action_dim=1) + + # Training Loop Parameters - Used by SACTrainer._training_loop + total_training_steps: 100000 # Total steps for the SAC training loop + buffer_capacity: 1000000 # Maximum size of the replay buffer + batch_size: 256 # Batch size for sampling from replay buffer + start_steps: 10000 # Number of initial steps with random actions before training starts + update_after: 1000 # Number of steps to collect before first agent update + update_every: 50 # Perform agent updates every N steps + save_freq: 5000 # Save agent checkpoint every N steps + log_freq: 100 # Log training metrics (losses, Q-values) to TensorBoard every N steps + eval_freq: 5000 # Evaluate agent performance every N steps (requires evaluation logic) + + # Alpha (Entropy Temperature) Annealing - Used by SACTrainer._training_loop + alpha_anneal_start_step: 10000 # Step to start annealing alpha (if auto-tune enabled) + alpha_anneal_end_step: 50000 # Step to finish annealing alpha + initial_alpha: 0.2 # Alpha value before annealing starts + final_alpha: 0.01 # Target alpha value after annealing finishes + + # Prioritized Experience Replay (PER) - Used by SACTrainer / PrioritizedReplayBuffer + use_per: False # Enable PER? If False, uses standard uniform replay buffer. + # PER parameters (used only if use_per=True) + per_alpha: 0.6 # Priority exponent (how much prioritization). 0 = uniform. + per_beta_start: 0.4 # Initial importance sampling exponent (annealed to 1.0) + per_beta_frames: 100000 # Steps over which to anneal beta from beta_start to 1.0 + # Optional PER Alpha annealing (anneals the priority exponent alpha) + per_alpha_anneal_enabled: False + per_alpha_start: 0.6 + per_alpha_end: 0.4 + per_alpha_anneal_steps: 50000 + + # Oracle Seeding (Potentially deprecated/experimental) + oracle_seeding_pct: 0.0 # Percentage of buffer to pre-fill using heuristic policy + + # State Normalization - Used by SACTrainer + use_state_filter: True # Use MeanStdFilter for state normalization? + state_dim_fallback: 5 # Fallback state dim if cannot be inferred (e.g., from loaded agent metadata) + action_dim_fallback: 1 # Fallback action dim if cannot be inferred + + # Loading Control - Used by pipeline_stages.modelling.train_or_load_sac_fold + train_sac: True # Master switch: Train SAC agent? If False, attempts to load based on control flags. + +# --- SAC Agent Aggregation (Post Walk-Forward) --- +sac_aggregation: + enabled: True # Aggregate agents from multiple folds? + method: "average_weights" # Currently only 'average_weights' is supported + +# --- Trading Environment Simulation --- +environment: # Parameters passed to TradingEnv and Backtester + initial_capital: 10000.0 # Starting capital for simulation/backtest + transaction_cost: 0.0005 # Fractional cost per trade (e.g., 0.0005 = 0.05%) + # Reward shaping parameters (used within TradingEnv._calculate_reward) + reward_scale: 100.0 # Multiplier applied to the raw PnL reward + action_penalty_lambda: 0.0 # Penalty factor for action magnitude or changes (0 disables) + # Add other env parameters if needed (e.g., position limits, reward clipping) + +# --- Baseline Models --- +baselines: # Configuration for BaselineChecker + run_baseline1: True # Run Logistic Regression baseline? (Requires binary labels) + run_baseline2: False # Run placeholder/second baseline? + # Parameters for Logistic Regression (Baseline 1) + logistic_regression: + max_iter: 1000 + solver: "lbfgs" + random_state: 42 + val_subset_split_ratio: 0.2 # Internal split ratio used within baseline check + val_subset_shuffle: False # Shuffle for internal split? + ci_confidence_level: 0.95 # Confidence level for binomial test CI + + # Parameters for RandomForestClassifier (Baseline 2) + random_forest: + n_estimators: 100 # Number of trees + max_depth: 10 # Maximum depth of trees (None for unlimited) + min_samples_split: 2 # Minimum samples required to split an internal node + min_samples_leaf: 1 # Minimum number of samples required to be at a leaf node + random_state: 42 + n_jobs: -1 # Use all available CPU cores + # Use the same internal split and CI settings as LogReg for comparison + val_subset_split_ratio: 0.2 + val_subset_shuffle: False + ci_confidence_level: 0.95 + + # --- Ternary Baselines (run only if gru.use_ternary=True) --- # + run_baseline3: True # Run Multinomial Logistic Regression? + run_baseline4: False # Run Ternary Random Forest? + + # Parameters for Multinomial Logistic Regression (Baseline 3) + multinomial_logistic_regression: + max_iter: 1000 + solver: "lbfgs" + multi_class: "multinomial" # Explicitly set for clarity + random_state: 42 + # Use same internal split/CI settings + val_subset_split_ratio: 0.2 + val_subset_shuffle: False + ci_confidence_level: 0.95 + + # Parameters for Ternary RandomForestClassifier (Baseline 4) + ternary_random_forest: + n_estimators: 100 + max_depth: 10 + min_samples_split: 2 + min_samples_leaf: 1 + random_state: 42 + n_jobs: -1 + # Use same internal split/CI settings + val_subset_split_ratio: 0.2 + val_subset_shuffle: False + ci_confidence_level: 0.95 + +# --- Pipeline Validation Gates --- +validation_gates: # Thresholds checked at different stages to potentially halt the pipeline + # Binary Baseline Gates (used if gru.use_ternary=False) + run_baseline_check: True # Master switch for running *any* applicable baseline check + baseline1_min_ci_lb: 0.52 # Binary LR (Raw) CI LB threshold (internal split) + baseline2_min_ci_lb: 0.54 # Binary RF (Raw) CI LB threshold (internal split) + baseline1_edge_min_ci_lb: 0.60 # Binary LR (Edge-Filtered) CI LB threshold (validation set) + baseline2_edge_min_ci_lb: 0.62 # Binary RF (Edge-Filtered) CI LB threshold (validation set) + + # Ternary Baseline Gates (used if gru.use_ternary=True) + baseline3_min_ci_lb: 0.40 # Ternary LR (Raw) CI LB threshold (internal split, vs 1/3 chance) + baseline4_min_ci_lb: 0.42 # Ternary RF (Raw) CI LB threshold (internal split, vs 1/3 chance) + baseline3_edge_min_ci_lb: 0.57 # Ternary LR (Edge-Filtered) CI LB threshold (validation set) + baseline4_edge_min_ci_lb: 0.58 # Ternary RF (Edge-Filtered) CI LB threshold (validation set) + + gru_performance: # Checks on GRU validation predictions (after calibration) + enabled: True + min_edge_accuracy: 0.60 # Minimum accuracy using the optimized/configured edge threshold + max_brier_score: 0.24 # Maximum acceptable Brier score + + backtest: # Master switch for all backtest performance gates + enabled: True + backtest_performance: # Specific performance checks on the backtest results + enabled: True # Enable/disable Sharpe and Max DD checks specifically + min_sharpe_ratio: 1.2 # Minimum acceptable annualized Sharpe ratio + max_drawdown_pct: 15.0 # Maximum acceptable drawdown percentage (positive value) + +# --- Pipeline Control Flags --- +control: + generate_plots: True # Generate and save plots (learning curves, backtest summary, etc.)? + + # Loading specific models instead of training/running stages + # Note: train_gru and train_sac flags override these if both are set + # GRU Loading: see gru.load_gru_model section + # SAC Loading: Used if sac.train_sac=False + sac_load_run_id: null # Specify SAC Training Run ID (e.g., "sac_train_...") to load for backtesting + sac_load_step: 'final' # 'final' or specific step number checkpoint to load + + # Resuming SAC Training (Loads agent and potentially buffer state to continue training) + sac_resume_run_id: null # Specify SAC Training Run ID to resume from + sac_resume_step: 'final' # 'final' or step number checkpoint to resume from + +# --- Logging Configuration --- +logging: + console_level: "INFO" # Level for console output: DEBUG, INFO, WARNING, ERROR, CRITICAL + file_level: "DEBUG" # Level for file output: DEBUG, INFO, WARNING, ERROR, CRITICAL + log_to_file: True # Enable logging to a file? + # Log file path determined by IOManager: logs//pipeline.log + log_format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' # Format string + log_date_format: '%Y-%m-%d %H:%M:%S' # Date format for logs + # Rotating File Handler settings (if log_to_file=True) + log_file_max_bytes: 10485760 # Max size in bytes (e.g., 10MB) before rotation + log_file_backup_count: 5 # Number of backup log files to keep + +# --- Output Artifacts Configuration --- +output: + base_dirs: # Base directories (relative to project root or absolute) + results: "results" + models: "models" + logs: "logs" + # Figure generation settings + figure_dpi: 150 # DPI for saved figures + figure_size: [16, 9] # Default figure size (width, height in inches) + figure_footer: "© GRU-SAC v3" # Footer text added to plots + plot_style: "seaborn-v0_8-darkgrid" # Matplotlib style sheet to use + # Plot-specific settings + reward_plot_smoothing_alpha: 0.2 # EMA alpha for SAC reward plot smoothing + # reliability_plot_bins: 10 # Defined under calibration section + + # IOManager settings + dataframe_save_format: "parquet_if_large" # "csv", "parquet", "parquet_if_large" + dataframe_max_csv_mb: 100 # Threshold (MB) for using Parquet if format is parquet_if_large + +# ... existing code ... \ No newline at end of file diff --git a/cuda-keyring_1.1-1_all.deb b/cuda-keyring_1.1-1_all.deb new file mode 100644 index 00000000..d0229418 Binary files /dev/null and b/cuda-keyring_1.1-1_all.deb differ diff --git a/gru_sac_predictor/.gitignore b/gru_sac_predictor/.gitignore new file mode 100644 index 00000000..b453dd02 --- /dev/null +++ b/gru_sac_predictor/.gitignore @@ -0,0 +1,59 @@ +# Ignore everything by default +* + +# Un-ignore specific files to track + +# Scripts +!scripts/aggregate_metrics.py +!scripts/run_validation.sh + +# Package initialization +!__init__.py +!src/__init__.py + +# Core source files +!src/backtester.py +!src/calibrator_vector.py +!src/baseline_checker.py +!src/calibrator.py +!src/calibrate.py +!src/data_loader.py +!src/gru_hyper_tuner.py +!src/feature_engineer.py +!src/features.py +!src/gru_model_handler.py +!src/io_manager.py +!src/logger_setup.py +!src/metrics.py +!src/sac_agent.py +!src/sac_trainer.py +!src/trading_env.py +!src/trading_pipeline.py + +# Tests +!tests/ +!tests/*.py +!tests/**/*.py + +# Configuration files +!config.yaml +!config_baseline.yaml + +# Documentation and logs +!README.md +!requirements.txt +!revisions.txt +!main_v7.log + +# Entry points +!run.py +!train_sac_runner.py + +# Git configuration +!.gitignore + +# Make sure parent directories are un-ignored for nesting to work +!src/ +!scripts/ + +*.txt \ No newline at end of file diff --git a/gru_sac_predictor/README.md b/gru_sac_predictor/README.md index 532c6f7e..0bff0242 100644 --- a/gru_sac_predictor/README.md +++ b/gru_sac_predictor/README.md @@ -1,143 +1,372 @@ -# GRU + Simplified SAC Trading Agent +# GRU-SAC Trading Predictor -This project implements a cryptocurrency trading system using a GRU model for price prediction and a **Simplified SAC (Soft Actor-Critic)** agent for position sizing. +This project implements a multi-stage machine learning pipeline for predicting financial market movements (specifically cryptocurrency price direction) and generating trading signals. It combines a Gated Recurrent Unit (GRU) network for initial prediction with a Soft Actor-Critic (SAC) reinforcement learning agent for refining trading decisions. The pipeline features advanced walk-forward validation, hyperparameter tuning, probability calibration, baseline model comparisons, and configurable validation gates. -The system predicts future *price* using a GRU model adapted from the V6 architecture. It calculates the *predicted percentage return* from this price prediction and estimates prediction *uncertainty* based on the standard deviation of Monte Carlo dropout predictions. It also extracts recent *momentum* and *volatility* features. These values, along with a risk proxy (`z_proxy`), form the **5-dimensional state** input (`[predicted_return, mc_unscaled_std_dev, z_proxy, momentum_5, volatility_20]`) to the SAC reinforcement learning agent, which determines optimal position sizing (-1 to +1) using a **squashed Gaussian policy** and **automatic entropy tuning**. +## Key Features -The system incorporates efficiency improvements by pre-computing GRU predictions and uncertainties before generating SAC experiences or running the backtest. It includes detailed backtesting, performance reporting, and visualization capabilities, including **SAC training loss plots**. +* **Data Handling:** Loads tick/minute-level data, preprocesses (resampling, missing value imputation), and handles large datasets efficiently. +* **Feature Engineering:** Generates a comprehensive set of technical indicators (ATR, RSI, MACD, etc.) and allows for feature selection. +* **Advanced Walk-Forward Validation:** Implements a flexible `FoldGenerator` supporting: + * Contiguous data chunk splitting based on time gaps. + * Regime awareness (e.g., volatility regimes) ensuring folds are representative. + * N-Fold Block Splitting. + * Rolling/Expanding Window validation. + * Minimum fold duration checks. + * Purging and embargo periods. +* **GRU Model (v3):** + * Predicts future price movement direction (Binary or Ternary) and expected return magnitude (`mu`). + * Architecture includes GRU layer, optional Multi-Head Attention with causal masking, Layer Normalization, and separate output heads. + * Uses Huber loss for `mu` and Categorical Focal Loss for direction prediction. + * Supports optional Nested Cross-Validation using Optuna for hyperparameter tuning within each walk-forward fold. +* **Soft Actor-Critic (SAC) Agent:** + * Trains offline using the GRU predictions as part of the state. + * Aims to learn an optimal trading policy (position sizing/timing). + * Features automatic entropy tuning (alpha), optional Prioritized Experience Replay (PER), and state normalization. +* **Probability Calibration:** Calibrates GRU output probabilities using Vector Scaling (for ternary) or Temperature Scaling (for binary - *currently assumes ternary/vector*). Includes optimization for the decision edge threshold. +* **Baseline Models:** Compares GRU performance against standard baselines (Logistic Regression, Random Forest) for both binary and ternary classification tasks. +* **Validation Gates:** Implements configurable checks at multiple pipeline stages (Baseline performance, GRU performance, Backtest performance, Final Release) to ensure model quality and potentially halt execution. +* **Incremental Fold Reporting:** Generates detailed JSON reports for each fold, logging the status and results of each validation gate step, facilitating debugging even if a fold fails mid-way. +* **Modularity & Configuration:** Highly configurable via `config.yaml`. Pipeline stages are modular functions. +* **Artifact Management:** Uses an `IOManager` (assumed) to handle saving/loading of data, models, results, logs, and figures in a structured run-specific manner. -## System Design +## Pipeline Flow Diagram -The system integrates a GRU predictor and a Simplified SAC agent within a backtesting framework. +```mermaid +graph TD + subgraph Global_Setup ["Global Setup"] + direction LR + A[Load Config] --> B["Init IOManager"]; + B --> C["Init Pipeline Components: DataLoader, FeatureEng, GRUHandler, etc."]; + C --> D[Load Full Raw Data]; + D --> E["Init FoldGenerator w/ Data"]; + end -### 1. Data Flow & Processing + subgraph Fold_Loop ["Fold Loop"] + direction TB + F[Generate Folds via FoldGenerator] --> G{Iterate Folds}; + G -- Fold Data --> H[Start Fold Processing]; -1. **Loading:** Raw 1-minute OHLCV data is loaded from the SQLite database directory specified in `main.py` (e.g., `downloaded_data/`) using `src.data_pipeline.load_data_from_db` which utilizes `src.crypto_db_fetcher.CryptoDBFetcher`. -2. **Splitting:** Data is chronologically split into training (60%), validation (20%), and test (20%) sets using `src.data_pipeline.create_data_pipeline`. -3. **GRU Training / Loading (on Train/Validation Sets):** - * If `TRAIN_GRU_MODEL` is `True`: - * *Preprocessing*: `TradingSystem._preprocess_data_for_gru_training` calculates V6 features plus basic return features (`calculate_v6_features`) on the raw train/val data. It determines the future *price* target (`prediction_horizon` steps ahead) and aligns features, targets (prices), and the *unscaled* starting close prices needed for return calculation. - * *Scaling*: Within `TradingSystem.train_gru`, a `StandardScaler` is fitted *only* on the training features. A `MinMaxScaler` is fitted *only* on the training future *price* targets. Train and validation features/targets are scaled using these fitted scalers. - * *Sequence Creation*: `src.data_pipeline.create_sequences_v2` creates input sequences `(batch, sequence_length, num_features)` and corresponding scaled target prices using the scaled features/targets and the unscaled start prices. - * *Model Training*: `CryptoGRUModel.train` builds the V6-style GRU model (if not already built) and trains it using Mean Squared Error (MSE) loss on the scaled sequences. Callbacks monitor `val_rmse` for early stopping and model checkpointing. The best model (`best_model_reg.keras`) and the fitted scalers (`feature_scaler.joblib`, `y_scaler.joblib`) are saved. - * If `LOAD_EXISTING_SYSTEM` is `True` and `TRAIN_GRU_MODEL` is `False`: - * Attempts to load a pre-trained GRU model and scalers. If `GRU_MODEL_LOAD_RUN_ID` is set in `main.py`, it loads the GRU from that specific run ID's directory (`gru_sac_predictor/models/run_`); otherwise, it attempts to load from the default `MODEL_SAVE_PATH` (expecting the model and scalers to be directly in that path). - * **Note:** SAC model loading is handled *separately* based on the `LOAD_SAC_AGENT` flag and the `GRU_MODEL_LOAD_RUN_ID` setting (see Model Loading/Training section in `main.py` for details). -4. **SAC Training (on Validation Set):** - * **Training Loop:** The training process runs for a fixed number of epochs (`SAC_EPOCHS`). - * **Experience Generation** (`TradingSystem.generate_trading_experiences`): - * **Efficiency:** Pre-computes all required GRU outputs (predicted returns, uncertainties) for the entire validation set by calling `CryptoGRUModel.evaluate` *once*. - * **State Extraction:** Extracts pre-computed GRU outputs and relevant features (`momentum_5`, `volatility_20`) from the validation features dataframe. - * **Experience Format:** Iterates through the pre-computed results. Forms the 5D state `s_t = [pred_return_t, uncertainty_t, z_proxy_t, momentum_5_t, volatility_20_t]` (where `z_proxy` uses the position *before* the action). The SAC agent (`SimplifiedSACTradingAgent.get_action`) provides a *non-deterministic* action `a_t` and `log_prob`. The next state `s_{t+1}` is constructed similarly (using `action` for `z_proxy`). A reward `r_t = action * actual_return - cost` is calculated. The transition `(s_t, a_t, r_t, s_{t+1}, done)` is stored. - * **Note:** Experience sampling strategies (recency bias, stratification) defined in `experience_config` are currently *not* implemented in `generate_trading_experiences` but the configuration remains. - * **Agent Training** (`TradingSystem.train_sac` calls `SimplifiedSACTradingAgent.train`): Iterates for `SAC_EPOCHS`. In each epoch, the agent performs one training step. Batches are sampled from the replay buffer. Actor and Critic networks are updated using the SAC algorithm with automatic alpha tuning. Agent uses `store_transition` to add experiences to its internal NumPy buffer. - * **History Plotting:** After successful training, `plot_sac_training_history` is called to generate and save a plot of actor and critic losses. -5. **Backtesting (on Test Set):** - * *Pre-computation* (`ExtendedBacktester.backtest`): Preprocesses test data, scales, creates sequences, calls `CryptoGRUModel.evaluate` once for GRU outputs, and extracts required features (`momentum_5`, `volatility_20`). - * *State Generation*: Constructs the 5D state `s_t = [pred_return, uncertainty, z_proxy, momentum_5, volatility_20]` using pre-computed results and the current position. - * *Action Selection*: The trained `SimplifiedSACTradingAgent` selects a *deterministic* action `a_t` (unpacking the tuple returned by `get_action`). - * *Portfolio Simulation*: Calculates PnL based on the previous position, actual return, and transaction costs. - * *Logging*: Records detailed metrics, trade history, and timestamps. -6. **Evaluation:** - * *Performance Metrics*: `ExtendedBacktester._calculate_performance_metrics` computes overall portfolio metrics (Sharpe, Sortino, Drawdown, correlations, etc.) and Buy & Hold benchmark metrics. - * *Visualization*: `ExtendedBacktester.plot_results` generates a 3-panel plot: GRU Predictions vs Actual Price (with uncertainty), SAC Actions (Position Size), and Portfolio Value vs Buy & Hold (with trade markers). - * *Reporting*: `ExtendedBacktester.generate_performance_report` creates a detailed Markdown report. + subgraph Fold_Processing ["Fold N Processing"] + direction TB + I[Engineer Features] --> J[Define Labels & Align]; + J --> K[Split Data: Train/Val/Test]; + K --> L[Scale Features]; + L --> M["Coarse Filter Features (Optional)"]; + M --> N[Run Initial Baseline Checks]; + N --> N_Report["Record Baseline 1 Report"]; + N_Report --> N_Gate{Baseline Gate 1 Passed?}; -### 2. Core Components & Inputs/Outputs + N_Gate -- Yes --> O[Select & Prune Features]; + O --> P[Run Post-Pruning Baseline Checks]; + P --> P_Report["Record Baseline 2 Report"]; + P_Report --> P_Gate{Baseline Gate 2 Passed?}; -* **`src.crypto_db_fetcher.CryptoDBFetcher`**: Loads and resamples data from SQLite DBs. -* **`src.data_pipeline`**: Functions for DB loading, data splitting, sequence creation. -* **`src.trading_system.calculate_v6_features`**: Calculates features (TA-Lib based V6 set + past returns). -* **`src.trading_system._preprocess_data_for_gru_training`**: Prepares features, future price targets, and start prices. -* **`src.gru_predictor.CryptoGRUModel`**: (V6 Adaptation) - * `train()`: Trains the GRU price prediction model. Saves model (`.keras`) and scalers (`.joblib`). - * `evaluate()`: Performs standard prediction and MC dropout inference. Returns dict including `pred_percent_change`, `mc_unscaled_std_dev`, `predicted_unscaled_prices`, `true_unscaled_prices`. -* **`src.sac_agent_simplified.SimplifiedSACTradingAgent`**: (V7 Simplified) - * **Goal:** Learns a policy mapping state to optimal position size (-1.0 to +1.0). Optimized for faster training. - * **State Input:** 5-element array `[predicted_return, mc_unscaled_std_dev, z_proxy, momentum_5, volatility_20]`. - * **Action Output:** Float between -1.0 and +1.0. - * `get_action()`: Selects action (stochastic or deterministic). Adds uncertainty-scaled noise during exploration. - * `store_transition()`: Adds experience to internal NumPy buffer. - * `train()`: Updates agent using buffer samples (internally handles batch size). Uses `@tf.function` for performance. - * `save()` / `load()`: Handles Actor/Critic weights (`.weights.h5`), potentially `alpha.npy`. - * **Note:** Models and optimizers are built explicitly during `__init__` using dummy inputs to prevent TensorFlow graph mode issues. -* **`src.trading_system.TradingSystem`**: Integrates GRU and SAC. Manages training pipelines, feature calculation, experience generation. -* **`src.trading_system.ExtendedBacktester`**: Performs efficient backtesting using pre-computed GRU outputs, calculates metrics, plots results, generates reports. -* **`src.trading_system.plot_sac_training_history`**: Generates plot for SAC actor/critic losses during training. + P_Gate -- Yes --> Q[Update Scaled Data = Pruned]; + Q --> R[Create Sequences]; + R --> S["Train/Load GRU w/ Nested CV?"]; + S --> T[Calibrate Probabilities]; + T --> U[Run GRU Validation Checks]; + U --> U_Report["Record GRU Validation Report"]; + U_Report --> U_Gate{GRU Gate Passed?}; -### 3. Model Architectures + U_Gate -- Yes --> V[Train/Load SAC Agent]; + V --> V_Report["Record SAC Status Report"]; + V_Report --> W[Run Backtest Simulation]; + W --> W_Report["Record Backtest Report & Perf Gate Check"]; + %% Checks Perf Gates Internally + W_Report --> X[Store Fold Metrics]; + X --> Y[Run Forward Baseline Check]; + Y --> Y_Gate{Forward Baseline Gate Passed?}; + Y_Gate -- Yes --> Z[Fold Success]; -* **GRU (`src.gru_predictor.CryptoGRUModel._build_model`)**: V6 Architecture. - * Input -> GRU(100) -> Dropout(0.2) -> Dense(1, linear). - * Compiled with Adam (LR=0.001), MSE loss. -* **Simplified SAC (`src.sac_agent_simplified.SimplifiedSACTradingAgent`)**: - * **Actor Network**: MLP `(state_dim=5)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, name='mu'), Dense(1, name='log_std'). Output is `mu` and `log_std` for a **Gaussian policy**. `log_std` is clipped. - * **Critic Network (x2)**: MLP `(state_dim=5 + action_dim=1)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, linear). - * **Algorithm**: Implements SAC with Clipped Double-Q, **automatic entropy tuning** (optimizing `alpha` based on `target_entropy`), squashed actions (`tanh`), faster learning rates, smaller networks/buffer, optional Batch Normalization / Residual connections. Uses Huber loss for critics. `@tf.function` used for update steps (`_update_critics`, `_update_actor_and_alpha`). + %% Failure Paths within Fold (Leading to AA: Fold Failed) + N_Gate -- No --> AA[Fold Failed]; + P_Gate -- No --> AA; + U_Gate -- No --> AA; + Y_Gate -- No --> AA; + %% Note: Backtest Perf Gate Fail doesn't halt fold here, just logged in report + W_Report -- Error during Backtest --> AA; + %% Handle explicit errors -### 4. Features & State Representation + end -* **GRU Features:** Uses the V6 feature set plus basic past returns (see `calculate_v6_features`). Cyclical time features (`hour_sin`, `hour_cos`) are added *before* data splitting. -* **SAC State (`state_dim=5`):** - 1. `predicted_return`: GRU predicted percentage return for the next period. - 2. `uncertainty`: GRU MC dropout standard deviation (unscaled). - 3. `z_proxy`: Risk proxy, calculated as `current_position * volatility_20`. - 4. `momentum_5`: 5-minute return (`return_5m` feature). - 5. `volatility_20`: 20-day volatility (`volatility_14d` feature, name mismatch intended). -* **Scaling:** Features for GRU scaled with `StandardScaler`. Target price for GRU scaled with `MinMaxScaler`. SAC state components are used directly without separate scaling. + H --> I; + %% Start processing for the fold + %% Save Report in Finally Block (Implied - happens before moving to AB) + Z -- Save Report --> AB{More Folds?}; + AA -- Save Report --> AB; + AB -- Yes --> G; + G -- No More Folds --> AC[End Fold Loop]; -### 5. Evaluation + end -* **GRU Model:** Evaluated using RMSE loss on validation set. Callbacks monitor `val_rmse`. Plots compare predicted vs actual price. -* **SAC Agent & Overall System:** Evaluated via the `ExtendedBacktester` metrics (Sharpe, Sortino, Max Drawdown, correlations, etc.), plots (Portfolio vs B&H, Actions), and a final Markdown report. SAC training progress monitored via saved loss plots (`sac_training_history_.png`). + subgraph Final_Aggregation ["Final Aggregation"] + direction TB + AC --> AD["Aggregate Fold Metrics"]; + AC --> AE["Aggregate SAC Agents (Optional)"]; + AD --> AF["Final Release Decision"]; + end -## File Structure + E --> F; + %% Fold Generation Starts after Global Setup + AF --> AG[End Pipeline]; -- `downloaded_data/`: **Place your SQLite database files here.** (Or update `DB_DIR` in `main.py`). -- `gru_sac_predictor/`: Project root directory. - - `models/`: Trained models saved here under `run_/` directories. - - `results/`: Backtest results saved here under `/` directories. - - `logs/`: Log files saved here under `/` directories. - - `src/`: Core Python modules. - - `crypto_db_fetcher.py` - - `data_pipeline.py` - - `gru_predictor.py` - - `sac_agent_simplified.py` - - `trading_system.py` - - `main.py`: Main script. - - `requirements.txt` - - `README.md` + style Fold_Processing fill:#f9f,stroke:#333,stroke-width:2px +``` -## Setup +## Installation / Setup -1. **Data:** Place your V6 `downloaded_data` directory containing the SQLite files relative to the `gru_sac_predictor` project root, or update the `DB_DIR` variable in `main.py` to point to the correct location. -2. **Dependencies:** Install required packages: +*(Assuming standard Python environment management)* + +1. **Clone the repository:** + ```bash + git clone + cd gru_sac_predictor + ``` +2. **Create and activate a virtual environment:** (Recommended) + ```bash + python -m venv venv + source venv/bin/activate # or venv\Scripts\activate on Windows + ``` +3. **Install dependencies:** ```bash pip install -r requirements.txt ``` - *Strongly Recommended:* Install TA-Lib for the full feature set. See TA-Lib installation guides for your OS. -3. **Configuration:** Review and adjust parameters in `main.py`. Key parameters include: - * `DB_DIR`, `TICKER`, `EXCHANGE`, `START_DATE`, `END_DATE`, `INTERVAL` - * Model hyperparameters (GRU and SAC sections) - * Control Flags: `LOAD_EXISTING_SYSTEM`, `TRAIN_GRU_MODEL`, `TRAIN_SAC_AGENT`, `LOAD_SAC_AGENT` - * Loading Specific Models: `GRU_MODEL_LOAD_RUN_ID` (set to a specific run ID string like `'YYYYMMDD_HHMMSS'` to load *only* the GRU model from `gru_sac_predictor/models/run_/`). SAC loading depends on `LOAD_SAC_AGENT` flag. - * SAC Training: `SAC_EPOCHS` defines the number of training epochs. - * Experience Generation: `experience_config` dictionary (sampling strategies currently not implemented). - * Backtesting: `INITIAL_CAPITAL`, `TRANSACTION_COST`. -4. **Run:** Execute from the project root directory (the one *containing* `gru_sac_predictor`): - ```bash - python -m gru_sac_predictor.main - ``` - Output files (logs, models, plots, report) will be generated in `gru_sac_predictor/logs/`, `gru_sac_predictor/models/`, and `gru_sac_predictor/results/` within run-specific subdirectories. + *(Ensure `requirements.txt` includes `tensorflow`, `pandas`, `numpy`, `scikit-learn`, `pyyaml`, `optuna`, `joblib`, `matplotlib`, `seaborn`, `tqdm` etc.)* +4. **Data Setup:** Ensure the crypto market database exists at the location specified by `data.db_dir` in `config.yaml`. The `DataLoader` expects a specific database structure (details should be added based on `DataLoader` implementation). -## Reporting +## Configuration (`config.yaml`) -The report generated by the `ExtendedBacktester` includes performance metrics, correlation analyses, and configuration details. Key metrics include: +The pipeline's behavior is primarily controlled by `config.yaml`. Below is a detailed breakdown of key sections and parameters: -* Total/Annualized Return -* Sharpe & Sortino Ratios -* Volatility & Max Drawdown -* Buy & Hold Comparison -* Position/Prediction Accuracy -* Prediction/Position/Uncertainty Correlations -* Total Trades \ No newline at end of file +**`pipeline`:** +* `description`: Textual description of the configuration. +* `stages_to_run`: (Optional) List of stages to execute (e.g., `["data", "features", "gru"]`) for partial runs. Defaults to all stages if omitted. + +**`data`:** +* `ticker`, `exchange`, `interval`, `start_date`, `end_date`: Parameters for data loading via `DataLoader`. +* `db_dir`: **Crucial**. Path to the database directory (relative to project root or absolute). Default: '../data/crypto_market_data'. +* `bar_frequency`: Target frequency for resampling (e.g., "1T" for 1 minute). Used in preprocessing. +* `missing`: Configuration for handling missing data points after resampling. + * `strategy`: Method ('drop', 'neutral', 'ffill', 'interpolate'). Default: 'neutral'. + * `max_gap`: Maximum consecutive missing bars allowed before applying `strategy` or potentially dropping (depends on strategy). Default: 60. + * `interpolate`: Settings if `strategy: "interpolate"`. Includes `method` and `limit`. +* `volatility_sampling`: (Optional) Configuration for volatility-based downsampling. + * `enabled`: Boolean flag. Default: `False`. + * `window`: Window for volatility calculation. Default: 30. + * `quantile`: Quantile threshold for sampling. Default: 0.5. + +**`features`:** +* Parameters for base feature calculation in `FeatureEngineer.add_base_features` (e.g., `atr_window`, `rsi_window`, `macd_fast`, etc.). Defaults shown in `config.yaml`. +* Parameters for feature selection (`FeatureEngineer.select_features`). + * `selection_method`: Method used (e.g., "correlation"). Default: "correlation". + * `correlation_threshold`: Threshold for correlation-based selection. Default: 0.02. + * `min_features_after_selection`: Minimum features to retain. Default: 10. + * `coarse_univariate_quantile`: (Used in `TradingPipeline` before baseline) Quantile threshold for a coarse univariate filter based on correlation with the target before main feature selection. Keeps features *above* this quantile. Default: 0.70 (i.e., keep top 30%). + +**`walk_forward`:** Configuration for data splitting and fold generation (`FoldGenerator`). +* `enabled`: Master switch for walk-forward vs. single split. Default: `True`. +* `split_ratios`: Used if `enabled: False`. Defines train/validation ratios for a single split. Test ratio is inferred. Defaults: `train: 0.6`, `validation: 0.2`. + **(Walk-Forward Settings - used if `enabled: True`)** + * `num_folds`: Number of folds. If > 1, uses N-Fold Block splitting. If <= 1, uses Rolling Window mode. Default: 5. + **(Rolling Window Specific - used if `num_folds <= 1`)** + * `train_period_days`, `validation_period_days`, `test_period_days`: Duration of each period in days. Defaults: 25, 10, 10. + * `step_days`: How many days to slide the window forward. Default: 7. + * `expanding_window`: If `True`, training period grows; otherwise, it slides. Default: `false`. + **(General Walk-Forward Settings)** + * `initial_offset_days`: Skip days at the beginning before the first fold starts. Default: 14. + * `purge_window_minutes`: Drop training samples potentially overlapping with validation/test based on lookahead/lookback (implementation details needed). Default: 0. + * `embargo_minutes`: Gap added after train/val periods before starting the next. Default: 0. + * `final_holdout_period_days`: Reserve days at the very end, excluded from all folds. Default: 0. + * `min_fold_duration_days`: Minimum *total* duration (train_start to test_end) for a fold to be considered valid. Default: 1. + **(Gap and Regime Settings)** + * `gap_threshold_minutes`: Split data into contiguous chunks if time gap exceeds this. Default: 5. + * `min_chunk_days`: Minimum duration for a chunk to be used for fold generation. Default: 1. + * `regime`: Configuration for regime-aware fold generation. + * `enabled`: Enable regime filtering. Default: `true`. + * `indicator`: Indicator used for regimes (e.g., 'volatility'). Default: 'volatility'. + * `indicator_params`: Parameters for the chosen indicator (e.g., `window`). Default: `window: 20`. + * `quantiles`: Quantiles to define regime boundaries. Default: `[0.33, 0.66]` (for 3 regimes). + * `min_regime_representation_pct`: Minimum percentage each defined regime must occupy in *each* period (train, val, test) of a candidate fold for it to be yielded. Default: 10. + +**`gru`:** Configuration for the GRU model (`GRUModelHandler`, label generation). +* `train_gru`: Master switch to train a new GRU model vs. loading one. Default: `true`. +* `use_ternary`: Use ternary (Up/Flat/Down) labels instead of binary (Up/Down). Default: `True`. +* `prediction_horizon`: Lookahead period (in units of `data.interval`) for target calculation. Default: 5. +* `flat_sigma_multiplier`: 'k' factor for ternary flat label threshold (ε = k * rolling_std(fwd_ret)). Used if `use_ternary: True`. Default: 0.25. +* `label_smoothing`: Alpha for *binary* label smoothing (0.0 disables). Default: 0.0. (Note: Focal loss has its own smoothing). +* `drop_imputed_sequences`: If `True`, sequences containing any bar marked as imputed (`bar_imputed` feature) are dropped. Default: `true`. + **(Model Architecture v3 - `build_gru_model_v3`)** + * `gru_units`: GRU layer units. Default: 96. + * `attention_units`: Attention layer units (set <= 0 to disable). Default: 16. + * `dropout_rate`: Dropout rate for GRU/Attention. Default: 0.1. + * `learning_rate`: Adam optimizer learning rate. Default: 1e-4. + * `l2_reg`: L2 regularization factor for Dense layers. Default: 1e-4. +* **(Loss Function v3 - `build_gru_model_v3`)** + * `focal_gamma`: Gamma for categorical focal loss (if `use_ternary: True`). Default: 2.0. + * `focal_label_smoothing`: Label smoothing within focal loss (if `use_ternary: True`). Default: 0.1. + * `huber_delta`: Delta for Huber loss (mu/return prediction). Default: 1.0. + * `loss_weight_mu`: Weight for the mu/return loss component. Default: 0.3. + * `loss_weight_dir3`: Weight for the direction loss component. Default: 1.0. +* **(Training Parameters - `GRUModelHandler.train`)** + * `lookback`: Sequence length (timesteps) for GRU input. Default: 60. + * `epochs`: Maximum training epochs. Default: 25. + * `batch_size`: Training batch size. Default: 128. + * `patience`: Early stopping patience (epochs). Default: 5. +* **(Loading Control - `train_or_load_gru_fold`)** + * `load_gru_model`: Section to specify loading a pre-trained model. + * `run_id`: Specific GRU pipeline *run ID* to load from. If `null`, training occurs (if `train_gru: True`). Default: `null`. + * `fold_num`: Specific fold number to load. If `null`, loader might try to find 'best'/last based on internal logic. Default: `null`. + +**`hyperparameter_tuning.gru`:** Configuration for Optuna Nested CV (`train_or_load_gru_fold`). +* `sweep_enabled`: Master switch to enable/disable Optuna sweep. Default: `False`. +* `study_name`: Name for the Optuna study. Default: "gru_optimization". +* `direction`: "minimize" (val_loss) or "maximize" (val_accuracy). Default: "minimize". +* `n_trials`: Number of hyperparameter combinations to try. Default: 50. +* `inner_cv_splits`: Number of inner folds for TimeSeriesSplit within the training data. Default: 3. +* `pruner`: Optuna pruner ('median', 'hyperband', etc.). Default: "median". +* `sampler`: Optuna sampler ('tpe', 'random', etc.). Default: "tpe". +* `search_space`: Dictionary defining parameters to tune, their types (`int`, `float`, `categorical`, `loguniform`), and ranges/choices. Defaults shown in `config.yaml`. + +**`calibration`:** Configuration for probability calibration (`calibrate_probabilities_fold`). +* `method`: Calibration method ('vector', 'temperature', etc.). Default: 'vector'. (Note: Implementation assumes 'vector' for ternary, 'temperature' for binary). +* `optimize_edge_threshold`: Whether to optimize the edge threshold using Youden's J on validation predictions. Default: `true`. +* `edge_threshold`: Initial/fixed edge threshold if not optimizing. Default: 0.5. + **(Rolling Calibration - Used if implemented within Calibrator classes)** + * `rolling_window_size`, `rolling_min_samples`, `rolling_step`: Parameters for rolling calibration methods. Defaults: 250, 50, 50. +* `reliability_plot_bins`: Number of bins for reliability plot generation. Default: 10. + +**`sac`:** Configuration for the SAC agent (`SACTradingAgent`) and trainer (`SACTrainer`). +* `imputed_handling`: How SAC environment handles steps with imputed features ('hold', 'skip', etc. - requires env implementation). Default: "hold". +* `action_penalty`: Penalty applied during SAC training for large/frequent actions (requires env implementation). Default: 0.05. + **(Agent Hyperparameters - `SACTradingAgent`)** + * `gamma`, `tau`, `actor_lr`, `critic_lr`, `alpha`, `alpha_auto_tune`, `target_entropy`: Standard SAC hyperparameters. Defaults shown in `config.yaml`. + * `lr_decay_rate`, `decay_steps`: Optional learning rate decay parameters. + * `ou_noise_stddev`: Optional Ornstein-Uhlenbeck noise parameter (if used). +* **(Training Loop - `SACTrainer`)** + * `total_training_steps`, `buffer_capacity`, `batch_size`, `start_steps`, `update_after`, `update_every`, `save_freq`, `log_freq`, `eval_freq`: Parameters controlling the SAC training loop. Defaults shown in `config.yaml`. +* **(Alpha Annealing - `SACTrainer`)** + * `alpha_anneal_start_step`, `alpha_anneal_end_step`, `initial_alpha`, `final_alpha`: Parameters for annealing the entropy temperature `alpha` (if `alpha_auto_tune: True`). Defaults shown in `config.yaml`. +* **(Prioritized Experience Replay (PER) - `SACTrainer`/Buffer)** + * `use_per`: Enable PER. Default: `False`. + * `per_alpha`, `per_beta_start`, `per_beta_frames`: PER parameters (used if `use_per: True`). Defaults shown in `config.yaml`. + * `per_alpha_anneal_enabled`, `per_alpha_start`, `per_alpha_end`, `per_alpha_anneal_steps`: Optional annealing for PER alpha. +* `oracle_seeding_pct`: (Experimental) Percentage of buffer to pre-fill using heuristic. Default: 0.0. + **(State Normalization - `SACTrainer`)** + * `use_state_filter`: Use MeanStdFilter for state normalization. Default: `True`. + * `state_dim_fallback`, `action_dim_fallback`: Fallback dimensions if they cannot be inferred (e.g., from loaded agent). Defaults: 5, 1. +* `train_sac`: Master switch to train a new SAC agent vs. loading one for the fold. Default: `True`. + +**`sac_aggregation`:** Configuration for aggregating SAC agents after walk-forward (`aggregate_sac_agents`). +* `enabled`: Enable aggregation. Default: `True`. +* `method`: Aggregation method ('average_weights' currently supported). Default: "average_weights". + +**`environment`:** Parameters passed to the `TradingEnv` (used by `SACTrainer` and `Backtester`). +* `initial_capital`: Starting capital. Default: 10000.0. +* `transaction_cost`: Fractional cost per trade. Default: 0.0005 (0.05%). +* `reward_scale`: Multiplier applied to raw PnL reward. Default: 100.0. +* `action_penalty_lambda`: Penalty factor for action magnitude/changes in reward calculation (0 disables). Default: 0.0. + +**`baselines`:** Configuration for `BaselineChecker`. +* `run_baseline1`, `run_baseline2`, `run_baseline3`, `run_baseline4`: Enable specific baselines (Logistic Regression, Random Forest for binary/ternary). Defaults: True/False/True/False. + Sections for each baseline (`logistic_regression`, `random_forest`, `multinomial_logistic_regression`, `ternary_random_forest`) containing model hyperparameters (e.g., `max_iter`, `n_estimators`), internal validation split (`val_subset_split_ratio`, `val_subset_shuffle`), and CI confidence level (`ci_confidence_level`). Defaults shown in `config.yaml`. + +**`validation_gates`:** Thresholds checked at different stages to potentially halt the pipeline fold. +* `run_baseline_check`: Master switch for *all* baseline checks. Default: `True`. + Baseline Gates (`baseline_min_ci_lb`, `baseline_edge_min_ci_lb`): Minimum acceptable Confidence Interval Lower Bound for baseline accuracy (raw and edge-filtered) for both binary and ternary models. Defaults shown in `config.yaml`. +* `gru_performance`: Gates applied after GRU calibration (`run_gru_validation_checks_fold`). + * `enabled`: Enable these GRU checks. Default: `True`. + * `min_edge_accuracy`: Minimum edge-filtered accuracy on validation set. Default: 0.60. + * `max_brier_score`: Maximum acceptable Brier score on validation set. Default: 0.24. +* `backtest`: Master switch for *all* backtest performance gates. Default: `True`. +* `backtest_performance`: Specific gates applied after backtesting (`run_backtest_fold`). + * `enabled`: Enable Sharpe and Max Drawdown checks. Default: `True`. + * `min_sharpe_ratio`: Minimum acceptable annualized Sharpe ratio. Default: 1.2. + * `max_drawdown_pct`: Maximum acceptable drawdown percentage. Default: 15.0. +* `final_release`: (Used in `TradingPipeline.final_release_decision`) Criteria applied to *aggregated* metrics across all folds. + * `min_successful_folds_pct`: Minimum required percentage of fully successful folds. Default: 0.75. + * `median_sharpe_threshold`: Minimum median Sharpe ratio across successful folds. Default: 1.3. + * `max_drawdown_max_threshold`: Maximum allowable *worst-case* drawdown across all successful folds. Default: 20.0. + +**`control`:** Flags controlling pipeline execution flow. +* `generate_plots`: Generate and save plots (learning curves, backtest summary, etc.). Default: `True`. +* `sac_load_run_id`: Specify SAC Training *Run ID* to load for backtesting (used if `sac.train_sac: False`). Default: `null`. +* `sac_load_step`: Step/checkpoint to load ('final' or step number). Default: 'final'. +* `sac_resume_run_id`, `sac_resume_step`: Parameters for resuming a previous SAC training run (implementation status needs verification). Defaults: `null`, 'final'. + +**`logging`:** Configuration for Python's `logging` module. +* `console_level`, `file_level`: Logging levels (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults: INFO, DEBUG. +* `log_to_file`: Enable file logging. Default: `True`. Log file path determined by `IOManager`. +* `log_format`, `log_date_format`: Formatting strings. +* `log_file_max_bytes`, `log_file_backup_count`: Parameters for `RotatingFileHandler`. Defaults: 10MB, 5 backups. + +**`output`:** Configuration for saving artifacts and plots (`IOManager`). +* `base_dirs`: Base directories for `results`, `models`, `logs`. Defaults: "results", "models", "logs". +* `figure_dpi`, `figure_size`, `figure_footer`, `plot_style`: Matplotlib settings for saved figures. Defaults shown in `config.yaml`. +* `reward_plot_smoothing_alpha`: EMA alpha for smoothing SAC reward plot. Default: 0.2. +* `dataframe_save_format`: Format for saving DataFrames ('csv', 'parquet', 'parquet_if_large'). Default: "parquet_if_large". +* `dataframe_max_csv_mb`: Size threshold (MB) for using Parquet if format is 'parquet_if_large'. Default: 100. + +## Pipeline Stages + +The pipeline executes the following stages, primarily coordinated by `TradingPipeline.execute`: + +1. **Load & Preprocess Data (`load_and_preprocess_data` -> `pipeline_stages.data_processing.load_and_preprocess`)** + * **Purpose:** Load raw market data, clean, resample, and handle missing values. + * **Logic:** Uses `DataLoader` to fetch data based on `config['data']`. Applies resampling to `data.bar_frequency`. Imputes missing values based on `data.missing` config. Optionally performs volatility sampling (`data.volatility_sampling`). Calculates basic returns (`return_1m`). Adds `bar_imputed` feature. + * **Inputs:** `config['data']`. + * **Outputs:** `self.df_raw` (DataFrame with preprocessed data and DatetimeIndex). + +2. **Initialize Fold Generator (`TradingPipeline.__init__`)** + * **Purpose:** Prepare the fold generation mechanism. + * **Logic:** Instantiates `FoldGenerator` with the full `self.df_raw`. If `walk_forward.regime.enabled` is true, it calls internal helper `_add_regime_tags` to calculate and add the `regime_tag` column to an internal copy (`self.df_raw_tagged`) based on `walk_forward.regime` parameters. Stores `self.regime_enabled` status. + * **Inputs:** `config['walk_forward']`, `self.df_raw`. + * **Outputs:** `self.fold_generator` (initialized instance), `self.regime_enabled`. + +3. **Generate Folds & Loop (`TradingPipeline.execute` -> `FoldGenerator.generate_folds`)** + * **Purpose:** Iterate through walk-forward fold definitions. + * **Logic:** Calls `self.fold_generator.generate_folds()`. This generator internally: + * Splits `self.df_raw_tagged` into contiguous chunks based on `walk_forward.gap_threshold_minutes`. + * Skips chunks shorter than `walk_forward.min_chunk_days`. + * For each valid chunk, calls `FoldGenerator._generate_folds_for_chunk` which implements single split, N-block, or rolling window logic based on `walk_forward.enabled` and `walk_forward.num_folds`. + * Applies `walk_forward.min_fold_duration_days` check to each candidate fold. + * If `self.regime_enabled`, checks regime balance using `walk_forward.regime.min_regime_representation_pct` for each period (train, val, test) in the candidate fold before yielding. + * **Inputs:** `config['walk_forward']`, `self.df_raw_tagged` (internal to generator). + * **Outputs:** Iterator yielding tuples of fold dates `(train_start, train_end, val_start, val_end, test_start, test_end)`. + +4. **--- Inside Fold Loop (`TradingPipeline.execute`) ---** + + a. **Engineer Features (`engineer_features` -> `pipeline_stages.data_processing.engineer_features_for_fold`)** + * **Purpose:** Calculate technical indicators for the current fold's raw data slice. + * **Logic:** Takes the `current_fold_data_raw` slice. Calls `FeatureEngineer.add_all_features` (or similar) based on `config['features']` parameters (windows, etc.). + * **Inputs:** `current_fold_data_raw`, `config['features']`. + * **Outputs:** `df_engineered_fold` (DataFrame with features for the fold). + + b. **Define Labels & Align (`define_labels_and_align` -> `pipeline_stages.data_processing.define_labels_and_align_fold`)** + * **Purpose:** Generate target labels (return, direction) and align features/targets. + * **Logic:** Calculates forward returns (`fwd_log_ret_N`) based on `gru.prediction_horizon`. Generates direction labels (binary or ternary based on `gru.use_ternary`, using `gru.flat_sigma_multiplier` for ternary). Aligns features (X) and targets (y) by dropping NaNs introduced by lookaheads/lookbacks. Stores required target column names. Calculates volatility (`eps`) used for ternary labels. + * **Inputs:** `df_engineered_fold`, `config['gru']`. + * **Outputs:** `df_labeled_aligned_fold`, `self.target_dir_col`, `self.target_columns`, `self.fwd_returns_aligned`, `self.eps_aligned`. + + c. **Split Data (`split_data` -> `pipeline_stages.data_processing.split_data_fold`)** + * **Purpose:** Split the aligned data for the current fold into train, validation, and test sets based on the dates yielded by `FoldGenerator`. + * **Logic:** Slices `df_labeled_aligned_fold` using the `fold_dates` tuple. Extracts corresponding X (features) and y (targets). Stores original unsplit data slices for reference. Extracts ordinal direction labels. Extracts corresponding slices of forward returns and epsilon values needed for baseline checks. + * **Inputs:** `df_labeled_aligned_fold`, `fold_dates`, `self.fwd_returns_aligned`, `self.eps_aligned`, `self.target_columns`, `self.target_dir_col`. + * **Outputs:** `self.X_train_raw`, `self.X_val_raw`, `self.X_test_raw`, `self.y_train`, `self.y_val`, `self.y_test`, `self.df_train_original`, `self.df_val_original`, `self.df_test_original`, `self.y_dir_train_ordinal`, `self.y_dir_val_ordinal`, `self.fwd_ret_train`, `self.eps_train`, `self.fwd_ret_val`, `self.eps_val`. + + d. **Scale Features (`scale_features` -> `pipeline_stages.feature_processing.scale_features_fold`)** + * **Purpose:** Apply standardization (StandardScaler) to features. + * **Logic:** Fits `StandardScaler` on `self.X_train_raw`, then transforms train, validation, and test sets. + * **Inputs:** `self.X_train_raw`, `self.X_val_raw`, `self.X_test_raw`. + * **Outputs:** `self.scaler` (fitted scaler object), `self.X_train_scaled`, `self.X_val_scaled`, `self.X_test_scaled`. + + e. **Coarse Univariate Filter (`TradingPipeline.execute`)** + * **Purpose:** Optional initial filtering step before baseline checks. + * **Logic:** Calculates absolute correlation between each feature in `self.X_train_scaled` and the aligned `self.y_dir_train_ordinal`. Keeps features with correlation >= the quantile specified by `features.coarse_univariate_quantile`. Applies this filter to create `X_train_coarse` and `X_val_coarse`. + * **Inputs:** `self.X_train_scaled`, `self.X_val_scaled`, `self.y_dir_train_ordinal`, `config['features']`. + * **Outputs:** `X_train_coarse`, `X_val_coarse` (or `None` if filter fails/removes all features). + + f. **Initial Baseline Check (`run_baseline_checks` -> `pipeline_stages.evaluation.run_baseline_checks_fold`)** + * **Purpose:** Run baseline models on scaled (or coarsely filtered) data as an early validation gate. + * **Logic:** Calls the stage function, passing scaled data (or `X_*_coarse` if available), ordinal labels, forward returns, and epsilon values. The stage function uses `BaselineChecker` to train and evaluate baselines specified in `config['baselines']`. Compares performance (CI lower bounds) against thresholds in `config['validation_gates']`. Generates label histogram plot. + * **Inputs:** Scaled features (`X_train_scaled` \ No newline at end of file diff --git a/gru_sac_predictor/__init__.py b/gru_sac_predictor/__init__.py new file mode 100644 index 00000000..0519ecba --- /dev/null +++ b/gru_sac_predictor/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/gru_sac_predictor/requirements.txt b/gru_sac_predictor/requirements.txt index ae8121a4..a422f9d5 100644 --- a/gru_sac_predictor/requirements.txt +++ b/gru_sac_predictor/requirements.txt @@ -1,10 +1,18 @@ -pandas -numpy -tensorflow -tensorflow-probability +pandas==2.1.0 +numpy==1.26.0 # Or newer +tensorflow==2.18.0 # Upgrade to TF 2.18 +tf-keras==2.18.0 # Match TF version +tensorflow-probability==0.25.0 # Matches TF >= 2.18 requirement matplotlib joblib scikit-learn tqdm PyYAML -TA-Lib \ No newline at end of file +# TA-Lib C library wrapper (requires libta-lib-dev installed) +# TA-Lib +ta # Use pure Python ta library +# tensorflow-addons==0.23.0 # Removed - incompatible +scipy +pytest +statsmodels # Added for VIF calculation +torch # Added for SAC \ No newline at end of file diff --git a/gru_sac_predictor/run.py b/gru_sac_predictor/run.py new file mode 100644 index 00000000..bd67e8b3 --- /dev/null +++ b/gru_sac_predictor/run.py @@ -0,0 +1,129 @@ +import argparse +import logging +import os +import sys + +# Ensure the src directory is in the Python path +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(script_dir) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# Import necessary components AFTER setting up path +from src.logger_setup import setup_logger # Correct function import +from src.io_manager import IOManager +from src.utils.run_id import make_run_id, get_git_sha # Import Git SHA function +from src.trading_pipeline import TradingPipeline # Keep pipeline import + +# --- Define Version --- # +__version__ = "3.0.0-dev" + +# --- Config Loading Helper --- # +def load_config(config_path: str) -> dict: + """Helper to load YAML config.""" + import yaml + # Logic similar to TradingPipeline._load_config, but simplified for entry point + if not os.path.isabs(config_path): + # Try relative to current dir first, then project root + potential_path = os.path.abspath(config_path) + if not os.path.exists(potential_path): + potential_path = os.path.join(project_root, config_path) + if os.path.exists(potential_path): + config_path = potential_path + else: + print(f"ERROR: Config file not found at '{config_path}' (tried CWD and project root).", file=sys.stderr) + sys.exit(1) + + try: + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + if not isinstance(config, dict): + raise TypeError("Config file did not parse as a dictionary.") + print(f"Config loaded ✓ ({config_path})") # Log before full logger setup + return config + except Exception as e: + print(f"ERROR: Failed to load or parse config file '{config_path}': {e}", file=sys.stderr) + sys.exit(1) + +# --- Main Execution Block --- # +def main(): + """Main execution function: parses args, sets up, runs pipeline.""" + parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.") + # Default config path seeking strategy + default_config_rel_root = os.path.join(project_root, 'config.yaml') + default_config_pkg = os.path.join(project_root, 'gru_sac_predictor', 'config.yaml') + default_config_cwd = os.path.abspath('config.yaml') + + if os.path.exists(default_config_rel_root): + default_config = default_config_rel_root + elif os.path.exists(default_config_pkg): + default_config = default_config_pkg + else: + default_config = default_config_cwd + + parser.add_argument( + '--config', type=str, default=default_config, + help=f"Path to the configuration YAML file (default attempts relative to project root, package dir, or CWD)" + ) + parser.add_argument( + '--use-ternary', + action='store_true', + help="Enable ternary (up/flat/down) direction labels instead of binary." + ) + args = parser.parse_args() + + # 1. Generate Run ID and Get Git SHA + run_id = make_run_id() + git_sha = get_git_sha(short=False) or "unknown" + + # 2. Load Config first + try: + config = load_config(args.config) # Load config dictionary + except Exception as e: + # Error message handled within load_config + sys.exit(1) # Exit if config loading fails + + # 3. Setup IOManager (passing loaded config dict) + try: + io = IOManager(cfg=config, run_id=run_id) # Pass config dict, not path + # Add git_sha as an attribute AFTER initialization + io.git_sha = git_sha + except Exception as e: + print(f"ERROR: Failed to initialize IOManager: {e}") + sys.exit(1) + + # 4. Setup Logger (using path from IOManager) + logger = setup_logger(cfg=config, run_id=run_id, io=io) # Pass config dict here too (use cfg=) + + # Log Banner + logger.info("="*80) + logger.info(f" GRU-SAC Predictor {__version__} | Commit: {git_sha[:8]} | Run: {run_id}") + logger.info(f" Config File: {os.path.basename(args.config)}") + logger.info("="*80) + + # 5. Modify config based on CLI args (if any) + if args.use_ternary: + if 'gru' not in config: config['gru'] = {} # Ensure 'gru' section exists + config['gru']['label_type'] = 'ternary' # Override label type + logger.warning("CLI override: Using ternary labels (--use-ternary).") + + # 6. Initialize and Run Pipeline + logger.info("Initializing TradingPipeline...") + try: + # Pass the loaded (and potentially modified) config dictionary directly + pipeline = TradingPipeline(config=config, io_manager=io) + logger.info("TradingPipeline initialized. Starting execution...") + pipeline.execute() + logger.info("--- Pipeline Execution Finished ---") + except SystemExit as e: + logger.critical(f"Pipeline halted prematurely by SystemExit: {e}") + sys.exit(1) + except ValueError as e: + logger.critical(f"Pipeline initialization failed with ValueError: {e}") + sys.exit(1) + except Exception as e: + logger.critical(f"An unexpected error occurred during pipeline execution: {e}", exc_info=True) + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/gru_sac_predictor/scripts/aggregate_metrics.py b/gru_sac_predictor/scripts/aggregate_metrics.py new file mode 100644 index 00000000..c406d962 --- /dev/null +++ b/gru_sac_predictor/scripts/aggregate_metrics.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python +""" +Aggregate metrics from the latest performance_metrics.txt file found via a pattern +and perform final validation checks based on Sharpe Ratio and Max Drawdown. + +Ref: revisions.txt Section 8 +""" + +import argparse +import glob +import os +import re +import sys +import logging + +# Setup logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +METRIC_SHARPE = "Annualized Sharpe Ratio (Re-centred)" +METRIC_MAX_DD = "Max Drawdown (%)" + +def parse_metric_value(line: str, metric_name: str) -> float | None: + """Extracts float value from a 'metric_name: value' line.""" + # Match lines like "Metric Name: 12.3456" or "Metric Name (%): 12.3456" + # Handles potential variations in spacing and optional % sign near name + pattern = rf"^\s*{re.escape(metric_name)}\s*\(?%?\)?\s*:\s*(-?\d*\.?\d+)" + match = re.search(pattern, line, re.IGNORECASE) + if match: + try: + return float(match.group(1)) + except ValueError: + logger.warning(f"Could not convert value '{match.group(1)}' to float for metric '{metric_name}'.") + return None + +def main(): + parser = argparse.ArgumentParser( + description="Parse latest metrics file and validate Sharpe/Max Drawdown." + ) + parser.add_argument( + 'metrics_pattern', + type=str, + help='Glob pattern for performance_metrics.txt files (e.g., "results/*/performance_metrics*.txt")' + ) + parser.add_argument( + '--min_sharpe', type=float, default=1.2, + help='Minimum acceptable Annualized Sharpe Ratio (Re-centred).' + ) + parser.add_argument( + '--max_drawdown_pct', type=float, default=15.0, + help='Maximum acceptable Max Drawdown Percentage.' + ) + args = parser.parse_args() + + logger.info(f"Searching for metrics files using pattern: {args.metrics_pattern}") + try: + metrics_files = sorted(glob.glob(args.metrics_pattern)) + except Exception as e: + logger.error(f"Error during glob pattern expansion '{args.metrics_pattern}': {e}") + sys.exit(1) + + if not metrics_files: + logger.error(f"No metrics files found matching pattern: {args.metrics_pattern}") + sys.exit(1) + + latest_file = metrics_files[-1] + logger.info(f"Processing latest metrics file: {latest_file}") + + sharpe_value = None + max_dd_value = None + + try: + with open(latest_file, 'r') as f: + for line in f: + # Use the robust parsing function + if sharpe_value is None: + parsed_sharpe = parse_metric_value(line, METRIC_SHARPE) + if parsed_sharpe is not None: + sharpe_value = parsed_sharpe + + if max_dd_value is None: + parsed_dd = parse_metric_value(line, METRIC_MAX_DD) + if parsed_dd is not None: + max_dd_value = parsed_dd + + # Stop reading if both metrics found + if sharpe_value is not None and max_dd_value is not None: + break + + except FileNotFoundError: + logger.error(f"Could not find file: {latest_file}") + sys.exit(1) + except Exception as e: + logger.error(f"Error reading or parsing file {latest_file}: {e}") + sys.exit(1) + + # --- Perform Checks --- # + checks_passed = True + fail_reasons = [] + + logger.info(f"Extracted Metrics: Sharpe={sharpe_value}, Max Drawdown={max_dd_value}%") + + # Check Sharpe Ratio + if sharpe_value is None: + logger.error(f"'{METRIC_SHARPE}' not found or could not be parsed.") + fail_reasons.append("Sharpe ratio missing/unparseable") + checks_passed = False + elif sharpe_value < args.min_sharpe: + logger.error(f"VALIDATION FAIL: Sharpe Ratio ({sharpe_value:.3f}) is below threshold ({args.min_sharpe:.3f})") + fail_reasons.append(f"Sharpe ({sharpe_value:.3f}) < {args.min_sharpe:.3f}") + checks_passed = False + else: + logger.info(f"VALIDATION PASS: Sharpe Ratio ({sharpe_value:.3f}) >= {args.min_sharpe:.3f}") + + # Check Max Drawdown + if max_dd_value is None: + logger.error(f"'{METRIC_MAX_DD}' not found or could not be parsed.") + fail_reasons.append("Max Drawdown missing/unparseable") + checks_passed = False + elif max_dd_value > args.max_drawdown_pct: + logger.error(f"VALIDATION FAIL: Max Drawdown ({max_dd_value:.2f}%) exceeds threshold ({args.max_drawdown_pct:.2f}%)") + fail_reasons.append(f"Max Drawdown ({max_dd_value:.2f}%) > {args.max_drawdown_pct:.2f}%") + checks_passed = False + else: + logger.info(f"VALIDATION PASS: Max Drawdown ({max_dd_value:.2f}%) <= {args.max_drawdown_pct:.2f}%") + + # --- Exit Status --- # + if checks_passed: + logger.info("All final metric checks passed.") + sys.exit(0) + else: + logger.error(f"One or more final metric checks failed: {', '.join(fail_reasons)}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/gru_sac_predictor/scripts/run_validation.sh b/gru_sac_predictor/scripts/run_validation.sh new file mode 100644 index 00000000..062a81e1 --- /dev/null +++ b/gru_sac_predictor/scripts/run_validation.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +# Validation Checklist Script for GRU-SAC Predictor v3 +# Ref: revisions.txt Section 8 + +# Exit immediately if a command exits with a non-zero status. +set -e + +# --- Configuration --- # +# Define paths relative to the script location or project root? +# Assuming script is run from project root (e.g., ./scripts/run_validation.sh) +PROJECT_ROOT="." +CONFIG_DIR="${PROJECT_ROOT}/configs" +TEST_CONFIG_DIR="${PROJECT_ROOT}/tests" +SCRIPTS_DIR="${PROJECT_ROOT}/scripts" +RESULTS_DIR="${PROJECT_ROOT}/results" +SRC_DIR="${PROJECT_ROOT}/src" # Or wherever run.py lives relative to root + +SMOKE_CONFIG="${TEST_CONFIG_DIR}/smoke.yaml" +VAL_CONFIG="${CONFIG_DIR}/quick_val.yaml" +METRICS_AGG_SCRIPT="${SCRIPTS_DIR}/aggregate_metrics.py" +PYTHON_EXEC="python" + +# --- Validation Steps --- # + +echo "[Validation Step 1/4] Running Unit Tests..." +pytest -q ${PROJECT_ROOT}/tests/ + +echo "\n[Validation Step 2/4] Running Smoke Test..." +# Assume run.py is executable from project root +${PYTHON_EXEC} ${SRC_DIR}/run_pipeline.py --config ${SMOKE_CONFIG} + +echo "\n[Validation Step 3/4] Running Quick Validation Training & Backtest..." +${PYTHON_EXEC} ${SRC_DIR}/run_pipeline.py --config ${VAL_CONFIG} \ + --train_gru true --train_sac true \ + --run_backtest true \ + --use_v3 true # Ensure v3 model is tested if needed + # Add other relevant CLI overrides if necessary for validation + +echo "\n[Validation Step 4/4] Aggregating and Checking Metrics..." +# Check if aggregate script exists +if [ ! -f "${METRICS_AGG_SCRIPT}" ]; then + echo "ERROR: Metrics aggregation script not found at ${METRICS_AGG_SCRIPT}" >&2 + echo "Skipping metrics aggregation and final checks." >&2 + exit 0 # Exit gracefully for now, but ideally should fail? +fi + +# Aggregate results from the validation run (or potentially all runs?) +# Need to determine how to target the specific run or use a pattern +# Assuming results are in subdirs under RESULTS_DIR +METRICS_PATTERN="${RESULTS_DIR}/*/performance_metrics*.txt" + +${PYTHON_EXEC} ${METRICS_AGG_SCRIPT} ${METRICS_PATTERN} + +# The Python script aggregate_metrics.py should contain the logic +# to parse the metrics files and exit with a non-zero status if +# the final checks fail (Sharpe < 1.2 or Max DD > 15%). + +echo "\nValidation Checklist Completed Successfully." +exit 0 \ No newline at end of file diff --git a/gru_sac_predictor/src/__init__.py b/gru_sac_predictor/src/__init__.py index 0519ecba..3a2391ef 100644 --- a/gru_sac_predictor/src/__init__.py +++ b/gru_sac_predictor/src/__init__.py @@ -1 +1,36 @@ - \ No newline at end of file +""" +GRU-SAC Predictor Package +""" + +import os +import logging +from datetime import datetime + +# --- Versioning and Build Info (Task 0.3) --- # +__version__ = "3.0.0-dev" # Placeholder version + +# Attempt to get Git SHA using the utility function +try: + # Need to adjust path if run_id is in utils + from .utils.run_id import get_git_sha + GIT_SHA = get_git_sha(short=False) or "unknown" +except ImportError: + logging.warning("Could not import get_git_sha from utils. GIT_SHA set to 'unknown'.") + GIT_SHA = "unknown" +except Exception as e: + logging.warning(f"Error getting git sha for package info: {e}") + GIT_SHA = "unknown" + +# Placeholder for build date (could be set during build process) +BUILD_DATE = datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC") +# --- End Versioning --- # + +# Configure logging for the package? +# Or assume it's configured by the entry point (run.py) +# Setting up a null handler to avoid "No handler found" warnings if no +# configuration is done by the application. +logging.getLogger(__name__).addHandler(logging.NullHandler()) + +# Expose key components (optional, depends on desired package structure) +# from .trading_pipeline import TradingPipeline +# from .sac_trainer import SACTrainer \ No newline at end of file diff --git a/gru_sac_predictor/src/backtester.py b/gru_sac_predictor/src/backtester.py new file mode 100644 index 00000000..29c676b9 --- /dev/null +++ b/gru_sac_predictor/src/backtester.py @@ -0,0 +1,809 @@ +""" +Backtesting Engine. + +Simulates trading strategy execution on historical test data, calculates +performance metrics, and generates reports and plots. +""" + +import os +import logging +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sns +from typing import Dict, Any, Tuple, Optional, List +import inspect + +# Import required components (use absolute paths) +from gru_sac_predictor.src.sac_agent import SACTradingAgent +from gru_sac_predictor.src.gru_model_handler import GRUModelHandler +from gru_sac_predictor.src.calibrator import Calibrator +from gru_sac_predictor.src.calibrator_vector import VectorCalibrator # Import needed for type hints +from gru_sac_predictor.src.metrics import edge_filtered_accuracy, calculate_sharpe_ratio, calculate_brier_score +# --- Import Metrics (Task 6.4) --- +try: + from .metrics import edge_filtered_accuracy, calculate_sharpe_ratio, calculate_brier_score +except ImportError: + logging.error("Failed to import metrics. Sharpe, Brier and Edge Acc will be missing.") + # Define placeholders + def edge_filtered_accuracy(*args, **kwargs): return np.nan, 0 + def calculate_sharpe_ratio(*args, **kwargs): return np.nan + def calculate_brier_score(*args, **kwargs): return np.nan +# --- End Import --- # + +logger = logging.getLogger(__name__) + + +def calculate_sharpe_ratio(returns, periods_per_year=252*24*60): # Default for 1-min data + """Calculate annualized Sharpe ratio from a series of returns.""" + returns = pd.Series(returns) + if returns.std() == 0 or np.isnan(returns.std()): + return 0.0 if returns.mean() > 0 else (-np.inf if returns.mean() < 0 else 0.0) + # Assuming risk-free rate is 0 for simplicity + return np.sqrt(periods_per_year) * returns.mean() / returns.std() + +def calculate_max_drawdown(equity_curve): + """Calculate the maximum drawdown from an equity curve series.""" + equity_curve = pd.Series(equity_curve) + rolling_max = equity_curve.cummax() + drawdown = (equity_curve - rolling_max) / rolling_max + max_drawdown = drawdown.min() + return abs(max_drawdown) # Return positive value + +class Backtester: + """Runs the backtest simulation and generates results.""" + + def __init__(self, config: dict, io_manager: Optional[Any] = None): + """ + Initialize the Backtester. + Args: + config (dict): Pipeline configuration dictionary. + io_manager (Optional[Any]): IOManager instance for saving results. + """ + self.config = config + self.io = io_manager + self.env_cfg = config.get('environment', {}) + self.cal_cfg = config.get('calibration', {}) + self.sac_cfg = config.get('sac', {}) + + self.initial_capital = self.env_cfg.get('initial_capital', 10000.0) + self.transaction_cost = self.env_cfg.get('transaction_cost', 0.0005) + + # --- Store Rolling Calibration Config --- # + self.rolling_cal_enabled = self.cal_cfg.get('rolling_enabled', False) + self.recalibrate_every_n = self.cal_cfg.get('recalibrate_every_n', 5000) + self.recalibration_window = self.cal_cfg.get('recalibration_window', 20000) + self.coverage_alarm_enabled = self.cal_cfg.get('coverage_alarm_enabled', False) + self.coverage_alarm_threshold_drop = self.cal_cfg.get('coverage_alarm_threshold_drop', 0.03) + self.coverage_alarm_window = self.cal_cfg.get('coverage_alarm_window', 1000) + # --- End --- # + + self.results_df: Optional[pd.DataFrame] = None + self.metrics: Optional[Dict[str, Any]] = None + + logger.info("Backtester initialized.") + logger.info(f" Initial Capital: {self.initial_capital:.2f}") + logger.info(f" Transaction Cost: {self.transaction_cost*100:.4f}%") + if self.rolling_cal_enabled: + logger.info(" Rolling Calibration: Enabled") + logger.info(f" Recalibrate Every: {self.recalibrate_every_n} steps") + logger.info(f" Recalibration Window: {self.recalibration_window} steps") + if self.coverage_alarm_enabled: + logger.info(" Coverage Alarm: Enabled") + logger.info(f" Alarm Window: {self.coverage_alarm_window} steps") + logger.info(f" Alarm Threshold Drop: {self.coverage_alarm_threshold_drop:.3f}") + else: + logger.info(" Rolling Calibration: Disabled") + + self.ece_recalibration_threshold = self.cal_cfg.get('ece_recalibration_threshold', 0.03) + + def run_backtest( + self, + sac_agent_load_path: Optional[str], + X_test_seq: np.ndarray, + y_test_seq_dict: Dict[str, np.ndarray], + test_indices: pd.Index, + gru_handler: GRUModelHandler, + # --- Pass Calibrator Instances & Initial Params/Threshold --- # + calibrator: Optional[Calibrator], # For Temp Scaling + vector_calibrator: Optional[VectorCalibrator], # For Vector Scaling + initial_optimal_T: Optional[float], + initial_vector_params: Optional[np.ndarray], + fold_edge_threshold: float, # Optimized or fixed threshold from validation + # --- Pass Raw Predictions for Rolling Cal --- # + p_raw_test: Optional[np.ndarray] = None, + logits_test: Optional[np.ndarray] = None, + # --- End Args --- # + original_prices: Optional[pd.DataFrame] = None, # Pass DataFrame + is_ternary: bool = False, + fold_num: int = 0 # Added fold number + ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]], Optional[pd.DataFrame]]: + """ + Executes the backtesting simulation loop with optional rolling calibration. + + Args: + sac_agent_load_path: Path to load the SAC agent. + X_test_seq: Test set feature sequences. + y_test_seq_dict: Dict of test set targets ('ret', 'dir'/'dir3'). + test_indices: Timestamps corresponding to the test sequences' targets. + gru_handler: Instance to get GRU predictions (used only if raw preds not passed). + calibrator: Instance for Temperature scaling (if used). + vector_calibrator: Instance for Vector scaling (if used). + initial_optimal_T: Optimal temperature from fold validation. + initial_vector_params: Optimal vector params from fold validation. + fold_edge_threshold: Edge threshold determined during validation (optimized or fixed). + p_raw_test: Raw binary probabilities P(up) from GRU. Required if rolling_cal_enabled & !is_ternary. + logits_test: Raw logits (N, 3) from GRU. Required if rolling_cal_enabled & is_ternary. + original_prices: DataFrame with 'open', 'high', 'low', 'close', 'volume' aligned with test_indices. + is_ternary: Flag indicating ternary classification. + fold_num: Current fold number for logging. + + Returns: + Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]], Optional[pd.DataFrame]]: + - DataFrame with detailed backtest results per step. + - Dictionary containing calculated performance metrics. + - DataFrame with periodic metrics logged during the backtest. + Returns (None, None, None) if backtest cannot run. + """ + logger.info(f"--- Fold {fold_num}: Starting Backtest Simulation --- ") + + # --- Input Validation --- # + if X_test_seq is None or y_test_seq_dict is None or test_indices is None: + logger.error(f"Fold {fold_num}: Test sequence data (X, y, indices) is missing. Cannot run backtest.") + return None, None, None + if gru_handler.model is None: + logger.error(f"Fold {fold_num}: GRU model is not loaded in the handler. Cannot run backtest.") + return None, None, None + if original_prices is None: + logger.warning(f"Fold {fold_num}: Original prices DataFrame not provided. Some metrics/plots may be unavailable.") + if self.rolling_cal_enabled and not is_ternary and p_raw_test is None: + logger.error(f"Fold {fold_num}: Rolling calibration enabled for binary case, but p_raw_test not provided.") + return None, None, None + if self.rolling_cal_enabled and is_ternary and logits_test is None: + logger.error(f"Fold {fold_num}: Rolling calibration enabled for ternary case, but logits_test not provided.") + return None, None, None + # --- End Validation --- # + + # 1. Initialize SAC Agent & Load Weights + # Ensure agent state dim matches the state construction below + agent_state_dim = 5 # mu, sigma, edge, |mu|/sigma, position + + # --- Filter sac_cfg to only include valid SACTradingAgent.__init__ args --- + expected_args = set(inspect.signature(SACTradingAgent.__init__).parameters.keys()) + # Remove 'self' and potentially dims if handled separately + expected_args -= {'self', 'state_dim', 'action_dim'} + + filtered_sac_cfg = {k: v for k, v in self.sac_cfg.items() if k in expected_args} + logger.info(f"Filtered SAC config passed to agent: {filtered_sac_cfg.keys()}") + # --- End Filter --- + + agent = SACTradingAgent( + state_dim=agent_state_dim, + action_dim=1, + # Pass FILTERED relevant SAC params from config + **filtered_sac_cfg + ) + # Set the edge threshold for the agent (used in oracle seeding if enabled) + agent.edge_threshold_config = fold_edge_threshold + + if sac_agent_load_path and os.path.exists(sac_agent_load_path): + logger.info(f"Fold {fold_num}: Loading SAC agent weights from: {sac_agent_load_path}") + try: + agent.load(sac_agent_load_path) + except Exception as e: + logger.error(f"Fold {fold_num}: Failed to load SAC agent weights from {sac_agent_load_path}: {e}. Proceeding with untrained agent.", exc_info=True) + else: + logger.warning(f"Fold {fold_num}: SAC agent load path not found or not specified ({sac_agent_load_path}). Proceeding with untrained agent.") + + # 2. Get GRU Predictions (if not passed for rolling cal) AND Pre-calculate Logits/Probs if needed + mu_test_state, sigma_test_state = None, None + p_raw_fallback = None # For binary non-rolling case + p_softmax_test = None # For ternary non-rolling/non-vector case (fallback) + # Pre-calculate all test logits/raw_probs IF rolling cal is OFF to avoid step-by-step prediction + all_logits_test_precalc = None + all_p_raw_test_precalc = None + + if not self.rolling_cal_enabled: + logger.info(f"Fold {fold_num}: Rolling cal disabled. Pre-calculating necessary GRU outputs for test set...") + if is_ternary: + all_logits_test_precalc = gru_handler.predict_logits(X_test_seq) + # Also need mu for state construction + preds_test_mu = gru_handler.predict(X_test_seq) # Assumes predict returns mu as first item + if preds_test_mu is None: return None, None, None + mu_test_state = preds_test_mu[0].flatten() + sigma_test_state = np.ones_like(mu_test_state) * 0.01 # Placeholder sigma + if all_logits_test_precalc is None: + logger.error(f"Fold {fold_num}: Failed to pre-calculate logits needed for vector calibration.") + return None, None, None + else: + logger.info(f"Pre-calculated all_logits_test_precalc shape: {all_logits_test_precalc.shape}") + # Calculate softmax once if needed as fallback + import tensorflow as tf # Local import ok here + p_softmax_test = tf.nn.softmax(all_logits_test_precalc).numpy() + + else: # Binary case, non-rolling + preds_test = gru_handler.predict(X_test_seq) + if preds_test is None or len(preds_test) < 3: return None, None, None # Error logged inside + mu_test_state = preds_test[0].flatten() + log_sigma_test_state = preds_test[1][:, 1].flatten() if preds_test[1].shape[-1] == 2 else np.log(preds_test[1].flatten() + 1e-9) + sigma_test_state = np.exp(log_sigma_test_state) + all_p_raw_test_precalc = preds_test[2].flatten() # Pre-calculate raw probs + p_raw_fallback = all_p_raw_test_precalc # Use pre-calculated as fallback + logger.info(f"Pre-calculated all_p_raw_test_precalc shape: {all_p_raw_test_precalc.shape}") + + else: # Rolling calibration IS enabled, use passed-in logits/p_raw + logger.info(f"Fold {fold_num}: Rolling cal enabled. Using passed logits_test or p_raw_test.") + # Still need mu/sigma for state construction + if is_ternary: + # Logits were passed, get mu from separate prediction call + preds_test_mu = gru_handler.predict(X_test_seq) # Assumes predict returns mu as first item + if preds_test_mu is None: return None, None, None + mu_test_state = preds_test_mu[0].flatten() + sigma_test_state = np.ones_like(mu_test_state) * 0.01 # Placeholder sigma + else: # Binary case, rolling + preds_test_state = gru_handler.predict(X_test_seq) + if preds_test_state is None or len(preds_test_state) < 2: # Need at least mu, sigma + logger.error(f"Fold {fold_num}: Failed to get GRU state predictions (mu, log_sigma) when p_raw was provided for rolling cal.") + return None, None, None + mu_test_state = preds_test_state[0].flatten() + log_sigma_test_state = preds_test_state[1][:, 1].flatten() if preds_test_state[1].shape[-1] == 2 else np.log(preds_test_state[1].flatten() + 1e-9) + sigma_test_state = np.exp(log_sigma_test_state) + # p_raw_fallback is not needed when rolling cal is enabled, as p_raw_test is used directly later + # p_softmax_test is not needed when rolling cal is enabled if vector cal is used + + # Extract actual returns and directions + actual_ret_test = y_test_seq_dict.get('ret') + dir_key = 'dir3' if is_ternary else 'dir' + actual_dir_test = y_test_seq_dict.get(dir_key) + if actual_ret_test is None or actual_dir_test is None: + logger.error(f"Fold {fold_num}: Actual return ('ret') or direction ('{dir_key}') missing from y_test_seq_dict.") + return None, None, None + + # Verify prediction lengths (ensure mu_test_state/sigma_test_state were calculated) + n_test = len(X_test_seq) + if mu_test_state is None or sigma_test_state is None: + logger.error(f"Fold {fold_num}: Failed to extract mu/sigma state components. Cannot proceed.") + return None, None, None + if not (len(mu_test_state) == n_test and len(sigma_test_state) == n_test and + len(actual_ret_test) == n_test and len(actual_dir_test) == n_test and + len(test_indices) == n_test): + logger.error(f"Fold {fold_num}: Length mismatch in test state predictions/targets/indices.") + return None, None, None + + # --- Initialize Calibration State --- + current_optimal_T = initial_optimal_T + current_vector_params = initial_vector_params + + # --- Initialize Rolling Calibration State --- + historical_data = [] # Stores tuples of (raw_pred, true_label) + coverage_hits = [] # Buffer for coverage alarm hit rate calculation + last_recalib_step = -1 + + # --- Initialize Simulation State --- + capital = self.initial_capital + current_position = 0.0 # Starts neutral (-1 to 1) + equity_curve = [capital] + positions = [current_position] + actions_taken = [0.0] # SAC agent's desired fractional position + calibrated_probs_steps = [] # Store step-by-step calibrated prob + edge_steps = [] # Store step-by-step edge + pnl_steps = [] + trades_executed = [] # Store details of trades + metrics_log = [] # Store periodic metrics + step_correct_nonzero, step_count_nonzero = 0, 0 + step_abs_actions = [] + + logger.info(f"Fold {fold_num}: Starting backtest simulation loop ({n_test} steps)...") + for i in range(n_test): + # --- Step i: Calibration --- + p_cal_step = np.nan + edge_step = np.nan + + if is_ternary and vector_calibrator is not None and current_vector_params is not None: + # --- Use pre-calculated logits if available --- + if all_logits_test_precalc is not None: + logit_step = all_logits_test_precalc[i:i+1] + elif logits_test is not None: # Use passed logits if rolling cal enabled + logit_step = logits_test[i:i+1] + else: # Should not happen if logic above is correct + logger.error(f"Step {i}: Logits not pre-calculated or passed. Cannot calibrate.") + logit_step = None + # --- End Use pre-calculated --- + + if logit_step is None: + logger.warning(f"Step {i}: Failed to get logits, using neutral prob."); p_cal_step=0.33; edge_step=0.0 + else: + # --- Corrected Call: Use calibrate() which uses internal W, b --- + vector_calibrator.W = current_vector_params[:len(logit_step[0])] + vector_calibrator.b = current_vector_params[len(logit_step[0]):] + p_cal_step_all = vector_calibrator.calibrate(logit_step) + # --- End Correction --- + p_cal_step = p_cal_step_all[0, 2] # Prob(Up) + edge_step = p_cal_step_all[0, 2] - p_cal_step_all[0, 0] + + elif not is_ternary and calibrator is not None and current_optimal_T is not None: + # --- Use pre-calculated raw probs if available --- + if all_p_raw_test_precalc is not None: + p_raw_step = all_p_raw_test_precalc[i] + elif p_raw_test is not None: # Use passed raw probs if rolling cal enabled + p_raw_step = p_raw_test[i] + else: # Should not happen + logger.error(f"Step {i}: Raw probs not pre-calculated or passed. Cannot calibrate.") + p_raw_step = None + # --- End Use pre-calculated --- + + if p_raw_step is None: + logger.warning(f"Step {i}: Failed to get raw prob, using neutral prob."); p_cal_step=0.5; edge_step=0.0 + else: + # --- Corrected Temp Scaling Call --- + # p_cal_step = calibrator.calibrate_with_T(p_raw_step, current_optimal_T) + calibrator.optimal_T = current_optimal_T # Ensure instance has the right T + p_cal_step = calibrator.calibrate(np.array([[p_raw_step]]))[0,0] # calibrate expects 2D + # --- End Correction --- + edge_step = 2 * p_cal_step - 1 + else: + # Fallback if calibrator missing or not fitted - use raw prob? or neutral? + if not is_ternary: + p_cal_step = p_raw_fallback[i] if p_raw_fallback is not None else 0.5 # Use pre-calculated raw if available + edge_step = 2 * p_cal_step - 1 + elif is_ternary and p_softmax_test is not None: + p_cal_step = p_softmax_test[i, 2] # Use raw softmax P(Up) + edge_step = p_softmax_test[i, 2] - p_softmax_test[i, 0] # Raw edge + else: # Cannot determine calibrated prob or edge + p_cal_step = 0.5 if not is_ternary else 0.33 + edge_step = 0.0 + + calibrated_probs_steps.append(p_cal_step) + edge_steps.append(edge_step) + + # --- Step i: State Construction --- + z_score_step = np.abs(mu_test_state[i]) / (sigma_test_state[i] + 1e-9) + state = np.array([ + mu_test_state[i], sigma_test_state[i], edge_step, z_score_step, current_position + ], dtype=np.float32) + + # --- Step i: SAC Action --- + sac_action = agent.get_action(state, deterministic=True)[0] + target_position = np.clip(sac_action, -1.0, 1.0) + + # --- Step i: PnL Calculation --- + step_actual_return = actual_ret_test[i] + gross_pnl = current_position * capital * (np.exp(step_actual_return) - 1) + trade = target_position - current_position + cost = abs(trade) * capital * self.transaction_cost + net_pnl = gross_pnl - cost + capital += net_pnl + + # --- Step i: Store Results --- + equity_curve.append(capital) + positions.append(target_position) + actions_taken.append(sac_action) + pnl_steps.append(net_pnl) + if abs(trade) > 1e-6: + trades_executed.append({ + 'timestamp': test_indices[i], 'trade_size': trade, 'cost': cost, + 'position_before': current_position, 'position_after': target_position + }) + + # --- Step i: Update Position --- + current_position = target_position + + # --- Step i: Update Metrics Log Data --- + if abs(current_position) > 1e-6: + step_count_nonzero += 1 + if (gross_pnl > 0 and current_position > 0) or \ + (gross_pnl < 0 and current_position < 0) or \ + (abs(gross_pnl) < 1e-9): + step_correct_nonzero += 1 + step_abs_actions.append(abs(sac_action)) + # Log metrics periodically (end of loop) + + # --- Step i: Rolling Calibration Update Logic --- + recalibrate_now = False + if self.rolling_cal_enabled: + # Store data for potential refit + true_label_step = actual_dir_test[i] + # Use the raw prediction corresponding to the calibration method + raw_pred_step = logits_test[i] if is_ternary else p_raw_test[i] + # Only store if raw prediction was valid + if raw_pred_step is not None: + historical_data.append((raw_pred_step, true_label_step)) + + # Select raw outputs and true labels for the window + window_start_idx = max(0, i - self.recalibration_window + 1) + window_labels_onehot = actual_dir_test[window_start_idx:i] + window_logits = None + window_raw_probs = None + window_calibrated_probs = None # Store current calibrated probs for ECE + if is_ternary and logits_test is not None: + window_logits = logits_test[window_start_idx:i] + if vector_calibrator is not None: # Ensure calibrator exists + window_calibrated_probs = vector_calibrator.calibrate(window_logits) + elif not is_ternary and p_raw_test is not None: + window_raw_probs = p_raw_test[window_start_idx:i] + if calibrator is not None: # Ensure calibrator exists + window_calibrated_probs = calibrator.calibrate(window_raw_probs) + + # --- Check Coverage Alarm (Now using ECE - Revision 4) --- # + trigger_recalibration = False + if self.coverage_alarm_enabled and window_calibrated_probs is not None: + try: + # Prepare inputs for ECE calculation + if is_ternary: + probs_for_ece = window_calibrated_probs + labels_for_ece = window_labels_onehot + else: # Binary + # ECE helper expects N x K shape + p_binary_2class = np.vstack([1 - window_calibrated_probs, window_calibrated_probs]).T + # Convert labels to one-hot if not already + if window_labels_onehot.shape[1] == 1: + y_true_binary_indices = window_labels_onehot.flatten().astype(int) + labels_for_ece = tf.keras.utils.to_categorical(y_true_binary_indices, num_classes=2) + else: + labels_for_ece = window_labels_onehot + probs_for_ece = p_binary_2class + + # Calculate ECE + if labels_for_ece.shape[0] > 0: # Ensure data exists + ece = self._calculate_ece(probs_for_ece, labels_for_ece) + logger.debug(f"Step {i}: Rolling ECE check: {ece:.4f} (Threshold: {self.ece_recalibration_threshold})") + if ece > self.ece_recalibration_threshold: + trigger_recalibration = True + logger.warning(f"Step {i}: ECE Coverage Alarm! ECE ({ece:.4f}) > Threshold ({self.ece_recalibration_threshold}). Triggering recalibration.") + else: + logger.debug(f"Step {i}: Skipping ECE calculation (no valid data in window).") + + except Exception as ece_err: + logger.warning(f"Step {i}: Could not calculate ECE for coverage alarm: {ece_err}") + # --- End ECE Check --- # + # Coverage Alarm Check + if self.coverage_alarm_enabled: + # Is prediction edge >= threshold? + has_edge = abs(edge_step) >= fold_edge_threshold + # Was the prediction correct? (Handle ternary/binary) + pred_dir = np.sign(edge_step) if abs(edge_step) >= 1e-6 else 0 + true_dir = np.sign(step_actual_return) if not is_ternary else np.argmax(true_label_step)-1 # Convert one-hot to -1,0,1 + is_correct = (pred_dir == true_dir) + + if has_edge: + coverage_hits.append(is_correct) + # Trim buffer + if len(coverage_hits) > self.coverage_alarm_window: + coverage_hits.pop(0) + # Check alarm condition + if len(coverage_hits) == self.coverage_alarm_window: + current_hit_rate = np.mean(coverage_hits) + alarm_trigger_level = fold_edge_threshold - self.coverage_alarm_threshold_drop + # Note: This condition might need refinement. Comparing hit rate to edge threshold directly isn't ideal. + # A better approach might compare current hit rate to expected (e.g., fold validation hit rate) + # For now, using the simpler logic from description: + if current_hit_rate < (0.5 + alarm_trigger_level/2.0): # Simplified check: hit rate < expected rate at threshold + logger.warning(f"Fold {fold_num} Step {i+1}: Coverage Alarm Triggered! Hit rate {current_hit_rate:.3f} < {0.5 + alarm_trigger_level/2.0:.3f} (Threshold {fold_edge_threshold:.3f}, Drop {self.coverage_alarm_threshold_drop:.3f}) over last {self.coverage_alarm_window} edge samples. Forcing recalibration.") + recalibrate_now = True + coverage_hits = [] # Reset alarm buffer + + # Interval Check + if not recalibrate_now and (i + 1) % self.recalibrate_every_n == 0: + recalibrate_now = True + + # Perform Recalibration + if recalibrate_now and len(historical_data) >= self.recalibration_window: + logger.info(f"Fold {fold_num} Step {i+1}: Recalibrating...") + + # Get recent data + recent_data = historical_data[-self.recalibration_window:] + recent_preds, recent_labels = zip(*recent_data) + recent_preds = np.array(recent_preds) + recent_labels = np.array(recent_labels) + + if is_ternary and vector_calibrator is not None: + vector_calibrator.fit(recent_preds, recent_labels) # Refit + if np.array_equal(current_vector_params, vector_calibrator.optimal_params): + logger.info(" Vector params unchanged.") + else: + current_vector_params = vector_calibrator.optimal_params + logger.info(f" New vector params obtained (shape: {current_vector_params.shape}).") + # TODO: Optionally save updated params? + elif not is_ternary and calibrator is not None: + new_T = calibrator.optimise_temperature(recent_preds, recent_labels) + if new_T is not None and abs(new_T - current_optimal_T) > 1e-4: + logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})") + current_optimal_T = new_T + calibrator.optimal_T = new_T # Update instance + else: + logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).") + + last_recalib_step = i + elif recalibrate_now: + logger.warning(f"Fold {fold_num} Step {i+1}: Recalibration triggered but not enough historical data ({len(historical_data)} < {self.recalibration_window}).") + + + # --- Step i: Check for Ruin --- + if capital <= 0: + logger.warning(f"Fold {fold_num}: Capital depleted at step {i+1}. Stopping backtest.") + n_test = i + 1 # Adjust length to current step + break + + # --- Step i: Log Periodic Metrics --- + if (i + 1) % 1000 == 0: # Log every 1000 steps + hr_nonzero = (step_correct_nonzero / step_count_nonzero) if step_count_nonzero > 0 else 0 + mean_abs_action = np.mean(step_abs_actions) if step_abs_actions else 0 + metrics_log.append({ + 'step': i + 1, 'timestamp': test_indices[i], + 'hit_rate_nonzero': hr_nonzero, 'mean_abs_action': mean_abs_action, + 'equity': capital, 'current_T': current_optimal_T, + 'last_recalib_step': last_recalib_step + }) + # Reset interval counters + step_correct_nonzero, step_count_nonzero, step_abs_actions = 0, 0, [] + # --- End Periodic Log --- + + logger.info(f"Fold {fold_num}: Backtest simulation loop finished.") + logger.info(f"Fold {fold_num}: Final Equity: {capital:.2f}") + + # --- Convert metrics log to DataFrame --- + metrics_log_df = pd.DataFrame(metrics_log).set_index('step') + + # 5. Prepare Results DataFrame + if n_test == 0: + logger.warning(f"Fold {fold_num}: Backtest executed 0 steps.") + return pd.DataFrame(), {}, pd.DataFrame() + + results_data = { + 'equity': equity_curve[1:], + 'position': positions[1:], + 'action': actions_taken[1:], + 'pnl': pnl_steps, + 'actual_return': actual_ret_test[:n_test], + 'mu_pred': mu_test_state[:n_test], # Use state mu + 'sigma_pred': sigma_test_state[:n_test], # Use state sigma + 'p_cal_pred': calibrated_probs_steps[:n_test], # Use step-calibrated probs + 'edge_pred': edge_steps[:n_test], # Use step-calculated edge + 'actual_dir': actual_dir_test[:n_test] if not is_ternary else np.argmax(actual_dir_test[:n_test], axis=1) # Store labels + } + if original_prices is not None: + aligned_prices = original_prices.reindex(test_indices[:n_test]) # Reindex safely + results_data['close_price'] = aligned_prices['close'].values # Add close price + + self.results_df = pd.DataFrame(results_data, index=test_indices[:n_test]) + self.results_df['returns'] = self.results_df['equity'].pct_change().fillna(0.0) + self.results_df['cumulative_return'] = (1 + self.results_df['returns']).cumprod() - 1 + + # Calculate Buy & Hold Benchmark + bh_sharpe = 0.0 + if 'close_price' in self.results_df.columns and not self.results_df['close_price'].isnull().all(): + bh_returns = self.results_df['close_price'].pct_change().fillna(0.0) + self.results_df['bh_cumulative_return'] = (1 + bh_returns).cumprod() - 1 + bh_sharpe = calculate_sharpe_ratio(bh_returns) + else: + self.results_df['bh_cumulative_return'] = 0.0 + logger.warning(f"Fold {fold_num}: Could not calculate Buy & Hold benchmark due to missing/NaN price data.") + + # 6. Calculate Final Performance Metrics + logger.info(f"Fold {fold_num}: Calculating final performance metrics...") + final_equity = self.results_df['equity'].iloc[-1] + total_return_pct = (final_equity / self.initial_capital - 1) * 100 + sharpe = calculate_sharpe_ratio(self.results_df['returns']) + max_dd = calculate_max_drawdown(self.results_df['equity']) + + wins = self.results_df[self.results_df['pnl'] > 0]['pnl'] + losses = self.results_df[self.results_df['pnl'] < 0]['pnl'] + profit_factor = wins.sum() / abs(losses.sum()) if losses.sum() != 0 else np.inf + + num_trades = len(trades_executed) + win_rate_steps = (self.results_df['pnl'] > 0).mean() * 100 if len(self.results_df) > 0 else 0 + + # Use the determined fold edge threshold for final metric calculation + edge_acc, edge_n = edge_filtered_accuracy( + y_true=self.results_df['actual_dir'], + p_cal=self.results_df['p_cal_pred'], + thr=fold_edge_threshold + ) + sharpe_recentered = calculate_sharpe_ratio(self.results_df['returns']) + + # Brier score (only for binary) + brier = np.nan + if not is_ternary: + brier = calculate_brier_score(self.results_df['actual_dir'], self.results_df['p_cal_pred']) + + self.metrics = { + "Fold Number": fold_num, + "Test Period Start": test_indices[0].strftime('%Y-%m-%d %H:%M'), + "Test Period End": test_indices[n_test-1].strftime('%Y-%m-%d %H:%M'), + "Initial Capital": self.initial_capital, + "Final Equity": final_equity, + "Total Net PnL": self.results_df['pnl'].sum(), + "Total Return (%)": total_return_pct, + "Annualized Sharpe Ratio": sharpe, + "Annualized Sharpe Ratio (Re-centred)": sharpe_recentered, + "Max Drawdown (%)": max_dd * 100, + "Profit Factor": profit_factor, + "Number of Trades": num_trades, + "Win Rate (%)": win_rate_steps, # Use step PnL win rate + "Edge Threshold Used": fold_edge_threshold, + "Edge Filtered Accuracy": edge_acc, + "Edge Filtered N": edge_n, + "Brier Score": brier, + "Buy & Hold Sharpe Ratio": bh_sharpe, + "Initial Optimal T": initial_optimal_T, + # Add rolling cal info if enabled + "Rolling Calibration Enabled": self.rolling_cal_enabled, + "Last Recalibration Step": last_recalib_step if self.rolling_cal_enabled else "N/A", + } + logger.info(f"Fold {fold_num}: --- Backtest Simulation Finished ---") + return self.results_df, self.metrics, metrics_log_df + + def save_results( + self, + results_df: pd.DataFrame, + metrics: Dict[str, Any], + results_dir: str, # Base results dir for the fold or run + run_id: str, # Overall pipeline run_id + metrics_log_df: Optional[pd.DataFrame] = None, + fold_num: Optional[int] = None # Pass fold number for unique filenames + ): + """ + Saves the backtest results, metrics report, and plots for a specific fold. + """ + fold_suffix = f"_fold_{fold_num}" if fold_num is not None else "" + logger.info(f"--- Fold {fold_num}: Saving Backtest Results --- ") + if results_df is None or metrics is None: + logger.warning(f"Fold {fold_num}: No results DataFrame or metrics to save.") + return + + if not self.io: + logger.error(f"Fold {fold_num}: IOManager not provided. Cannot save results.") + return + + # Define base filenames with fold suffix + metrics_fname = f"performance_metrics{fold_suffix}" + results_df_fname = f"backtest_results{fold_suffix}" + metrics_log_fname = f"backtest_metrics_log{fold_suffix}" + summary_plot_fname = f"backtest_summary{fold_suffix}" + + # IOManager handles the full path construction including run_id and section + # We save these fold-specific results within the main run's 'results' section + + # 1. Save Metrics Report + try: + self.io.save_json(metrics, metrics_fname, section='results', use_txt=True) + logger.info(f"Fold {fold_num}: Performance metrics saved to {self.io.path('results', metrics_fname, suffix='.txt')}") + except Exception as e: + logger.error(f"Fold {fold_num}: Failed to save metrics report: {e}", exc_info=True) + + # 2. Save Results DataFrame + try: + # Let IOManager decide csv/parquet based on its internal logic + saved_path = self.io.save_df(results_df, results_df_fname, section='results') + logger.info(f"Fold {fold_num}: Detailed backtest results saved to {saved_path}") + except Exception as e: + logger.error(f"Fold {fold_num}: Failed to save results DataFrame: {e}", exc_info=True) + + # 3. Save Metrics Log DataFrame + if metrics_log_df is not None and not metrics_log_df.empty: + try: + saved_path = self.io.save_df(metrics_log_df, metrics_log_fname, section='results') + logger.info(f"Fold {fold_num}: Periodic backtest metrics log saved to {saved_path}") + except Exception as e: + logger.error(f"Fold {fold_num}: Failed to save metrics log DataFrame: {e}", exc_info=True) + + # 4. Generate and Save Plots + if self.config.get('control', {}).get('generate_plots', True): + logger.info(f"Fold {fold_num}: Generating backtest plots...") + try: + # Plot 1: Multi-subplot summary + fig_size = self.config.get('output', {}).get('figure_size', [16, 9]) + fig, axes = plt.subplots(3, 1, figsize=fig_size, sharex=True) + plt.style.use('seaborn-v0_8-darkgrid') + footer_text = f"© GRU-SAC v3 | Run: {run_id} | Fold: {fold_num}" + + # --- Pane 1: Price + Edge Background --- # + ax = axes[0] + if 'close_price' in results_df.columns and not results_df['close_price'].isnull().all(): + ax.plot(results_df.index, results_df['close_price'], label='Price', color='black', alpha=0.9, linewidth=1.0) + ax.set_ylabel("Price") + + edge_thr = metrics.get("Edge Threshold Used", 0.1) # Get actual threshold used + long_edge_mask = results_df['edge_pred'] >= edge_thr + short_edge_mask = results_df['edge_pred'] <= -edge_thr + ax.fill_between(results_df.index, ax.get_ylim()[0], ax.get_ylim()[1], + where=long_edge_mask, color='blue', alpha=0.1, label=f'Long Edge >= {edge_thr:.2f}') + ax.fill_between(results_df.index, ax.get_ylim()[0], ax.get_ylim()[1], + where=short_edge_mask, color='red', alpha=0.1, label=f'Short Edge <= {-edge_thr:.2f}') + else: + ax.text(0.5, 0.5, 'Price data unavailable', ha='center', va='center') + ax.set_title(f'Backtest Summary (Fold: {fold_num})', fontsize=14) + ax.legend(fontsize=8) + ax.grid(True, linestyle='--', alpha=0.6) + + # --- Pane 2: Position Size --- # + ax = axes[1] + ax.plot(results_df.index, results_df['position'], label='Target Position', color='purple', drawstyle='steps-post') + ax.set_ylabel("Position (-1 to 1)") + ax.set_ylim(-1.1, 1.1) + ax.legend(fontsize=8) + ax.grid(True, linestyle='--', alpha=0.6) + + # --- Pane 3: Equity Curve + Drawdowns --- # + ax = axes[2] + equity_norm = results_df['equity'] / self.initial_capital + ax.plot(results_df.index, equity_norm, label='Strategy Equity', color='green') + # Add Buy & Hold if available + if 'bh_cumulative_return' in results_df.columns: + ax.plot(results_df.index, results_df['bh_cumulative_return'] + 1, label='Buy & Hold Equity', color='grey', linestyle=':') + ax.set_ylabel("Normalized Equity") + ax.set_xlabel("Time") + + rolling_max_norm = equity_norm.cummax() + drawdown_norm = (equity_norm - rolling_max_norm) + ax.fill_between(results_df.index, equity_norm, rolling_max_norm, where=drawdown_norm < 0, + color='red', alpha=0.3, label='Drawdown') + + sharpe_val = metrics.get('Annualized Sharpe Ratio (Re-centred)', metrics.get('Annualized Sharpe Ratio', np.nan)) + max_dd_val = metrics.get('Max Drawdown (%)', np.nan) + metrics_text = f"Sharpe: {sharpe_val:.2f}\nMax DD: {max_dd_val:.2f}%" + ax.text(0.02, 0.1, metrics_text, transform=ax.transAxes, fontsize=9, + verticalalignment='bottom', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5)) + + ax.legend(fontsize=8) + ax.grid(True, linestyle='--', alpha=0.6) + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M')) + plt.xticks(rotation=30) + + fig.text(0.99, 0.01, footer_text, horizontalalignment='right', + verticalalignment='bottom', fontsize=8, color='gray') + + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) + + # Save using IOManager + saved_path = self.io.save_figure(fig, summary_plot_fname, section='results') + logger.info(f"Fold {fold_num}: Backtest summary plot saved to {saved_path}") + plt.close(fig) + + except Exception as e: + logger.error(f"Fold {fold_num}: Failed to generate or save backtest plots: {e}", exc_info=True) + else: + logger.info(f"Fold {fold_num}: Skipping plot generation as per config.") + + logger.info(f"--- Fold {fold_num}: Finished Saving Backtest Results ---") + + # --- Helper for ECE Calculation (Revision 4) --- # + def _calculate_ece(self, probs: np.ndarray, y_true_onehot: np.ndarray, n_bins: int = 10) -> float: + """Calculates the Expected Calibration Error (ECE) for multi-class models. + + Args: + probs (np.ndarray): Predicted probabilities (N x K). + y_true_onehot (np.ndarray): True labels, one-hot encoded (N x K). + n_bins (int): Number of bins to divide the confidence scores into. + + Returns: + float: The calculated ECE. + """ + if probs.shape != y_true_onehot.shape: + raise ValueError("Probs and y_true_onehot must have the same shape.") + if len(probs.shape) != 2: + raise ValueError("Inputs must be 2D arrays (N x K).") + + num_samples = probs.shape[0] + confidences = np.max(probs, axis=1) + predictions = np.argmax(probs, axis=1) + true_labels = np.argmax(y_true_onehot, axis=1) + accuracies = (predictions == true_labels).astype(float) + + ece = 0.0 + bin_lowers = np.linspace(0.0, 1.0, n_bins + 1)[:-1] + bin_uppers = np.linspace(0.0, 1.0, n_bins + 1)[1:] + + for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): + in_bin = (confidences > bin_lower) & (confidences <= bin_upper) + prop_in_bin = np.mean(in_bin) + + if prop_in_bin > 0: + accuracy_in_bin = np.mean(accuracies[in_bin]) + avg_confidence_in_bin = np.mean(confidences[in_bin]) + ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin + + return ece + # --- End ECE Helper --- # \ No newline at end of file diff --git a/gru_sac_predictor/src/baseline_checker.py b/gru_sac_predictor/src/baseline_checker.py new file mode 100644 index 00000000..89061c4e --- /dev/null +++ b/gru_sac_predictor/src/baseline_checker.py @@ -0,0 +1,1375 @@ +""" +Contains the BaselineChecker class for running baseline model checks. +""" + +import logging +import pandas as pd +import numpy as np +import scipy.stats as st +from scipy.stats import binomtest +from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import RandomForestClassifier # Import RF +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, classification_report, confusion_matrix +from typing import Dict, Any, Optional, Tuple, Union, List +import matplotlib.pyplot as plt +import seaborn as sns +import os +# Add CalibratedClassifierCV import +from sklearn.calibration import CalibratedClassifierCV + +# Import IOManager for saving results +try: + from ..io_manager import IOManager +except ImportError: + logging.warning("IOManager could not be imported in BaselineChecker. Report saving will fail.") + IOManager = None # Placeholder if unavailable + +logger = logging.getLogger(__name__) + +class BaselineChecker: + """Runs baseline model checks, including raw and edge-filtered evaluations.""" + + def __init__(self, config: Dict[str, Any], io: IOManager): + """ + Initialize the BaselineChecker. + + Args: + config (Dict[str, Any]): Pipeline configuration dictionary. + io: IOManager instance for saving. + """ + self.config = config + self.io = io + + def _calculate_ci_lower_bound(self, k: int, n: int, p: float, conf_level: float) -> float: + """Helper to calculate binomial test CI lower bound.""" + if n == 0: return np.nan + try: + return binomtest(k, n, p=p, alternative='greater').proportion_ci(confidence_level=conf_level).low + except ValueError as e: + logger.error(f"Error calculating binomial test (k={k}, n={n}, p={p}): {e}. Returning NaN.") + return np.nan + + def run_binary_logistic_baseline(self, + X_train_scaled: pd.DataFrame, + y_train_dir_binary: pd.Series, + X_val_scaled: pd.DataFrame, + y_val_dir_binary: pd.Series, + io: Optional[IOManager] = None + ) -> Tuple[Optional[LogisticRegression], Any, Dict[str, Any]]: # Return base_model, calibrator, report + """ + Runs Binary Logistic Regression baseline. Uses internal split for raw CI. + Optionally applies isotonic calibration before edge checks. + + Args: + X_train_scaled: SCALED training features. + y_train_dir_binary: BINARY {0, 1} training labels. + X_val_scaled: SCALED validation features. + y_val_dir_binary: BINARY {0, 1} validation labels. + io: IOManager instance for saving. + + Returns: + Tuple: (base_lr_model, final_predictor, report_dict) + base_lr_model: Raw Logistic Regression model. + final_predictor: Calibrated model if enabled, else base_lr_model. + report_dict: Dictionary containing metrics. + """ + # --- Configuration --- # + baseline_cfg = self.config.get('baselines', {}).get('logistic_regression', {}) + max_iter = baseline_cfg.get('max_iter', 1000) + solver = baseline_cfg.get('solver', 'lbfgs') + # random_state = baseline_cfg.get('random_state', 42) # Use calibration random state + val_subset_split_ratio = baseline_cfg.get('val_subset_split_ratio', 0.2) + val_subset_shuffle = baseline_cfg.get('val_subset_shuffle', False) + ci_confidence_level = baseline_cfg.get('ci_confidence_level', 0.95) + # Calibration Config + calibration_enabled = baseline_cfg.get('calibration_enabled', False) + calibration_method = baseline_cfg.get('calibration_method', 'isotonic') + calibration_holdout = baseline_cfg.get('calibration_holdout', 0.2) + calibration_random_state = baseline_cfg.get('random_state', 42) # Reuse random_state + + logger.info(f"Running Binary Logistic Regression baseline (Baseline 1) on {X_train_scaled.shape[1]} scaled features.") + if calibration_enabled: + logger.info(f" Calibration enabled: method={calibration_method}, holdout={calibration_holdout}") + else: + logger.info(" Calibration disabled.") + + # --- Handle potential NaNs and Infs in input --- # + # ... (NaN/Inf handling code remains the same) ... + logger.debug(f"Checking for NaNs/Infs before LogReg. X_train_scaled NaNs: {X_train_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_train_scaled).sum().sum()}") + X_train_scaled_clean = pd.DataFrame(np.nan_to_num(X_train_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_train_scaled.index, columns=X_train_scaled.columns) + logger.debug(f"Checking for NaNs/Infs before LogReg. X_val_scaled NaNs: {X_val_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_val_scaled).sum().sum()}") + X_val_scaled_clean = pd.DataFrame(np.nan_to_num(X_val_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_val_scaled.index, columns=X_val_scaled.columns) + # Add check after cleaning + if X_train_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_train_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_binary_logistic_baseline (Train).") + if X_val_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_val_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_binary_logistic_baseline (Val).") + # --- End Handle NaNs/Infs --- + + # --- Initialize Report --- # + report = { + "baseline_model_type": "LogisticRegression_Binary", + "calibration_enabled": calibration_enabled, + "accuracy_val_subset": np.nan, + "ci_lower_bound": np.nan, + "n_val_subset": 0, + "accuracy_orig_val": np.nan, + "classification_report_orig_val": "N/A", + } + + base_lr_model = None + calibrator = None + final_predictor = None # This will hold the calibrated model if enabled, else the base model + + try: + # --- Split train data further if calibration is enabled --- # + if calibration_enabled: + logger.info(f"Splitting training data for calibration (holdout={calibration_holdout})...") + X_train_for_lr, X_cal, y_train_for_lr, y_cal = train_test_split( + X_train_scaled_clean, y_train_dir_binary, + test_size=calibration_holdout, + shuffle=True, # Shuffle is recommended for calibration split + stratify=y_train_dir_binary, + random_state=calibration_random_state + ) + logger.info(f" Base LR training set size: {X_train_for_lr.shape[0]}, Calibration set size: {X_cal.shape[0]}") + else: + # Use all provided training data for LR + X_train_for_lr = X_train_scaled_clean + y_train_for_lr = y_train_dir_binary + X_cal, y_cal = None, None # No calibration set + + # --- Train Base Logistic Regression Model --- # + logger.info(f"Fitting Base Binary LogReg on {X_train_for_lr.shape[0]} samples...") + base_lr_model = LogisticRegression( + max_iter=max_iter, + solver=solver, + random_state=calibration_random_state # Use same random state + ) + base_lr_model.fit(X_train_for_lr, y_train_for_lr) + final_predictor = base_lr_model # Default predictor is the base model + + # --- Fit Calibrator (if enabled) --- # + if calibration_enabled and X_cal is not None and not X_cal.empty: + logger.info(f"Fitting Calibrator ({calibration_method}) on {X_cal.shape[0]} samples...") + try: + calibrator = CalibratedClassifierCV( + base_lr_model, + method=calibration_method, + cv='prefit' # Use the already fitted base model + ) + calibrator.fit(X_cal, y_cal) + final_predictor = calibrator # Update final predictor to the calibrated one + logger.info("Calibration fitting complete.") + except Exception as cal_err: + logger.error(f"Error fitting calibrator: {cal_err}. Using uncalibrated model.", exc_info=True) + final_predictor = base_lr_model # Fallback to base model on calibration error + elif calibration_enabled: + logger.warning("Calibration enabled but calibration set is empty. Using uncalibrated model.") + final_predictor = base_lr_model + + # --- Internal Split for Raw CI (Always uses BASE model) --- # + logger.info("Performing internal split for Raw CI calculation...") + # Use the same X_train_for_lr data that the base model was trained on + X_teach, X_val_subset, y_teach, y_val_subset = train_test_split( + X_train_for_lr, y_train_for_lr, + test_size=val_subset_split_ratio, # Use the original val_subset ratio + shuffle=val_subset_shuffle, + stratify=y_train_for_lr if val_subset_shuffle else None, + random_state=calibration_random_state + 1 # Use different seed for this split + ) + report["n_val_subset"] = len(y_val_subset) + + if report["n_val_subset"] > 0: + # Fit a TEMPORARY base model on the teaching subset for raw CI + logger.info(f"Fitting TEMP Binary LogReg on teaching subset ({len(y_teach)} samples) for Raw CI...") + temp_lr_model = LogisticRegression( + max_iter=max_iter, + solver=solver, + random_state=calibration_random_state # Re-use state + ) + temp_lr_model.fit(X_teach, y_teach) + + logger.debug("Predicting on internal validation subset (Binary LogReg - RAW CI)...") + # Use predict (not predict_proba) on the TEMP model + y_pred_val_subset = temp_lr_model.predict(X_val_subset) + acc_val_subset = accuracy_score(y_val_subset, y_pred_val_subset) + report["accuracy_val_subset"] = float(acc_val_subset) + + # Calculate Raw CI lower bound (vs 0.5) using the TEMP model results + k_correct = np.sum(y_pred_val_subset == y_val_subset) + n = report["n_val_subset"] + raw_ci_lb = self._calculate_ci_lower_bound(k_correct, n, p=0.5, conf_level=ci_confidence_level) + report["ci_lower_bound"] = raw_ci_lb + logger.info(f"Baseline 1 (Binary LogReg - Raw Internal Split): Acc={acc_val_subset:.3f}, CI_LB={raw_ci_lb:.3f}") + else: + logger.warning("Internal validation subset empty for Binary LogReg Raw CI. Skipping calculation.") + + # --- Evaluate on Original Validation Set (Uses FINAL predictor) --- # + if final_predictor is not None and not X_val_scaled_clean.empty: + logger.info("Evaluating FINAL predictor (calibrated if enabled) on original validation set...") + # Use predict for accuracy/report, predict_proba will be used by edge check + orig_val_pred = final_predictor.predict(X_val_scaled_clean) + orig_val_acc = accuracy_score(y_val_dir_binary, orig_val_pred) + report["accuracy_orig_val"] = float(orig_val_acc) + try: + report["classification_report_orig_val"] = classification_report(y_val_dir_binary, orig_val_pred, output_dict=True, zero_division=0) + except Exception as report_err: + logger.warning(f"Could not generate Binary LogReg classification report: {report_err}") + report["classification_report_orig_val"] = "Error generating report" + logger.info(f"Baseline 1 (Binary LogReg - FINAL Predictor): Original validation set accuracy: {orig_val_acc:.3f}") + elif final_predictor is None: + logger.warning("FINAL predictor (calibrated or base LR) not available. Skipping evaluation on original validation set.") + else: # X_val_scaled is empty + logger.warning("Original validation set is empty. Skipping Binary LogReg evaluation on it.") + + except Exception as e: + logger.error(f"Failed during Binary Logistic Regression baseline: {e}", exc_info=True) + report["notes"] = f"Failed with error: {e}" + base_lr_model = None # Ensure models are None on failure + final_predictor = None + + # --- Save Report --- # + # ... (Saving logic remains the same) ... + if io: + try: + io.save_json(report, "baseline_binary_lr_report", section='results', use_txt=True) + logger.info("Saved Binary Logistic Regression baseline report.") + except Exception as save_e: + logger.error(f"Failed to save Binary LogReg baseline report: {save_e}") + else: + logger.warning("IOManager not provided, skipping saving Binary LogReg baseline report.") + + # Return the base model (for raw CI check logic if separated) and the final predictor (for edge checks) + return base_lr_model, final_predictor, report + + def run_binary_random_forest_baseline(self, + X_train_scaled: pd.DataFrame, + y_train_dir_binary: pd.Series, + X_val_scaled: pd.DataFrame, + y_val_dir_binary: pd.Series, + io: Optional[IOManager] = None + ) -> Tuple[Optional[RandomForestClassifier], Dict[str, Any]]: + """ + Runs Binary RandomForest baseline. Uses internal split for raw CI. + + Args: + X_train_scaled: SCALED training features. + y_train_dir_binary: BINARY {0, 1} training labels. + X_val_scaled: SCALED validation features. + y_val_dir_binary: BINARY {0, 1} validation labels. + io: IOManager instance for saving. + + Returns: + Tuple: (fitted_model, report_dict) + """ + # --- Configuration --- # + rf_cfg = self.config.get('baselines', {}).get('random_forest', {}) # Reuse binary RF config section + n_estimators = rf_cfg.get('n_estimators', 100) + max_depth = rf_cfg.get('max_depth', None) + min_samples_split = rf_cfg.get('min_samples_split', 2) + min_samples_leaf = rf_cfg.get('min_samples_leaf', 1) + random_state = rf_cfg.get('random_state', 42) + n_jobs = rf_cfg.get('n_jobs', -1) + val_subset_split_ratio = rf_cfg.get('val_subset_split_ratio', 0.2) + val_subset_shuffle = rf_cfg.get('val_subset_shuffle', False) + ci_confidence_level = rf_cfg.get('ci_confidence_level', 0.95) + + logger.info(f"Running Binary RandomForest baseline (Baseline 2) on {X_train_scaled.shape[1]} scaled features.") + logger.info(f" RF Params: n_estimators={n_estimators}, max_depth={max_depth}, n_jobs={n_jobs}") + + # --- Handle potential NaNs and Infs in input --- + logger.debug(f"Checking for NaNs/Infs before RF. X_train_scaled NaNs: {X_train_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_train_scaled).sum().sum()}") + X_train_scaled_clean = pd.DataFrame(np.nan_to_num(X_train_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_train_scaled.index, columns=X_train_scaled.columns) + logger.debug(f"Checking for NaNs/Infs before RF. X_val_scaled NaNs: {X_val_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_val_scaled).sum().sum()}") + X_val_scaled_clean = pd.DataFrame(np.nan_to_num(X_val_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_val_scaled.index, columns=X_val_scaled.columns) + # Add check after cleaning + if X_train_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_train_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_binary_random_forest_baseline (Train).") + if X_val_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_val_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_binary_random_forest_baseline (Val).") + # --- End Handle NaNs/Infs --- + + # --- Initialize Report --- # + report = { + "baseline_model_type": "RandomForestClassifier_Binary", # Baseline 2 ID + "accuracy_val_subset": np.nan, + "ci_lower_bound": np.nan, # Raw CI LB (internal split) + "n_val_subset": 0, + "accuracy_orig_val": np.nan, + "classification_report_orig_val": "N/A", + } + + baseline_model = None + acc_val_subset = np.nan # Initialize outside the if block + + try: + # Split train data into teach/validation subsets for raw CI calculation + X_teach, X_val_subset, y_teach, y_val_subset = train_test_split( + X_train_scaled_clean, y_train_dir_binary, # Use cleaned data + test_size=val_subset_split_ratio, + shuffle=val_subset_shuffle, + stratify=y_train_dir_binary if val_subset_shuffle else None + ) + report["n_val_subset"] = len(y_val_subset) + + if report["n_val_subset"] > 0: + logger.info(f"Fitting Binary RF on teaching subset ({len(y_teach)} samples)...") + baseline_model = RandomForestClassifier( + n_estimators=n_estimators, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + random_state=random_state, + n_jobs=n_jobs + ) + baseline_model.fit(X_teach, y_teach) # Fit on cleaned subset + + logger.debug("Predicting on internal validation subset (Binary RF)..") + y_pred_val_subset = baseline_model.predict(X_val_subset) # Predict on cleaned subset + acc_val_subset = accuracy_score(y_val_subset, y_pred_val_subset) + report["accuracy_val_subset"] = float(acc_val_subset) # Now correctly indented inside the if + + # Calculate Raw CI lower bound (vs 0.5) - Correctly indented inside the if + k_correct = np.sum(y_pred_val_subset == y_val_subset) + n = report["n_val_subset"] + raw_ci_lb = self._calculate_ci_lower_bound(k_correct, n, p=0.5, conf_level=ci_confidence_level) + report["ci_lower_bound"] = raw_ci_lb + logger.info(f"Baseline 2 (Binary RF - Raw Internal Split): Acc={acc_val_subset:.3f}, CI_LB={raw_ci_lb:.3f}") + else: # Correctly aligned with if + # Correctly indented inside the else + logger.warning("Internal validation subset empty for Binary RF. Skipping raw CI calculation.") + + # --- Evaluate on Original Validation Set --- # + if baseline_model is not None and not X_val_scaled_clean.empty: + logger.info("Evaluating Binary RF on original validation set...") + orig_val_pred = baseline_model.predict(X_val_scaled_clean) # Predict on cleaned validation data + orig_val_acc = accuracy_score(y_val_dir_binary, orig_val_pred) + report["accuracy_orig_val"] = float(orig_val_acc) + try: + report["classification_report_orig_val"] = classification_report(y_val_dir_binary, orig_val_pred, output_dict=True, zero_division=0) + except Exception as report_err: + logger.warning(f"Could not generate Binary RF classification report: {report_err}") + report["classification_report_orig_val"] = "Error generating report" + logger.info(f"Baseline 2 (Binary RF): Original validation set accuracy: {orig_val_acc:.3f}") + elif baseline_model is None: + logger.warning("Binary RF model not trained. Skipping evaluation on original validation set.") + else: # X_val_scaled is empty + logger.warning("Original validation set is empty. Skipping Binary RF evaluation on it.") + + except ImportError: + logger.error("Scikit-learn (RandomForestClassifier) not found. Cannot run Baseline 2.") + report["notes"] = "Skipped - scikit-learn not installed." + baseline_model = None + except Exception as e: + logger.error(f"Failed during Binary RandomForest baseline: {e}", exc_info=True) + report["notes"] = f"Failed with error: {e}" + baseline_model = None + + # --- Save RF Baseline Report --- # + if io: + try: + io.save_json(report, "baseline_binary_rf_report", section='results', use_txt=True) + logger.info("Saved Binary RandomForest baseline report.") + except Exception as save_e: + logger.error(f"Failed to save Binary RF baseline report: {save_e}") + else: + logger.warning("IOManager not provided, skipping saving Binary RF baseline report.") + + return baseline_model, report + + def run_ternary_logistic_baseline(self, + X_train_scaled: pd.DataFrame, + y_train_dir_ordinal: pd.Series, + X_val_scaled: pd.DataFrame, + y_val_dir_ordinal: pd.Series, + io: Optional[IOManager] = None + ) -> Tuple[Optional[LogisticRegression], Dict[str, Any]]: + """ + Runs Multinomial Logistic Regression baseline for ternary labels. Uses internal split for raw CI. + + Args: + X_train_scaled: SCALED training features. + y_train_dir_ordinal: ORDINAL {0, 1, 2} training labels. + X_val_scaled: SCALED validation features. + y_val_dir_ordinal: ORDINAL {0, 1, 2} validation labels. + io: IOManager instance for saving. + + Returns: + Tuple: (fitted_model, report_dict) + """ + # --- Configuration --- # + mlogreg_cfg = self.config.get('baselines', {}).get('multinomial_logistic_regression', {}) + max_iter = mlogreg_cfg.get('max_iter', 1000) + solver = mlogreg_cfg.get('solver', 'lbfgs') + multi_class = mlogreg_cfg.get('multi_class', 'multinomial') + random_state = mlogreg_cfg.get('random_state', 42) + val_subset_split_ratio = mlogreg_cfg.get('val_subset_split_ratio', 0.2) + val_subset_shuffle = mlogreg_cfg.get('val_subset_shuffle', False) + ci_confidence_level = mlogreg_cfg.get('ci_confidence_level', 0.95) + chance_level = 1/3 # For ternary CI + + logger.info(f"Running Ternary Logistic Regression baseline (Baseline 3) on {X_train_scaled.shape[1]} scaled features.") + + # --- Handle potential NaNs and Infs in input --- + logger.debug(f"Checking for NaNs/Infs before Ternary LogReg. X_train_scaled NaNs: {X_train_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_train_scaled).sum().sum()}") + X_train_scaled_clean = pd.DataFrame(np.nan_to_num(X_train_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_train_scaled.index, columns=X_train_scaled.columns) + logger.debug(f"Checking for NaNs/Infs before Ternary LogReg. X_val_scaled NaNs: {X_val_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_val_scaled).sum().sum()}") + X_val_scaled_clean = pd.DataFrame(np.nan_to_num(X_val_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_val_scaled.index, columns=X_val_scaled.columns) + # Add check after cleaning + if X_train_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_train_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_ternary_logistic_baseline (Train).") + if X_val_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_val_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_ternary_logistic_baseline (Val).") + # --- End Handle NaNs/Infs --- + + # --- Initialize Report --- # + report = { + "baseline_model_type": "LogisticRegression_Ternary", # Baseline 3 ID + "accuracy_val_subset": np.nan, + "ci_lower_bound": np.nan, # Raw CI LB (internal split vs 1/3) + "n_val_subset": 0, + "accuracy_orig_val": np.nan, + "classification_report_orig_val": "N/A", + } + + baseline_model = None + + try: + # Split train data into teach/validation subsets for raw CI calculation + X_teach, X_val_subset, y_teach, y_val_subset = train_test_split( + X_train_scaled_clean, y_train_dir_ordinal, # Use cleaned data + test_size=val_subset_split_ratio, + shuffle=val_subset_shuffle, + stratify=y_train_dir_ordinal if val_subset_shuffle else None + ) + report["n_val_subset"] = len(y_val_subset) + + if report["n_val_subset"] > 0: + logger.info(f"Fitting Ternary LogReg on teaching subset ({len(y_teach)} samples)...") + baseline_model = LogisticRegression( + max_iter=max_iter, + solver=solver, + multi_class=multi_class, + random_state=random_state + ) + baseline_model.fit(X_teach, y_teach) # Fit on cleaned subset + + logger.debug("Predicting on internal validation subset (Ternary LogReg)...") + y_pred_val_subset = baseline_model.predict(X_val_subset) # Predict on cleaned subset + acc_val_subset = accuracy_score(y_val_subset, y_pred_val_subset) + report["accuracy_val_subset"] = float(acc_val_subset) + + # Calculate Raw CI lower bound (vs 1/3) + k_correct = np.sum(y_pred_val_subset == y_val_subset) + n = report["n_val_subset"] + raw_ci_lb = self._calculate_ci_lower_bound(k_correct, n, p=chance_level, conf_level=ci_confidence_level) + report["ci_lower_bound"] = raw_ci_lb + logger.info(f"Baseline 3 (Ternary LogReg - Raw Internal Split): Acc={acc_val_subset:.3f}, CI_LB={raw_ci_lb:.3f} (vs {chance_level:.3f})") + else: + logger.warning("Internal validation subset empty for Ternary LogReg. Skipping raw CI calculation.") + + # Evaluate on original validation set + if baseline_model is not None and not X_val_scaled_clean.empty: + logger.info("Evaluating Ternary LogReg on original validation set...") + orig_val_pred = baseline_model.predict(X_val_scaled_clean) # Predict on cleaned validation data + orig_val_acc = accuracy_score(y_val_dir_ordinal, orig_val_pred) + report["accuracy_orig_val"] = float(orig_val_acc) + try: + report["classification_report_orig_val"] = classification_report(y_val_dir_ordinal, orig_val_pred, output_dict=True, zero_division=0) + except Exception as report_err: + logger.warning(f"Could not generate Ternary LogReg classification report: {report_err}") + report["classification_report_orig_val"] = "Error generating report" + logger.info(f"Baseline 3 (Ternary LogReg): Original validation set accuracy: {orig_val_acc:.3f}") + elif baseline_model is None: + logger.warning("Ternary LogReg model not trained. Skipping evaluation on original validation set.") + else: + logger.warning("Original validation set empty for Ternary LogReg.") + + except Exception as e: + logger.error(f"Failed during Ternary Logistic Regression baseline: {e}", exc_info=True) + report["notes"] = f"Failed with error: {e}" + baseline_model = None + + # Save Report + if io: + try: + io.save_json(report, "baseline_ternary_lr_report", section='results', use_txt=True) + logger.info("Saved Ternary Logistic Regression baseline report.") + except Exception as save_e: + logger.error(f"Failed to save Ternary LogReg report: {save_e}") + else: + logger.warning("IOManager not provided, skipping saving Ternary LogReg report.") + + return baseline_model, report + + def run_ternary_random_forest_baseline(self, + X_train_scaled: pd.DataFrame, + y_train_dir_ordinal: pd.Series, + X_val_scaled: pd.DataFrame, + y_val_dir_ordinal: pd.Series, + io: Optional[IOManager] = None + ) -> Tuple[Optional[RandomForestClassifier], Dict[str, Any]]: + """ + Runs Ternary RandomForest baseline. Uses internal split for raw CI. + + Args: + X_train_scaled: SCALED training features. + y_train_dir_ordinal: ORDINAL {0, 1, 2} training labels. + X_val_scaled: SCALED validation features. + y_val_dir_ordinal: ORDINAL {0, 1, 2} validation labels. + io: IOManager instance for saving. + + Returns: + Tuple: (fitted_model, report_dict) + """ + # --- Configuration --- # + rf_cfg = self.config.get('baselines', {}).get('ternary_random_forest', {}) + n_estimators = rf_cfg.get('n_estimators', 100) + max_depth = rf_cfg.get('max_depth', None) + min_samples_split = rf_cfg.get('min_samples_split', 2) + min_samples_leaf = rf_cfg.get('min_samples_leaf', 1) + random_state = rf_cfg.get('random_state', 42) + n_jobs = rf_cfg.get('n_jobs', -1) + val_subset_split_ratio = rf_cfg.get('val_subset_split_ratio', 0.2) + val_subset_shuffle = rf_cfg.get('val_subset_shuffle', False) + ci_confidence_level = rf_cfg.get('ci_confidence_level', 0.95) + chance_level = 1/3 + + logger.info(f"Running Ternary RandomForest baseline (Baseline 4) on {X_train_scaled.shape[1]} scaled features.") + logger.info(f" RF Params: n_estimators={n_estimators}, max_depth={max_depth}, n_jobs={n_jobs}") + + # --- Handle potential NaNs and Infs in input --- + logger.debug(f"Checking for NaNs/Infs before Ternary RF. X_train_scaled NaNs: {X_train_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_train_scaled).sum().sum()}") + X_train_scaled_clean = pd.DataFrame(np.nan_to_num(X_train_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_train_scaled.index, columns=X_train_scaled.columns) + logger.debug(f"Checking for NaNs/Infs before Ternary RF. X_val_scaled NaNs: {X_val_scaled.isnull().sum().sum()}, Infs: {np.isinf(X_val_scaled).sum().sum()}") + X_val_scaled_clean = pd.DataFrame(np.nan_to_num(X_val_scaled, nan=0.0, posinf=0.0, neginf=0.0), index=X_val_scaled.index, columns=X_val_scaled.columns) + # Add check after cleaning + if X_train_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_train_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_ternary_random_forest_baseline (Train).") + if X_val_scaled_clean.isnull().sum().sum() > 0 or np.isinf(X_val_scaled_clean).sum().sum() > 0: + logger.warning("NaNs or Infs persisted after np.nan_to_num in run_ternary_random_forest_baseline (Val).") + # --- End Handle NaNs/Infs --- + + # --- Initialize Report --- # + report = { + "baseline_model_type": "RandomForestClassifier_Ternary", # Baseline 4 ID + "accuracy_val_subset": np.nan, + "ci_lower_bound": np.nan, # Raw CI LB (internal split vs 1/3) + "n_val_subset": 0, + "accuracy_orig_val": np.nan, + "classification_report_orig_val": "N/A", + } + + baseline_model = None + + try: + # Split train data into teach/validation subsets for raw CI calculation + X_teach, X_val_subset, y_teach, y_val_subset = train_test_split( + X_train_scaled_clean, y_train_dir_ordinal, # Use cleaned data + test_size=val_subset_split_ratio, + shuffle=val_subset_shuffle, + stratify=y_train_dir_ordinal if val_subset_shuffle else None + ) + report["n_val_subset"] = len(y_val_subset) + + if report["n_val_subset"] > 0: + logger.info(f"Fitting Ternary RF on teaching subset ({len(y_teach)} samples)...") + baseline_model = RandomForestClassifier( + n_estimators=n_estimators, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + random_state=random_state, + n_jobs=n_jobs + ) + baseline_model.fit(X_teach, y_teach) # Fit on cleaned subset + + logger.debug("Predicting on internal validation subset (Ternary RF)..") + y_pred_val_subset = baseline_model.predict(X_val_subset) # Predict on cleaned subset + acc_val_subset = accuracy_score(y_val_subset, y_pred_val_subset) + report["accuracy_val_subset"] = float(acc_val_subset) + + # Calculate Raw CI lower bound (vs 1/3) + k_correct = np.sum(y_pred_val_subset == y_val_subset) + n = report["n_val_subset"] + raw_ci_lb = self._calculate_ci_lower_bound(k_correct, n, p=chance_level, conf_level=ci_confidence_level) + report["ci_lower_bound"] = raw_ci_lb + logger.info(f"Baseline 4 (Ternary RF - Raw Internal Split): Acc={acc_val_subset:.3f}, CI_LB={raw_ci_lb:.3f} (vs {chance_level:.3f})") + else: + logger.warning("Internal validation subset empty for Ternary RF. Skipping raw CI calculation.") + + # Evaluate on original validation set + if baseline_model is not None and not X_val_scaled_clean.empty: + logger.info("Evaluating Ternary RF on original validation set...") + orig_val_pred = baseline_model.predict(X_val_scaled_clean) # Predict on cleaned validation data + orig_val_acc = accuracy_score(y_val_dir_ordinal, orig_val_pred) + report["accuracy_orig_val"] = float(orig_val_acc) + try: + report["classification_report_orig_val"] = classification_report(y_val_dir_ordinal, orig_val_pred, output_dict=True, zero_division=0) + except Exception as report_err: + logger.warning(f"Could not generate Ternary RF classification report: {report_err}") + report["classification_report_orig_val"] = "Error generating report" + logger.info(f"Baseline 4 (Ternary RF): Original validation set accuracy: {orig_val_acc:.3f}") + elif baseline_model is None: + logger.warning("Ternary RF model not trained. Skipping evaluation on original validation set.") + else: + logger.warning("Original validation set empty for Ternary RF.") + + except ImportError: + logger.error("Scikit-learn (RandomForestClassifier) not found. Cannot run Baseline 4.") + report["notes"] = "Skipped - scikit-learn not installed." + baseline_model = None + except Exception as e: + logger.error(f"Failed during Ternary RandomForest baseline: {e}", exc_info=True) + report["notes"] = f"Failed with error: {e}" + baseline_model = None + + # Save Report + if io: + try: + io.save_json(report, "baseline_ternary_rf_report", section='results', use_txt=True) + logger.info("Saved Ternary RandomForest baseline report.") + except Exception as save_e: + logger.error(f"Failed to save Ternary RF report: {save_e}") + else: + logger.warning("IOManager not provided, skipping saving Ternary RF report.") + + return baseline_model, report + + def _run_edge_filtered_check(self, + predictor: Any, + X_val: pd.DataFrame, + y_val_ordinal: pd.Series, + edge_threshold: float, + baseline_name: str, + is_ternary: bool, + ci_conf_level: float) -> Dict[str, Any]: + """ + Helper function to run edge-filtered checks on a trained baseline model or calibrator. + + Args: + predictor: The trained sklearn classifier or CalibratedClassifierCV instance. # Updated docstring + X_val: The original SCALED validation features. + y_val_ordinal: The original ORDINAL validation labels (0,1 or 0,1,2). + edge_threshold: The threshold |p_up - p_down| >= thr. + baseline_name: Name for logging (e.g., "Binary-LR"). + is_ternary: Flag indicating if the model/labels are ternary. + ci_conf_level: Confidence level for CI calculation. + + Returns: + Dictionary with edge-filtered metrics. + """ + edge_report = { + "edge_accuracy": np.nan, + "edge_ci_lower_bound": np.nan, + "edge_num_samples_val": len(y_val_ordinal), + "edge_num_samples_val_filtered": 0, + } + + if predictor is None: # Check predictor + logger.warning(f"{baseline_name} edge-filtered check skipped: Predictor not available.") + return edge_report + if X_val.empty or y_val_ordinal.empty: + logger.warning(f"{baseline_name} edge-filtered check skipped: Validation data is empty.") + return edge_report + + logger.info(f"Running edge-filtered check for {baseline_name} (Threshold={edge_threshold:.3f})...") + + try: + # Predict probabilities using the provided predictor (base or calibrated) + if not hasattr(predictor, "predict_proba"): + logger.error(f"{baseline_name} edge-filtered check failed: Predictor has no predict_proba method.") + return edge_report + + # --- Handle potential NaNs and Infs in validation data before prediction --- # + # Assume X_val is clean here, cleaning happens before calling run_checks + X_val_clean = X_val + # --- End Handle NaNs/Infs --- # + + p_pred = predictor.predict_proba(X_val_clean) # Use predictor.predict_proba + + if is_ternary: + # Ternary edge calculation: P(up) - P(down) + if p_pred.shape[1] != 3: + logger.error(f"{baseline_name} edge check failed: Expected 3 probability columns for ternary, got {p_pred.shape[1]}.") + return edge_report + edge = p_pred[:, 2] - p_pred[:, 0] + y_true_binary_for_edge = (y_val_ordinal == 2).astype(int) + else: # Binary case + # Binary edge calculation: P(up) - P(down) = 2 * P(up) - 1 + if p_pred.shape[1] != 2: + logger.error(f"{baseline_name} edge check failed: Expected 2 probability columns for binary, got {p_pred.shape[1]}.") + return edge_report + p_up = p_pred[:, 1] + edge = 2 * p_up - 1 + y_true_binary_for_edge = y_val_ordinal.astype(int) # Already 0 or 1 + + # Filter samples based on edge threshold + mask_edge = np.abs(edge) >= edge_threshold + n_filtered = int(mask_edge.sum()) + edge_report["edge_num_samples_val_filtered"] = n_filtered + + if n_filtered > 0: + # Evaluate accuracy on the filtered subset + y_true_filtered = y_true_binary_for_edge[mask_edge] + edge_pred_dir = (edge[mask_edge] > 0).astype(int) # Predict 1 if edge > 0, else 0 + + edge_acc = accuracy_score(y_true_filtered, edge_pred_dir) + edge_report["edge_accuracy"] = float(edge_acc) + + # Calculate Edge CI lower bound (always vs 0.5 chance for binary direction) + k_correct_edge = np.sum(y_true_filtered == edge_pred_dir) + edge_ci_lb = self._calculate_ci_lower_bound(k_correct_edge, n_filtered, p=0.5, conf_level=ci_conf_level) + edge_report["edge_ci_lower_bound"] = edge_ci_lb + logger.info(f"{baseline_name} Edge-Filtered: Acc={edge_acc:.3f}, CI_LB={edge_ci_lb:.3f}, N_filt={n_filtered}/{edge_report['edge_num_samples_val']}") + else: + logger.warning(f"{baseline_name} Edge-Filtered: No validation samples met edge threshold ({edge_threshold}). Edge metrics are NaN.") + + except Exception as edge_err: + logger.error(f"Error during {baseline_name} edge-filtered calculation: {edge_err}", exc_info=True) + + return edge_report + + # <<< NEW METHOD: Forward Baseline Check >>> + def run_forward_baseline_check( + self, + X_train_fold: pd.DataFrame, + y_train_fold_ordinal: pd.Series, + X_test_fold: pd.DataFrame, + y_test_fold_ordinal: Optional[pd.Series], + fold_num: int, + io: Optional[IOManager] = None + ) -> bool: + """ + Runs a forward-looking baseline check using the fold's train/test split. + + Trains a Binary Logistic Regression on the entire training fold's data + and evaluates its CI lower bound on the test fold's data. + + Args: + X_train_fold: SCALED training features for the fold. + y_train_fold_ordinal: ORDINAL {0, 1, 2} training labels for the fold. + X_test_fold: SCALED test features for the fold. + y_test_fold_ordinal: Optional ORDINAL {0, 1, 2} test labels for the fold. + fold_num: Current fold number. + io: IOManager instance (optional, for saving detailed report). + + Returns: + bool: True if the forward baseline check passes the threshold, False otherwise. + """ + logger.info(f"--- Fold {fold_num}: Running Forward Baseline Check --- ") + gate_cfg = self.config.get('validation_gates', {}) + fwd_gate_cfg = gate_cfg.get('forward_baseline', {}) # New config section + run_fwd_check = fwd_gate_cfg.get('enabled', True) # Check if enabled + ci_threshold = fwd_gate_cfg.get('ci_threshold', 0.52) # Threshold for this check + + if not run_fwd_check: + logger.info(f"Fold {fold_num}: Skipping forward baseline check as per config.") + return True # Pass if disabled + + # --- Check Prerequisites --- + if X_train_fold is None or X_train_fold.empty or \ + y_train_fold_ordinal is None or y_train_fold_ordinal.empty or \ + X_test_fold is None or X_test_fold.empty or \ + y_test_fold_ordinal is None or y_test_fold_ordinal.empty: + logger.error(f"Fold {fold_num}: Missing or empty train/test data/labels for forward baseline check. Check Failed.") + return False + + # --- Prepare Binary Labels --- # + try: + y_train_fold_binary = y_train_fold_ordinal.map({0: 0, 1: 1, 2: 1}) + y_test_fold_binary = y_test_fold_ordinal.map({0: 0, 1: 1, 2: 1}) + if y_train_fold_binary.isnull().any() or y_test_fold_binary.isnull().any(): + raise ValueError("NaNs found after mapping ordinal to binary labels.") + except Exception as e: + logger.error(f"Fold {fold_num}: Error deriving binary labels for forward check: {e}. Check Failed.") + return False + + # --- Clean Data (NaN/Inf) --- # + try: + X_train_clean = pd.DataFrame(np.nan_to_num(X_train_fold, nan=0.0, posinf=0.0, neginf=0.0), index=X_train_fold.index, columns=X_train_fold.columns) + X_test_clean = pd.DataFrame(np.nan_to_num(X_test_fold, nan=0.0, posinf=0.0, neginf=0.0), index=X_test_fold.index, columns=X_test_fold.columns) + if X_train_clean.isnull().values.any() or np.isinf(X_train_clean).values.any() or \ + X_test_clean.isnull().values.any() or np.isinf(X_test_clean).values.any(): + logger.warning(f"Fold {fold_num}: NaNs/Infs persisted after cleaning in forward check.") + except Exception as e: + logger.error(f"Fold {fold_num}: Error cleaning data for forward check: {e}. Check Failed.") + return False + + # --- Train Model --- # + model = None + try: + baseline_cfg = self.config.get('baselines', {}).get('logistic_regression', {}) + max_iter = baseline_cfg.get('max_iter', 1000) + solver = baseline_cfg.get('solver', 'lbfgs') + random_state = baseline_cfg.get('random_state', 42) + + logger.info(f"Fold {fold_num}: Fitting Binary LogReg on full training fold ({len(y_train_fold_binary)} samples) for forward check...") + model = LogisticRegression( + max_iter=max_iter, + solver=solver, + random_state=random_state + ) + model.fit(X_train_clean, y_train_fold_binary) + except Exception as e: + logger.error(f"Fold {fold_num}: Failed to train LogReg model for forward check: {e}", exc_info=True) + return False # Fail check if model training fails + + # --- Evaluate on Test Set & Calculate CI LB --- # + ci_lb = np.nan + accuracy = np.nan + n_test = len(y_test_fold_binary) + status = "FAIL" + try: + logger.debug(f"Fold {fold_num}: Predicting on test fold ({n_test} samples) for forward check...") + y_pred_test = model.predict(X_test_clean) + accuracy = accuracy_score(y_test_fold_binary, y_pred_test) + k_correct = np.sum(y_pred_test == y_test_fold_binary) + + # Get confidence level from config (reuse from binary baseline section?) + bin_baseline_cfg = self.config.get('validation_gates', {}).get('baseline_binary', {}) + ci_conf_level = bin_baseline_cfg.get('ci_confidence_level', 0.95) + + ci_lb = self._calculate_ci_lower_bound(k_correct, n_test, p=0.5, conf_level=ci_conf_level) + + if pd.isna(ci_lb): + logger.warning(f"Forward Baseline Check (Fold {fold_num}): CI LB calculation resulted in NaN (Acc={accuracy:.4f}, N={n_test}). Check Failed.") + status = "FAIL (NaN)" + elif ci_lb >= ci_threshold: + logger.info(f"Forward Baseline Check PASSED (Fold {fold_num}): Test CI LB {ci_lb:.4f} >= {ci_threshold} (Acc={accuracy:.4f}, N={n_test})") + status = "PASS" + else: + logger.error(f"Forward Baseline Check FAILED (Fold {fold_num}): Test CI LB {ci_lb:.4f} < {ci_threshold} (Acc={accuracy:.4f}, N={n_test})") + status = "FAIL" + + except Exception as e: + logger.error(f"Fold {fold_num}: Error during prediction or CI calculation for forward check: {e}", exc_info=True) + status = "FAIL (Error)" + + # --- Save Detailed Report (Optional) --- # + if io: + report = { + "fold_num": fold_num, + "check_type": "forward_baseline_logistic_regression", + "n_train": len(y_train_fold_binary), + "n_test": n_test, + "test_accuracy": accuracy, + "test_ci_lower_bound": ci_lb, + "ci_threshold": ci_threshold, + "status": status + } + try: + io.save_json(report, f"forward_baseline_report_fold_{fold_num}", section='results', use_txt=True) + except Exception as save_e: + logger.error(f"Failed to save forward baseline report for fold {fold_num}: {save_e}") + + return status == "PASS" + # <<< END NEW METHOD >>> + + # --- Main Check Runner --- # + def run_checks(self, + X_train: pd.DataFrame, + y_train_ordinal: pd.Series, + X_val: pd.DataFrame, + y_val_ordinal: Optional[pd.Series], # Now always ordinal (0,1 or 0,1,2) + fold_num: int, + io_manager: Optional[IOManager] + ) -> bool: # Returns True if checks allow pipeline continuation + """ + Runs ALL enabled baseline checks (binary and ternary, respecting individual flags). + + Args: + X_train: SCALED training features. + y_train_ordinal: ORDINAL training labels {0, 1} or {0, 1, 2}. + X_val: SCALED validation features. + y_val_ordinal: ORDINAL validation labels {0, 1} or {0, 1, 2}. CAN BE NONE if val set is empty. + fold_num: Current fold number for logging/saving. + io_manager: IOManager instance. + + Returns: + bool: True if all enabled and required checks passed thresholds, False otherwise. + Returns True if master baseline check flag is disabled. + """ + logger.info(f"--- Fold {fold_num}: Running Baseline Checks (Binary & Ternary if enabled) ---") + overall_checks_passed = True # Assume passing unless a gate fails + check_reports = [] # Store reports from individual baselines + summary_data = [] # <<< ADDED: List to store data for the summary table + + # --- Check if Validation Labels Exist --- # + val_labels_exist = y_val_ordinal is not None and not y_val_ordinal.empty + if not val_labels_exist: + logger.warning(f"Fold {fold_num}: Validation labels (y_val_ordinal) are missing or empty. Checks requiring validation data will be skipped or may fail.") + + # --- Determine if checks should run at all (Master Switch) --- # + gate_cfg = self.config.get('validation_gates', {}) + run_baseline_check = gate_cfg.get('run_baseline_check', True) + if not run_baseline_check: + logger.info(f"Fold {fold_num}: Skipping baseline checks as per master config (validation_gates.run_baseline_check = False).") + return True + + # --- Prepare Binary Labels (Derive from Ordinal) --- # + # Mapping: 0 -> 0 (Down), 1 -> 1 (Flat -> Up), 2 -> 1 (Up -> Up) + y_train_binary = y_train_ordinal.map({0: 0, 1: 1, 2: 1}) if y_train_ordinal is not None else None + y_val_binary = y_val_ordinal.map({0: 0, 1: 1, 2: 1}) if val_labels_exist else None + if y_train_binary is None: + logger.error(f"Fold {fold_num}: Failed to derive binary training labels. Aborting baseline checks.") + raise SystemExit(f"Fold {fold_num}: Failed to derive binary training labels for baseline checks.") + logger.debug(f"Fold {fold_num}: Derived binary labels for checks.") + # --- End Prepare Binary Labels --- # + + # --- Initialize flags for the new "either/or" raw gate logic --- # + lr_raw_passed = False + rf_raw_passed = False + # --- Initialize flags for the new "either/or" edge gate logic --- # + lr_edge_passed = False + rf_edge_passed = False + + # --- Run BINARY CHECKS (Baselines 1 & 2) --- # + logger.info(f"--- Fold {fold_num}: Evaluating BINARY Baselines ---") + binary_gate_cfg = gate_cfg.get('baseline_binary', {}) + binary_ci_threshold = binary_gate_cfg.get('ci_threshold', 0.51) + # --- Add config for edge-filtered check --- # + binary_edge_threshold_value = binary_gate_cfg.get('edge_threshold_value', 0.1) # Threshold for *calculating* edge metrics + binary_edge_ci_threshold_gate = binary_gate_cfg.get('edge_ci_threshold_gate', 0.60) # Threshold for the *gate* + run_binary_lr_edge_check = binary_gate_cfg.get('run_logistic_regression_edge_check', True) # Specific flag for edge check + # --- End Add config --- # + run_binary_lr = binary_gate_cfg.get('run_logistic_regression', True) + run_binary_rf = binary_gate_cfg.get('run_random_forest', False) + + # Baseline 1: Binary Logistic Regression + if run_binary_lr: + logger.info(f"--- Fold {fold_num}: Running Binary Logistic Regression (Baseline 1) ---") + y_val_binary_input = y_val_binary if val_labels_exist else pd.Series(dtype=int) + # <<< Capture both base_lr and final_predictor >>> + lr_model, lr_predictor, lr_report = self.run_binary_logistic_baseline(X_train, y_train_binary, X_val, y_val_binary_input, self.io) + check_reports.append(lr_report) + + # --- Raw CI Check (Failure Gate - uses BASE model results stored in report) --- # + raw_ci_status = "SKIPPED (NaN)" + raw_ci_value = np.nan + if 'ci_lower_bound' in lr_report and not pd.isna(lr_report['ci_lower_bound']): + raw_ci_value = lr_report['ci_lower_bound'] + if raw_ci_value < binary_ci_threshold: + # overall_checks_passed = False # <<< MODIFIED: Don't fail overall gate here + raw_ci_status = "FAIL (Threshold)" + else: + lr_raw_passed = True # <<< ADDED: Set flag if passed + raw_ci_status = "PASS (Threshold)" + else: + logger.warning(f"BASELINE CHECK SKIPPED (Fold {fold_num}): Binary LogReg Raw CI LB not calculated or NaN. Cannot determine pass/fail.") + # overall_checks_passed = False # Treat missing mandatory check as failure -> This will be handled by the new gate + raw_ci_status = "SKIPPED (NaN)" + + summary_data.append({ + "Baseline": "Binary LogReg (Raw)", "Metric": "CI Lower Bound", + "Value": f"{raw_ci_value:.4f}", "Threshold": f">= {binary_ci_threshold:.2f}", + "Status": raw_ci_status, + "Is Gate": "See Combined" # <<< MODIFIED: Indicate combined gate + }) + + # --- Edge-Filtered Check (Failure Gate - uses FINAL predictor) --- # + edge_ci_status = "SKIPPED (Config)" + edge_ci_value = np.nan + if run_binary_lr_edge_check: + if overall_checks_passed: # Only run if raw check passed + # Use lr_predictor (which is either base LR or CalibratedClassifierCV) + if lr_predictor is not None and not X_val.empty and val_labels_exist: + logger.info(f"--- Fold {fold_num}: Running Binary LogReg Edge-Filtered Check (Gate Threshold={binary_edge_ci_threshold_gate:.2f}) ---") + # Get confidence level from config (used inside helper) + baseline_lr_cfg = self.config.get('baselines', {}).get('logistic_regression', {}) + ci_conf_level = baseline_lr_cfg.get('ci_confidence_level', 0.95) + + # <<< Pass lr_predictor to edge check >>> + edge_report_lr = self._run_edge_filtered_check( + predictor=lr_predictor, # Pass the final predictor (calibrated or not) + X_val=X_val, # Pass original scaled X_val (assumed clean) + y_val_ordinal=y_val_binary_input, # Pass binary labels (0/1) + edge_threshold=binary_edge_threshold_value, # Use config value for calculation + baseline_name="Binary LogReg", + is_ternary=False, + ci_conf_level=ci_conf_level + ) + check_reports.append(edge_report_lr) + + if 'edge_ci_lower_bound' in edge_report_lr and not pd.isna(edge_report_lr['edge_ci_lower_bound']): + edge_ci_value = edge_report_lr['edge_ci_lower_bound'] + if edge_ci_value < binary_edge_ci_threshold_gate: + # overall_checks_passed = False # <<< MODIFIED: Don't fail overall gate here + edge_ci_status = "FAIL (Threshold)" + else: # Correctly indented else for `if edge_ci_value < ...` + lr_edge_passed = True # <<< ADDED: Set flag if passed + edge_ci_status = "PASS (Threshold)" + else: # Correctly indented else for `if 'edge_ci_lower_bound' in ...` + logger.warning(f"BASELINE CHECK SKIPPED (Fold {fold_num}): Binary LogReg Edge CI LB not calculated or NaN. Cannot determine pass/fail.") + # overall_checks_passed = False # Treat missing mandatory check as failure -> Handled by combined gate + edge_ci_status = "SKIPPED (NaN)" # Added status update + else: # Correctly indented else for `if lr_predictor is not None ...` + logger.warning(f"Skipping Binary LogReg Edge-Filtered check: Predictor not available, validation data empty, or validation labels missing.") + # overall_checks_passed = False # Consider it a failure if mandatory edge check couldn't run -> Handled by combined gate + edge_ci_status = "SKIPPED (Prereq)" # Added status update + else: # Correctly indented else for `if overall_checks_passed:` + logger.info(f"Skipping Binary LogReg Edge-Filtered check as Raw CI check already failed.") + edge_ci_status = "SKIPPED (Prior Fail)" + # else: # run_binary_lr_edge_check is False, status remains SKIPPED (Config) + + summary_data.append({ + "Baseline": "Binary LogReg (Edge)", "Metric": "CI Lower Bound", + "Value": f"{edge_ci_value:.4f}", "Threshold": f">= {binary_edge_ci_threshold_gate:.2f}", + "Status": edge_ci_status, + "Is Gate": "See Combined Edge" if run_binary_lr_edge_check else "No" # <<< MODIFIED + }) + # --- End Edge-Filtered Check --- # + else: # Correctly indented else for `if run_binary_lr:` + logger.info(f"Fold {fold_num}: Skipping Binary Logistic Regression baseline as per config.") + summary_data.append({"Baseline": "Binary LogReg (Raw)", "Metric": "CI Lower Bound", "Value": "N/A", "Threshold": "N/A", "Status": "SKIPPED (Config)", "Is Gate": True}) + summary_data.append({"Baseline": "Binary LogReg (Edge)", "Metric": "CI Lower Bound", "Value": "N/A", "Threshold": "N/A", "Status": "SKIPPED (Config)", "Is Gate": True}) # Gate status depends on run_binary_lr_edge_check, but if LR is skipped, edge is too + + # Baseline 2: Binary Random Forest (Monitor Only -> ABORT if enabled) + rf_model = None # Initialize rf_model + rf_ci_status = "SKIPPED (Config)" + rf_ci_value = np.nan + rf_edge_ci_status = "SKIPPED (Config)" + rf_edge_ci_value = np.nan + + # --- Corrected Config Key: Use same section 'baseline_binary' --- # + run_binary_rf = binary_gate_cfg.get('run_random_forest', False) + binary_rf_ci_threshold = binary_gate_cfg.get('binary_rf_ci_lb', 0.54) # G2 Threshold + run_binary_rf_edge_check = binary_gate_cfg.get('run_random_forest_edge_check', False) # G9 Flag + binary_rf_edge_ci_threshold_gate = binary_gate_cfg.get('edge_binary_rf_ci_lb', 0.62) # G9 Threshold + # --- End Config Key Correction --- # + + if run_binary_rf: + logger.info(f"--- Fold {fold_num}: Running Binary Random Forest (Baseline 2) ---") + y_val_binary_input = y_val_binary if val_labels_exist else pd.Series(dtype=int) + rf_model, rf_report = self.run_binary_random_forest_baseline(X_train, y_train_binary, X_val, y_val_binary_input, self.io) + check_reports.append(rf_report) + if 'ci_lower_bound' in rf_report and not pd.isna(rf_report['ci_lower_bound']): + rf_ci_value = rf_report['ci_lower_bound'] + # --- G2: Abort if enabled and fails threshold --- # + if rf_ci_value < binary_rf_ci_threshold: + # logger.error(...) # Keep as error log but also trigger failure + # logger.error(f"BASELINE FAILED (GATE): Binary RF Raw CI LB {rf_ci_value:.4f} < {binary_rf_ci_threshold}") + # overall_checks_passed = False # <<< MODIFIED: Don't fail overall gate here (G2) + rf_ci_status = "FAIL (Threshold)" + else: # Correctly indented else for `if rf_ci_value < ...` + rf_raw_passed = True # <<< ADDED: Set flag if passed + rf_ci_status = "PASS (Threshold)" + else: # Correctly indented else for `if 'ci_lower_bound' in ...` + logger.warning(f"BASELINE CHECK SKIPPED (Fold {fold_num}): Binary RF Raw CI LB not calculated or NaN. Cannot determine pass/fail.") + # overall_checks_passed = False # <<< MODIFIED: Don't fail overall gate here (G2) + rf_ci_status = "SKIPPED (NaN)" + + # --- G9: Edge Binary RF Check (Abort if enabled) --- # + if run_binary_rf_edge_check: + if overall_checks_passed: # Only run if raw check passed + if rf_model is not None and not X_val.empty and val_labels_exist: + logger.info(f"--- Fold {fold_num}: Running Binary RF Edge-Filtered Check (Gate Threshold={binary_rf_edge_ci_threshold_gate:.2f}) ---") + baseline_rf_cfg = self.config.get('baselines', {}).get('random_forest', {}) + ci_conf_level = baseline_rf_cfg.get('ci_confidence_level', 0.95) + + edge_report_rf = self._run_edge_filtered_check( + predictor=rf_model, # <<< Fix: Changed model to predictor + X_val=X_val, + y_val_ordinal=y_val_binary_input, + edge_threshold=binary_edge_threshold_value, + baseline_name="Binary RF", + is_ternary=False, + ci_conf_level=ci_conf_level + ) + check_reports.append(edge_report_rf) + + if 'edge_ci_lower_bound' in edge_report_rf and not pd.isna(edge_report_rf['edge_ci_lower_bound']): + rf_edge_ci_value = edge_report_rf['edge_ci_lower_bound'] + if rf_edge_ci_value < binary_rf_edge_ci_threshold_gate: + # overall_checks_passed = False # <<< MODIFIED: Don't fail overall gate here (G9) + rf_edge_ci_status = "FAIL (Threshold)" + else: # Correctly indented else for `if rf_edge_ci_value < ...` + rf_edge_passed = True # <<< ADDED: Set flag if passed + rf_edge_ci_status = "PASS (Threshold)" + else: # Correctly indented else for `if 'edge_ci_lower_bound' in ...` + logger.warning(f"BASELINE CHECK SKIPPED (Fold {fold_num}): Binary RF Edge CI LB not calculated or NaN. Cannot determine pass/fail.") # Corrected log message + # overall_checks_passed = False # Fail if mandatory G9 cannot run -> Handled by combined gate + rf_edge_ci_status = "SKIPPED (NaN)" + else: # Correctly indented else for `if rf_model is not None ...` + logger.warning(f"Skipping Binary RF Edge-Filtered check: Model not trained, validation data empty, or validation labels missing.") + # overall_checks_passed = False # Fail if mandatory G9 cannot run -> Handled by combined gate + rf_edge_ci_status = "SKIPPED (Prereq)" # Added status update + else: # Correctly indented else for `if overall_checks_passed:` + logger.info(f"Skipping Binary RF Edge-Filtered check as Raw CI check already failed.") + rf_edge_ci_status = "SKIPPED (Prior Fail)" + # else: # run_binary_rf_edge_check is False, status remains SKIPPED (Config) + # --- End G9 --- # + else: # Correctly indented else for `if run_binary_rf:` + # If RF baseline is not run, status remains SKIPPED (Config) for both raw and edge + pass # No action needed, statuses already initialized + + summary_data.append({ # Add RF Raw to summary + "Baseline": "Binary RF (Raw)", "Metric": "CI Lower Bound", + "Value": f"{rf_ci_value:.4f}", "Threshold": f">= {binary_rf_ci_threshold:.2f}", + "Status": rf_ci_status, + "Is Gate": "See Combined" # <<< MODIFIED: Indicate combined gate + }) + summary_data.append({ # Add RF Edge to summary + "Baseline": "Binary RF (Edge)", "Metric": "CI Lower Bound", + "Value": f"{rf_edge_ci_value:.4f}", "Threshold": f">= {binary_rf_edge_ci_threshold_gate:.2f}", + "Status": rf_edge_ci_status, "Is Gate": "See Combined Edge" if run_binary_rf_edge_check else "No" # <<< MODIFIED + }) + # --- End BINARY CHECKS --- # + + # --- New Gate: Abort if NEITHER Raw LR NOR Raw RF passed --- # + # This logic applies if *any* binary raw check was configured to run. + # If both were disabled, this gate is bypassed. + lr_check_enabled = binary_gate_cfg.get('run_logistic_regression', True) + rf_check_enabled = binary_gate_cfg.get('run_random_forest', False) + if (lr_check_enabled or rf_check_enabled) and not (lr_raw_passed or rf_raw_passed): + logger.error(f"BASELINE FAILED (COMBINED GATE): Neither Binary LogReg Raw ({'PASS' if lr_raw_passed else 'FAIL/SKIPPED'}) nor Binary RF Raw ({'PASS' if rf_raw_passed else 'FAIL/SKIPPED'}) passed their respective thresholds. At least one must pass. Aborting fold.") + overall_checks_passed = False # Set overall flag to ensure the fold aborts + + # --- New Gate: Abort if NEITHER Enabled Edge LR NOR Enabled Edge RF passed --- # + lr_edge_check_enabled = binary_gate_cfg.get('run_logistic_regression_edge_check', True) + rf_edge_check_enabled = binary_gate_cfg.get('run_random_forest_edge_check', False) + if (lr_edge_check_enabled or rf_edge_check_enabled) and not (lr_edge_passed or rf_edge_passed): + logger.error(f"BASELINE FAILED (COMBINED EDGE GATE): Neither Binary LogReg Edge ({'PASS' if lr_edge_passed else 'FAIL/SKIPPED'}) nor Binary RF Edge ({'PASS' if rf_edge_passed else 'FAIL/SKIPPED'}) passed their respective thresholds (when enabled). At least one enabled edge check must pass. Aborting fold.") + overall_checks_passed = False # Set overall flag to ensure the fold aborts + + # --- Run TERNARY CHECKS (Baselines 3 & 4) - Monitor Only (unless edge check fails) --- # + # Initialize ternary statuses + ternary_lr_ci_status = "SKIPPED (Config)" + ternary_lr_ci_value = np.nan + ternary_rf_ci_status = "SKIPPED (Config)" + ternary_rf_ci_value = np.nan + ternary_lr_edge_ci_status = "SKIPPED (Config)" + ternary_lr_edge_ci_value = np.nan + ternary_rf_edge_ci_status = "SKIPPED (Config)" + ternary_rf_edge_ci_value = np.nan + ternary_labels_missing = False + lr_model_t = None # Initialize ternary models + rf_model_t = None + + use_ternary_config = self.config.get('gru', {}).get('use_ternary', False) + if use_ternary_config: + logger.info(f"--- Fold {fold_num}: Evaluating TERNARY Baselines (use_ternary=True in config) --- ") + ternary_gate_cfg = gate_cfg.get('baseline_ternary', {}) + # --- Add config keys for Ternary checks --- # + ternary_ci_threshold = ternary_gate_cfg.get('ternary_ci_lb', 0.40) # G3 Threshold + ternary_rf_ci_threshold = ternary_gate_cfg.get('ternary_rf_ci_lb', 0.42) # G4 Threshold + run_ternary_lr = ternary_gate_cfg.get('run_logistic_regression', True) + run_ternary_rf = ternary_gate_cfg.get('run_random_forest', False) + run_ternary_lr_edge_check = ternary_gate_cfg.get('run_logistic_regression_edge_check', False) # G10 Flag + ternary_edge_threshold_value = ternary_gate_cfg.get('edge_threshold_value', 0.1) + ternary_lr_edge_ci_threshold_gate = ternary_gate_cfg.get('edge_ternary_ci_lb', 0.57) # G10 Threshold + run_ternary_rf_edge_check = ternary_gate_cfg.get('run_random_forest_edge_check', False) # G11 Flag + ternary_rf_edge_ci_threshold_gate = ternary_gate_cfg.get('edge_ternary_rf_ci_lb', 0.58) # G11 Threshold + # --- End Add config keys --- # + + if not val_labels_exist: + logger.error(f"Fold {fold_num}: Cannot run Ternary checks because validation labels are missing.") + ternary_labels_missing = True + # NOTE: overall_checks_passed = False removed - handled by final check if mandatory + ternary_lr_ci_status = "SKIPPED (Labels Missing)" + ternary_rf_ci_status = "SKIPPED (Labels Missing)" + ternary_lr_edge_ci_status = "SKIPPED (Labels Missing)" + ternary_rf_edge_ci_status = "SKIPPED (Labels Missing)" + else: + # Baseline 3: Ternary Logistic Regression (Raw: Warn, Edge: Warn) + if run_ternary_lr: + logger.info(f"--- Fold {fold_num}: Running Ternary Logistic Regression (Baseline 3) ---") + lr_model_t, lr_report_t = self.run_ternary_logistic_baseline(X_train, y_train_ordinal, X_val, y_val_ordinal, self.io) + check_reports.append(lr_report_t) + if 'ci_lower_bound' in lr_report_t and not pd.isna(lr_report_t['ci_lower_bound']): + ternary_lr_ci_value = lr_report_t['ci_lower_bound'] + if ternary_lr_ci_value < ternary_ci_threshold: + logger.warning(f"MONITOR (BASELINE WARN): Ternary LogReg Raw CI LB {ternary_lr_ci_value:.4f} < {ternary_ci_threshold}") + ternary_lr_ci_status = "WARN (Low Raw CI)" + else: # Correctly indented else for `if ternary_lr_ci_value < ...` + ternary_lr_ci_status = "PASS (Raw CI)" + else: # Correctly indented else for `if 'ci_lower_bound' in ...` + logger.warning(f"MONITOR (BASELINE SKIPPED): Ternary LogReg Raw CI LB not calculated or NaN.") + ternary_lr_ci_status = "SKIPPED (NaN)" + + # --- G10: Edge Ternary LR Check (Warn) --- # + if run_ternary_lr_edge_check: # Correctly indented if + if lr_model_t is not None: + logger.info(f"--- Fold {fold_num}: Running Ternary LogReg Edge-Filtered Check (Warn Threshold={ternary_lr_edge_ci_threshold_gate:.2f}) ---") + baseline_tlr_cfg = self.config.get('baselines', {}).get('multinomial_logistic_regression', {}) + ci_conf_level = baseline_tlr_cfg.get('ci_confidence_level', 0.95) + + edge_report_tlr = self._run_edge_filtered_check( + predictor=lr_model_t, # <<< Fix: Changed model to predictor + X_val=X_val, + y_val_ordinal=y_val_ordinal, # Pass original ternary labels + edge_threshold=ternary_edge_threshold_value, + baseline_name="Ternary LogReg", + is_ternary=True, + ci_conf_level=ci_conf_level + ) + check_reports.append(edge_report_tlr) + if 'edge_ci_lower_bound' in edge_report_tlr and not pd.isna(edge_report_tlr['edge_ci_lower_bound']): + ternary_lr_edge_ci_value = edge_report_tlr['edge_ci_lower_bound'] + if ternary_lr_edge_ci_value < ternary_lr_edge_ci_threshold_gate: + logger.warning(f"MONITOR (BASELINE WARN): Ternary LogReg Edge CI LB {ternary_lr_edge_ci_value:.4f} < {ternary_lr_edge_ci_threshold_gate}") + ternary_lr_edge_ci_status = "WARN (Low Edge CI)" + else: # Correctly indented else for `if ternary_lr_edge_ci_value < ...` + ternary_lr_edge_ci_status = "PASS (Edge CI)" + else: # Correctly indented else for `if 'edge_ci_lower_bound' in ...` + logger.warning(f"MONITOR (BASELINE SKIPPED): Ternary LogReg Edge CI LB not calculated or NaN.") + ternary_lr_edge_ci_status = "SKIPPED (NaN)" + else: # Correctly indented else for `if lr_model_t is not None:` + logger.warning(f"Skipping Ternary LogReg Edge-Filtered check: Model not trained.") + ternary_lr_edge_ci_status = "SKIPPED (Model)" + else: # Correctly indented else for `if run_ternary_lr_edge_check:` + ternary_lr_edge_ci_status = "SKIPPED (Config)" + # --- End G10 --- # + else: # Correctly indented else for `if run_ternary_lr:` + ternary_lr_ci_status = "SKIPPED (Config)" + ternary_lr_edge_ci_status = "SKIPPED (Config)" + + # Baseline 4: Ternary Random Forest (Raw: Warn, Edge: Warn) + if run_ternary_rf: + logger.info(f"--- Fold {fold_num}: Running Ternary Random Forest (Baseline 4) ---") + rf_model_t, rf_report_t = self.run_ternary_random_forest_baseline(X_train, y_train_ordinal, X_val, y_val_ordinal, self.io) + check_reports.append(rf_report_t) + if 'ci_lower_bound' in rf_report_t and not pd.isna(rf_report_t['ci_lower_bound']): + ternary_rf_ci_value = rf_report_t['ci_lower_bound'] + if ternary_rf_ci_value < ternary_rf_ci_threshold: # Use G4 threshold + logger.warning(f"MONITOR (BASELINE WARN): Ternary RF Raw CI LB {ternary_rf_ci_value:.4f} < {ternary_rf_ci_threshold}") + ternary_rf_ci_status = "WARN (Low Raw CI)" + else: # Correctly indented else for `if ternary_rf_ci_value < ...` + ternary_rf_ci_status = "PASS (Raw CI)" + else: # Correctly indented else for `if 'ci_lower_bound' in ...` + logger.warning(f"MONITOR (BASELINE SKIPPED): Ternary RF Raw CI LB not calculated or NaN.") + ternary_rf_ci_status = "SKIPPED (NaN)" + + # --- G11: Edge Ternary RF Check (Warn) --- # + if run_ternary_rf_edge_check: # Correctly indented if + if rf_model_t is not None: + logger.info(f"--- Fold {fold_num}: Running Ternary RF Edge-Filtered Check (Warn Threshold={ternary_rf_edge_ci_threshold_gate:.2f}) ---") + baseline_trf_cfg = self.config.get('baselines', {}).get('ternary_random_forest', {}) + ci_conf_level = baseline_trf_cfg.get('ci_confidence_level', 0.95) + + edge_report_trf = self._run_edge_filtered_check( + predictor=rf_model_t, # <<< Fix: Changed model to predictor + X_val=X_val, + y_val_ordinal=y_val_ordinal, + edge_threshold=ternary_edge_threshold_value, + baseline_name="Ternary RF", + is_ternary=True, + ci_conf_level=ci_conf_level + ) + check_reports.append(edge_report_trf) + if 'edge_ci_lower_bound' in edge_report_trf and not pd.isna(edge_report_trf['edge_ci_lower_bound']): + ternary_rf_edge_ci_value = edge_report_trf['edge_ci_lower_bound'] + if ternary_rf_edge_ci_value < ternary_rf_edge_ci_threshold_gate: + logger.warning(f"MONITOR (BASELINE WARN): Ternary RF Edge CI LB {ternary_rf_edge_ci_value:.4f} < {ternary_rf_edge_ci_threshold_gate}") + ternary_rf_edge_ci_status = "WARN (Low Edge CI)" + else: # Correctly indented else for `if ternary_rf_edge_ci_value < ...` + ternary_rf_edge_ci_status = "PASS (Edge CI)" + else: # Correctly indented else for `if 'edge_ci_lower_bound' in ...` + logger.warning(f"MONITOR (BASELINE SKIPPED): Ternary RF Edge CI LB not calculated or NaN.") + ternary_rf_edge_ci_status = "SKIPPED (NaN)" + else: # Correctly indented else for `if rf_model_t is not None:` + logger.warning(f"Skipping Ternary RF Edge-Filtered check: Model not trained.") + ternary_rf_edge_ci_status = "SKIPPED (Model)" + else: # Correctly indented else for `if run_ternary_rf_edge_check:` + ternary_rf_edge_ci_status = "SKIPPED (Config)" + # --- End G11 --- # + else: # Correctly indented else for `if run_ternary_rf:` + ternary_rf_ci_status = "SKIPPED (Config)" + ternary_rf_edge_ci_status = "SKIPPED (Config)" + + # Add ternary results to summary table + summary_data.append({ # Ternary LR Raw + "Baseline": "Ternary LogReg (Raw)", "Metric": "CI Lower Bound", + "Value": f"{ternary_lr_ci_value:.4f}", "Threshold": f">= {ternary_ci_threshold:.2f}", + "Status": ternary_lr_ci_status, "Is Gate": False # G3: Warn + }) + summary_data.append({ # Ternary LR Edge + "Baseline": "Ternary LogReg (Edge)", "Metric": "CI Lower Bound", + "Value": f"{ternary_lr_edge_ci_value:.4f}", "Threshold": f">= {ternary_lr_edge_ci_threshold_gate:.2f}", + "Status": ternary_lr_edge_ci_status, "Is Gate": False # G10: Warn + }) + summary_data.append({ # Ternary RF Raw + "Baseline": "Ternary RF (Raw)", "Metric": "CI Lower Bound", + "Value": f"{ternary_rf_ci_value:.4f}", "Threshold": f">= {ternary_rf_ci_threshold:.2f}", + "Status": ternary_rf_ci_status, "Is Gate": False # G4: Warn + }) + summary_data.append({ # Ternary RF Edge + "Baseline": "Ternary RF (Edge)", "Metric": "CI Lower Bound", + "Value": f"{ternary_rf_edge_ci_value:.4f}", "Threshold": f">= {ternary_rf_edge_ci_threshold_gate:.2f}", + "Status": ternary_rf_edge_ci_status, "Is Gate": False # G11: Warn + }) + # --- End TERNARY CHECKS --- # + + + # --- Aggregate and Save Overall Report --- # + overall_report = { + "fold_num": fold_num, + "use_ternary_config": use_ternary_config, + "baseline_checks_passed_mandatory": overall_checks_passed, # Rename key for clarity + "individual_reports": check_reports, + "summary_table_data": summary_data # Add summary data to report + } + if self.io: + try: + self.io.save_json(overall_report, f"baseline_overall_report_fold_{fold_num}", section='results', use_txt=True) + logger.info(f"Saved baseline overall report for Fold {fold_num}.") + except Exception as e: + logger.error(f"Failed to save baseline overall report for Fold {fold_num}: {e}") + # --- End Aggregation --- # + + # --- Format and Print Summary Table --- # + # Calculate column widths dynamically (optional, can use fixed widths) + widths = { + "Baseline": max(len(row["Baseline"]) for row in summary_data) if summary_data else 15, + "Metric": max(len(row["Metric"]) for row in summary_data) if summary_data else 15, + "Value": 10, + "Threshold": 10, + "Status": max(len(row["Status"]) for row in summary_data) if summary_data else 18, + "Is Gate": 8 + } + total_width = sum(widths.values()) + len(widths) * 3 + 1 # Account for spacing + + # Header + header = f"| {'Baseline':<{widths['Baseline']}} | {'Metric':<{widths['Metric']}} | {'Value':>{widths['Value']}} | {'Threshold':>{widths['Threshold']}} | {'Status':<{widths['Status']}} | {'Is Gate':<{widths['Is Gate']}} |" + separator = '-' * total_width + + # Build table string + table_str = f"\n{separator}\nBASELINE CHECK SUMMARY - Fold {fold_num}\n{separator}\n{header}\n{separator}\n" + for row in summary_data: + # Format NaN values consistently + value_str = "N/A" if pd.isna(row["Value"]) or row["Value"]=='nan' else str(row["Value"]) + threshold_str = str(row["Threshold"]) + status_str = str(row["Status"]) + gate_str = "Yes" if row["Is Gate"] else "No" + table_str += f"| {row['Baseline']:<{widths['Baseline']}} | {row['Metric']:<{widths['Metric']}} | {value_str:>{widths['Value']}} | {threshold_str:>{widths['Threshold']}} | {status_str:<{widths['Status']}} | {gate_str:<{widths['Is Gate']}} |\n" + table_str += f"{separator}\n" + + # Print the table via logger + logger.info(table_str) + # --- End Format and Print --- # + + # --- Final Check & Raise SystemExit --- # + # Add ternary label missing condition to mandatory failure check + mandatory_failed = not overall_checks_passed + # Check if ternary LR was supposed to run and labels were missing + ternary_lr_should_run = use_ternary_config and run_ternary_lr + if ternary_lr_should_run and ternary_labels_missing: + logger.error(f"Fold {fold_num}: FAILED MANDATORY baseline check due to missing ternary labels when required for Ternary LogReg.") + mandatory_failed = True + + if mandatory_failed: + logger.error(f"Fold {fold_num}: FAILED one or more MANDATORY baseline check gates. See summary table and logs.") + raise SystemExit(f"Fold {fold_num}: FAILED MANDATORY Baseline Check Gate") + else: + final_status_msg = f"Fold {fold_num}: PASSED MANDATORY baseline checks." + # Add note about monitored checks if any failed or warned + monitored_issues = [r for r in summary_data if not r["Is Gate"] and ("FAIL" in r["Status"] or "WARN" in r["Status"])] + if monitored_issues: + final_status_msg += " (Note: Some monitored baselines failed/warned - see table)." + else: + final_status_msg += " (Monitored baselines OK or skipped)." + logger.info(final_status_msg) + + return overall_checks_passed \ No newline at end of file diff --git a/gru_sac_predictor/src/calibrate.py b/gru_sac_predictor/src/calibrate.py new file mode 100644 index 00000000..aa7ecdd2 --- /dev/null +++ b/gru_sac_predictor/src/calibrate.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import numpy as np +import matplotlib.pyplot as plt +from scipy.optimize import minimize_scalar +from scipy.special import expit, logit +from typing import Tuple + +__all__ = [ + "optimise_temperature", + "calibrate", + "reliability_curve", + "EDGE_THR", +] + +# ------------------------------------------------------------------ +# Hyper‑parameters +# ------------------------------------------------------------------ +# Minimum calibrated edge magnitude before we take a trade. +# EDGE_THR: float = 0.55 # Default value if not passed + + +# ------------------------------------------------------------------ +# Temperature scaling +# ------------------------------------------------------------------ + +def _nll_temperature(T: float, logit_p: np.ndarray, y_true: np.ndarray) -> float: + """Negative log‑likelihood of labels given scaled logits.""" + p_cal = expit(logit_p / T) + # Binary cross‑entropy (NLL) + eps = 1e-12 + p_cal = np.clip(p_cal, eps, 1 - eps) + nll = -(y_true * np.log(p_cal) + (1 - y_true) * np.log(1 - p_cal)) + return float(np.mean(nll)) + + +def optimise_temperature(p_raw: np.ndarray, y_true: np.ndarray) -> float: + """Return optimal temperature `T` that minimises NLL on `y_true`.""" + p_raw = p_raw.flatten().astype(float) + y_true = y_true.flatten().astype(int) + logit_p = logit(np.clip(p_raw, 1e-6, 1 - 1e-6)) + + res = minimize_scalar( + lambda T: _nll_temperature(T, logit_p, y_true), + bounds=(0.05, 10.0), + method="bounded", + ) + return float(res.x) + + +# ------------------------------------------------------------------ +# Public API +# ------------------------------------------------------------------ + +def calibrate(p_raw: np.ndarray, T: float) -> np.ndarray: + """Return temperature‑scaled probabilities.""" + return expit(logit(np.clip(p_raw, 1e-6, 1 - 1e-6)) / T) + + +def reliability_curve( + p_raw: np.ndarray, + y_true: np.ndarray, + n_bins: int = 10, + show: bool = False, +) -> Tuple[np.ndarray, np.ndarray]: + """Return (bin_centres, empirical_prob) and optionally plot reliability.""" + p_raw = p_raw.flatten() + y_true = y_true.flatten() + bins = np.linspace(0, 1, n_bins + 1) + bin_ids = np.digitize(p_raw, bins) - 1 # 0‑indexed + + bin_centres = 0.5 * (bins[:-1] + bins[1:]) + acc = np.zeros(n_bins) + for i in range(n_bins): + idx = bin_ids == i + if np.any(idx): + acc[i] = y_true[idx].mean() + else: + acc[i] = np.nan + + if show: + plt.figure(figsize=(4, 4)) + plt.plot([0, 1], [0, 1], "k--", label="perfect") + plt.plot(bin_centres, acc, "o-", label="empirical") + plt.xlabel("Predicted P(up)") + plt.ylabel("Actual frequency") + plt.title("Reliability curve") + plt.legend() + plt.grid(True) + plt.tight_layout() + return bin_centres, acc + + +# ------------------------------------------------------------------ +# Signal filter +# ------------------------------------------------------------------ + +def action_signal(p_cal: np.ndarray, edge_threshold: float = 0.55) -> np.ndarray: + """Return trading signal: 1, -1 or 0 based on calibrated edge threshold.""" + up = p_cal > edge_threshold + dn = p_cal < 1 - edge_threshold + return np.where(up, 1, np.where(dn, -1, 0)) \ No newline at end of file diff --git a/gru_sac_predictor/src/calibrator.py b/gru_sac_predictor/src/calibrator.py new file mode 100644 index 00000000..4915ccbb --- /dev/null +++ b/gru_sac_predictor/src/calibrator.py @@ -0,0 +1,371 @@ +""" +Calibration Component for GRU Model Probabilities. + +Provides methods for temperature scaling (Platt scaling) and generating +action signals based on calibrated probabilities. +""" + +import pandas as pd # Added for type hinting +import numpy as np +import matplotlib.pyplot as plt +from scipy.optimize import minimize_scalar +from scipy.special import expit, logit +from typing import Tuple, Optional, List, Dict, Any +import logging +import os +from sklearn.metrics import confusion_matrix # Added for Youden's J + +logger = logging.getLogger(__name__) + +class Calibrator: + """Handles probability calibration using temperature scaling.""" + + def __init__(self, edge_threshold: float): + """ + Initialize the Calibrator. + + Args: + edge_threshold (float): Minimum calibrated edge magnitude for taking a trade (e.g., 0.55 means P(up) > 0.55 or P(down) > 0.55). + """ + self.edge_threshold = edge_threshold + self.optimal_T: Optional[float] = None # Stores the calculated temperature + logger.info(f"Calibrator initialized with edge threshold: {self.edge_threshold}") + + def _nll_objective(self, T: float, logit_p: np.ndarray, y_true: np.ndarray) -> float: + """Negative log-likelihood objective function for temperature optimization.""" + if T <= 0: + return np.inf # Temperature must be positive + p_cal = expit(logit_p / T) + # Binary cross-entropy (NLL) + eps = 1e-12 # Epsilon for numerical stability + p_cal = np.clip(p_cal, eps, 1 - eps) + # Ensure y_true is broadcastable if necessary (should be 1D) + nll = -(y_true * np.log(p_cal) + (1 - y_true) * np.log(1 - p_cal)) + return float(np.mean(nll)) + + def _nll_objective_regularized(self, T: float, logit_p: np.ndarray, y_true: np.ndarray, reg_lambda: float) -> float: + """NLL objective with L2 regularization on T (deviation from 1).""" + base_nll = self._nll_objective(T, logit_p, y_true) + # Regularize T towards 1 (no scaling) + l2_penalty = reg_lambda * ((T - 1.0) ** 2) + return base_nll + l2_penalty + + def optimise_temperature(self, p_raw: np.ndarray, y_true: np.ndarray, bounds=(0.1, 10.0), reg_lambda: float = 0.0) -> float: + """ + Finds the optimal temperature `T` by minimizing NLL on validation data, + optionally with L2 regularization. + + Args: + p_raw (np.ndarray): Raw model probabilities (validation set). + y_true (np.ndarray): True binary labels (validation set). + bounds (tuple): Bounds for the temperature search. + reg_lambda (float): L2 regularization strength (penalty on (T-1)^2). Defaults to 0.0 (no reg). + + Returns: + float: The optimal temperature found. + """ + logger.info(f"Optimizing calibration temperature using {len(p_raw)} samples (L2 lambda={reg_lambda})...") + # Ensure inputs are flat numpy arrays and correct type + p_raw = np.asarray(p_raw).flatten().astype(float) + y_true = np.asarray(y_true).flatten().astype(int) + + # Clip raw probabilities and compute logits for numerical stability + eps = 1e-7 + p_clipped = np.clip(p_raw, eps, 1 - eps) + logit_p = logit(p_clipped) + + # Handle cases where all predictions are the same (logit might be inf) + if np.isinf(logit_p).any(): + logger.warning("Infinite values encountered in logits during temperature scaling. Clipping may be too aggressive or predictions are uniform. Returning T=1.0") + self.optimal_T = 1.0 + return 1.0 + + try: + # Select objective function based on regularization + objective_func = self._nll_objective if reg_lambda == 0.0 else \ + lambda T: self._nll_objective_regularized(T, logit_p, y_true, reg_lambda) + + res = minimize_scalar( + objective_func, + bounds=bounds, + method="bounded", + ) + + if res.success: + optimal_T_found = float(res.x) + logger.info(f"Optimal temperature found: T = {optimal_T_found:.4f}") + self.optimal_T = optimal_T_found + return optimal_T_found + else: + logger.warning(f"Temperature optimization failed: {res.message}. Returning T=1.0") + self.optimal_T = 1.0 + return 1.0 + except Exception as e: + logger.error(f"Error during temperature optimization: {e}", exc_info=True) + logger.warning("Returning T=1.0 due to optimization error.") + self.optimal_T = 1.0 + return 1.0 + + def calibrate(self, p_raw: np.ndarray, T: Optional[float] = None) -> np.ndarray: + """ + Applies temperature scaling to raw probabilities. + + Args: + p_raw (np.ndarray): Raw model probabilities. + T (Optional[float]): Temperature value. If None, uses the stored optimal_T. + Defaults to 1.0 if optimal_T is also None. + + Returns: + np.ndarray: Calibrated probabilities. + """ + temp = T if T is not None else self.optimal_T + if temp is None: + logger.warning("Temperature T not provided and not optimized yet. Using T=1.0 for calibration.") + temp = 1.0 + if temp <= 0: + logger.error(f"Invalid temperature T={temp}. Using T=1.0 instead.") + temp = 1.0 + + # Clip raw probabilities and compute logits + eps = 1e-7 + p_clipped = np.clip(np.asarray(p_raw).astype(float), eps, 1 - eps) + logit_p = logit(p_clipped) + + # Apply temperature scaling + p_cal = expit(logit_p / temp) + return p_cal + + def reliability_curve( + self, + p_pred: np.ndarray, # Expects raw OR calibrated probabilities + y_true: np.ndarray, + n_bins: int = 10, + plot_title: str = "Reliability Curve", + save_path: Optional[str] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Computes and optionally plots the reliability curve for binary classification. + + Args: + p_pred (np.ndarray): Predicted probabilities for the positive class. + y_true (np.ndarray): True binary labels (0 or 1). + n_bins (int): Number of bins for the curve. + plot_title (str): Title for the plot. + save_path (Optional[str]): If provided, saves the plot to this path. + + Returns: + Tuple[np.ndarray, np.ndarray]: (bin_centers, empirical_prob) + """ + p_pred = np.asarray(p_pred).flatten() + y_true = np.asarray(y_true).flatten() + + if not np.all((y_true == 0) | (y_true == 1)): + # Handle potential soft labels by converting to hard labels for accuracy calculation + logger.debug("Non-binary values detected in y_true for reliability curve. Converting > 0.5 to 1, else 0.") + y_true = (y_true > 0.5).astype(int) + + bins = np.linspace(0, 1, n_bins + 1) + bin_centers = 0.5 * (bins[:-1] + bins[1:]) + # Handle potential edge cases with digitize for values exactly 1.0 + # Ensure indices are within [0, n_bins-1] + bin_ids = np.digitize(p_pred, bins[1:], right=False) + bin_ids = np.clip(bin_ids, 0, n_bins - 1) # Clip to handle edge case p_pred = 0 + + empirical_prob = np.zeros(n_bins) * np.nan # Default to NaN + avg_confidence = np.zeros(n_bins) * np.nan # Default to NaN + bin_counts = np.zeros(n_bins, dtype=int) + + for i in range(n_bins): + idx = bin_ids == i + bin_counts[i] = np.sum(idx) + if bin_counts[i] > 0: + empirical_prob[i] = y_true[idx].mean() + avg_confidence[i] = p_pred[idx].mean() + + # Filter out bins with no samples for plotting + valid_mask = bin_counts > 0 + plot_centers = bin_centers[valid_mask] + plot_probs = empirical_prob[valid_mask] + + # Calculate ECE (Expected Calibration Error) + ece = np.sum(np.abs(empirical_prob[valid_mask] - avg_confidence[valid_mask]) * (bin_counts[valid_mask] / len(p_pred))) + plot_title += f" (ECE = {ece:.3f})" + + if save_path: + try: + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + ax.plot([0, 1], [0, 1], "k--", label="Perfect Calibration") + # Only plot bins with counts + ax.plot(plot_centers, plot_probs, "o-", label="Model Calibration") + + # Add bar chart for confidence distribution underneath + ax2 = ax.twinx() + ax2.bar(bin_centers, bin_counts, width=(bins[1]-bins[0])*0.9, alpha=0.2, color='grey', label='Bin Counts') + ax2.set_ylabel("Count per Bin", color='grey') + ax2.tick_params(axis='y', labelcolor='grey') + ax2.set_ylim(bottom=0) + fig.legend(loc="upper left", bbox_to_anchor=(0.1, 0.9)) + + ax.set_xlabel("Mean Predicted Probability (per bin)") + ax.set_ylabel("Fraction of Positives (per bin)") + ax.set_title(plot_title) + ax.grid(True, alpha=0.5) + ax.set_xlim([0, 1]) + ax.set_ylim([0, 1]) + + plt.tight_layout() + plt.savefig(save_path) + plt.close(fig) + logger.info(f"Reliability curve saved to {save_path}") + except Exception as e: + logger.error(f"Failed to generate or save reliability plot: {e}", exc_info=True) + + return bin_centers, empirical_prob + + def action_signal(self, p_cal: np.ndarray) -> np.ndarray: + """ + Generates trading signal (1, -1, or 0) based on calibrated probability + and the instance's edge threshold. + + Args: + p_cal (np.ndarray): Calibrated probabilities P(up). + + Returns: + np.ndarray: Action signals (1 for long, -1 for short, 0 for neutral). + """ + p_cal = np.asarray(p_cal) + # Signal long if P(up) > threshold + go_long = p_cal > self.edge_threshold + # Signal short if P(down) > threshold, which is P(up) < 1 - threshold + go_short = p_cal < (1.0 - self.edge_threshold) + + # Assign signals: 1 for long, -1 for short, 0 otherwise + signal = np.where(go_long, 1, np.where(go_short, -1, 0)) + return signal + + def optimize_edge_threshold(self, p_cal: np.ndarray, y_true: np.ndarray, n_thresholds: int = 101) -> float: + """ + Finds the optimal binary edge threshold by maximizing Youden's J statistic. + + Args: + p_cal (np.ndarray): Calibrated probabilities for the positive class (validation set). + y_true (np.ndarray): True binary labels (validation set). + n_thresholds (int): Number of thresholds to test between 0 and 1. + + Returns: + float: The threshold that maximizes Youden's J (TPR - FPR). + Defaults to 0.5 if calculation fails or yields no improvement. + """ + logger.info(f"Optimizing edge threshold using {len(p_cal)} validation samples...") + p_cal = np.asarray(p_cal).flatten() + y_true = np.asarray(y_true).flatten() + + # Handle potential soft labels for threshold optimization + if not np.all((y_true == 0) | (y_true == 1)): + logger.debug("Non-binary values detected in y_true for threshold optimization. Converting > 0.5 to 1, else 0.") + y_true = (y_true > 0.5).astype(int) + + thresholds = np.linspace(0, 1, n_thresholds) + best_j = -1 + best_threshold = 0.5 # Default + + for thr in thresholds: + # Predictions based on threshold + y_pred = (p_cal >= thr).astype(int) + + try: + # Calculate confusion matrix: [[TN, FP], [FN, TP]] + tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() + + # Calculate TPR (Sensitivity) and FPR (1 - Specificity) + tpr = tp / (tp + fn) if (tp + fn) > 0 else 0 + fpr = fp / (fp + tn) if (fp + tn) > 0 else 0 + + # Youden's J statistic + j_stat = tpr - fpr + + if j_stat > best_j: + best_j = j_stat + best_threshold = thr + + except ValueError as e: + # Handle cases where confusion_matrix might fail (e.g., only one class predicted/present) + logger.debug(f"Could not calculate confusion matrix for threshold {thr:.3f}: {e}. Skipping.") + continue + except Exception as e: + logger.error(f"Unexpected error calculating Youden's J for threshold {thr:.3f}: {e}. Skipping.", exc_info=True) + continue + + if best_j <= 0: + logger.warning(f"Could not find threshold yielding positive Youden's J (Max J = {best_j:.3f}). Defaulting to 0.5.") + best_threshold = 0.5 + else: + logger.info(f"Optimal edge threshold found: {best_threshold:.3f} (Max Youden's J = {best_j:.3f})") + + # Update the calibrator's internal threshold if desired? + # self.edge_threshold = best_threshold # Or leave it to the pipeline to manage + + return best_threshold + + def fit_rolling(self, p_raw: np.ndarray, y_true: np.ndarray, every_n: int, window_m: int, indices: Optional[pd.Index] = None) -> List[Tuple[Any, float]]: + """ + Fits temperature scaling parameters on rolling windows. + + Args: + p_raw (np.ndarray): Raw model probabilities. + y_true (np.ndarray): True binary labels. + every_n (int): How often to refit the temperature (e.g., 5000 steps). + window_m (int): The size of the rolling window to use for refitting (e.g., 20000 steps). + indices (Optional[pd.Index]): Optional original indices corresponding to p_raw/y_true. + If provided, the output schedule uses these indices. + + Returns: + List[Tuple[Any, float]]: A schedule of (index, optimal_T) tuples. + The index is the step/timestamp where the new T applies. + Returns empty list if inputs are too short or invalid. + """ + logger.info(f"Performing rolling temperature calibration: every {every_n} steps, window size {window_m}") + n_samples = len(p_raw) + if n_samples < window_m or n_samples < every_n or window_m <= 0 or every_n <= 0: + logger.warning(f"Insufficient samples ({n_samples}) or invalid params (every={every_n}, window={window_m}) for rolling calibration. Skipping.") + return [] + + schedule = [] + last_fit_idx = -1 # Track the index of the last fit to avoid redundant calculations + + for i in range(window_m -1, n_samples, every_n): + # Define the window for this fit + start_idx = max(0, i - window_m + 1) + end_idx = i + 1 # Slice is exclusive at the end + + # Check if this window significantly overlaps with the last one processed + # This is a simple check; more sophisticated checks could be added + # if start_idx <= last_fit_idx: + # continue + + window_p = p_raw[start_idx:end_idx] + window_y = y_true[start_idx:end_idx] + + # Use the standard optimise_temperature method on the window + logger.debug(f"Rolling fit at step {i} (Window: [{start_idx}, {end_idx}))") + optimal_t_window = self.optimise_temperature(window_p, window_y) + + # Determine the index/timestamp for this schedule entry + schedule_idx = indices[i] if indices is not None and i < len(indices) else i + + schedule.append((schedule_idx, optimal_t_window)) + last_fit_idx = i + + # Store the latest T as the primary optimal_T for the instance? + self.optimal_T = optimal_t_window + + logger.info(f"Generated rolling calibration schedule with {len(schedule)} entries.") + if not schedule: + logger.warning("Rolling calibration resulted in an empty schedule.") + # Consider fitting once on the whole data as fallback? + # optimal_t_full = self.optimise_temperature(p_raw, y_true) + # schedule_idx_full = indices[-1] if indices is not None else n_samples - 1 + # schedule.append((schedule_idx_full, optimal_t_full)) + # self.optimal_T = optimal_t_full + + return schedule \ No newline at end of file diff --git a/gru_sac_predictor/src/calibrator_vector.py b/gru_sac_predictor/src/calibrator_vector.py new file mode 100644 index 00000000..824b183e --- /dev/null +++ b/gru_sac_predictor/src/calibrator_vector.py @@ -0,0 +1,454 @@ +""" +Vector Scaling Calibration for Multi-Class Classifiers. + +Ref: revisions.txt Section 4 +Based on: https://arxiv.org/abs/1706.04599 (On Calibration of Modern Neural Networks) +""" + +import numpy as np +import tensorflow as tf +from scipy.optimize import minimize +import logging +from typing import Optional, Tuple, List, Any +import pandas as pd + +logger = logging.getLogger(__name__) + +class VectorCalibrator: + """ + Implements Vector Scaling calibration. + + Finds a diagonal matrix W and a vector b such that the calibrated + probabilities p_cal = softmax(W * z + b) minimize the NLL, where z + are the pre-softmax logits. + For K classes, this involves optimizing 2*K parameters. + """ + + def __init__(self): + """Initialize the calibrator.""" + self.W = None # Diagonal weight matrix (represented as a vector) + self.b = None # Bias vector + self.optimal_params = None # Store concatenated [W_diag, b] + + def _softmax(self, x: np.ndarray) -> np.ndarray: + """Numerically stable softmax.""" + e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return e_x / np.sum(e_x, axis=-1, keepdims=True) + + def _nll_loss(self, params: np.ndarray, logits: np.ndarray, y_onehot: np.ndarray) -> float: + """ + Negative Log-Likelihood loss function (without regularization). + + Args: + params (np.ndarray): Concatenated vector [W_diag, b]. + logits (np.ndarray): Raw output logits from the model (shape: N x K). + y_onehot (np.ndarray): One-hot encoded true labels (shape: N x K). + + Returns: + float: The calculated NLL loss. + """ + num_classes = logits.shape[1] + if len(params) != 2 * num_classes: + raise ValueError(f"Expected {2*num_classes} params, got {len(params)}") + + W_diag = params[:num_classes] + b = params[num_classes:] + + # Apply scaling: W is diagonal, so element-wise multiplication works + scaled_logits = logits * W_diag + b # Broadcasting W_diag and b + + # Calculate probabilities using softmax + calibrated_probs = self._softmax(scaled_logits) + + # Avoid log(0) - clip probabilities + eps = 1e-12 + calibrated_probs = np.clip(calibrated_probs, eps, 1.0 - eps) + + # Calculate NLL + nll = -np.sum(y_onehot * np.log(calibrated_probs), axis=1) + return np.mean(nll) + + def _nll_loss_regularized(self, params: np.ndarray, logits: np.ndarray, y_onehot: np.ndarray, reg_lambda: float) -> float: + """NLL loss with L2 regularization on W and b (deviation from W=1, b=0).""" + base_nll = self._nll_loss(params, logits, y_onehot) + + num_classes = logits.shape[1] + W_diag = params[:num_classes] + b = params[num_classes:] + + # L2 penalty: encourage W towards 1 and b towards 0 + l2_penalty_W = np.sum((W_diag - 1.0) ** 2) + l2_penalty_b = np.sum(b ** 2) + l2_penalty = reg_lambda * (l2_penalty_W + l2_penalty_b) + + return base_nll + l2_penalty + + def fit(self, logits: np.ndarray, y_onehot: np.ndarray, reg_lambda: float = 0.0) -> None: + """ + Finds the optimal scaling parameters W (diagonal) and b, optionally with L2 regularization. + + Args: + logits (np.ndarray): Raw output logits from the model (shape: N x K). + y_onehot (np.ndarray): One-hot encoded true labels (shape: N x K). + reg_lambda (float): L2 regularization strength. Defaults to 0.0. + """ + if logits.shape[0] != y_onehot.shape[0]: + raise ValueError("Logits and labels must have the same number of samples.") + if len(logits.shape) != 2 or len(y_onehot.shape) != 2: + raise ValueError("Logits and labels must be 2D arrays (N x K).") + if logits.shape[1] != y_onehot.shape[1]: + raise ValueError("Logits and labels must have the same number of classes.") + + num_classes = logits.shape[1] + logger.info(f"Fitting Vector Scaling for {num_classes} classes (L2 lambda={reg_lambda})...") + + # Initial guess: W = identity (diag=1), b = zero vector + initial_params = np.concatenate([np.ones(num_classes), np.zeros(num_classes)]) + + # Define bounds (optional, but can help stability) + # Allow W > 0, b can be anything + bounds = [(1e-6, None)] * num_classes + [(None, None)] * num_classes + + # Select objective function based on regularization + if reg_lambda == 0.0: + objective_func = self._nll_loss + extra_args = (logits, y_onehot) + else: + objective_func = self._nll_loss_regularized + extra_args = (logits, y_onehot, reg_lambda) + + # Minimize the NLL loss + # Using L-BFGS-B as it handles bounds well + result = minimize( + objective_func, + initial_params, + args=extra_args, # Pass logits, y_onehot (and lambda if needed) + method='L-BFGS-B', + bounds=bounds, # Use bounds + options={'maxiter': 1000, 'ftol': 1e-8} # Example options + ) + + if result.success: + self.optimal_params = result.x + self.W = self.optimal_params[:num_classes] + self.b = self.optimal_params[num_classes:] + logger.info(f"Vector Scaling fit successful. Optimal NLL: {result.fun:.4f}") + logger.info(f" Optimal W (diag): {np.round(self.W, 3)}") + logger.info(f" Optimal b: {np.round(self.b, 3)}") + else: + logger.error(f"Vector Scaling optimization failed: {result.message}") + # Handle failure: maybe use initial params or raise error? + self.optimal_params = initial_params # Fallback to initial + self.W = self.optimal_params[:num_classes] + self.b = self.optimal_params[num_classes:] + logger.warning("Using initial parameters (W=I, b=0) due to optimization failure.") + + def calibrate(self, logits: np.ndarray) -> np.ndarray: + """ + Applies the learned scaling parameters to new logits. + + Args: + logits (np.ndarray): Raw logits from the model (shape: N x K). + + Returns: + np.ndarray: Calibrated probabilities (shape: N x K). + Returns uncalibrated softmax if fit() wasn't called or failed. + """ + if self.W is None or self.b is None: + logger.warning("Vector Scaling parameters not fitted. Returning uncalibrated softmax.") + return self._softmax(logits) + + if logits.shape[1] != len(self.W): + raise ValueError(f"Input logits have {logits.shape[1]} classes, but calibrator was fitted for {len(self.W)} classes.") + + scaled_logits = logits * self.W + self.b + calibrated_probs = self._softmax(scaled_logits) + return calibrated_probs + + def save_params(self, filepath: str) -> None: + """Saves the optimal parameters (W_diag and b) to a .npy file.""" + if self.optimal_params is None: + logger.error("No parameters to save. Call fit() first.") + return + try: + np.save(filepath, self.optimal_params) + logger.info(f"Vector Scaling parameters saved to {filepath}") + except Exception as e: + logger.error(f"Failed to save parameters to {filepath}: {e}") + + def load_params(self, filepath: str) -> bool: + """Loads the optimal parameters from a .npy file.""" + try: + params = np.load(filepath) + num_params = len(params) + if num_params % 2 != 0: + raise ValueError(f"Loaded params have odd length ({num_params}), expected 2*K.") + num_classes = num_params // 2 + self.optimal_params = params + self.W = self.optimal_params[:num_classes] + self.b = self.optimal_params[num_classes:] + logger.info(f"Vector Scaling parameters loaded successfully from {filepath} ({num_classes} classes).") + return True + except FileNotFoundError: + logger.error(f"Parameter file not found: {filepath}") + return False + except Exception as e: + logger.error(f"Failed to load parameters from {filepath}: {e}") + # Reset parameters on load failure? + self.W = None + self.b = None + self.optimal_params = None + return False + + def optimize_edge_threshold_ternary(self, p_cal: np.ndarray, y_onehot: np.ndarray, + positive_class_indices: List[int] = [2], # Default: Class 2 (Up) is positive + n_thresholds: int = 101) -> float: + """ + Finds an optimal edge threshold for ternary classification using Youden's J. + This simplifies the ternary problem to binary: Positive Class(es) vs Others. + + Args: + p_cal (np.ndarray): Calibrated probabilities (N x K, e.g., N x 3). + y_onehot (np.ndarray): True one-hot labels (N x K). + positive_class_indices (List[int]): List of indices considered 'positive' (e.g., [2] for 'Up'). + n_thresholds (int): Number of thresholds to test for the summed positive class probability. + + Returns: + float: The threshold on the summed positive class probability that maximizes Youden's J. + Defaults to 0.5 / len(positive_class_indices) if calculation fails. + """ + logger.info(f"Optimizing ternary edge threshold (Positive classes: {positive_class_indices})...") + if p_cal.shape[1] != y_onehot.shape[1]: + raise ValueError("Calibrated probs and one-hot labels must have same number of classes.") + if not positive_class_indices or any(i >= p_cal.shape[1] for i in positive_class_indices): + raise ValueError(f"Invalid positive_class_indices: {positive_class_indices}") + + # Combine probabilities for positive classes + p_positive_sum = np.sum(p_cal[:, positive_class_indices], axis=1) + + # Create binary true labels: 1 if true class is in positive_class_indices, 0 otherwise + y_true_binary = np.sum(y_onehot[:, positive_class_indices], axis=1).astype(int) + + # Reuse the binary threshold optimization logic + # Need an instance of the base Calibrator or copy the method here + # For simplicity, let's assume a temporary base calibrator or call a static helper + # This assumes a base Calibrator class exists with optimize_edge_threshold + try: + # Use a generic function or temporary instance approach if needed + # Here, assuming we can call optimize_edge_threshold directly (or copy its logic) + # Copied logic approach for self-containment: + thresholds = np.linspace(0, 1, n_thresholds) + best_j = -1 + num_pos_classes = len(positive_class_indices) + default_threshold = 0.5 / num_pos_classes if num_pos_classes > 0 else 0.5 # Adjust default + best_threshold = default_threshold + + for thr in thresholds: + y_pred_binary = (p_positive_sum >= thr).astype(int) + try: + tn, fp, fn, tp = confusion_matrix(y_true_binary, y_pred_binary, labels=[0, 1]).ravel() + tpr = tp / (tp + fn) if (tp + fn) > 0 else 0 + fpr = fp / (fp + tn) if (fp + tn) > 0 else 0 + j_stat = tpr - fpr + if j_stat > best_j: + best_j = j_stat + best_threshold = thr + except ValueError: + continue # Skip thresholds where confusion matrix fails + except Exception as e: + logger.error(f"Error calculating Youden's J for threshold {thr:.3f}: {e}") + continue + + if best_j <= 0: + logger.warning(f"Could not find ternary threshold yielding positive Youden's J (Max J = {best_j:.3f}). Defaulting to {default_threshold:.3f}.") + best_threshold = default_threshold + else: + logger.info(f"Optimal ternary edge threshold found: {best_threshold:.3f} (Max Youden's J = {best_j:.3f})") + + return best_threshold + + except Exception as e: + logger.error(f"Error during ternary edge threshold optimization: {e}", exc_info=True) + return default_threshold # Return default on error + + def fit_rolling(self, logits: np.ndarray, y_onehot: np.ndarray, every_n: int, window_m: int, indices: Optional[pd.Index] = None) -> List[Tuple[Any, np.ndarray]]: + """ + Fits Vector Scaling parameters on rolling windows. + + Args: + logits (np.ndarray): Raw model logits. + y_onehot (np.ndarray): True one-hot labels. + every_n (int): How often to refit the parameters (e.g., 5000 steps). + window_m (int): The size of the rolling window for refitting (e.g., 20000 steps). + indices (Optional[pd.Index]): Optional original indices corresponding to data. + + Returns: + List[Tuple[Any, np.ndarray]]: A schedule of (index, optimal_params) tuples. + optimal_params = [W_diag, b]. + """ + logger.info(f"Performing rolling vector scaling calibration: every {every_n} steps, window size {window_m}") + n_samples = len(logits) + if n_samples < window_m or n_samples < every_n or window_m <= 0 or every_n <= 0: + logger.warning(f"Insufficient samples ({n_samples}) or invalid params (every={every_n}, window={window_m}) for rolling calibration. Skipping.") + return [] + + schedule = [] + last_fit_idx = -1 + num_classes = logits.shape[1] + initial_params = np.concatenate([np.ones(num_classes), np.zeros(num_classes)]) # Default params + + for i in range(window_m - 1, n_samples, every_n): + start_idx = max(0, i - window_m + 1) + end_idx = i + 1 + + window_logits = logits[start_idx:end_idx] + window_y = y_onehot[start_idx:end_idx] + + logger.debug(f"Rolling fit at step {i} (Window: [{start_idx}, {end_idx}))") + + # Fit parameters for this window + # Use a temporary instance or call a static fit method? + # Re-using self._nll_loss and minimize directly seems cleaner here + bounds = [(1e-6, None)] * num_classes + [(None, None)] * num_classes + result = minimize( + self._nll_loss, + initial_params, # Start from scratch or use previous params? + args=(window_logits, window_y), + method='L-BFGS-B', + bounds=bounds, + options={'maxiter': 500, 'ftol': 1e-7} # Faster options for rolling? + ) + + optimal_params_window = initial_params # Default if fit fails + if result.success: + optimal_params_window = result.x + else: + logger.warning(f"Rolling fit failed at step {i}: {result.message}. Using default params for this step.") + + schedule_idx = indices[i] if indices is not None and i < len(indices) else i + schedule.append((schedule_idx, optimal_params_window)) + last_fit_idx = i + + # Update the main instance parameters to the latest successful fit? + if result.success: + self.optimal_params = optimal_params_window + self.W = self.optimal_params[:num_classes] + self.b = self.optimal_params[num_classes:] + + logger.info(f"Generated rolling vector calibration schedule with {len(schedule)} entries.") + if not schedule: + logger.warning("Rolling vector calibration resulted in an empty schedule.") + # Optionally fit once on the full data as fallback + + return schedule + + # --- Add Reliability Curve Plotting --- # + def reliability_curve( + self, + probs: np.ndarray, # Calibrated probabilities (N, K) + y_true: np.ndarray, # True labels (N,) or one-hot (N, K) + n_bins: int = 10, + plot_title: str = "Multi-Class Reliability Curve", + save_path: Optional[str] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Computes and optionally plots the reliability curve for multi-class classification. + Uses the maximum predicted probability as confidence. + + Args: + probs (np.ndarray): Calibrated probabilities (shape: N x K). + y_true (np.ndarray): True labels (class index, shape: N) or one-hot (N, K). + n_bins (int): Number of bins for the curve. + plot_title (str): Title for the plot. + save_path (Optional[str]): If provided, saves the plot to this path. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: (bin_centers, accuracy_in_bin, avg_confidence_in_bin) + """ + if len(probs.shape) != 2: + raise ValueError("Input probabilities must be 2D (N x K).") + + n_samples, num_classes = probs.shape + + # Ensure y_true is class index (N,) + if len(y_true.shape) == 2 and y_true.shape[1] == num_classes: + y_true_idx = np.argmax(y_true, axis=1) + elif len(y_true.shape) == 1: + y_true_idx = np.asarray(y_true).astype(int) + else: + raise ValueError("y_true must be 1D class indices or 2D one-hot encoded.") + + if len(y_true_idx) != n_samples: + raise ValueError("Number of samples mismatch between probs and y_true.") + + # Get confidence (max probability) and predicted class + confidences = np.max(probs, axis=1) + predictions = np.argmax(probs, axis=1) + correctness = (predictions == y_true_idx).astype(float) + + bins = np.linspace(0, 1, n_bins + 1) + bin_centers = 0.5 * (bins[:-1] + bins[1:]) + bin_ids = np.digitize(confidences, bins[1:], right=False) + bin_ids = np.clip(bin_ids, 0, n_bins - 1) + + accuracy_in_bin = np.zeros(n_bins) * np.nan + avg_confidence_in_bin = np.zeros(n_bins) * np.nan + bin_counts = np.zeros(n_bins, dtype=int) + + for i in range(n_bins): + idx = bin_ids == i + bin_counts[i] = np.sum(idx) + if bin_counts[i] > 0: + accuracy_in_bin[i] = np.mean(correctness[idx]) + avg_confidence_in_bin[i] = np.mean(confidences[idx]) + + # Filter out bins with no samples for plotting + valid_mask = bin_counts > 0 + plot_centers = bin_centers[valid_mask] + plot_accuracy = accuracy_in_bin[valid_mask] + plot_confidence = avg_confidence_in_bin[valid_mask] + + # Calculate ECE + ece = np.sum(np.abs(plot_accuracy - plot_confidence) * (bin_counts[valid_mask] / n_samples)) + plot_title += f" (ECE = {ece:.3f})" + + if save_path: + # Reuse plotting logic similar to binary Calibrator + try: + import matplotlib.pyplot as plt # Local import for plotting + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + ax.plot([0, 1], [0, 1], "k--", label="Perfect Calibration") + ax.plot(plot_confidence, plot_accuracy, "o-", label="Model Calibration") # Plot Acc vs Conf + + ax2 = ax.twinx() + ax2.bar(bin_centers, bin_counts, width=(bins[1]-bins[0])*0.9, alpha=0.2, color='grey', label='Bin Counts') + ax2.set_ylabel("Count per Bin", color='grey') + ax2.tick_params(axis='y', labelcolor='grey') + ax2.set_ylim(bottom=0) + # Adjust legend position slightly for multi-class + fig.legend(loc="upper left", bbox_to_anchor=(0.1, 0.9)) + + ax.set_xlabel("Average Confidence (Max Probability per bin)") + ax.set_ylabel("Accuracy (per bin)") + ax.set_title(plot_title) + ax.grid(True, alpha=0.5) + ax.set_xlim([0, 1]) + ax.set_ylim([0, 1]) + + plt.tight_layout() + plt.savefig(save_path) + plt.close(fig) + logger.info(f"Multi-class reliability curve saved to {save_path}") + except ImportError: + logger.error("Matplotlib not found. Cannot generate reliability plot.") + except Exception as e: + logger.error(f"Failed to generate or save multi-class reliability plot: {e}", exc_info=True) + + return bin_centers, accuracy_in_bin, avg_confidence_in_bin + # --- End Reliability Curve Plotting --- # + +# --- Set Availability Flag --- # +# Define this at the module level after the class definition +VECTOR_CALIBRATOR_AVAILABLE = True +logger.info("VectorCalibrator class defined and available.") +# --- End Set Flag --- # \ No newline at end of file diff --git a/gru_sac_predictor/src/data_loader.py b/gru_sac_predictor/src/data_loader.py new file mode 100644 index 00000000..ffa2bd38 --- /dev/null +++ b/gru_sac_predictor/src/data_loader.py @@ -0,0 +1,648 @@ +""" +Data Loader for Cryptocurrency Market Data from SQLite Databases. +""" + +import os +import logging +import pandas as pd +import numpy as np +import sqlite3 +import glob +import re +import sys +from datetime import datetime, timedelta +from typing import List, Optional + +# Import IOManager (adjust path if necessary) +from .io_manager import IOManager +from omegaconf import DictConfig # Assuming OmegaConf for config type hint + +logger = logging.getLogger(__name__) + +# Define helper function (outside the class is fine) +def sample_by_volatility(df: pd.DataFrame, vol_window: int = 30, vol_quantile: float = 0.5) -> pd.Series: + """ + Creates a boolean mask to sample data points based on rolling volatility. + + Keeps data points where the rolling volatility is above a specified quantile. + + Args: + df (pd.DataFrame): DataFrame with a 'close' column and DatetimeIndex. + vol_window (int): Rolling window size for volatility calculation. + vol_quantile (float): Quantile threshold (0.0 to 1.0). Data with volatility + above this quantile will be kept. + + Returns: + pd.Series: Boolean mask, True for rows to keep. + """ + if 'close' not in df.columns: + raise ValueError("DataFrame must contain a 'close' column.") + if not isinstance(df.index, pd.DatetimeIndex): + raise ValueError("DataFrame must have a DatetimeIndex.") + if vol_window <= 1: + raise ValueError("vol_window must be greater than 1.") + + # Calculate rolling volatility (std dev of returns) + returns = df['close'].pct_change() + # Use min_periods=vol_window//2+1 to get some values earlier, but still require significant data + rolling_vol = returns.rolling(window=vol_window, min_periods=max(2, vol_window // 2 + 1)).std() + + # Calculate the threshold volatility value + threshold_vol = rolling_vol.quantile(vol_quantile) + + if pd.isna(threshold_vol) or threshold_vol == 0: + logger.warning(f"Volatility quantile ({vol_quantile}) is NaN or zero ({threshold_vol}). " + f"Threshold calculated over {rolling_vol.count()} non-NaN values. " + f"Disabling volatility sampling for this segment.") + # Return a mask that keeps all data if threshold is problematic + return pd.Series(True, index=df.index) + + # Create the mask + mask = rolling_vol > threshold_vol + + # Handle initial NaNs from rolling calculation - discard these rows. + mask.fillna(False, inplace=True) # Discard rows where rolling vol couldn't be calculated + + logger.info(f"Volatility sampling: window={vol_window}, quantile={vol_quantile}, " + f"threshold={threshold_vol:.6f}. Keeping {mask.sum()} / {len(df)} rows.") + + return mask + +class DataLoader: + """ + Loads historical cryptocurrency market data from SQLite databases. + Combines functionality from the previous CryptoDBFetcher and data loading logic. + """ + def __init__(self, db_dir: str, cache_dir: str = "data/cache", use_cache: bool = False): + """ + Initialize the DataLoader. + + Args: + db_dir (str): Directory where SQLite database files are stored. Should be an absolute path or resolvable relative to the execution context. + cache_dir (str): Directory to store cached data (currently not implemented). + use_cache (bool): Whether to use cached data (currently not implemented). + """ + # The path resolution should happen *before* calling the DataLoader. + # We expect db_dir to be the correct path already. + self.db_dir = os.path.abspath(db_dir) # Ensure absolute path for consistency + + self.cache_dir = cache_dir # Placeholder for future cache implementation + self.use_cache = use_cache # Placeholder + + self._db_files = None # Cache discovered DB files + + logger.info(f"Initialized DataLoader with db_dir='{self.db_dir}'") + if not os.path.exists(self.db_dir): + # Log a warning, but allow continuation in case the directory is created later + # or if only specific file checks are relevant later. + logger.warning(f"Database directory may not exist or is inaccessible: {self.db_dir}") + + def _get_db_files(self) -> List[str]: + """Get available database files, sorted by date desc (cached). Uses recursive glob.""" + if self._db_files is not None: + return self._db_files + + logger.info(f"Scanning for DB files recursively in: {self.db_dir}") + # Check existence *here* right before scanning + if not os.path.isdir(self.db_dir): # More specific check for directory + logger.error(f"Database directory does not exist or is not a directory: {self.db_dir}") + self._db_files = [] + return [] + + patterns = ["*.mktdata.ohlcv.db", "*.db", "*.sqlite", "*.sqlite3"] + db_files = [] + for pattern in patterns: + # Recursive search + recursive_pattern = os.path.join(self.db_dir, '**', pattern) + try: + files = glob.glob(recursive_pattern, recursive=True) + if files: + logger.debug(f"Found {len(files)} files recursively with pattern '{pattern}'") + db_files.extend(files) + except Exception as e: + logger.error(f"Error during glob pattern '{recursive_pattern}': {e}") + + if not db_files: + logger.warning(f"No database files found in '{self.db_dir}' matching patterns: {patterns}") + self._db_files = [] + return [] + + db_files = sorted(list(set(db_files))) # Remove duplicates and sort alphabetically for consistency before date sort + + # Sort by date (newest first) if possible + date_pattern = re.compile(r'(\d{8})') + file_dates = [] + for file in db_files: + basename = os.path.basename(file) + match = date_pattern.search(basename) + date_obj = None + if match: + try: + date_obj = pd.to_datetime(match.group(1), format='%Y%m%d') + except ValueError: + pass + # Fallback: try modification time + if date_obj is None: + try: + date_obj = pd.to_datetime(os.path.getmtime(file), unit='s') + except Exception: + date_obj = pd.Timestamp.min # Default to oldest if error + file_dates.append((date_obj, file)) + + # Sort by date object, newest first + file_dates.sort(key=lambda x: x[0], reverse=True) + self._db_files = [file for _, file in file_dates] + + logger.info(f"Found {len(self._db_files)} DB files. Using newest: {os.path.basename(self._db_files[0]) if self._db_files else 'None'}") + return self._db_files + + def _get_relevant_db_files(self, start_dt: pd.Timestamp, end_dt: pd.Timestamp) -> List[str]: + """Find DB files potentially containing data for the date range.""" + all_files = self._get_db_files() + relevant_files = set() + date_pattern = re.compile(r'(\d{8})') + + start_date_only = start_dt.date() + end_date_only = end_dt.date() + + for file in all_files: + basename = os.path.basename(file) + match = date_pattern.search(basename) + file_date = None + if match: + try: + file_date = pd.to_datetime(match.group(1), format='%Y%m%d').date() + except ValueError: + pass # Ignore files with unparseable dates in name + + # Strategy 1: Check if filename date is within the requested range + if file_date and start_date_only <= file_date <= end_date_only: + relevant_files.add(file) + continue # Found by filename date, no need to check mtime + + # Strategy 2: Check if file modification time falls within the range (less precise) + # Useful for files without dates in the name or if a single file spans multiple dates + try: + mod_time_dt = pd.to_datetime(os.path.getmtime(file), unit='s', utc=True) + # Check if the file's modification date is within the range or shortly after + # We add a buffer (e.g., 1 day) because file might contain data slightly past its mod time + if start_dt <= mod_time_dt <= (end_dt + timedelta(days=1)): + relevant_files.add(file) + except Exception as e: + logger.debug(f"Could not get or parse modification time for {file}: {e}") + + # If no files found based on date/mtime, use the most recent file as a fallback + # This is a safety measure, but might lead to incorrect data if the range is old + if not relevant_files and all_files: + logger.warning(f"No DB files found matching date range {start_date_only} - {end_date_only}. Using most recent file as fallback: {os.path.basename(all_files[0])}") + return [all_files[0]] + elif not relevant_files: + logger.error("No relevant DB files found and no fallback files available.") + return [] + + # Sort the relevant files chronologically (oldest first for processing) + # Sorting by basename which often includes date is a reasonable heuristic + return sorted(list(relevant_files), key=lambda f: os.path.basename(f)) + + def _get_table_name(self, conn: sqlite3.Connection, exchange: str, interval: str = "1min") -> Optional[str]: + """Find the correct table name, trying variations. Prioritizes 1min.""" + cursor = conn.cursor() + try: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + except sqlite3.Error as e: + logger.error(f"Failed to list tables in database: {e}") + return None + + # Standard format check (lowercase exchange, exact interval) + base_table = f"{exchange.lower()}_ohlcv_{interval}" + if base_table in tables: return base_table + + # Check for 1min source table specifically (common case) + one_min_table = f"{exchange.lower()}_ohlcv_1min" + if one_min_table in tables: return one_min_table + + # Try other variations (case, common interval formats) + variations = [ + f"{exchange.upper()}_ohlcv_{interval}", + f"{exchange.upper()}_ohlcv_1min", + f"{exchange.lower()}_ohlcv_1m", # Common abbreviation + f"{exchange.upper()}_ohlcv_1m", + ] + for var in variations: + if var in tables: + logger.debug(f"Found table using variation: {var}") + return var + + # Fallback: Check if *any* OHLCV table exists for the exchange + for t in tables: + if t.lower().startswith(f"{exchange.lower()}_ohlcv_"): + logger.warning(f"Using first available OHLCV table found for exchange '{exchange}': {t}. Interval might not match '{interval}'.") + return t + + logger.warning(f"No suitable OHLCV table found for exchange '{exchange}' with interval '{interval}' or '1min' in the database.") + return None + + def _query_data_from_db(self, db_file: str, ticker: str, exchange: str, start_timestamp_ns: int, end_timestamp_ns: int) -> pd.DataFrame: + """ + Query market data from a single database file for a specific ticker and time range (nanoseconds). + Always queries the 1-minute interval table. + """ + instrument_id = f"PAIR-{ticker}" if not ticker.startswith("PAIR-") else ticker + query_interval = "1min" # Always query base interval from DB + + try: + logger.debug(f"Querying DB '{os.path.basename(db_file)}' for {instrument_id} between {start_timestamp_ns} and {end_timestamp_ns}") + with sqlite3.connect(f'file:{db_file}?mode=ro', uri=True) as conn: # Read-only mode + table_name = self._get_table_name(conn, exchange, query_interval) + if not table_name: + logger.warning(f"No table found for {exchange}/1min in {os.path.basename(db_file)}") + return pd.DataFrame() + + cursor = conn.cursor() + try: + cursor.execute(f"PRAGMA table_info({table_name})") + columns_info = cursor.fetchall() + column_names = [col[1].lower() for col in columns_info] + except sqlite3.Error as e: + logger.error(f"Failed to get column info for table '{table_name}' in {db_file}: {e}") + return pd.DataFrame() + + # Check for essential columns + select_cols = ["tstamp", "open", "high", "low", "close", "volume"] + if not all(c in column_names for c in select_cols): + logger.warning(f"Table '{table_name}' in {db_file} missing one or more standard columns: {select_cols}. Found: {column_names}") + return pd.DataFrame() + + # Build query + select_str = ", ".join(select_cols) + where_clauses = ["tstamp >= ?", "tstamp <= ?"] + params: list = [start_timestamp_ns, end_timestamp_ns] + + if 'instrument_id' in column_names: + where_clauses.append("instrument_id = ?") + params.append(instrument_id) + # Note: exchange_id filtering is complex due to potential variations; rely on table name for now. + + query = f"SELECT {select_str} FROM {table_name} WHERE {' AND '.join(where_clauses)} ORDER BY tstamp" + + df = pd.read_sql_query(query, conn, params=params) + if df.empty: + logger.debug(f"Query returned no data for {instrument_id} in {os.path.basename(db_file)} for the time range.") + return pd.DataFrame() + + # Convert timestamp and set index + df['date'] = pd.to_datetime(df['tstamp'], unit='ns', utc=True) + df = df.set_index('date').drop(columns=['tstamp']) + # Ensure numeric types + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + + # Drop rows with NaNs that might result from coerce + df.dropna(subset=['open', 'high', 'low', 'close'], inplace=True) + + logger.debug(f"Query from {os.path.basename(db_file)} returned {len(df)} rows for {instrument_id}.") + return df + + except sqlite3.Error as e: + logger.error(f"SQLite error querying {db_file} table '{table_name if 'table_name' in locals() else 'N/A'}': {e}") + except Exception as e: + logger.error(f"Unexpected error querying {db_file}: {e}", exc_info=False) + return pd.DataFrame() + + def _resample_data(self, df: pd.DataFrame, interval: str) -> pd.DataFrame: + """Resample 1-minute data to a different interval.""" + if df.empty or not isinstance(df.index, pd.DatetimeIndex): + logger.warning("Input DataFrame for resampling is empty or has non-DatetimeIndex.") + return df + if interval == '1min': # No resampling needed + return df + + logger.info(f"Resampling data to {interval}...") + try: + # Define aggregation rules + agg_dict = {'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last'} + if 'volume' in df.columns: agg_dict['volume'] = 'sum' + + # Check for required columns before resampling + required_cols = ['open', 'high', 'low', 'close'] + missing_cols = [c for c in required_cols if c not in df.columns] + if missing_cols: + logger.error(f"Cannot resample, missing required columns: {missing_cols}") + return pd.DataFrame() # Return empty if essential cols missing + + # Perform resampling + resampled_df = df.resample(interval).agg(agg_dict) + + # Drop rows where essential OHLC data is missing after resampling + resampled_df = resampled_df.dropna(subset=['open', 'high', 'low', 'close']) + + if resampled_df.empty: + logger.warning(f"Resampling to {interval} resulted in an empty DataFrame.") + else: + logger.info(f"Resampling complete. New shape: {resampled_df.shape}") + return resampled_df + + except ValueError as e: + logger.error(f"Invalid interval string for resampling: '{interval}'. Error: {e}") + return pd.DataFrame() + except Exception as e: + logger.error(f"Error during resampling to {interval}: {e}", exc_info=True) + return pd.DataFrame() + + def load_data(self, ticker: str, exchange: str, start_date: str, end_date: str, interval: str, + vol_sampling: bool = False, vol_window: int = 30, vol_quantile: float = 0.5) -> pd.DataFrame: + """ + Loads, combines, and optionally resamples/filters data from relevant DB files. + + Args: + ticker (str): The trading pair symbol (e.g., 'SOL-USDT'). + exchange (str): The exchange name (e.g., 'bnbspot'). + start_date (str): Start date string (YYYY-MM-DD). + end_date (str): End date string (YYYY-MM-DD). + interval (str): The desired final data interval (e.g., '1min', '5min', '1h'). + vol_sampling (bool): If True, apply volatility-based sampling. + vol_window (int): Window size for volatility calculation if vol_sampling is True. + vol_quantile (float): Quantile threshold for volatility sampling. + + Returns: + pd.DataFrame: Combined and processed OHLCV data. + """ + logger.info(f"Loading data for {ticker} ({exchange}) from {start_date} to {end_date}, interval {interval}") + + try: + # Parse dates - add time component to cover full days + start_dt = pd.to_datetime(start_date, utc=True).replace(hour=0, minute=0, second=0, microsecond=0) + end_dt = pd.to_datetime(end_date, utc=True).replace(hour=23, minute=59, second=59, microsecond=999999) + if start_dt >= end_dt: + raise ValueError("Start date must be before end date") + except Exception as e: + logger.error(f"Invalid date format or range: {e}") + return pd.DataFrame() + + # Timestamps for DB query (nanoseconds) + start_timestamp_ns = int(start_dt.timestamp() * 1e9) + end_timestamp_ns = int(end_dt.timestamp() * 1e9) + + # Find relevant database files based on the date range + db_files_to_query = self._get_relevant_db_files(start_dt, end_dt) + if not db_files_to_query: + logger.error("No relevant database files found for the specified date range.") + return pd.DataFrame() + + logger.info(f"Identified {len(db_files_to_query)} potential DB files: {[os.path.basename(f) for f in db_files_to_query]}") + + # Query each relevant file and collect data + all_data = [] + for db_file in db_files_to_query: + df_part = self._query_data_from_db(db_file, ticker, exchange, start_timestamp_ns, end_timestamp_ns) + if not df_part.empty: + all_data.append(df_part) + + if not all_data: + logger.warning(f"No data found in any identified DB files for {ticker} ({exchange}) in the specified range.") + return pd.DataFrame() + + # Combine data from all files + try: + combined_df = pd.concat(all_data) + # Remove duplicate indices (e.g., from overlapping file queries), keeping the first occurrence + combined_df = combined_df[~combined_df.index.duplicated(keep='first')] + # Sort chronologically + combined_df = combined_df.sort_index() + except Exception as e: + logger.error(f"Error concatenating or sorting dataframes: {e}") + return pd.DataFrame() + + logger.info(f"Combined data shape before final filtering/resampling: {combined_df.shape}") + + # Apply precise date range filtering *after* combining and sorting + final_df = combined_df[(combined_df.index >= start_dt) & (combined_df.index <= end_dt)] + + if final_df.empty: + logger.warning(f"Dataframe is empty after final date range filtering ({start_dt} to {end_dt}).") + return pd.DataFrame() + + logger.info(f"Shape after final date filtering: {final_df.shape}") + + # --- Add future_close for potential leakage analysis upstream --- + # Placeholder: Needs config access or passed horizon value + # Assuming horizon = 5 for now as per revisions.txt context. + # A better implementation would pass cfg.gru.prediction_horizon here. + prediction_horizon = 5 # TODO: Replace with value from config + final_df['future_close'] = final_df['close'].shift(-prediction_horizon) + logger.info(f"Added 'future_close' column shifting by {prediction_horizon} periods.") + # --- End future_close addition --- + + # --- VOLATILITY SAMPLING --- + if vol_sampling: + logger.info("Applying volatility-aware sampling...") + try: + vol_mask = sample_by_volatility(final_df, vol_window=vol_window, vol_quantile=vol_quantile) + rows_before = len(final_df) + final_df = final_df[vol_mask] + logger.info(f"Applied volatility sampling. Kept {len(final_df)} of {rows_before} rows.") + if final_df.empty: + logger.warning("DataFrame is empty after volatility sampling.") + # Return empty DF immediately if sampling removed everything + return pd.DataFrame() + except Exception as e: + logger.error(f"Error during volatility sampling: {e}. Skipping sampling.", exc_info=True) + # --- END VOLATILITY SAMPLING --- + + # Resample if the requested interval is different from 1min + if interval != "1min": + final_df = self._resample_data(final_df, interval) + if final_df.empty: + logger.error(f"Resampling to {interval} resulted in an empty DataFrame. Check resampling logic or input data.") + return pd.DataFrame() + + # Final check for NaNs in essential columns + essential_cols = ['open', 'high', 'low', 'close'] + if final_df[essential_cols].isnull().any().any(): + rows_before = len(final_df) + final_df.dropna(subset=essential_cols, inplace=True) + logger.warning(f"Dropped {rows_before - len(final_df)} rows with NaN values in essential OHLC columns after potential resampling.") + + if final_df.empty: + logger.error(f"Final DataFrame is empty after NaN checks for {ticker}.") + return pd.DataFrame() + + logger.info(f"Successfully loaded and processed data for {ticker}. Final shape: {final_df.shape}") + return final_df + +# --- Missing Bar Handling Functions --- + +def _consecutive_gaps(missing_indices: pd.DatetimeIndex) -> List[pd.Timedelta]: + """Calculate the length of consecutive gaps in missing timestamps.""" + if missing_indices.empty: + return [] + + # Calculate differences between consecutive missing timestamps + diffs = missing_indices.to_series().diff() + + # Identify the start of each gap (where diff > expected frequency) + # Assumes missing_indices is sorted, diff() calculates t[i] - t[i-1] + # We infer the frequency from the first gap if available, otherwise fallback needs consideration. + # This simple version assumes a constant frequency is implicitly known. + if len(missing_indices) > 1: + # Infer frequency from the first difference if possible + inferred_freq = diffs.iloc[1] + else: + inferred_freq = pd.Timedelta(minutes=1) # Default assumption if only one missing point + # A better approach might need the expected freq passed in. + + # A gap starts when the difference is larger than the inferred frequency + gap_starts = diffs[diffs > inferred_freq].index + + # Calculate gap lengths + gaps = [] + current_start = missing_indices[0] + count = 0 + for i, ts in enumerate(missing_indices): + if i > 0: + # Check if consecutive based on inferred frequency + if ts - missing_indices[i-1] > inferred_freq: + # End of a gap + gaps.append(count) # Store the count of consecutive missing bars + current_start = ts + count = 1 # Start a new gap + else: + count += 1 + else: + count = 1 # First timestamp starts a gap + + gaps.append(count) # Add the last gap + + return gaps + +def find_missing_bars(df: pd.DataFrame, freq: str) -> pd.DatetimeIndex: + """Detects missing timestamps in a DataFrame based on expected frequency.""" + if not isinstance(df.index, pd.DatetimeIndex): + raise ValueError("DataFrame must have a DatetimeIndex.") + if df.empty: + return pd.DatetimeIndex([]) + + # Ensure the index is sorted + df = df.sort_index() + + # Create the expected full date range + full_range = pd.date_range(start=df.index.min(), end=df.index.max(), freq=freq.replace('T','min')) + + # Find the missing timestamps + missing = full_range.difference(df.index) + return missing + +def report_missing(missing: pd.DatetimeIndex, cfg: DictConfig, io: IOManager, logger): + """Reports details about missing bars.""" + total_missing = len(missing) + strategy = cfg['data']['missing']['strategy'] # Access using dict style + max_gap_allowed = cfg['data']['missing']['max_gap'] # Access using dict style + + if total_missing == 0: + logger.info("No missing bars detected.") + return 0 # Return longest gap = 0 + + gaps = _consecutive_gaps(missing) + longest_gap = max(gaps) if gaps else 0 + + # Log warning + warning_msg = ( + f"Detected {total_missing} missing bars (longest consecutive gap: {longest_gap}). " + f"Applied strategy: {strategy}." + ) + logger.warning(warning_msg) + + # Save report + report_data = { + "total_missing_bars": total_missing, + "longest_consecutive_gap": longest_gap, + "applied_strategy": strategy, + "gap_distribution": {i: gaps.count(i) for i in sorted(list(set(gaps)))}, # Count occurrences of each gap length + "missing_timestamps_utc": missing.strftime('%Y-%m-%d %H:%M:%S').tolist() # Store as strings + } + try: + io.save_json(report_data, "missing_bars_summary.json", indent=4) + logger.info(f"Saved missing bars summary.") + except Exception as e: + logger.error(f"Failed to save missing bars summary: {e}") + + return longest_gap + +def fill_missing_bars(df: pd.DataFrame, cfg: DictConfig, io: IOManager, logger) -> pd.DataFrame: + """Detects, reports, and fills missing bars based on configuration.""" + freq = cfg['data']['bar_frequency'] + strategy = cfg['data']['missing']['strategy'] + max_gap_allowed = cfg['data']['missing']['max_gap'] + interpolate_cfg = cfg['data']['missing']['interpolate'] + + missing_indices = find_missing_bars(df, freq) + longest_gap = report_missing(missing_indices, cfg, io, logger) + + # Check if longest gap exceeds allowed maximum + if longest_gap > max_gap_allowed: + raise ValueError( + f"Longest consecutive gap ({longest_gap}) exceeds maximum allowed ({max_gap_allowed}). " + f"Consider adjusting 'data.missing.max_gap' or cleaning the data source." + ) + + if missing_indices.empty: + df['bar_imputed'] = False # Add column even if no missing data initially + return df + + # Reindex to include all missing timestamps + full_range = pd.date_range(start=df.index.min(), end=df.index.max(), freq=freq) + df_full = df.reindex(full_range) + + # Apply filling strategy + if strategy == "drop": + logger.info("Missing bar strategy 'drop': Returning original data (no filling). Missing bars remain as NaNs or gaps.") + # Note: 'drop' doesn't really fit the fill_missing name, but follows instructions. + # It implies subsequent steps might handle NaNs or the user accepts gaps. + # We still need to add the 'bar_imputed' flag based on the *detected* missing indices. + df_filled = df.copy() # Return original df + # We can't directly assign based on full_range index to original df if size differs. + # Instead, add the column and mark True where original index matches a missing one. + df_filled['bar_imputed'] = df_filled.index.isin(missing_indices) + return df_filled + elif strategy == "neutral": + logger.info("Missing bar strategy 'neutral': Forward filling close, setting OHLC=close, Volume=0.") + # Ffill everything first to get the last known close + df_filled = df_full.ffill() + # Identify rows that were originally NaN (i.e., the missing ones) + was_nan_mask = df_full['close'].isna() + # Set Volume to 0 for imputed bars + df_filled.loc[was_nan_mask, 'volume'] = 0 + # Set Open, High, Low to the forward-filled Close for imputed bars + df_filled.loc[was_nan_mask, ['open', 'high', 'low']] = df_filled.loc[was_nan_mask, 'close'] + # Ensure no NaNs remain due to ffill at the beginning + df_filled.bfill(inplace=True) + elif strategy == "ffill": + logger.info("Missing bar strategy 'ffill': Forward filling all columns, then backward filling initial NaNs.") + df_filled = df_full.ffill().bfill() + elif strategy == "interpolate": + logger.info(f"Missing bar strategy 'interpolate': Interpolating using method='{interpolate_cfg.method}', limit={interpolate_cfg.limit}.") + # Interpolate numeric columns (OHLCV) + numeric_cols = ['open', 'high', 'low', 'close', 'volume'] + df_filled = df_full.copy() # Start with the reindexed df + for col in numeric_cols: + if col in df_filled.columns: + df_filled[col] = df_filled[col].interpolate( + method=interpolate_cfg.method, + limit=interpolate_cfg.limit, + limit_direction='forward', # Only fill gaps, not leading/trailing NaNs initially + limit_area=None # Fill within NaNs only + ) + # Still need to handle potential leading/trailing NaNs if any + df_filled.ffill(inplace=True) # Fill any remaining NaNs at the start using ffill + df_filled.bfill(inplace=True) # Fill any remaining NaNs at the end using bfill + else: + raise ValueError(f"Unknown missing data strategy: {strategy}") + + # Add the 'bar_imputed' column AFTER filling + df_filled['bar_imputed'] = df_filled.index.isin(missing_indices) + + logger.info(f"Successfully filled missing bars using strategy '{strategy}'. Added 'bar_imputed' column.") + return df_filled + +# --- End Missing Bar Handling --- \ No newline at end of file diff --git a/gru_sac_predictor/src/feature_engineer.py b/gru_sac_predictor/src/feature_engineer.py new file mode 100644 index 00000000..73e9617c --- /dev/null +++ b/gru_sac_predictor/src/feature_engineer.py @@ -0,0 +1,1059 @@ +""" +Feature Engineering Component. + +Handles adding base features (cyclical, imbalance, TA) and selecting features +using Logistic Regression (L1) and Variance Inflation Factor (VIF). +""" + +import pandas as pd +import numpy as np +import logging +import json +import scipy.stats as st # Added for z-score +from typing import List +import time +import pandas_ta as ta # <<< ADDED +import warnings # <<< ADDED for catching warnings + +from sklearn.linear_model import LogisticRegression, LinearRegression # Added LinearRegression +from sklearn.feature_selection import SelectFromModel +from statsmodels.stats.outliers_influence import variance_inflation_factor +import statsmodels.api as sm +# Added RollingOLS for trend-fit residuals +from statsmodels.regression.rolling import RollingOLS +from sklearn.preprocessing import PolynomialFeatures # <<< ADDED for interactions + +# Import TA library functions directly +from ta.volatility import AverageTrueRange, KeltnerChannel, BollingerBands # Added BollingerBands +from ta.momentum import RSIIndicator +from ta.trend import EMAIndicator, MACD +# Remove specific ta.volume imports if using pandas_ta alias + +logger = logging.getLogger(__name__) + +_EPS = 1e-6 + +# --- Helper function for safe division --- +def safe_divide(numerator, denominator, default=0.0): + """Performs division, returning default value where denominator is near zero or NaN.""" + mask = (np.abs(denominator) < _EPS) | np.isnan(denominator) | np.isnan(numerator) + result = np.divide(numerator, denominator, out=np.full_like(numerator, default, dtype=np.float64), where=~mask) + result[mask] = default + return result + +# --- Helper function for rolling regression residuals --- +def rolling_regression_residual(y, window=3): + """Calculates the residual of a rolling OLS regression of y against time.""" + if len(y) < window: + return pd.Series(np.nan, index=y.index) + + try: + # Ensure y is numeric + y_numeric = pd.to_numeric(y, errors='coerce') + + # Prepare Exog Data (Add constant to x) + x_trend = np.arange(len(y)) + X_exog = sm.add_constant(x_trend, prepend=True) + X_exog_df = pd.DataFrame(X_exog, index=y.index, columns=['const', 'time']) + + # Align y and X + y_numeric, X_aligned = y_numeric.align(X_exog_df, join='inner', axis=0) + + if y_numeric.isnull().all() or X_aligned.isnull().all().all(): + logger.warning("Input to RollingOLS contains all NaNs after alignment.") + return pd.Series(np.nan, index=y.index) + + rols = RollingOLS(endog=y_numeric, exog=X_aligned, window=window, min_nobs=window) + + # <<< Catch RuntimeWarning during fit >>> + with warnings.catch_warnings(): + # Ignore the specific warning about log(ssr) when ssr is zero + warnings.filterwarnings( + 'ignore', + message='divide by zero encountered in log', + category=RuntimeWarning + ) + results = rols.fit() + # <<< End Catch >>> + + # --- Re-calculate Fitted Values & Residuals (Fix for AttributeError) --- # + params = results.params + fitted_values = (X_aligned * params).sum(axis=1) + residuals = y_numeric - fitted_values # Correct calculation + # --- End Re-calculation --- # + + # Reindex residuals to match the original input index, filling gaps with NaN + return residuals.reindex(y.index) + + except Exception as e: + logger.error(f"Error during rolling regression: {e}", exc_info=True) + return pd.Series(np.nan, index=y.index) + +class FeatureEngineer: + """Encapsulates feature creation and selection logic.""" + + def __init__(self, config: dict): # Changed to accept config + """ + Initialize the FeatureEngineer. + + Args: + config (dict): Pipeline configuration dictionary. + """ + self.config = config + # Minimal whitelist definition can be moved here or kept separate + # For now, assuming it's defined elsewhere or passed via config if needed + self.minimal_whitelist = config.get('features', {}).get('minimal_whitelist', []) + if not self.minimal_whitelist: + logger.warning("Minimal whitelist not found in config or is empty.") + # Define a default fallback if necessary + self.minimal_whitelist = [ + "return_1m", "return_15m", "return_60m", "ATR_14", "volatility_14d", + "chaikin_AD_10", "svi_10", "EMA_10", "EMA_50", + "hour_sin", "hour_cos", + ] + logger.info(f"FeatureEngineer initialized. Minimal whitelist: {self.minimal_whitelist}") + + def _add_cyclical_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds sine and cosine transformations of the hour and week progress.""" + if not isinstance(df.index, pd.DatetimeIndex): + logger.warning("Index is not DatetimeIndex. Skipping cyclical features.") + # Add placeholders if needed + df['hour_sin'] = 0.0 + df['hour_cos'] = 1.0 + df['week_sin'] = 0.0 + df['week_cos'] = 1.0 + return df + + timestamp_source = df.index + logger.info("Adding cyclical hour features (sin/cos)...") + # --- Hourly Features --- # + hours_in_day = 24 + df['hour_sin'] = np.sin(2 * np.pi * timestamp_source.hour / hours_in_day) + df['hour_cos'] = np.cos(2 * np.pi * timestamp_source.hour / hours_in_day) + + # --- Weekly Features (Task 2.2) --- # + logger.info("Adding cyclical weekly features (sin/cos)...") + # Calculate time elapsed in minutes within the week (0=Monday 00:00, max=7*24*60) + # dayofweek: Monday=0, Sunday=6 + minutes_in_week = 7 * 24 * 60 + time_in_week_minutes = (timestamp_source.dayofweek * 24 * 60 + + timestamp_source.hour * 60 + + timestamp_source.minute) + + df['week_sin'] = np.sin(2 * np.pi * time_in_week_minutes / minutes_in_week) + df['week_cos'] = np.cos(2 * np.pi * time_in_week_minutes / minutes_in_week) + # --- End Weekly Features --- # + + return df + + def _add_imbalance_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Add Chaikin AD line, signed volume imbalance, gap imbalance, PVT, MFI.""" + logger.info("Adding imbalance features...") + if not {"open", "high", "low", "close", "volume"}.issubset(df.columns): + logger.warning("Missing required columns for imbalance features. Skipping.") + return df + + df_shifted = df.shift(1) # Use shifted data for calculations + df_results = pd.DataFrame(index=df.index) # Store new features here + # Initialize all expected columns with NaN to ensure float dtype from the start + cols_to_init = [ + "chaikin_AD_10", "svi_10", "signed_vol_norm_std_10", + "gap_imbalance", "pvt", "mfi_14" + ] + for col in cols_to_init: + # Explicitly initialize with float64 dtype + df_results[col] = pd.Series(np.nan, index=df.index, dtype=np.float64) + + try: + # Chaikin AD (uses non-shifted internally via clv) + clv = safe_divide((df["close"] - df["low"]) - (df["high"] - df["close"]), + (df["high"] - df["low"]), default=0.0) + df_results["chaikin_AD_10"] = (clv * df["volume"]).rolling(10).sum() + + # Signed Volume Imbalance (Rolling Sum) + signed_vol = np.where(df_shifted["close"] >= df_shifted["open"], df_shifted["volume"], -df_shifted["volume"]) + df_results["svi_10"] = pd.Series(signed_vol, index=df.index).rolling(10).sum() + + # Signed Volume / Rolling Vol Std (window=10) + vol_std_10 = df_shifted['volume'].rolling(window=10, min_periods=5).std() + df_results["signed_vol_norm_std_10"] = safe_divide(pd.Series(signed_vol, index=df.index), vol_std_10, default=0.0) + + # Gap Imbalance + med_vol = df_shifted["volume"].rolling(50).median() + gap_up = (df_shifted["low"] > df["high"].shift(2)) & (df_shifted["volume"] > 2 * med_vol) # Check against t-2 high + gap_dn = (df_shifted["high"] < df["low"].shift(2)) & (df_shifted["volume"] > 2 * med_vol) # Check against t-2 low + df_results["gap_imbalance"] = gap_up.astype(int) - gap_dn.astype(int) + + # Price-Volume Trend (PVT) using ta library + # Remove individual np.nan initialization here + try: + # Use pandas_ta function call + df_results["pvt"] = ta.pvt(close=df_shifted["close"], volume=df_shifted["volume"]) + except Exception as pvt_e: + logger.error(f"Error calculating PVT: {pvt_e}") + df_results["pvt"] = 0.0 + + # Money Flow Index (MFI) using ta library (default window 14) + # Remove individual np.nan initialization here + try: + # Use pandas_ta function call + # REMOVE .astype(float) here as the target column dtype is already set + df_results["mfi_14"] = ta.mfi(high=df_shifted["high"], low=df_shifted["low"], close=df_shifted["close"], volume=df_shifted["volume"], window=14).astype(float) + except Exception as mfi_e: + logger.error(f"Error calculating MFI: {mfi_e}") + df_results["mfi_14"] = 50.0 # Ensure fallback is float + + # Fill NaNs introduced within this method + imbalance_cols = ["chaikin_AD_10", "svi_10", "signed_vol_norm_std_10", "gap_imbalance", "pvt", "mfi_14"] + for col in imbalance_cols: + if col in df_results.columns: + # Replace inplace fillna + df_results[col] = df_results[col].ffill() + df_results[col] = df_results[col].bfill() + df_results[col] = df_results[col].fillna(0) # Final fill for start/end + + # Add results to original DataFrame + for col in imbalance_cols: + if col in df_results.columns: + # Explicitly cast to float to avoid dtype warnings + df[col] = df_results[col].astype(float) + + logger.info("Successfully added imbalance features (incl. PVT, MFI).") + except Exception as e: + logger.error(f"Error calculating imbalance features: {e}", exc_info=True) + + return df + + def _add_vol_norm_returns(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds volatility-normalized returns.""" + logger.info("Adding volatility-normalized returns...") + # Expects return_1m, return_15m, return_60m to exist from _add_ta_features + return_cols = [col for col in df.columns if col.startswith('return_') and col.endswith('m')] + if not return_cols: + logger.warning("No base return columns (e.g., return_1m) found. Skipping volatility-normalized returns.") + return df + + window = 10 # Define the rolling window for std dev + logger.info(f"Using window={window} for volatility normalization standard deviation.") + + for col in return_cols: + norm_col_name = col.replace("return_", "return_norm_") + try: + rolling_std = df[col].rolling(window=window, min_periods=window // 2).std() + # Handle potential division by zero or NaNs + df[norm_col_name] = df[col] / (rolling_std + _EPS) + # Fill any resulting NaNs/Infs (e.g., from initial periods or zero std dev) + # Fix inplace replace warning + df[norm_col_name] = df[norm_col_name].replace([np.inf, -np.inf], np.nan) + # Replace inplace fillna + df[norm_col_name] = df[norm_col_name].ffill() + df[norm_col_name] = df[norm_col_name].bfill() + df[norm_col_name] = df[norm_col_name].fillna(0) # Final fill + logger.debug(f"Added {norm_col_name}") + except Exception as e: + logger.error(f"Error calculating {norm_col_name}: {e}") + df[norm_col_name] = 0.0 # Add placeholder on error + + logger.info("Successfully added volatility-normalized returns.") + return df + + def _add_ta_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds TA features using the 'ta' library and config parameters.""" + logger.info("Adding TA features...") + required_cols = {'open', 'high', 'low', 'close', 'volume'} + if not required_cols.issubset(df.columns): + logger.warning(f"Missing required columns for TA features ({required_cols - set(df.columns)}). Skipping TA.") + return df + + # Get feature windows from config + feat_cfg = self.config.get('features', {}) + atr_window = feat_cfg.get('atr_window', 14) + rsi_window = feat_cfg.get('rsi_window', 14) + macd_fast = feat_cfg.get('macd_fast', 12) + macd_slow = feat_cfg.get('macd_slow', 26) + macd_signal = feat_cfg.get('macd_signal', 9) + ema_10_window = 10 # Hardcoded for now, could be configurable + ema_50_window = 50 # Hardcoded for now, could be configurable + volatility_window_days = 14 # Hardcoded for now, could be configurable + + # Apply shift(1) to prevent lookahead bias in TA features based on close + # Features will be calculated based on data up to t-1 + df_shifted = df.shift(1) + df_ta = pd.DataFrame(index=df.index) # Create empty DF to store results aligned with original index + + try: + # Calculate returns first (use shifted close) + # Fill NaNs robustly before pct_change on the *shifted* data + close_filled = df_shifted["close"].bfill().ffill() + df_ta["return_1m"] = close_filled.pct_change() + df_ta["return_15m"] = close_filled.pct_change(15) + df_ta["return_60m"] = close_filled.pct_change(60) + + # Calculate TA features using ta library on *shifted* data + df_ta[f"ATR_{atr_window}"] = AverageTrueRange(df_shifted['high'], df_shifted['low'], df_shifted['close'], window=atr_window).average_true_range() + + # Daily volatility (use calculated 1m return) + # Assumes 1-min bars; adjust multiplier if interval changes + vol_roll_window = 60 * 24 * volatility_window_days + df_ta[f"volatility_{volatility_window_days}d"] = safe_divide(1, 1) * df_ta["return_1m"].rolling(vol_roll_window, min_periods=vol_roll_window // 2).std() # Corrected line + + # EMA 10 / 50 + MACD using ta library (on shifted close) + df_ta[f"EMA_{ema_10_window}"] = EMAIndicator(df_shifted["close"], ema_10_window).ema_indicator() + df_ta[f"EMA_{ema_50_window}"] = EMAIndicator(df_shifted["close"], ema_50_window).ema_indicator() + macd = MACD(df_shifted["close"], window_slow=macd_slow, window_fast=macd_fast, window_sign=macd_signal) + df_ta["MACD"] = macd.macd() + df_ta["MACD_signal"] = macd.macd_signal() + + # RSI using ta library (on shifted close) + df_ta[f"RSI_{rsi_window}"] = RSIIndicator(df_shifted["close"], window=rsi_window).rsi() + + # Handle potential NaNs introduced by TA calculations + df_ta.bfill(inplace=True) + df_ta.ffill(inplace=True) + + # Add the calculated TA features back to the original df + ta_cols_to_add = [col for col in df_ta.columns if col not in df.columns] + for col in ta_cols_to_add: + df[col] = df_ta[col] + logger.info("Successfully added TA features.") + + except Exception as e: + logger.error(f"Error calculating TA features: {e}", exc_info=True) + + return df + + def _add_microstructure_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds bar-level microstructure features based on revisions.txt 1-B.""" + logger.info("Adding microstructure features...") + required_cols = {"open", "high", "low", "close", "volume"} + if not required_cols.issubset(df.columns): + logger.warning("Missing required columns for microstructure features. Skipping.") + return df + + df_micro = pd.DataFrame(index=df.index) # Create empty DF for results + df_shifted = df.shift(1) # Shift all inputs to avoid lookahead + + try: + # --- Revision 2: Add guards for potentially all-NaN columns --- # + # Check if high/low are all NaNs before using them + high_valid = not df_shifted["high"].isnull().all() + low_valid = not df_shifted["low"].isnull().all() + # --- End Revision 2 --- # + + # 1. Spread proxy + # Calculation: (high/low - 1) + # --- Revision 2 Guard --- # + if high_valid and low_valid: + df_micro["spread_proxy"] = df_shifted["high"] / (df_shifted["low"] + _EPS) - 1 + else: + logger.warning("Spread proxy cannot be calculated (high/low contains all NaNs). Filling with 0.") + df_micro["spread_proxy"] = 0.0 + # --- End Revision 2 --- # + + # 2. Vol-norm volume spike (using 1 day = 1440 min rolling mean/std) + # Calculation: (volume - rolling_mean_1440) / rolling_std_1440 + # Use min_periods for robustness at the start + vol_roll_mean = df_shifted["volume"].rolling(window=1440, min_periods=int(1440*0.5)).mean() + vol_roll_std = df_shifted["volume"].rolling(window=1440, min_periods=int(1440*0.5)).std() + df_micro["vol_norm_volume_spike"] = (df_shifted["volume"] - vol_roll_mean) / (vol_roll_std + _EPS) + + # 3. Return asymmetry + # Calculation: (close - open) / (high - low) + intrabar_range = df_shifted["high"] - df_shifted["low"] + _EPS + df_micro["return_asymmetry"] = (df_shifted["close"] - df_shifted["open"]) / intrabar_range + + # 4. Close-location value + # Calculation: abs(close - (high + low) / 2) / (high - low) + # --- Revision 2 Guard --- # + if high_valid and low_valid: + mid_point = (df_shifted["high"] + df_shifted["low"]) / 2 + # Ensure intrabar_range is recalculated here if needed, or use the one from above + intrabar_range_clv = df_shifted["high"] - df_shifted["low"] + _EPS + df_micro["close_location_value"] = abs(df_shifted["close"] - mid_point) / intrabar_range_clv + else: + logger.warning("Close-location value cannot be calculated (high/low contains all NaNs). Filling with 0.") + df_micro["close_location_value"] = 0.0 + # --- End Revision 2 --- # + + # 5. Keltner band position (using EMA 20 and ATR 20) + # Calculation: (close - ema20) / (atr20 * 2) + # Need to calculate EMA(20) and ATR(20) if not already present from _add_ta_features + if 'EMA_20' not in df_shifted.columns: + # Use ta lib directly on shifted data + ema_20_calc = EMAIndicator(df_shifted["close"], window=20).ema_indicator() + else: + ema_20_calc = df_shifted['EMA_20'] + + if 'ATR_20' not in df_shifted.columns: + # Use ta lib directly on shifted data + atr_20_calc = AverageTrueRange(df_shifted['high'], df_shifted['low'], df_shifted['close'], window=20).average_true_range() + else: + atr_20_calc = df_shifted['ATR_20'] + + # Keltner calculation using pre-calculated or just-calculated indicators + keltner_center = ema_20_calc + keltner_deviation = (atr_20_calc * 2) # As per formula spec (atr20 * 2) + df_micro["keltner_band_pos"] = (df_shifted["close"] - keltner_center) / (keltner_deviation + _EPS) + + # Fill NaNs introduced by rolling calculations or division by zero + micro_cols = df_micro.columns.tolist() + # Simple fillna(0) might be problematic. Let's try bfill/ffill first. + df_micro.bfill(inplace=True) + df_micro.ffill(inplace=True) + # Final fill with 0 if any NaNs remain (e.g., at the very start) + df_micro.fillna(0, inplace=True) + + # Add calculated features back to original df + for col in micro_cols: + if col not in df.columns: + df[col] = df_micro[col] + else: # Overwrite if column already exists (e.g., from previous runs/versions) + df[col] = df_micro[col] + logger.warning(f"Overwriting existing column '{col}' with new microstructure feature.") + logger.info(f"Successfully added microstructure features: {micro_cols}") + + except ImportError: + logger.error("'ta' library not found. Cannot calculate Keltner/EMA/ATR needed for microstructure features. Skipping.") + except Exception as e: + logger.error(f"Error calculating microstructure features: {e}", exc_info=True) + + return df + + def add_base_features(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Adds a standard set of base features: cyclical, imbalance, TA, microstructure, + and additional features from the grab-bag list. + + Args: + df (pd.DataFrame): Input DataFrame with OHLCV data and DatetimeIndex. + + Returns: + pd.DataFrame: DataFrame with added features. + """ + start_time = time.time() + logger.info(f"Starting base feature engineering on {df.shape[0]} rows...") + # Ensure index is datetime + if not isinstance(df.index, pd.DatetimeIndex): + try: + df.index = pd.to_datetime(df.index) + logger.info("Converted index to DatetimeIndex.") + except Exception as e: + logger.error(f"Failed to convert index to DatetimeIndex: {e}. Some features may fail.") + # Depending on strictness, could return df here or raise error + + # --- Apply Feature Steps --- # + # Order is somewhat important based on dependencies + df = self._add_ta_features(df) # Needs OHLCV + df = self._add_vol_norm_returns(df) # Needs returns from _add_ta_features + df = self._add_cyclical_features(df) # Needs index + df = self._add_imbalance_features(df) # Needs OHLCV, Volume; Adds PVT, MFI + df = self._add_microstructure_features(df) # Needs OHLCV, Volume, potentially EMA/ATR from _add_ta_features + + # <<< Call NEW feature methods >>> + df = self._add_candle_features(df) # Needs OHLC + df = self._add_volatility_channel_features(df) # Needs Close, EMA/ATR + df = self._add_momentum_trend_features(df) # Needs Close, return_1m + df = self._add_extra_seasonality_features(df) # Needs index + # <<< End CALL NEW >>> + + # <<< CALL Interaction Method HERE + df = self._add_interaction_features(df) + + # --- Handle bar_imputed if present --- # + if 'bar_imputed' in df.columns: + # Ensure it's integer type and fill potential NaNs from joins/shifts + df['bar_imputed'] = df['bar_imputed'].fillna(0).astype(int) + else: + # If missing after all feature steps, add it as 0 + logger.warning("'bar_imputed' column not found after feature engineering. Adding column filled with 0.") + df['bar_imputed'] = 0 + # --- End Handle bar_imputed --- # + + # --- Final Check & Clean --- # + # Convert boolean columns potentially created by TA-Lib to int/float + for col in df.select_dtypes(include='bool').columns: + logger.debug(f"Converting boolean column '{col}' to int.") + df[col] = df[col].astype(int) + + # Replace any remaining infinities just in case + df.replace([np.inf, -np.inf], np.nan, inplace=True) + + # Optional: Final check for NaNs - uncomment if needed for debugging + # nan_counts = df.isnull().sum() + # if nan_counts.sum() > 0: + # logger.warning(f"NaNs remain after feature engineering:") + # logger.warning(nan_counts[nan_counts > 0]) + # # Decide on final fill strategy - maybe ffill/bfill again or fillna(0) + # df.ffill(inplace=True) + # df.bfill(inplace=True) + # df.fillna(0, inplace=True) + # --- End Final Check --- # + + end_time = time.time() + logger.info(f"Base feature engineering complete. Shape: {df.shape}. Time: {end_time - start_time:.2f}s") + return df + + def select_features(self, X_train_raw: pd.DataFrame, y_dir_train: pd.Series) -> list: + """ + Performs feature selection using: + 1. Logistic Regression (L1 penalty) to find predictive features. + 2. Variance Inflation Factor (VIF) to remove multi-collinear features. + Uses parameters from the 'features' section of the config. + + Args: + X_train_raw (pd.DataFrame): Raw training features (unscaled). + y_dir_train (pd.Series): Training direction labels (ordinal). + + Returns: + list: The final list of selected feature names. + """ + logger.info("Starting feature selection process...") + if X_train_raw.empty: + logger.error("Input DataFrame X_train_raw is empty. Skipping feature selection.") + return [] + + # Get parameters from config + feat_cfg = self.config.get('features', {}) + vif_threshold = feat_cfg.get('vif_threshold', 10.0) # Default VIF threshold + # Increase C for less penalty (C is the inverse of regularization strength) + logreg_c = feat_cfg.get('logreg_c', 1.0) # Default L1 regularization strength (Higher C = less penalty) + min_features = feat_cfg.get('min_features_after_selection', 5) # Default min features + selection_method = feat_cfg.get('selection_method', 'logreg_vif') # Default method + + # Ensure targets are not in feature set + X_train_only_features = X_train_raw.select_dtypes(include=np.number) + if y_dir_train.name in X_train_only_features.columns: + X_train_only_features = X_train_only_features.drop(columns=[y_dir_train.name]) + logger.debug(f"Removed target column '{y_dir_train.name}' from features before selection.") + + # Handle potential NaNs/Infs (important before LogReg/VIF) + X_train_clean = X_train_only_features.replace([np.inf, -np.inf], np.nan).fillna(0) + if X_train_clean.isnull().any().any(): + logger.warning("NaNs still present after fillna(0). Feature selection might be unstable.") + + if X_train_clean.empty: + logger.error("Feature set is empty after cleaning. Skipping selection.") + return [] + + # --- 1. Univariate Filtering (New Step) --- # + univariate_quantile_threshold = feat_cfg.get('univariate_quantile_threshold', 0.30) # Keep top 70% by default + logger.info(f"Performing univariate filtering (keeping features above {univariate_quantile_threshold} quantile of correlation with target)...") + try: + # Ensure y_dir_train is aligned with X_train_clean and numeric + y_aligned, X_aligned = y_dir_train.align(X_train_clean, join='inner', axis=0) + y_numeric = pd.to_numeric(y_aligned, errors='coerce').fillna(0) # Convert to numeric, fill NaNs + X_train_clean = X_aligned # Use aligned X from now on + + if X_train_clean.empty: + logger.error("Feature set empty after aligning with target for univariate filter. Skipping selection.") + return [] + + corrs = X_train_clean.apply(lambda col: abs(np.corrcoef(col, y_numeric)[0, 1]) if col.var() > _EPS else 0.0) + corrs.fillna(0, inplace=True) # Fill NaN correlations (e.g., from zero-variance columns) + + if corrs.empty or corrs.isnull().all(): + logger.warning("Could not compute valid correlations for univariate filter. Skipping this step.") + X_train_univariate_filtered = X_train_clean # Pass all cleaned features to next step + else: + quantile_value = corrs.quantile(univariate_quantile_threshold) + keep_mask = corrs >= quantile_value + univariate_selected_features = corrs[keep_mask].index.tolist() + if not univariate_selected_features: + logger.warning(f"Univariate filter removed all features (threshold {quantile_value:.4f}). Keeping all features to proceed.") + X_train_univariate_filtered = X_train_clean + else: + X_train_univariate_filtered = X_train_clean[univariate_selected_features] + logger.info(f"Univariate filter kept {len(univariate_selected_features)} features (Corr >= {quantile_value:.4f}).") + + except Exception as e: + logger.error(f"Error during univariate filtering: {e}", exc_info=True) + logger.warning("Skipping univariate filtering due to error.") + X_train_univariate_filtered = X_train_clean # Use original cleaned data if filter fails + + if X_train_univariate_filtered.empty: + logger.error("Feature set is empty after univariate filtering. Returning minimal whitelist.") + return self.minimal_whitelist[:min_features] + + # --- 2. Logistic Regression (L1) Selection (Operates on univariately filtered data) --- # + logger.info(f"Performing L1 Logistic Regression (C={logreg_c}) on {X_train_univariate_filtered.shape[1]} features...") + try: + # Align target with the potentially reduced feature set again + y_aligned_l1, X_aligned_l1 = y_dir_train.align(X_train_univariate_filtered, join='inner', axis=0) + y_numeric_l1 = pd.to_numeric(y_aligned_l1, errors='coerce').fillna(0) + + if X_aligned_l1.empty: + logger.error("Feature set empty after aligning for L1 step. Returning minimal whitelist.") + return self.minimal_whitelist[:min_features] + + selector = SelectFromModel( + LogisticRegression( + penalty='l1', + C=logreg_c, + solver='liblinear', # Good solver for L1 + random_state=42, + max_iter=500 # Increased max_iter + ), + threshold="mean" # Use mean threshold as suggested + ) + selector.fit(X_aligned_l1, y_numeric_l1) # Fit on aligned data + l1_selected_features = X_aligned_l1.columns[selector.get_support()].tolist() + logger.info(f"L1 LogReg selected {len(l1_selected_features)} features (threshold='mean').") + + if not l1_selected_features: + logger.warning("L1 LogReg selected zero features after univariate filter. Returning minimal whitelist as fallback.") + return self.minimal_whitelist[:min_features] # Return min number from minimal list + + except Exception as e: + logger.error(f"Error during L1 Logistic Regression selection: {e}", exc_info=True) + logger.warning("Falling back to minimal whitelist due to L1 selection error.") + return self.minimal_whitelist[:min_features] + + # --- 3. VIF Calculation and Filtering (Operates on L1 selected features) --- # + logger.info(f"Performing VIF filtering (threshold={vif_threshold}) on {len(l1_selected_features)} features...") + features_to_check = l1_selected_features + # Use the univariately filtered data to get the columns for VIF + X_vif_input = X_train_univariate_filtered[features_to_check] + + final_whitelist = [] + while len(features_to_check) > min_features: + X_subset = X_vif_input[features_to_check] # Subset the columns further in each iteration + # Add constant for VIF calculation + try: + # Align VIF input data again? Should be okay if features_to_check is correct. + X_subset_const = sm.add_constant(X_subset, has_constant='add') + vif = pd.Series( + [variance_inflation_factor(X_subset_const.values, i) + for i in range(1, X_subset_const.shape[1])], # Skip constant (index 0) + index=X_subset.columns, + dtype=float # Ensure float type + ) + + max_vif = vif.max() + if max_vif > vif_threshold: + feature_to_drop = vif.idxmax() + features_to_check.remove(feature_to_drop) + logger.debug(f" Dropped '{feature_to_drop}' due to high VIF ({max_vif:.2f}) > {vif_threshold}. Remaining: {len(features_to_check)}") + else: + logger.debug(f" Max VIF ({max_vif:.2f}) <= threshold ({vif_threshold}). Stopping VIF loop.") + break # All remaining features are below threshold + except Exception as vif_e: + logger.error(f"Error calculating VIF for features {features_to_check}: {vif_e}", exc_info=True) + logger.warning("Aborting VIF filtering due to error.") + # Should we fallback? For now, just stop VIF filtering and use current set. + break + + final_whitelist = features_to_check + + # Ensure minimal whitelist features are included (if specified by config?) + # Add back any minimal whitelist features that were dropped, if needed + missing_minimal = set(self.minimal_whitelist) - set(final_whitelist) + if missing_minimal: + logger.info(f"Adding missing minimal whitelist features back: {list(missing_minimal)}") + final_whitelist.extend(list(missing_minimal)) + # Ensure uniqueness after adding back + final_whitelist = list(pd.Series(final_whitelist).unique()) + + # Ensure minimum number of features + if len(final_whitelist) < min_features: + logger.warning(f"Final whitelist has {len(final_whitelist)} features, less than minimum required ({min_features}). Adding from minimal whitelist if possible.") + needed = min_features - len(final_whitelist) + additional_features = [f for f in self.minimal_whitelist if f not in final_whitelist][:needed] + final_whitelist.extend(additional_features) + logger.info(f"Added {len(additional_features)} features from minimal list. Final count: {len(final_whitelist)}") + + # --- Ensure bar_imputed is ALWAYS included --- # + imputed_col_name = 'bar_imputed' # Assuming this is the column name + if imputed_col_name not in final_whitelist: + logger.warning(f"'{imputed_col_name}' was not selected by L1/VIF. Adding it back for sequence creation.") + final_whitelist.append(imputed_col_name) + # --- End Ensure --- # + + logger.info(f"Final whitelist contains {len(final_whitelist)} features after L1 + VIF selection.") + logger.debug(f"Final Whitelist: {final_whitelist}") + return final_whitelist + + def prune_features(self, df: pd.DataFrame, whitelist: list) -> pd.DataFrame: + """ + Prunes the DataFrame to include only features in the whitelist. + + Args: + df (pd.DataFrame): Input DataFrame. + whitelist (list): List of feature names to keep. + + Returns: + pd.DataFrame: Pruned DataFrame. + """ + logger.info(f"Pruning features to whitelist ({len(whitelist)} features)...") + missing_cols = [col for col in whitelist if col not in df.columns] + if missing_cols: + logger.warning(f"Whitelist features missing from DataFrame during pruning: {missing_cols}. These columns cannot be kept.") + # Filter whitelist to only include existing columns + whitelist = [col for col in whitelist if col in df.columns] + + if not whitelist: # Check if whitelist became empty + logger.error("Whitelist is empty after checking for existing columns. Cannot prune.") + return df # Return original df if whitelist is empty + + try: + df_pruned = df[whitelist].copy() + logger.info(f"Pruning complete. Shape after pruning: {df_pruned.shape}") + return df_pruned + except KeyError as e: + logger.error(f"KeyError during pruning: {e}. Ensure whitelist columns exist in the DataFrame. Returning original DataFrame.") + return df + except Exception as e: + logger.error(f"Unexpected error during pruning: {e}. Returning original DataFrame.") + return df + + def select_features_l1(self, X_train_scaled: pd.DataFrame, y_dir_train: pd.Series) -> List[str]: + """Selects features using L1 regularization.""" + logger.info("Performing feature selection using L1 Logistic Regression (C=0.1)...") + # Convert y_dir_train to int if it's not already (needed for LogReg) + y_train_int = y_dir_train.astype(int) + + # Use configuration for C? For now, hardcoding C=0.1 as per log message + # Added max_iter, verbose, and n_jobs=1 for debugging hangs + C = 0.1 + logger.info(f"Performing feature selection using L1 Logistic Regression (C={C})...") + # Ensure y_dir_train is 1D + if y_dir_train.ndim > 1: + y_dir_train = y_dir_train.squeeze() + + lr_l1 = LogisticRegression( + C=C, + penalty='l1', + solver='liblinear', # Explicitly set solver + random_state=42, + max_iter=1000, # Increase max_iter slightly just in case + multi_class='auto' # Handle binary/multiclass automatically + ) + + try: + logger.info(f"Input features shape for L1 selection: {X_train_scaled.shape}") # Log shape + selector = SelectFromModel(lr_l1, threshold=-np.inf) # Keep all features with non-zero coef + selector.fit(X_train_scaled, y_train_int) + selected_mask = selector.get_support() + selected_features = X_train_scaled.columns[selected_mask].tolist() + + # Ensure minimal whitelist features are included + final_features = list(set(selected_features + self.minimal_whitelist)) + logger.info(f"L1 selected {len(selected_features)} features. After adding minimal whitelist, total: {len(final_features)}") + return final_features + except Exception as e: + logger.error(f"Error during L1 feature selection: {e}", exc_info=True) + # Fallback: return only the minimal whitelist + logger.warning("Falling back to using only the minimal whitelist due to L1 error.") + return self.minimal_whitelist + + # <<< NEW METHOD: Price Range & Candle Features >>> + def _add_candle_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds features based on candle shape and price range.""" + logger.info("Adding candle features...") + required_cols = {'open', 'high', 'low', 'close'} + if not required_cols.issubset(df.columns): + logger.warning(f"Missing required columns for candle features ({required_cols - set(df.columns)}). Skipping.") + return df + + df_shifted = df.shift(1) + df_results = pd.DataFrame(index=df.index) + + try: + # 1. Total Range + df_results["total_range"] = df_shifted["high"] - df_shifted["low"] + + # 2. Body & Shadow Ratios + body = df_shifted["close"] - df_shifted["open"] + upper_shd = df_shifted["high"] - np.maximum(df_shifted["open"], df_shifted["close"]) + lower_shd = np.minimum(df_shifted["open"], df_shifted["close"]) - df_shifted["low"] + + range_with_eps = df_results["total_range"] + _EPS + df_results["body_ratio"] = safe_divide(body, range_with_eps) + df_results["upper_shd_ratio"] = safe_divide(upper_shd, range_with_eps) + df_results["lower_shd_ratio"] = safe_divide(lower_shd, range_with_eps) + + # 3. Doji / Hammer flags + df_results["is_doji"] = (np.abs(body) < 0.1 * df_results["total_range"]).astype(int) + df_results["is_hammer"] = ((lower_shd > 2 * np.abs(body)) & (upper_shd < 0.1 * df_results["total_range"])).astype(int) + + # Fill NaNs + candle_cols = ["total_range", "body_ratio", "upper_shd_ratio", "lower_shd_ratio", "is_doji", "is_hammer"] + for col in candle_cols: + if col in df_results.columns: + # Fix inplace fillna + df_results[col] = df_results[col].ffill() + df_results[col] = df_results[col].bfill() + df_results[col] = df_results[col].fillna(0) + + # Add results to original DataFrame + for col in candle_cols: + if col in df_results.columns: + df[col] = df_results[col] + + logger.info(f"Successfully added candle features: {candle_cols}") + + except Exception as e: + logger.error(f"Error adding candle features: {e}", exc_info=True) + + return df + + # <<< NEW METHOD: Volatility and Channel Features >>> + def _add_volatility_channel_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds Bollinger Band Width and Keltner Channel Width.""" + logger.info("Adding volatility/channel features...") + required_cols = {'high', 'low', 'close'} # Need these for base calcs + if not required_cols.issubset(df.columns): + logger.warning(f"Missing required columns for volatility/channel features. Skipping.") + return df + + df_shifted = df.shift(1) + df_results = pd.DataFrame(index=df.index) + + W = 20 # Common window for BBands/KC + + try: + # Bollinger Bands Width (Normalized) + try: + bb_indicator = BollingerBands(close=df_shifted["close"], window=W, window_dev=2) + bb_width = bb_indicator.bollinger_wband() # Width = (Upper - Lower) + bb_center = bb_indicator.bollinger_mavg() # Center line (SMA) + df_results["bb_width_norm"] = safe_divide(bb_width, bb_center, default=0.0) + except Exception as bb_e: + logger.error(f"Error calculating Bollinger Band Width: {bb_e}") + df_results["bb_width_norm"] = 0.0 + + # Keltner Channel Width (Normalized) + try: + # Ensure EMA_20 and ATR_14 are available (or calculate them) + if f'EMA_{W}' not in df_shifted.columns: + ema_center = EMAIndicator(close=df_shifted["close"], window=W).ema_indicator() + else: + ema_center = df_shifted[f'EMA_{W}'] + + atr_window = 14 # Match ATR window used elsewhere? Or use W=20? Let's use 14 for now. + if f'ATR_{atr_window}' not in df_shifted.columns: + atr_val = AverageTrueRange(high=df_shifted['high'], low=df_shifted['low'], close=df_shifted['close'], window=atr_window).average_true_range() + else: + atr_val = df_shifted[f'ATR_{atr_window}'] + + kc_dev_multiplier = 1.5 # Common multiplier + kc_upper = ema_center + kc_dev_multiplier * atr_val + kc_lower = ema_center - kc_dev_multiplier * atr_val + kc_width = kc_upper - kc_lower + df_results["kc_width_norm"] = safe_divide(kc_width, ema_center, default=0.0) + except Exception as kc_e: + logger.error(f"Error calculating Keltner Channel Width: {kc_e}") + df_results["kc_width_norm"] = 0.0 + + # Fill NaNs + vol_cols = ["bb_width_norm", "kc_width_norm"] + for col in vol_cols: + if col in df_results.columns: + # Fix inplace fillna + df_results[col] = df_results[col].ffill() + df_results[col] = df_results[col].bfill() + df_results[col] = df_results[col].fillna(0) + + # Add results to original DataFrame + for col in vol_cols: + if col in df_results.columns: + df[col] = df_results[col] + + logger.info(f"Successfully added volatility/channel features: {vol_cols}") + + except ImportError: + logger.error("'ta' library components missing. Cannot calculate BBands/KC features. Skipping.") + except Exception as e: + logger.error(f"Error calculating volatility/channel features: {e}", exc_info=True) + + return df + + # <<< NEW METHOD: Momentum and Trend Features >>> + def _add_momentum_trend_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds Acceleration, Trend-Fit Residuals, Directional Change Count.""" + logger.info("Adding momentum/trend features...") + # Requires return_1m, close + if 'return_1m' not in df.columns or 'close' not in df.columns: + logger.warning("Missing required 'return_1m' or 'close' columns. Skipping momentum/trend features.") + return df + + df_shifted = df.shift(1) # Use shifted data for inputs when appropriate + df_results = pd.DataFrame(index=df.index) + + try: + # 1. Acceleration (using return_1m) + df_results["accel_1m"] = df["return_1m"].diff() # Difference of the non-shifted return column + + # 2. Trend-Fit Residuals (3-bar rolling regression on shifted close price) + df_results["trend_resid_3b"] = rolling_regression_residual(df_shifted['close'], window=3) + + # 3. Directional Change Count (consecutive up bars in last k=10 bars) + k = 10 + up_seq = (df_shifted['close'] > df_shifted['close'].shift()).astype(int) + # Calculate consecutive run length + run_up = up_seq.groupby((up_seq != up_seq.shift()).cumsum()).cumcount() + 1 + # Zero out runs where the bar was not an up bar + run_up_masked = run_up * up_seq + df_results[f"dir_chg_cnt_{k}b"] = run_up_masked.rolling(window=k, min_periods=1).sum() # Sum of consecutive counts in window + + # Fill NaNs + mom_cols = ["accel_1m", "trend_resid_3b", f"dir_chg_cnt_{k}b"] + for col in mom_cols: + if col in df_results.columns: + # Fix inplace fillna + df_results[col] = df_results[col].ffill() + df_results[col] = df_results[col].bfill() + df_results[col] = df_results[col].fillna(0) + + # Add results to original DataFrame + for col in mom_cols: + if col in df_results.columns: + df[col] = df_results[col] + + logger.info(f"Successfully added momentum/trend features: {mom_cols}") + + except Exception as e: + logger.error(f"Error calculating momentum/trend features: {e}", exc_info=True) + + return df + + # <<< NEW METHOD: Additional Seasonality Features >>> + def _add_extra_seasonality_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds Day-of-Week cycle and Time-to-Midnight features.""" + logger.info("Adding extra seasonality features...") + if not isinstance(df.index, pd.DatetimeIndex): + logger.warning("Index is not DatetimeIndex. Skipping extra seasonality features.") + return df + + timestamp_source = df.index + df_results = pd.DataFrame(index=df.index) + + try: + # 1. Day-of-week cycle + dow = timestamp_source.dayofweek # Monday=0, Sunday=6 + df_results["sin_dow"] = np.sin(2 * np.pi * dow / 7) + df_results["cos_dow"] = np.cos(2 * np.pi * dow / 7) + + # 2. Time-to-midnight (UTC) + minutes_in_day = 1440 + minutes_past_midnight = timestamp_source.hour * 60 + timestamp_source.minute + minutes_to_midnight = minutes_in_day - minutes_past_midnight + df_results["ttm_sin"] = np.sin(2 * np.pi * minutes_to_midnight / minutes_in_day) + # df_results["ttm_cos"] = np.cos(2 * np.pi * minutes_to_midnight / minutes_in_day) # Optional Cosine component + + # Fill NaNs (shouldn't be any here, but for consistency) + season_cols = ["sin_dow", "cos_dow", "ttm_sin"] + for col in season_cols: + if col in df_results.columns: + # Fix inplace fillna (though likely unnecessary here) + df_results[col] = df_results[col].fillna(0) + + # Add results to original DataFrame + for col in season_cols: + if col in df_results.columns: + df[col] = df_results[col] + + logger.info(f"Successfully added extra seasonality features: {season_cols}") + + except Exception as e: + logger.error(f"Error calculating extra seasonality features: {e}", exc_info=True) + + return df + + # <<< NEW METHOD: Interaction Features >>> + def _add_interaction_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds interaction features based on existing base features.""" + logger.info("Adding interaction features...") + df_results = pd.DataFrame(index=df.index) + required_base_features = [ + # Group 1 deps + "return_1m", "return_norm_1m", "volatility_14d", + # Group 2 deps + "svi_10", "chaikin_AD_10", "vol_norm_volume_spike", + # Group 3 deps + "return_60m", + # Group 4 deps + "hour_sin", "cos_dow", + # Group 5 deps + "EMA_10", "EMA_50", "body_ratio" + # Group 6 deps are already listed + ] + missing_reqs = [f for f in required_base_features if f not in df.columns] + if missing_reqs: + logger.warning(f"Missing required base features for interactions: {missing_reqs}. Skipping interaction features.") + return df + + try: + # --- Group 1: Momentum x Volatility --- # + # 1a: return_1m * rolling_std(return_1m, window=10) + rolling_std_ret1m = df["return_1m"].rolling(window=10, min_periods=5).std() + df_results["mom_vol_int_1m"] = df["return_1m"] * rolling_std_ret1m + # 1b: return_norm_1m * volatility_14d + df_results["norm_mom_vol_int_1m_14d"] = df["return_norm_1m"] * df["volatility_14d"] + + # --- Group 2: Price x Volume Imbalance --- # + # 2a: SVI_10 * return_1m + df_results["svi10_ret1m_int"] = df["svi_10"] * df["return_1m"] + # 2b: Chaikin AD * Volume Spike (using 1440 window spike) + df_results["ad10_volspike_int"] = df["chaikin_AD_10"] * df["vol_norm_volume_spike"] + + # --- Group 3: Cross-Horizon Momentum --- # + # 3a: return_1m * return_60m + df_results["ret1m_ret60m_int"] = df["return_1m"] * df["return_60m"] + + # --- Group 4: Seasonality x Edge --- # + # 4a: hour_sin * return_1m + df_results["hour_sin_ret1m_int"] = df["hour_sin"] * df["return_1m"] + # 4b: cos_dow * return_60m + df_results["dow_cos_ret60m_int"] = df["cos_dow"] * df["return_60m"] + + # --- Group 5: Trend x Candle Shape --- # + # 5a: Trend Strength (EMA diff) + df_results["ema_trend_strength"] = df["EMA_50"] - df["EMA_10"] + # 5b: Body Ratio * Trend Strength + df_results["body_ratio_trend_int"] = df["body_ratio"] * df_results["ema_trend_strength"] + + # --- Group 6: Auto-Polynomial Features --- # + poly_features_in = ["return_1m", "volatility_14d", "svi_10"] + missing_poly_in = [f for f in poly_features_in if f not in df.columns] + if not missing_poly_in: + logger.info(f"Generating polynomial features (degree=2, interaction_only) for: {poly_features_in}") + poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False) + # Ensure input data is clean + X_poly_in = df[poly_features_in].fillna(0) + X_poly_out = poly.fit_transform(X_poly_in) + # Get feature names (will be like 'x0 x1', 'x0 x2', 'x1 x2') + poly_feature_names_raw = poly.get_feature_names_out(poly_features_in) + # Create more readable names + poly_feature_names_out = [ + f"{name.replace(' ', '_x_')}_poly" for name in poly_feature_names_raw + ] + # Add to results DataFrame + df_poly = pd.DataFrame(X_poly_out, index=df.index, columns=poly_feature_names_out) + # Add only the interaction terms (avoid duplicating originals) + for col in df_poly.columns: + if col not in df_results.columns and col not in df.columns: + df_results[col] = df_poly[col] + logger.info(f"Added polynomial features: {df_poly.columns.tolist()}") + else: + logger.warning(f"Skipping polynomial features: Missing input columns {missing_poly_in}") + + # --- Fill NaNs/Infs for all interaction features --- # + interaction_cols = df_results.columns.tolist() + logger.info(f"Generated interaction features: {interaction_cols}") + for col in interaction_cols: + if col in df_results.columns: + # Fix inplace replace and fillna + df_results[col] = df_results[col].replace([np.inf, -np.inf], np.nan) # Use assignment + df_results[col] = df_results[col].ffill() + df_results[col] = df_results[col].bfill() + df_results[col] = df_results[col].fillna(0) + + # --- Add results to original DataFrame --- # + for col in interaction_cols: + if col in df_results.columns: + df[col] = df_results[col] + + except Exception as e: + logger.error(f"Error adding interaction features: {e}", exc_info=True) + + return df \ No newline at end of file diff --git a/gru_sac_predictor/src/features.py b/gru_sac_predictor/src/features.py new file mode 100644 index 00000000..86f020d8 --- /dev/null +++ b/gru_sac_predictor/src/features.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import pandas as pd +import numpy as np +# Restore imports from 'ta' library +from ta.volatility import AverageTrueRange +from ta.momentum import RSIIndicator +from ta.trend import EMAIndicator, MACD +# import talib # Remove talib import + +__all__ = [ + "add_imbalance_features", + "add_ta_features", + "prune_features", + "minimal_whitelist", +] + +_EPS = 1e-6 + +# --- New Feature Function (Task 2.1) --- +def vola_norm_return(df: pd.DataFrame, k: int) -> pd.Series: + """ + Calculates volatility-normalized returns over k periods. + return_k / rolling_std(return_k, window=k) + """ + if 'close' not in df.columns: + raise ValueError("'close' column required for vola_norm_return") + if k <= 1: + raise ValueError("Window k must be > 1 for rolling std dev") + + # Calculate k-period percentage change returns + returns_k = df['close'].pct_change(k) + + # Calculate rolling standard deviation of these k-period returns + sigma_k = returns_k.rolling(window=k, min_periods=max(2, k // 2 + 1)).std() + + # Normalize returns by volatility, replacing 0 std dev with NaN + vola_normed = returns_k / sigma_k.replace(0, np.nan) + + return vola_normed +# --- End New Feature Function --- + + +def add_imbalance_features(df: pd.DataFrame) -> pd.DataFrame: + """Add Chaikin AD line, signed volume imbalance, gap imbalance.""" + if not {"open", "high", "low", "close", "volume"}.issubset(df.columns): + return df + + clv = ((df["close"] - df["low"]) - (df["high"] - df["close"])) / ( + df["high"] - df["low"] + _EPS + ) + df["chaikin_AD_10"] = (clv * df["volume"]).rolling(10).sum() + + signed_vol = np.where(df["close"] >= df["open"], df["volume"], -df["volume"]) + df["svi_10"] = pd.Series(signed_vol, index=df.index).rolling(10).sum() + + med_vol = df["volume"].rolling(50).median() + gap_up = (df["low"] > df["high"].shift(1)) & (df["volume"] > 2 * med_vol) + gap_dn = (df["high"] < df["low"].shift(1)) & (df["volume"] > 2 * med_vol) + df["gap_imbalance"] = gap_up.astype(int) - gap_dn.astype(int) + + df.fillna(0, inplace=True) + return df + + +# ------------------------------------------------------------------ +# Technical analysis features +# ------------------------------------------------------------------ + + +def add_ta_features(df: pd.DataFrame) -> pd.DataFrame: + """Adds TA features to the dataframe using the ta library.""" + # Remove talib checks + # required_cols = {'open': 'open', 'high': 'high', 'low': 'low', 'close': 'close', 'volume': 'volume'} + # if not set(required_cols.keys()).issubset(df.columns): + # print(f"WARN: Missing required columns for TA-Lib in features.py. Need {required_cols.keys()}") + # return df + # Ensure correct dtype for talib (often float64) + # for col in required_cols.keys(): + # if df[col].dtype != np.float64: + # try: + # df[col] = df[col].astype(np.float64) + # except Exception as e: + # print(f"WARN: Could not convert column {col} to float64 for TA-Lib: {e}") + # return df # Cannot proceed if conversion fails + + df_copy = df.copy() + + # Calculate returns first (use bfill + ffill for pct_change compatibility) + # Fill NaNs robustly before pct_change + df_copy["close_filled"] = df_copy["close"].bfill().ffill() + df_copy["return_1m"] = df_copy["close_filled"].pct_change() + df_copy["return_15m"] = df_copy["close_filled"].pct_change(15) + df_copy["return_60m"] = df_copy["close_filled"].pct_change(60) + df_copy.drop(columns=["close_filled"], inplace=True) + + # Calculate TA features using ta library + # df_copy["ATR_14"] = talib.ATR(df_copy['high'], df_copy['low'], df_copy['close'], timeperiod=14) + df_copy["ATR_14"] = AverageTrueRange(df_copy['high'], df_copy['low'], df_copy['close'], window=14).average_true_range() + + # Daily volatility 14d of returns + df_copy["volatility_14d"] = ( + df_copy["return_1m"].rolling(60 * 24 * 14, min_periods=30).std() # rough 14d for 1‑min bars + ) + + # EMA 10 / 50 + MACD using ta library + # df_copy["EMA_10"] = talib.EMA(df_copy["close"], timeperiod=10) + # df_copy["EMA_50"] = talib.EMA(df_copy["close"], timeperiod=50) + df_copy["EMA_10"] = EMAIndicator(df_copy["close"], 10).ema_indicator() + df_copy["EMA_50"] = EMAIndicator(df_copy["close"], 50).ema_indicator() + # talib.MACD returns macd, macdsignal, macdhist + # macd, macdsignal, macdhist = talib.MACD(df_copy["close"], fastperiod=12, slowperiod=26, signalperiod=9) + macd = MACD(df_copy["close"], window_slow=26, window_fast=12, window_sign=9) + df_copy["MACD"] = macd.macd() + df_copy["MACD_signal"] = macd.macd_signal() + + # RSI 14 using ta library + # df_copy["RSI_14"] = talib.RSI(df_copy["close"], timeperiod=14) + df_copy["RSI_14"] = RSIIndicator(df_copy["close"], window=14).rsi() + + # Cyclical hour already recommended to add upstream (data_pipeline). + + # Handle potential NaNs introduced by TA calculations + # df.fillna(method="bfill", inplace=True) # Deprecated + df_copy.bfill(inplace=True) + df_copy.ffill(inplace=True) # Add ffill for any remaining NaNs at the beginning + + return df_copy + + +# ------------------------------------------------------------------ +# Pruning & whitelist +# ------------------------------------------------------------------ + +minimal_whitelist = [ + # Returns + "return_1m", + "return_15m", + "return_60m", + # Volatility + "ATR_14", + "volatility_14d", + # Vola-Normalized Returns (New) + "vola_norm_return_15", + "vola_norm_return_60", + # Imbalance + "chaikin_AD_10", + "svi_10", + # Trend + "EMA_10", + "EMA_50", + # "MACD", # Removed Task 2.3 + # "MACD_signal", # Removed Task 2.3 + # Cyclical (Time) + "hour_sin", + "hour_cos", + "week_sin", # Added Task 2.2 + "week_cos", # Added Task 2.2 +] + + +def prune_features(df: pd.DataFrame, whitelist: list[str] | None = None) -> pd.DataFrame: + """Return DataFrame containing only *whitelisted* columns.""" + if whitelist is None: + whitelist = minimal_whitelist + # Find columns present in both DataFrame and whitelist + cols_to_keep = [c for c in whitelist if c in df.columns] + # Ensure the set of kept columns exactly matches the intersection + df_pruned = df[cols_to_keep].copy() + assert set(df_pruned.columns) == set(cols_to_keep), \ + f"Pruning failed: Output columns {set(df_pruned.columns)} != Expected intersection {set(cols_to_keep)}" + # Optional: Assert against the full whitelist if input is expected to always contain all + # assert set(df_pruned.columns) == set(whitelist), \ + # f"Pruning failed: Output columns {set(df_pruned.columns)} != Full whitelist {set(whitelist)}" + return df_pruned \ No newline at end of file diff --git a/gru_sac_predictor/src/gru_hyper_tuner.py b/gru_sac_predictor/src/gru_hyper_tuner.py new file mode 100644 index 00000000..26ecd211 --- /dev/null +++ b/gru_sac_predictor/src/gru_hyper_tuner.py @@ -0,0 +1,449 @@ +""" +GRU Hyperparameter Tuner. + +Implements hyperparameter optimization for GRU models using Optuna. +""" + +import os +import logging +import json +from typing import Dict, Any, List, Tuple, Optional +import numpy as np +import pandas as pd +import optuna +from optuna.pruners import MedianPruner +from optuna.samplers import TPESampler +import tensorflow as tf +from sklearn.metrics import accuracy_score +import matplotlib.pyplot as plt +from tensorflow.keras.callbacks import Callback + +from gru_sac_predictor.src.gru_model_handler import GRUModelHandler +from gru_sac_predictor.src.metrics import edge_filtered_accuracy, calculate_brier_score + +logger = logging.getLogger(__name__) + +# --- Revision 3: Optuna Pruning Callback --- # +class OptunaPruningCallback(Callback): + """Keras Callback for Optuna Pruning.""" + def __init__(self, trial: optuna.Trial, monitor: str = 'val_loss'): + super().__init__() + self._trial = trial + self._monitor = monitor + + def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None): + logs = logs or {} + current_value = logs.get(self._monitor) + if current_value is None: + # Metric not found, cannot prune based on it + # logger.warning(f"Optuna Callback: Metric '{self._monitor}' not found in logs. Skipping pruning check.") + return + + # Report current value to Optuna trial + self._trial.report(current_value, step=epoch) + + # Check if the trial should be pruned + if self._trial.should_prune(): + message = f"Trial pruned at epoch {epoch} based on {self._monitor}={current_value}." + logger.info(message) + # Stop training by raising TrialPruned exception + raise optuna.TrialPruned(message) +# --- End Revision 3 --- # + +class GRUHyperTuner: + """ + Optimizes GRU model hyperparameters using Optuna. + + This class runs an Optuna hyperparameter sweep to find the best GRU + hyperparameters for a given dataset using edge-filtered accuracy as + the objective function. + """ + + def __init__(self, config: Dict[str, Any], fold_dir: str): + """ + Initialize the GRU Hyperparameter Tuner. + + Args: + config (Dict[str, Any]): Configuration dictionary. + fold_dir (str): Directory to store fold-specific results. + """ + self.config = config + self.fold_dir = fold_dir + self.is_ternary = config.get('gru', {}).get('use_ternary', False) + + # Get hyperparameter sweep config + sweep_config = config.get('hyperparameter_tuning', {}).get('gru', {}) + self.n_trials = sweep_config.get('sweep_n_trials', 20) + self.timeout = sweep_config.get('sweep_timeout', 7200) # 2 hours default + self.edge_threshold = config.get('calibration', {}).get('edge_threshold', 0.1) + + # Pruning settings + self.pruning_enabled = sweep_config.get('enable_pruning', True) + self.pruning_warmup = sweep_config.get('pruning_warmup_trials', 5) + + # Create fold's best params directory if it doesn't exist + os.makedirs(fold_dir, exist_ok=True) + + # Set random seeds for reproducibility + self.seed = config.get('random_seed', 42) + np.random.seed(self.seed) + tf.random.set_seed(self.seed) + + # Store best params and score + self.best_params = None + self.best_score = 0.0 + self.best_trial = None + + def _create_model(self, trial: optuna.Trial, X_train_shape: Tuple) -> Tuple[GRUModelHandler, int, int]: + """ + Create a GRU model with hyperparameters suggested by Optuna trial. + + Args: + trial (optuna.Trial): Optuna trial object. + X_train_shape (Tuple): Shape of X_train to determine input dimensions. + + Returns: + Tuple[GRUModelHandler, int, int]: Configured GRU model handler, lookback, and n_features. + """ + # Create a unique run ID for this trial + trial_run_id = f"tune_trial_{trial.number}" + + # Create a copy of the config to modify + trial_config = dict(self.config) + + # Sample hyperparameters + gru_config = trial_config.get('gru', {}) + + # Units in GRU layers + gru_config['units1'] = trial.suggest_categorical('units1', [48, 64, 96]) + gru_config['units2'] = trial.suggest_categorical('units2', [0, 48, 64]) # 0 means no second layer + + # Dropout rates + gru_config['dropout1'] = trial.suggest_float('dropout1', 0.0, 0.5) + gru_config['dropout2'] = trial.suggest_float('dropout2', 0.0, 0.5) + + # Learning rate and batch size + gru_config['learning_rate'] = trial.suggest_float('learning_rate', 1e-4, 3e-3, log=True) + gru_config['batch_size'] = trial.suggest_categorical('batch_size', [32, 64, 128, 256]) + + # L2 regularization + gru_config['l2_reg'] = trial.suggest_float('l2_reg', 1e-6, 1e-3, log=True) + + # Add attention heads if using v3 model + if trial_config.get('control', {}).get('use_v3', False): + gru_config['attn_heads'] = trial.suggest_categorical('attn_heads', [0, 4]) # 0 implies None + + # Update config + trial_config['gru'] = gru_config + + # Create model handler with trial config + model_handler = GRUModelHandler( + run_id=trial_run_id, + models_dir=self.fold_dir, + config=trial_config + ) + + # Set input shape and configuration + lookback = gru_config.get('lookback', 60) + n_features = X_train_shape[2] + + return model_handler, lookback, n_features + + def objective(self, trial: optuna.Trial, X_train: np.ndarray, y_train_dict: Dict[str, np.ndarray], + X_val: np.ndarray, y_val_dict: Dict[str, np.ndarray]) -> float: + """ + Objective function for Optuna optimization. + + Args: + trial (optuna.Trial): Optuna trial object. + X_train (np.ndarray): Training features. + y_train_dict (Dict[str, np.ndarray]): Training targets. + X_val (np.ndarray): Validation features. + y_val_dict (Dict[str, np.ndarray]): Validation targets. + + Returns: + float: Edge-filtered accuracy score to maximize. + """ + # Create model with trial hyperparameters + model_handler, lookback, n_features = self._create_model(trial, X_train.shape) + + # Early stopping config + patience = self.config.get('gru', {}).get('patience', 5) + max_epochs = self.config.get('gru', {}).get('epochs', 25) + batch_size = model_handler.config.get('gru', {}).get('batch_size', 128) + + # --- Revision 3: Add Optuna Pruning Callback --- # + callbacks = [] + if self.pruning_enabled: + # Use val_loss as the monitor for pruning + pruning_callback = OptunaPruningCallback(trial, monitor='val_loss') + callbacks.append(pruning_callback) + # --- End Revision 3 --- # + + # Determine which target key to use based on ternary flag + dir_key = "dir3" if self.is_ternary else "dir" + + # Train the model + try: + gru_model, history = model_handler.train( + X_train=X_train, + y_train_dict=y_train_dict, + X_val=X_val, + y_val_dict=y_val_dict, + lookback=lookback, + n_features=n_features, + max_epochs=max_epochs, + batch_size=batch_size, + patience=patience, + callbacks=callbacks # Pass the Optuna callback + ) + + if gru_model is None: + logger.warning(f"Trial {trial.number}: Model training failed.") + return -np.inf # Return worse possible score for maximization + + # Calculate edge-filtered accuracy and Brier score + edge_accuracy = np.nan + brier_score = np.nan + n_filtered = 0 + objective_value = -np.inf # Default for maximization objective + + if self.is_ternary: + # Get logits and calibrate if needed + logits_val = model_handler.predict_logits(X_val) + if logits_val is None: + logger.warning(f"Trial {trial.number}: Failed to get logits for ternary evaluation.") + return -np.inf + + # Use P(up) equivalent for binary check compatibility + p_up_equiv = logits_val[:, 2] + y_true_binary_equiv = (np.argmax(y_val_dict[dir_key], axis=1) == 2).astype(int) + + edge_accuracy, n_filtered = edge_filtered_accuracy( + y_true=y_true_binary_equiv, + p_cal=p_up_equiv, + thr=self.edge_threshold + ) + + # Brier score not calculated for ternary in this setup + logger.warning("Brier score calculation skipped for ternary objective.") + brier_score = 0.0 # Assign neutral value for objective calculation + objective_value = edge_accuracy # Objective for ternary is just edge accuracy for now + + else: + # Get binary predictions + predictions_val = model_handler.predict(X_val) + if predictions_val is None or len(predictions_val) < 3: + logger.warning(f"Trial {trial.number}: Failed to get predictions for binary evaluation.") + return -np.inf + + p_cal = predictions_val[2].flatten() + y_true = y_val_dict[dir_key] + + edge_accuracy, n_filtered = edge_filtered_accuracy( + y_true=y_true, + p_cal=p_cal, + thr=self.edge_threshold + ) + + # Calculate Brier score + try: + brier_score = calculate_brier_score(y_true, p_cal) + except Exception as e: + logger.warning(f"Trial {trial.number}: Failed to calculate Brier score: {e}. Setting to NaN.") + brier_score = np.nan + + # --- Calculate Combined Objective --- # + objective_metric = self.config.get('hyperparameter_tuning', {}).get('gru', {}).get('objective_metric', 'edge_accuracy') + if pd.isna(edge_accuracy) or pd.isna(brier_score): + objective_value = -np.inf # Invalid if metrics are NaN + elif objective_metric == 'edge_accuracy': + objective_value = edge_accuracy + elif objective_metric == 'brier_score': + objective_value = -brier_score # Minimize Brier, so maximize negative Brier + elif objective_metric == 'edge_acc_minus_brier': + w_edge = self.config.get('hyperparameter_tuning', {}).get('gru', {}).get('objective_edge_acc_weight', 0.7) + w_brier = self.config.get('hyperparameter_tuning', {}).get('gru', {}).get('objective_brier_weight', 0.3) + objective_value = w_edge * edge_accuracy - w_brier * brier_score + else: + logger.warning(f"Unknown objective metric '{objective_metric}'. Defaulting to edge accuracy.") + objective_value = edge_accuracy + # --- End Combined Objective --- # + + # --- Log components to Optuna Trial --- # + trial.set_user_attr("edge_accuracy", float(edge_accuracy) if pd.notna(edge_accuracy) else None) + trial.set_user_attr("brier_score", float(brier_score) if pd.notna(brier_score) else None) + trial.set_user_attr("n_filtered", int(n_filtered)) + # --- End Log --- # + + logger.info(f"Trial {trial.number} finished. Objective ({objective_metric}): {objective_value:.4f} (EdgeAcc: {edge_accuracy:.4f}, Brier: {brier_score:.4f}, N_filt: {n_filtered})") + return objective_value + + except optuna.TrialPruned as e: + # Handle pruning gracefully + logger.info(f"Trial {trial.number} pruned: {e}") + raise e # Re-raise to let Optuna handle it + except Exception as e: + logger.error(f"Error during Optuna objective evaluation (Trial {trial.number}): {e}", exc_info=True) + # Return a very bad value for maximization problems + return -np.inf + + def optimize(self, X_train: np.ndarray, y_train_dict: Dict[str, np.ndarray], + X_val: np.ndarray, y_val_dict: Dict[str, np.ndarray]) -> Dict[str, Any]: + """ + Run Optuna optimization to find the best hyperparameters. + + Args: + X_train (np.ndarray): Training features. + y_train_dict (Dict[str, np.ndarray]): Training targets. + X_val (np.ndarray): Validation features. + y_val_dict (Dict[str, np.ndarray]): Validation targets. + + Returns: + Dict[str, Any]: Best hyperparameters. + """ + # Create sampler and pruner + sampler = TPESampler(seed=self.seed) + pruner = MedianPruner(n_startup_trials=self.pruning_warmup, n_warmup_steps=0) if self.pruning_enabled else None + + # Create study + study = optuna.create_study( + direction="maximize", + sampler=sampler, + pruner=pruner, + study_name="gru_hyperparameter_tuning" + ) + + # Define objective function with fixed data + objective_with_data = lambda trial: self.objective( + trial, X_train, y_train_dict, X_val, y_val_dict + ) + + # Run optimization + logger.info(f"Starting GRU hyperparameter optimization with {self.n_trials} trials") + study.optimize(objective_with_data, n_trials=self.n_trials, timeout=self.timeout) + + # Get best trial and params + self.best_trial = study.best_trial + self.best_params = study.best_params + self.best_score = study.best_value + + # Log results + logger.info(f"Optimization finished. Best score: {self.best_score:.4f}") + logger.info(f"Best parameters: {self.best_params}") + + # Save best parameters to JSON file + best_params_path = os.path.join(self.fold_dir, 'best_gru_params.json') + with open(best_params_path, 'w') as f: + json.dump(self.best_params, f, indent=4) + logger.info(f"Best parameters saved to {best_params_path}") + + # Plot optimization history + self._plot_optimization_history(study) + + return self.best_params + + def _plot_optimization_history(self, study: optuna.Study) -> None: + """ + Plot the optimization history. + + Args: + study (optuna.Study): Completed Optuna study. + """ + try: + # Create figure + fig = plt.figure(figsize=(10, 6)) + + # Plot optimization history + optuna.visualization.matplotlib.plot_optimization_history(study) + + # Save figure + history_path = os.path.join(self.fold_dir, 'optuna_history.png') + plt.savefig(history_path, dpi=100) + plt.close(fig) + logger.info(f"Optimization history plot saved to {history_path}") + + # Plot parameter importances if we have enough trials + if len(study.trials) >= 5: + fig = plt.figure(figsize=(10, 6)) + optuna.visualization.matplotlib.plot_param_importances(study) + importances_path = os.path.join(self.fold_dir, 'param_importances.png') + plt.savefig(importances_path, dpi=100) + plt.close(fig) + logger.info(f"Parameter importances plot saved to {importances_path}") + + except Exception as e: + logger.error(f"Failed to plot optimization results: {str(e)}") + + def train_with_best_params(self, X_train: np.ndarray, y_train_dict: Dict[str, np.ndarray], + X_val: np.ndarray, y_val_dict: Dict[str, np.ndarray]) -> Tuple[GRUModelHandler, Dict]: + """ + Train final model with the best hyperparameters. + + Args: + X_train (np.ndarray): Training features. + y_train_dict (Dict[str, np.ndarray]): Training targets. + X_val (np.ndarray): Validation features. + y_val_dict (Dict[str, np.ndarray]): Validation targets. + + Returns: + Tuple[GRUModelHandler, Dict]: Trained model handler and history. + """ + if self.best_params is None: + logger.error("No best parameters found. Run optimize() first.") + return None, None + + # Create a unique run ID for the final model + final_run_id = f"tune_best_model" + + # Create a copy of the config to modify + final_config = dict(self.config) + gru_config = final_config.get('gru', {}).copy() + + # Apply best parameters + for param, value in self.best_params.items(): + gru_config[param] = value + + # Skip second GRU layer if units2 is 0 + if 'units2' in self.best_params and self.best_params['units2'] == 0: + gru_config['use_second_layer'] = False + else: + gru_config['use_second_layer'] = True + + # Update config + final_config['gru'] = gru_config + + # Create model handler with final config + model_handler = GRUModelHandler( + run_id=final_run_id, + models_dir=self.fold_dir, + config=final_config + ) + + # Set training parameters + lookback = gru_config.get('lookback', 60) + n_features = X_train.shape[2] + max_epochs = gru_config.get('epochs', 25) + batch_size = gru_config.get('batch_size', 128) + patience = gru_config.get('patience', 5) + + # Train final model + logger.info("Training final model with best hyperparameters") + model, history = model_handler.train( + X_train=X_train, + y_train_dict=y_train_dict, + X_val=X_val, + y_val_dict=y_val_dict, + lookback=lookback, + n_features=n_features, + max_epochs=max_epochs, + batch_size=batch_size, + patience=patience + ) + + # Save the model + if model is not None: + saved_path = model_handler.save() + logger.info(f"Final model saved to {saved_path}") + + return model_handler, history \ No newline at end of file diff --git a/gru_sac_predictor/src/gru_model_handler.py b/gru_sac_predictor/src/gru_model_handler.py new file mode 100644 index 00000000..0a204b75 --- /dev/null +++ b/gru_sac_predictor/src/gru_model_handler.py @@ -0,0 +1,736 @@ +""" +Handles GRU Model Training, Loading, Saving, and Prediction. + +This is the comprehensive GRU model implementation that includes both v2 and v3 +model architectures as well as the handler for training, saving, loading, and prediction. +""" + +import sys +import os +import json +import joblib +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import logging +from tqdm.keras import TqdmCallback +from typing import Dict, Tuple, Any, List, Optional, Callable + +# --- Add Imports for TF/Keras and Optional TFA --- # +try: + import tensorflow as tf + from tensorflow.keras import Model, layers, callbacks, saving + from tensorflow.keras.losses import Huber, CategoricalCrossentropy + from tensorflow.keras.optimizers import Adam + from tensorflow.keras import regularizers + from tensorflow.keras.metrics import MeanAbsoluteError, RootMeanSquaredError # type: ignore + KERAS_AVAILABLE = True +except ImportError: + logger.error("TensorFlow/Keras not found. GRU model functionality will be unavailable.") + # Define placeholders for types if Keras is not available + Model = Any + layers = Any + callbacks = Any + saving = Any + Huber = Any + Adam = Any + regularizers = Any + MeanAbsoluteError = Any + RootMeanSquaredError = Any + KERAS_AVAILABLE = False + +# --- Removed TFA Import Logic --- # +# try: +# import tensorflow_addons as tfa +# # Check for specific loss needed (optional, but good practice) +# if hasattr(tfa, 'losses') and (hasattr(tfa.losses, 'BinaryFocalCrossentropy') or hasattr(tfa.losses, 'CategoricalFocalCrossentropy')): +# TFA_AVAILABLE = True +# logger.info("TensorFlow Addons found and contains required Focal Loss classes.") +# else: +# TFA_AVAILABLE = False +# logger.warning("TensorFlow Addons found, but required Focal Loss classes (Binary/Categorical) seem missing.") +# except ImportError: +# logger.warning("TensorFlow Addons (tfa) not found. Focal Loss will not be available.") +# tfa = None # Define tfa as None if not imported +# TFA_AVAILABLE = False +# --- End Removed TFA Import Logic --- # + +# =================================================================== +# UTILITIES AND LOSS FUNCTIONS +# =================================================================== + +logger = logging.getLogger(__name__) # Define module-level logger + +# --- Define create_mask function globally --- # +def create_mask(inputs): + """Creates a lower-triangular causal mask for Attention layers.""" + input_shape = tf.shape(inputs) + # Need batch_size only if modifying mask per batch, otherwise seq_len is enough + # batch_size = input_shape[0] + seq_len = input_shape[1] + # Lower triangular mask (ones below and on diagonal, zeros above) + mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) + # MHA expects mask shape like (batch_size, seq_len, seq_len) or (seq_len, seq_len) + # Returning (seq_len, seq_len) is usually sufficient as Keras broadcasts. + return mask +# --- End global function definition --- + +# --- Custom Focal Loss Implementations --- # + +# --- Removed Binary Focal Loss --- # +# def binary_focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25): +# ... +# --- End Removed Binary Focal Loss --- # + +def categorical_focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25, label_smoothing=0.0): + """Categorical focal loss implementation. + Handles one-hot encoded targets and optional label smoothing. + Source: Adapted from https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalFocalCrossentropy + """ + if not KERAS_AVAILABLE: + raise ImportError("Keras/TensorFlow backend needed for focal loss calculation.") + + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + + # Apply label smoothing + if label_smoothing > 0.0: + num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype) + y_true = y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes) + + # Calculate cross-entropy loss + epsilon = tf.keras.backend.epsilon() + y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon) + ce = -y_true * tf.math.log(y_pred) + + # Calculate focal loss components + p_t = tf.reduce_sum(y_true * y_pred, axis=-1) # Probability of the true class + focal_factor = tf.pow(1.0 - p_t, gamma) + + # Weighted loss + # Note: TF/Keras implementation uses alpha on per-class basis if provided as array. + # Here, we use a single alpha for simplicity, but could be adapted. + # If alpha is used, it typically applies to the positive class contribution. + # For categorical, it's less common or needs careful definition. Let's omit alpha here. + # weighted_loss = alpha * focal_factor * ce + focal_loss_per_example = focal_factor * tf.reduce_sum(ce, axis=-1) + + return tf.reduce_mean(focal_loss_per_example) +# --- End Custom Focal Loss --- # + +@saving.register_keras_serializable(package='GRU') +def gaussian_nll(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: + """Gaussian negative‑log likelihood for *scalar* targets. + + The model is assumed to predict concatenated [mu, log_sigma]. + + Given targets :math:`y` and predictions :math:`\mu, \log\sigma`, + the NLL is + + .. math:: + + \mathcal{L}_{NLL} = \frac{1}{2} \exp(-2\log\sigma)(y-\mu)^2 + + \log\sigma. + + A small constant is added for numerical stability. + """ + mu, log_sigma = tf.split(y_pred, 2, axis=-1) + # Ensure y_true has the same shape as mu for the subtraction + y_true_shaped = tf.reshape(y_true, tf.shape(mu)) + inv_var = tf.exp(-2.0 * log_sigma) # = 1/sigma^2 + nll = 0.5 * inv_var * tf.square(y_true_shaped - mu) + log_sigma + return tf.reduce_mean(nll) + +# =================================================================== +# MODEL BUILDERS +# =================================================================== + +# --- Removed build_gru_model (v2) --- # +# def build_gru_model(lookback: int, n_features: int, kappa: float = 0.2) -> Model: +# ... +# --- End Removed build_gru_model (v2) --- # + +def build_gru_model_v3( + lookback: int, + n_features: int, + gru_units: int = 96, + attention_units: int = 16, + dropout_rate: float = 0.1, + learning_rate: float = 1e-4, + focal_gamma: float = 2.0, + focal_label_smoothing: float = 0.1, + huber_delta: float = 1.0, + loss_weight_mu: float = 0.3, + loss_weight_dir3: float = 1.0, + l2_reg: float = 1e-4 +) -> Model: + """ + Builds and compiles the GRU v3 model based on the specified architecture and hyperparameters. + + Architecture: Input -> GRU -> LayerNorm -> Attention -> LayerNorm -> Output Heads + + Args: + lookback (int): The sequence length for the GRU input. + n_features (int): The number of features at each timestep. + gru_units (int): Number of units for the GRU layer. + attention_units (int): Number of units for the Attention layer. + dropout_rate (float): Dropout rate for the GRU and Attention layers. + learning_rate (float): Learning rate for the Adam optimizer. + focal_gamma (float): Gamma parameter for CategoricalFocalCrossentropy. + focal_label_smoothing (float): Label smoothing for CategoricalFocalCrossentropy. + huber_delta (float): Delta parameter for Huber loss. + loss_weight_mu (float): Weight for the 'mu' output loss. + loss_weight_dir3 (float): Weight for the 'dir3' output loss. + l2_reg (float): L2 regularization factor for dense layers. + + Returns: + keras.Model: The compiled Keras model. + """ + + # --- Ensure l2_reg is a float --- # + try: + l2_reg_float = float(l2_reg) + except (ValueError, TypeError) as e: + logger.error(f"Invalid value for l2_reg: '{l2_reg}'. Expected a number. Error: {e}. Using 0.0.") + l2_reg_float = 0.0 + # --- End Ensure --- # + + # --- Ensure learning_rate is a float --- # + try: + learning_rate_float = float(learning_rate) + except (ValueError, TypeError) as e: + logger.error(f"Invalid value for learning_rate: '{learning_rate}'. Expected a number. Error: {e}. Using 1e-4.") + learning_rate_float = 1e-4 # Default value + # --- End Ensure --- # + + input_shape = (lookback, n_features) + inputs = layers.Input(shape=input_shape) + + # GRU Layer + Layer Norm (Revision 2-C) + x = layers.GRU( + gru_units, + return_sequences=True, + name='gru_base', + reset_after=True # Force standard kernel implementation + )(inputs) + x = layers.LayerNormalization(name='gru_layernorm')(x) # Added LayerNorm + x = layers.Dropout(dropout_rate, name='dropout_gru')(x) + + # Attention Layer (if attention_units > 0) + if attention_units > 0: + # --- Create Causal Mask using Lambda Layer (Keras Compatible) --- # + causal_mask = layers.Lambda(create_mask, name='causal_mask_lambda')(x) + # --- End Causal Mask Creation --- # + + attn_output = layers.MultiHeadAttention( + num_heads=max(1, attention_units // 16), # Example heuristic for num_heads + key_dim=attention_units, + kernel_regularizer=regularizers.l2(l2_reg_float), # Use float value + dropout=dropout_rate, + # --- REMOVED use_causal_mask --- # + # use_causal_mask=True, + # --- End REMOVED --- # + name='multi_head_attention' + # --- RE-ADD attention_mask argument --- # + )(query=x, value=x, key=x, attention_mask=causal_mask) # Pass the mask tensor from Lambda + # --- End RE-ADD --- # + x = layers.Dropout(dropout_rate, name='dropout_attn')(attn_output) + # If attention_units is 0, skip attention layer + + # Pooling Layer (Consider GlobalMaxPooling1D as alternative?) + pooled_output = layers.GlobalAveragePooling1D(name='global_avg_pool')(x) + + # Dense Heads with L2 Regularization (Revision 2-C) + kernel_regularizer = regularizers.l2(l2_reg_float) # Use float value + + # For the classification head, we'll separate logits and activation for potential logit extraction + logits_dir3 = layers.Dense( + 3, + name='dir3_logits', + kernel_regularizer=kernel_regularizer + )(pooled_output) + dir3_output = layers.Activation('softmax', name='dir3')(logits_dir3) + + mu_output = layers.Dense( + 1, + activation='linear', + name='mu', + kernel_regularizer=kernel_regularizer + )(pooled_output) + + model = tf.keras.Model(inputs=inputs, outputs=[mu_output, dir3_output]) + + # Compile using passed hyperparameters + # Use actual FocalLoss if TFA is available, otherwise the standard loss + # dir3_loss_instance = FocalLoss(gamma=focal_gamma, label_smoothing=focal_label_smoothing) if TFA_AVAILABLE else FocalLoss(label_smoothing=focal_label_smoothing) + # --- Use local categorical_focal_loss --- # + dir3_loss_instance = lambda y_true, y_pred: categorical_focal_loss( + y_true, y_pred, gamma=focal_gamma, label_smoothing=focal_label_smoothing + ) + dir3_loss_instance.__name__ = 'categorical_focal_loss' # For Keras saving/loading + # --- End Use local --- # + losses = { + "dir3": dir3_loss_instance, + "mu": Huber(delta=huber_delta) + } + loss_weights = {"dir3": loss_weight_dir3, "mu": loss_weight_mu} + optimizer = Adam(learning_rate=learning_rate_float, clipnorm=1.0) + metrics = {"dir3": ['accuracy']} + + try: + model.compile( + optimizer=optimizer, + loss=losses, + loss_weights=loss_weights, + metrics=metrics + ) + logger.info("GRU v3 model built and compiled successfully.") + except Exception as e: + logger.error(f"Failed to compile GRU v3 model: {e}") + # Re-raise to prevent using uncompiled model + raise e + + return model + +class GRUModelHandler: + """Manages the lifecycle of the GRU model.""" + + def __init__(self, run_id: str, models_dir: str, config: dict): + """ + Initialize the handler. + + Args: + run_id (str): The current pipeline run ID. + models_dir (str): The base directory where models for this run are saved. + config (dict): The pipeline configuration dictionary. + """ + self.run_id = run_id + self.models_dir = models_dir # Should be the specific directory for this run + self.config = config # Store config + self.model: Model | None = None + self.model_version_used = None # Track which version was built/loaded + self.use_ternary = self.config.get('gru', {}).get('use_ternary', False) # Store ternary flag + logger.info(f"GRUModelHandler initialized for run {run_id} in {models_dir}") + + def train( + self, + X_train: np.ndarray, + y_train_dict: Dict[str, np.ndarray], + X_val: np.ndarray, + y_val_dict: Dict[str, np.ndarray], + lookback: int, + n_features: int, + max_epochs: int = 25, + batch_size: int = 128, + patience: int = 3 + ) -> Tuple[Model | None, Any]: # Returns model and history + """Trains the GRU model using provided data and configuration.""" + if not KERAS_AVAILABLE: + logger.error("Keras is not available. Cannot train GRU model.") + return None, None + + if self.model is None: + logger.info("No existing model found. Building a new one for training.") + self.model = build_gru_model_v3( + lookback=lookback, + n_features=n_features, + # Pass hyperparameters from config + gru_units=self.config.get('gru', {}).get('gru_units', 96), + attention_units=self.config.get('gru', {}).get('attention_units', 16), + dropout_rate=self.config.get('gru', {}).get('dropout_rate', 0.1), + learning_rate=self.config.get('gru', {}).get('learning_rate', 1e-4), + focal_gamma=self.config.get('gru', {}).get('focal_gamma', 2.0), + focal_label_smoothing=self.config.get('gru', {}).get('focal_label_smoothing', 0.1), + huber_delta=self.config.get('gru', {}).get('huber_delta', 1.0), + loss_weight_mu=self.config.get('gru', {}).get('loss_weight_mu', 0.3), + loss_weight_dir3=self.config.get('gru', {}).get('loss_weight_dir3', 1.0), + l2_reg=self.config.get('gru', {}).get('l2_reg', 1e-4) + ) + else: + logger.info("Using the existing model for training.") + + if self.model is None: + logger.error("Failed to build or retrieve model for training.") + return None, None + + # --- Callbacks --- # + # 1. Early Stopping + early_stopping = callbacks.EarlyStopping( + monitor='val_loss', + patience=patience, + restore_best_weights=True, + verbose=1 + ) + + # 2. Model Checkpoint (optional, EarlyStopping restores best weights) + # Consider saving best model directly after training finishes if needed + # checkpoint_path = os.path.join(self.models_dir, f'best_gru_model_{self.run_id}.keras') + # model_checkpoint = callbacks.ModelCheckpoint( + # filepath=checkpoint_path, + # monitor='val_loss', + # save_best_only=True, + # save_weights_only=False, + # verbose=1 + # ) + + # 3. TQDM Progress Bar + tqdm_callback = TqdmCallback(verbose=1) + + # 4. CSV Logger (V3 Output Contract) + # Define logs directory based on models_dir structure (assuming logs are sibling) + base_dir = os.path.dirname(self.models_dir) # e.g., /path/to/models + logs_base_dir = os.path.join(os.path.dirname(base_dir), 'logs') # /path/to/logs + run_logs_dir = os.path.join(logs_base_dir, self.run_id) # /path/to/logs/ + os.makedirs(run_logs_dir, exist_ok=True) + csv_log_path = os.path.join(run_logs_dir, 'gru_history.csv') + logger.info(f"Setting up CSVLogger to save history to: {csv_log_path}") + csv_logger = callbacks.CSVLogger(csv_log_path, append=False) + + # --- Prepare target data structure for model.fit --- # + # The keys in the dictionary passed to y must match the names of the output layers. + # Model output layer names are 'mu' and 'dir3' (as defined in build_gru_model_v3) + + try: + # Find the actual keys in the input dictionary + input_keys = list(y_train_dict.keys()) + ret_key = next((k for k in input_keys if 'fwd_log_ret' in k), None) + dir_key = next((k for k in input_keys if 'direction_label' in k), None) + + if ret_key is None or dir_key is None: + raise KeyError(f"Could not find required target keys ('fwd_log_ret...', 'direction_label...') in y_train_dict. Found: {input_keys}") + + # Map model output names ('mu', 'dir3') to data using actual input keys + y_train_fit = { + 'mu': y_train_dict[ret_key], + 'dir3': y_train_dict[dir_key] + } + # Assume y_val_dict has the same keys + y_val_fit = { + 'mu': y_val_dict[ret_key], + 'dir3': y_val_dict[dir_key] + } + logger.debug(f"Target structure for model.fit prepared. Model output names mapped: {list(y_train_fit.keys())}") + except KeyError as e: + logger.error(f"KeyError preparing target dictionary for model.fit: {e}") + logger.error(f"Available train keys: {list(y_train_dict.keys())}") + logger.error(f"Available val keys: {list(y_val_dict.keys())}") + return None, None # Indicate failure + # --- End target preparation --- # + + # --- Handle potential NaNs/Infs before training --- + if np.isnan(X_train).any() or np.isinf(X_train).any(): + logger.error("NaN or Inf found in X_train sequence data before training. Aborting.") + return None, None + if any(np.isnan(arr).any() or np.isinf(arr).any() for arr in y_train_fit.values()): + logger.error("NaN or Inf found in y_train_fit sequence data before training. Aborting.") + return None, None + if np.isnan(X_val).any() or np.isinf(X_val).any(): + logger.error("NaN or Inf found in X_val sequence data before training. Aborting.") + return None, None + if any(np.isnan(arr).any() or np.isinf(arr).any() for arr in y_val_fit.values()): + logger.error("NaN or Inf found in y_val_fit sequence data before training. Aborting.") + return None, None + logger.info("Data checks passed: No NaNs or Infs found in training/validation sequences.") + # --- End Check --- + + logger.info(f"Starting GRU model training for {max_epochs} epochs...") + history = self.model.fit( + X_train, + y_train_fit, # Pass the prepared dictionary + epochs=max_epochs, + batch_size=batch_size, + validation_data=(X_val, y_val_fit), # Pass the prepared dictionary + callbacks=[ + early_stopping, + # model_checkpoint, # Optional: Rely on restore_best_weights + tqdm_callback, + csv_logger # Add CSV Logger + ], + shuffle=False, # Important for time series + verbose=1 # Or 2 for more detail per epoch + ) + + return self.model, history + + def save(self, model_name: str = 'gru_model') -> str | None: + """ + Saves the current model to the run's model directory. + Appends model version to the filename. + """ + if self.model is None: + logger.error("No model available to save.") + return None + if self.model_version_used is None: + logger.warning("Model version was not set during training/loading. Saving with default name.") + version_suffix = "unknown" + else: + version_suffix = self.model_version_used # e.g., 'v2' or 'v3' + + # Use .keras format and include version in filename + save_filename = f"{model_name}_{self.run_id}.keras" + save_path = os.path.join(self.models_dir, save_filename) + try: + self.model.save(save_path) + logger.info(f"GRU model saved successfully to: {save_path}") + return save_path + except Exception as e: + logger.error(f"Failed to save GRU model to {save_path}: {e}", exc_info=True) + return None + + def load(self, model_path: str) -> Model | None: + """ + Loads a GRU model from the specified path. + Handles custom objects if needed (primarily for v2 gaussian_nll). + """ + if not os.path.exists(model_path): + logger.error(f"Model file not found at: {model_path}") + return None + + logger.info(f"Loading GRU model from: {model_path}") + try: + # Custom objects needed mainly for v2's gaussian_nll + # custom_objects = {'gaussian_nll': gaussian_nll} + # v3 doesn't need gaussian_nll, but might need focal loss if not standard + custom_objects = { + 'categorical_focal_loss': categorical_focal_loss, + 'FocalLoss': categorical_focal_loss, # Add alias if model was saved with class name + # --- Add create_mask to custom objects --- # + 'create_mask': create_mask + # --- End Add --- # + } + + # Attempt to load FocalLoss if needed (REMOVED - Using local implementation) + # if TFA_AVAILABLE: + # ... + + self.model = tf.keras.models.load_model(model_path, custom_objects=custom_objects) + logger.info("GRU model loaded successfully.") + # Try to infer model version from loaded model output names + try: + if 'dir3' in self.model.output_names: + self.model_version_used = 'v3' + # elif 'dir' in self.model.output_names: # Assuming v2 used 'dir' + # self.model_version_used = 'v2' # Removed v2 check + else: + self.model_version_used = 'unknown' + logger.info(f"Inferred loaded model version: {self.model_version_used}") + except Exception: + logger.warning("Could not infer model version from loaded model.") + self.model_version_used = 'unknown' + + self.model.summary(print_fn=logger.info) # Log summary of loaded model + return self.model + except Exception as e: + logger.error(f"Failed to load GRU model from {model_path}: {e}", exc_info=True) + self.model = None + return None + + # --- Add Logits Prediction Method (Task 4) --- # + def _make_logits_model(self) -> Model | None: + """Builds a frozen view that outputs dir3 pre-softmax logits.""" + if self.model is None: + logger.error("Cannot create logits view: Main model is not loaded/built.") + return None + + # Check if the view is already cached + if hasattr(self, "_logit_view") and self._logit_view is not None: + return self._logit_view + + try: + # Check if the expected logits layer exists + logits_layer = self.model.get_layer("dir3_logits") + logits_tensor = logits_layer.output + # Create a new model sharing the inputs and weights + self._logit_view = tf.keras.Model( + inputs=self.model.input, + outputs=logits_tensor, + name="gru_logits_view" # Optional name + ) + # No compilation needed for inference-only model + logger.info("Created inference-only model view for 'dir3_logits' output.") + return self._logit_view + except ValueError: + logger.error("Layer 'dir3_logits' not found in the current model. Cannot create logits view.") + self._logit_view = None # Ensure cache is None + return None + except Exception as e: + logger.error(f"Error creating logits view model: {e}", exc_info=True) + self._logit_view = None + return None + + def predict_logits(self, X_data: np.ndarray, batch_size: int = 1024) -> np.ndarray | None: + """Returns raw logits (n,3) for Vector Scaling calibration using a model view.""" + logit_model = self._make_logits_model() + + if logit_model is None: + logger.error("Logits view model is not available. Cannot predict logits.") + return None + + if X_data is None or len(X_data) == 0: + logger.warning("Input data for logit prediction is None or empty.") + return None + + logger.info(f"Generating logit predictions for {len(X_data)} samples...") + try: + # Use verbose=0 to avoid Keras progress bars for this internal prediction + logits = logit_model.predict(X_data, batch_size=batch_size, verbose=0) + logger.info("Logit predictions generated successfully.") + logger.debug(f"Logits output shape: {logits.shape}") + return logits + except Exception as e: + logger.error(f"Error during logit prediction: {e}", exc_info=True) + return None + # --- End Logits Prediction Method --- # + + def predict(self, X_data: np.ndarray, batch_size: int = 1024) -> Any: + """ + Generates predictions using the loaded/trained model. + """ + if self.model is None: + logger.error("No model available for prediction.") + return None + if X_data is None or len(X_data) == 0: + logger.warning("Input data for prediction is None or empty.") + return None + + logger.info(f"Generating predictions for {len(X_data)} samples...") + try: + predictions = self.model.predict(X_data, batch_size=batch_size) + logger.info("Predictions generated successfully.") + if isinstance(predictions, list): + pred_shapes = [p.shape for p in predictions] + logger.debug(f"Prediction output shapes: {pred_shapes}") + else: + logger.debug(f"Prediction output shape: {predictions.shape}") + return predictions + except Exception as e: + logger.error(f"Error during model prediction: {e}", exc_info=True) + return None + + # --- New Method for Temporary Training (Nested CV) --- # + def build_and_train_temporary( + self, + gru_config: Dict[str, Any], + X_train: np.ndarray, + y_train_dict: Dict[str, np.ndarray], + X_val: np.ndarray, + y_val_dict: Dict[str, np.ndarray], + lookback: int, + n_features: int, + ) -> Tuple[Optional[Model], Optional[Any]]: + """ + Builds and trains a temporary GRU model for hyperparameter tuning trials. + Does NOT save persistent checkpoints or modify the main handler state (self.model). + + Args: + gru_config (dict): Hyperparameters for this specific trial. + X_train, y_train_dict: Training sequences and targets for the inner fold. + X_val, y_val_dict: Validation sequences and targets for the inner fold. + lookback (int): Input sequence length. + n_features (int): Number of input features. + + Returns: + Tuple[Model | None, History | None]: The trained (but not saved) Keras model + and the History object, or (None, None) on failure. + """ + logger.info(f"--- Starting TEMPORARY GRU training (Trial) ---") + logger.debug(f"Trial Params: {gru_config}") + + # Determine parameters from trial config + # Use .get with defaults from the main config stored in self.config if available + _max_epochs = int(gru_config.get('epochs', self.config.get('gru', {}).get('epochs', 25))) + _batch_size = int(gru_config.get('batch_size', self.config.get('gru', {}).get('batch_size', 128))) + _patience = int(gru_config.get('patience', self.config.get('gru', {}).get('patience', 5))) + + temp_model: Optional[Model] = None + temp_history: Optional[Any] = None + + # 1. Build the temporary model + try: + # Ensure parameters are passed as correct types + temp_model = build_gru_model_v3( + lookback=lookback, + n_features=n_features, + gru_units=int(gru_config.get('gru_units', 96)), + attention_units=int(gru_config.get('attention_units', 16)), + dropout_rate=float(gru_config.get('dropout_rate', 0.1)), + learning_rate=float(gru_config.get('learning_rate', 1e-4)), + focal_gamma=float(gru_config.get('focal_gamma', 2.0)), + focal_label_smoothing=float(gru_config.get('focal_label_smoothing', 0.1)), + huber_delta=float(gru_config.get('huber_delta', 1.0)), + loss_weight_mu=float(gru_config.get('loss_weight_mu', 0.3)), + loss_weight_dir3=float(gru_config.get('loss_weight_dir3', 1.0)), + l2_reg=float(gru_config.get('l2_reg', 1e-4)) + ) + logger.info("Temporary GRU model built.") + except Exception as build_err: + logger.error(f"Error building temporary GRU model: {build_err}", exc_info=True) + return None, None + + # 2. Setup TEMPORARY Callbacks (No ModelCheckpoint or TensorBoard) + early_stopping_callback = callbacks.EarlyStopping( + monitor='val_loss', + patience=_patience, + verbose=0, # Less verbose for trials + restore_best_weights=True # Important for getting best trial performance + ) + tqdm_callback = TqdmCallback(verbose=0) # Silent progress bar + + temp_callbacks = [early_stopping_callback, tqdm_callback] + + # 3. Train the temporary model + try: + logger.debug(f"Starting temporary model.fit (Epochs={_max_epochs}, Batch={_batch_size})...") + temp_history = temp_model.fit( + X_train, + y_train_dict, + epochs=_max_epochs, + batch_size=_batch_size, + validation_data=(X_val, y_val_dict), + callbacks=temp_callbacks, + shuffle=False, + verbose=0 + ) + logger.debug("Temporary model training finished.") + # Log best val_loss achieved in this temporary training + if temp_history and 'val_loss' in temp_history.history: + if temp_history.history['val_loss']: # Check if list is not empty + best_val_loss = min(temp_history.history['val_loss']) + logger.debug(f"Temporary training best val_loss: {best_val_loss:.5f}") + else: + logger.warning("Temporary training history 'val_loss' list is empty.") + else: + logger.warning("Temporary training history missing 'val_loss' key or history is None.") + + except Exception as train_err: + logger.warning(f"Error during temporary GRU model training: {train_err}") # Log as warning for trials + temp_model = None + temp_history = None + + return temp_model, temp_history + # --- End New Method --- # + + def save(self, model_name: str = 'gru_model') -> str | None: + """ + Saves the trained model architecture, weights, scaler, and metadata. + Uses Keras standard saving format (.keras). + """ + if self.model is None: + logger.error("No model available to save.") + return None + if self.model_version_used is None: + logger.warning("Model version was not set during training/loading. Saving with default name.") + version_suffix = "unknown" + else: + version_suffix = self.model_version_used # e.g., 'v2' or 'v3' + + # Use .keras format and include version in filename + save_filename = f"{model_name}_{self.run_id}.keras" + save_path = os.path.join(self.models_dir, save_filename) + try: + self.model.save(save_path) + logger.info(f"GRU model saved successfully to: {save_path}") + return save_path + except Exception as e: + logger.error(f"Failed to save GRU model to {save_path}: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/gru_sac_predictor/src/io_manager.py b/gru_sac_predictor/src/io_manager.py new file mode 100644 index 00000000..8ad5c691 --- /dev/null +++ b/gru_sac_predictor/src/io_manager.py @@ -0,0 +1,276 @@ +""" +IO Manager for handling file paths and saving artifacts. + +Ref: revisions.txt Section 1 +""" + +import os +import json +import logging +import pandas as pd +from typing import Any, Dict, Optional, List +import matplotlib.pyplot as plt + +logger = logging.getLogger(__name__) + +class IOManager: + """ + Manages input/output operations, including path construction and saving various artifacts. + """ + + def __init__(self, cfg: Dict[str, Any], run_id: str): + """ + Initialize the IOManager. + + Args: + cfg (Dict[str, Any]): The pipeline configuration dictionary. + run_id (str): The unique identifier for the current run. + """ + self.cfg = cfg + self.run_id = run_id + + # Extract base directories, providing defaults if missing + self.base_dirs = cfg.get('base_dirs', {}) + self.results_dir = self._resolve_path(self.base_dirs.get('results', 'results')) + self.models_dir = self._resolve_path(self.base_dirs.get('models', 'models')) + self.logs_dir = self._resolve_path(self.base_dirs.get('logs', 'logs')) + + # Specific directories for the current run + self.run_results_dir = os.path.join(self.results_dir, self.run_id) + self.run_models_dir = os.path.join(self.models_dir, self.run_id) + self.run_logs_dir = os.path.join(self.logs_dir, self.run_id) + self.run_figures_dir = os.path.join(self.run_results_dir, 'figures') # Figures within results + + # Create directories if they don't exist + os.makedirs(self.run_results_dir, exist_ok=True) + os.makedirs(self.run_models_dir, exist_ok=True) + os.makedirs(self.run_logs_dir, exist_ok=True) + os.makedirs(self.run_figures_dir, exist_ok=True) + + logger.info(f"IOManager initialized for run {self.run_id}.") + logger.info(f" Results Dir: {self.run_results_dir}") + logger.info(f" Models Dir: {self.run_models_dir}") + logger.info(f" Logs Dir: {self.run_logs_dir}") + logger.info(f" Figures Dir: {self.run_figures_dir}") + + def _resolve_path(self, path: str) -> str: + """ + Resolves a path relative to the project root. + Assumes this file is in src/ for relative path calculation. + """ + if os.path.isabs(path): + return path + else: + # Assumes src/io_manager.py structure + script_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(script_dir) + # Go up one level from src to get to package root + package_root = os.path.dirname(project_root) + return os.path.abspath(os.path.join(package_root, path)) + + def path(self, section: str, name: str, suffix: Optional[str] = None) -> str: + """ + Constructs a full path for an artifact within a specific run section. + + Args: + section (str): The base directory section ('results', 'models', 'logs', 'figures'). + name (str): The base name of the file (without extension or run_id typically). + suffix (Optional[str]): File extension (e.g., '.txt', '.png'). Auto-added by save methods. + + Returns: + str: The full, absolute path to the artifact. + Includes the run_id in the path structure. + """ + base_path = "" + if section == 'results': + base_path = self.run_results_dir + elif section == 'models': + base_path = self.run_models_dir + elif section == 'logs': + base_path = self.run_logs_dir + elif section == 'figures': + base_path = self.run_figures_dir + else: + raise ValueError(f"Unknown path section: '{section}'. Must be one of results, models, logs, figures.") + + filename = name + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + filename += suffix + + full_path = os.path.join(base_path, filename) + return full_path + + def get_fold_dirs(self, fold_num: int) -> Dict[str, str]: + """ + Creates and returns paths for fold-specific subdirectories. + + Args: + fold_num (int): The fold number (1-based). + + Returns: + Dict[str, str]: Dictionary mapping section names to fold-specific directory paths. + """ + fold_suffix = f"fold_{fold_num}" + fold_dirs = {} + + # Define paths for fold-specific subdirectories + fold_dirs['results'] = os.path.join(self.run_results_dir, fold_suffix) + fold_dirs['models'] = os.path.join(self.run_models_dir, fold_suffix) + fold_dirs['logs'] = os.path.join(self.run_logs_dir, fold_suffix) # Optional: logs per fold? + fold_dirs['figures'] = os.path.join(self.run_figures_dir, fold_suffix) + + # Create these directories + for section, path in fold_dirs.items(): + try: + os.makedirs(path, exist_ok=True) + logger.debug(f"Ensured fold directory exists: {path}") + except OSError as e: + logger.error(f"Failed to create fold directory '{path}' for section '{section}': {e}") + # Decide how to handle error: raise, return partial dict, etc. + # For now, let's log and continue, returning potentially incomplete dict + + logger.info(f"Fold {fold_num} directories prepared:") + for section, path in fold_dirs.items(): + logger.info(f" {section.capitalize()}: {path}") + + return fold_dirs + + # --- Save Methods (Task 1.1) --- # + def save_json(self, data: Dict[str, Any], name: str, section: str = 'results', indent: int = 4, use_txt: bool = False): + """ + Saves dictionary data to a JSON file (or .txt if specified) in the target section. + """ + suffix = '.txt' if use_txt else '.json' + file_path = self.path(section, name, suffix=suffix) + logger.info(f"Saving JSON data to {file_path}...") + try: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=indent) + logger.debug(f"Successfully saved JSON to {file_path}") + except TypeError as e: + logger.error(f"TypeError saving JSON to {file_path}. Data may contain non-serializable types: {e}") + except Exception as e: + logger.error(f"Failed to save JSON to {file_path}: {e}", exc_info=True) + + def save_df(self, df: pd.DataFrame, name: str, section: str = 'results'): + """ + Saves DataFrame to CSV or Parquet based on estimated size and config. + Uses 'output.dataframe_save_format' and 'output.dataframe_max_csv_mb' from config. + """ + if df is None or df.empty: + logger.warning(f"Attempted to save empty DataFrame '{name}'. Skipping.") + return None # Return None to indicate skip + + output_cfg = self.cfg.get('output', {}) + save_format_pref = output_cfg.get('dataframe_save_format', 'parquet_if_large') # Default pref + max_csv_mb = output_cfg.get('dataframe_max_csv_mb', 100) # Default size limit + + file_path = None + saved_format = None + + try: + size_mb = df.memory_usage(index=True, deep=True).sum() / (1024**2) + logger.debug(f"DataFrame '{name}' estimated size: {size_mb:.2f} MB (Max CSV Size: {max_csv_mb} MB)") + + use_parquet = False + if save_format_pref == 'parquet': + use_parquet = True + elif save_format_pref == 'parquet_if_large' and size_mb > max_csv_mb: + use_parquet = True + # else: save_format_pref == 'csv' or (parquet_if_large and size_mb <= max_csv_mb) -> use_csv + + if use_parquet: + try: + # Check for pyarrow installation before attempting Parquet save + import pyarrow # noqa: F401 + suffix = '.parquet' + saved_format = 'Parquet' + file_path = self.path(section, name, suffix=suffix) + logger.info(f"Saving DataFrame '{name}' as {saved_format} to {file_path}...") + os.makedirs(os.path.dirname(file_path), exist_ok=True) + df.to_parquet(file_path, index=True) + except ImportError: + logger.warning(f"Cannot save DataFrame '{name}' as Parquet. Missing 'pyarrow'. Falling back to CSV.") + use_parquet = False # Force CSV save + except Exception as parquet_e: + logger.error(f"Failed to save DataFrame '{name}' as Parquet: {parquet_e}", exc_info=True) + # Optionally fallback? For now, let error propagate if it's not ImportError + raise parquet_e + + if not use_parquet: # Save as CSV if format is 'csv', fallback, or size threshold not met + suffix = '.csv' + saved_format = 'CSV' + file_path = self.path(section, name, suffix=suffix) + logger.info(f"Saving DataFrame '{name}' as {saved_format} to {file_path}...") + os.makedirs(os.path.dirname(file_path), exist_ok=True) + df.to_csv(file_path, index=True) + + logger.debug(f"Successfully saved DataFrame to {file_path}") + return file_path # Return path where it was saved + + except Exception as e: + logger.error(f"Failed to save DataFrame '{name}' (intended format: {saved_format or 'N/A'}): {e}", exc_info=True) + return None # Return None on failure + + def save_figure(self, fig: plt.Figure, name: str, section: str = 'figures', **kwargs): + """ + Saves matplotlib figure using config settings (dpi). + """ + file_path = self.path(section, name, suffix='.png') + logger.info(f"Saving figure to {file_path}...") + try: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + output_cfg = self.cfg.get('output', {}) + dpi = kwargs.pop('dpi', output_cfg.get('figure_dpi', 150)) + + if hasattr(fig, 'tight_layout'): + try: + fig.tight_layout() + except Exception as tl_e: + logger.warning(f"Could not apply tight_layout to figure '{name}': {tl_e}") + + fig.savefig(file_path, dpi=dpi, bbox_inches='tight', **kwargs) + logger.debug(f"Successfully saved figure to {file_path}") + except Exception as e: + logger.error(f"Failed to save figure to {file_path}: {e}", exc_info=True) + finally: + plt.close(fig) # Close figure to free memory + +# Example Usage +if __name__ == '__main__': + # Mock config for testing + mock_config = { + 'base_dirs': {'results': 'temp_results', 'models': 'temp_models', 'logs': 'temp_logs'}, + 'output': {'figure_dpi': 120} + } + mock_run_id = "20250418_110000_testabc" + + # Create mock directories for test + if not os.path.exists('temp_results'): os.makedirs('temp_results') + if not os.path.exists('temp_models'): os.makedirs('temp_models') + if not os.path.exists('temp_logs'): os.makedirs('temp_logs') + + io = IOManager(mock_config, mock_run_id) + + print(f"Results Path: {io.path('results', 'metrics', '.txt')}") + print(f"Models Path: {io.path('models', 'gru_model', '.keras')}") + print(f"Figures Path: {io.path('figures', 'calibration_plot', '.png')}") + print(f"Logs Path: {io.path('logs', 'pipeline_log', '.log')}") + + # Test saving + test_dict = {'a': 1, 'b': [2, 3], 'c': 'test'} + io.save_json(test_dict, 'test_data', section='results') + io.save_json(test_dict, 'report_data', section='results', use_txt=True) + + test_df_small = pd.DataFrame(np.random.randn(100, 5), columns=list('ABCDE')) + io.save_df(test_df_small, 'small_data', section='results') + + # Clean up mock directories + import shutil + if os.path.exists('temp_results'): shutil.rmtree('temp_results') + if os.path.exists('temp_models'): shutil.rmtree('temp_models') + if os.path.exists('temp_logs'): shutil.rmtree('temp_logs') + \ No newline at end of file diff --git a/gru_sac_predictor/src/logger_setup.py b/gru_sac_predictor/src/logger_setup.py new file mode 100644 index 00000000..389b8223 --- /dev/null +++ b/gru_sac_predictor/src/logger_setup.py @@ -0,0 +1,182 @@ +""" +Logger Setup Utility. + +Ref: revisions.txt Section 1 +""" + +import logging +import logging.handlers +import sys +import os +from typing import Dict, Any + +# Conditional import for colorlog +try: + import colorlog + COLORLOG_AVAILABLE = True +except ImportError: + COLORLOG_AVAILABLE = False + +# Assuming IOManager is in the same directory or accessible via path +try: + from .io_manager import IOManager +except ImportError: + # Fallback if run as script or structure changes + IOManager = None + +# Define default log format strings (as fallbacks) +DEFAULT_LOG_FORMAT_CONSOLE = '%(log_color)s%(levelname)-8s%(reset)s | %(name)-12s | %(message)s' +DEFAULT_LOG_FORMAT_FILE = '%(asctime)s | %(levelname)-8s | %(name)-15s | %(filename)s:%(lineno)d | %(message)s' +DEFAULT_LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S' + +def setup_logger(cfg: Dict[str, Any], run_id: str, io: IOManager) -> logging.Logger: + """ + Configures the root logger with console and rotating file handlers based on config. + + Args: + cfg (Dict[str, Any]): The pipeline configuration dictionary. + run_id (str): The unique run identifier. + io (IOManager): The IOManager instance for getting log file path. + + Returns: + logging.Logger: The configured root logger instance. + """ + # --- Get Logging Config --- # + log_cfg = cfg.get('logging', {}) + output_cfg = cfg.get('output', {}) # Still need this for base log level + + console_log_level_str = log_cfg.get('console_level', 'INFO').upper() + file_log_level_str = log_cfg.get('file_level', 'DEBUG').upper() + log_to_file = log_cfg.get('log_to_file', True) + log_format_console = log_cfg.get('log_format_console', DEFAULT_LOG_FORMAT_CONSOLE) + log_format_file = log_cfg.get('log_format', DEFAULT_LOG_FORMAT_FILE) # Use 'log_format' for file + log_date_format = log_cfg.get('log_date_format', DEFAULT_LOG_DATE_FORMAT) + log_file_max_bytes = log_cfg.get('log_file_max_bytes', 10*1024*1024) # Default 10MB + log_file_backup_count = log_cfg.get('log_file_backup_count', 5) # Default 5 backups + # --- End Logging Config --- # + + console_log_level = getattr(logging, console_log_level_str, logging.INFO) + file_log_level = getattr(logging, file_log_level_str, logging.DEBUG) + + root_logger = logging.getLogger() # Get the root logger + # Set root level to the lowest of the handler levels to ensure messages pass through + root_logger.setLevel(min(console_log_level, file_log_level)) + + # Remove existing handlers to avoid duplication if called multiple times + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # --- Console Handler --- # + if COLORLOG_AVAILABLE: + cformat = colorlog.ColoredFormatter( + log_format_console, # Use format from config + datefmt=log_date_format, # Use date format from config + reset=True, + log_colors={ + 'DEBUG': 'cyan', + 'INFO': 'green', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'red,bg_white', + }, + secondary_log_colors={}, + style='%' + ) + console_handler = colorlog.StreamHandler(sys.stdout) + console_handler.setFormatter(cformat) + else: + # Use file format from config if colorlog not available (or user specified non-color format) + formatter = logging.Formatter(log_format_file, datefmt=log_date_format) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + + console_handler.setLevel(console_log_level) # Console respects the config level + root_logger.addHandler(console_handler) + + # --- Rotating File Handler (if enabled) --- # + if log_to_file and io is not None: + try: + # Use run_id in log filename + log_file_path = io.path('logs', f'pipeline_{run_id}', suffix='.log') + # Use RotatingFileHandler with params from config + file_handler = logging.handlers.RotatingFileHandler( + log_file_path, + maxBytes=log_file_max_bytes, + backupCount=log_file_backup_count, + encoding='utf-8' + ) + file_formatter = logging.Formatter(log_format_file, datefmt=log_date_format) + file_handler.setFormatter(file_formatter) + file_handler.setLevel(file_log_level) # File respects its config level + root_logger.addHandler(file_handler) + logging.info(f"File logging ({file_log_level_str} level) configured at: {log_file_path}") + except Exception as e: + logging.error(f"Failed to configure file logging: {e}", exc_info=True) + elif log_to_file and io is None: + logging.warning("File logging enabled in config, but IOManager not provided. Cannot configure file handler.") + else: + logging.info("File logging disabled in config.") + + # --- Set TensorFlow Log Level --- # + # Quieter TF logs by default + tf_log_level = os.environ.get('TF_CPP_MIN_LOG_LEVEL', '2') # Default to ERROR + os.environ['TF_CPP_MIN_LOG_LEVEL'] = tf_log_level + # Also set Python TF logger level (optional, affects tf.get_logger()) + tf_logger = logging.getLogger('tensorflow') + if tf_log_level == '0': # ALL + tf_logger.setLevel(logging.DEBUG) + elif tf_log_level == '1': # INFO + tf_logger.setLevel(logging.INFO) + elif tf_log_level == '2': # WARNING + tf_logger.setLevel(logging.WARNING) + else: # 3 = ERROR + tf_logger.setLevel(logging.ERROR) + logging.info(f"TensorFlow logging level set based on TF_CPP_MIN_LOG_LEVEL={tf_log_level}") + + logging.info(f"Root logger configured. Console level: {console_log_level_str}, File level: {file_log_level_str}") + return root_logger + +# Example usage: +if __name__ == '__main__': + mock_config = { + 'base_dirs': {'results': 'temp_results', 'models': 'temp_models', 'logs': 'temp_logs'}, + 'output': {'log_level': 'INFO'}, + 'logging': { + 'console_level': 'INFO', + 'file_level': 'DEBUG', + 'log_to_file': True, + 'log_format_console': '%(log_color)s%(levelname)-8s%(reset)s | %(name)-12s | %(message)s', + 'log_format': '%(asctime)s | %(levelname)-8s | %(name)-15s | %(filename)s:%(lineno)d | %(message)s', + 'log_date_format': '%Y-%m-%d %H:%M:%S', + 'log_file_max_bytes': 10*1024*1024, + 'log_file_backup_count': 5 + } + } + mock_run_id = "20250418_113000_logtest" + + # Need a mock IOManager for the example + if IOManager is None: + print("Mocking IOManager as it couldn't be imported.") + class MockIOManager: + def __init__(self, cfg, run_id): + self.run_logs_dir = os.path.join('temp_logs', run_id) + os.makedirs(self.run_logs_dir, exist_ok=True) + def path(self, section, name, suffix): + return os.path.join(self.run_logs_dir, f"{name}{suffix}") + io_manager = MockIOManager(mock_config, mock_run_id) + else: + if not os.path.exists('temp_logs'): os.makedirs('temp_logs') + io_manager = IOManager(mock_config, mock_run_id) + + logger_instance = setup_logger(mock_config, mock_run_id, io_manager) + + logger_instance.debug("This is a debug message (should only go to file).") + logger_instance.info("This is an info message.") + logger_instance.warning("This is a warning message.") + logger_instance.error("This is an error message.") + + print(f"Check log file in: {io_manager.run_logs_dir}") + + # Clean up + import shutil + if os.path.exists('temp_logs'): shutil.rmtree('temp_logs') \ No newline at end of file diff --git a/gru_sac_predictor/src/metrics.py b/gru_sac_predictor/src/metrics.py new file mode 100644 index 00000000..daeb7bab --- /dev/null +++ b/gru_sac_predictor/src/metrics.py @@ -0,0 +1,184 @@ +""" +Custom Metrics for Trading Performance Evaluation. + +Ref: revisions.txt Section 6 +""" + +import numpy as np +from typing import Tuple +import logging +import pandas as pd +from sklearn.metrics import brier_score_loss + +logger = logging.getLogger(__name__) + +def edge_filtered_accuracy(y_true: np.ndarray, p_cal: np.ndarray, thr: float = 0.1) -> Tuple[float, int]: + """ + Calculates accuracy only on samples where the calibrated prediction + has sufficient edge (confidence). + + Args: + y_true (np.ndarray): True binary labels (0 or 1, potentially soft). + p_cal (np.ndarray): Calibrated probabilities P(up) (shape: N,). + thr (float): Edge threshold. Only samples where |2*p_cal - 1| >= thr are included. + Defaults to 0.1 (equivalent to p_cal >= 0.55 or p_cal <= 0.45). + + Returns: + Tuple[float, int]: + - Accuracy on the filtered samples (NaN if no samples meet threshold). + - Number of samples included in the calculation. + """ + if len(y_true) != len(p_cal): + raise ValueError("Length mismatch between y_true and p_cal.") + if len(y_true) == 0: + return np.nan, 0 + + y_true = np.asarray(y_true) + p_cal = np.asarray(p_cal) + + # Calculate edge + edge = np.abs(2 * p_cal - 1) + + # Create mask + mask = edge >= thr + n_filtered = int(np.sum(mask)) + + if n_filtered == 0: + logger.warning(f"No samples met edge threshold {thr:.3f}. Cannot calculate edge-filtered accuracy.") + return np.nan, 0 + + # Filter data + p_cal_filtered = p_cal[mask] + y_true_filtered = y_true[mask] + + # Predict direction based on calibrated probability > 0.5 + y_pred_filtered = (p_cal_filtered > 0.5).astype(int) + + # Handle potentially soft true labels + if not np.all((y_true_filtered == 0) | (y_true_filtered == 1)): + logger.debug("Soft labels detected in y_true_filtered. Comparing predictions to > 0.5 threshold.") + y_true_hard_filtered = (y_true_filtered > 0.5).astype(int) + else: + y_true_hard_filtered = y_true_filtered.astype(int) + + # Calculate accuracy + accuracy = np.mean(y_pred_filtered == y_true_hard_filtered) + + # logger.debug(f"Edge>={thr:.2f}: Acc={accuracy:.4f}, N={n_filtered}/{len(y_true)}") + return accuracy, n_filtered + +def calculate_brier_score(y_true: np.ndarray, p_cal: np.ndarray) -> float: + """ + Calculates the Brier score for predicted probabilities. + + Args: + y_true (np.ndarray): True binary labels (0 or 1, potentially soft). + p_cal (np.ndarray): Calibrated probabilities P(up) (shape: N,). + + Returns: + float: The Brier score (lower is better). + """ + if len(y_true) != len(p_cal): + raise ValueError("Length mismatch between y_true and p_cal.") + if len(y_true) == 0: + return np.nan + + y_true = np.asarray(y_true) + p_cal = np.asarray(p_cal) + + # Handle soft labels by converting them to binary (0 or 1) for Brier score + if not np.all((y_true == 0) | (y_true == 1)): + logger.debug("Soft labels detected in y_true. Converting to hard binary for Brier score calculation.") + y_true_hard = (y_true > 0.5).astype(int) + else: + y_true_hard = y_true.astype(int) + + # Ensure probabilities are within [0, 1] + p_cal = np.clip(p_cal, 0.0, 1.0) + + return brier_score_loss(y_true_hard, p_cal) + +# --- TODO: Add other metrics from Section 6 --- # +# - CI lower bound calculation helper? (Done implicitly in pipeline check) +# - Re-centred Sharpe calculation? + +def calculate_sharpe_ratio(returns: pd.Series, benchmark_return: float = 0.0, annualization_factor: int = 252*24*60) -> float: + """ + Calculates the annualized Sharpe ratio relative to a benchmark. + + Args: + returns (pd.Series): Series of portfolio returns (e.g., daily or per-period). + Assumes returns are fractional (e.g., 0.01 for 1%). + benchmark_return (float): The benchmark return per period (e.g., risk-free rate). + Defaults to 0.0. + annualization_factor (int): Factor to annualize Sharpe ratio (e.g., 252 for daily, + 52 for weekly, 12 for monthly, 252*24*60 for 1-min). + + Returns: + float: The annualized Sharpe ratio. + """ + if not isinstance(returns, pd.Series): + returns = pd.Series(returns) + + if returns.empty or returns.isnull().all(): + return np.nan + + # Calculate excess returns over the benchmark + excess_returns = returns - benchmark_return + + # Calculate mean and standard deviation of excess returns + mean_excess_return = excess_returns.mean() + std_excess_return = excess_returns.std() + + if std_excess_return == 0 or np.isnan(std_excess_return): + # Handle cases with zero or undefined volatility + return 0.0 if mean_excess_return > 0 else (-np.inf if mean_excess_return < 0 else 0.0) + + # Calculate per-period Sharpe ratio + sharpe_ratio_period = mean_excess_return / std_excess_return + + # Annualize Sharpe ratio + annualized_sharpe_ratio = sharpe_ratio_period * np.sqrt(annualization_factor) + + return annualized_sharpe_ratio + +# --- Helper for Youden's J Optimization --- # +def _calculate_optimal_edge_threshold(y_true: np.ndarray, p_cal: np.ndarray) -> float: + """ + Finds the optimal edge threshold by maximizing Youden's J statistic. + + Args: + y_true: True binary labels (0 or 1). + p_cal: Calibrated probabilities (P(1)). + + Returns: + Optimal threshold (float). + """ + try: + from sklearn.metrics import roc_curve + except ImportError: + logger.error("Scikit-learn not installed. Cannot optimize edge threshold. Returning default 0.1") + return 0.1 # Default fallback + + if len(np.unique(y_true)) < 2: + logger.warning("Only one class present in y_true for threshold optimization. Returning default 0.1") + return 0.1 + + fpr, tpr, thresholds = roc_curve(y_true, p_cal) + youden_j = tpr - fpr + # Filter out invalid thresholds (often includes inf/-inf) + valid_thresholds = thresholds[np.isfinite(thresholds)] + if len(valid_thresholds) == 0: + logger.warning("No valid thresholds found during ROC curve calculation. Returning default 0.1") + return 0.1 + + # Ensure we handle the case where thresholds might not align perfectly with youden_j after filtering + optimal_idx = np.argmax(youden_j[np.isfinite(thresholds)]) + optimal_threshold = valid_thresholds[optimal_idx] + + # Ensure threshold is within a reasonable range (e.g., 0.01 to 0.49) to avoid extremes + optimal_threshold = np.clip(optimal_threshold, 0.01, 0.49) + + logger.info(f"Optimized edge threshold via Youden's J: {optimal_threshold:.4f}") + return optimal_threshold +# --- End Helper --- # \ No newline at end of file diff --git a/gru_sac_predictor/src/sac_agent.py b/gru_sac_predictor/src/sac_agent.py index bb5e0085..65bfd8ba 100644 --- a/gru_sac_predictor/src/sac_agent.py +++ b/gru_sac_predictor/src/sac_agent.py @@ -5,6 +5,7 @@ import tensorflow_probability as tfp from tensorflow.keras.optimizers.schedules import ExponentialDecay import logging import os +import json sac_logger = logging.getLogger(__name__) sac_logger.setLevel(logging.INFO) @@ -30,75 +31,11 @@ class OrnsteinUhlenbeckActionNoise: def reset(self): self.x_prev = self.x_initial if self.x_initial is not None else np.zeros_like(self.mean) -class ReplayBuffer: - """Standard Experience replay buffer for SAC agent""" - - def __init__(self, capacity=100000, state_dim=2, action_dim=1): - self.capacity = capacity - self.counter = 0 - - # Initialize buffer arrays - self.states = np.zeros((capacity, state_dim), dtype=np.float32) - self.actions = np.zeros((capacity, action_dim), dtype=np.float32) - self.rewards = np.zeros((capacity, 1), dtype=np.float32) - self.next_states = np.zeros((capacity, state_dim), dtype=np.float32) - self.dones = np.zeros((capacity, 1), dtype=np.float32) - - def add(self, state, action, reward, next_state, done): - """Add experience to buffer""" - idx = self.counter % self.capacity - - # Ensure inputs are correctly shaped numpy arrays - state = np.array(state, dtype=np.float32).flatten() - action = np.array(action, dtype=np.float32).flatten() - reward = np.array([reward], dtype=np.float32) - next_state = np.array(next_state, dtype=np.float32).flatten() - done = np.array([done], dtype=np.float32) - - if state.shape[0] != self.states.shape[1]: - sac_logger.error(f"State shape mismatch: {state.shape} vs {self.states.shape[1]}") - return - if next_state.shape[0] != self.next_states.shape[1]: - sac_logger.error(f"Next State shape mismatch: {next_state.shape} vs {self.next_states.shape[1]}") - return - if action.shape[0] != self.actions.shape[1]: - sac_logger.error(f"Action shape mismatch: {action.shape} vs {self.actions.shape[1]}") - return - - self.states[idx] = state - self.actions[idx] = action - self.rewards[idx] = reward - self.next_states[idx] = next_state - self.dones[idx] = done - - self.counter += 1 - - def sample(self, batch_size): - """Sample batch of experiences from buffer""" - max_idx = min(self.counter, self.capacity) - if max_idx < batch_size: - print(f"Warning: Trying to sample {batch_size} elements, but buffer only has {max_idx}. Sampling with replacement.") - indices = np.random.choice(max_idx, batch_size, replace=True) - else: - indices = np.random.choice(max_idx, batch_size, replace=False) - - states = tf.convert_to_tensor(self.states[indices], dtype=tf.float32) - actions = tf.convert_to_tensor(self.actions[indices], dtype=tf.float32) - rewards = tf.convert_to_tensor(self.rewards[indices], dtype=tf.float32) - next_states = tf.convert_to_tensor(self.next_states[indices], dtype=tf.float32) - dones = tf.convert_to_tensor(self.dones[indices], dtype=tf.float32) - - return states, actions, rewards, next_states, dones - - def __len__(self): - """Get current size of buffer""" - return min(self.counter, self.capacity) - class SACTradingAgent: """V7.3 Enhanced: SAC agent with updated params and architecture fixes.""" def __init__(self, - state_dim=2, # Standard [pred_ret, uncert] + state_dim=5, # [mu, sigma, edge, |mu|/sigma, position] action_dim=1, gamma=0.99, tau=0.005, @@ -106,31 +43,64 @@ class SACTradingAgent: decay_steps=100000, end_lr=5e-6, # Note: End LR not directly used by ExponentialDecay lr_decay_rate=0.96, - buffer_capacity=100000, ou_noise_stddev=0.2, ou_noise_theta=0.15, ou_noise_dt=0.01, alpha=0.2, alpha_auto_tune=True, target_entropy=-1.0, - min_buffer_size=1000): + edge_threshold_config: float | None = None, + reward_scale_config: float | None = None, + action_penalty_lambda_config: float | None = None): """ Initialize the SAC agent with enhancements. + Args: + state_dim (int): The dimension of the state space. + action_dim (int): The dimension of the action space. + gamma (float): The discount factor. + tau (float): The target network update rate. + initial_lr (float): The initial learning rate. + decay_steps (int): The number of steps over which to decay the learning rate. + end_lr (float): The final learning rate. (Not used by ExponentialDecay) + lr_decay_rate (float): The rate at which to decay the learning rate. + ou_noise_stddev (float): The standard deviation of the Ornstein-Uhlenbeck noise. + ou_noise_theta (float): The theta parameter of the Ornstein-Uhlenbeck noise. + ou_noise_dt (float): The dt parameter of the Ornstein-Uhlenbeck noise. + alpha (float): The initial alpha value. + alpha_auto_tune (bool): Whether to automatically tune the alpha value. + target_entropy (float): The target entropy for the alpha auto-tuning. + edge_threshold_config (float | None): The edge threshold value from the config. + reward_scale_config (float | None): Environment reward scale used during training. + action_penalty_lambda_config (float | None): Action penalty lambda used during training. """ self.state_dim = state_dim self.action_dim = action_dim self.gamma = gamma self.tau = tau - self.min_buffer_size = min_buffer_size - self.target_entropy = tf.constant(target_entropy, dtype=tf.float32) self.alpha_auto_tune = alpha_auto_tune + self.edge_threshold_config = edge_threshold_config + self.reward_scale_config = reward_scale_config + self.action_penalty_lambda_config = action_penalty_lambda_config + # --- Target Entropy Calculation --- # + effective_target_entropy = target_entropy + default_target_entropy_value = -1.0 * float(self.action_dim) if self.alpha_auto_tune: + if abs(target_entropy - default_target_entropy_value) < 1e-6: + effective_target_entropy = -0.5 * np.log(4.0) + sac_logger.info(f"alpha_auto_tune=True and default target_entropy detected. Setting target_entropy to -0.5*log(4) = {effective_target_entropy:.4f}") + else: + effective_target_entropy = target_entropy + sac_logger.info(f"alpha_auto_tune=True, using explicitly provided target_entropy: {effective_target_entropy:.4f}") self.log_alpha = tf.Variable(tf.math.log(alpha), trainable=True, name='log_alpha') self.alpha = tfp.util.DeferredTensor(self.log_alpha, tf.exp) - self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=initial_lr) + self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=float(initial_lr)) else: + effective_target_entropy = target_entropy self.alpha = tf.constant(alpha, dtype=tf.float32) + sac_logger.info(f"alpha_auto_tune=False. Using fixed alpha={self.alpha:.4f}") + self.target_entropy = tf.constant(effective_target_entropy, dtype=tf.float32) + # --- End Target Entropy --- # self.ou_noise = OrnsteinUhlenbeckActionNoise( mean=np.zeros(action_dim), @@ -138,8 +108,10 @@ class SACTradingAgent: theta=ou_noise_theta, dt=ou_noise_dt) self.lr_schedule = ExponentialDecay( - initial_learning_rate=initial_lr, decay_steps=decay_steps, - decay_rate=lr_decay_rate, staircase=False) + initial_learning_rate=float(initial_lr), + decay_steps=int(decay_steps), + decay_rate=float(lr_decay_rate), + staircase=False) sac_logger.info(f"Using ExponentialDecay LR: init={initial_lr}, steps={decay_steps}, rate={lr_decay_rate}") self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr_schedule) self.critic1_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr_schedule) @@ -147,22 +119,19 @@ class SACTradingAgent: # Initialize networks self.actor = self._build_actor() - self.critic1 = self._build_critic() # Outputs [Q_mean, Q_log_std] + self.critic1 = self._build_critic() self.critic2 = self._build_critic() self.target_critic1 = self._build_critic() self.target_critic2 = self._build_critic() self.update_target_networks(tau=1.0) - self.buffer = ReplayBuffer(capacity=buffer_capacity, state_dim=state_dim, action_dim=action_dim) - sac_logger.info("Enhanced SAC Agent Initialized (V7.3).") + sac_logger.info("Enhanced SAC Agent Initialized (V7.3 - Consolidated).") sac_logger.info(f" State Dim: {state_dim}, Action Dim: {action_dim}") sac_logger.info(f" Hyperparams: gamma={gamma}, tau={tau}, alpha={'auto' if alpha_auto_tune else alpha}, target_entropy={target_entropy}") sac_logger.info(f" LR Schedule: Exponential {initial_lr} -> ? (decay_rate={lr_decay_rate})") - sac_logger.info(f" Buffer: {buffer_capacity}, Min Size: {min_buffer_size}, Batch Size: Default 256 (in train)") sac_logger.info(f" OU Noise: std={ou_noise_stddev}, theta={ou_noise_theta}, dt={ou_noise_dt}") - sac_logger.info(f" PER Note: Standard buffer used (PER={False})") - + def _build_actor(self): inputs = layers.Input(shape=(self.state_dim,)) x1 = layers.Dense(128, activation='relu')(inputs); x1_norm = layers.BatchNormalization()(x1) @@ -221,6 +190,7 @@ class SACTradingAgent: for target_weight, weight in zip(self.target_critic1.weights, self.critic1.weights): target_weight.assign(tau * weight + (1.0 - tau) * target_weight) for target_weight, weight in zip(self.target_critic2.weights, self.critic2.weights): target_weight.assign(tau * weight + (1.0 - tau) * target_weight) + @tf.function def _update_critics(self, states, actions, rewards, next_states, dones): next_means, next_log_stds = self.actor(next_states); next_stds = tf.exp(next_log_stds) next_distributions = tfp.distributions.Normal(loc=next_means, scale=next_stds) @@ -234,43 +204,45 @@ class SACTradingAgent: target_q_min_mean = tf.minimum(target_q1_mean, target_q2_mean) target_q = target_q_min_mean - self.alpha * next_log_probs target_q_values = rewards + (1.0 - dones) * self.gamma * target_q + # Ensure target_q_values is correctly shaped [batch_size, 1] + target_q_values_stopped = tf.stop_gradient(target_q_values) - # Explicitly get trainable variables before the tape critic1_vars = self.critic1.trainable_variables critic2_vars = self.critic2.trainable_variables with tf.GradientTape(persistent=True) as tape: - # Ensure the tape watches the correct variables if needed, though default should be fine - # tape.watch(critic1_vars) - # tape.watch(critic2_vars) - current_q1_mean, current_q1_log_std = self.critic1([states, actions]) current_q2_mean, current_q2_log_std = self.critic2([states, actions]) pred_dist1 = tfp.distributions.Normal(loc=current_q1_mean, scale=tf.exp(current_q1_log_std)) pred_dist2 = tfp.distributions.Normal(loc=current_q2_mean, scale=tf.exp(current_q2_log_std)) - nll_loss1 = -pred_dist1.log_prob(tf.stop_gradient(target_q_values)) - nll_loss2 = -pred_dist2.log_prob(tf.stop_gradient(target_q_values)) + nll_loss1 = -pred_dist1.log_prob(target_q_values_stopped) # Use stopped gradient here + nll_loss2 = -pred_dist2.log_prob(target_q_values_stopped) # Use stopped gradient here critic1_loss = tf.reduce_mean(nll_loss1) critic2_loss = tf.reduce_mean(nll_loss2) + + # --- Calculate TD Errors for PER --- # + # Use the mean of the prediction as the current Q estimate for error calculation + current_q_min = tf.minimum(current_q1_mean, current_q2_mean) + td_errors = tf.abs(target_q_values - current_q_min) # Target Q already has stopped gradient from above + # Ensure td_errors have shape [batch_size, 1] or [batch_size] + td_errors = tf.squeeze(td_errors) # Squeeze to [batch_size] if needed by buffer + # --- End TD Error Calculation --- # - # Calculate gradients w.r.t the specific variable lists critic1_gradients = tape.gradient(critic1_loss, critic1_vars) critic2_gradients = tape.gradient(critic2_loss, critic2_vars) del tape - # Apply gradients paired with the specific variable lists using separate optimizers self.critic1_optimizer.apply_gradients(zip(critic1_gradients, critic1_vars)) self.critic2_optimizer.apply_gradients(zip(critic2_gradients, critic2_vars)) - return critic1_loss, critic2_loss + # Return losses and TD errors + return critic1_loss, critic2_loss, td_errors + @tf.function def _update_actor(self, states): - # Explicitly get trainable variables before the tape actor_vars = self.actor.trainable_variables - with tf.GradientTape() as tape: - # tape.watch(actor_vars) means, log_stds = self.actor(states); stds = tf.exp(log_stds) distributions = tfp.distributions.Normal(loc=means, scale=stds) actions_raw = distributions.sample(); actions_tanh = tf.tanh(actions_raw) @@ -284,12 +256,11 @@ class SACTradingAgent: actor_loss = tf.reduce_mean(self.alpha * log_probs - q_min_mean) - # Calculate gradients w.r.t the specific variable list actor_gradients = tape.gradient(actor_loss, actor_vars) - # Apply gradients paired with the specific variable list self.actor_optimizer.apply_gradients(zip(actor_gradients, actor_vars)) return actor_loss, log_probs + @tf.function def _update_alpha(self, log_probs): with tf.GradientTape() as tape: alpha_loss = -tf.reduce_mean(self.log_alpha * tf.stop_gradient(log_probs + self.target_entropy)) @@ -298,26 +269,43 @@ class SACTradingAgent: self.alpha_optimizer.apply_gradients(zip(alpha_gradients, [self.log_alpha])) return alpha_loss - def train(self, batch_size=256): + def train(self, states, actions, rewards, next_states, dones, importance_weights=None): """ - Train the enhanced SAC agent. - Includes alpha auto-tuning. - Reverted aux tasks and state dim. + Perform a single training update step using a batch of experience. + Now accepts and returns values needed for PER. + + Args: + states (tf.Tensor): Batch of states. + actions (tf.Tensor): Batch of actions. + rewards (tf.Tensor): Batch of rewards. + next_states (tf.Tensor): Batch of next states. + dones (tf.Tensor): Batch of done flags. + importance_weights (tf.Tensor, optional): Importance sampling weights for PER. Defaults to None. + + Returns: + dict: Dictionary containing loss metrics and TD errors. """ - if len(self.buffer) < self.min_buffer_size: - return {} + # Apply importance weights if provided (for PER) + sample_weights = importance_weights if importance_weights is not None else tf.ones_like(rewards) - states, actions, rewards, next_states, dones = self.buffer.sample(batch_size) - - critic1_loss, critic2_loss = self._update_critics( + # Critic updates now return TD errors + critic1_loss, critic2_loss, td_errors = self._update_critics( states, actions, rewards, next_states, dones ) + # Apply importance weights to critic losses (if needed by loss definition, often handled internally) + # Example: critic1_loss = tf.reduce_mean(critic1_loss * sample_weights) + # Note: If using NLL loss, applying weights might need careful consideration. + # For simplicity here, we assume the mean reduction inside _update_critics is sufficient, + # but a true PER implementation might weight the NLL terms *before* reduction. actor_loss, log_probs = self._update_actor(states) + # Actor loss already uses Q values which implicitly account for importance via critic updates? + # Typically, actor loss is not directly weighted by IS weights. alpha_loss = None if self.alpha_auto_tune: alpha_loss = self._update_alpha(log_probs) + # Alpha loss typically not weighted by IS weights. self.update_target_networks() @@ -326,7 +314,8 @@ class SACTradingAgent: "critic2_loss": float(critic2_loss), "actor_loss": float(actor_loss), "learning_rate": float(self.lr_schedule(self.actor_optimizer.iterations)), - "alpha": float(tf.exp(self.log_alpha)) if self.alpha_auto_tune else float(self.alpha) + "alpha": float(tf.exp(self.log_alpha)) if self.alpha_auto_tune else float(self.alpha), + "td_errors": td_errors.numpy() # Return TD errors for priority updates } if alpha_loss is not None: metrics["alpha_loss"] = float(alpha_loss) @@ -334,24 +323,90 @@ class SACTradingAgent: return metrics def save(self, path): + """Saves agent weights and potentially metadata.""" try: - self.actor.save_weights(f"{path}/actor.weights.h5"); self.critic1.save_weights(f"{path}/critic1.weights.h5") - self.critic2.save_weights(f"{path}/critic2.weights.h5") - if self.alpha_auto_tune and hasattr(self, 'log_alpha'): np.save(f"{path}/log_alpha.npy", self.log_alpha.numpy()) - sac_logger.info(f"Enhanced SAC Agent weights saved to {path}/") - except Exception as e: sac_logger.error(f"Error saving SAC weights: {e}") - + os.makedirs(path, exist_ok=True) + self.actor.save_weights(os.path.join(path, "actor.weights.h5")) + self.critic1.save_weights(os.path.join(path, "critic1.weights.h5")) + self.critic2.save_weights(os.path.join(path, "critic2.weights.h5")) + + metadata = {} + if self.alpha_auto_tune and hasattr(self, 'log_alpha'): + # Save log_alpha directly + np.save(os.path.join(path, "log_alpha.npy"), self.log_alpha.numpy()) + metadata['log_alpha_saved'] = True + else: + metadata['log_alpha_saved'] = False + metadata['fixed_alpha'] = float(self.alpha) # Save fixed alpha value + + # Add other relevant metadata if needed (like state_dim, action_dim used during training) + metadata['state_dim'] = self.state_dim + metadata['action_dim'] = self.action_dim + # Add edge threshold to metadata (Step 4-A) + if self.edge_threshold_config is not None: + metadata['edge_threshold'] = self.edge_threshold_config # Use the key specified in revisions.txt + else: + metadata['edge_threshold'] = None # Or a default value like 0.55? + # --- Add Env Params to Metadata (Task 5.6) --- # + if self.reward_scale_config is not None: + metadata['reward_scale'] = self.reward_scale_config + else: + metadata['reward_scale'] = None # Indicate if not set + if self.action_penalty_lambda_config is not None: + metadata['lambda'] = self.action_penalty_lambda_config # Use lambda key from revisions.txt + else: + metadata['lambda'] = None + # --- End Add Env Params --- # + + meta_path = os.path.join(path, 'agent_metadata.json') + with open(meta_path, 'w') as f: + json.dump(metadata, f, indent=4) + + sac_logger.info(f"SAC Agent weights and metadata saved to {path}/") + except Exception as e: + sac_logger.error(f"Error saving SAC weights/metadata: {e}", exc_info=True) + def load(self, path): + """Loads agent weights and potentially metadata.""" try: + # Load weights (existing logic seems ok, ensures models are built) if not self.actor.built: self.actor.build((None, self.state_dim)) if not self.critic1.built: self.critic1.build([(None, self.state_dim), (None, self.action_dim)]) if not self.critic2.built: self.critic2.build([(None, self.state_dim), (None, self.action_dim)]) if not self.target_critic1.built: self.target_critic1.build([(None, self.state_dim), (None, self.action_dim)]) if not self.target_critic2.built: self.target_critic2.build([(None, self.state_dim), (None, self.action_dim)]) - self.actor.load_weights(f"{path}/actor.weights.h5"); self.critic1.load_weights(f"{path}/critic1.weights.h5") - self.critic2.load_weights(f"{path}/critic2.weights.h5"); self.target_critic1.load_weights(f"{path}/critic1.weights.h5") - self.target_critic2.load_weights(f"{path}/critic2.weights.h5") - log_alpha_path = f"{path}/log_alpha.npy" - if self.alpha_auto_tune and os.path.exists(log_alpha_path): self.log_alpha.assign(np.load(log_alpha_path)); sac_logger.info(f"Loaded log_alpha value") - sac_logger.info(f"Enhanced SAC Agent weights loaded from {path}/") - except Exception as e: sac_logger.error(f"Error loading SAC weights from {path}: {e}. Ensure files exist/shapes match.") \ No newline at end of file + self.actor.load_weights(os.path.join(path, "actor.weights.h5")); self.critic1.load_weights(os.path.join(path, "critic1.weights.h5")) + self.critic2.load_weights(os.path.join(path, "critic2.weights.h5")); self.target_critic1.load_weights(os.path.join(path, "critic1.weights.h5")) + self.target_critic2.load_weights(os.path.join(path, "critic2.weights.h5")) + + # Load metadata + meta_path = os.path.join(path, 'agent_metadata.json') + metadata = {} + if os.path.exists(meta_path): + with open(meta_path, 'r') as f: + metadata = json.load(f) + sac_logger.info(f"Loaded agent metadata: {metadata}") + + # Load log_alpha if saved and auto-tuning + log_alpha_path = os.path.join(path, "log_alpha.npy") + if self.alpha_auto_tune and metadata.get('log_alpha_saved', False) and os.path.exists(log_alpha_path): + self.log_alpha.assign(np.load(log_alpha_path)) + sac_logger.info(f"Restored log_alpha value from saved state.") + elif not self.alpha_auto_tune and 'fixed_alpha' in metadata: + # Restore fixed alpha if not auto-tuning + self.alpha = tf.constant(metadata['fixed_alpha'], dtype=tf.float32) + sac_logger.info(f"Restored fixed alpha value: {self.alpha:.4f}") + + else: + sac_logger.warning(f"Agent metadata file not found at {meta_path}. Cannot verify parameters or load log_alpha.") + + sac_logger.info(f"SAC Agent weights loaded from {path}/") + return metadata # Return metadata for potential checks + + except Exception as e: + sac_logger.error(f"Error loading SAC weights/metadata from {path}: {e}. Ensure files exist/shapes match.", exc_info=True) + return {} # Return empty dict on failure + + def clear_buffer(self): + """Clears the agent's replay buffer.""" + sac_logger.error("Agent or buffer does not have a clear_buffer method.") \ No newline at end of file diff --git a/gru_sac_predictor/src/sac_trainer.py b/gru_sac_predictor/src/sac_trainer.py new file mode 100644 index 00000000..74f72840 --- /dev/null +++ b/gru_sac_predictor/src/sac_trainer.py @@ -0,0 +1,1044 @@ +""" +SAC Trainer Component. + +Handles the offline training process for the SAC agent, including loading +dependencies from a specified GRU run, preparing data for the environment, +and executing the training loop. +""" + +import tensorflow as tf +import json +import logging +import os +import sys +import yaml +import joblib +import pandas as pd +import numpy as np +from datetime import datetime +from tqdm import tqdm +from tensorflow.keras.callbacks import TensorBoard +import collections +import csv +import time + +# Import necessary components from the pipeline +# Use absolute imports assuming the package structure is correct +from gru_sac_predictor.src.data_loader import DataLoader +from gru_sac_predictor.src.feature_engineer import FeatureEngineer +from gru_sac_predictor.src.gru_model_handler import GRUModelHandler +from gru_sac_predictor.src.calibrator import Calibrator +from gru_sac_predictor.src.sac_agent import SACTradingAgent +from gru_sac_predictor.src.trading_env import TradingEnv +try: + from gru_sac_predictor.src.features import minimal_whitelist # For FE fallback +except ImportError: + minimal_whitelist = [] # Define empty if import fails + +# --- Import MeanStdFilter (Task 5.2) --- # +try: + from gru_sac_predictor.src.utils.running_stats import MeanStdFilter + STATE_FILTER_AVAILABLE = True +except ImportError: + logging.warning("MeanStdFilter not found in utils. State normalization will be disabled.") + MeanStdFilter = None + STATE_FILTER_AVAILABLE = False +# --- End Import --- # + +logger = logging.getLogger(__name__) + +# ============================================================================== +# Prioritized Experience Replay (PER) Buffer Implementation +# ============================================================================== + +class SumTree: + """ Simple SumTree implementation for PER priority sampling.""" + write = 0 + + def __init__(self, capacity): + self.capacity = capacity + # Tree structure: Internal nodes + leaves + self.tree = np.zeros(2 * capacity - 1) + # Data storage (corresponds to leaves) + self.data = np.zeros(capacity, dtype=object) + self.n_entries = 0 + + def _propagate(self, idx, change): + parent = (idx - 1) // 2 + self.tree[parent] += change + if parent != 0: + self._propagate(parent, change) + + def _retrieve(self, idx, s): + left = 2 * idx + 1 + right = left + 1 + + if left >= len(self.tree): + return idx + + if s <= self.tree[left]: + return self._retrieve(left, s) + else: + return self._retrieve(right, s - self.tree[left]) + + def total(self): + return self.tree[0] + + def add(self, p, data): + # Data is now expected to be (sample, is_seeded) + idx = self.write + self.capacity - 1 + + self.data[self.write] = data + self.update(idx, p) + + self.write += 1 + if self.write >= self.capacity: + self.write = 0 + + if self.n_entries < self.capacity: + self.n_entries += 1 + + def update(self, idx, p): + change = p - self.tree[idx] + self.tree[idx] = p + self._propagate(idx, change) + + def get(self, s): + idx = self._retrieve(0, s) + dataIdx = idx - self.capacity + 1 + # Return index, priority, and the stored (sample, is_seeded) tuple + return (idx, self.tree[idx], self.data[dataIdx]) + + +class PrioritizedReplayBuffer: + """ PER buffer using a SumTree.""" + epsilon = 0.01 # Small amount to avoid zero priority + beta = 0.4 # Importance sampling exponent + beta_increment_per_sampling = 0.001 + abs_err_upper = 1. # Clipped abs error + + def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_frames=100000): + self.tree = SumTree(capacity) + self.capacity = capacity + # Alpha is annealed externally now + self.beta = beta_start + # Calculate beta increment based on total frames/steps + self.beta_increment_per_sampling = (1. - beta_start) / beta_frames + self.max_priority = 1.0 # Initialize max priority + + # --- Revision 5: Oracle Seed Decay Steps --- # + self.per_seed_decay_steps = beta_frames // 2 # Default: decay over half training + logger.info(f"PER Oracle Seed IS weight decay steps: {self.per_seed_decay_steps}") + # --- End Revision 5 --- # + + def add(self, error, sample, is_seeded=False): + """ Store new experience with initial priority and seeded flag. """ + p = self.max_priority # Use max priority for new samples + # Store sample along with its seeded status + self.tree.add(p, (sample, is_seeded)) + + def sample(self, n, beta=None): + """ Sample batch, returning indices, samples, and IS weights. """ + if beta is None: + beta = self.beta # Use current beta if not specified + + batch = [] + idxs = [] + segment = self.tree.total() / n + priorities = [] + + # Anneal beta + self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) + + for i in range(n): + a = segment * i + b = segment * (i + 1) + s = np.random.uniform(a, b) + (idx, p, data) = self.tree.get(s) + priorities.append(p) + # Unpack data: contains (sample, is_seeded) + if isinstance(data, tuple) and len(data) == 2: + batch.append(data) # Keep tuple for now + else: # Fallback for older buffer states maybe? + batch.append((data, False)) # Assume not seeded + idxs.append(idx) + + sampling_probabilities = np.array(priorities) / self.tree.total() + is_weight = np.power(self.tree.n_entries * sampling_probabilities, -beta) + is_weight /= is_weight.max() # Normalize weights + + return idxs, batch, is_weight + + # --- Revision 5: Modified priority calculation and update --- # + def _get_priority(self, error, alpha): + """Calculate priority: (clip(|error|) + eps) ^ alpha.""" + # Clip absolute error + clipped_abs_error = np.clip(np.abs(error), 0, self.abs_err_upper) + return (clipped_abs_error + self.epsilon) ** alpha + + def update_priorities(self, idxs, errors, alpha): + """ Update priorities of sampled transitions using current alpha. """ + for i, idx in enumerate(idxs): + # Use current annealed alpha + p = self._get_priority(errors[i], alpha) + self.max_priority = max(self.max_priority, p) # Keep track of max priority + self.tree.update(idx, p) + # --- End Revision 5 --- # + + def __len__(self): + return self.tree.n_entries + +# ============================================================================== + +class SACTrainer: + """Manages the offline SAC training workflow.""" + + def __init__(self, config: dict, base_models_dir: str, base_logs_dir: str, base_results_dir: str): + """ + Initialize the SACTrainer. + + Args: + config (dict): The main pipeline configuration dictionary. + base_models_dir (str): The base directory where all run models are stored (e.g., project_root/models). + base_logs_dir (str): Base directory for logs. + base_results_dir (str): Base directory for results. + """ + self.config = config + self.base_models_dir = base_models_dir + self.sac_cfg = config['sac'] + self.env_cfg = config.get('environment', {}) + self.control_cfg = config.get('control', {}) + self.data_cfg = config['data'] + + # --- Cache useful config values --- + self.use_ternary = self.config.get('gru', {}).get('use_ternary', False) + # --- Store PER config params --- # + self.use_per = self.sac_cfg.get('use_per', False) + self.per_alpha = self.sac_cfg.get('per_alpha', 0.6) + self.per_beta_start = self.sac_cfg.get('per_beta_start', 0.4) + # Estimate beta annealing frames based on total steps + total_training_steps = self.sac_cfg.get('total_training_steps', 100000) + self.per_beta_frames = self.sac_cfg.get('per_beta_frames', total_training_steps) + # --- End PER Config --- # + + # Generate a specific run ID for this SAC training instance + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.sac_train_run_id = f"sac_train_{timestamp}" + logger.info(f"Initializing SACTrainer with Run ID: {self.sac_train_run_id}") + + # Setup directories for this specific SAC training run + self.sac_run_models_dir = os.path.join(self.base_models_dir, self.sac_train_run_id) + self.sac_run_logs_dir = os.path.join(base_logs_dir, self.sac_train_run_id) + self.sac_run_results_dir = os.path.join(base_results_dir, self.sac_train_run_id) + self.sac_tb_log_dir = os.path.join(self.sac_run_logs_dir, 'tensorboard') + + os.makedirs(self.sac_run_models_dir, exist_ok=True) + os.makedirs(self.sac_run_logs_dir, exist_ok=True) + os.makedirs(self.sac_run_results_dir, exist_ok=True) + os.makedirs(self.sac_tb_log_dir, exist_ok=True) + + # --- Initialize State Filter (Task 5.2) --- # + # Get state dim - assumes TradingEnv has a fixed state_dim attribute + # We might need to instantiate a dummy env briefly to get this? + # Or hardcode based on current TradingEnv._get_state() + # Current state: [mu, sigma, edge, |mu|/sigma, position] -> dim=5 + # TODO: Get state_dim more robustly if env changes + self.state_dim_env = 5 # Hardcoded based on current TradingEnv + self.state_filter = None + if STATE_FILTER_AVAILABLE and self.config.get('sac',{}).get('use_state_filter', True): # Add config flag + logger.info(f"Initializing MeanStdFilter for state normalization (shape={self.state_dim_env}).") + self.state_filter = MeanStdFilter(shape=(self.state_dim_env,)) + else: + logger.warning("State filter is disabled (either unavailable or config flag is false).") + # --- End Initialize State Filter --- # + + # Configure logging specifically for this trainer instance if needed + # For now, relies on the pipeline's logger setup + logger.info(f" SAC Models Dir: {self.sac_run_models_dir}") + logger.info(f" SAC Logs Dir: {self.sac_run_logs_dir}") + logger.info(f" SAC Results Dir:{self.sac_run_results_dir}") + logger.info(f" SAC TB Dir: {self.sac_tb_log_dir}") + + # Save config subset relevant to SAC training? + # Or assume full config is saved by the main pipeline + + def _load_gru_dependencies(self, gru_run_id: str) -> dict | None: + """ + Loads artifacts (whitelist, scaler, GRU model, T) from a completed GRU pipeline run. + + Args: + gru_run_id (str): The run ID of the GRU pipeline run. + + Returns: + dict | None: A dictionary containing the loaded dependencies + ('whitelist', 'scaler', 'gru_model', 'optimal_T'), or None on failure. + """ + logger.info(f"--- Loading Dependencies from GRU Run ID: {gru_run_id} ---") + gru_run_models_dir = os.path.join(self.base_models_dir, gru_run_id) + if not os.path.exists(gru_run_models_dir): + logger.error(f"Models directory for GRU run {gru_run_id} not found at: {gru_run_models_dir}") + return None + + dependencies = {} + + # 1. Load Whitelist + whitelist_path = os.path.join(gru_run_models_dir, f"final_whitelist_{gru_run_id}.json") + try: + with open(whitelist_path, 'r') as f: + dependencies['whitelist'] = json.load(f) + logger.info(f"Loaded whitelist ({len(dependencies['whitelist'])} features) from {whitelist_path}") + if not dependencies['whitelist']: + raise ValueError("Loaded whitelist is empty.") + except Exception as e: + logger.error(f"Failed to load whitelist from {whitelist_path}: {e}", exc_info=True) + return None + + # 2. Load Scaler + scaler_path = os.path.join(gru_run_models_dir, f"feature_scaler_{gru_run_id}.joblib") + try: + dependencies['scaler'] = joblib.load(scaler_path) + logger.info(f"Loaded scaler from {scaler_path}") + except Exception as e: + logger.error(f"Failed to load scaler from {scaler_path}: {e}", exc_info=True) + return None + + # 3. Load GRU Model + model_path = os.path.join(gru_run_models_dir, f"gru_model_{gru_run_id}.keras") + # Need a temporary GRU handler instance to use its load method + # Pass self.config (the trainer's config) to the GRUModelHandler + temp_gru_handler = GRUModelHandler(run_id="temp_load", models_dir="temp_load", config=self.config) + # Attempt to load the model using the handler + # The load method likely needs the full path, not just the directory and ID + loaded_model_info = temp_gru_handler.load(model_path=model_path) # Pass full path + + # Adjust based on what gru_handler.load returns + # Assuming it returns (model, info_dict) or None + if loaded_model_info and isinstance(loaded_model_info, tuple) and len(loaded_model_info) > 0: + dependencies['gru_model'] = loaded_model_info[0] # Get the actual model + if dependencies['gru_model'] is None: + logger.error(f"GRU handler load method returned None for model from {model_path}") + return None + elif loaded_model_info and not isinstance(loaded_model_info, tuple): # If it just returns the model + dependencies['gru_model'] = loaded_model_info + else: # If it returned None or unexpected tuple + logger.error(f"Failed to load GRU model using handler from {model_path}") + return None + + logger.info(f"Loaded GRU model from {model_path}") + + # 4. Load Calibration Info and Parameters + calib_info_path = os.path.join(gru_run_models_dir, f"calibration_info_{gru_run_id}.json") + calib_params = None + calib_method = None + try: + with open(calib_info_path, 'r') as f: + calib_info = json.load(f) + logger.info(f"Loaded calibration info: {calib_info}") + calib_method = calib_info.get("method") + params_filename = calib_info.get("params_filename") + + if calib_method and params_filename: + params_path = os.path.join(gru_run_models_dir, params_filename) + if os.path.exists(params_path): + if calib_method == "temperature": + calib_params = float(np.load(params_path)) + logger.info(f"Loaded optimal temperature T={calib_params:.4f} from {params_path}") + elif calib_method == "vector": + # Load vector params (assuming saved as .npy) + calib_params = np.load(params_path, allow_pickle=True) + logger.info(f"Loaded vector calibration params from {params_path} (type: {type(calib_params)}, shape: {getattr(calib_params, 'shape', 'N/A')})") + else: + logger.warning(f"Unknown calibration method '{calib_method}' in info file. Cannot load params.") + else: + logger.error(f"Calibration parameter file specified in info ({params_filename}) not found at {params_path}") + return None # Fail if params file missing + elif calib_method in ["temperature_failed", "vector_failed", "skipped_mismatch", "temperature_error", "vector_error", "vector_unavailable"]: + logger.warning(f"Calibration was not successful or was skipped during GRU run (method: {calib_method}). SAC training may be suboptimal.") + # Decide if this is fatal. For now, let's allow it but store None. + calib_params = None + else: + logger.error(f"Invalid calibration info content: {calib_info}") + return None # Fail if info is invalid + + except FileNotFoundError: + logger.error(f"Calibration info file not found at {calib_info_path}. Cannot determine calibration parameters.") + # Check for legacy temperature file for backward compatibility? + legacy_temp_path = os.path.join(gru_run_models_dir, f"calibration_temp_{gru_run_id}.npy") + if os.path.exists(legacy_temp_path): + logger.warning("Calibration info file missing, but found legacy temperature file. Attempting to load it.") + try: + calib_params = float(np.load(legacy_temp_path)) + calib_method = "temperature" # Assume temperature + logger.info(f"Loaded legacy optimal temperature T={calib_params:.4f} from {legacy_temp_path}") + except Exception as legacy_e: + logger.error(f"Failed to load legacy temperature file {legacy_temp_path}: {legacy_e}") + return None # Fail if legacy load fails + else: + logger.error("Neither calibration info file nor legacy temperature file found.") + return None # Fail if no calibration info found + except Exception as e: + logger.error(f"Failed to load calibration info/parameters: {e}", exc_info=True) + return None + + # Store method and params in dependencies + dependencies['calibration_method'] = calib_method + dependencies['calibration_params'] = calib_params + + logger.info("--- Successfully loaded all GRU dependencies ---") + return dependencies + + def _prepare_data_for_sac(self, gru_dependencies: dict) -> tuple | None: + """ + Replicates the necessary data loading and preparation steps using the loaded + GRU dependencies, specifically focusing on the VALIDATION dataset to + generate inputs for the TradingEnv. + + Args: + gru_dependencies (dict): The dictionary returned by _load_gru_dependencies. + + Returns: + tuple | None: A tuple containing (mu_val, sigma_val, p_cal_val, actual_ret_val) + for the validation set, or None on failure. + """ + logger.info("--- Preparing Validation Data for SAC Environment --- ") + try: + # 1. Load Raw Data (using a temporary DataLoader) + temp_data_loader = DataLoader(db_dir=self.data_cfg['db_dir']) + df_raw = temp_data_loader.load_data( + ticker=self.data_cfg['ticker'], + exchange=self.data_cfg['exchange'], + start_date=self.data_cfg['start_date'], + end_date=self.data_cfg['end_date'], + interval=self.data_cfg['interval'] + ) + if df_raw is None or df_raw.empty: raise ValueError("Raw data loading failed") + df_raw.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True) + logger.info("Loaded raw data.") + + # 2. Engineer Base Features (using config) + # FIX: Instantiate FeatureEngineer correctly using self.config + temp_feature_engineer = FeatureEngineer(config=self.config) + df_engineered = temp_feature_engineer.add_base_features(df_raw) + df_engineered.dropna(inplace=True) # Drop NaNs after feature eng + if df_engineered.empty: raise ValueError("Dataframe empty after feature engineering.") + logger.info("Engineered base features.") + + # 3. Define Labels (on the full engineered data) + horizon = self.config['gru'].get('prediction_horizon', 5) + target_ret_col = f'fwd_log_ret_{horizon}' + target_dir_col = f'direction_label3_{horizon}' if self.use_ternary else f'direction_label_{horizon}' + _EPS = 1e-9 + if 'future_close' not in df_engineered.columns: + df_engineered['future_close'] = df_engineered['close'].shift(-horizon) + if 'future_close' in df_engineered.columns: + df_engineered[target_ret_col] = np.log(df_engineered['future_close'] / (df_engineered['close'] + _EPS)) + else: + df_engineered[target_ret_col] = np.log(df_engineered['close'].shift(-horizon) / (df_engineered['close'] + _EPS)) + + if self.use_ternary: + flat_sigma = self.config.get('gru', {}).get('flat_sigma_multiplier', 0.3) + if 'ATR_14' in df_engineered.columns: + atr_aligned = df_engineered['ATR_14'].reindex(df_engineered.index).bfill().ffill() + threshold = flat_sigma * atr_aligned / (df_engineered['close'] + _EPS) + df_engineered[target_dir_col] = np.select( + [df_engineered[target_ret_col] > threshold, df_engineered[target_ret_col] < -threshold], + [2, 0], default=1 + ) + else: + raise ValueError("ATR_14 needed for ternary labels not found.") + else: # Binary case + df_engineered[target_dir_col] = (df_engineered[target_ret_col] > 0).astype(int) + + # 4. Align Features and Targets + df_engineered.dropna(subset=[target_ret_col, target_dir_col], inplace=True) + # Identify *all* potential feature columns (excluding targets) + potential_feature_cols = [col for col in df_engineered.columns if col not in [target_ret_col, target_dir_col, 'future_close']] + # Ensure whitelist features are present in potential features + loaded_whitelist = gru_dependencies['whitelist'] + missing_whitelist_check = set(loaded_whitelist) - set(potential_feature_cols) + if missing_whitelist_check: + logger.warning(f"Whitelist features missing from engineered columns: {missing_whitelist_check}. Adding with 0 fill.") + for col in missing_whitelist_check: + df_engineered[col] = 0.0 + potential_feature_cols.append(col) + # Now df_engineered contains all potential features and targets, aligned + logger.info("Defined labels and aligned features/targets.") + + # 5. Split Data (using aligned engineered data) + wf_enabled = self.config.get('walk_forward', {}).get('enabled', False) + if wf_enabled: logger.warning("Walk-forward enabled, but SAC uses split_ratios.") + split_ratios = self.config.get('walk_forward', {}).get('split_ratios', {}) + train_ratio = split_ratios.get('train', 0.6); val_ratio = split_ratios.get('validation', 0.2) + if not (0 < train_ratio < 1 and 0 < val_ratio < 1 and (train_ratio + val_ratio) <= 1.0): + raise ValueError(f"Invalid split ratios: train={train_ratio}, validation={val_ratio}") + total_len = len(df_engineered) + train_end_idx = int(total_len * train_ratio) + val_end_idx = int(total_len * (train_ratio + val_ratio)) + val_indices = df_engineered.index[train_end_idx:val_end_idx] + if val_indices.empty: raise ValueError("Validation split resulted in empty indices.") + df_val_aligned = df_engineered.loc[val_indices] + + # -- Determine columns expected by scaler -- + scaler = gru_dependencies['scaler'] + expected_scaler_features = [] + if hasattr(scaler, 'feature_names_in_'): + expected_scaler_features = scaler.feature_names_in_.tolist() + logger.debug(f"Scaler expects features: {expected_scaler_features}") + else: + # Fallback: Use all numeric columns if scaler has no names + logger.warning("Scaler has no feature names saved. Assuming it was fit on all numeric columns present at that time.") + # This is risky; need to ensure columns match implicitly + expected_scaler_features = df_val_aligned.select_dtypes(include=np.number).columns.tolist() + # Manually exclude known target columns if needed as a safeguard + expected_scaler_features = [f for f in expected_scaler_features if f not in [target_ret_col, target_dir_col]] + logger.debug(f"Falling back to using numeric columns for scaler: {expected_scaler_features}") + + # Ensure all expected scaler features exist in the validation slice + missing_for_scaler = set(expected_scaler_features) - set(df_val_aligned.columns) + if missing_for_scaler: + # Attempt to add missing with 0 fill (e.g., if a feature wasn't calculable for val split) + logger.warning(f"Columns expected by scaler missing from validation data: {missing_for_scaler}. Adding with 0 fill.") + for col in missing_for_scaler: + df_val_aligned[col] = 0.0 + # Re-verify + missing_for_scaler = set(expected_scaler_features) - set(df_val_aligned.columns) + if missing_for_scaler: + raise ValueError(f"Could not prepare all features expected by scaler, still missing: {missing_for_scaler}") + + # Isolate the exact features needed for the scaler transform + X_val_engineered_for_scaler = df_val_aligned[expected_scaler_features] + y_val = df_val_aligned[[target_ret_col, target_dir_col]] + logger.info(f"Isolated validation set data for scaler (Features: {X_val_engineered_for_scaler.shape}, Targets: {y_val.shape}).") + + # 6. Scale the features expected by the scaler + X_val_scaled_full = X_val_engineered_for_scaler.copy() + try: + X_val_scaled_full[expected_scaler_features] = scaler.transform(X_val_engineered_for_scaler) + logger.info("Scaled validation features using loaded scaler.") + except ValueError as e: + logger.error(f"Error applying scaler transform even after aligning columns: {e}") + raise + + # 7. WORKAROUND: Drop non-feature columns (like future_close) *after* scaling + # Identify actual features intended for the model (whitelist) + loaded_whitelist = gru_dependencies['whitelist'] + cols_to_keep_after_scale = [f for f in loaded_whitelist if f in X_val_scaled_full.columns] + missing_whitelist_post_scale = set(loaded_whitelist) - set(cols_to_keep_after_scale) + if missing_whitelist_post_scale: + # This shouldn't happen if whitelist features were in expected_scaler_features + logger.warning(f"Whitelist features missing AFTER scaling: {missing_whitelist_post_scale}. This might indicate issues.") + + if not cols_to_keep_after_scale: + raise ValueError("Whitelist resulted in no features remaining after scaling step.") + + X_val_scaled_eng = X_val_scaled_full[cols_to_keep_after_scale].copy() + logger.info(f"Selected whitelisted features after scaling. Shape: {X_val_scaled_eng.shape}") + + # 8. Pruning step is now effectively done by selecting whitelist columns above. + X_val_pruned_scaled = X_val_scaled_eng # Rename for clarity in subsequent steps + logger.info(f"Using whitelisted scaled features for sequencing. Final shape: {X_val_pruned_scaled.shape}") + + # 9. Create Validation Sequences (using the pruned & scaled data) + lookback = self.config['gru']['lookback'] + X_val_seq_list, y_val_seq_targets_list, val_seq_indices = [], [], [] + features_np_arr = X_val_pruned_scaled.values + targets_np_arr = y_val.values + for i in range(lookback, len(features_np_arr)): + X_val_seq_list.append(features_np_arr[i-lookback : i]) + y_val_seq_targets_list.append(targets_np_arr[i]) + val_seq_indices.append(y_val.index[i]) + + if not X_val_seq_list: raise ValueError("Validation sequence creation resulted in empty list.") + X_val_seq = np.array(X_val_seq_list) + y_val_seq_targets = np.array(y_val_seq_targets_list) + actual_ret_val_seq = y_val_seq_targets[:, 0] + y_dir_val_seq = y_val_seq_targets[:, 1] + logger.info(f"Created validation sequences (X shape: {X_val_seq.shape}).") + + # 10. Get GRU Predictions using loaded GRU model directly + gru_model = gru_dependencies['gru_model'] + logger.info(f"Generating GRU predictions using loaded model (type: {type(gru_model)}).") + + # Check model type and predict accordingly + if not hasattr(gru_model, 'predict'): + raise TypeError("Loaded GRU model object does not have a 'predict' method.") + + predictions_val = gru_model.predict(X_val_seq) + + # --- Explicit Shape Logging --- # + if isinstance(predictions_val, list): + logger.info(f"DEBUG: GRU prediction output list length: {len(predictions_val)}") + if len(predictions_val) > 0: logger.info(f"DEBUG: Shape of predictions_val[0]: {predictions_val[0].shape}") + if len(predictions_val) > 1: logger.info(f"DEBUG: Shape of predictions_val[1]: {predictions_val[1].shape}") + if len(predictions_val) > 2: logger.info(f"DEBUG: Shape of predictions_val[2]: {predictions_val[2].shape}") + else: + logger.info(f"DEBUG: GRU prediction output type: {type(predictions_val)}") + # --- End Explicit Shape Logging --- # + + # --- Infer output structure (Corrected Indexing) --- + mu_val_pred, sigma_val_pred, p_raw_val_pred, logits_val_pred = None, None, None, None + + if isinstance(predictions_val, list) and len(predictions_val) == 2: + logger.debug(f"GRU model returned list of {len(predictions_val)} outputs. Correcting index assumption: [mu, dir].") + # Corrected Indexing based on logs: + mu_output = predictions_val[0] # Index 0 has shape (N, 1) + dir_output = predictions_val[1] # Index 1 has shape (N, 3) + + # Extract Mu + if mu_output.ndim == 2 and mu_output.shape[-1] == 1: + mu_val_pred = mu_output.flatten() + else: + # Corrected error message to reflect index 0 check + raise ValueError(f"Unexpected shape for mu output (index 0): {mu_output.shape}. Expected (N, 1).") + + # Handle missing Sigma - Use a default/fallback + logger.warning("Log_sigma_sq output not found in model prediction. Using default sigma=0.1.") + sigma_val_pred = np.ones_like(mu_val_pred) * 0.1 + + # Determine direction output type from dir_output (index 1) + if self.use_ternary: + if dir_output.ndim == 2 and dir_output.shape[-1] == 3: + # Assume dir_output (index 1) is logits if ternary + logits_val_pred = dir_output + from scipy.special import softmax # Local import ok here + p_raw_val_pred = softmax(logits_val_pred, axis=-1) + logger.info("Inferred ternary output (logits assumed at index 1). Calculated raw probabilities.") + else: + raise ValueError(f"Expected ternary output shape (N, 3) at index 1, got {dir_output.shape}") + else: # Binary + # Binary case is tricky if dir_output (index 1) is still (N, 3) + # This shouldn't happen if the model structure changes for binary + # Let's assume the mu_output (index 0) might be P(up) in binary? + logger.warning(f"Binary mode, dir_output shape is {dir_output.shape}. Trying to use mu output (index 0) as P(up) - THIS IS RISKY.") + if mu_output.ndim == 2 and mu_output.shape[-1] == 1: + p_raw_val_pred = mu_output.flatten() + epsilon = 1e-7 + p_clipped = np.clip(p_raw_val_pred, epsilon, 1 - epsilon) + logits_val_pred = np.log(p_clipped / (1 - p_clipped)) + logger.info("Using mu output (index 0) as P(up) for binary mode. Inferred raw logits.") + else: + raise ValueError(f"Cannot determine binary P(up) output. mu_output shape: {mu_output.shape}") + else: + raise ValueError(f"GRU prediction failed or returned unexpected type/structure: {type(predictions_val)}, length: {len(predictions_val) if isinstance(predictions_val, list) else 'N/A'}. Expected list of 2 arrays.") + + # Final check for necessary predictions + if mu_val_pred is None or sigma_val_pred is None or p_raw_val_pred is None: + raise ValueError("Failed to extract mu, sigma, or raw probabilities from GRU model output.") + + # Verify lengths + n_seq = len(X_val_seq) + if not (len(mu_val_pred) == n_seq and len(sigma_val_pred) == n_seq and \ + p_raw_val_pred.shape[0] == n_seq and len(actual_ret_val_seq) == n_seq): + raise ValueError(f"Length mismatch after validation predictions: Expected {n_seq}, " + f"got mu={len(mu_val_pred)}, sigma={len(sigma_val_pred)}, " + f"p_raw={p_raw_val_pred.shape[0]}, ret={len(actual_ret_val_seq)}") + + # 11. Calibrate Predictions using loaded parameters + calib_method = gru_dependencies.get('calibration_method') + calib_params = gru_dependencies.get('calibration_params') + p_cal_val_pred = None + + if calib_method == "temperature" and calib_params is not None: + optimal_T = calib_params + # Temperature scaling needs logits + if logits_val_pred is not None: + scaled_logits = logits_val_pred / optimal_T + if self.use_ternary: + p_cal_val_pred = softmax(scaled_logits, axis=-1) + logger.info(f"Applied temperature scaling (T={optimal_T:.4f}) to ternary logits.") + else: # Binary - apply sigmoid + p_cal_val_pred = 1 / (1 + np.exp(-scaled_logits)) # Sigmoid + logger.info(f"Applied temperature scaling (T={optimal_T:.4f}) to binary logits.") + else: + logger.error(f"Cannot apply temperature scaling (method={calib_method}): Raw logits not available.") + p_cal_val_pred = p_raw_val_pred # Fallback to raw + + elif calib_method == "vector" and calib_params is not None: + # Vector calibration needs logits + if logits_val_pred is not None: + try: + from gru_sac_predictor.src.calibrator_vector import VectorCalibrator # Ensure import + temp_vector_calibrator = VectorCalibrator() + temp_vector_calibrator.optimal_params = calib_params # Set loaded params + # Calibrate expects logits, returns probabilities + p_cal_val_pred = temp_vector_calibrator.calibrate(logits_val_pred) + logger.info(f"Calibrated validation predictions using loaded vector parameters.") + except ImportError: + logger.error("Cannot import VectorCalibrator. Vector calibration step skipped.") + p_cal_val_pred = p_raw_val_pred # Fallback to raw + except Exception as vec_e: + logger.error(f"Error during vector calibration application: {vec_e}", exc_info=True) + p_cal_val_pred = p_raw_val_pred # Fallback to raw + else: + logger.error(f"Cannot apply vector calibration (method={calib_method}): Raw logits not available.") + p_cal_val_pred = p_raw_val_pred # Fallback to raw + + else: # No calibration or failed + logger.warning(f"Calibration method was '{calib_method}' or params were None or previous step failed. Using RAW predictions for SAC environment.") + p_cal_val_pred = p_raw_val_pred # Use the raw probs (softmaxed for ternary, P(up) for binary) + + # Final check for calibrated probabilities + if p_cal_val_pred is None: + logger.error("Failed to obtain final (calibrated or raw) predictions. Cannot proceed.") + return None + # Verify shape of final probabilities + expected_final_dim = 3 if self.use_ternary else 1 + if p_cal_val_pred.ndim == 2 and p_cal_val_pred.shape[-1] == expected_final_dim: + if not self.use_ternary: p_cal_val_pred = p_cal_val_pred.flatten() # Flatten binary case + elif p_cal_val_pred.ndim == 1 and not self.use_ternary and expected_final_dim == 1: + pass # Already flat for binary + else: + logger.error(f"Final probability array has unexpected shape: {p_cal_val_pred.shape}. Expected dim {expected_final_dim}.") + return None + + + # 12. Return the necessary components for the TradingEnv + logger.info("--- Successfully prepared validation data for SAC Environment ---") + return mu_val_pred, sigma_val_pred, p_cal_val_pred, actual_ret_val_seq + + except Exception as e: + logger.error(f"Error preparing data for SAC environment: {e}", exc_info=True) + return None + + def _load_agent_for_resume(self, agent: SACTradingAgent) -> None: + """Loads agent weights and state filter status if resuming.""" + load_run_id = self.control_cfg.get('sac_resume_run_id') + load_step = self.control_cfg.get('sac_resume_step', 'final') + current_edge_threshold = self.config.get('calibration', {}).get('edge_threshold', 0.55) + + if not load_run_id: + logger.info("No SAC resume run ID specified. Starting training from scratch.") + return + + # Construct path relative to base models dir + if load_step == 'final': + load_path = os.path.join(self.base_models_dir, f"{load_run_id}", 'sac_agent_final') + else: + try: + step_num = int(load_step) + load_path = os.path.join(self.base_models_dir, f"{load_run_id}", f'sac_agent_step_{step_num}') + except ValueError: + logger.error(f"Invalid sac_resume_step: {load_step}. Must be 'final' or an integer. Starting fresh.") + return + + logger.info(f"Attempting to load SAC agent from {load_path} to resume training...") + if os.path.exists(load_path): + try: + loaded_meta = agent.load(load_path) + # Check for Buffer Purge on Load + saved_edge_thr = loaded_meta.get('edge_threshold') + if saved_edge_thr is not None and abs(saved_edge_thr - current_edge_threshold) > 1e-6: + logger.warning(f'Edge threshold mismatch on load (Saved={saved_edge_thr:.3f}, Current={current_edge_threshold:.3f}). Clearing replay buffer before resuming.') + agent.clear_buffer() + elif saved_edge_thr is None: + logger.warning("Loaded SAC agent metadata did not contain 'edge_threshold'. Cannot verify consistency.") + else: + logger.info('Edge threshold consistent with loaded agent metadata.') + + # --- Load State Filter (Task 5.2) --- # + filter_state_path = os.path.join(load_path, 'state_filter.npz') + if self.state_filter is not None and os.path.exists(filter_state_path): + try: + with np.load(filter_state_path) as data: + filter_state = {key: data[key] for key in data.files} + self.state_filter.set_state(filter_state) + logger.info(f"Loaded state filter state from {filter_state_path}") + except Exception as filter_e: + logger.error(f"Failed to load state filter state from {filter_state_path}: {filter_e}. Filter will be reset.") + elif self.state_filter is not None: + logger.warning(f"State filter state file not found at {filter_state_path}. Filter will be reset.") + # --- End Load State Filter --- # + + except Exception as e: + logger.error(f"Failed to load SAC agent for resume: {e}. Starting fresh.", exc_info=True) + else: + logger.warning(f"SAC agent path not found for resume: {load_path}. Starting fresh.") + + def _training_loop(self, agent: SACTradingAgent, env: TradingEnv) -> str | None: + """The main SAC training loop.""" + total_steps = self.sac_cfg.get('total_training_steps', 100000) + start_steps = self.sac_cfg.get('start_steps', 10000) + update_after = self.sac_cfg.get('update_after', 1000) + update_every = self.sac_cfg.get('update_every', 50) + save_freq = self.sac_cfg.get('save_freq', 5000) + log_freq = self.sac_cfg.get('log_freq', 100) + buffer_capacity = self.sac_cfg.get('buffer_capacity', 1000000) + batch_size = self.sac_cfg.get('batch_size', 256) + + # Initialize Replay Buffer (Standard or Prioritized) + if self.use_per: + logger.info(f"Using Prioritized Replay Buffer (Capacity: {buffer_capacity})") + replay_buffer = PrioritizedReplayBuffer( + capacity=buffer_capacity, + alpha=self.per_alpha, # Initial alpha + beta_start=self.per_beta_start, + beta_frames=self.per_beta_frames + ) + else: + logger.info(f"Using Standard Replay Buffer (Capacity: {buffer_capacity})") + replay_buffer = collections.deque(maxlen=buffer_capacity) + + # TensorBoard setup + tb_callback = TensorBoard(log_dir=self.sac_tb_log_dir) + # --- Revision 4: Set model for TensorBoard --- # + # Check if agent has actor/critic models accessible + # This depends heavily on SACTradingAgent implementation + # Assuming agent.actor and agent.critic1/2 are the models + if hasattr(agent, 'actor') and hasattr(agent, 'critic1') and hasattr(agent, 'critic2'): + tb_callback.set_model(agent.actor) # Link to one model is often sufficient + logger.info("TensorBoard callback linked to SAC agent model (actor).") + else: + logger.warning("Could not link TensorBoard callback to agent models (actor/critic not found).") + # --- End Revision 4 --- + + # --- Initialize optional imputed transition logger --- # + imputed_log_path = os.path.join(self.sac_run_results_dir, 'sac_imputed_transitions.csv') + imputed_log_file = None + imputed_csv_writer = None + try: + imputed_log_file = open(imputed_log_path, 'w', newline='') + imputed_csv_writer = csv.writer(imputed_log_file) + imputed_csv_writer.writerow(['step', 'imputed_handling_mode', 'action', 'reward', 'position_before', 'position_after']) + logger.info(f"Logging imputed transitions details to: {imputed_log_path}") + except IOError as e: + logger.error(f"Failed to open imputed transition log file {imputed_log_path}: {e}. Logging disabled.") + if imputed_log_file: imputed_log_file.close() + imputed_log_file = None + imputed_csv_writer = None + # --- End init imputed logger --- # + + state = env.reset() + # Normalize initial state if filter is active + if self.state_filter: + state = self.state_filter(state, update=True) # Update filter with initial state + + total_reward = 0.0 + start_time = time.time() + + logger.info(f"Starting SAC training loop for {total_steps} steps...") + for step in tqdm(range(total_steps), desc="SAC Training Steps"): + # Store state before taking action (for replay buffer) + state_before_action = state + position_before_action = env.current_position # Store for imputed log + + # Select action + if step < start_steps: + action = env.action_space.sample()[0] # Sample random action + else: + action = agent.select_action(state) + + # Step the environment + next_state_raw, reward, done, info = env.step(action) + + # Check if the step was skipped due to imputed bar handling + is_skipped = info.get('is_imputed_step_skipped', False) + + # Normalize next state if filter is active + if self.state_filter: + next_state = self.state_filter(next_state_raw, update=True) # Update filter + else: + next_state = next_state_raw + + # Store transition ONLY if the step was NOT skipped + if not is_skipped: + if self.use_per: + # Add with initial error=1 (or max priority) and is_seeded=False + # The error will be updated after the first training step on this sample. + replay_buffer.add(error=1.0, sample=(state_before_action, action, reward, next_state, done)) + else: + replay_buffer.append((state_before_action, action, reward, next_state, done)) + else: + logger.debug(f"Step {step}: Skipped adding transition to buffer due to imputed bar handling (mode=skip).") + + # --- Log imputed transition details to CSV --- # + # Check if the *previous* step was imputed and handled by hold/penalty + # env.current_step was incremented *inside* env.step() + was_imputed_idx = env.current_step - 1 + if 0 <= was_imputed_idx < env.n_steps and env.bar_imputed[was_imputed_idx] and not is_skipped and imputed_csv_writer: + # Don't log if skipped, log if hold/penalty was applied + imputed_handling_mode = self.config.sac.get('imputed_handling', 'unknown') + action_taken = env.current_position if imputed_handling_mode == 'hold' else action # Action applied or agent's intended action + position_after_action = env.current_position + try: + imputed_csv_writer.writerow([ + was_imputed_idx, # Log the actual step number where imputation occurred + imputed_handling_mode, + f"{action_taken:.4f}", + f"{reward:.6f}", # Log the reward received for this step + f"{position_before_action:.4f}", + f"{position_after_action:.4f}" + ]) + except Exception as log_e: + logger.warning(f"Failed to write imputed transition to CSV: {log_e}") + # --- End Log imputed transition --- # + + state = next_state + total_reward += reward + + # Perform SAC agent updates + if step >= update_after and step % update_every == 0: + for update_i in range(update_every): # Perform multiple updates per interval + if self.use_per: + if len(replay_buffer) > batch_size: + idxs, batch_data, is_weights = replay_buffer.sample(batch_size, beta=replay_buffer.beta) # Use annealed beta + # Unpack batch_data which contains (sample, is_seeded) tuples + batch = [item[0] for item in batch_data] + update_info = agent.update(batch, is_weights=is_weights, per_beta=replay_buffer.beta) + if update_info and 'td_errors' in update_info: + # Update priorities using current annealed alpha + current_alpha = agent.get_current_per_alpha(step) + replay_buffer.update_priorities(idxs, update_info['td_errors'], alpha=current_alpha) + else: + continue # Not enough samples yet for PER + else: # Standard buffer + if len(replay_buffer) > batch_size: + indices = np.random.choice(len(replay_buffer), size=batch_size, replace=False) + batch = [replay_buffer[i] for i in indices] + update_info = agent.update(batch) + else: + continue # Not enough samples yet + + # Log training metrics (losses, Q-values, alpha) to TensorBoard + if update_info and step % log_freq == 0 and update_i == 0: # Log once per interval + with tb_callback.writer.as_default(): + for key, value in update_info.items(): + if key != 'td_errors': # Don't log TD errors directly + tf.summary.scalar(f'sac/{key}', value, step=step) + # logger.debug(f"Step {step}: Logged SAC metrics to TensorBoard.") + + # Check environment done state + if done: + state = env.reset() + # Normalize reset state + if self.state_filter: + state = self.state_filter(state, update=True) + total_reward = 0.0 # Reset episodic reward + + # Save agent checkpoint periodically + if step % save_freq == 0 and step > 0: + save_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step}') + try: + agent.save(save_path) + # Also save state filter if used + if self.state_filter: + filter_path = os.path.join(save_path, 'state_filter.pkl') + joblib.dump(self.state_filter, filter_path) + logger.info(f"Saved state filter to {filter_path}") + except Exception as e: + logger.error(f"Failed to save agent/filter checkpoint at step {step} to {save_path}: {e}", exc_info=True) + + # --- Final Save --- # + final_save_path = os.path.join(self.sac_run_models_dir, 'sac_agent_final') + try: + agent.save(final_save_path) + if self.state_filter: + filter_path = os.path.join(final_save_path, 'state_filter.pkl') + joblib.dump(self.state_filter, filter_path) + logger.info(f"Saved final state filter to {filter_path}") + self.last_saved_agent_path = final_save_path # Store path for potential return + except Exception as e: + logger.error(f"Failed to save final agent/filter checkpoint to {final_save_path}: {e}", exc_info=True) + self.last_saved_agent_path = None + + end_time = time.time() + training_duration = end_time - start_time + logger.info(f"SAC training loop finished in {training_duration:.2f} seconds.") + logger.info(f"Final agent checkpoint saved to: {self.last_saved_agent_path}") + + # --- Close imputed transition log file --- # + if imputed_log_file: + try: + imputed_log_file.close() + logger.info(f"Closed imputed transition log file: {imputed_log_path}") + except Exception as e: + logger.error(f"Error closing imputed transition log file: {e}") + # --- End close log file --- # + + return self.last_saved_agent_path + + def train(self, gru_run_id_for_sac: str) -> str | None: + """ + Main entry point to start the SAC training process. + + Args: + gru_run_id_for_sac (str): The run ID of the GRU pipeline run whose artifacts should be used. + + Returns: + str | None: Path to the final saved SAC agent, or None if training failed. + """ + logger.info(f"=== Starting SAC Training Process (SAC Run ID: {self.sac_train_run_id}) ===") + logger.info(f"Using artifacts from GRU Run ID: {gru_run_id_for_sac}") + + # 1. Load GRU dependencies + gru_dependencies = self._load_gru_dependencies(gru_run_id_for_sac) + if gru_dependencies is None: + logger.error("Failed to load GRU dependencies. Aborting SAC training.") + return None + + # 2. Prepare data for SAC environment (using validation set) + env_data = self._prepare_data_for_sac(gru_dependencies) + if env_data is None: + logger.error("Failed to prepare data for SAC environment. Aborting SAC training.") + return None + mu_val, sigma_val, p_cal_val, actual_ret_val = env_data + + # 3. Initialize Environment + logger.info("Initializing Trading Environment...") + env = TradingEnv( + mu_predictions=mu_val, + sigma_predictions=sigma_val, + p_cal_predictions=p_cal_val, + actual_returns=actual_ret_val, + initial_capital=self.env_cfg.get('initial_capital', 10000.0), + transaction_cost=self.env_cfg.get('transaction_cost', 0.0005) + ) + logger.info(f"TradingEnv initialized with {env.n_steps} steps.") + + # 4. Initialize SAC Agent + logger.info("Initializing SAC Agent...") + current_edge_threshold = self.config.get('calibration', {}).get('edge_threshold', 0.55) + # --- Get Env Params for Agent Metadata (Task 5.6) --- # + reward_scale = self.env_cfg.get('reward_scale', 100.0) # Default from TradingEnv + action_penalty_lambda = self.env_cfg.get('action_penalty_lambda', 0.0) # Default from TradingEnv + # --- End Get Env Params --- # + agent = SACTradingAgent( + state_dim=env.state_dim, + action_dim=env.action_dim, + gamma=self.sac_cfg.get('gamma', 0.99), + tau=self.sac_cfg.get('tau', 0.005), + initial_lr=self.sac_cfg.get('actor_lr', 3e-4), + lr_decay_rate=self.sac_cfg.get('lr_decay_rate', 0.96), + decay_steps=self.sac_cfg.get('decay_steps', 100000), + ou_noise_stddev=self.sac_cfg.get('ou_noise_stddev', 0.2), + alpha=self.sac_cfg.get('alpha', 0.2), + alpha_auto_tune=self.sac_cfg.get('alpha_auto_tune', True), + target_entropy=self.sac_cfg.get('target_entropy', -1.0 * env.action_dim), + edge_threshold_config=current_edge_threshold, # Pass edge threshold + # --- Pass Env Params (Task 5.6) --- # + reward_scale_config=reward_scale, + action_penalty_lambda_config=action_penalty_lambda + # --- End Pass Env Params --- # + ) + logger.info("SAC Agent initialized.") + + # 5. Load agent weights if resuming + self._load_agent_for_resume(agent) + + # --- Add Training Size Assertion (Step 4-C) --- # + total_steps_cfg = self.sac_cfg.get('total_training_steps', 100000) + val_len = len(actual_ret_val) # Length of the validation dataset used by env + min_required_steps = int(0.4 * val_len) + assert total_steps_cfg >= min_required_steps, \ + f"Configured total_training_steps ({total_steps_cfg}) is less than 40% of the validation set size ({val_len}). Minimum required: {min_required_steps}. Consider increasing training steps." + logger.info(f"Training size check passed: {total_steps_cfg=}, {min_required_steps=}") + # --- End Assertion --- # + + # 6. Run training loop + final_agent_path = self._training_loop(agent, env) + + if final_agent_path: + logger.info(f"=== SAC Training Process Completed Successfully ===") + else: + logger.error("=== SAC Training Process Failed ===") + + return final_agent_path \ No newline at end of file diff --git a/gru_sac_predictor/src/trading_env.py b/gru_sac_predictor/src/trading_env.py new file mode 100644 index 00000000..4582baf4 --- /dev/null +++ b/gru_sac_predictor/src/trading_env.py @@ -0,0 +1,230 @@ +""" +Simplified Trading Environment for SAC Training. + +Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns. +""" +import numpy as np +import pandas as pd +import logging +import gymnasium as gym +from omegaconf import DictConfig # Added for config typing + +env_logger = logging.getLogger(__name__) + +class TradingEnv: + def __init__(self, + mu_predictions: np.ndarray, + sigma_predictions: np.ndarray, + p_cal_predictions: np.ndarray, + actual_returns: np.ndarray, + bar_imputed_flags: np.ndarray, # Added imputed flags + config: DictConfig, # Added config + initial_capital: float = 10000.0, + transaction_cost: float = 0.0005, + reward_scale: float = 100.0, + action_penalty_lambda: float = 0.0): + """ + Initialize the environment. + + Args: + mu_predictions: Predicted log returns (μ̂). + sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)). + p_cal_predictions: Calibrated probability of price increase (p_cal). + actual_returns: Actual log returns (y_ret). + bar_imputed_flags: Boolean array indicating if a bar was imputed. + config: OmegaConf configuration object. + initial_capital: Starting capital for simulation (used notionally in reward). + transaction_cost: Fractional cost per trade. + reward_scale: Multiplier for the reward signal. + action_penalty_lambda: Coefficient for the action magnitude penalty (λ). + """ + assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \ + "All input arrays (predictions, returns, imputed_flags) must have the same length" + + self.mu = mu_predictions + self.sigma = sigma_predictions + self.p_cal = p_cal_predictions + self.actual_returns = actual_returns + self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags + self.config = config # Store config + + self.initial_capital = initial_capital + self.transaction_cost = transaction_cost + self.reward_scale = reward_scale + self.action_penalty_lambda = action_penalty_lambda + + # --- Revision 5: Calculate action penalty based on transaction cost --- # + if self.transaction_cost > 0: + # Override passed lambda if cost is positive + self._internal_action_penalty_lambda = 0.01 / self.transaction_cost + env_logger.info(f"Using calculated action penalty lambda: {self._internal_action_penalty_lambda:.4f} (0.01 / {self.transaction_cost})") + else: + # Use passed value if cost is zero (or default to 0) + self._internal_action_penalty_lambda = self.action_penalty_lambda + env_logger.info(f"Using provided action penalty lambda: {self._internal_action_penalty_lambda:.4f}") + # --- End Revision 5 --- # + + self.n_steps = len(actual_returns) + self.current_step = 0 + self.current_position = 0.0 # Fraction of capital (-1 to 1) + self.current_capital = initial_capital # Track for info, not used in reward directly + + # State dimension: [mu, sigma, edge, |mu|/sigma, position] + self.state_dim = 5 + self.action_dim = 1 + + # --- Define Gym Spaces --- + self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32) + self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32) + # --- End Define Gym Spaces --- + + env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.") + + def _get_state(self) -> np.ndarray: + """Construct the state vector for the current step.""" + if self.current_step >= self.n_steps: + return np.zeros(self.state_dim, dtype=np.float32) + + mu_t = self.mu[self.current_step] + sigma_t = self.sigma[self.current_step] + p_cal_t = self.p_cal[self.current_step] + + # Calculate edge based on p_cal shape (binary vs ternary) + if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3: + # Ternary: edge = max(P(up), P(down)) - P(flat) + # Assuming order [Down, Flat, Up] for p_cal_t + edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1] + elif isinstance(p_cal_t, (float, np.number)): + # Binary: edge = 2 * P(up) - 1 + edge_t = 2 * p_cal_t - 1 + else: + env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.") + edge_t = 0.0 + + _EPS = 1e-9 # Define epsilon locally + z_score_t = np.abs(mu_t) / (sigma_t + _EPS) + + # State uses position *before* the action for this step is taken + state = np.array([ + mu_t, + sigma_t, + edge_t, + z_score_t, + self.current_position + ], dtype=np.float32) + return state + + def reset(self) -> np.ndarray: + """Reset the environment to the beginning.""" + self.current_step = 0 + self.current_position = 0.0 + self.current_capital = self.initial_capital + env_logger.debug("Environment reset.") + return self._get_state() + + def step(self, action: float) -> tuple[np.ndarray, float, bool, dict]: + """ + Execute one time step. + + Args: + action (float): Agent's desired position size (-1 to 1). + + Returns: + tuple: (next_state, reward, done, info_dict) + """ + info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False} + + if self.current_step >= self.n_steps: + # Should not happen if 'done' is handled correctly, but as safeguard + env_logger.warning("Step called after environment finished.") + return self._get_state(), 0.0, True, info + + # --- Handle Imputed Bar --- # + imputed = self.bar_imputed[self.current_step] + if imputed: + mode = self.config.sac.imputed_handling + env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}") + if mode == "skip": + self.current_step += 1 + next_state = self._get_state() # Get state for the *next* actual step + # Return 0 reward, not done, but indicate skip for buffer handling + info['is_imputed_step_skipped'] = True + return next_state, 0.0, False, info + elif mode == "hold": + # Action is forced to maintain current position + action = self.current_position + elif mode == "penalty": + # Calculate reward penalty based on config + target_position_penalty = np.clip(action, -1.0, 1.0) + reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2 + # Update position based on agent's intended action (clipped) + self.current_position = target_position_penalty + # Update capital notionally (no actual return, only cost if implemented) + # Cost is implicitly 0 here as there's no trade size if pos doesn't change + # If penalty mode allowed position change, cost would apply. + # For simplicity, we don't add cost here for the penalty step. + self.current_step += 1 + next_state = self._get_state() + scaled_reward = reward * self.reward_scale # Scale the penalty + done = self.current_step >= self.n_steps + info['capital'] = self.current_capital + info['position'] = self.current_position + return next_state, scaled_reward, done, info + # else: default behavior (treat as normal bar) - implicitly handled by falling through + # --- End Handle Imputed Bar --- # + + # --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- # + # Action is the TARGET position for the *end* of this step + target_position = np.clip(action, -1.0, 1.0) + trade_size = target_position - self.current_position + + # Calculate PnL based on position held *during* this step + step_actual_return = self.actual_returns[self.current_step] + # Use simple return for PnL calculation: exp(log_ret) - 1 + pnl_fraction = self.current_position * (np.exp(step_actual_return) - 1) + + # Calculate transaction costs for the trade executed now + cost_fraction = abs(trade_size) * self.transaction_cost + + # Reward is net PnL fraction (doesn't scale with capital directly) + reward = pnl_fraction - cost_fraction + + # --- Apply Action Penalty (Revision 5) --- # + # Penalty is applied to the raw reward *before* scaling + # Uses the *trade size* for the penalty, with internal lambda + if self._internal_action_penalty_lambda > 0: + action_penalty = self._internal_action_penalty_lambda * abs(trade_size) + reward -= action_penalty + env_logger.debug(f"Step {self.current_step}: Action penalty applied: {action_penalty:.5f} (lambda={self._internal_action_penalty_lambda:.4f}, trade_size={trade_size:.3f})") + # --- End Action Penalty --- # + + # --- Apply Reward Scaling (Task 5.1) --- # + scaled_reward = reward * self.reward_scale + # --- End Reward Scaling --- # + + # Update internal state for the *next* step + self.current_position = target_position + self.current_capital *= (1 + pnl_fraction - cost_fraction) # Update tracked capital + self.current_step += 1 + + # Check if done + done = self.current_step >= self.n_steps or self.current_capital <= 0 + + next_state = self._get_state() + # Update info dict (capital/position might have changed in normal step) + info['capital'] = self.current_capital + info['position'] = self.current_position + + # Log step details periodically + # if self.current_step % 1000 == 0: + # env_logger.debug(f"Step {self.current_step}: Action={action:.2f}, Pos={self.current_position:.2f}, Ret={step_actual_return:.5f}, Rew={reward:.5f}, Cap={self.current_capital:.2f}") + + if done: + env_logger.info(f"Environment finished at step {self.current_step}. Final Capital: {self.current_capital:.2f}") + + return next_state, scaled_reward, done, info + + def close(self): + """Clean up any resources (if needed).""" + env_logger.info("TradingEnv closed.") + pass \ No newline at end of file diff --git a/gru_sac_predictor/src/trading_pipeline.py b/gru_sac_predictor/src/trading_pipeline.py new file mode 100644 index 00000000..25738260 --- /dev/null +++ b/gru_sac_predictor/src/trading_pipeline.py @@ -0,0 +1,1556 @@ +""" +Main Orchestrator for the Trading Pipeline. + +Coordinates data loading, feature engineering, model training, calibration, +SAC training, and backtesting. +""" + +import os +import sys +import logging +import yaml +import pandas as pd +import numpy as np +from datetime import datetime, timezone, timedelta +import argparse +import joblib +import json +from typing import Optional, Any, List, Tuple, Iterator, Dict +import matplotlib.pyplot as plt +import seaborn as sns +import torch # Added for SAC weight aggregation +from collections import OrderedDict # Added for SAC weight aggregation +import time +import shutil +import copy # Added import + +# Determine the project root directory based on the script location +# This assumes the script is in src/ and the project root is two levels up +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(os.path.dirname(script_dir)) +# Add project root to sys.path to allow absolute imports from the package +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# Now use absolute imports based on the package structure +from gru_sac_predictor.src.data_loader import DataLoader +# Import other components as they are created +from gru_sac_predictor.src.feature_engineer import FeatureEngineer +# Try importing minimal_whitelist from features.py (assuming it exists there) +try: + from gru_sac_predictor.src.features import minimal_whitelist +except ImportError: + # Fallback: Define it here if features.py doesn't exist or doesn't have it + logging.warning("Could not import minimal_whitelist from .features, defining fallback.") + minimal_whitelist = [ + "return_1m", "return_15m", "return_60m", "ATR_14", "volatility_14d", + "chaikin_AD_10", "svi_10", "EMA_10", "EMA_50", "MACD", "MACD_signal", + "hour_sin", "hour_cos", + ] +from gru_sac_predictor.src.gru_model_handler import GRUModelHandler +from gru_sac_predictor.src.calibrator import Calibrator +# --- Import Vector Calibrator (Task 4) --- # +try: + from gru_sac_predictor.src.calibrator_vector import VectorCalibrator, VECTOR_CALIBRATOR_AVAILABLE +except ImportError: + logging.warning("VectorCalibrator could not be imported. Vector scaling method will not be available.") + VectorCalibrator = None # Define as None if import fails + VECTOR_CALIBRATOR_AVAILABLE = False +# --- End Import --- # +from gru_sac_predictor.src.sac_trainer import SACTrainer +from gru_sac_predictor.src.sac_agent import SACTradingAgent # Added for SAC agent loading +from gru_sac_predictor.src.backtester import Backtester +from gru_sac_predictor.src.baseline_checker import BaselineChecker # Import BaselineChecker +from gru_sac_predictor.src.metrics import edge_filtered_accuracy, calculate_brier_score, calculate_sharpe_ratio, _calculate_optimal_edge_threshold # Added calculate_sharpe_ratio if needed for aggregation + +# Removed redundant imports for feature selection +from sklearn.preprocessing import StandardScaler +# --- Add imports for baseline --- # +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, roc_auc_score, classification_report +from sklearn.model_selection import train_test_split +import scipy.stats as st +# --- Import edge_filtered_accuracy (Task 6.1/6.2) --- # +try: + # Ensure both metrics are imported + from gru_sac_predictor.src.metrics import edge_filtered_accuracy, calculate_brier_score, _calculate_optimal_edge_threshold +except ImportError: + logging.error("Failed to import metrics from gru_sac_predictor.src.metrics. Validation check will fail.") + # Define placeholders + def edge_filtered_accuracy(*args, **kwargs): return np.nan, 0 + def calculate_brier_score(*args, **kwargs): return np.nan + def _calculate_optimal_edge_threshold(*args, **kwargs): return 0.1 # Default fallback +# --- End Import --- # +# --- End imports for baseline --- # + +# --- Import Stage Functions --- # +from gru_sac_predictor.src.pipeline_stages.data_processing import ( + engineer_features_for_fold, + define_labels_and_align_fold, + split_data_fold +) +from gru_sac_predictor.src.pipeline_stages.feature_processing import ( + scale_features_fold, + select_features_fold, + prune_features_fold # Added import +) +from gru_sac_predictor.src.pipeline_stages.sequence_creation import ( + create_sequences_fold # Added import +) +from gru_sac_predictor.src.pipeline_stages.evaluation import run_baseline_checks_fold # Added baseline check import +from gru_sac_predictor.src.pipeline_stages.evaluation import run_gru_validation_checks_fold # Added validation check import +from gru_sac_predictor.src.pipeline_stages.evaluation import run_backtest_fold # Import the new backtest function +from gru_sac_predictor.src.pipeline_stages.modelling import train_or_load_gru_fold +from gru_sac_predictor.src.pipeline_stages.modelling import calibrate_probabilities_fold +from gru_sac_predictor.src.pipeline_stages.modelling import train_or_load_sac_fold +from gru_sac_predictor.src.pipeline_stages.modelling import aggregate_sac_agents + +# --- Import FoldGenerator --- # +from gru_sac_predictor.src.fold_generator import FoldGenerator + +logger = logging.getLogger(__name__) # Use module-level logger + +# --- Refactored Label Generation Logic --- # +# [Function _generate_direction_labels removed - Moved to data_processing.py] +# --- End Refactored Label Generation --- # + +# --- Gap & Regime Helper Functions --- # +def add_regime_tags(df: pd.DataFrame, indicator: str, window: int, quantiles: List[float]) -> pd.DataFrame: + """ + Calculates a regime indicator and adds a 'regime_tag' column. + + Args: + df (pd.DataFrame): Input dataframe with features (must include 'close' or 'return'). + indicator (str): The type of indicator ('volatility'). + window (int): The rolling window size for the indicator. + quantiles (List[float]): Quantiles to define regime boundaries. + + Returns: + pd.DataFrame: DataFrame with 'regime_tag' column added. + """ + df = df.copy() + if indicator == 'volatility': + # Calculate rolling volatility (std dev of log returns) + if 'return_1m' not in df.columns: # Assuming 'return_1m' or similar exists + df['log_ret'] = np.log(df['close'] / df['close'].shift(1)) + indicator_series = df['log_ret'].rolling(window=window, min_periods=window // 2).std() + del df['log_ret'] + else: + # Use a pre-calculated short-term return if available + indicator_series = df['return_1m'].rolling(window=window, min_periods=window // 2).std() + # Add other indicator calculations here (e.g., trend, RSI) + else: + logger.warning(f"Unsupported regime indicator: {indicator}. Skipping regime tagging.") + df['regime_tag'] = -1 # Assign a default tag + return df + + # Calculate quantile thresholds + thresholds = indicator_series.quantile(quantiles).tolist() + + # Assign regime tags based on thresholds + def assign_tag(value): + if pd.isna(value): + return -1 # Handle NaNs (e.g., at the start of the series) + for i, threshold in enumerate(thresholds): + if value <= threshold: + return i + return len(thresholds) # Tag for values above the highest threshold + + df['regime_tag'] = indicator_series.apply(assign_tag) + logger.info(f"Regime tags added using '{indicator}' (window={window}). Distribution:\n{df['regime_tag'].value_counts(normalize=True).sort_index()}") + return df + +def split_into_contiguous_chunks(df: pd.DataFrame, gap_threshold_minutes: int) -> List[Tuple[pd.Timestamp, pd.Timestamp]]: + """ + Splits a DataFrame index into contiguous chunks based on time gaps. + + Args: + df (pd.DataFrame): DataFrame with a DatetimeIndex. + gap_threshold_minutes (int): Maximum allowed gap in minutes before splitting. + + Returns: + List[Tuple[pd.Timestamp, pd.Timestamp]]: List of (start, end) timestamps for each chunk. + """ + if not isinstance(df.index, pd.DatetimeIndex): + raise ValueError("DataFrame must have a DatetimeIndex.") + + if df.empty: + return [] + + chunks = [] + start_time = df.index[0] + gap_threshold = pd.Timedelta(minutes=gap_threshold_minutes) + + time_diffs = df.index.to_series().diff() + split_points = time_diffs[time_diffs > gap_threshold].index + + last_end_time = start_time + for split_time in split_points: + # Find the index entry *before* the split point + split_idx_loc = df.index.get_loc(split_time) + if split_idx_loc > 0: + end_time = df.index[split_idx_loc - 1] + if end_time >= start_time: + chunks.append((start_time, end_time)) + logger.debug(f"Detected chunk: [{start_time}, {end_time}] due to gap before {split_time}") + # Start the next chunk at the split_time itself + start_time = split_time + last_end_time = end_time # Keep track for the final chunk + + # Add the final chunk + final_end_time = df.index[-1] + if final_end_time >= start_time: + chunks.append((start_time, final_end_time)) + logger.debug(f"Adding final chunk: [{start_time}, {final_end_time}]") + + logger.info(f"Split data into {len(chunks)} contiguous chunks based on gap threshold > {gap_threshold_minutes} minutes.") + return chunks +# --- End Gap & Regime Helpers --- # + +class TradingPipeline: + """Orchestrates the entire trading strategy pipeline.""" + + def __init__(self, config: dict, io_manager: Optional[Any] = None): + """ + Initialize the pipeline with configuration, optional CLI args, and IOManager. + + Args: + config (dict): The loaded configuration dictionary. + io_manager (IOManager, optional): Initialized IOManager instance. Defaults to None. + """ + # Store the passed config dictionary directly + self.config = config + # Run ID and Git SHA should be generated *before* logger/io setup in run.py + # and passed via the IOManager. + if io_manager is None: + # IOManager is considered essential for proper operation. + # Raise an error or handle appropriately if not provided. + # For now, log critical error and exit, assuming IOManager is required. + # TODO: Decide final handling if IOManager *can* be optional. + logging.critical("IOManager not provided during TradingPipeline initialization. Cannot proceed.") + raise ValueError("IOManager instance is required for TradingPipeline.") + # Fallback removed - rely on IOManager + # self.run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_fallback") + # self.git_sha = "unknown" + # logger_to_use = logging # Use root logger if no io/logger setup provided + else: + # --- Retrieve run_id and git_sha FROM io_manager --- + if not hasattr(io_manager, 'run_id') or not io_manager.run_id: + raise ValueError("IOManager instance provided, but does not have a 'run_id' attribute set.") + self.run_id = io_manager.run_id + + # Assume git_sha is also an attribute of the initialized IOManager + if not hasattr(io_manager, 'git_sha') or not io_manager.git_sha: + # logging.warning("IOManager instance does not have 'git_sha' attribute. Using 'unknown'.") + # self.git_sha = "unknown" # Or raise error if mandatory + # For now, let's assume it's mandatory for traceability + raise ValueError("IOManager instance provided, but does not have a 'git_sha' attribute set.") + self.git_sha = io_manager.git_sha + # --- End Retrieval --- + + logger_to_use = logging.getLogger() # Assume logger was set up by IOManager/run.py + + # self.io is now guaranteed to be an IOManager instance if we proceed past checks + self.io = io_manager + # Store git_sha on self as well, retrieved via IOManager + # Ensure git_sha exists on io_manager (already checked above, but good practice) + self.git_sha = getattr(io_manager, 'git_sha', 'unknown_in_pipeline') + self.pipeline_version = "3.0.0" # Placeholder version + + # --- Directory Setup (Now handled by IOManager if provided) --- # + if self.io: + self.dirs = { + 'results': self.io.run_results_dir, + 'models': self.io.run_models_dir, + 'logs': self.io.run_logs_dir, + 'figures': self.io.run_figures_dir # Added figures dir + } + self.base_models_dir_path = self.io.models_dir # Base models dir (parent of run dir) + self.current_run_models_dir = self.io.run_models_dir + # Logging setup is now done *before* pipeline init in run.py + # self._setup_logging() # Remove internal logging setup + else: + logger_to_use.warning("IOManager not provided. Setting up directories manually.") + # Fallback manual setup (might be removed if IOManager is required) + self._setup_directories_manual() + self._setup_logging_manual() # Fallback logging + # --- End Directory Setup --- # + + # --- Feature Whitelist Modification --- # + # Add 'bar_imputed' to the minimal whitelist if it's not already there + if 'bar_imputed' not in minimal_whitelist: + minimal_whitelist.append('bar_imputed') + logger.info("Added 'bar_imputed' to minimal_whitelist.") + # --- End Whitelist Modification --- # + + # Log Banner (Moved to run.py which has version info) + # logger_to_use.info(...) + + # --- Initialize Components --- # + # Extract db_directory from config before passing to DataLoader + data_cfg = self.config.get('data', {}) + db_directory_path = data_cfg.get('db_dir', 'data/db') # Use 'db_dir' key, provide a default + if not db_directory_path or not isinstance(db_directory_path, str): + # Adjust error message to reflect the correct key 'db_dir' + logger.error(f"Invalid or missing 'db_dir' in config['data']. Found: {db_directory_path}. Using default 'data/db'.") + db_directory_path = 'data/db' # Fallback + + self.data_loader = DataLoader(db_dir=db_directory_path) + self.feature_engineer = FeatureEngineer(self.config) + # Extract edge threshold for Calibrator initialization + calibration_cfg = self.config.get('calibration', {}) + initial_edge_threshold = calibration_cfg.get('edge_threshold', 0.1) # Get edge from config + self.calibrator = Calibrator(edge_threshold=initial_edge_threshold) + # --- Vector Calibrator (Task 4) --- # + if VECTOR_CALIBRATOR_AVAILABLE: + self.vector_calibrator = VectorCalibrator() # Initialize without config + else: + self.vector_calibrator = None + # --- End Vector Calibrator --- # + # Initialize SACTrainer only when needed (in train_or_load_sac) + # self.sac_trainer = SACTrainer(config=self.config) + self.backtester = Backtester(self.config, io_manager=self.io) # Pass io_manager + # Initialize gru_handler (needs run_id and models dir) + self.gru_handler = GRUModelHandler( + run_id=self.run_id, + models_dir=self.current_run_models_dir, + config=self.config + ) + # Initialize BaselineChecker + self.baseline_checker = BaselineChecker(self.config, io=self.io) # Pass io + # --- End Initialize Components --- # + + # --- Initialize state variables --- # + self.df_raw = None + self.load_summary = None # Store load summary + self.df_engineered_full = None + # self.df_features_minimal = None # Removed minimal pruning here + self.df_labeled_aligned = None + self.X_raw_aligned = None + self.y_aligned = None + self.y_dir_aligned = None + self.X_train_raw = None + self.X_val_raw = None + self.X_test_raw = None + self.y_train = None + self.y_val = None + self.y_test = None + self.df_train_original = None # Store original data for splits + self.df_val_original = None + self.df_test_original = None + self.y_dir_train = None + self.scaler = None + self.X_train_scaled = None + self.X_val_scaled = None + self.X_test_scaled = None + self.final_whitelist = None + self.X_train_pruned = None + self.X_val_pruned = None + self.X_test_pruned = None + self.X_train_seq = None + self.X_val_seq = None + self.X_test_seq = None + self.y_train_seq_dict = None + self.y_val_seq_dict = None + self.y_test_seq_dict = None + self.train_indices = None + self.val_indices = None + self.test_indices = None + self.gru_model = None + self.gru_model_run_id_loaded_from = None # Track which run ID model came from + self.optimal_T = None + self.vector_cal_params = None # Store vector calibration parameters + self.sac_agent_load_path = None # Path to the SAC agent to load for backtesting + self.backtest_results_df = None + self.backtest_metrics = None + self.metrics_log_df = None # For logging detailed metrics + self.use_ternary = self.config.get('gru', {}).get('use_ternary', False) # Cache ternary flag + self.aggregated_metrics: Optional[dict] = None # Aggregated metrics across folds + self.optimized_edge_threshold: Optional[float] = None # Store optimized edge threshold per fold + # --- Add attributes for baseline filtering --- # + self.fwd_returns_aligned: Optional[pd.Series] = None + self.eps_aligned: Optional[pd.Series] = None + # Attributes for split returns/eps needed by baseline check + self.fwd_ret_train: Optional[pd.Series] = None + self.eps_train: Optional[pd.Series] = None + self.fwd_ret_val: Optional[pd.Series] = None + self.eps_val: Optional[pd.Series] = None + self.y_dir_val_ordinal: Optional[pd.Series] = None # <<< ADDED + # --- Feature selection optimization --- # + self.initial_l1_whitelist: Optional[List[str]] = None + # --- End Add --- # + # --- End Initialize state variables --- # + + # Save config handled by run.py via IOManager typically + # self._save_run_config() + if self.io: + # Maybe save it again here for completeness? Or rely on run.py? + config_save_path = self.io.path('results', 'run_config', suffix='.yaml') + try: + with open(config_save_path, 'w') as f: + yaml.dump(self.config, f, default_flow_style=False) + logger_to_use.info(f"Saved run configuration copy to {config_save_path}") + except Exception as e: + logger_to_use.error(f"Failed to save run configuration via IOManager: {e}") + else: + logger_to_use.warning("IOManager not available, cannot save run config copy from pipeline.") + + self.fold_generator = None # Initialize fold generator attribute + + # --- Remove Wrapper Method Definitions --- # + # def define_labels_and_align(self, df_engineered: pd.DataFrame) -> Tuple[pd.DataFrame, str, List[str]]: + # # ... implementation removed ... + # pass + + # def split_data(self, df_labeled_aligned_fold: pd.DataFrame, fold_dates: Optional[Tuple] = None): + # # ... implementation removed ... + # pass + + # def run_baseline_checks(self, fold_num: int, X_train_to_use: Optional[pd.DataFrame] = None, X_val_to_use: Optional[pd.DataFrame] = None) -> Dict[str, Any]: + # # ... implementation removed ... + # pass + # --- End Remove Wrapper Methods --- # + + # --- execute method refactored for Walk-Forward --- # + def execute(self): + """Runs the full trading pipeline, potentially using Walk-Forward validation.""" + logger.info(f"--- Starting Trading Pipeline: Run ID {self.run_id} ---") + + # 1. Load and Preprocess FULL Data + # Replace direct call with call to data_loader using config + data_cfg = self.config.get('data', {}) + ticker = data_cfg.get('ticker') + exchange = data_cfg.get('exchange') + start_date = data_cfg.get('start_date') + end_date = data_cfg.get('end_date') + interval = data_cfg.get('interval', '1min') # Default to 1min if not specified + vol_sampling_cfg = data_cfg.get('volatility_sampling', {}) + vol_sampling_enabled = vol_sampling_cfg.get('enabled', False) + vol_window = vol_sampling_cfg.get('window', 30) + vol_quantile = vol_sampling_cfg.get('quantile', 0.5) + + if not all([ticker, exchange, start_date, end_date]): + logger.critical("Missing required data parameters (ticker, exchange, start_date, end_date) in config['data']. Cannot load data.") + raise ValueError("Missing required data configuration.") + + try: + logger.info(f"Loading data for {ticker} from {exchange} between {start_date} and {end_date} at {interval} interval.") + # Call the load_data method on the data_loader instance + self.df_raw = self.data_loader.load_data( + ticker=ticker, + exchange=exchange, + start_date=start_date, + end_date=end_date, + interval=interval, + vol_sampling=vol_sampling_enabled, + vol_window=vol_window, + vol_quantile=vol_quantile + ) + # The load_data method in DataLoader now includes imputation and reporting + + except Exception as e: + logger.critical(f"Failed to load initial data: {e}", exc_info=True) + self.df_raw = None + raise SystemExit("Failed to load initial data.") from e + + if self.df_raw is None or self.df_raw.empty: + logger.error("Initial data loading and preprocessing resulted in an empty DataFrame. Exiting pipeline.") + if hasattr(self.data_loader, 'last_load_summary') and self.data_loader.last_load_summary: + logger.error(f"Load Summary: {self.data_loader.last_load_summary}") + raise SystemExit("Data loading failed.") + + # --- Initialize FoldGenerator AFTER data is loaded --- # + try: + logger.info("Initializing FoldGenerator...") + self.fold_generator = FoldGenerator(config=self.config, df_raw=self.df_raw, io_manager=self.io) + logger.info("FoldGenerator initialized.") + except Exception as fg_init_err: + logger.error(f"Failed to initialize FoldGenerator: {fg_init_err}", exc_info=True) + raise SystemExit("FoldGenerator initialization failed.") from fg_init_err + # --- END FoldGenerator Initialization --- # + + # 2. Generate Walk-Forward Folds using the generator instance + try: + # Make generate_folds return an iterator/generator + fold_dates_iterator = self.fold_generator.generate_folds() + except Exception as fg_gen_err: + logger.error(f"Error calling fold_generator.generate_folds(): {fg_gen_err}", exc_info=True) + raise SystemExit("Fold generation failed.") from fg_gen_err + + self.all_fold_metrics = [] # Reset fold metrics list for the run + all_successful_sac_agent_paths = [] # Store paths of successfully trained SAC agents per fold + fold_count = 0 + + # 3. Loop Through Folds + for fold_dates in fold_dates_iterator: + fold_count += 1 + self.current_fold = fold_count # Set current fold number + logger.info(f"=== Processing Fold {self.current_fold} ===") + self.current_fold_report_data = [] # <<< RE-INITIALIZE report list for the fold >>> + + # --- Handle single split case is now handled INTERNALLY by FoldGenerator --- + # No need for the explicit single split date calculation here anymore. + # The generator yields the correct dates directly. + + if fold_dates is None or len(fold_dates) != 6: + logger.error(f"Fold {self.current_fold}: Invalid fold dates received from generator: {fold_dates}. Skipping.") + self.all_fold_metrics.append({'fold': self.current_fold, 'status': 'error', 'reason': 'Invalid fold dates received'}) + continue + + train_start, train_end, val_start, val_end, test_start, test_end = fold_dates + + # Setup fold-specific directories using IOManager if available + if self.io: + # self.fold_dirs = self.io.setup_fold_dirs(self.current_fold) + self.fold_dirs = self.io.get_fold_dirs(self.current_fold) # Correct method name + else: + self.fold_dirs = {} # Set empty if no IOManager + logger.warning(f"Fold {self.current_fold}: IOManager not available, cannot create fold-specific directories.") + + # Select data for the current fold using the ORIGINAL df_raw + # The generator used df_raw_tagged internally, but pipeline stages use df_raw + fold_start_date = train_start + # Determine the latest end date actually provided for slicing + fold_end_candidates = [d for d in [test_end, val_end, train_end] if pd.notna(d)] + if not fold_end_candidates: + logger.error(f"Fold {self.current_fold}: No valid end date found in fold dates: {fold_dates}. Skipping fold.") + self.all_fold_metrics.append({'fold': self.current_fold, 'status': 'error', 'reason': 'No valid end date'}) + continue + fold_end_date = max(fold_end_candidates) + + # Ensure start/end dates have timezone consistent with df_raw index if applicable + if self.df_raw.index.tz is not None: + if pd.notna(fold_start_date) and fold_start_date.tzinfo is None: + fold_start_date = fold_start_date.tz_localize(self.df_raw.index.tz) + if pd.notna(fold_end_date) and fold_end_date.tzinfo is None: + fold_end_date = fold_end_date.tz_localize(self.df_raw.index.tz) + + # Check validity before slicing + if pd.isna(fold_start_date) or pd.isna(fold_end_date) or fold_end_date < fold_start_date: + logger.error(f"Fold {self.current_fold}: Invalid start/end range for slicing df_raw: Start={fold_start_date}, End={fold_end_date}. Skipping fold.") + self.all_fold_metrics.append({'fold': self.current_fold, 'status': 'error', 'reason': 'Invalid date range for fold data selection'}) + continue + + try: + current_fold_data_raw = self.df_raw.loc[fold_start_date:fold_end_date] + except Exception as slice_err: + logger.error(f"Fold {self.current_fold}: Error slicing df_raw for dates [{fold_start_date} to {fold_end_date}]: {slice_err}. Skipping fold.", exc_info=True) + self.all_fold_metrics.append({'fold': self.current_fold, 'status': 'error', 'reason': f'Error slicing raw data: {slice_err}'}) + continue + + if current_fold_data_raw.empty: + logger.warning(f"Fold {self.current_fold}: No raw data found for range {fold_start_date} to {fold_end_date}. Skipping fold.") + # Store failure info + self.all_fold_metrics.append({'fold': self.current_fold, 'status': 'skipped', 'reason': 'No raw data in date range'}) + continue + + logger.info(f"Fold {self.current_fold}: Raw data range for pipeline stages [{current_fold_data_raw.index.min()}, {current_fold_data_raw.index.max()}]") + + # --- Run Pipeline Steps within the Fold --- # + try: + # a. Engineer Features (using stage function) + df_engineered_fold = engineer_features_for_fold( + df=current_fold_data_raw, + feature_engineer=self.feature_engineer, + io=self.io, + config=self.config, + target_col=None + ) + if df_engineered_fold.empty: + raise SystemExit(f"Fold {self.current_fold}: Feature engineering resulted in empty dataframe.") + + # b. Define Labels and Align (Direct stage function call) + logger.info(f"Fold {self.current_fold}: Calling Stage: Defining Labels and Aligning...") + df_labeled_aligned_fold, target_dir_col, target_cols, fwd_returns_aligned, eps_aligned = define_labels_and_align_fold( + df_engineered=df_engineered_fold, + config=self.config + ) + if df_labeled_aligned_fold.empty: + raise SystemExit(f"Fold {self.current_fold}: Label definition stage resulted in empty dataframe.") + # Store results on self + self.target_dir_col = target_dir_col + self.target_columns = target_cols + self.fwd_returns_aligned = fwd_returns_aligned + self.eps_aligned = eps_aligned + + # c. Split data (Direct stage function call) + logger.info(f"Fold {self.current_fold}: Calling Stage: Splitting Data...") + ( + self.X_train_raw, self.X_val_raw, self.X_test_raw, + self.y_train, self.y_val, self.y_test, + self.df_train_original, self.df_val_original, self.df_test_original, + self.y_dir_train_ordinal, + self.fwd_ret_train, self.fwd_ret_val, + self.eps_train, self.eps_val, + self.y_dir_val_ordinal # Captures val ordinal dir labels + ) = split_data_fold( + df_labeled_aligned=df_labeled_aligned_fold, + fwd_returns_aligned=self.fwd_returns_aligned, + eps_aligned=self.eps_aligned, + config=self.config, + target_columns=self.target_columns, + target_dir_col=self.target_dir_col, + fold_dates=fold_dates, + current_fold=self.current_fold, + io=self.io + ) + # Note: split_data_fold raises SystemExit on failure + logger.info(f"Fold {self.current_fold}: Data splitting stage complete.") + + # d. Scale Features (using stage function) + (self.X_train_scaled, + self.X_val_scaled, + self.X_test_scaled, + self.scaler) = scale_features_fold( + X_train_raw=self.X_train_raw, + X_val_raw=self.X_val_raw, + X_test_raw=self.X_test_raw, + run_id=self.run_id, + fold_num=self.current_fold, + fold_models_dir=self.fold_dirs.get('models', '.'), + main_run_models_dir=self.current_run_models_dir, + ) + # Note: scale_features_fold now raises SystemExit on failure + + # <<< NEW: Coarse Univariate Filter >>> + X_train_coarse, X_val_coarse = None, None + feat_cfg = self.config.get('features', {}) + coarse_quantile = feat_cfg.get('coarse_univariate_quantile', 0.70) + logger.info(f"Fold {self.current_fold}: Applying coarse univariate filter (keeping top {1-coarse_quantile:.0%}) before first baseline check...") + try: + if self.y_dir_train_ordinal is not None and self.X_train_scaled is not None: + y_aligned, X_aligned = self.y_dir_train_ordinal.align(self.X_train_scaled, join='inner', axis=0) + y_numeric = pd.to_numeric(y_aligned, errors='coerce').fillna(0) + X_train_aligned_for_coarse = X_aligned + + if X_train_aligned_for_coarse.empty: + logger.warning(f"Fold {self.current_fold}: X_train became empty after alignment for coarse filter. Skipping filter.") + else: + corrs = X_train_aligned_for_coarse.apply(lambda col: abs(np.corrcoef(col, y_numeric)[0, 1]) if col.var() > 1e-6 else 0.0) + corrs.fillna(0, inplace=True) + if not corrs.empty and not corrs.isnull().all(): + quantile_value = corrs.quantile(coarse_quantile) + keep_mask = corrs >= quantile_value + keep_coarse_features = corrs[keep_mask].index.tolist() + if keep_coarse_features: + X_train_coarse = self.X_train_scaled[keep_coarse_features] + if self.X_val_scaled is not None: + X_val_coarse = self.X_val_scaled[[col for col in keep_coarse_features if col in self.X_val_scaled.columns]] + logger.info(f"Fold {self.current_fold}: Coarse filter kept {len(keep_coarse_features)} features (Corr >= {quantile_value:.4f}).") + else: + logger.warning(f"Fold {self.current_fold}: Coarse univariate filter removed all features. Baseline will run on full scaled set.") + else: + logger.warning(f"Fold {self.current_fold}: Could not compute valid correlations for coarse filter. Baseline will run on full scaled set.") + else: + logger.warning(f"Fold {self.current_fold}: Skipping coarse filter due to missing y_dir_train_ordinal or X_train_scaled.") + except Exception as coarse_err: + logger.error(f"Fold {self.current_fold}: Error during coarse univariate filtering: {coarse_err}. Baseline will run on full scaled set.", exc_info=True) + # <<< END NEW Filter >>> + + # e. Baseline Check (Direct stage function call) + # --- Select Input Data for Baseline --- # + run_on_coarse = X_train_coarse is not None + X_train_for_check = X_train_coarse if run_on_coarse else self.X_train_scaled + X_val_for_check = X_val_coarse if run_on_coarse else self.X_val_scaled + log_msg_data_source = "COARSELY FILTERED" if run_on_coarse else "FULL SCALED" + logger.info(f"Fold {self.current_fold}: Running baseline checks on {log_msg_data_source} data.") + # --- End Input Data Selection --- # + + # Check prerequisites before calling stage function + baseline_prereqs_met = True + required_attrs_baseline = [ + 'y_dir_train_ordinal', 'y_dir_val_ordinal', + 'fwd_ret_train', 'fwd_ret_val', + 'eps_train', 'eps_val', + 'baseline_checker' + ] + missing_attrs = [attr for attr in required_attrs_baseline if not hasattr(self, attr) or getattr(self, attr) is None] + if X_train_for_check is None or X_val_for_check is None: + missing_attrs.append("scaled_features (or coarse features)") + + if missing_attrs: + logger.error(f"Fold {self.current_fold}: Cannot run baseline checks. Missing required attributes/data: {missing_attrs}") + baseline_prereqs_met = False + baseline_report = { # Manually create error report + "gate_name": "baseline_checks", + "status": "error", + "timestamp": datetime.now(timezone.utc).isoformat(), + "reason": f"Missing required inputs: {missing_attrs}", + "details": {} + } + + if baseline_prereqs_met: + baseline_report = run_baseline_checks_fold( + X_train_scaled=X_train_for_check, + X_val_scaled=X_val_for_check, + y_train_dir_ordinal=self.y_dir_train_ordinal, + y_val_dir_ordinal=self.y_dir_val_ordinal, + fwd_ret_train=self.fwd_ret_train, + eps_train=self.eps_train, + fwd_ret_val=self.fwd_ret_val, + eps_val=self.eps_val, + baseline_checker=self.baseline_checker, + config=self.config, + io=self.io, + fold_num=self.current_fold, + fold_dirs=self.fold_dirs, + base_results_dir=self.io.run_results_dir if self.io else '.' + ) + + # --- Report Handling (same as before) --- # + self.current_fold_report_data.append(baseline_report) + if baseline_report.get('status') in ['failed', 'error']: + raise SystemExit(f"Fold {self.current_fold} failed baseline checks: {baseline_report.get('reason')}") + # --- End report handling --- # + + # f. Select Features & Prune (using stage functions) + logger.info(f"Fold {self.current_fold}: Selecting features...") + self.final_whitelist = select_features_fold( + X_train_raw=self.X_train_raw, + y_dir_train_ordinal=self.y_dir_train_ordinal, # Pass ordinal labels + feature_engineer=self.feature_engineer, + io=self.io, + run_id=self.run_id, + fold_num=self.current_fold, + fold_models_dir=self.fold_dirs.get('models', '.'), + fold_results_dir=self.fold_dirs.get('results', '.'), + main_run_models_dir=self.current_run_models_dir + ) + # Note: select_features_fold raises SystemExit on failure + + logger.info(f"Fold {self.current_fold}: Pruning scaled features using whitelist ({len(self.final_whitelist)} features)...") + (self.X_train_pruned, + self.X_val_pruned, + self.X_test_pruned) = prune_features_fold( + X_train_scaled=self.X_train_scaled, + X_val_scaled=self.X_val_scaled, + X_test_scaled=self.X_test_scaled, + final_whitelist=self.final_whitelist, + feature_engineer=self.feature_engineer, # Pass the instance + fold_num=self.current_fold + ) + # Note: prune_features_fold raises SystemExit on failure + + # <<< Pruning handled within select_and_prune_features >>> + # <<< The above calls replace the old self.select_and_prune_features() >>> + + # g. Post-Pruning Baseline Check (Direct stage function call) + # --- Capture report and check status --- # + # Replace wrapper call with direct stage function call + post_pruning_baseline_report = run_baseline_checks_fold( + X_train_scaled=self.X_train_pruned, # Pass PRUNED data + X_val_scaled=self.X_val_pruned, # Pass PRUNED data + y_train_dir_ordinal=self.y_dir_train_ordinal, + y_val_dir_ordinal=self.y_dir_val_ordinal, + fwd_ret_train=self.fwd_ret_train, + eps_train=self.eps_train, + fwd_ret_val=self.fwd_ret_val, + eps_val=self.eps_val, + baseline_checker=self.baseline_checker, + config=self.config, + io=self.io, + fold_num=self.current_fold, + fold_dirs=self.fold_dirs, + base_results_dir=self.io.run_results_dir if self.io else '.' + ) + # Rename the gate for clarity in reports + post_pruning_baseline_report['gate_name'] = 'post_pruning_baseline_checks' + self.current_fold_report_data.append(post_pruning_baseline_report) + if post_pruning_baseline_report.get('status') in ['failed', 'error']: + raise SystemExit(f"Fold {self.current_fold} failed post-pruning baseline checks: {post_pruning_baseline_report.get('reason')}") + # --- End report handling --- # + + # h. Update Scaled Data to Pruned Data + logger.info(f"Fold {self.current_fold}: Post-pruning checks passed. Updating scaled data to pruned versions for subsequent steps.") + self.X_train_scaled = self.X_train_pruned + self.X_val_scaled = self.X_val_pruned + self.X_test_scaled = self.X_test_pruned + # Clean up pruned attributes to avoid confusion if needed? Optional. + # del self.X_train_pruned, self.X_val_pruned, self.X_test_pruned + # <<< End Update >>> + + # i. Create Sequences (using stage function) + # Replace single method call with calls to stage function for each split + logger.info(f"Fold {self.current_fold}: Creating sequences...") + gru_cfg = self.config.get('gru', {}) + lookback = gru_cfg.get('lookback', 60) + + if not hasattr(self, 'target_columns') or not self.target_columns: + raise SystemExit(f"Fold {self.current_fold}: Target column names not set before create_sequences stage.") + + # --- Create Train Sequences --- # + (self.X_train_seq, + self.y_train_seq_dict, + self.train_indices, + dropped_train) = create_sequences_fold( + X_data=self.X_train_pruned, + y_data=self.df_train_original, # Original df contains targets + target_names=self.target_columns, + lookback=lookback, + name="Train", + config=self.config, + io=self.io + ) + if self.X_train_seq is None: + raise SystemExit(f"Fold {self.current_fold}: Failed to create training sequences.") + + # --- Create Validation Sequences --- # + (self.X_val_seq, + self.y_val_seq_dict, + self.val_indices, + dropped_val) = create_sequences_fold( + X_data=self.X_val_pruned, + y_data=self.df_val_original, # Original df contains targets + target_names=self.target_columns, + lookback=lookback, + name="Validation", + config=self.config, + io=self.io + ) + if self.X_val_seq is None: + raise SystemExit(f"Fold {self.current_fold}: Failed to create validation sequences.") + + # --- Create Test Sequences --- # + (self.X_test_seq, + self.y_test_seq_dict, + self.test_indices, + dropped_test) = create_sequences_fold( + X_data=self.X_test_pruned, + y_data=self.df_test_original, # Original df contains targets + target_names=self.target_columns, + lookback=lookback, + name="Test", + config=self.config, + io=self.io + ) + if self.X_test_seq is None: + raise SystemExit(f"Fold {self.current_fold}: Failed to create test sequences.") + # --- End Capture --- # + logger.info(f"Fold {self.current_fold}: Sequence creation complete. Dropped counts: Train={dropped_train}, Val={dropped_val}, Test={dropped_test}") + + # j. Train/Load GRU Model (using stage function) + # Replace method call with stage function call + logger.info(f"Fold {self.current_fold}: Training or loading GRU model...") + ( + self.gru_model, + self.gru_handler, # Handler might be updated internally + self.gru_model_run_id_loaded_from, + scaler_returned, # Scaler associated with the loaded/trained model + X_train_seq_new, + y_train_seq_dict_new, + train_indices_new, + X_val_seq_new, + y_val_seq_dict_new, + val_indices_new, + X_test_seq_new, + y_test_seq_dict_new, + test_indices_new + ) = train_or_load_gru_fold( + config=self.config, + run_id=self.run_id, + current_fold=self.current_fold, + current_run_models_dir=self.current_run_models_dir, + base_models_dir_path=self.base_models_dir_path, + gru_handler=self.gru_handler, + X_train_seq=self.X_train_seq, + y_train_seq_dict=self.y_train_seq_dict, + X_val_seq=self.X_val_seq, + y_val_seq_dict=self.y_val_seq_dict, + X_test_seq=self.X_test_seq, + y_test_seq_dict=self.y_test_seq_dict, + X_train_raw=self.X_train_raw, + X_val_raw=self.X_val_raw, + X_test_raw=self.X_test_raw, + y_train=self.df_train_original, + y_val=self.df_val_original, + y_test=self.df_test_original, + scaler=self.scaler, final_whitelist=self.final_whitelist, feature_engineer=self.feature_engineer, + train_indices=self.train_indices, + # <<< Pass original val/test indices >>> + val_indices=self.val_indices, + test_indices=self.test_indices, + io=self.io + ) + + # Handle potential re-processing: update sequences/scaler if new ones returned + if scaler_returned is not None: + self.scaler = scaler_returned + if X_train_seq_new is not None: + logger.info(f"Fold {self.current_fold}: Updating sequences due to GRU model loading/re-processing.") + self.X_train_seq = X_train_seq_new + self.y_train_seq_dict = y_train_seq_dict_new + self.train_indices = train_indices_new + self.X_val_seq = X_val_seq_new + self.y_val_seq_dict = y_val_seq_dict_new + self.val_indices = val_indices_new + self.X_test_seq = X_test_seq_new + self.y_test_seq_dict = y_test_seq_dict_new + self.test_indices = test_indices_new + + if self.gru_model is None: + raise SystemExit(f"Fold {self.current_fold}: Failed to train or load GRU model.") + + # k. Calibrate Probabilities (using stage function) + # Replace method call with stage function call + logger.info(f"Fold {self.current_fold}: Calibrating probabilities...") + (self.optimal_T, + self.vector_cal_params, + self.optimized_edge_threshold, + p_cal_val_for_check, # Store intermediate result for validation checks + y_dir_val_for_check) = calibrate_probabilities_fold( + config=self.config, + current_fold=self.current_fold, + gru_model=self.gru_model, + gru_handler=self.gru_handler, + X_val_seq=self.X_val_seq, + y_val_seq_dict=self.y_val_seq_dict, + use_ternary=self.use_ternary, + calibrator=self.calibrator, + vector_calibrator=self.vector_calibrator, + fold_dirs=self.fold_dirs, + current_run_models_dir=self.current_run_models_dir, + run_id=self.run_id, + io=self.io + ) + + # Store results needed for validation checks on self + self.p_cal_val_for_check = p_cal_val_for_check + self.y_dir_val_for_check = y_dir_val_for_check + # Optimized edge is already stored in self.optimized_edge_threshold by the call above + + # j.2. Perform GRU Validation Checks (Direct stage function call) + # --- Capture report and check status --- # + # Replace wrapper call with direct stage function call + gru_val_report = run_gru_validation_checks_fold( + config=self.config, + current_fold=self.current_fold, + p_cal_val=self.p_cal_val_for_check, # Use intermediate result + y_dir_val=self.y_dir_val_for_check, # Use intermediate result + optimized_edge_threshold=self.optimized_edge_threshold, + use_ternary=self.use_ternary, + io=self.io + ) + self.current_fold_report_data.append(gru_val_report) + if gru_val_report.get('status') in ['failed', 'error']: + raise SystemExit(f"Fold {self.current_fold} failed GRU validation checks: {gru_val_report.get('reason')}") + # --- End report handling --- # + + # k. Train/Load SAC Agent (Direct stage function call) + logger.info(f"Fold {self.current_fold}: Proceeding to SAC training/loading step.") + # Replace wrapper call with direct stage function call + # Instantiate SACTrainer here if training + sac_trainer_instance = None + if self.config.get('sac', {}).get('train_sac', False): + if not self.io: + logger.error(f"Fold {self.current_fold}: IOManager is required to initialize SACTrainer for training. Skipping SAC training.") + else: + sac_trainer_instance = SACTrainer(config=self.config, io_manager=self.io) + # Store instance if needed elsewhere? For now, just pass to stage func. + # self.sac_trainer = sac_trainer_instance + + self.sac_agent_load_path = train_or_load_sac_fold( + config=self.config, + current_fold=self.current_fold, + gru_model_run_id_loaded_from=self.gru_model_run_id_loaded_from, + base_models_dir_path=self.base_models_dir_path, + sac_trainer=sac_trainer_instance, # Pass instance or None + io=self.io + ) + # Optional: Add a report entry for SAC training status? + sac_train_status = { + 'gate_name': 'sac_training', + 'timestamp': datetime.now(timezone.utc).isoformat(), + 'status': 'completed' if self.sac_agent_load_path else 'failed_or_skipped', + 'agent_path': self.sac_agent_load_path, + 'reason': None if self.sac_agent_load_path else 'SAC training/loading did not yield a valid agent path' + } + self.current_fold_report_data.append(sac_train_status) + + # l. Run Backtest (Direct stage function call) + # --- Capture report and check status --- # + # Replace wrapper call with direct stage function call + # --- Prerequisites Check (simplified from wrapper) --- # + backtest_prereqs_met = True + if not all([ + self.X_test_seq is not None, + self.y_test_seq_dict is not None, + self.test_indices is not None, + self.df_test_original is not None + ]): + logger.error(f"Fold {self.current_fold}: Missing test sequences, indices, or original test df. Cannot run backtest stage.") + backtest_prereqs_met = False + # Create an error report manually if prereqs fail + backtest_report = { + "gate_name": "backtest_stage_execution", + "status": "error", + "timestamp": datetime.now(timezone.utc).isoformat(), + "reason": "Missing required data for backtest stage.", + "metrics": {}, + "thresholds": {} + } + + # --- Get Inputs for Backtest (simplified from wrapper) --- # + p_raw_test_input: Optional[np.ndarray] = None + logits_test_input: Optional[np.ndarray] = None + if backtest_prereqs_met: + cal_method = self.config.get('calibration', {}).get('method', 'temperature') + if cal_method == 'vector': + if self.gru_handler: + logits_test_input = self.gru_handler.predict_logits(self.X_test_seq) + if logits_test_input is None: + logger.error(f"Fold {self.current_fold}: Failed to get logits for vector calibration in backtest.") + backtest_prereqs_met = False + backtest_report = { "status": "error", "reason": "Failed to get logits for backtest vector calibration." } + else: + logger.error(f"Fold {self.current_fold}: GRU handler not available for logits prediction.") + backtest_prereqs_met = False + backtest_report = { "status": "error", "reason": "GRU handler missing for logits prediction." } + else: # Temperature or other methods might need raw probabilities + if self.gru_handler: + predictions_dict = self.gru_handler.predict(self.X_test_seq) + if predictions_dict is None: + logger.error(f"Fold {self.current_fold}: Failed to get GRU predictions for backtest.") + backtest_prereqs_met = False + backtest_report = { "status": "error", "reason": "Failed to get GRU predictions for backtest." } + else: + try: + if self.use_ternary: + p_raw_test_input = predictions_dict.get('dir3') + if p_raw_test_input is None: raise KeyError("'dir3' missing") + else: + prob_key = next((k for k in predictions_dict if k != 'mu' and k != 'dir3_logits'), None) + if prob_key: + p_raw_test_input = predictions_dict[prob_key] + if p_raw_test_input is not None and p_raw_test_input.ndim > 1: p_raw_test_input = p_raw_test_input.flatten() + else: + raise KeyError("Binary probability key not found") + except KeyError as e: + logger.error(f"Fold {self.current_fold}: Error extracting probabilities from GRU prediction dict: {e}") + backtest_prereqs_met = False + backtest_report = { "status": "error", "reason": f"Error extracting probabilities: {e}" } + else: + logger.error(f"Fold {self.current_fold}: GRU handler not available for prediction.") + backtest_prereqs_met = False + backtest_report = { "status": "error", "reason": "GRU handler missing for prediction." } + + # --- Call Backtest Stage Function --- # + if backtest_prereqs_met: + try: + results_df, metrics_dict, metrics_log, backtest_gate_report = run_backtest_fold( + config=self.config, + io=self.io, + current_fold=self.current_fold, + fold_dirs=self.fold_dirs, + sac_agent_load_path=self.sac_agent_load_path, + X_test_seq=self.X_test_seq, + y_test_seq_dict=self.y_test_seq_dict, + test_indices=self.test_indices, # Pass test_indices + df_test_original=self.df_test_original, + gru_handler=self.gru_handler, + calibrator=self.calibrator, + vector_calibrator=self.vector_calibrator, + initial_optimal_T=getattr(self, 'optimal_T', None), + initial_vector_params=getattr(self, 'vector_cal_params', None), + optimized_edge_threshold=self.optimized_edge_threshold, + p_raw_test=p_raw_test_input, + logits_test=logits_test_input, + use_ternary=self.use_ternary + ) + # Store results on self if successful execution + self.backtest_results_df = results_df + self.backtest_metrics = metrics_dict + self.metrics_log_df = metrics_log + backtest_report = backtest_gate_report # Use the report from the stage func + logger.info(f"Fold {self.current_fold}: Backtest stage execution completed.") + except Exception as e: + logger.error(f"Fold {self.current_fold}: An unexpected error occurred calling run_backtest_fold stage: {e}", exc_info=True) + backtest_report = { + "gate_name": "backtest_stage_execution", + "status": "error", + "timestamp": datetime.now(timezone.utc).isoformat(), + "reason": f"Unhandled exception in run_backtest_fold: {e}", + "metrics": {}, + "thresholds": {} + } + # else: backtest_report was already set due to prereq failure + + # Add the final backtest report (either success/fail from gate or error report) + self.current_fold_report_data.append(backtest_report) + if backtest_report.get('status') == 'error': # Only raise SystemExit on execution error + raise SystemExit(f"Fold {self.current_fold} encountered an error during backtest stage: {backtest_report.get('reason')}") + elif backtest_report.get('status') == 'failed': # Log warning if performance gate failed + logger.warning(f"Fold {self.current_fold} FAILED backtest performance gates: {backtest_report.get('reason')}. Proceeding to store metrics.") + # --- End report handling --- # + + # m. Persist Fold Artefacts & Store Metrics + logger.info(f"Storing metrics for Fold {self.current_fold}") + if self.backtest_metrics is not None: + fold_metrics = self.backtest_metrics.copy() + fold_metrics['fold_number'] = self.current_fold + fold_metrics['train_start'] = train_start.isoformat() if train_start else None + fold_metrics['train_end'] = train_end.isoformat() if train_end else None + fold_metrics['val_start'] = val_start.isoformat() if val_start else None + fold_metrics['val_end'] = val_end.isoformat() if val_end else None + fold_metrics['test_start'] = test_start.isoformat() if test_start else None + fold_metrics['test_end'] = test_end.isoformat() if test_end else None + # Determine status based on metrics + if 'Annualized Sharpe Ratio' in fold_metrics and not pd.isna(fold_metrics['Annualized Sharpe Ratio']): + fold_metrics['status'] = 'success' + else: + fold_metrics['status'] = 'failed_backtest' + self.all_fold_metrics.append(fold_metrics) + else: + logger.warning(f"Fold {self.current_fold}: No backtest metrics generated to store.") + self.all_fold_metrics.append({'fold_number': self.current_fold, 'status': 'failed_backtest', 'error': 'No metrics returned'}) + + # Store SAC Agent Path if Trained Successfully for this fold + if self.config.get('control', {}).get('train_sac', False) and self.sac_agent_load_path: + # Assuming sac_agent_load_path points to the *newly* trained agent if training occurred + if hasattr(self, 'sac_trainer') and self.sac_trainer and self.sac_trainer.last_saved_agent_path == self.sac_agent_load_path: + if os.path.exists(self.sac_agent_load_path): + logger.info(f"Fold {self.current_fold}: Storing successfully trained SAC agent path for aggregation: {self.sac_agent_load_path}") + all_successful_sac_agent_paths.append(self.sac_agent_load_path) + else: + logger.warning(f"Fold {self.current_fold}: SAC training reported success, but path {self.sac_agent_load_path} not found.") + elif self.sac_agent_load_path: # If path exists but wasn't from training this fold + pass # Don't add loaded agents to the aggregation list unless explicitly intended + + # <<< ADDED: Forward Baseline Check >>> + # Derive y_test_ordinal if possible + y_test_ordinal = None + if self.y_test is not None and self.target_dir_col in self.y_test.columns: + if self.use_ternary: + # Convert one-hot list back to ordinal for test set + test_dir_raw = self.y_test[self.target_dir_col] + valid_mask_test = test_dir_raw.notna() & test_dir_raw.apply(lambda x: isinstance(x, list) and len(x) == 3) + if valid_mask_test.any(): + ordinal_values_test = test_dir_raw[valid_mask_test].apply(np.argmax) + y_test_ordinal = pd.Series(np.nan, index=test_dir_raw.index) + y_test_ordinal[valid_mask_test] = ordinal_values_test + logger.info(f"Fold {self.current_fold}: Extracted ordinal test labels for forward check. Count: {valid_mask_test.sum()}") + else: + logger.warning(f"Fold {self.current_fold}: No valid ternary labels found in y_test for forward check.") + y_test_ordinal = pd.Series(dtype=float) # Empty series + else: + # Binary case: convert float (0.0/1.0) to int + y_test_ordinal = self.y_test[self.target_dir_col].astype(int) + else: + logger.warning(f"Fold {self.current_fold}: Cannot run forward baseline check: y_test or target_dir_col ('{self.target_dir_col}') is missing.") + + if y_test_ordinal is not None: + forward_check_passed = self.baseline_checker.run_forward_baseline_check( + X_train_fold=self.X_train_scaled, + y_train_fold_ordinal=self.y_dir_train_ordinal, + X_test_fold=self.X_test_scaled, + y_test_fold_ordinal=y_test_ordinal, + fold_num=self.current_fold, + io=self.io + ) + if not forward_check_passed: + # Raise SystemExit to be caught by the except block below + raise SystemExit(f"Fold {self.current_fold}: FAILED Forward Baseline Check Gate") + # <<< END ADDED >>> + + + except SystemExit as fold_exit: + # Log the reason for stopping the fold and continue + logger.error(f"Fold {self.current_fold} processing stopped: {fold_exit}. Skipping subsequent stages (including GRU Validation & SAC) for this fold.") + # Store failure reason more specifically + self.all_fold_metrics.append({'fold_number': self.current_fold, 'status': f'failed_{fold_exit}', 'error': str(fold_exit)}) # <<< Updated status logging + # Log that SAC wasn't reached + logger.info(f"Fold {self.current_fold}: *** Did NOT reach SAC training step due to SystemExit: {fold_exit}. ***") + continue # Proceed to the next fold + except Exception as e: # Catch any exception during fold processing + # Log the unexpected error and continue + logger.error(f"Fold {self.current_fold}: Unhandled exception occurred: {e}. Skipping to next fold.", exc_info=True) # Add exc_info for traceback + self.all_fold_metrics.append({'fold_number': self.current_fold, 'status': 'error', 'error': str(e)}) # <<< Updated status logging + # Log that SAC wasn't reached + logger.info(f"Fold {self.current_fold}: *** Did NOT reach SAC training step due to unhandled exception. ***") + continue # Proceed to the next fold + + # --- End Fold Loop --- # + logger.info(f"=== Finished Processing Fold {self.current_fold} ===") + # If only single split, break after first iteration - Generator handles this now + # No explicit break needed here anymore. + + # 4. Aggregate Fold Metrics & Final Decision + release_decision_passed = False + if self.all_fold_metrics: + self.aggregated_metrics = self.aggregate_fold_metrics(self.all_fold_metrics) + logger.info("--- Aggregated Walk-Forward Metrics --- ") + # Use json dumps for pretty printing dict/nested dict + # <<< ADD CHECK FOR NON-NONE AGGREGATED METRICS BEFORE LOGGING/SAVING >>> + if self.aggregated_metrics is not None: + logger.info(json.dumps(self.aggregated_metrics, default=lambda x: str(x) if isinstance(x, (pd.Timestamp, np.ndarray)) else x, indent=2)) + if self.io: + self.io.save_json(self.aggregated_metrics, 'aggregated_wf_metrics', section='results') + else: + logger.error("Aggregation resulted in None. Cannot log or save aggregated metrics.") + # <<< END CHECK >>> + + # <<< ADD CHECK FOR NON-NONE AGGREGATED METRICS BEFORE FINAL DECISION >>> + if self.aggregated_metrics is not None: + release_decision_passed = self.final_release_decision(self.aggregated_metrics) + else: + logger.error("Skipping final release decision because aggregated metrics are None.") + else: + logger.warning("No fold metrics were generated. Skipping aggregation and final decision.") + self.aggregated_metrics = {} # Ensure defined as empty dict if no metrics + + # Log Final Status + if release_decision_passed: + logger.info(f"--- Pipeline Run {self.run_id} finished successfully and meets release criteria. ---") + else: + logger.error(f"--- Pipeline Run {self.run_id} finished but FAILED to meet release criteria. See aggregated metrics and logs. ---") + + # 5. Aggregate SAC Agents (Optional) + if self.config.get('sac_aggregation', {}).get('enabled', False): + if all_successful_sac_agent_paths: + self.aggregate_sac_agents(all_successful_sac_agent_paths) + else: + logger.warning("SAC agent aggregation enabled, but no successfully trained fold agents were found/stored. Skipping aggregation.") + else: + logger.info("SAC agent aggregation disabled. Skipping.") + + + # --- Log Summary (using aggregated_metrics dict) --- # + logger.info(f"--- Aggregated Fold Metrics Summary ---") + agg_status = self.aggregated_metrics.get("aggregation_status", "unknown") + total_folds_agg = self.aggregated_metrics.get("total_folds", 0) + fully_successful_agg = self.aggregated_metrics.get("fully_successful_folds", 0) + partially_completed_agg = self.aggregated_metrics.get("partially_completed_folds", 0) + failed_early_agg = self.aggregated_metrics.get("failed_early_folds", 0) + pass_rate_agg = self.aggregated_metrics.get("pass_rate", 0.0) + failed_details_agg = self.aggregated_metrics.get("failed_fold_details", {}) + + logger.info(f"Aggregation Status: {agg_status}") + logger.info(f"Total Folds Attempted: {total_folds_agg}") + logger.info(f" - Fully Successful (passed all gates): {fully_successful_agg}") + logger.info(f" - Partially Completed (failed late gate): {partially_completed_agg}") + logger.info(f" - Failed Early / Error: {failed_early_agg}") + logger.info(f"Overall Pass Rate (Fully Successful / Total): {pass_rate_agg:.2%}") + if failed_details_agg: + logger.info("Details for Non-Successful Folds:") + # Sort by fold number if possible + try: + sorted_failed_details = sorted(failed_details_agg.items(), key=lambda item: int(item[0])) + except ValueError: + sorted_failed_details = failed_details_agg.items() # Fallback if fold numbers aren't ints + + for fold, details in sorted_failed_details: + status = details.get('status', 'unknown') + error = details.get('error', '') + logger.info(f" - Fold {fold}: Status='{status}', Error='{error}'") + # --- End Log Summary --- # + + def aggregate_fold_metrics(self, all_fold_metrics: List[dict]) -> dict: + """ + Aggregates metrics collected from successful walk-forward folds. + + Args: + all_fold_metrics (List[dict]): A list where each element is a dictionary + of metrics from a single fold, potentially + including 'status' or 'error' keys for failed folds. + + Returns: + dict: A dictionary containing summary statistics (mean, std, min, max, count) + for key performance metrics across fully successful folds, plus detailed fold counts and pass rate. + Returns an error dict if no folds completed sufficiently for aggregation. + """ + num_total_folds = len(all_fold_metrics) + if num_total_folds == 0: + logger.warning("No fold metrics provided for aggregation.") + return {"aggregation_status": "failed", "reason": "No fold metrics"} + + # --- Categorize Folds --- + fully_successful_folds_data = [] + partially_completed_folds_data = [] + failed_early_folds_data = [] + failed_fold_details = {} # Store status/error for non-successful folds + + # Define statuses indicating partial completion (ran significant steps but failed a late gate) + # Use startswith matching for flexibility (e.g., 'failed_backtest', 'failed_gru_gate', 'failed_forward_baseline') + partial_completion_prefixes = ('failed_backtest', 'failed_gru_gate', 'failed_forward_baseline') + # Define statuses indicating full success + full_success_status = 'success' + + for m in all_fold_metrics: + fold_num = m.get('fold_number', 'unknown') + status = m.get('status', 'unknown') + error_msg = m.get('error', '') + + # Ensure status is a string for startswith check + status_str = str(status) if status is not None else 'unknown' + + # Check for full success first + if status_str == full_success_status and 'Annualized Sharpe Ratio' in m and not pd.isna(m['Annualized Sharpe Ratio']): + fully_successful_folds_data.append(m) + else: + # Store details for all non-successful folds + failed_fold_details[fold_num] = {'status': status_str, 'error': str(error_msg)} + # Categorize further + is_partial = False + for partial_prefix in partial_completion_prefixes: + if status_str.startswith(partial_prefix): + partially_completed_folds_data.append(m) + is_partial = True + break + if not is_partial: + failed_early_folds_data.append(m) # Assume others failed earlier or had generic error + + num_fully_successful = len(fully_successful_folds_data) + num_partially_completed = len(partially_completed_folds_data) + num_failed_early = len(failed_early_folds_data) + # Verification: num_total_folds should equal sum of categories + if num_total_folds != (num_fully_successful + num_partially_completed + num_failed_early): + logger.warning(f"Fold count mismatch during aggregation: Total={num_total_folds}, Categories Sum = {num_fully_successful + num_partially_completed + num_failed_early}. Check status strings/logic.") + + pass_rate = num_fully_successful / num_total_folds if num_total_folds > 0 else 0.0 + # --- Log Summary within aggregation function - can be removed if logged outside --- + # logger.info(f"--- Aggregating Fold Metrics --- ") + # logger.info(f"Total Folds Attempted: {num_total_folds}") + # ... (rest of logging removed as it's now done outside) ... + # --- End Log Summary --- + + # ... (rest of the code remains unchanged) + + def final_release_decision(self, aggregated_metrics: dict) -> bool: + """ + Makes a final decision based on aggregated walk-forward metrics. + + Args: + aggregated_metrics (dict): Dictionary of aggregated metrics from + aggregate_fold_metrics. Can be None if aggregation failed. + + Returns: + bool: True if the performance meets release criteria, False otherwise. + """ + logger.info("--- Making Final Release Decision based on Aggregated Metrics ---") + + # <<< Updated check for aggregation status (Handles None input) >>> + if aggregated_metrics is None: + logger.error("Final Release Decision: FAILED - Aggregated metrics object is None. Cannot evaluate release criteria.") + return False + + if aggregated_metrics.get('aggregation_status') not in ['success', 'no_numeric_metrics', 'no_successful_folds']: + logger.error(f"Final Release Decision: FAILED - Aggregation status indicates failure ('{aggregated_metrics.get('aggregation_status', 'N/A')}'). Cannot evaluate release criteria.") + return False + + # --- Load Release Criteria from Config --- # + gate_config = self.config.get('validation_gates', {}).get('final_release', {}) + criteria = { + 'min_successful_folds_pct': gate_config.get('min_successful_folds_pct', 0.75), + 'median_sharpe_threshold': gate_config.get('median_sharpe_threshold', 1.3), + # Add other criteria as needed (e.g., max_drawdown_max_threshold) + 'max_drawdown_max_threshold': gate_config.get('max_drawdown_max_threshold', 20.0) # Example + } + logger.info(f"Using final release criteria from config: {criteria}") + # --- End Load Criteria --- # + + # --- Check Criteria --- # + passed_all = True + reasons = [] + + # 1. Successful Folds Percentage (Pass Rate) + total_folds = aggregated_metrics.get('total_folds', 0) + # Use the calculated pass_rate from the aggregation dict + pass_rate = aggregated_metrics.get('pass_rate', 0.0) + + if total_folds > 0: + if pass_rate < criteria['min_successful_folds_pct']: + passed_all = False + reasons.append(f"Insufficient pass rate (fully successful folds): {pass_rate:.1%} < {criteria['min_successful_folds_pct']:.1%}") + else: + reasons.append(f"Pass rate OK: {pass_rate:.1%} >= {criteria['min_successful_folds_pct']:.1%}") + else: + passed_all = False # Should have been caught by aggregation status check, but belt-and-suspenders + reasons.append("No folds were run.") + + # --- Check Performance Metrics (only if aggregation status allows) --- + if aggregated_metrics.get('aggregation_status') == 'success': + perf_metrics = aggregated_metrics.get('aggregated_performance_metrics', {}) + sharpe_key = None + if "Annualized Sharpe Ratio (Re-centred)" in perf_metrics: + sharpe_key = "Annualized Sharpe Ratio (Re-centred)" + elif "Annualized Sharpe Ratio" in perf_metrics: + sharpe_key = "Annualized Sharpe Ratio" + + # 2. Median Sharpe Ratio + if sharpe_key: + if hasattr(self, 'all_fold_metrics') and self.all_fold_metrics: + successful_fold_metrics_list = [m for m in self.all_fold_metrics if m.get('status') == 'success' and sharpe_key in m and not pd.isna(m[sharpe_key])] + if successful_fold_metrics_list: + successful_fold_metrics_df = pd.DataFrame(successful_fold_metrics_list) + median_sharpe = successful_fold_metrics_df[sharpe_key].median() + if np.isnan(median_sharpe) or median_sharpe < criteria['median_sharpe_threshold']: + passed_all = False + reasons.append(f"Median Sharpe ({sharpe_key}) too low: {median_sharpe:.2f} < {criteria['median_sharpe_threshold']}") + else: + reasons.append(f"Median Sharpe OK: {median_sharpe:.2f} >= {criteria['median_sharpe_threshold']}") + else: + passed_all = False + reasons.append(f"No fully successful folds with valid Sharpe metric '{sharpe_key}' found for median calculation.") + else: + passed_all = False + reasons.append("Could not access per-fold metrics to calculate median Sharpe.") + else: + # Only fail if the sharpe threshold is actually set in criteria + if 'median_sharpe_threshold' in criteria: + passed_all = False + reasons.append("Sharpe metric required by criteria but not found in aggregation results for median check.") + else: + reasons.append("Sharpe metric not found in aggregation results (median check skipped as not required by criteria).") + + + # 3. Max Drawdown + max_dd_key = "Max Drawdown (%)" + if 'max_drawdown_max_threshold' in criteria: + if max_dd_key in perf_metrics and 'max' in perf_metrics[max_dd_key]: + max_dd_max = perf_metrics[max_dd_key].get('max', np.nan) + max_dd_threshold = criteria['max_drawdown_max_threshold'] + if np.isnan(max_dd_max): + passed_all = False # Treat NaN as failure if threshold is set + reasons.append(f"Max Drawdown is NaN, cannot compare to threshold {max_dd_threshold}%.") + elif max_dd_max > max_dd_threshold: + passed_all = False + reasons.append(f"Max Drawdown exceeded limit: {max_dd_max:.2f}% > {max_dd_threshold}%") + else: + reasons.append(f"Max Drawdown OK: {max_dd_max:.2f}% <= {max_dd_threshold}%") + else: + passed_all = False # Metric required but not found or 'max' stat missing + reasons.append(f"Required Max Drawdown metric ('{max_dd_key}' with 'max' stat) not found in aggregation.") + # ... Add similar checks for other criteria ... + + elif aggregated_metrics.get('aggregation_status') == 'no_numeric_metrics': + reasons.append("Cannot check performance metrics (Sharpe, Drawdown) as none were aggregated.") + # Decision depends on whether min_successful_folds_pct was met + if not passed_all: # If pass rate check already failed + pass # Keep passed_all as False + else: # Pass rate was ok, but no metrics to check + # Decide if this constitutes a failure - let's say yes if performance metrics are crucial + if 'median_sharpe_threshold' in criteria or 'max_drawdown_max_threshold' in criteria: + passed_all = False + reasons.append("Release criteria require performance metrics, but none were available for checking.") + + elif aggregated_metrics.get('aggregation_status') == 'no_successful_folds': + reasons.append("Cannot check performance metrics (Sharpe, Drawdown) as no folds were fully successful.") + # pass_rate check should have already failed if min_successful_folds_pct > 0 + passed_all = False # Explicitly set to False + + # --- Log Final Decision --- # + decision_str = "PASSED" if passed_all else "FAILED" + logger.info(f"Final Release Decision: {decision_str}") + for reason in reasons: + logger.info(f" - {reason}") + + return passed_all + # Removed dangling else/except block that didn't belong here + + # ... (rest of the code remains unchanged) + + def aggregate_sac_agents(self, agent_paths: List[str]): + """Averages the weights of multiple SAC agents saved during walk-forward folds.""" + if not agent_paths: + logger.warning("No SAC agent paths provided for aggregation. Skipping.") + return + + logger.info(f"Starting aggregation of {len(agent_paths)} SAC agents...") + + # --- Load State Dicts --- # + all_state_dicts = [] + reference_state_dict = None + for path in agent_paths: + try: + # Load onto CPU to avoid potential GPU memory issues during aggregation + state_dict = torch.load(path, map_location=torch.device('cpu')) + + # Basic validation: check if it's a dictionary + if not isinstance(state_dict, dict): + logger.warning(f"Skipping agent at {path}: Loaded object is not a dictionary.") + continue + + # Further validation: Check if essential keys exist (adjust based on SACTradingAgent structure) + # Example check (modify based on actual keys): + required_keys = ['actor_state_dict', 'critic_state_dict'] + if not all(key in state_dict for key in required_keys): + logger.warning(f"Skipping agent at {path}: Missing required keys ({required_keys}).") + continue + + all_state_dicts.append(state_dict) + if reference_state_dict is None: + reference_state_dict = state_dict # Use first valid one as structure reference + + # Optional: Add more detailed structure comparison here if needed + # e.g., check if keys within actor/critic state dicts match reference + + except FileNotFoundError: + logger.warning(f"Skipping agent aggregation: File not found at {path}") + except Exception as e: + logger.warning(f"Skipping agent aggregation: Failed to load agent from {path}: {e}") + + if len(all_state_dicts) < 2: + logger.warning(f"Need at least two valid agents for aggregation, found {len(all_state_dicts)}. Skipping aggregation.") + return + + if reference_state_dict is None: + logger.error("Could not load any valid reference state dict. Cannot proceed with aggregation.") + return + + logger.info(f"Successfully loaded {len(all_state_dicts)} valid state dictionaries for aggregation.") + + # --- Average Parameters --- # + # Deepcopy the structure from the reference + aggregated_state_dict = copy.deepcopy(reference_state_dict) + + # Iterate through components (actor, critic, etc.) + for component_key in reference_state_dict.keys(): # e.g., 'actor_state_dict', 'critic_state_dict' + if isinstance(reference_state_dict[component_key], OrderedDict): + logger.debug(f"Aggregating parameters for component: {component_key}") + # Iterate through parameters within the component's state dict + for param_key in reference_state_dict[component_key].keys(): + # Sum tensors from all loaded state dicts + summed_tensor = None + valid_tensors_count = 0 + for sd in all_state_dicts: + # Check if the component and parameter exist in this state dict + if component_key in sd and isinstance(sd[component_key], OrderedDict) and param_key in sd[component_key]: + tensor = sd[component_key][param_key] + if summed_tensor is None: + summed_tensor = tensor.clone().detach().float() # Use float for averaging + else: + # Ensure tensors are compatible before adding + if summed_tensor.shape == tensor.shape: + summed_tensor += tensor.float() + else: + logger.warning(f"Shape mismatch for {component_key}.{param_key}. Skipping tensor from one agent.") + continue # Skip this tensor, but potentially count others? + valid_tensors_count += 1 + else: + logger.warning(f"Parameter {component_key}.{param_key} not found in one of the state dicts. Skipping.") + + # Calculate average if tensors were found and summed + if summed_tensor is not None and valid_tensors_count > 0: + averaged_tensor = summed_tensor / valid_tensors_count + # Update the aggregated state dict + aggregated_state_dict[component_key][param_key] = averaged_tensor + logger.debug(f" Averaged {component_key}.{param_key} from {valid_tensors_count} agents.") + elif valid_tensors_count == 0: + logger.warning(f" Could not find any valid tensors for {component_key}.{param_key}. Keeping reference value.") + else: + # Handle non-state-dict items if necessary (e.g., config, optimizers - usually not averaged) + logger.debug(f"Skipping non-state-dict component: {component_key}") + # Keep the value from the reference dict for non-parameter items + aggregated_state_dict[component_key] = reference_state_dict[component_key] + + # --- Save Aggregated Agent --- # + if self.io: + try: + agg_agent_filename = 'sac_agent_aggregated' # Consistent filename + save_path = self.io.path('models', agg_agent_filename, suffix='.pth') # Save in main run models dir + torch.save(aggregated_state_dict, save_path) + logger.info(f"Successfully saved aggregated SAC agent state dictionary to: {save_path}") + except Exception as e: + logger.error(f"Failed to save aggregated SAC agent: {e}", exc_info=True) + else: + logger.warning("IOManager not available. Cannot save aggregated SAC agent.") + \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_calibration.py b/gru_sac_predictor/tests/test_calibration.py new file mode 100644 index 00000000..2634b72f --- /dev/null +++ b/gru_sac_predictor/tests/test_calibration.py @@ -0,0 +1,183 @@ +""" +Tests for probability calibration (Sec 6 of revisions.txt). +""" +import pytest +import numpy as np +from scipy.stats import binomtest +from scipy.special import logit, expit +import os + +# Try to import the modules; skip tests if not found (e.g., path issues) +try: + from gru_sac_predictor.src import calibrate +except ImportError: + calibrate = None + +# --- Import VectorCalibrator (Task 4) --- # +try: + from gru_sac_predictor.src.calibrator_vector import VectorCalibrator +except ImportError: + VectorCalibrator = None +# --- End Import --- # + +# --- Helper Function for ECE --- # +def _calculate_ece(probs: np.ndarray, y_true: np.ndarray, n_bins: int = 10) -> float: + """ + Calculates the Expected Calibration Error (ECE). + + Args: + probs (np.ndarray): Predicted probabilities for the positive class (N,) or all classes (N, K). + y_true (np.ndarray): True labels (0 or 1 for binary, or class index for multi-class). + n_bins (int): Number of bins to divide probabilities into. + + Returns: + float: The calculated ECE score. + """ + if len(probs.shape) == 1: # Binary case + p_max = probs + y_pred_class = (probs > 0.5).astype(int) + y_true_class = y_true + elif len(probs.shape) == 2: # Multi-class case + p_max = np.max(probs, axis=1) + y_pred_class = np.argmax(probs, axis=1) + # If y_true is one-hot, convert to class index + if len(y_true.shape) == 2 and y_true.shape[1] > 1: + y_true_class = np.argmax(y_true, axis=1) + else: + y_true_class = y_true # Assume already class index + else: + raise ValueError("probs array must be 1D or 2D") + + ece = 0.0 + bin_boundaries = np.linspace(0, 1, n_bins + 1) + + for i in range(n_bins): + in_bin = (p_max > bin_boundaries[i]) & (p_max <= bin_boundaries[i+1]) + prop_in_bin = np.mean(in_bin) + + if prop_in_bin > 0: + accuracy_in_bin = np.mean(y_pred_class[in_bin] == y_true_class[in_bin]) + avg_confidence_in_bin = np.mean(p_max[in_bin]) + ece += np.abs(accuracy_in_bin - avg_confidence_in_bin) * prop_in_bin + + return ece +# --- End ECE Helper --- # + +# --- Fixtures --- +@pytest.fixture(scope="module") +def calibration_data(): + """ + Generate sample raw probabilities and true outcomes. + Simulates an overconfident model (T_implied < 1) where true probability drifts. + """ + np.random.seed(42) + n_samples = 2500 + # Simulate drifting true probability centered around 0.5 + drift = 0.05 * np.sin(np.linspace(0, 3 * np.pi, n_samples)) + true_prob = np.clip(0.5 + drift + np.random.randn(n_samples) * 0.05, 0.05, 0.95) + # Simulate overconfidence (implied T ~ 0.7) + raw_logits = logit(true_prob) / 0.7 + p_raw = expit(raw_logits) + # Generate true outcomes + y_true = (np.random.rand(n_samples) < true_prob).astype(int) + return p_raw, y_true + +# --- Tests --- +@pytest.mark.skipif(calibrate is None, reason="Module gru_sac_predictor.src.calibrate not found") +def test_optimise_temperature(calibration_data): + """Check if optimise_temperature runs and returns a plausible value.""" + p_raw, y_true = calibration_data + optimal_T = calibrate.optimise_temperature(p_raw, y_true) + print(f"\nOptimised T: {optimal_T:.4f}") + # Expect T > 0. A T near 0.7 would undo the simulated effect. + assert optimal_T > 0.1 and optimal_T < 5.0, "Optimised temperature seems out of expected range." + +@pytest.mark.skipif(calibrate is None, reason="Module gru_sac_predictor.src.calibrate not found") +def test_calibration_hit_rate_threshold(calibration_data): + """ + Verify that the lower 95% CI of the hit-rate for non-zero calibrated + signals is >= 0.55 (using the module's EDGE_THR). + """ + p_raw, y_true = calibration_data + optimal_T = calibrate.optimise_temperature(p_raw, y_true) + p_cal = calibrate.calibrate(p_raw, optimal_T) + action_signals = calibrate.action_signal(p_cal) + + # Filter for non-zero signals + non_zero_idx = action_signals != 0 + if not np.any(non_zero_idx): + pytest.fail("No non-zero action signals generated for hit-rate test.") + + signals_taken = action_signals[non_zero_idx] + actual_direction = y_true[non_zero_idx] + + # Hit: signal matches actual direction (1 vs 1, -1 vs 0) + hits = np.sum((signals_taken == 1) & (actual_direction == 1)) + \ + np.sum((signals_taken == -1) & (actual_direction == 0)) + total_trades = len(signals_taken) + + if total_trades < 30: + pytest.skip(f"Insufficient non-zero signals ({total_trades}) for reliable CI.") + + # Calculate 95% lower CI using binomial test + try: + # Ensure hits is integer + hits = int(hits) + result = binomtest(hits, total_trades, p=0.5, alternative='greater') + lower_ci = result.proportion_ci(confidence_level=0.95).low + except Exception as e: + pytest.fail(f"Binomial test failed: {e}") + + hit_rate = hits / total_trades + required_threshold = calibrate.EDGE_THR # Use threshold from module + + print(f"\nCalibration Test: EDGE_THR={required_threshold:.3f}") + print(f" Trades={total_trades}, Hits={hits}, Hit Rate={hit_rate:.4f}") + print(f" 95% Lower CI: {lower_ci:.4f}") + + assert lower_ci >= required_threshold, \ + f"Hit rate lower CI ({lower_ci:.4f}) is below module threshold ({required_threshold:.3f})" + +# --- Vector Scaling Test (Task 4.4) --- # +@pytest.mark.skipif(VectorCalibrator is None, reason="VectorCalibrator not found") +def test_vector_scaling_calibration(): + """Check if Vector Scaling reduces ECE on sample multi-class data.""" + np.random.seed(123) + n_samples = 5000 + num_classes = 3 + + # Simulate slightly miscalibrated logits (e.g., too peaky or too flat) + # True distribution is uniform-ish + true_labels = np.random.randint(0, num_classes, n_samples) + y_onehot = tf.keras.utils.to_categorical(true_labels, num_classes=num_classes) + + # Generate logits - make class 1 slightly more likely, and make logits "peaky" + logits_raw = np.random.randn(n_samples, num_classes) * 0.5 # Base noise + logits_raw[:, 1] += 0.5 # Bias towards class 1 + # Add systematic miscalibration (e.g., scale up logits -> overconfidence) + logits_miscalibrated = logits_raw * 1.8 + + # Instantiate calibrator + vector_cal = VectorCalibrator() + + # Calculate ECE before calibration + probs_uncal = vector_cal._softmax(logits_miscalibrated) + ece_before = _calculate_ece(probs_uncal, true_labels) + + # Fit vector scaling + vector_cal.fit(logits_miscalibrated, y_onehot) + assert vector_cal.W is not None and vector_cal.b is not None, "Vector scaling fit failed" + + # Calibrate probabilities + probs_cal = vector_cal.calibrate(logits_miscalibrated) + + # Calculate ECE after calibration + ece_after = _calculate_ece(probs_cal, true_labels) + + print(f"\nVector Scaling Test: ECE Before = {ece_before:.4f}, ECE After = {ece_after:.4f}") + + # Assert that ECE improved (decreased) + # Allow for slight numerical noise, but expect significant improvement + assert ece_after < ece_before * 0.7, f"ECE did not improve significantly after Vector Scaling (Before: {ece_before:.4f}, After: {ece_after:.4f})" + # Assert ECE is reasonably low after calibration + assert ece_after < 0.05, f"ECE after Vector Scaling ({ece_after:.4f}) is higher than expected (< 0.05)" \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_data_loader.py b/gru_sac_predictor/tests/test_data_loader.py new file mode 100644 index 00000000..3f0a6ad9 --- /dev/null +++ b/gru_sac_predictor/tests/test_data_loader.py @@ -0,0 +1,253 @@ +import pytest +import pandas as pd +import numpy as np +from omegaconf import OmegaConf, DictConfig +from unittest.mock import MagicMock, call # For mocking IOManager and logger +import os +import tempfile +import json + +# Adjust the import path based on your project structure +from gru_sac_predictor.src.data_loader import ( + find_missing_bars, + _consecutive_gaps, + fill_missing_bars, + report_missing # Import report_missing if we want to test its side effects +) +from gru_sac_predictor.src.io_manager import IOManager # Adjust path if needed + +# --- Test Fixtures --- + +@pytest.fixture +def gappy_dataframe(): + """Creates a DataFrame with missing timestamps.""" + dates = pd.to_datetime([ + '2023-01-01 00:00:00', '2023-01-01 00:01:00', + # Missing 00:02:00, 00:03:00 (2 bars) + '2023-01-01 00:04:00', + # Missing 00:05:00 (1 bar) + '2023-01-01 00:06:00', '2023-01-01 00:07:00', + # Missing 00:08:00, 00:09:00, 00:10:00 (3 bars) + '2023-01-01 00:11:00', '2023-01-01 00:12:00' + ]) + data = { + 'open': [100, 101, 104, 106, 107, 111, 112], + 'high': [100.5, 101.5, 104.5, 106.5, 107.5, 111.5, 112.5], + 'low': [99.5, 100.5, 103.5, 105.5, 106.5, 110.5, 111.5], + 'close': [100.2, 101.2, 104.2, 106.2, 107.2, 111.2, 112.2], + 'volume': [10, 11, 14, 16, 17, 21, 22] + } + df = pd.DataFrame(data, index=dates) + df.index.name = 'timestamp' + return df + +@pytest.fixture +def base_config(): + """Creates a base OmegaConf config for testing.""" + conf = OmegaConf.create({ + 'data': { + 'bar_frequency': '1T', # Using 'T' for minute frequency + 'missing': { + 'strategy': 'neutral', # Default strategy + 'max_gap': 5, + 'interpolate': { + 'method': 'linear', + 'limit': 10 + } + } + } + # Add other necessary sections if fill_missing_bars requires them + }) + return conf + +@pytest.fixture +def mock_io_manager(): + """Creates a mock IOManager that saves to a temporary directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Mock the IOManager methods needed by report_missing + mock_io = MagicMock(spec=IOManager) + mock_io.results_dir = tmpdir + + # Define how save_json should behave + saved_jsons = {} + def mock_save_json(data, filename, **kwargs): + filepath = os.path.join(tmpdir, filename) + saved_jsons[filename] = data # Store saved data for inspection + with open(filepath, 'w') as f: + json.dump(data, f, **kwargs) + print(f"Mock saving {filename} to {filepath}") + + mock_io.save_json.side_effect = mock_save_json + mock_io.get_artifact_path.side_effect = lambda filename: os.path.join(tmpdir, filename) + + # Attach the saved data dictionary for test inspection + mock_io._saved_jsons = saved_jsons + + yield mock_io # Provide the mock object to the test + + +@pytest.fixture +def mock_logger(): + """Creates a mock logger.""" + return MagicMock() + +# --- Test Functions --- + +def test_find_missing_bars(gappy_dataframe): + expected_missing = pd.to_datetime([ + '2023-01-01 00:02:00', '2023-01-01 00:03:00', + '2023-01-01 00:05:00', + '2023-01-01 00:08:00', '2023-01-01 00:09:00', '2023-01-01 00:10:00' + ]) + missing = find_missing_bars(gappy_dataframe, freq='T') + pd.testing.assert_index_equal(missing, expected_missing, check_names=False) + +def test_consecutive_gaps(gappy_dataframe): + missing = find_missing_bars(gappy_dataframe, freq='T') + gaps = _consecutive_gaps(missing) + assert sorted(gaps) == sorted([2, 1, 3]) # Check counts of consecutive missing bars + +def test_fill_missing_bars_strategy_drop(gappy_dataframe, base_config, mock_io_manager, mock_logger): + cfg = base_config.copy() + cfg.data.missing.strategy = 'drop' + + original_df = gappy_dataframe.copy() + filled_df = fill_missing_bars(original_df, cfg, mock_io_manager, mock_logger) + + # Strategy 'drop' should return the original data, but with imputed flag + assert 'bar_imputed' in filled_df.columns + pd.testing.assert_frame_equal(filled_df.drop(columns=['bar_imputed']), original_df) + + # Check imputed flags (should be false for existing bars) + assert not filled_df['bar_imputed'].any() + +def test_fill_missing_bars_strategy_neutral(gappy_dataframe, base_config, mock_io_manager, mock_logger): + cfg = base_config.copy() + cfg.data.missing.strategy = 'neutral' + + filled_df = fill_missing_bars(gappy_dataframe.copy(), cfg, mock_io_manager, mock_logger) + + assert 'bar_imputed' in filled_df.columns + assert len(filled_df) == 13 # Original 7 + 6 missing + + # Check imputed timestamps + missing_ts = find_missing_bars(gappy_dataframe, freq='T') + pd.testing.assert_index_equal(filled_df[filled_df['bar_imputed']].index, missing_ts) + assert not filled_df[~filled_df['bar_imputed']].index.isin(missing_ts).any() + + # Check neutral fill logic for a specific imputed bar (00:02:00) + imputed_bar = filled_df.loc['2023-01-01 00:02:00'] + last_close = gappy_dataframe.loc['2023-01-01 00:01:00', 'close'] + assert imputed_bar['open'] == last_close + assert imputed_bar['high'] == last_close + assert imputed_bar['low'] == last_close + assert imputed_bar['close'] == last_close + assert imputed_bar['volume'] == 0 + assert imputed_bar['bar_imputed'] == True + + # Check a non-imputed bar remains unchanged + original_bar = gappy_dataframe.loc['2023-01-01 00:04:00'] + filled_bar = filled_df.loc['2023-01-01 00:04:00'] + assert filled_bar['close'] == original_bar['close'] + assert filled_bar['volume'] > 0 # Should not be 0 + assert filled_bar['bar_imputed'] == False + + # Check log message (optional, depends on mocking detail) + mock_logger.warning.assert_called_once() + assert 'Detected 6 missing bars' in mock_logger.warning.call_args[0][0] + # Check saved report (optional) + assert 'missing_bars_summary.json' in mock_io_manager._saved_jsons + report_data = mock_io_manager._saved_jsons['missing_bars_summary.json'] + assert report_data['total_missing_bars'] == 6 + assert report_data['longest_consecutive_gap'] == 3 + assert report_data['applied_strategy'] == 'neutral' + + +def test_fill_missing_bars_strategy_ffill(gappy_dataframe, base_config, mock_io_manager, mock_logger): + cfg = base_config.copy() + cfg.data.missing.strategy = 'ffill' + + filled_df = fill_missing_bars(gappy_dataframe.copy(), cfg, mock_io_manager, mock_logger) + + assert 'bar_imputed' in filled_df.columns + assert len(filled_df) == 13 + + # Check imputed timestamps + missing_ts = find_missing_bars(gappy_dataframe, freq='T') + pd.testing.assert_index_equal(filled_df[filled_df['bar_imputed']].index, missing_ts) + + # Check ffill logic for a specific imputed bar (00:02:00) + imputed_bar = filled_df.loc['2023-01-01 00:02:00'] + prev_bar = gappy_dataframe.loc['2023-01-01 00:01:00'] + assert imputed_bar['open'] == prev_bar['open'] + assert imputed_bar['high'] == prev_bar['high'] + assert imputed_bar['low'] == prev_bar['low'] + assert imputed_bar['close'] == prev_bar['close'] + assert imputed_bar['volume'] == prev_bar['volume'] + assert imputed_bar['bar_imputed'] == True + + # Check another imputed bar (00:05:00) + imputed_bar_2 = filled_df.loc['2023-01-01 00:05:00'] + prev_bar_2 = gappy_dataframe.loc['2023-01-01 00:04:00'] + assert imputed_bar_2['close'] == prev_bar_2['close'] + assert imputed_bar_2['bar_imputed'] == True + + +def test_fill_missing_bars_strategy_interpolate(gappy_dataframe, base_config, mock_io_manager, mock_logger): + cfg = base_config.copy() + cfg.data.missing.strategy = 'interpolate' + cfg.data.missing.interpolate.method = 'linear' + + filled_df = fill_missing_bars(gappy_dataframe.copy(), cfg, mock_io_manager, mock_logger) + + assert 'bar_imputed' in filled_df.columns + assert len(filled_df) == 13 + + # Check imputed timestamps + missing_ts = find_missing_bars(gappy_dataframe, freq='T') + pd.testing.assert_index_equal(filled_df[filled_df['bar_imputed']].index, missing_ts) + + # Check interpolation logic for close price at 00:02:00 and 00:03:00 + # These are between 00:01:00 (101.2) and 00:04:00 (104.2) - 3 steps total + close_01 = gappy_dataframe.loc['2023-01-01 00:01:00', 'close'] + close_04 = gappy_dataframe.loc['2023-01-01 00:04:00', 'close'] + expected_close_02 = close_01 + (close_04 - close_01) / 3.0 * 1 + expected_close_03 = close_01 + (close_04 - close_01) / 3.0 * 2 + + assert np.isclose(filled_df.loc['2023-01-01 00:02:00', 'close'], expected_close_02) + assert np.isclose(filled_df.loc['2023-01-01 00:03:00', 'close'], expected_close_03) + assert filled_df.loc['2023-01-01 00:02:00', 'bar_imputed'] == True + assert filled_df.loc['2023-01-01 00:03:00', 'bar_imputed'] == True + + # Check interpolation logic for volume at 00:05:00 + # This is between 00:04:00 (14) and 00:06:00 (16) - 2 steps total + vol_04 = gappy_dataframe.loc['2023-01-01 00:04:00', 'volume'] + vol_06 = gappy_dataframe.loc['2023-01-01 00:06:00', 'volume'] + expected_vol_05 = vol_04 + (vol_06 - vol_04) / 2.0 * 1 + assert np.isclose(filled_df.loc['2023-01-01 00:05:00', 'volume'], expected_vol_05) + assert filled_df.loc['2023-01-01 00:05:00', 'bar_imputed'] == True + +def test_fill_missing_bars_max_gap_exceeded(gappy_dataframe, base_config, mock_io_manager, mock_logger): + cfg = base_config.copy() + cfg.data.missing.max_gap = 2 # Set max gap lower than the actual max gap (3) + + with pytest.raises(ValueError) as excinfo: + fill_missing_bars(gappy_dataframe.copy(), cfg, mock_io_manager, mock_logger) + + assert "Longest consecutive gap (3) exceeds maximum allowed (2)" in str(excinfo.value) + # Verify report was still generated before the error + mock_logger.warning.assert_called_once() + assert 'missing_bars_summary.json' in mock_io_manager._saved_jsons + +def test_fill_missing_bars_no_missing_data(gappy_dataframe, base_config, mock_io_manager, mock_logger): + # Create a dataframe with no gaps + no_gap_df = pd.DataFrame({ + 'open': [1, 2, 3], 'high': [1, 2, 3], 'low': [1, 2, 3], 'close': [1, 2, 3], 'volume': [10, 20, 30] + }, index=pd.date_range('2023-01-01', periods=3, freq='T')) + + filled_df = fill_missing_bars(no_gap_df.copy(), base_config, mock_io_manager, mock_logger) + + assert 'bar_imputed' in filled_df.columns + assert not filled_df['bar_imputed'].any() # No bars should be marked as imputed + pd.testing.assert_frame_equal(filled_df.drop(columns=['bar_imputed']), no_gap_df) + mock_logger.info.assert_any_call("No missing bars detected.") # Check info log \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_feature_engineer.py b/gru_sac_predictor/tests/test_feature_engineer.py new file mode 100644 index 00000000..cc6ccf3b --- /dev/null +++ b/gru_sac_predictor/tests/test_feature_engineer.py @@ -0,0 +1,125 @@ +""" +Tests for the FeatureEngineer class and its methods. + +Ref: revisions.txt Task 2.5 +""" + +import pytest +import pandas as pd +import numpy as np +import sys, os +from unittest.mock import patch, MagicMock + +# --- Add path for src imports --- # +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(script_dir) +src_path = os.path.join(project_root, 'src') +if src_path not in sys.path: + sys.path.insert(0, src_path) +# --- End Add path --- # + +from feature_engineer import FeatureEngineer +# Import minimal_whitelist from features to pass to constructor +from features import minimal_whitelist as base_minimal_whitelist + +# --- Fixtures --- # + +@pytest.fixture +def sample_engineer() -> FeatureEngineer: + """Provides a FeatureEngineer instance with a basic whitelist.""" + # Use a copy to avoid modifying the original during tests + test_whitelist = base_minimal_whitelist.copy() + return FeatureEngineer(minimal_whitelist=test_whitelist) + +@pytest.fixture +def sample_feature_data() -> pd.DataFrame: + """Creates sample features for testing selection.""" + np.random.seed(42) + data = { + 'return_1m': np.random.randn(100) * 0.01, + 'EMA_50': 100 + np.random.randn(100).cumsum() * 0.1, + 'ATR_14': np.random.rand(100) * 0.5, + 'hour_sin': np.sin(np.linspace(0, 2 * np.pi, 100)), + 'highly_correlated_1': 100 + np.random.randn(100).cumsum() * 0.1, # Copy EMA_50 roughly + 'highly_correlated_2': 101 + np.random.randn(100).cumsum() * 0.1, # Copy EMA_50 roughly + 'constant_feat': np.ones(100), + 'nan_feat': np.full(100, np.nan), + 'inf_feat': np.full(100, np.inf) + } + index = pd.date_range(start='2023-01-01', periods=100, freq='min', tz='UTC') + df = pd.DataFrame(data, index=index) + # Add the correlation + df['highly_correlated_1'] = df['EMA_50'] * (1 + np.random.randn(100) * 0.01) + df['highly_correlated_2'] = df['highly_correlated_1'] * (1 + np.random.randn(100) * 0.01) + return df + +@pytest.fixture +def sample_target_data() -> pd.Series: + """Creates sample binary target variable.""" + np.random.seed(123) + # Create somewhat predictable target based on EMA_50 trend + ema = 100 + np.random.randn(100).cumsum() * 0.1 + target = (np.diff(ema, prepend=0) > 0).astype(int) + index = pd.date_range(start='2023-01-01', periods=100, freq='min', tz='UTC') + return pd.Series(target, index=index) + +# --- Tests --- # + +def test_select_features_vif_skip(sample_engineer, sample_feature_data, sample_target_data): + """ + Test 2.5: Assert VIF calculation is skipped if skip_vif=True in config. + We need to mock the config access within select_features. + """ + engineer = sample_engineer + X_train = sample_feature_data + y_train = sample_target_data + + # Mock the config dictionary that would be passed or accessed + # For now, assume select_features might take an optional config or we patch where it reads it. + # Since it doesn't currently take config, we have to modify the method or mock dependencies. + # Let's *assume* for this test that select_features *will be* modified to check a config. + # We will patch the VIF function itself and assert it's not called. + + # Add a feature that would definitely be removed by VIF to ensure the check matters + X_train['perfectly_correlated'] = X_train['EMA_50'] * 2 + + with patch('feature_engineer.variance_inflation_factor') as mock_vif: + # We also need to mock the SelectFromModel part to return *some* features initially + with patch('feature_engineer.SelectFromModel') as mock_select_from_model: + # Configure the mock selector to return a subset of features including correlated ones + mock_instance = MagicMock() + initial_selection = [True] * 5 + [False] * 4 + [True] # Select first 5 + perfectly_correlated + mock_instance.get_support.return_value = np.array(initial_selection) + mock_select_from_model.return_value = mock_instance + + # Call select_features - **modify it conceptually to accept skip_vif** + # Since we can't modify the source directly here, we test by asserting VIF wasn't called. + # This implicitly tests the skip logic. + + # Simulate the call as if skip_vif=True was passed/checked internally + # Patch the VIF calculation call site directly + with patch('feature_engineer.sm.add_constant') as mock_add_constant: # VIF loop uses this + # Call the function normally - the patch on VIF itself is the key + selected_features = engineer.select_features(X_train, y_train) + + # Assert that variance_inflation_factor was NOT called + mock_vif.assert_not_called() + # Assert that add_constant (used within VIF loop) was also NOT called + mock_add_constant.assert_not_called() + + # Assert that the features returned are those from the mocked L1 selection + # (potentially plus minimal whitelist, depending on implementation) + # The exact output depends on how L1 + whitelist are combined *before* VIF step + # Let's just assert the correlated feature IS included, as VIF didn't remove it + assert 'perfectly_correlated' in selected_features + + # We should also check that the log message indicating VIF skip was printed + # (This requires capturing logs, omitted here for brevity) + +# TODO: Add more tests for FeatureEngineer +# - Test feature calculation methods (_add_cyclical_features, _add_imbalance_features, _add_ta_features) +# - Test add_base_features orchestration +# - Test select_features VIF logic *when enabled* (e.g., check correlated feature is removed) +# - Test select_features LogReg L1 logic (e.g., check constant feature is removed) +# - Test handling of NaNs/Infs in select_features +# - Test prune_features (although covered in test_feature_pruning.py) \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_feature_pruning.py b/gru_sac_predictor/tests/test_feature_pruning.py new file mode 100644 index 00000000..89c6141a --- /dev/null +++ b/gru_sac_predictor/tests/test_feature_pruning.py @@ -0,0 +1,87 @@ +""" +Tests for feature pruning logic. + +Ref: revisions.txt Step 1-D +""" +import pytest +import pandas as pd + +# TODO: Import prune_features function and minimal_whitelist from src.features +# from gru_sac_predictor.src.features import prune_features, minimal_whitelist + +# Mock minimal_whitelist for testing if import fails +minimal_whitelist = ['feat_a', 'feat_b', 'feat_c', 'hour_sin'] + +# Mock prune_features if import fails +def prune_features(df: pd.DataFrame, whitelist: list[str] | None = None) -> pd.DataFrame: + if whitelist is None: + whitelist = minimal_whitelist + cols_to_keep = [c for c in whitelist if c in df.columns] + df_pruned = df[cols_to_keep].copy() + assert set(df_pruned.columns) == set(cols_to_keep), \ + f"Pruning failed: Output columns {set(df_pruned.columns)} != Expected intersection {set(cols_to_keep)}" + return df_pruned + + +@pytest.fixture +def sample_dataframe() -> pd.DataFrame: + """Create a sample DataFrame for testing.""" + data = { + 'feat_a': [1, 2, 3], + 'feat_b': [4, 5, 6], + 'feat_extra': [7, 8, 9], + 'hour_sin': [0.1, 0.2, 0.3] + } + return pd.DataFrame(data) + + +def test_prune_to_minimal_whitelist(sample_dataframe): + """Test pruning to the default minimal whitelist.""" + df_pruned = prune_features(sample_dataframe, whitelist=minimal_whitelist) + + expected_cols = {'feat_a', 'feat_b', 'hour_sin'} + assert set(df_pruned.columns) == expected_cols + assert 'feat_extra' not in df_pruned.columns + +def test_prune_with_custom_whitelist(sample_dataframe): + """Test pruning with a custom whitelist.""" + custom_whitelist = ['feat_a', 'feat_extra'] + df_pruned = prune_features(sample_dataframe, whitelist=custom_whitelist) + + expected_cols = {'feat_a', 'feat_extra'} + assert set(df_pruned.columns) == expected_cols + assert 'feat_b' not in df_pruned.columns + assert 'hour_sin' not in df_pruned.columns + +def test_prune_missing_whitelist_cols(sample_dataframe): + """Test when whitelist contains columns not in the dataframe.""" + custom_whitelist = ['feat_a', 'feat_c', 'hour_sin'] # feat_c is not in sample_dataframe + df_pruned = prune_features(sample_dataframe, whitelist=custom_whitelist) + + expected_cols = {'feat_a', 'hour_sin'} # Only existing columns are kept + assert set(df_pruned.columns) == expected_cols + assert 'feat_c' not in df_pruned.columns + +def test_prune_empty_whitelist(): + """Test pruning with an empty whitelist.""" + df = pd.DataFrame({'a': [1], 'b': [2]}) + df_pruned = prune_features(df, whitelist=[]) + assert df_pruned.empty + assert df_pruned.columns.empty + +def test_prune_empty_dataframe(): + """Test pruning an empty dataframe.""" + df = pd.DataFrame() + df_pruned = prune_features(df, whitelist=minimal_whitelist) + assert df_pruned.empty + assert df_pruned.columns.empty + +def test_prune_assertion(sample_dataframe): + """Verify the assertion within prune_features catches mismatches (requires mocking or specific setup).""" + # This test might be tricky without modifying the function or using complex mocks. + # The assertion `assert set(df_pruned.columns) == set(cols_to_keep)` should generally hold + # if the logic `df_pruned = df[cols_to_keep].copy()` is correct. + # We rely on the other tests implicitly covering this assertion. + pytest.skip("Assertion test might require specific mocking setup.") + +# Add tests for edge cases like DataFrames with duplicate column names if relevant. \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_gru_model_handler.py b/gru_sac_predictor/tests/test_gru_model_handler.py new file mode 100644 index 00000000..fd8d8237 --- /dev/null +++ b/gru_sac_predictor/tests/test_gru_model_handler.py @@ -0,0 +1,84 @@ +import pytest +import tensorflow as tf +import numpy as np +import os +import sys + +# Add the src directory to the Python path to allow imports +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) + +try: + from gru_model_handler import build_gru_model_v3 + MODEL_HANDLER_AVAILABLE = True +except ImportError as e: + print(f"Failed to import gru_model_handler: {e}. Skipping tests that require it.") + MODEL_HANDLER_AVAILABLE = False + # Define a dummy function to avoid errors if import fails + def build_gru_model_v3(*args, **kwargs): + raise ImportError("gru_model_handler could not be imported.") + +@pytest.mark.skipif(not MODEL_HANDLER_AVAILABLE, reason="GRUModelHandler module not found or TensorFlow/Keras missing") +def test_causal_mask_lambda_layer(): + """Tests if the Lambda layer correctly generates a lower-triangular causal mask.""" + lookback = 60 + n_features = 96 + attention_units = 16 # Must be > 0 to include the attention layer and mask + + # Build a minimal model instance to access the layer + try: + model = build_gru_model_v3( + lookback=lookback, + n_features=n_features, + attention_units=attention_units + # Use default values for other params for this test + ) + except Exception as build_err: + pytest.fail(f"Failed to build GRU model for testing: {build_err}") + + + # Find the Lambda layer by name + try: + mask_lambda_layer = model.get_layer('causal_mask_lambda') + except ValueError: + pytest.fail("Could not find the 'causal_mask_lambda' layer in the model.") + + # Create a dummy input tensor (batch size 1, sequence length lookback, feature dim n_features) + # The content doesn't matter, only the shape for the mask generation + dummy_input_tensor = tf.zeros((1, lookback, n_features)) + + # Get the mask output from the Lambda layer + # Note: Calling the layer directly works outside tf.function context + mask_output = mask_lambda_layer(dummy_input_tensor) + + # Expected mask is lower triangular matrix of shape (lookback, lookback) + expected_mask = np.tril(np.ones((lookback, lookback))) + + # Compare the Lambda layer output with the expected mask + # The mask_output might have a batch dimension added by Keras depending on context, + # but the core logic in create_mask returns (seq_len, seq_len). + # Let's check the shape and squeeze if necessary, although direct call might not add batch dim. + + # Convert tensor to numpy for comparison + mask_numpy = mask_output.numpy() + + # Assert shape is as expected (seq_len, seq_len) + assert mask_numpy.shape == (lookback, lookback), f"Expected mask shape {(lookback, lookback)}, but got {mask_numpy.shape}" + + # Assert the values are close to the lower triangular matrix + assert np.allclose(mask_numpy, expected_mask), "Mask generated by Lambda layer is not lower-triangular." + + print("Causal mask test passed.") + +# Add a main block to run the test if the file is executed directly (optional) +if __name__ == "__main__": + # Basic check if dependencies are met + if not MODEL_HANDLER_AVAILABLE: + print("Skipping test execution: GRUModelHandler or TensorFlow/Keras not available.") + else: + print("Running causal mask test...") + # Simple manual run without pytest fixtures/discovery + try: + test_causal_mask_lambda_layer() + except Exception as e: + print(f"Test failed with error: {e}") + \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_integration.py b/gru_sac_predictor/tests/test_integration.py new file mode 100644 index 00000000..3f95759f --- /dev/null +++ b/gru_sac_predictor/tests/test_integration.py @@ -0,0 +1,117 @@ +""" +Integration tests for cross-module interactions. +""" +import pytest +import os +import numpy as np +import tempfile +import json + +# Try to import the module; skip tests if not found +try: + from gru_sac_predictor.src import sac_agent + import tensorflow as tf # Needed for agent init/load +except ImportError: + sac_agent = None + tf = None + +@pytest.fixture +def sac_agent_for_integration(): + """Provides a basic SAC agent instance.""" + if sac_agent is None or tf is None: + pytest.skip("SAC Agent module or TF not found.") + # Use minimal params for saving/loading tests + agent = sac_agent.SACTradingAgent( + state_dim=5, action_dim=1, + buffer_capacity=100, min_buffer_size=10 + ) + # Build models + try: + agent.actor(tf.zeros((1, 5))) + agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))]) + agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))]) + agent.update_target_networks(tau=1.0) + except Exception as e: + pytest.fail(f"Failed to build agent models: {e}") + return agent + +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TF not found") +def test_save_load_metadata(sac_agent_for_integration): + """Test if metadata is saved and loaded correctly.""" + agent = sac_agent_for_integration + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, "sac_test_save") + agent.save(save_path) + + # Check if metadata file exists + meta_path = os.path.join(save_path, 'agent_metadata.json') + assert os.path.exists(meta_path), "Metadata file was not saved." + + # Create a new agent and load + new_agent = sac_agent.SACTradingAgent(state_dim=5, action_dim=1) + loaded_meta = new_agent.load(save_path) + + assert isinstance(loaded_meta, dict), "Load method did not return a dict." + assert loaded_meta.get('state_dim') == 5, "Loaded state_dim incorrect." + assert loaded_meta.get('action_dim') == 1, "Loaded action_dim incorrect." + # Check alpha status (default is auto_tune=True) + assert loaded_meta.get('log_alpha_saved') == True, "log_alpha status incorrect." + +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TF not found") +def test_replay_buffer_purge_on_change(sac_agent_for_integration): + """ + Simulate loading an agent where the edge_threshold has changed + and verify the buffer is cleared. + """ + agent_to_save = sac_agent_for_integration + original_edge_thr = 0.55 + agent_to_save.edge_threshold_config = original_edge_thr # Manually set for saving + + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, "sac_purge_test") + + # 1. Save agent with original threshold in metadata + agent_to_save.save(save_path) + meta_path = os.path.join(save_path, 'agent_metadata.json') + assert os.path.exists(meta_path) + with open(meta_path, 'r') as f: + saved_meta = json.load(f) + assert saved_meta.get('edge_threshold_config') == original_edge_thr + + # 2. Create a new agent instance to load into + new_agent = sac_agent.SACTradingAgent( + state_dim=5, action_dim=1, + buffer_capacity=100, min_buffer_size=10 + ) + # Build models for the new agent + try: + new_agent.actor(tf.zeros((1, 5))) + new_agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))]) + new_agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))]) + new_agent.update_target_networks(tau=1.0) + except Exception as e: + pytest.fail(f"Failed to build new agent models: {e}") + + # Add dummy data to the *new* agent's buffer *before* loading + for _ in range(20): + dummy_state = np.random.rand(5).astype(np.float32) + dummy_action = np.random.rand(1).astype(np.float32) + new_agent.buffer.add(dummy_state, dummy_action, 0.0, dummy_state, 0.0) + assert len(new_agent.buffer) == 20, "Buffer should have data before load." + + # 3. Simulate loading with a *different* current edge threshold config + current_config_edge_thr = 0.60 + assert abs(current_config_edge_thr - original_edge_thr) > 1e-6 + + loaded_meta = new_agent.load(save_path) + saved_edge_thr = loaded_meta.get('edge_threshold_config') + + # 4. Perform the check and clear if needed (simulating pipeline logic) + if saved_edge_thr is not None and abs(saved_edge_thr - current_config_edge_thr) > 1e-6: + print(f"\nEdge threshold mismatch detected (Saved={saved_edge_thr}, Current={current_config_edge_thr}). Clearing buffer.") + new_agent.clear_buffer() + else: + print(f"\nEdge threshold match or not saved. Buffer not cleared.") + + # 5. Assert buffer is now empty + assert len(new_agent.buffer) == 0, "Buffer was not cleared after edge threshold mismatch." \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_labels.py b/gru_sac_predictor/tests/test_labels.py new file mode 100644 index 00000000..48456d9a --- /dev/null +++ b/gru_sac_predictor/tests/test_labels.py @@ -0,0 +1,201 @@ +""" +Tests for label generation and potential leakage. + +Ref: revisions.txt Step 1-A, 1.4 +""" +import pytest +import pandas as pd +import numpy as np +import sys, os + +# --- Add path for src imports --- # +# Assuming tests is one level down from the package root +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(script_dir) # Go up one level +src_path = os.path.join(project_root, 'src') +if src_path not in sys.path: + sys.path.insert(0, src_path) +# --- End Add path --- # + +# Import the function to test +from trading_pipeline import _generate_direction_labels + +# --- Fixtures --- # +@pytest.fixture +def sample_close_data() -> pd.DataFrame: + """Creates a sample DataFrame with close prices and DatetimeIndex.""" + # Generate data with some variation + np.random.seed(42) + prices = 100 + np.cumsum(np.random.randn(200) * 0.5) + data = {'close': prices} + index = pd.date_range(start='2023-01-01', periods=len(data['close']), freq='min', tz='UTC') + df = pd.DataFrame(data, index=index) + return df + +@pytest.fixture +def sample_config() -> dict: + """Provides a basic config dictionary.""" + return { + 'gru': { + 'prediction_horizon': 5, + 'use_ternary': False, + 'flat_sigma_multiplier': 0.25 + }, + 'data': { + 'label_smoothing': 0.0 + } + } + +# --- Tests --- # + +def test_lookahead_bias(sample_close_data, sample_config): + """ + Test 1.4.a: Verify labels don't depend on information *beyond* the prediction horizon. + Strategy: Modify future close prices (beyond horizon) and check if labels change. + """ + df = sample_close_data + config = sample_config + horizon = config['gru']['prediction_horizon'] + + # Generate baseline labels (binary) + df_labeled_base, label_col_base = _generate_direction_labels(df.copy(), config) + + # Modify close prices far into the future (beyond the horizon needed for any label) + df_modified = df.copy() + future_index = len(df) - 1 # Index of the last point + modify_point = future_index - horizon - 5 # Index well beyond the last needed future price + if modify_point > 0: + df_modified.iloc[modify_point:, df_modified.columns.get_loc('close')] *= 1.5 # Modify future prices + + # Generate labels with modified future data + df_labeled_mod, label_col_mod = _generate_direction_labels(df_modified.copy(), config) + + # Align based on index (label function drops NaNs at the end) + common_index = df_labeled_base.index.intersection(df_labeled_mod.index) + labels_base_aligned = df_labeled_base.loc[common_index, label_col_base] + labels_mod_aligned = df_labeled_mod.loc[common_index, label_col_mod] + + # Assert: Labels should be identical, as modification was beyond the horizon + pd.testing.assert_series_equal(labels_base_aligned, labels_mod_aligned, check_names=False) + + # --- Repeat for Ternary --- # + config['gru']['use_ternary'] = True + df_labeled_base_t, label_col_base_t = _generate_direction_labels(df.copy(), config) + df_labeled_mod_t, label_col_mod_t = _generate_direction_labels(df_modified.copy(), config) + + common_index_t = df_labeled_base_t.index.intersection(df_labeled_mod_t.index) + labels_base_aligned_t = df_labeled_base_t.loc[common_index_t, label_col_base_t] + labels_mod_aligned_t = df_labeled_mod_t.loc[common_index_t, label_col_mod_t] + + # Assert: Ternary labels should also be identical + # Need careful comparison for list/array column + assert labels_base_aligned_t.equals(labels_mod_aligned_t) + +def test_binary_label_distribution(sample_close_data, sample_config): + """ + Test 1.4.b: Check binary label distribution has >= 5% in each class. + """ + df = sample_close_data + config = sample_config + config['gru']['use_ternary'] = False + config['data']['label_smoothing'] = 0.0 # Ensure hard binary for this test + + df_labeled, label_col = _generate_direction_labels(df.copy(), config) + + assert not df_labeled.empty, "Label generation resulted in empty DataFrame" + assert label_col in df_labeled.columns, f"Label column '{label_col}' not found" + + labels = df_labeled[label_col] + counts = labels.value_counts(normalize=True) + + assert len(counts) == 2, f"Expected 2 binary classes, found {len(counts)}" + assert counts.min() >= 0.05, f"Minimum binary class proportion ({counts.min():.2%}) is less than 5%" + print(f"\nBinary Dist: {counts.to_dict()}") # Print for info + +def test_soft_binary_label_distribution(sample_close_data, sample_config): + """ + Test 1.4.b: Check soft binary label distribution has >= 5% in each effective class. + """ + df = sample_close_data + config = sample_config + config['gru']['use_ternary'] = False + config['data']['label_smoothing'] = 0.2 # Example smoothing + smoothing = config['data']['label_smoothing'] + low_label = smoothing / 2.0 + high_label = 1.0 - smoothing / 2.0 + + df_labeled, label_col = _generate_direction_labels(df.copy(), config) + + assert not df_labeled.empty, "Label generation resulted in empty DataFrame" + assert label_col in df_labeled.columns, f"Label column '{label_col}' not found" + + labels = df_labeled[label_col] + counts = labels.value_counts(normalize=True) + + assert len(counts) == 2, f"Expected 2 soft binary classes, found {len(counts)}" + assert counts.min() >= 0.05, f"Minimum soft binary class proportion ({counts.min():.2%}) is less than 5%" + assert low_label in counts.index, f"Low label {low_label} not found in counts" + assert high_label in counts.index, f"High label {high_label} not found in counts" + print(f"\nSoft Binary Dist: {counts.to_dict()}") + +def test_ternary_label_distribution(sample_close_data, sample_config): + """ + Test 1.4.b: Check ternary label distribution (flat=[0.15, 0.45], others >= 0.10). + Uses default k=0.25. + """ + df = sample_close_data + config = sample_config + config['gru']['use_ternary'] = True + k = config['gru']['flat_sigma_multiplier'] # Should be 0.25 from fixture + + df_labeled, label_col = _generate_direction_labels(df.copy(), config) + + assert not df_labeled.empty, "Label generation resulted in empty DataFrame" + assert label_col in df_labeled.columns, f"Label column '{label_col}' not found" + + # Decode one-hot labels back to ordinal for distribution check + labels_one_hot = np.stack(df_labeled[label_col].values) + assert labels_one_hot.shape[1] == 3, "Ternary labels should have 3 columns" + ordinal_labels = np.argmax(labels_one_hot, axis=1) + + counts = np.bincount(ordinal_labels, minlength=3) + total = len(ordinal_labels) + dist_pct = counts / total * 100 + + print(f"\nTernary Dist (k={k}): Down={dist_pct[0]:.1f}%, Flat={dist_pct[1]:.1f}%, Up={dist_pct[2]:.1f}%") + + # Check constraints based on design doc / implementation + assert 15.0 <= dist_pct[1] <= 45.0, f"Flat class ({dist_pct[1]:.1f}%) out of expected range [15%, 45%] for k={k}" + assert dist_pct[0] >= 10.0, f"Down class ({dist_pct[0]:.1f}%) is less than 10% (check impl threshold)" + assert dist_pct[2] >= 10.0, f"Up class ({dist_pct[2]:.1f}%) is less than 10% (check impl threshold)" + +# --- Old Tests (Keep or Remove?) --- +# The original tests checked 'future_close', which is related but not the final label. +# We can keep test_future_close_shift as it verifies the shift logic used internally. +# The NaN test is less relevant now as the main function handles NaN dropping. + +def test_future_close_shift(sample_close_data): + """Verify that 'future_close' is correctly shifted and has NaNs at the end.""" + df = sample_close_data + horizon = 5 # Example horizon + + # Apply the logic directly for testing the shift itself + df['future_close'] = df['close'].shift(-horizon) + df['fwd_log_ret'] = np.log(df['future_close'] / df['close']) + + # Assertions + # 1. Check for correct shift in fwd_log_ret + # The first valid fwd_log_ret depends on close[0] and close[horizon] + assert pd.notna(df['fwd_log_ret'].iloc[0]) + # The last valid fwd_log_ret depends on close[end-horizon-1] and close[end-1] + assert pd.notna(df['fwd_log_ret'].iloc[len(df) - horizon - 1]) + + # 2. Check for NaNs at the end due to shift + assert pd.isna(df['fwd_log_ret'].iloc[-horizon:]).all() + assert pd.notna(df['fwd_log_ret'].iloc[:-horizon]).all() + +# def test_no_nan_in_future_close_output(): +# """Unit test to ensure no unexpected NaNs in the output of label creation (specific to the function).""" +# # Setup similar to above, potentially call the actual DataLoader/label function +# # Assert pd.notna(output_df['future_close'][:-horizon]).all() +# pytest.skip("Test covered by NaN dropping in _generate_direction_labels and its tests.") \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_leakage.py b/gru_sac_predictor/tests/test_leakage.py new file mode 100644 index 00000000..f96d3860 --- /dev/null +++ b/gru_sac_predictor/tests/test_leakage.py @@ -0,0 +1,133 @@ +""" +Tests for data leakage (Sec 6 of revisions.txt). +""" +import pytest +import pandas as pd +import numpy as np + +# Assume test data is loaded via fixtures later +@pytest.fixture(scope="module") +def sample_data_for_leakage(): + """ + Provides sample features and target for leakage tests. + Includes correctly shifted features, a feature with direct leakage, + and a rolling feature calculated correctly vs incorrectly. + """ + np.random.seed(43) + dates = pd.date_range(start='2023-01-01', periods=500, freq='T') + n = len(dates) + df = pd.DataFrame(index=dates) + df['noise'] = np.random.randn(n) + df['close'] = 100 + np.cumsum(df['noise'] * 0.1) + df['y_ret'] = np.log(df['close'].shift(-1) / df['close']) + + # --- Features --- + # OK: Based on past noise + df['feature_ok_past_noise'] = df['noise'].shift(1) + # OK: Rolling mean on correctly shifted past data + df['feature_ok_rolling_shifted'] = df['noise'].shift(1).rolling(10).mean() + # LEAKY: Uses future return directly + df['feature_leaky_direct'] = df['y_ret'] + # LEAKY: Rolling mean calculated *before* shifting target relationship + df['feature_leaky_rolling_unaligned'] = df['close'].rolling(5).mean() + + # Drop rows with NaNs from shifts/rolls AND the last row where y_ret is NaN + df.dropna(inplace=True) + + # Define features and target for the test + y_target = df['y_ret'] + features_df = df.drop(columns=['close', 'y_ret', 'noise']) # Exclude raw data used for generation + + return features_df, y_target + +@pytest.mark.parametrize("leakage_threshold", [0.02]) +def test_feature_leakage_correlation(sample_data_for_leakage, leakage_threshold): + """ + Verify that no feature has correlation > threshold with the correctly shifted target. + """ + features_df, y_target = sample_data_for_leakage + + max_abs_corr = 0.0 + leaky_col = "None" + all_corrs = {} + + print(f"\nTesting {features_df.shape[1]} features for leakage (threshold={leakage_threshold})...") + for col in features_df.columns: + if pd.api.types.is_numeric_dtype(features_df[col]): + # Handle potential NaNs introduced by feature engineering (though fixture avoids it) + temp_df = pd.concat([features_df[col], y_target], axis=1).dropna() + if len(temp_df) < 0.5 * len(features_df): + print(f" Skipping {col} due to excessive NaNs after merging with target.") + continue + + correlation = temp_df[col].corr(temp_df['y_ret']) + all_corrs[col] = correlation + # print(f" Corr({col}, y_ret): {correlation:.4f}") + if abs(correlation) > max_abs_corr: + max_abs_corr = abs(correlation) + leaky_col = col + else: + print(f" Skipping non-numeric column: {col}") + + print(f"Correlations found: { {k: round(v, 4) for k, v in all_corrs.items()} }") + print(f"Maximum absolute correlation found: {max_abs_corr:.4f} (feature: {leaky_col})") + + assert max_abs_corr < leakage_threshold, \ + f"Feature '{leaky_col}' has correlation {max_abs_corr:.4f} > threshold {leakage_threshold}, suggesting leakage." + +@pytest.mark.skipif(features is None, reason="Module gru_sac_predictor.src.features not found") +def test_ta_feature_leakage(sample_data_for_leakage, leakage_threshold=0.02): + """ + Specifically test TA features (EMA, MACD etc.) for leakage. + Ensures they were calculated on shifted data. + """ + features_df, y_target = sample_data_for_leakage + # Add TA features using the helper (simulating pipeline) + # We need OHLC in the input df for add_ta_features + # Recreate a df with shifted OHLC + other features for TA calc + np.random.seed(43) # Ensure consistent data with primary fixture + dates = pd.date_range(start='2023-01-01', periods=500, freq='T') + n = len(dates) + df_ohlc = pd.DataFrame(index=dates) + df_ohlc['close'] = 100 + np.cumsum(np.random.randn(n) * 0.1) + df_ohlc['open'] = df_ohlc['close'].shift(1) * (1 + np.random.randn(n) * 0.001) + df_ohlc['high'] = df_ohlc[['open','close']].max(axis=1) * (1 + np.random.rand(n) * 0.001) + df_ohlc['low'] = df_ohlc[['open','close']].min(axis=1) * (1 - np.random.rand(n) * 0.001) + df_ohlc['volume'] = np.random.rand(n) * 1000 + + # IMPORTANT: Shift before calculating TA features + df_shifted_ohlc = df_ohlc.shift(1) + df_ta = features.add_ta_features(df_shifted_ohlc) + + # Align with the target (requires original non-shifted index) + df_ta = df_ta.loc[y_target.index] + + ta_features_to_test = [col for col in features.minimal_whitelist if col in df_ta.columns and col not in ["return_1m", "return_15m", "return_60m", "hour_sin", "hour_cos"]] + max_abs_corr = 0.0 + leaky_col = "None" + all_corrs = {} + + print(f"\nTesting {len(ta_features_to_test)} TA features for leakage (threshold={leakage_threshold})...") + print(f" Features: {ta_features_to_test}") + + for col in ta_features_to_test: + if pd.api.types.is_numeric_dtype(df_ta[col]): + temp_df = pd.concat([df_ta[col], y_target], axis=1).dropna() + if len(temp_df) < 0.5 * len(y_target): + print(f" Skipping {col} due to excessive NaNs after merging.") + continue + correlation = temp_df[col].corr(temp_df['y_ret']) + all_corrs[col] = correlation + if abs(correlation) > max_abs_corr: + max_abs_corr = abs(correlation) + leaky_col = col + else: + print(f" Skipping non-numeric TA column: {col}") + + print(f"TA Feature Correlations: { {k: round(v, 4) for k, v in all_corrs.items()} }") + print(f"Maximum absolute TA correlation found: {max_abs_corr:.4f} (feature: {leaky_col})") + + assert max_abs_corr < leakage_threshold, \ + f"TA Feature '{leaky_col}' has correlation {max_abs_corr:.4f} > threshold {leakage_threshold}, suggesting leakage from TA calculation." + +# test_label_timing is usually covered by the correlation test, so removed for brevity. \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_metrics.py b/gru_sac_predictor/tests/test_metrics.py new file mode 100644 index 00000000..5e17e182 --- /dev/null +++ b/gru_sac_predictor/tests/test_metrics.py @@ -0,0 +1,136 @@ +""" +Tests for custom metric functions. + +Ref: revisions.txt Task 6.5 +""" + +import pytest +import numpy as np +import pandas as pd +import sys, os + +# --- Add path for src imports --- # +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(script_dir) +src_path = os.path.join(project_root, 'src') +if src_path not in sys.path: + sys.path.insert(0, src_path) +# --- End Add path --- # + +from metrics import edge_filtered_accuracy, calculate_sharpe_ratio + +# --- Tests for edge_filtered_accuracy --- # + +def test_edge_filtered_accuracy_basic(): + """Test basic functionality with hard labels and clear edge.""" + y_true = np.array([1, 0, 1, 0, 1, 1, 0, 0]) + p_cal = np.array([0.9, 0.1, 0.8, 0.2, 0.7, 0.6, 0.3, 0.4]) # Edge > 0.1 for all + thr = 0.1 + + accuracy, n_filtered = edge_filtered_accuracy(y_true, p_cal, thr=thr) + + assert n_filtered == 8 + # Predictions: 1, 0, 1, 0, 1, 1, 0, 0. All correct. + assert accuracy == pytest.approx(1.0) + +def test_edge_filtered_accuracy_thresholding(): + """Test that the threshold correctly filters samples.""" + y_true = np.array([1, 0, 1, 0, 1, 1, 0, 0]) + p_cal = np.array([0.9, 0.1, 0.8, 0.2, 0.51, 0.49, 0.55, 0.45]) # Edge: 0.8, 0.8, 0.6, 0.6, 0.02, 0.02, 0.1, 0.1 + + # Test with thr=0.15 (should exclude last 4 samples) + thr1 = 0.15 + accuracy1, n_filtered1 = edge_filtered_accuracy(y_true, p_cal, thr=thr1) + assert n_filtered1 == 4 + # Predictions on first 4: 1, 0, 1, 0. All correct. + assert accuracy1 == pytest.approx(1.0) + + # Test with thr=0.05 (should include all but middle 2) + thr2 = 0.05 + accuracy2, n_filtered2 = edge_filtered_accuracy(y_true, p_cal, thr=thr2) + assert n_filtered2 == 6 + # Included: 1,0,1,0, 1, 0. Correct: 1,0,1,0, ?, ?. Preds: 1,0,1,0, 1, 0. 6/6 correct. + assert accuracy2 == pytest.approx(1.0) + +def test_edge_filtered_accuracy_soft_labels(): + """Test with soft labels.""" + y_true_soft = np.array([0.9, 0.1, 0.8, 0.2, 0.7, 0.6]) # Soft labels + p_cal = np.array([0.8, 0.3, 0.9, 0.1, 0.6, 0.7]) # All edge > 0.1 + thr = 0.1 + + accuracy, n_filtered = edge_filtered_accuracy(y_true_soft, p_cal, thr=thr) + + assert n_filtered == 6 + # y_true_hard: 1, 0, 1, 0, 1, 1 + # y_pred : 1, 0, 1, 0, 1, 1. All correct. + assert accuracy == pytest.approx(1.0) + +def test_edge_filtered_accuracy_no_samples(): + """Test case where no samples meet the edge threshold.""" + y_true = np.array([1, 0, 1, 0]) + p_cal = np.array([0.51, 0.49, 0.52, 0.48]) # All edge < 0.1 + thr = 0.1 + + accuracy, n_filtered = edge_filtered_accuracy(y_true, p_cal, thr=thr) + assert n_filtered == 0 + assert np.isnan(accuracy) + +def test_edge_filtered_accuracy_empty_input(): + """Test with empty input arrays.""" + y_true = np.array([]) + p_cal = np.array([]) + thr = 0.1 + + accuracy, n_filtered = edge_filtered_accuracy(y_true, p_cal, thr=thr) + assert n_filtered == 0 + assert np.isnan(accuracy) + +# --- Tests for calculate_sharpe_ratio --- # + +def test_calculate_sharpe_ratio_basic(): + """Test basic Sharpe calculation.""" + returns = pd.Series([0.01, -0.005, 0.02, 0.005, -0.01]) + # mean = 0.004, std = 0.01166, Sharpe_period = 0.343 + # Annualized (252) = 0.343 * sqrt(252) = 5.44 + expected_sharpe = 5.44441 + sharpe = calculate_sharpe_ratio(returns, benchmark_return=0.0, annualization_factor=252) + assert sharpe == pytest.approx(expected_sharpe, abs=1e-4) + +def test_calculate_sharpe_ratio_different_annualization(): + """Test Sharpe with different annualization factor.""" + returns = pd.Series([0.01, -0.005, 0.02, 0.005, -0.01]) + # Annualized (52) = 0.343 * sqrt(52) = 2.47 + expected_sharpe = 2.4738 + sharpe = calculate_sharpe_ratio(returns, benchmark_return=0.0, annualization_factor=52) + assert sharpe == pytest.approx(expected_sharpe, abs=1e-4) + +def test_calculate_sharpe_ratio_with_benchmark(): + """Test Sharpe with a non-zero benchmark return.""" + returns = pd.Series([0.01, -0.005, 0.02, 0.005, -0.01]) # mean=0.004 + benchmark = 0.001 # Per period + # excess mean = 0.003, std = 0.01166, Sharpe_period = 0.257 + # Annualized (252) = 0.257 * sqrt(252) = 4.08 + expected_sharpe = 4.0833 + sharpe = calculate_sharpe_ratio(returns, benchmark_return=benchmark, annualization_factor=252) + assert sharpe == pytest.approx(expected_sharpe, abs=1e-4) + +def test_calculate_sharpe_ratio_zero_std(): + """Test Sharpe when returns have zero standard deviation.""" + returns_positive = pd.Series([0.01, 0.01, 0.01]) + returns_negative = pd.Series([-0.01, -0.01, -0.01]) + returns_zero = pd.Series([0.0, 0.0, 0.0]) + + assert calculate_sharpe_ratio(returns_positive) == 0.0 # Positive mean, zero std -> 0? + # assert calculate_sharpe_ratio(returns_negative) == -np.inf # Negative mean, zero std -> -inf? + assert calculate_sharpe_ratio(returns_zero) == 0.0 + + # Let's refine zero std handling based on function's logic + # Function returns 0 if mean>0, -inf if mean<0, 0 if mean=0 + assert calculate_sharpe_ratio(returns_positive) == 0.0 + assert calculate_sharpe_ratio(returns_negative) == -np.inf + assert calculate_sharpe_ratio(returns_zero) == 0.0 + +def test_calculate_sharpe_ratio_empty_or_nan(): + """Test Sharpe with empty or all-NaN input.""" + assert np.isnan(calculate_sharpe_ratio(pd.Series([], dtype=float))) + assert np.isnan(calculate_sharpe_ratio(pd.Series([np.nan, np.nan], dtype=float))) \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_model_shapes.py b/gru_sac_predictor/tests/test_model_shapes.py new file mode 100644 index 00000000..6616a2ca --- /dev/null +++ b/gru_sac_predictor/tests/test_model_shapes.py @@ -0,0 +1,139 @@ +""" +Tests for GRU model input/output shapes. + +Ref: revisions.txt Task 3.6 +""" +import pytest +import numpy as np +import sys, os + +# --- Add path for src imports --- # +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(script_dir) +src_path = os.path.join(project_root, 'src') +if src_path not in sys.path: + sys.path.insert(0, src_path) +# --- End Add path --- # + +# Import the v3 model builder +from model_gru_v3 import build_gru_model_v3 +# TODO: Import v2 model builder if needed for comparison tests +# from model_gru import build_gru_model + +# --- Constants for Testing --- # +LOOKBACK = 60 +N_FEATURES = 25 +BATCH_SIZE = 4 + +# --- Tests --- # + +def test_gru_v3_output_shapes(): + """Verify the output shapes of the GRU v3 model heads.""" + print(f"\nBuilding GRU v3 model for shape test...") + # Build the v3 model with default parameters + model = build_gru_model_v3(lookback=LOOKBACK, n_features=N_FEATURES) + assert model is not None, "Failed to build GRU v3 model" + + # Check number of outputs + assert len(model.outputs) == 2, f"Expected 2 outputs, got {len(model.outputs)}" + + # Check output names and shapes + # Output order in the model definition was [mu, dir3] + mu_output_shape = model.outputs[0].shape.as_list() + dir3_output_shape = model.outputs[1].shape.as_list() + + # Assert shapes (ignoring batch size None) + # mu head should be (None, 1) + assert mu_output_shape == [None, 1], f"Expected mu shape [None, 1], got {mu_output_shape}" + # dir3 head should be (None, 3) + assert dir3_output_shape == [None, 3], f"Expected dir3 shape [None, 3], got {dir3_output_shape}" + + print("GRU v3 output shapes test passed.") + +def test_gru_v3_prediction_shapes(): + """Verify the prediction shapes match the output shapes for a sample batch.""" + model = build_gru_model_v3(lookback=LOOKBACK, n_features=N_FEATURES) + assert model is not None, "Failed to build GRU v3 model" + + # Create dummy input data + dummy_input = np.random.rand(BATCH_SIZE, LOOKBACK, N_FEATURES) + + # Generate predictions + predictions = model.predict(dummy_input) + + # Check prediction structure and shapes + assert isinstance(predictions, list), "Predictions should be a list for multi-output model" + assert len(predictions) == 2, f"Expected 2 prediction arrays, got {len(predictions)}" + + # Predictions order should match model.outputs order [mu, dir3] + mu_preds = predictions[0] + dir3_preds = predictions[1] + + # Assert prediction shapes match expected batch size + assert mu_preds.shape == (BATCH_SIZE, 1), f"Expected mu prediction shape ({BATCH_SIZE}, 1), got {mu_preds.shape}" + assert dir3_preds.shape == (BATCH_SIZE, 3), f"Expected dir3 prediction shape ({BATCH_SIZE}, 3), got {dir3_preds.shape}" + + print("GRU v3 prediction shapes test passed.") + +# TODO: Add tests for GRU v2 model shapes if it's still relevant. + +def test_logits_view_shapes(): + """Test that softmax applied to predict_logits output matches predict output.""" + print(f"\nBuilding GRU v3 model for logits view test...") + model = build_gru_model_v3(lookback=LOOKBACK, n_features=N_FEATURES) + assert model is not None, "Failed to build GRU v3 model" + + # --- Requires GRUModelHandler to run predict_logits --- # + # We need to instantiate the handler to test its methods. + # Mock config and directories needed for handler init. + mock_config = { + 'control': {'use_v3': True}, + 'gru_v3': {} # Use defaults for building + } + mock_run_id = "test_logits_run" + mock_models_dir = "./mock_models/test_logits_run" + os.makedirs(mock_models_dir, exist_ok=True) # Create mock dir + + # Import handler locally for test setup + from gru_model_handler import GRUModelHandler + handler = GRUModelHandler(run_id=mock_run_id, models_dir=mock_models_dir, config=mock_config) + handler.model = model # Assign the already built model to the handler + handler.model_version_used = 'v3' # Set version manually + # --- End Handler Setup --- # + + # Create dummy input data + dummy_input = np.random.rand(BATCH_SIZE, LOOKBACK, N_FEATURES).astype(np.float32) + + # Generate predictions using both methods + logits = handler.predict_logits(dummy_input) + predictions = handler.predict(dummy_input) + + assert logits is not None, "predict_logits returned None" + assert predictions is not None, "predict returned None" + assert isinstance(predictions, list) and len(predictions) == 2, "predict output structure incorrect" + + probs_from_predict = predictions[1] # dir3 is the second output + + # Apply softmax to logits + # Use tf.nn.softmax for consistency with Keras backend + import tensorflow as tf + probs_from_logits = tf.nn.softmax(logits).numpy() + + # Assert shapes match first + assert probs_from_logits.shape == probs_from_predict.shape, \ + f"Shape mismatch: softmax(logits)={probs_from_logits.shape}, predict_probs={probs_from_predict.shape}" + + # Assert values are close + np.testing.assert_allclose( + probs_from_logits, + probs_from_predict, + rtol=1e-6, + atol=1e-6, # Use tighter tolerance for numerical precision check + err_msg="Softmax applied to logits does not match probability output from model.predict()" + ) + + print("Logits view test passed.") + # Clean up mock directory + import shutil + if os.path.exists("./mock_models"): + shutil.rmtree("./mock_models") \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_sac_agent.py b/gru_sac_predictor/tests/test_sac_agent.py new file mode 100644 index 00000000..9ffd96d0 --- /dev/null +++ b/gru_sac_predictor/tests/test_sac_agent.py @@ -0,0 +1,110 @@ +""" +Tests for the SACTradingAgent class. + +Ref: revisions.txt Task 5.7 +""" +import pytest +import numpy as np +import tensorflow as tf +import sys, os + +# --- Add path for src imports --- # +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(script_dir) +src_path = os.path.join(project_root, 'src') +if src_path not in sys.path: + sys.path.insert(0, src_path) +# --- End Add path --- # + +from sac_agent import SACTradingAgent + +# --- Constants --- # +STATE_DIM = 5 +ACTION_DIM = 1 +BUFFER_SIZE = 5000 +MIN_BUFFER = 1000 +TRAIN_STEPS = 1500 # Number of training steps for the test +BATCH_SIZE = 64 + +# --- Fixtures --- # + +@pytest.fixture +def sac_agent_fixture() -> SACTradingAgent: + """Provides a default SACTradingAgent instance for testing.""" + agent = SACTradingAgent( + state_dim=STATE_DIM, + action_dim=ACTION_DIM, + buffer_capacity=BUFFER_SIZE, + min_buffer_size=MIN_BUFFER, + alpha_auto_tune=True, # Enable auto-tuning for realistic test + target_entropy=-1.0 * ACTION_DIM # Default target entropy + ) + return agent + +def _populate_buffer(agent: SACTradingAgent, num_samples: int): + """Helper to add random transitions to the agent's buffer.""" + print(f"\nPopulating buffer with {num_samples} random samples...") + for _ in range(num_samples): + state = np.random.randn(STATE_DIM).astype(np.float32) + action = np.random.uniform(-1, 1, size=(ACTION_DIM,)).astype(np.float32) + reward = np.random.randn() + next_state = np.random.randn(STATE_DIM).astype(np.float32) + done = float(np.random.rand() < 0.05) # 5% chance of done + agent.buffer.add(state, action, reward, next_state, done) + print(f"Buffer populated. Size: {len(agent.buffer)}") + +# --- Tests --- # + +def test_sac_training_updates(sac_agent_fixture): + """ + Test 5.7: Run training steps and check for basic health: + a) Q-values are not NaN. + b) Action variance is reasonable (suggests exploration). + """ + agent = sac_agent_fixture + # Populate buffer sufficiently to start training + _populate_buffer(agent, MIN_BUFFER + BATCH_SIZE) + + print(f"\nRunning {TRAIN_STEPS} training steps...") + metrics_history = [] + for i in range(TRAIN_STEPS): + metrics = agent.train(batch_size=BATCH_SIZE) + if metrics: # Train only runs if buffer is full enough + metrics_history.append(metrics) + # Basic check within the loop to fail fast + if i % 100 == 0 and metrics: + assert not np.isnan(metrics['critic1_loss']), f"Critic1 loss is NaN at step {i}" + assert not np.isnan(metrics['critic2_loss']), f"Critic2 loss is NaN at step {i}" + assert not np.isnan(metrics['actor_loss']), f"Actor loss is NaN at step {i}" + if agent.alpha_auto_tune: + assert not np.isnan(metrics['alpha_loss']), f"Alpha loss is NaN at step {i}" + + assert len(metrics_history) > 0, "Training loop did not execute (buffer size issue?)" + print(f"Training steps completed. Last metrics: {metrics_history[-1]}") + + # a) Check final Q-values (indirectly via loss) + last_metrics = metrics_history[-1] + assert not np.isnan(last_metrics['critic1_loss']), "Final Critic1 loss is NaN" + assert not np.isnan(last_metrics['critic2_loss']), "Final Critic2 loss is NaN" + # We assume if losses are not NaN, Q-values involved are also not NaN + print("Check a) Passed: Q-value losses are not NaN.") + + # b) Check action variance after training + num_samples_for_variance = 500 + sampled_actions = [] + dummy_state = np.random.randn(STATE_DIM).astype(np.float32) + for _ in range(num_samples_for_variance): + # Sample non-deterministically to check stochastic policy variance + action = agent.get_action(dummy_state, deterministic=False) + sampled_actions.append(action) + + sampled_actions = np.array(sampled_actions) + action_variance = np.var(sampled_actions, axis=0) + print(f"Action variance after {TRAIN_STEPS} steps: {action_variance}") + + # Check if variance is above a threshold (e.g., 0.2 from revisions.txt) + # This threshold might need tuning based on action space scaling (-1 to 1) + min_variance_threshold = 0.2 + assert np.all(action_variance > min_variance_threshold), \ + f"Action variance ({action_variance}) is below threshold ({min_variance_threshold}). Exploration might be too low." + print(f"Check b) Passed: Action variance ({action_variance.round(3)}) > {min_variance_threshold}.") \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_sac_sanity.py b/gru_sac_predictor/tests/test_sac_sanity.py new file mode 100644 index 00000000..8d44bf67 --- /dev/null +++ b/gru_sac_predictor/tests/test_sac_sanity.py @@ -0,0 +1,121 @@ +""" +Sanity checks for the SAC agent (Sec 6 of revisions.txt). +""" +import pytest +import numpy as np +import os + +# Try to import the agent; skip tests if not found +try: + from gru_sac_predictor.src import sac_agent + # Need TF for tensor conversion if testing agent directly + import tensorflow as tf +except ImportError: + sac_agent = None + tf = None + +# --- Fixtures --- +@pytest.fixture(scope="module") +def sac_agent_instance(): + """ + Provides a default SAC agent instance for testing. + Uses standard parameters suitable for basic checks. + """ + if sac_agent is None: + pytest.skip("SAC Agent module not found.") + # Use default params, state_dim=5 as per revisions + # Use fixed seeds for reproducibility in tests if needed inside agent + agent = sac_agent.SACTradingAgent( + state_dim=5, action_dim=1, + initial_lr=1e-4, # Use a common LR for test simplicity + buffer_capacity=1000, # Smaller buffer for testing + min_buffer_size=100, + target_entropy=-1.0 + ) + # Build the models eagerly + try: + agent.actor(tf.zeros((1, 5))) + agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))]) + agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))]) + # Copy weights to target networks + agent.update_target_networks(tau=1.0) + except Exception as e: + pytest.fail(f"Failed to build SAC agent models: {e}") + return agent + +@pytest.fixture(scope="module") +def sample_sac_inputs(): + """ + Generate sample states and corresponding directional signals. + Simulates states with varying edge and signal-to-noise. + """ + np.random.seed(44) + n_samples = 1500 + # Simulate GRU outputs and position + mu = np.random.randn(n_samples) * 0.0015 # Slightly higher variance + sigma = np.random.uniform(0.0005, 0.0025, n_samples) + # Simulate edge with clearer separation for testing signals + edge_base = np.random.choice([-0.15, -0.05, 0.0, 0.05, 0.15], n_samples, p=[0.2, 0.2, 0.2, 0.2, 0.2]) + edge = np.clip(edge_base + np.random.randn(n_samples) * 0.03, -1.0, 1.0) + z_score = np.abs(mu) / (sigma + 1e-9) + position = np.random.uniform(-1, 1, n_samples) + states = np.vstack([mu, sigma, edge, z_score, position]).T.astype(np.float32) + # Use a small positive/negative threshold for determining signal from edge + signals = np.where(edge > 0.02, 1, np.where(edge < -0.02, -1, 0)) + return states, signals + +# --- Tests --- +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TensorFlow not found") +def test_sac_agent_default_min_buffer(sac_agent_instance): + """Verify the default min_buffer_size is at least 10000.""" + agent = sac_agent_instance + # Note: Fixture currently initializes with specific values, overriding default. + # Re-initialize with defaults for this test. + default_agent = sac_agent.SACTradingAgent(state_dim=5, action_dim=1) + min_buffer = default_agent.min_buffer_size + print(f"\nAgent default min_buffer_size: {min_buffer}") + assert min_buffer >= 10000, f"Default min_buffer_size ({min_buffer}) is less than recommended 10000." + +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TensorFlow not found") +def test_sac_action_variance(sac_agent_instance, sample_sac_inputs): + """ + Verify that the mean absolute action taken when the signal is non-zero + is >= 0.05. + """ + agent = sac_agent_instance + states, signals = sample_sac_inputs + + actions = [] + for state in states: + # Use deterministic action for this sanity check + action = agent.get_action(state, deterministic=True) + actions.append(action[0]) # get_action returns list/array + actions = np.array(actions) + + # Filter for non-zero signals based on the *simulated* edge + non_zero_signal_idx = signals != 0 + if not np.any(non_zero_signal_idx): + pytest.fail("No non-zero signals generated in fixture for SAC variance test.") + + actions_on_signal = actions[non_zero_signal_idx] + + if len(actions_on_signal) == 0: + # This case should ideally not happen if the above check passed + pytest.fail("Filtered actions array is empty despite non-zero signals.") + + mean_abs_action = np.mean(np.abs(actions_on_signal)) + + print(f"\nSAC Sanity Test: Mean Absolute Action (on signal != 0): {mean_abs_action:.4f}") + + # Check if the agent is outputting actions with sufficient magnitude + assert mean_abs_action >= 0.05, \ + f"Mean absolute action ({mean_abs_action:.4f}) is below threshold (0.05). Agent might be too timid or stuck near zero." + +@pytest.mark.skip(reason="Requires full backtest results which are not available in this unit test setup.") +def test_sac_reward_correlation(): + """ + Optional: Check if actions taken correlate positively with subsequent rewards. + NOTE: This test requires results from a full backtest run (actions vs rewards) + and cannot be reliably simulated or executed in this unit test. + """ + pass # Cannot implement without actual backtest results \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_sequence_creation.py b/gru_sac_predictor/tests/test_sequence_creation.py new file mode 100644 index 00000000..01539535 --- /dev/null +++ b/gru_sac_predictor/tests/test_sequence_creation.py @@ -0,0 +1,186 @@ +import pytest +import pandas as pd +import numpy as np +from omegaconf import OmegaConf +from unittest.mock import MagicMock +import os +import tempfile +import json + +# Adjust the import path based on your project structure +from gru_sac_predictor.src.pipeline_stages.sequence_creation import create_sequences_fold +from gru_sac_predictor.src.io_manager import IOManager # Adjust path if needed + +# --- Test Fixtures --- + +@pytest.fixture +def sample_data_with_imputed(): + """Creates sample X and y dataframes with a 'bar_imputed' column.""" + dates = pd.to_datetime(pd.date_range('2023-01-01', periods=20, freq='T')) + lookback = 5 + n_features_orig = 3 + n_samples = len(dates) + + # Features (including bar_imputed) + X_data = pd.DataFrame( + np.random.randn(n_samples, n_features_orig), + index=dates, + columns=[f'feat_{i}' for i in range(n_features_orig)] + ) + # Add bar_imputed column - mark some bars as imputed + imputed_flags = np.zeros(n_samples, dtype=bool) + imputed_flags[2] = True # Imputed within first potential sequence + imputed_flags[8] = True # Imputed within a later potential sequence + imputed_flags[15] = True # Imputed near the end + X_data['bar_imputed'] = imputed_flags + + # Targets (mu and dir3) + y_data = pd.DataFrame({ + 'mu': np.random.randn(n_samples), + 'dir3': [list(row) for row in np.eye(3)[np.random.randint(0, 3, n_samples)]] # Example one-hot + }, index=dates) + + return X_data, y_data + +@pytest.fixture +def base_config(): + """Creates a base OmegaConf config for testing sequence creation.""" + conf = OmegaConf.create({ + 'gru': { + 'lookback': 5, + 'use_ternary': True, # Matches sample_data_with_imputed + 'drop_imputed_sequences': True # Default to True for testing dropping + }, + # Add other necessary sections if needed + }) + return conf + +@pytest.fixture +def mock_io_manager(): + """Creates a mock IOManager for testing artefact saving.""" + with tempfile.TemporaryDirectory() as tmpdir: + mock_io = MagicMock(spec=IOManager) + mock_io.results_dir = tmpdir + saved_jsons = {} + def mock_save_json(data, filename, **kwargs): + filepath = os.path.join(tmpdir, filename) + saved_jsons[filename] = data + with open(filepath, 'w') as f: + json.dump(data, f, **kwargs) + mock_io.save_json.side_effect = mock_save_json + mock_io.get_artifact_path.side_effect = lambda filename: os.path.join(tmpdir, filename) + mock_io._saved_jsons = saved_jsons + yield mock_io + +# --- Test Functions --- + +def test_sequence_creation_shapes(sample_data_with_imputed, base_config, mock_io_manager): + X_data, y_data = sample_data_with_imputed + lookback = base_config.gru.lookback + n_features = X_data.shape[1] + n_samples = len(X_data) + expected_n_seq = n_samples - lookback + + # Test without dropping imputed + cfg_no_drop = base_config.copy() + cfg_no_drop.gru.drop_imputed_sequences = False + + X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold( + X_data=X_data, y_data=y_data, target_names=['mu', 'dir3'], + lookback=lookback, name="TestSplit", config=cfg_no_drop, io=mock_io_manager + ) + + assert dropped_count == 0 + assert X_seq is not None + assert y_seq_dict is not None + assert indices is not None + assert X_seq.shape == (expected_n_seq, lookback, n_features) + assert 'mu' in y_seq_dict and y_seq_dict['mu'].shape == (expected_n_seq,) + assert 'dir3' in y_seq_dict and y_seq_dict['dir3'].shape == (expected_n_seq, 3) + assert len(indices) == expected_n_seq + # Check first target index corresponds to lookback-th original index + assert indices[0] == X_data.index[lookback] + # Check last target index corresponds to last original index + assert indices[-1] == X_data.index[-1] + +def test_sequence_dropping_imputed(sample_data_with_imputed, base_config, mock_io_manager): + X_data, y_data = sample_data_with_imputed + lookback = base_config.gru.lookback + n_samples = len(X_data) + expected_n_seq_orig = n_samples - lookback + + # Config with dropping enabled (default in fixture) + cfg_drop = base_config + + X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold( + X_data=X_data.copy(), y_data=y_data.copy(), target_names=['mu', 'dir3'], + lookback=lookback, name="TestDrop", config=cfg_drop, io=mock_io_manager + ) + + assert X_seq is not None + assert y_seq_dict is not None + assert indices is not None + + # Determine which original sequences should have been dropped + # A sequence starting at index i uses data from [i, i+lookback-1] + # The target corresponds to index i+lookback + # We need to check the imputed flag in the range [i, i+lookback-1] for each potential sequence target index i+lookback + + # Original target indices range from index `lookback` to `n_samples - 1` + should_be_dropped_mask = np.zeros(expected_n_seq_orig, dtype=bool) + imputed_flags_np = X_data['bar_imputed'].values + for seq_idx in range(expected_n_seq_orig): + # The features for this sequence are from original indices [seq_idx, seq_idx + lookback - 1] + feature_indices_range = slice(seq_idx, seq_idx + lookback) + if np.any(imputed_flags_np[feature_indices_range]): + should_be_dropped_mask[seq_idx] = True + + expected_dropped_count = np.sum(should_be_dropped_mask) + expected_remaining_count = expected_n_seq_orig - expected_dropped_count + + assert dropped_count == expected_dropped_count + assert X_seq.shape[0] == expected_remaining_count + assert y_seq_dict['mu'].shape[0] == expected_remaining_count + assert y_seq_dict['dir3'].shape[0] == expected_remaining_count + assert len(indices) == expected_remaining_count + + # Check that the remaining indices are correct (weren't marked for dropping) + original_indices = X_data.index[lookback:] + expected_remaining_indices = original_indices[~should_be_dropped_mask] + pd.testing.assert_index_equal(indices, expected_remaining_indices) + + # Check artifact saving + assert 'imputed_sequence_summary_testdrop.json' in mock_io_manager._saved_jsons + report_data = mock_io_manager._saved_jsons['imputed_sequence_summary_testdrop.json'] + assert report_data['total_sequences_generated'] == expected_n_seq_orig + assert report_data['sequences_dropped_imputed'] == expected_dropped_count + assert report_data['sequences_remaining'] == expected_remaining_count + +def test_sequence_creation_no_imputed_col(sample_data_with_imputed, base_config, mock_io_manager): + X_data, y_data = sample_data_with_imputed + X_data_no_imputed = X_data.drop(columns=['bar_imputed']) + lookback = base_config.gru.lookback + + with pytest.raises(SystemExit) as excinfo: + create_sequences_fold( + X_data=X_data_no_imputed, y_data=y_data, target_names=['mu', 'dir3'], + lookback=lookback, name="TestNoImputedCol", config=base_config, io=mock_io_manager + ) + assert "'bar_imputed' column missing" in str(excinfo.value) + +def test_sequence_creation_insufficient_data(sample_data_with_imputed, base_config, mock_io_manager): + X_data, y_data = sample_data_with_imputed + lookback = base_config.gru.lookback + # Create data shorter than lookback + X_short = X_data.iloc[:lookback-1] + y_short = y_data.iloc[:lookback-1] + + X_seq, y_seq_dict, indices, dropped_count = create_sequences_fold( + X_data=X_short, y_data=y_short, target_names=['mu', 'dir3'], + lookback=lookback, name="TestShort", config=base_config, io=mock_io_manager + ) + + assert X_seq is None + assert y_seq_dict is None + assert indices is None + assert dropped_count == 0 \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_time_encoding.py b/gru_sac_predictor/tests/test_time_encoding.py new file mode 100644 index 00000000..728c3172 --- /dev/null +++ b/gru_sac_predictor/tests/test_time_encoding.py @@ -0,0 +1,94 @@ +""" +Tests for time encoding, specifically DST transitions. +""" +import pytest +import pandas as pd +import numpy as np +import pytz # For timezone handling + +@pytest.fixture(scope="module") +def generate_dst_timeseries(): + """ + Generate a minute-frequency timestamp series crossing DST transitions + for a specific timezone (e.g., US/Eastern). + """ + # Example: US/Eastern DST Start (e.g., March 10, 2024 2:00 AM -> 3:00 AM) + # Example: US/Eastern DST End (e.g., Nov 3, 2024 2:00 AM -> 1:00 AM) + tz = pytz.timezone('US/Eastern') + + # Create timestamps around DST start + dst_start_range = pd.date_range( + start='2024-03-10 01:00:00', end='2024-03-10 04:00:00', freq='T', tz=tz + ) + # Create timestamps around DST end + dst_end_range = pd.date_range( + start='2024-11-03 00:00:00', end='2024-11-03 03:00:00', freq='T', tz=tz + ) + + # Combine and ensure uniqueness/order (though disjoint here) + timestamps = dst_start_range.union(dst_end_range) + df = pd.DataFrame(index=timestamps) + df.index.name = 'timestamp' + return df + +def calculate_cyclical_features(df): + """Helper to calculate sin/cos features from a datetime index.""" + if not isinstance(df.index, pd.DatetimeIndex): + raise TypeError("Input DataFrame must have a DatetimeIndex.") + + # Ensure timezone is present (fixture provides it) + if df.index.tz is None: + print("Warning: Index timezone is None, assuming UTC for calculation.") + timestamp_source = df.index.tz_localize('utc') + else: + timestamp_source = df.index + + # Use UTC hour for consistent calculation if timezone handling upstream is complex + # Or use localized hour if pipeline guarantees consistent local TZ + # Here, let's use the localized hour provided by the fixture + hour_of_day = timestamp_source.hour + # minute_of_day = timestamp_source.hour * 60 + timestamp_source.minute # Alternative + + df['hour_sin'] = np.sin(2 * np.pi * hour_of_day / 24) + df['hour_cos'] = np.cos(2 * np.pi * hour_of_day / 24) + return df + + +def test_cyclical_features_continuity(generate_dst_timeseries): + """ + Check if hour_sin and hour_cos features are continuous (no large jumps) + across DST transitions, assuming calculation uses localized time. + If using UTC hour, continuity is guaranteed, but might not capture + local market patterns intended. + """ + df = generate_dst_timeseries + df = calculate_cyclical_features(df) + + # Check differences between consecutive values + sin_diff = df['hour_sin'].diff().abs() + cos_diff = df['hour_cos'].diff().abs() + + # Define a reasonable threshold for a jump (e.g., difference > value for 15 mins) + # Max change in sin(2*pi*h/24) over 1 minute is small. + # A jump of 1 hour means h changes by 1, argument changes by pi/12. + # Max diff sin(x+pi/12) - sin(x) is approx pi/12 ~ 0.26 + max_allowed_diff = 0.3 # Allow slightly more than 1 hour jump equivalent + + print(f"\nMax Sin Diff: {sin_diff.max():.4f}") + print(f"Max Cos Diff: {cos_diff.max():.4f}") + + assert sin_diff.max() < max_allowed_diff, \ + f"Large jump detected in hour_sin ({sin_diff.max():.4f}) around DST. Check time source/calculation." + assert cos_diff.max() < max_allowed_diff, \ + f"Large jump detected in hour_cos ({cos_diff.max():.4f}) around DST. Check time source/calculation." + + # Optional: Plot to visually inspect + # import matplotlib.pyplot as plt + # plt.figure() + # plt.plot(df.index, df['hour_sin'], '.-.', label='sin') + # plt.plot(df.index, df['hour_cos'], '.-.', label='cos') + # plt.title('Cyclical Features Across DST') + # plt.legend() + # plt.xticks(rotation=45) + # plt.tight_layout() + # plt.show() \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_trading_env.py b/gru_sac_predictor/tests/test_trading_env.py new file mode 100644 index 00000000..8af07e6b --- /dev/null +++ b/gru_sac_predictor/tests/test_trading_env.py @@ -0,0 +1,166 @@ +import pytest +import numpy as np +from omegaconf import OmegaConf + +# Adjust import path based on structure +from gru_sac_predictor.src.trading_env import TradingEnv + +# --- Test Fixtures --- + +@pytest.fixture +def sample_env_data(): + """Provides sample data for initializing the TradingEnv.""" + n_steps = 10 + data = { + 'mu_predictions': np.random.randn(n_steps) * 0.001, + 'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005), + 'p_cal_predictions': np.random.rand(n_steps), + 'actual_returns': np.random.randn(n_steps) * 0.0015, + 'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool) + } + return data + +@pytest.fixture +def base_env_config(): + """Base configuration for the environment.""" + return OmegaConf.create({ + 'sac': { + 'imputed_handling': 'skip', # Default test mode + 'action_penalty': 0.05 + }, + 'environment': { + 'initial_capital': 10000.0, + 'transaction_cost': 0.0005, + 'reward_scale': 100.0, + 'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc + } + }) + +@pytest.fixture +def trading_env_instance(sample_env_data, base_env_config): + """Creates a TradingEnv instance with default 'skip' mode.""" + return TradingEnv(**sample_env_data, config=base_env_config) + +# --- Test Functions --- + +def test_env_initialization(trading_env_instance, sample_env_data): + assert trading_env_instance.n_steps == len(sample_env_data['actual_returns']) + assert trading_env_instance.current_step == 0 + assert trading_env_instance.current_position == 0.0 + assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags']) + +def test_env_reset(trading_env_instance): + # Take a few steps + trading_env_instance.step(0.5) + trading_env_instance.step(-0.2) + assert trading_env_instance.current_step > 0 + # Reset + initial_state = trading_env_instance.reset() + assert trading_env_instance.current_step == 0 + assert trading_env_instance.current_position == 0.0 + assert initial_state is not None + assert initial_state.shape == (trading_env_instance.state_dim,) + +def test_env_step_normal(trading_env_instance): + # Test a normal step (step 0 is not imputed) + initial_pos = trading_env_instance.current_position + action = 0.7 + next_state, reward, done, info = trading_env_instance.step(action) + + assert trading_env_instance.current_step == 1 + assert trading_env_instance.current_position == action # Position updates to action + assert not info['is_imputed_step_skipped'] + assert not done + assert next_state is not None + # Reward calculation is complex, just check type/sign if needed + assert isinstance(reward, float) + +def test_env_step_imputed_skip(trading_env_instance, sample_env_data): + # Step 2 is imputed in sample_env_data + trading_env_instance.step(0.5) # Step 0 + trading_env_instance.step(0.6) # Step 1 + assert trading_env_instance.current_step == 2 + initial_pos_before_imputed = trading_env_instance.current_position + + # Action for the imputed step (should be ignored by 'skip') + action_imputed = 0.9 + next_state, reward, done, info = trading_env_instance.step(action_imputed) + + # Should skip step 2 and now be at step 3 + assert trading_env_instance.current_step == 3 + # Position should NOT have changed from step 1 + assert trading_env_instance.current_position == initial_pos_before_imputed + assert reward == 0.0 # Skip gives 0 reward + assert not done + assert info['is_imputed_step_skipped'] == True # Crucial check for buffer + # Check that the returned state is for step 3 + expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3 + np.testing.assert_array_almost_equal(next_state, expected_state_step_3) + +def test_env_step_imputed_hold(sample_env_data, base_env_config): + cfg = base_env_config.copy() + cfg.sac.imputed_handling = 'hold' + env = TradingEnv(**sample_env_data, config=cfg) + + # Step 2 is imputed + env.step(0.5) # Step 0 + env.step(0.6) # Step 1 + assert env.current_step == 2 + position_before_imputed = env.current_position + + # Action for the imputed step (should be overridden by 'hold') + action_imputed = -0.5 + next_state, reward, done, info = env.step(action_imputed) + + # Should process step 2 and move to step 3 + assert env.current_step == 3 + # Position should be the same as before the step + assert env.current_position == position_before_imputed + assert not info['is_imputed_step_skipped'] + assert not done + # Reward should be calculated based on holding the position + expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1) + expected_cost = 0 # No trade size if holding + expected_penalty = 0 # No penalty in hold mode + expected_raw_reward = expected_pnl - expected_cost - expected_penalty + expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale + assert np.isclose(reward, expected_scaled_reward) + +def test_env_step_imputed_penalty(sample_env_data, base_env_config): + cfg = base_env_config.copy() + cfg.sac.imputed_handling = 'penalty' + cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing + env = TradingEnv(**sample_env_data, config=cfg) + + # Step 2 is imputed + env.step(0.5) # Step 0 + env.step(0.6) # Step 1 + assert env.current_step == 2 + position_before_imputed = env.current_position # Should be 0.6 + + # Action for the imputed step + action_imputed = -0.2 + next_state, reward, done, info = env.step(action_imputed) + + # Should process step 2 and move to step 3 + assert env.current_step == 3 + # Position should update to the *agent's* action + assert env.current_position == np.clip(action_imputed, -1.0, 1.0) + assert not info['is_imputed_step_skipped'] + assert not done + + # Reward calculation is ONLY the penalty + expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2 + expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale + assert np.isclose(reward, expected_scaled_reward) + +def test_env_done_condition(trading_env_instance, sample_env_data): + n_steps = len(sample_env_data['actual_returns']) + # Step through the environment + done = False + for i in range(n_steps): + _, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1)) + if i < n_steps - 1: + assert not done + else: + assert done # Should be done on the last step \ No newline at end of file diff --git a/notebooks/example_pipeline_run.txt b/notebooks/example_pipeline_run.txt new file mode 100644 index 00000000..6740b9e1 --- /dev/null +++ b/notebooks/example_pipeline_run.txt @@ -0,0 +1,709 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GRU-SAC Trading Pipeline: Step-by-Step Walkthrough\n", + "\n", + "This notebook demonstrates how to instantiate and run the refactored `TradingPipeline` class **sequentially**, executing each major step individually.\n", + "\n", + "**Goal:** Run the complete pipeline (data loading, feature engineering, GRU training/loading, calibration, SAC loading, backtesting) using a configuration file, inspecting the inputs and outputs at each stage." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Imports and Setup\n", + "\n", + "Import necessary libraries and configure path variables to locate the project code." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initial sys.path: ['/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/yasha/develop/gru_sac_predictor/.venv/lib/python3.10/site-packages']\n", + "Notebook directory (notebook_dir): /home/yasha/develop/gru_sac_predictor/gru_sac_predictor/notebooks\n", + "Calculated path for imports (package_root_for_imports): /home/yasha/develop/gru_sac_predictor/gru_sac_predictor\n", + "Checking if /home/yasha/develop/gru_sac_predictor/gru_sac_predictor is in sys.path...\n", + "Path not found. Adding /home/yasha/develop/gru_sac_predictor/gru_sac_predictor to sys.path.\n", + "sys.path after insert: ['/home/yasha/develop/gru_sac_predictor/gru_sac_predictor', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/yasha/develop/gru_sac_predictor/.venv/lib/python3.10/site-packages']\n", + "Project root for config/data (project_root): /home/yasha/develop/gru_sac_predictor\n", + "Attempting to import TradingPipeline...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-18 03:17:10.421895: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1744946230.439676 157301 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1744946230.445571 157301 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "/home/yasha/develop/gru_sac_predictor/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Successfully imported TradingPipeline.\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "import yaml\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "import logging\n", + "\n", + "print(f'Initial sys.path: {sys.path}')\n", + "\n", + "# --- Path Setup ---\n", + "# Initialize project_root to None\n", + "project_root = None\n", + "package_root_for_imports = None # Initialize separately for clarity\n", + "try:\n", + " notebook_dir = os.path.abspath('') # Get current directory (should be notebooks/)\n", + " print(f'Notebook directory (notebook_dir): {notebook_dir}')\n", + "\n", + " # Go up ONE level to get the package root directory\n", + " # Since notebook is in .../gru_sac_predictor/notebooks/, parent is .../gru_sac_predictor/\n", + " package_root_for_imports = os.path.dirname(notebook_dir)\n", + " print(f'Calculated path for imports (package_root_for_imports): {package_root_for_imports}')\n", + "\n", + " # Add the calculated path to sys.path to allow imports\n", + " print(f'Checking if {package_root_for_imports} is in sys.path...')\n", + " if package_root_for_imports not in sys.path:\n", + " print(f'Path not found. Adding {package_root_for_imports} to sys.path.')\n", + " sys.path.insert(0, package_root_for_imports)\n", + " print(f'sys.path after insert: {sys.path}')\n", + " else:\n", + " print(f'Path {package_root_for_imports} already in sys.path.')\n", + "\n", + " # Define project_root consistently, used later for finding config.yaml\n", + " # It should be the *outer* directory containing the package and config\n", + " project_root = os.path.dirname(package_root_for_imports) # Go up one more level\n", + " print(f'Project root for config/data (project_root): {project_root}')\n", + "\n", + "except Exception as e:\n", + " print(f'Error during path setup: {e}')\n", + "\n", + "# --- Import the main pipeline class ---\n", + "print(\"Attempting to import TradingPipeline...\")\n", + "try:\n", + " # Import relative to the package root added to sys.path\n", + " from src.trading_pipeline import TradingPipeline\n", + " print('Successfully imported TradingPipeline.')\n", + "except ImportError as e:\n", + " print(f'ERROR: Failed to import TradingPipeline: {e}')\n", + " print(f'Final sys.path before error: {sys.path}')\n", + " print(\"Please verify the project structure and the paths added to sys.path.\")\n", + "except Exception as e: # Catch other potential errors\n", + " print(f'An unexpected error occurred during import: {e}')\n", + " print(f'Final sys.path before error: {sys.path}')\n", + "\n", + "# Configure basic logging for the notebook\n", + "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", + "\n", + "# Set pandas display options for better inspection\n", + "pd.set_option('display.max_columns', None) # Show all columns\n", + "pd.set_option('display.max_rows', 100) # Show more rows if needed\n", + "pd.set_option('display.width', 1000) # Wider display" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Configuration\n", + "\n", + "Specify the path to the configuration file (`config.yaml`). This file defines all parameters for the data, models, training, and backtesting." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using config file: /home/yasha/develop/gru_sac_predictor/gru_sac_predictor/config.yaml\n", + "Config file found.\n" + ] + } + ], + "source": [ + "# Path to the configuration file\n", + "# Assumes config.yaml is in the directory *above* the package root\n", + "config_rel_path = 'gru_sac_predictor/config.yaml' # Relative to project_root defined above\n", + "config_abs_path = None\n", + "\n", + "# Construct absolute path relative to the project root identified earlier\n", + "if 'project_root' in locals() and project_root: # Check if project_root was successfully determined\n", + " config_abs_path = os.path.join(project_root, config_rel_path)\n", + "else:\n", + " print('ERROR: project_root not defined. Cannot find config file.')\n", + "\n", + "if config_abs_path:\n", + " print(f'Using config file: {config_abs_path}')\n", + " # Verify the config file exists\n", + " if not os.path.exists(config_abs_path):\n", + " print(f'ERROR: Config file not found at {config_abs_path}')\n", + " else:\n", + " print('Config file found.')\n", + " # Optionally load and display config for verification\n", + " try:\n", + " with open(config_abs_path, 'r') as f:\n", + " config_data = yaml.safe_load(f)\n", + " # print('\\nConfiguration:')\n", + " # print(yaml.dump(config_data, default_flow_style=False)) # Pretty print\n", + " except Exception as e:\n", + " print(f'Error reading config file: {e}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Instantiate the Pipeline\n", + "\n", + "Create an instance of the `TradingPipeline` class, passing the path to the configuration file. This initializes the pipeline object but does not run any steps yet." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Instantiating TradingPipeline...\n", + "2025-04-18 03:17:13,554 - root - INFO - Using Base Models Directory: /home/yasha/develop/gru_sac_predictor/models\n", + "2025-04-18 03:17:13,555 - root - INFO - Using results directory: /home/yasha/develop/gru_sac_predictor/results/20250418_031713\n", + "2025-04-18 03:17:13,555 - root - INFO - Using logs directory: /home/yasha/develop/gru_sac_predictor/logs/20250418_031713\n", + "2025-04-18 03:17:13,556 - root - INFO - Using models directory: /home/yasha/develop/gru_sac_predictor/models/20250418_031713\n", + "2025-04-18 03:17:13,557 - root - INFO - Logging setup complete. Log file: /home/yasha/develop/gru_sac_predictor/logs/20250418_031713/pipeline_20250418_031713.log\n", + "2025-04-18 03:17:13,558 - root - INFO - --- Starting Pipeline Run: 20250418_031713 ---\n", + "2025-04-18 03:17:13,559 - root - INFO - Using config: /home/yasha/develop/gru_sac_predictor/gru_sac_predictor/config.yaml\n", + "2025-04-18 03:17:13,560 - root - INFO - Resolved relative db_dir '../../data/crypto_market_data' to absolute path: /home/yasha/data/crypto_market_data\n", + "2025-04-18 03:17:13,561 - gru_sac_predictor.src.data_loader - INFO - Initialized DataLoader with db_dir='/home/yasha/data/crypto_market_data'\n", + "2025-04-18 03:17:13,562 - gru_sac_predictor.src.data_loader - WARNING - Database directory does not exist: /home/yasha/data/crypto_market_data\n", + "2025-04-18 03:17:13,563 - gru_sac_predictor.src.feature_engineer - INFO - FeatureEngineer initialized with minimal whitelist: ['return_1m', 'return_15m', 'return_60m', 'ATR_14', 'volatility_14d', 'chaikin_AD_10', 'svi_10', 'EMA_10', 'EMA_50', 'MACD', 'MACD_signal', 'hour_sin', 'hour_cos']\n", + "2025-04-18 03:17:13,564 - gru_sac_predictor.src.gru_model_handler - INFO - GRUModelHandler initialized for run 20250418_031713 in /home/yasha/develop/gru_sac_predictor/models/20250418_031713\n", + "2025-04-18 03:17:13,564 - gru_sac_predictor.src.calibrator - INFO - Calibrator initialized with edge threshold: 0.55\n", + "2025-04-18 03:17:13,565 - gru_sac_predictor.src.backtester - INFO - Backtester initialized.\n", + "2025-04-18 03:17:13,566 - gru_sac_predictor.src.backtester - INFO - Initial Capital: 10000.00\n", + "2025-04-18 03:17:13,566 - gru_sac_predictor.src.backtester - INFO - Transaction Cost: 0.0500%\n", + "2025-04-18 03:17:13,567 - gru_sac_predictor.src.backtester - INFO - Edge Threshold: 0.550\n", + "2025-04-18 03:17:13,575 - root - INFO - Saved run configuration to /home/yasha/develop/gru_sac_predictor/results/20250418_031713/run_config.yaml\n", + "TradingPipeline instantiated successfully.\n", + "Run ID: 20250418_031713\n", + "Results Dir: /home/yasha/develop/gru_sac_predictor/results/20250418_031713\n", + "Log Dir: /home/yasha/develop/gru_sac_predictor/logs/20250418_031713\n", + "Models Dir: /home/yasha/develop/gru_sac_predictor/models/20250418_031713\n" + ] + } + ], + "source": [ + "pipeline_instance = None # Define outside try block\n", + "if 'TradingPipeline' in locals() and config_abs_path and os.path.exists(config_abs_path):\n", + " try:\n", + " # Instantiate the pipeline\n", + " print('Instantiating TradingPipeline...')\n", + " pipeline_instance = TradingPipeline(config_path=config_abs_path)\n", + " print('TradingPipeline instantiated successfully.')\n", + " print(f'Run ID: {pipeline_instance.run_id}')\n", + " print(f'Results Dir: {pipeline_instance.dirs[\"results\"]}')\n", + " print(f'Log Dir: {pipeline_instance.dirs[\"logs\"]}')\n", + " print(f'Models Dir: {pipeline_instance.dirs[\"models\"]}')\n", + "\n", + " except FileNotFoundError as e:\n", + " print(f'ERROR during pipeline instantiation (FileNotFound): {e}')\n", + " except Exception as e:\n", + " print(f'An error occurred during pipeline instantiation: {e}')\n", + " logging.error('Pipeline instantiation failed.', exc_info=True) # Log traceback\n", + "else:\n", + " print('TradingPipeline class not imported, config path invalid, or config file not found. Cannot instantiate pipeline.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Step 1: Load Data\n", + "\n", + "Call the `load_data` method to fetch the raw market data based on the configuration." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== Running Step 1: Load Data ===\n", + "2025-04-18 03:17:15,747 - root - INFO - --- Notebook Step: Load Data (Calling load_and_preprocess_data) ---\n", + "2025-04-18 03:17:15,749 - root - INFO - --- Stage: Loading and Preprocessing Data ---\n", + "2025-04-18 03:17:15,751 - gru_sac_predictor.src.data_loader - INFO - Loading data for SOL-USDT (bnbspot) from 2024-06-01 to 2025-03-10, interval 1min\n", + "2025-04-18 03:17:15,767 - gru_sac_predictor.src.data_loader - INFO - Scanning for DB files recursively in: /home/yasha/data/crypto_market_data\n", + "2025-04-18 03:17:15,769 - gru_sac_predictor.src.data_loader - ERROR - Database directory /home/yasha/data/crypto_market_data does not exist\n", + "2025-04-18 03:17:15,773 - gru_sac_predictor.src.data_loader - ERROR - No relevant DB files found and no fallback files available.\n", + "2025-04-18 03:17:15,774 - gru_sac_predictor.src.data_loader - ERROR - No relevant database files found for the specified date range.\n", + "2025-04-18 03:17:15,779 - root - ERROR - Failed to load data. Exiting.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No traceback available to show.\n" + ] + }, + { + "ename": "SystemExit", + "evalue": "1", + "output_type": "error", + "traceback": [ + "An exception has occurred, use %tb to see the full traceback.\n", + "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/yasha/develop/gru_sac_predictor/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3587: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n", + " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n" + ] + } + ], + "source": [ + "%tb\n", + "if pipeline_instance:\n", + " try:\n", + " print('\\n=== Running Step 1: Load Data ===')\n", + " pipeline_instance.load_data()\n", + " print('load_data() finished.')\n", + "\n", + " print('\\n--- Inspecting Raw Data ---')\n", + " if pipeline_instance.raw_data is not None:\n", + " print(f'Shape of raw_data: {pipeline_instance.raw_data.shape}')\n", + " display(pipeline_instance.raw_data.head())\n", + " display(pipeline_instance.raw_data.tail())\n", + " display(pipeline_instance.raw_data.isnull().sum()) # Check for NaNs\n", + " else:\n", + " print('raw_data attribute is None.')\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Load Data step: {e}')\n", + " logging.error('Load Data step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated. Cannot run step.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Step 2: Engineer Features\n", + "\n", + "Call the `engineer_features` method to create technical indicators and other features from the raw data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance and pipeline_instance.raw_data is not None:\n", + " try:\n", + " print('\\n=== Running Step 2: Engineer Features ===')\n", + " pipeline_instance.engineer_features()\n", + " print('engineer_features() finished.')\n", + "\n", + " print('\\n--- Inspecting Features DataFrame ---')\n", + " if pipeline_instance.features_df is not None:\n", + " print(f'Shape of features_df: {pipeline_instance.features_df.shape}')\n", + " display(pipeline_instance.features_df.head())\n", + " display(pipeline_instance.features_df.tail())\n", + " display(pipeline_instance.features_df.isnull().sum()) # Check for NaNs introduced by features\n", + " else:\n", + " print('features_df attribute is None.')\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Engineer Features step: {e}')\n", + " logging.error('Engineer Features step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated or raw_data missing. Cannot run step.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Step 3: Prepare Sequences\n", + "\n", + "Call the `prepare_sequences` method to split the data into train/validation/test sets and create sequences suitable for the GRU model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance and pipeline_instance.features_df is not None:\n", + " try:\n", + " print('\\n=== Running Step 3: Prepare Sequences ===')\n", + " pipeline_instance.prepare_sequences()\n", + " print('prepare_sequences() finished.')\n", + "\n", + " print('\\n--- Inspecting Sequences and Targets ---')\n", + " # Assuming attributes like train_sequences, val_targets etc. exist\n", + " for name in ['train_sequences', 'val_sequences', 'test_sequences',\n", + " 'train_targets', 'val_targets', 'test_targets',\n", + " 'train_indices', 'val_indices', 'test_indices']:\n", + " attr = getattr(pipeline_instance, name, None)\n", + " if attr is not None:\n", + " # Check if it's numpy array or pandas series/df before getting shape\n", + " if hasattr(attr, 'shape'):\n", + " print(f'{name} shape: {attr.shape}')\n", + " else:\n", + " print(f'{name} type: {type(attr)}, length: {len(attr)}') # For lists like indices\n", + " else:\n", + " print(f'{name} attribute is None.')\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Prepare Sequences step: {e}')\n", + " logging.error('Prepare Sequences step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated or features_df missing. Cannot run step.')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Step 4: Train or Load GRU Model\n", + "\n", + "Call `train_or_load_gru` to either train a new GRU model or load a pre-trained one, based on the configuration flags." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance and pipeline_instance.train_sequences is not None: # Check if sequences are ready\n", + " try:\n", + " print('\\n=== Running Step 4: Train or Load GRU ===')\n", + " pipeline_instance.train_or_load_gru()\n", + " print('train_or_load_gru() finished.')\n", + "\n", + " print('\\n--- Inspecting GRU Handler ---')\n", + " if pipeline_instance.gru_handler is not None:\n", + " print(f'GRU Handler instantiated: {pipeline_instance.gru_handler}')\n", + " # Potentially inspect model summary if handler exposes it\n", + " # print(pipeline_instance.gru_handler.model.summary())\n", + " print(f'GRU Predictions available (val): {hasattr(pipeline_instance.gru_handler, \"val_predictions\")}')\n", + " print(f'GRU Predictions available (test): {hasattr(pipeline_instance.gru_handler, \"test_predictions\")}')\n", + " else:\n", + " print('gru_handler attribute is None.')\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Train/Load GRU step: {e}')\n", + " logging.error('Train/Load GRU step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated or sequences missing. Cannot run step.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Step 5: Calibrate Predictions\n", + "\n", + "Call `calibrate_predictions` to use the validation set predictions from the GRU to find an optimal probability threshold or apply other calibration techniques." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance and pipeline_instance.gru_handler is not None and hasattr(pipeline_instance.gru_handler, 'val_predictions'):\n", + " try:\n", + " print('\\n=== Running Step 5: Calibrate Predictions ===')\n", + " pipeline_instance.calibrate_predictions()\n", + " print('calibrate_predictions() finished.')\n", + "\n", + " print('\\n--- Inspecting Calibration Results ---')\n", + " if pipeline_instance.calibrator is not None:\n", + " print(f'Calibrator object: {pipeline_instance.calibrator}')\n", + " print(f'Optimal threshold: {getattr(pipeline_instance, \"optimal_threshold\", \"Not set\")}')\n", + " print(f'Calibrated Val Probs exist: {hasattr(pipeline_instance.calibrator, \"calibrated_val_probabilities\")}')\n", + " print(f'Calibrated Test Probs exist: {hasattr(pipeline_instance.calibrator, \"calibrated_test_probabilities\")}')\n", + "\n", + " else:\n", + " print('calibrator attribute is None.')\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Calibrate Predictions step: {e}')\n", + " logging.error('Calibrate Predictions step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated or GRU validation predictions missing. Cannot run step.')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Step 6: Prepare SAC Agent for Backtest\n", + "\n", + "Call `train_or_load_sac`. This step might involve triggering offline SAC training (if configured) or simply identifying and setting the path to the pre-trained SAC agent policy to be used in the backtest." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note: Actual SAC training might be complex to run directly inline.\n", + "# This step often just prepares the necessary info (like the agent path) for the backtester.\n", + "if pipeline_instance:\n", + " try:\n", + " print('\\n=== Running Step 6: Train or Load SAC (Prepare for Backtest) ===')\n", + " # This might just set an attribute like sac_agent_path based on config\n", + " pipeline_instance.train_or_load_sac()\n", + " print('train_or_load_sac() finished.')\n", + "\n", + " print('\\n--- Inspecting SAC Agent Info ---')\n", + " # Check the attribute storing the path or relevant SAC info\n", + " print(f'SAC Agent Path for backtest: {getattr(pipeline_instance, \"sac_agent_path\", \"Not set\")}')\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Train/Load SAC step: {e}')\n", + " logging.error('Train/Load SAC step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated. Cannot run step.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Step 7: Run Backtest\n", + "\n", + "Execute the trading simulation using the test data, GRU predictions (calibrated), and the loaded SAC agent policy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if necessary components are ready\n", + "backtest_ready = (\n", + " pipeline_instance and\n", + " pipeline_instance.test_sequences is not None and\n", + " pipeline_instance.test_targets is not None and\n", + " pipeline_instance.test_indices is not None and\n", + " pipeline_instance.gru_handler is not None and\n", + " pipeline_instance.calibrator is not None and # Ensure calibration ran\n", + " getattr(pipeline_instance, \"optimal_threshold\", None) is not None and\n", + " getattr(pipeline_instance, \"sac_agent_path\", None) is not None\n", + ")\n", + "\n", + "if backtest_ready:\n", + " try:\n", + " print('\\n=== Running Step 7: Run Backtest ===')\n", + " pipeline_instance.run_backtest()\n", + " print('run_backtest() finished.')\n", + "\n", + " print('\\n--- Inspecting Backtest Results ---')\n", + " if pipeline_instance.backtest_metrics:\n", + " print('\\n--- Backtest Metrics --- ')\n", + " metrics = pipeline_instance.backtest_metrics\n", + " metrics['Run ID'] = pipeline_instance.run_id # Add run ID for context\n", + " for key, value in metrics.items():\n", + " if key == \"Confusion Matrix (GRU Signal vs Actual Dir)\":\n", + " print(f'{key}:\\\\n{np.array(value)}')\n", + " elif key == \"Classification Report (GRU Signal)\":\n", + " print(f'{key}:\\\\n{value}')\n", + " elif isinstance(value, float):\n", + " print(f'{key}: {value:.4f}')\n", + " else:\n", + " print(f'{key}: {value}')\n", + " else:\n", + " print('Backtest metrics not available.')\n", + "\n", + " if pipeline_instance.backtest_results_df is not None:\n", + " print('\\n--- Backtest Results DataFrame (Head) --- ')\n", + " display(pipeline_instance.backtest_results_df.head())\n", + " print('\\n--- Backtest Results DataFrame (Tail) --- ')\n", + " display(pipeline_instance.backtest_results_df.tail())\n", + " print('\\n--- Backtest Results DataFrame (Description) --- ')\n", + " display(pipeline_instance.backtest_results_df.describe())\n", + " else:\n", + " print('Backtest results DataFrame not available.')\n", + "\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Run Backtest step: {e}')\n", + " logging.error('Run Backtest step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated or prerequisites for backtest are missing. Cannot run step.')\n", + " print(f\"Prerequisites check: pipeline={bool(pipeline_instance)}, test_sequences={pipeline_instance.test_sequences is not None if pipeline_instance else False}, \"\n", + " f\"test_targets={pipeline_instance.test_targets is not None if pipeline_instance else False}, test_indices={pipeline_instance.test_indices is not None if pipeline_instance else False}, \"\n", + " f\"gru_handler={pipeline_instance.gru_handler is not None if pipeline_instance else False}, calibrator={pipeline_instance.calibrator is not None if pipeline_instance else False}, \"\n", + " f\"optimal_T={getattr(pipeline_instance, 'optimal_threshold', None) is not None}, sac_path={getattr(pipeline_instance, 'sac_agent_path', None) is not None}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 11. Step 8: Save Results\n", + "\n", + "Save the calculated metrics, the detailed backtest results DataFrame, and any generated plots to the run-specific output directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance and pipeline_instance.backtest_results_df is not None and pipeline_instance.backtest_metrics:\n", + " try:\n", + " print('\\n=== Running Step 8: Save Results ===')\n", + " pipeline_instance.save_results()\n", + " print('save_results() finished.')\n", + " print(f'Results should be saved in: {pipeline_instance.dirs[\"results\"]}')\n", + "\n", + " except Exception as e:\n", + " print(f'An error occurred during Save Results step: {e}')\n", + " logging.error('Save Results step failed.', exc_info=True)\n", + "else:\n", + " print('Pipeline not instantiated or backtest results/metrics missing. Cannot run step.')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 12. Display Saved Plots\n", + "\n", + "Load and display the plots generated and saved during the pipeline execution (especially during calibration and backtesting/saving)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This code assumes plots were generated and saved by previous steps (like calibrate or save_results)\n", + "if pipeline_instance is not None and pipeline_instance.dirs.get('results'):\n", + " results_dir = pipeline_instance.dirs['results']\n", + " run_id = pipeline_instance.run_id\n", + " print(f'\\nLooking for plots in: {results_dir}\\n')\n", + "\n", + " plot_files = [\n", + " f'backtest_summary_{run_id}.png',\n", + " f'confusion_matrix_{run_id}.png',\n", + " f'reliability_curve_val_{run_id}.png', # Generated by calibration\n", + " f'calibration_curve_test_{run_id}.png' # Potentially generated by backtester/save_results\n", + " # Add any other plot filenames generated by your pipeline\n", + " ]\n", + "\n", + " plot_found = False\n", + " for plot_file in plot_files:\n", + " plot_path = os.path.join(results_dir, plot_file)\n", + " if os.path.exists(plot_path):\n", + " plot_found = True\n", + " print(f'--- Displaying: {plot_file} ---')\n", + " try:\n", + " img = mpimg.imread(plot_path)\n", + " # Determine appropriate figure size based on plot type\n", + " figsize = (15, 12) if 'summary' in plot_file else (8, 7)\n", + " plt.figure(figsize=figsize)\n", + " plt.imshow(img)\n", + " plt.axis('off') # Hide axes for image display\n", + " plt.title(plot_file)\n", + " plt.show()\n", + " except Exception as e:\n", + " print(f' Error loading/displaying plot {plot_file}: {e}')\n", + " else:\n", + " print(f'Plot not found: {plot_path}')\n", + "\n", + " if not plot_found:\n", + " print(\"No standard plots found in the results directory.\")\n", + "\n", + "else:\n", + " print('\\nPipeline object not found or results directory is not available. Cannot display plots.')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 13. Conclusion\n", + "\n", + "This notebook demonstrated the step-by-step workflow of using the `TradingPipeline`. By running each step individually, we could inspect the intermediate outputs. You can modify the `config.yaml` file to experiment with different parameters, data ranges, and control flags, then re-run the relevant steps of this notebook. The final results (metrics, plots, detailed CSV) are saved in the run-specific directory under the main project's `results/` folder." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/prompts/calibrating_edge_baseline.txt b/prompts/calibrating_edge_baseline.txt new file mode 100644 index 00000000..794c1612 --- /dev/null +++ b/prompts/calibrating_edge_baseline.txt @@ -0,0 +1,71 @@ +### Streamlined Calibration for Baseline LR Gates + +If you’re mainly struggling with mis‑calibrated confidence on your edge‑filtered checks, here’s a **minimal** integration to fix it without heavy lifting. + +--- + +#### 1. Add a toggle and holdout in `config.yaml` +```yaml +baseline: + calibration_enabled: true # turn on/off easily + calibration_method: "isotonic" # handles multiclass + calibration_holdout: 0.2 # 20% of your train split + random_state: 42 +``` + +--- + +#### 2. Quick split for calibration +In your `run_baseline_checks` (before any CI gates): +```python +# original train/val split +X_main, X_val, y_main, y_val = train_test_split( + X_pruned, y_labels, test_size=0.2, random_state=seed +) + +if self.config['baseline']['calibration_enabled']: + X_train, X_cal, y_train, y_cal = train_test_split( + X_main, y_main, + test_size=self.config['baseline']['calibration_holdout'], + random_state=self.config['baseline']['random_state'] + ) +else: + X_train, y_train = X_main, y_main + X_cal, y_cal = None, None +``` + +--- + +#### 3. Fit an isotonic calibrator only when needed +```python +# train raw LR +lr = LogisticRegression(...).fit(X_train, y_train) + +if X_cal is not None: + from sklearn.calibration import CalibratedClassifierCV + calibrator = CalibratedClassifierCV(lr, method='isotonic', cv='prefit') + calibrator.fit(X_cal, y_cal) +else: + calibrator = lr +``` + +--- + +#### 4. Use calibrated probabilities in your gates +Replace all `lr.predict_proba(X)` calls with: +```python +probs = calibrator.predict_proba(X) +# binary: edge = |probs[:,1] - 0.5| +# ternary: edge = max(probs, axis=1) - 1/3 +``` +Then run your existing CI lower‑bound checks as usual. + +--- + +#### 5. (Optional) Skip persistence +For a quick fix you can skip saving/loading the calibrator—just build and use it in the same process. + +--- + +With these five steps, you’ll correct your edge‑confidence estimates with minimal code and configuration. If your gates then pass, proceed to GRU training; if they still fail, the issue is likely weak features rather than calibration. + diff --git a/prompts/mdn_gru.txt b/prompts/mdn_gru.txt new file mode 100644 index 00000000..7e85ebdf --- /dev/null +++ b/prompts/mdn_gru.txt @@ -0,0 +1,163 @@ +Below is a consolidated set of revision instructions — and the key code snippets you’ll need — to switch your GRU/SAC pipeline to supervise **log‑returns** (so your “ret” and “gauss_params” heads are targeting log‑return), and to wire everything end‑to‑end so you get a clean 55 %+ edge before SAC. + +--- + +## 🛠 Revision Playbook + +1. **Compute forward log‐returns in your data pipeline** + In `TradingPipeline.define_labels_and_align` (or wherever you set up `df`): + + ```python + # Replace any raw-return calculation with log‐return + N = config['gru']['prediction_horizon'] + df['fwd_log_ret'] = np.log(df['close'].shift(-N) / df['close']) + df['direction_label'] = (df['fwd_log_ret'] > 0).astype(int) + # If you do ternary: + flat_thr = config['gru']['flat_sigma_multiplier'] * df['fwd_log_ret'].rolling(…) + df['dir3_label'] = pd.cut(df['fwd_log_ret'], + bins=[-np.inf, -flat_thr, flat_thr, np.inf], + labels=[0,1,2]).astype(int) + ``` + +2. **Align your targets** + Drop the last N rows so `fwd_log_ret` has no NaNs: + + ```python + df = df.iloc[:-N] + ``` + +3. **Pass log‐return into both heads** + When you build your sequences and target dicts: + + ```python + y_ret_seq = ... # shape (n_seq, 1) from fwd_log_ret + y_dir3_seq = ... # one‑hot from dir3_label + + y_train = {'mu': y_ret_seq, # Huber head + 'gauss_params': y_ret_seq, # NLL head uses same target + 'dir3': y_dir3_seq} # classification head + ``` + +4. **Update your GRU builder to match** + Make sure your v3 model has exactly three outputs: + ```python + model = Model(inputs, outputs=[mu_output, gauss_params_output, dir3_output]) + model.compile( + optimizer=Adam(lr), + loss={ + 'mu': Huber(delta), + 'gauss_params': gaussian_nll, + 'dir3': categorical_focal_loss + }, + loss_weights={'mu':1.0, 'gauss_params':0.2, 'dir3':0.4}, + metrics={'dir3':'accuracy'} + ) + ``` + +5. **Train with the new targets dict** + In `GRUModelHandler.train(...)`, replace your fit call with: + + ```python + history = model.fit( + X_train_seq, + y_train_dict, + validation_data=(X_val_seq, y_val_dict), + …callbacks… + ) + ``` + +6. **Calibrate on the “dir3” softmax outputs** + Your calibrator (Temp/Vector) must now consume the 3‑class logits or probabilities: + + ```python + raw_logits = handler.predict_logits(X_val_seq) + calibrator.fit(raw_logits, y_val_dir3) + ``` + +7. **Feed SAC the log‐return μ and σ** + In your `TradingEnv`, when you construct the state: + + ```python + mu, log_sigma, probs = gru_handler.predict(X_step) + sigma = np.exp(log_sigma) + edge = 2 * calibrated_p_up - 1 # if binary + z_score = np.abs(mu) / sigma + state = [mu, sigma, edge, z_score, prev_position] + ``` + +8. **Re‐run baseline check on log‐returns** + Your logistic baseline in `run_baseline_checks` should now be trained on `X_train_pruned` vs `y_dir3_label` (or binary), ensuring the CI ≥ 0.52 before you even build your GRU. + +9. **Validate end‑to‑end edge** + After these changes, you should see: + - Baseline logistic CI LB ≥ 0.52 + - GRU “edge” hit‑rate ≥ 0.55 on validation + - SAC backtest hitting meaningful Sharpe/Win‑rate gates + +--- + +## 🔧 Example Code Snippets + +### 1) Gaussian NLL stays the same: + +```python +@saving.register_keras_serializable(package='GRU') +def gaussian_nll(y_true, y_pred): + mu, log_sigma = tf.split(y_pred, 2, axis=-1) + y_true = tf.reshape(y_true, tf.shape(mu)) + inv_var = tf.exp(-2*log_sigma) + return tf.reduce_mean(0.5 * inv_var * tf.square(y_true-mu) + log_sigma) +``` + +### 2) Build & compile v3 model: + +```python +def build_gru_model_v3(...): + inp = layers.Input((lookback, n_features)) + x = layers.GRU(gru_units, return_sequences=True)(inp) + x = layers.LayerNormalization()(x) + if attention_units>0: + x = layers.MultiHeadAttention(...)(x,x) + x = layers.GlobalAveragePooling1D()(x) + + gauss = layers.Dense(2, name='gauss_params')(x) + mu = layers.Lambda(lambda z: z[:,0:1], name='mu')(gauss) + dir3_logits = layers.Dense(3, name='dir3_logits')(x) + dir3 = layers.Activation('softmax', name='dir3')(dir3_logits) + + model = Model(inp, [mu, gauss, dir3]) + model.compile( + optimizer=Adam(lr), + loss={'mu':Huber(delta), + 'gauss_params':gaussian_nll, + 'dir3':categorical_focal_loss}, + loss_weights={'mu':1.0,'gauss_params':0.2,'dir3':0.4}, + metrics={'dir3':'accuracy'} + ) + return model +``` + +### 3) Fitting in your handler: + +```python +history = self.model.fit( + X_train_seq, y_train_dict, + validation_data=(X_val_seq, y_val_dict), + epochs=max_epochs, + batch_size=batch_size, + callbacks=[early_stop, csv_logger, TqdmCallback()] +) +``` + +--- + +### Why these changes boost edge? + +- **Log‐returns** stabilize variance & symmetrize up/down moves. +- **NLL + Huber on log‐return** gives the model both distributional uncertainty (σ) and a robust error measure. +- **Proper softmax head** on three classes (up/flat/down) cleans up classification. +- **Calibration + optimized edge threshold** ensures your SAC agent only sees high‐confidence signals (edge≥thr). + +Together, this gets your baseline GRU above 55 % “edge” on validation, so the SAC agent can then learn a meaningful sizing policy rather than fight noise. + +Let me know if you need any further refinements! \ No newline at end of file diff --git a/prompts/missing_data.txt b/prompts/missing_data.txt new file mode 100644 index 00000000..2027c0e7 --- /dev/null +++ b/prompts/missing_data.txt @@ -0,0 +1,187 @@ +## Revision Instructions for AI Dev‑Agent +Implement end‑to‑end missing‑bar handling in GRU and SAC. Apply the steps below in sequence, with small PRs and CI green at each stage. + +--- + +### 1 | Config updates + +**File:** `config.yaml` +Add under `data` and create new sections for `gru` and `sac`: + +```yaml +data: + bar_frequency: "1T" + missing: + strategy: "neutral" # drop | neutral | ffill | interpolate + max_gap: 5 # max consecutive missing bars allowed + interpolate: + method: "linear" + limit: 10 + +gru: + drop_imputed_sequences: true # drop any sequence containing imputed bars + +sac: + imputed_handling: "hold" # hold | skip | penalty + action_penalty: 0.05 # used if imputed_handling=penalty +``` + +--- + +### 2 | Detect & fill missing bars + +**File:** `src/data_loader.py` + +1. **Import** at top: + ```python + import pandas as pd + from .io_manager import IOManager + ``` + +2. **Implement** `find_missing_bars(df, freq)` and `_consecutive_gaps` helpers. + +3. **Implement** `report_missing(missing, cfg, io, logger)` as described. + +4. **Implement** `fill_missing_bars(df, cfg, io, logger)`: + - Detect missing timestamps. + - Call `report_missing`. + - Reindex to full date_range. + - Apply `strategy`: + - `drop`: return original df. + - `neutral`: ffill close, set open=high=low=close, volume=0. + - `ffill`: `df_full.ffill().bfill()`. + - `interpolate`: use `df_full.interpolate(...)`. + - **After filling**, add column: + ```python + df['bar_imputed'] = df.index.isin(missing) + ``` + - **Error** if longest gap > `cfg.data.missing.max_gap`. + +5. **Integrate** in `TradingPipeline.load_and_preprocess_data` **before** feature engineering: + ```python + df = fill_missing_bars(df, self.cfg, io, logger) + ``` + +--- + +### 3 | Sequence creation respects imputed bars + +**File:** `src/trading_pipeline.py` + +1. In `create_sequences`, after building `X_seq` and `y_seq`, **build** `mask_seq` of shape `(n, lookback)` from `df['bar_imputed']`. + +2. **Conditionally drop** sequences: + ```python + if self.cfg.gru.drop_imputed_sequences: + valid = ~mask_seq.any(axis=1) + X_seq = X_seq[valid]; y_seq = y_seq[valid] + ``` +3. **Log**: + ```python + logger.info(f"Generated {orig_n} sequences, dropped {orig_n - X_seq.shape[0]} with imputed bars") + ``` +4. **Include** `bar_imputed` as a feature column in `minimal_whitelist`. + +--- + +### 4 | GRU model input channel + +**File:** `src/model_gru_v3.py` (or `model_gru.py` if v3) + +1. **Update input shape**: increase `n_features` by 1 to include `bar_imputed`. + +2. **No further architectural change**; the model now sees the imputed‑flag channel. + +--- + +### 5 | SAC environment handles imputed bars + +**File:** `src/trading_env.py` + +1. **Read** `bar_imputed` into `self.bar_imputed` aligned with your sequences. + +2. **In `step(action)`**, at the top: + ```python + imputed = self.bar_imputed[self.current_step] + if imputed: + mode = self.cfg.sac.imputed_handling + if mode == "skip": + self.current_step += 1 + return next_state, 0.0, False, {} + if mode == "hold": + action = self.position + if mode == "penalty": + reward = - self.cfg.sac.action_penalty * (action - self.position)**2 + self._update_position(action) + self.current_step += 1 + return next_state, reward, False, {} + # existing normal step follows + ``` + +3. **Ensure** imputed transitions are added to replay buffer only when `mode` ≠ `skip`. + +4. **Log**: + ```python + logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}") + ``` + +--- + +### 6 | Logging & artefacts + +1. **Data load** warning: + ``` + WARNING Detected {total} missing bars, longest gap {longest}; applied strategy={strategy} + ``` + +2. **Sequence creation** info: + ``` + INFO Generated {orig_n} sequences, dropped {dropped} with imputed bars + ``` + +3. **SAC training** debug: + ``` + DEBUG SAC on imputed bar at step {step}: handling={mode} + ``` + +4. **Report** saved under `results//`: + - `missing_bars_summary.json` + - `imputed_sequence_summary.json` with counts. + - `sac_imputed_transitions.csv` (optional detailed log). + +--- + +### 7 | Unit tests + +**Files:** `tests/test_data_loader.py`, `tests/test_sequence_creation.py`, `tests/test_trading_env.py` + +1. **`test_data_loader.py`**: + - Synthetic gappy DataFrame → assert `bar_imputed` flags and each strategy’s output. + +2. **`test_sequence_creation.py`**: + - Build toy DataFrame with `bar_imputed`; assert sequences dropped when `drop_imputed_sequences=True`. + +3. **`test_trading_env.py`**: + - Create `TradingEnv` with known imputed steps; for each `imputed_handling` mode assert `step()` behavior: + - `skip` moves without adding to buffer; + - `hold` returns same position; + - `penalty` returns negative reward equal to penalty formula. + +--- + +### 8 | Documentation + +1. **README.md** → add **Data Quality** section describing missing‑bar handling, config keys, and recommended defaults. + +2. **docs/v3_changelog.md** → note new missing‑bar feature and cfg flags. + +--- + +**Roll‑out Plan:** + +- **PR 1:** Config + data_loader missing‑bar detection & fill + tests. +- **PR 2:** Sequence creation & GRU channel update + tests. +- **PR 3:** SAC env updates + tests. +- **PR 4:** Logging/artefacts + docs. + +Merge each after CI passes. \ No newline at end of file diff --git a/prompts/output_artefacts.txt b/prompts/output_artefacts.txt new file mode 100644 index 00000000..460c1945 --- /dev/null +++ b/prompts/output_artefacts.txt @@ -0,0 +1,140 @@ +## **Revision Document – v3 Output Contract & Figure Specifications** +This single guide merges **I/O plumbing**, **logging**, **CI hooks**, **artefact paths**, and **figure design** into one actionable playbook. +Apply the steps **in order**, submitting small PRs so CI remains green throughout. + +--- + +### 0 ▪ Foundations + +| Step | File(s) | Action | +|------|---------|--------| +| 0.1 | **`config.yaml`** | Add: ```yaml base_dirs: {results: results, models: models, logs: logs} output: {figure_dpi: 150, figure_size: [16, 9], log_level: INFO}``` | +| 0.2 | `src/utils/run_id.py` | `make_run_id()` → `"20250418_152310_ab12cd"` (timestamp + short git‑hash). | +| 0.3 | `src/__init__.py` | Expose `__version__`, `GIT_SHA`, `BUILD_DATE`. | + +--- + +### 1 ▪ Core I/O & Logging + +| File | Content | +|------|---------| +| **`src/io_manager.py`** | `IOManager(cfg, run_id)`
• `path(section, name)`: returns full path under `results|models|logs|figures`.
• `save_json`, `save_df` (CSV ≤ 100 MB else Parquet), `save_figure` (uses cfg dpi/size). | +| **`src/logger_setup.py`** | `setup_logger(cfg, run_id, io)` with colourised console (INFO) + rotating file handler (DEBUG) in `logs//`. | + +**`run.py` entry banner** + +```python +run_id = make_run_id() +cfg = load_config(args.config) +io = IOManager(cfg, run_id) +logger = setup_logger(cfg, run_id, io) +logger.info(f"GRU‑SAC v{__version__} | commit {GIT_SHA} | run {run_id}") +logger.info(f"Loaded config file: {args.config}") +``` + +--- + +### 2 ▪ Stage Outputs + +| Stage | Implementation notes | Artefacts | +|-------|---------------------|-----------| +| **Data load & preprocess** | After sampling/NaN purge save:
`io.save_json(summary, "preprocess_summary")`
`io.save_df(df.head(20), "head_preprocessed")` | `results//preprocess_summary.txt`
`head_preprocessed.csv` | +| **Feature engineering** | Generate correlation heat‑map (see figure table) → `io.save_figure(...,"feature_corr_heatmap")` | 〃 | +| **Label generation** | Log distribution; produce histogram figure. | `label_histogram.png` | +| **Baseline 1 & 2** | Consolidate in `baseline_checker.py`; each returns dict with accuracy, CI etc.
`io.save_json(report,"baseline1_report")` (and 2). | `baseline1_report.txt / baseline2_report.txt` | +| **Feature whitelist** | Save JSON to `models//final_whitelist_.json`. | — | +| **GRU training** | Use Keras CSVLogger to `logs//gru_history.csv`; after training plot learning curve. | `gru_learning_curve.png` + `.keras` model | +| **Calibration (Vector)** | Save `calibrator_vec_.npy`; plot reliability curve. | `reliability_curve_val_.png` | +| **SAC training** | Write `episode_rewards.csv`, plot reward curve, save final agent under `models/sac_train_/`. | `sac_reward_plot.png` | +| **Back‑test** | Save step‑level CSV, metrics JSON, summary figure. | `backtest_results_.csv`
`performance_metrics_.txt`
`backtest_summary_.png` | + +--- + +### 3 ▪ Figure Specifications + +| File | Visualises | Layout / Details | +|------|-------------|------------------| +| **feature_corr_heatmap.png** | Pearson correlation of engineered features (pre‑prune). | Square heat‑map, features sorted by |ρ| vs target; diverging palette centred at 0; annotate |ρ| > 0.5; colour‑bar. | +| **label_histogram.png** | Direction‑label class mix (train split). | Bar chart: Down / Flat / Up (binary shows two). Percentages on bar tops; title shows ε value. | +| **gru_learning_curve.png** | GRU training progress. | 3 stacked panes: total loss (log‑y), val dir3 accuracy, vertical dashed “early‑stop”; share epoch‑axis. | +| **reliability_curve_val_*.png** | Calibration quality post‑Vector scaling. | Left 70 %: reliability diagram (10 equal‑freq bins). Right 30 %: histogram of predicted p_up. Title shows ECE & Brier. | +| **sac_reward_plot.png** | Offline SAC learning curve. | Smoothed episode reward (EMA 0.2) vs steps; action‑variance on twin y‑axis; checkpoint ticks. | +| **backtest_summary_*.png** | Live back‑test overview. | 3 stacked plots:
1) Price line + blue/red background for edge ≥ 0.1.
2) Position size step‑graph.
3) Equity curve with shaded draw‑downs; textbox shows Sharpe & Max DD. | + +_All figs_: 16 × 9 in, 150 DPI, `plt.tight_layout()`, footer `"© GRU‑SAC v3"` right‑bottom. + +--- + +### 4 ▪ Unit Tests + +* `tests/test_output_contract.py` + * Run mini‑pipeline (`tests/smoke.yaml`), assert each required file exists > 2 KB. + * Validate JSON keys (`accuracy`, `ci_lower` etc.). + * `assert_any_close(softmax(logits), probs)` for logits view. + +--- + +### 5 ▪ CI Workflow (`.github/workflows/pipeline.yml`) + +```yaml +jobs: + build-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: {python-version: "3.10"} + - run: pip install -r requirements.txt + - run: black --check . + - run: ruff . + - run: pytest -q + - name: Smoke e2e + run: python run.py --config tests/smoke.yaml + - name: Upload artefacts + uses: actions/upload-artifact@v4 + with: + name: run-${{ github.sha }} + path: | + results/*/* + logs/*/* +``` + +--- + +### 6 ▪ Documentation Updates + +* **`README.md`** → new *Outputs* section reproducing the artefact table. +* **`docs/v3_changelog.md`** → one‑pager summarising v3 versus v2 differences (labels, calibration, outputs). + +--- + +### 7 ▪ Roll‑out Plan (5‑PR cadence) + +1. **PR #1** – run‑id, IOManager, logger, CI log upload. +2. **PR #2** – data & feature stage outputs + tests. +3. **PR #3** – GRU training outputs + calibration figure. +4. **PR #4** – SAC & back‑test outputs, reward & summary figs. +5. **PR #5** – docs & README refresh. + +Tag `v3.0.0` after PR #5 passes. + +--- + +### 8 ▪ Success Criteria for CI + +Fail the pipeline when **any** occurs: + +* `baseline1_report.txt` CI‑LB < 0.52 +* `edge_filtered_accuracy` (val) < 0.60 +* Back‑test Sharpe < 1.2 or Max DD > 15 % + +--- + +Implementing this **single integrated revision** provides: + +* **Deterministic artefact paths** for every run. +* **Rich, shareable figures** for quick diagnostics. +* **Audit‑ready logs/reports** for research traceability. + +Merge each step once CI is green; you’ll have a reproducible, fully instrumented pipeline ready for iterative accuracy pushes toward the 65 % target. \ No newline at end of file diff --git a/prompts/pipeline_refactor.txt b/prompts/pipeline_refactor.txt new file mode 100644 index 00000000..3ad4deb7 --- /dev/null +++ b/prompts/pipeline_refactor.txt @@ -0,0 +1,55 @@ +Refactoring Plan for trading_pipeline.py +======================================= + +Goal: Break down the large TradingPipeline class into smaller, dedicated modules for better maintainability and readability, while minimizing disruption to the existing E2E system. + +Strategy: + +1. **Keep `TradingPipeline` as the Orchestrator:** + * The main `TradingPipeline` class in `trading_pipeline.py` remains. + * Responsibilities: + * Load configuration (`__init__`). + * Initialize core components (`DataLoader`, `FeatureEngineer`, etc.). + * Manage overall state (instance variables like `df_raw`, `gru_model`, etc.). + * Run the main execution flow (`execute`), including the walk-forward loop. + * Call functions from new stage-specific modules. + +2. **Create Stage-Specific Modules:** + * Create a new sub-directory: `src/pipeline_stages`. + * Move stage-specific logic from `TradingPipeline` methods into functions within these modules. + * Proposed Modules: + * `src/pipeline_stages/data_processing.py`: Loading, feature engineering, label generation, alignment, splitting. + * `src/pipeline_stages/feature_processing.py`: Scaling, feature selection, pruning. + * `src/pipeline_stages/sequence_creation.py`: GRU sequence creation. + * `src/pipeline_stages/modelling.py`: GRU/SAC training/loading, calibration, SAC aggregation. + * `src/pipeline_stages/evaluation.py`: Baseline checks, backtesting, results saving, metric aggregation, final decision. + +3. **Refactor `TradingPipeline` Methods:** + * Simplify existing methods in `TradingPipeline` (e.g., `load_and_preprocess_data`, `engineer_features`). + * These methods will now primarily: + * Import the corresponding function from the `pipeline_stages` module. + * Call the imported function, passing necessary data and components (`config`, `data_loader`, `io`, state variables). + * Receive results and update the `TradingPipeline` instance's state. + +4. **Data Flow:** + * Emphasize explicit data passing via function arguments and return values between stages. + * The `TradingPipeline.execute` method orchestrates this flow. + * State required across multiple stages/folds remains as `TradingPipeline` instance attributes. + +5. **Dependencies:** + * Pass `config`, `io`, and component instances (`DataLoader`, `FeatureEngineer`, etc.) as arguments to the stage functions that need them. + +Benefits: + +* **Readability:** `trading_pipeline.py` becomes a clearer orchestrator. +* **Maintainability:** Easier to isolate and modify specific stages. +* **Testability:** Stage functions are potentially easier to unit test. +* **Reduced Risk:** Focuses on moving logic, minimizing E2E breakage compared to a full rewrite. + +Implementation Steps: + +1. Create the `src/pipeline_stages` directory and module files. +2. Incrementally move logic for each stage into the corresponding module's functions. +3. Update `TradingPipeline` methods to import and call these new functions. +4. Adjust imports and function signatures as needed. +5. Proceed stage by stage, verifying structure and data flow. \ No newline at end of file diff --git a/prompts/validation_gates.txt b/prompts/validation_gates.txt new file mode 100644 index 00000000..1c29c45a --- /dev/null +++ b/prompts/validation_gates.txt @@ -0,0 +1,55 @@ +Below is a consolidated list of every “validation gate” we enforce in the pipeline—split by the **GRU** (prediction) stage and the **SAC** (position‑sizing) stage. Each check either **aborts** the run (hard‐fail) or **warns** you that a non‑critical gate didn’t clear. + +--- + +## GRU‑stage Validation Gates + +| Gate # | Check | Data used | Threshold | Action on Fail | +|--------|----------------------------------------------------------------|-----------------------|------------------|--------------------| +| **G1** | **Raw binary LR** (Internal split) | train 80/20 split | CI LB ≥ `binary_ci_lb` (0.52) | **Abort** | +| **G2** | **Raw binary RF** (Internal split, optional) | train 80/20 split | CI LB ≥ `binary_rf_ci_lb` (0.54) | **Abort** (if enabled) | +| **G3** | **Raw ternary LR** (Internal split, if `use_ternary`) | train 80/20 split | CI LB ≥ `ternary_ci_lb` (0.40) | **Warn** | +| **G4** | **Raw ternary RF** (Internal split, optional) | train 80/20 split | CI LB ≥ `ternary_rf_ci_lb` (0.42) | **Warn** | +| **G5** | **Forward‑fold binary LR** (True OOS, test fold t+1) | fold’s test set | CI LB ≥ `forward_ci_lb` (0.52) | **Abort** | +| **G6** | **Feature‑selection re‑baseline** (post‑prune binary LR) | pruned train 80/20 | CI LB ≥ `binary_ci_lb` (0.52) | **Abort** | +| **G7** | **Calibration check** (edge‑filtered p_cal on val split) | val split | CI LB ≥ `calibration_ci_lb` (0.55) | **Abort** | + +> **Notes:** +> • G1 catches “no predictive signal” cheaply. +> • G5 ensures it actually *generalises* forward. +> • G7 makes sure your calibrated probabilities have real edge before SAC ever sees them. + +--- + +## SAC‑stage Validation Gates + +| Gate # | Check | Data used | Threshold | Action on Fail | +|---------|-------------------------------------------------------------|-----------------------|---------------------|----------------------| +| **G8** | **Edge‑filtered binary LR** | val split probabilities | CI LB ≥ `edge_binary_ci_lb` (0.60) | **Abort** | +| **G9** | **Edge‑filtered binary RF** | val split probabilities | CI LB ≥ `edge_binary_rf_ci_lb` (0.62) | **Abort** | +| **G10** | **Edge‑filtered ternary LR** (if `use_ternary`) | val split p_cat[:,2]–p_cat[:,0] | CI LB ≥ `edge_ternary_ci_lb` (0.57) | **Warn** | +| **G11** | **Edge‑filtered ternary RF** (if `use_ternary`) | val split p_cat[:,2]–p_cat[:,0] | CI LB ≥ `edge_ternary_rf_ci_lb` (0.58) | **Warn** | +| **G12** | **Backtest performance** (Sharpe / Max DD on test fold) | aggregated test folds | Sharpe ≥ `backtest.sharpe_lb` (1.2)
Max DD ≤ `backtest.max_dd_ub` (15 %) | **Abort** if violated| + +> **Notes:** +> • G8–G9 gate the **high‐confidence edge** you feed into SAC. If they fail, SAC will only ever go all‑in/flat, so we abort. +> • G10–G11 warn you if your flat/no‑move sizing is shaky—SAC can still run, but you’ll get a console warning suggesting you tweak your flat thresholds or features. +> • G12 validates final **live‐like** performance; if you can’t hit the Sharpe/Max DD targets on the unseen test folds, the entire run is considered a no‑go. + +--- + +### What to do on each pattern + +1. **Any GRU‑abort gate (G1, G2, G5, G6, G7) fails** → + **Stop** before training. Improve features, horizon, calibration settings, or prune strategy. + +2. **GRU passes but SAC‑binary edge gates (G8/G9) fail** → + **Stop** before SAC training. Means your probabilities have no reliable high‑confidence edge—tweak calibration threshold or retrain GRU. + +3. **GRU & SAC‑binary gates pass, but SAC‑ternary edge gates (G10/G11) warn** → + **Proceed** with a warning: consider adding flat‑specific features or raising the `edge_threshold`. + +4. **All gates pass** → + Full pipeline runs to completion: GRU training, SAC training, backtest, resulting in models, logs, and performance reports. + +By strictly enforcing these gates, you ensure every GRU and SAC model you train has demonstrable, forward‑tested edge—maximizing your chances of hitting that 65 % directional target in live trading. \ No newline at end of file diff --git a/prompts/walk_forward.txt b/prompts/walk_forward.txt new file mode 100644 index 00000000..f9d1c336 --- /dev/null +++ b/prompts/walk_forward.txt @@ -0,0 +1,124 @@ + 1. Nested Cross-Validation (for GRU Hyperparameter Tuning) +Goal: To tune GRU hyperparameters (like gru_units, learning_rate, etc.) robustly for each main walk-forward fold, using only the training data allocated to that fold. This prevents hyperparameters from being influenced by data that will later appear in the fold's validation or test set. +Current Implementation: The hyperparameter_tuning.gru.sweep_enabled flag exists, but the tuning logic isn't currently nested within the fold processing loop in train_or_load_gru_fold. +Implementation Strategy: +Modify train_or_load_gru_fold (in gru_sac_predictor/src/pipeline_stages/modelling.py): This is the function responsible for training or loading the GRU for a specific outer walk-forward fold. +Check sweep_enabled: Inside this function, right before the actual GRU training would normally occur (i.e., if config['gru']['train_gru'] is true and a model isn't being loaded), check if config['hyperparameter_tuning']['gru']['sweep_enabled'] is also true. +Inner CV Loop: If sweep is enabled: +Data: Use the X_train_seq and y_train_seq_dict passed into this function (these represent the training data for the current outer fold). +Inner Splits: Use a time-series-appropriate splitter (like sklearn.model_selection.TimeSeriesSplit) on the sequence indices (train_indices_new if returned, otherwise derive from X_train_seq) to create, say, 3 or 5 inner train/validation splits within the outer fold's training data. +Optuna Study: Create a new Optuna study (or similar hyperparameter optimization framework) specific to this outer fold. +Objective Function: Define an Optuna objective function that takes a trial object: +It suggests hyperparameters based on config['hyperparameter_tuning']['gru']['search_space']. +It iterates through the inner CV splits. For each inner split: +Instantiate a temporary GRUModelHandler (or just the model) with the trial's hyperparameters. +Train the model on the inner training data slice. +Evaluate it on the inner validation data slice (e.g., calculate val_loss). +Return the average performance (e.g., average val_loss) across the inner splits. +Run Study: Execute study.optimize with the objective function and n_trials from the config. +Best Parameters: Retrieve the study.best_params after optimization. +Final Fold Training: Instantiate the GRUModelHandler (gru_handler passed into the function) or build the GRU model using these best_params. Train this single, final model for the outer fold on the entire X_train_seq and y_train_seq_dict. +Return: Return this optimally tuned GRU model and handler for the outer fold to proceed. +Configuration: +The existing hyperparameter_tuning.gru section is mostly sufficient. +You might add a key like inner_cv_splits: 3 to control the inner loop. +Considerations: This significantly increases computation time, as n_trials * inner_cv_splits models are trained per outer fold. + +2. Gap and Regime-Aware Folds +Here’s a minimal “wrapper” you can drop around your existing `_generate_walk_forward_folds` to get both gap‑aware **and** regime‑aware filtering, without rewriting your core logic: + +```python +def generate_filtered_folds(df, config): + # 1) Tag regimes once, right after loading & feature‐engineering the full dataset + if config['walk_forward']['regime']['enabled']: + df = add_regime_tags( + df, + indicator=config['walk_forward']['regime']['indicator'], + window=config['walk_forward']['regime']['indicator_params']['window'], + quantiles=config['walk_forward']['regime']['quantiles'] + ) + min_pct = config['walk_forward']['regime']['min_regime_representation_pct'] + + # 2) Split into contiguous chunks on data gaps + chunks = split_into_contiguous_chunks( + df, + config['walk_forward']['gap_threshold_minutes'] + ) + + # 3) For each chunk, run your normal fold‐generator, then filter by regime + for chunk_start, chunk_end in chunks: + df_chunk = df.loc[chunk_start:chunk_end] + # skip tiny chunks + if (chunk_end - chunk_start).days < config['walk_forward'].get('min_chunk_days', 1): + continue + + # your existing generator (rolling or block)— + # it yields tuples of (train_start, train_end, val_start, val_end, test_start, test_end) + for (t0, t1, v0, v1, e0, e1) in self._original_fold_generator(df_chunk, config): + # if regime gating is off, just yield + if not config['walk_forward']['regime']['enabled']: + yield (t0, t1, v0, v1, e0, e1) + continue + + # 4) Check regime balance in each period + periods = { + 'train': df_chunk.loc[t0:t1], + 'val': df_chunk.loc[v0:v1], + 'test': df_chunk.loc[e0:e1], + } + bad = False + for name, subdf in periods.items(): + counts = subdf['regime_tag'].value_counts(normalize=True) * 100 + # ensure every regime appears ≥ min_pct + for regime in sorted(df['regime_tag'].unique()): + pct = counts.get(regime, 0.0) + if pct < min_pct: + bad = True + break + if bad: + break + + if bad: + # you can log which period/regime failed here + continue + # otherwise it’s a valid fold + yield (t0, t1, v0, v1, e0, e1) +``` + +### Explanation of the steps + +1. **Regime Tagging** + - Run once, up‑front: compute your volatility or trend indicator over the full series, cut it into quantile bins, and assign each row a `regime_tag` of 0/1/2. + +2. **Gap Partitioning** + - Split the DataFrame into contiguous “chunks” wherever index gaps exceed your `gap_threshold_minutes`. + - This avoids forcing folds that straddle a hole in the data. + +3. **Fold Generation (Unchanged)** + - Call your existing `_generate_walk_forward_folds` (rolling or block) on each contiguous chunk. + +4. **Regime‐Balance Filter** + - For each candidate fold, slice out the train/val/test segments, compute the fraction of each regime tag, and **skip** any fold where any regime appears below your `min_regime_representation_pct`. + +--- + +#### Configuration sketch + +```yaml +walk_forward: + # existing fields… + gap_threshold_minutes: 5 + regime: + enabled: true + indicator: volatility + indicator_params: + window: 20 + quantiles: [0.33, 0.66] + min_regime_representation_pct: 10 +``` + +With this wrapper, you get: + +- **Automatic split** at data outages > 5 min +- **Dynamic skip** of any time‐slice folds that would be blind to a market regime (e.g. all high‑vol or all low‑vol) +- **No changes** to your core split logic—just filter its outputs. \ No newline at end of file diff --git a/test_config.yaml b/test_config.yaml new file mode 100644 index 00000000..8132bdc9 --- /dev/null +++ b/test_config.yaml @@ -0,0 +1,364 @@ +# GRU-SAC Predictor v3 Configuration File +# This file parameterizes all major components of the pipeline. + +pipeline: + description: "Configuration for the GRU-SAC trading predictor pipeline." + # Define stages to run, primarily for debugging/selective execution. + # stages_to_run: ["data", "features", "gru", "sac", "backtest", "aggregate"] # Example: Run all + +# --- Data Loading and Initial Processing --- +data: + ticker: "XRP-USDT" # Ticker symbol (adjust based on DataLoader capabilities) + exchange: "bnbspot" # Exchange name (adjust based on DataLoader) + interval: "1min" # Data interval (e.g., '1min', '5min', '1h') + start_date: "2024-06-01" # Start date for data loading (YYYY-MM-DD) + end_date: "2025-04-01" # End date for data loading (YYYY-MM-DD) - 10 day range + db_dir: '../data/crypto_market_data' # to database directory (relative to project root) + bar_frequency: "1T" # Added based on instructions + missing: # Added missing data section + strategy: "drop" # drop | neutral | ffill | interpolate + max_gap: 60 # max consecutive missing bars allowed + interpolate: + method: "linear" + limit: 10 + + volatility_sampling: # Optional volatility-based downsampling + enabled: False + window: 30 # Window for volatility calculation (e.g., 30 minutes) + quantile: 0.5 # Quantile threshold for sampling (0.0 to 1.0) + + # Optional: Add parameters for data cleaning if needed + # e.g., max_nan_fill_gap: 5 + +# --- Feature Engineering --- +features: + # Parameters for FeatureEngineer.add_base_features + atr_window: 14 + rsi_window: 14 + adx_window: 14 + macd_fast: 12 + macd_slow: 26 + macd_signal: 9 + # Add parameters for other indicators (e.g., Chaikin, SVI, Volatility) if configurable + # chaikin_ad_window: 10 + # svi_window: 10 + # volatility_window: 14 # e.g., for a rolling std dev feature + # Parameters for feature selection (used by FeatureEngineer.select_features) + # These might include method (e.g., 'correlation', 'mutual_info', 'lgbm'), thresholds, etc. + selection_method: "correlation" # Example + correlation_threshold: 0.02 # Example threshold for correlation-based selection + min_features_after_selection: 10 # Minimum number of features to keep + +# --- Data Splitting (Walk-Forward or Single Split) --- +walk_forward: + enabled: true + n_folds: 5 # Example: Divide data into 5 blocks + split_ratios: # Ratios applied WITHIN each block + train: 0.80 + validation: 0.1 + # test: 0.15 (Implicit) + initial_offset_days: 0 + step_days: 1 # Step forward by test period length + # train_period_days, validation_period_days, test_period_days, step_days are ignored if n_folds > 1ptional gap between periods (e.g., train-val, val-test) + +# --- GRU Model Configuration --- +gru: + # Label Definition + train_gru: true + use_ternary: True # Use ternary (Up/Flat/Down) labels? If False, uses binary (Up/Down). + prediction_horizon: 10 # Lookahead period for target returns/labels (in units of 'data.interval') + flat_sigma_multiplier: 0.3 # 'k' factor for ternary flat label threshold (eps = k * rolling_std(fwd_ret)) + label_smoothing: 0.0 # Alpha for binary label smoothing (0.0 disables) + drop_imputed_sequences: true # Added based on instructions + + # Model Architecture (V3) - Used by GRUModelHandler.build_gru_model_v3 + gru_units: 128 # Number of units in GRU layer + attention_units: 16 # Number of units in MultiHeadAttention layer (set to 0 to disable) + dropout_rate: 0.05 # Dropout rate for GRU and Attention layers + learning_rate: 1e-2 # Learning rate for Adam optimizer + l2_reg: 1e-4 # L2 regularization factor for Dense layers + + # Loss Function Parameters (V3) - Used by GRUModelHandler.build_gru_model_v3 + focal_gamma: 2.0 # Gamma parameter for categorical focal loss (if use_ternary=True) + focal_label_smoothing: 0.1 # Label smoothing within focal loss calculation + huber_delta: 1.0 # Delta parameter for Huber loss (mu/return prediction) + loss_weight_mu: 0.3 # Weight for the mu/return prediction loss component + loss_weight_dir3: 1.0 # Weight for the direction prediction loss component + + # Training Parameters - Used by GRUModelHandler.train + lookback: 120 # Sequence length (timesteps) for GRU input + epochs: 25 # Maximum number of training epochs + batch_size: 128 # Training batch size + patience: 5 # Early stopping patience (epochs with no improvement in val_loss) + # early_stopping_monitor: "val_loss" # Monitor for early stopping (hardcoded in handler) + # training_shuffle: False # Whether to shuffle training data each epoch (hardcoded False) + + # Loading Control - Used by pipeline_stages.modelling.train_or_load_gru_fold + load_gru_model: + run_id: null # Set to a specific GRU pipeline run ID to load model/scaler from instead of training + fold_num: null # Optional: Specify fold number (e.g., 1, 2...). If null, handler might load best/last fold based on its internal logic. + +# --- Hyperparameter Tuning (Optuna/W&B) --- +hyperparameter_tuning: + gru: + sweep_enabled: False # Master switch to enable Optuna sweep for GRU + # If enabled=True, define sweep parameters here: + study_name: "gru_optimization" + direction: "minimize" # "minimize" val_loss or "maximize" val_accuracy + n_trials: 50 + pruner: "median" # e.g., "median", "hyperband" + sampler: "tpe" # e.g., "tpe", "random" + search_space: + gru_units: { type: "int", low: 32, high: 128, step: 16 } + attention_units: { type: "int", low: 8, high: 64, step: 8 } + dropout_rate: { type: "float", low: 0.05, high: 0.3 } + learning_rate: { type: "loguniform", low: 1e-5, high: 1e-3 } + l2_reg: { type: "loguniform", low: 1e-5, high: 1e-3 } + loss_weight_mu: { type: "float", low: 0.1, high: 0.9 } + batch_size: { type: "categorical", choices: [64, 128, 256] } + +# --- Probability Calibration --- +calibration: + method: vector + optimize_edge_threshold: true + edge_threshold: 0.5 # Initial or fixed threshold if not optimizing + # Rolling calibration settings (if method requires) + rolling_window_size: 250 + rolling_min_samples: 50 + rolling_step: 50 + reliability_plot_bins: 10 # Number of bins for reliability plot + +# --- Soft Actor-Critic (SAC) Agent and Training --- +sac: + imputed_handling: "hold" # Added based on instructions + action_penalty: 0.05 # Added based on instructions + # Agent Hyperparameters - Used by SACTradingAgent.__init__ + gamma: 0.99 # Discount factor + tau: 0.005 # Target network update rate (polyak averaging) + actor_lr: 3e-4 # Learning rate for the actor network + critic_lr: 3e-4 # Learning rate for the critic networks + # Optional: LR Decay for actor/critic (if implemented in agent) + lr_decay_rate: 0.96 + decay_steps: 100000 + # Optional: Ornstein-Uhlenbeck noise parameters (if used) + ou_noise_stddev: 0.2 + alpha: 0.2 # Initial entropy temperature (used if alpha_auto_tune=False) + alpha_auto_tune: True # Enable automatic tuning of entropy temperature alpha + target_entropy: -1.0 # Target entropy for alpha tuning; -action_dim is common default (-1.0 for action_dim=1) + + # Training Loop Parameters - Used by SACTrainer._training_loop + total_training_steps: 100000 # Total steps for the SAC training loop + buffer_capacity: 1000000 # Maximum size of the replay buffer + batch_size: 256 # Batch size for sampling from replay buffer + start_steps: 10000 # Number of initial steps with random actions before training starts + update_after: 1000 # Number of steps to collect before first agent update + update_every: 50 # Perform agent updates every N steps + save_freq: 5000 # Save agent checkpoint every N steps + log_freq: 100 # Log training metrics (losses, Q-values) to TensorBoard every N steps + eval_freq: 5000 # Evaluate agent performance every N steps (requires evaluation logic) + + # Alpha (Entropy Temperature) Annealing - Used by SACTrainer._training_loop + alpha_anneal_start_step: 10000 # Step to start annealing alpha (if auto-tune enabled) + alpha_anneal_end_step: 50000 # Step to finish annealing alpha + initial_alpha: 0.2 # Alpha value before annealing starts + final_alpha: 0.01 # Target alpha value after annealing finishes + + # Prioritized Experience Replay (PER) - Used by SACTrainer / PrioritizedReplayBuffer + use_per: true # Enable PER? If False, uses standard uniform replay buffer. + # PER parameters (used only if use_per=True) + per_alpha: 0.6 # Priority exponent (how much prioritization). 0 = uniform. + per_beta_start: 0.4 # Initial importance sampling exponent (annealed to 1.0) + per_beta_frames: 100000 # Steps over which to anneal beta from beta_start to 1.0 + # Optional PER Alpha annealing (anneals the priority exponent alpha) + per_alpha_anneal_enabled: False + per_alpha_start: 0.6 + per_alpha_end: 0.4 + per_alpha_anneal_steps: 50000 + + # Oracle Seeding (Potentially deprecated/experimental) + oracle_seeding_pct: 0.1 # Percentage of buffer to pre-fill using heuristic policy + + # State Normalization - Used by SACTrainer + use_state_filter: True # Use MeanStdFilter for state normalization? + state_dim_fallback: 5 # Fallback state dim if cannot be inferred (e.g., from loaded agent metadata) + action_dim_fallback: 1 # Fallback action dim if cannot be inferred + + # Loading Control - Used by pipeline_stages.modelling.train_or_load_sac_fold + train_sac: true # Master switch: Train SAC agent? If False, attempts to load based on control flags. + +# --- SAC Agent Aggregation (Post Walk-Forward) --- +sac_aggregation: + enabled: true # Aggregate agents from multiple folds? + method: "average_weights" # Currently only 'average_weights' is supported + +# --- Trading Environment Simulation --- +environment: # Parameters passed to TradingEnv and Backtester + initial_capital: 10000.0 # Starting capital for simulation/backtest + transaction_cost: 0.0005 # Fractional cost per trade (e.g., 0.0005 = 0.05%) + # Reward shaping parameters (used within TradingEnv._calculate_reward) + reward_scale: 100.0 # Multiplier applied to the raw PnL reward + action_penalty_lambda: 0.0 # Penalty factor for action magnitude or changes (0 disables) + # Add other env parameters if needed (e.g., position limits, reward clipping) + +# --- Baseline Model Parameters --- +baselines: # Parameters for specific baseline models + # Parameters for Binary Logistic Regression + logistic_regression: + max_iter: 1000 + solver: "lbfgs" # Note: solver 'liblinear' is needed for L1 selection in FeatureEngineer, but LBFGS is fine for baseline checks + random_state: 42 + val_subset_split_ratio: 0.2 # Internal split ratio used within baseline check for raw CI + val_subset_shuffle: False # Shuffle for internal split? + ci_confidence_level: 0.95 # Confidence level for binomial test CI + # Baseline Calibration Settings (Applied only if baseline_binary.run_logistic_regression is true AND this is enabled) + calibration_enabled: true + calibration_method: isotonic # 'isotonic' or 'sigmoid' + calibration_holdout: 0.2 # % of training data used for calibrator fitting + # random_state is reused for calibration splitting + + # Parameters for Binary RandomForest Classifier + random_forest: + n_estimators: 100 # Number of trees + max_depth: 10 # Maximum depth of trees (None for unlimited) + min_samples_split: 2 # Minimum samples required to split an internal node + min_samples_leaf: 1 # Minimum number of samples required to be at a leaf node + random_state: 42 + n_jobs: -1 # Use all available CPU cores + # Internal split and CI settings (can reuse LogReg values or specify separately) + val_subset_split_ratio: 0.2 + val_subset_shuffle: False + ci_confidence_level: 0.95 + + # Parameters for Multinomial Logistic Regression (Ternary) + multinomial_logistic_regression: + max_iter: 1000 + solver: "lbfgs" # Common choice for multinomial + multi_class: "multinomial" + random_state: 42 + val_subset_split_ratio: 0.2 + val_subset_shuffle: False + ci_confidence_level: 0.95 + + # Parameters for Ternary RandomForest Classifier + ternary_random_forest: + n_estimators: 100 + max_depth: 10 + min_samples_split: 2 + min_samples_leaf: 1 + random_state: 42 + n_jobs: -1 + val_subset_split_ratio: 0.2 + val_subset_shuffle: False + ci_confidence_level: 0.95 + +# --- Pipeline Validation Gates --- +validation_gates: # Thresholds checked at different stages to potentially halt the pipeline + run_baseline_check: true # Master switch for running *any* applicable baseline check + + # Settings for Binary Baselines + baseline_binary: + run_logistic_regression: true # Enable Binary LogReg check? + run_random_forest: true # Enable Binary RF check? + # --- Thresholds for Binary Gates --- # + ci_threshold: 0.51 # G1: Raw LogReg CI LB threshold + binary_rf_ci_lb: 0.54 # G2: Raw RF CI LB threshold (Mandatory if RF enabled) + # --- Edge Check Settings for Binary --- # + run_logistic_regression_edge_check: true # Enable edge check for LogReg? + run_random_forest_edge_check: true # G9: Enable edge check for RF? (Mandatory if enabled) + edge_threshold_value: 0.1 # Edge value |P(up)-P(down)| for filtering samples + edge_ci_threshold_gate: 0.56 # G8: LogReg Edge CI LB threshold (Mandatory if edge check enabled) + edge_binary_rf_ci_lb: 0.60 # G9: RF Edge CI LB threshold (Mandatory if edge check enabled) + # ci_confidence_level is taken from the respective model's config in 'baselines' section + + # Settings for Ternary Baselines (Monitored, not mandatory gates by default) + baseline_ternary: + run_logistic_regression: true # Enable Ternary LogReg check? + run_random_forest: true # Enable Ternary RF check? + # --- Thresholds for Ternary Gates (Monitoring) --- # + ternary_ci_lb: 0.40 # G3: Raw Ternary LogReg CI LB threshold (vs 1/3) + ternary_rf_ci_lb: 0.42 # G4: Raw Ternary RF CI LB threshold (vs 1/3) + # --- Edge Check Settings for Ternary --- # + run_logistic_regression_edge_check: true # G10: Enable edge check for Ternary LogReg? + run_random_forest_edge_check: true # G11: Enable edge check for Ternary RF? + edge_threshold_value: 0.1 # Edge value |P(up)-P(down)| for filtering samples + edge_ternary_ci_lb: 0.57 # G10: Ternary LogReg Edge CI LB threshold + edge_ternary_rf_ci_lb: 0.58 # G11: Ternary RF Edge CI LB threshold + # ci_confidence_level is taken from the respective model's config in 'baselines' section + + # Other existing gate sections (Merge with the above structure) + post_pruning_baseline: # Example G6 - check baseline on pruned features + enabled: true + # ci_threshold: 0.52 # Optional: Override threshold, else defaults to baseline_binary.ci_threshold + + forward_baseline: # Example - check LR performance on immediate future + enabled: false # Typically disabled unless specifically testing + ci_threshold: 0.52 + + calibration_check: # Example G7 - check CI LB on calibrated edge accuracy + enabled: true + ci_threshold: 0.55 # Default 0.55 if not specified + + gru_performance: # Checks on GRU validation predictions (after calibration) + enabled: True + min_edge_accuracy: 0.60 # Minimum accuracy using the optimized/configured edge threshold + max_brier_score: 0.24 # Maximum acceptable Brier score + + backtest: # Master switch for all backtest performance gates + enabled: True + backtest_performance: # Specific performance checks on the backtest results + enabled: True # Enable/disable Sharpe and Max DD checks specifically + min_sharpe_ratio: 1.2 # Minimum acceptable annualized Sharpe ratio + max_drawdown_pct: 15.0 # Maximum acceptable drawdown percentage (positive value) + + final_release: # Gates checked on aggregated metrics across all folds + min_successful_folds_pct: 0.75 # Minimum % of folds that must succeed + median_sharpe_threshold: 1.3 # Minimum median Sharpe across successful folds + # max_drawdown_max_threshold: 20.0 # Optional: Max Drawdown allowed in *any* fold + +# --- Pipeline Control Flags --- +control: + generate_plots: True # Generate and save plots (learning curves, backtest summary, etc.)? + + # Loading specific models instead of training/running stages + # Note: train_gru and train_sac flags override these if both are set + # GRU Loading: see gru.load_gru_model section + # SAC Loading: Used if sac.train_sac=False + sac_load_run_id: null # Specify SAC Training Run ID (e.g., "sac_train_...") to load for backtesting + sac_load_step: 'final' # 'final' or specific step number checkpoint to load + + # Resuming SAC Training (Loads agent and potentially buffer state to continue training) + sac_resume_run_id: null # Specify SAC Training Run ID to resume from + sac_resume_step: 'final' # 'final' or step number checkpoint to resume from + +# --- Logging Configuration --- +logging: + console_level: "INFO" # Level for console output: DEBUG, INFO, WARNING, ERROR, CRITICAL + file_level: "DEBUG" # Level for file output: DEBUG, INFO, WARNING, ERROR, CRITICAL + log_to_file: True # Enable logging to a file? + # Log file path determined by IOManager: logs//pipeline.log + log_format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' # Format string + log_date_format: '%Y-%m-%d %H:%M:%S' # Date format for logs + # Rotating File Handler settings (if log_to_file=True) + log_file_max_bytes: 10485760 # Max size in bytes (e.g., 10MB) before rotation + log_file_backup_count: 5 # Number of backup log files to keep + +# --- Output Artifacts Configuration --- +output: + base_dirs: # Base directories (relative to project root or absolute) + results: "results" + models: "models" + logs: "logs" + # Figure generation settings + figure_dpi: 150 # DPI for saved figures + figure_size: [16, 9] # Default figure size (width, height in inches) + figure_footer: "© GRU-SAC v3" # Footer text added to plots + plot_style: "seaborn-v0_8-darkgrid" # Matplotlib style sheet to use + # Plot-specific settings + reward_plot_smoothing_alpha: 0.2 # EMA alpha for SAC reward plot smoothing + # reliability_plot_bins: 10 # Defined under calibration section + + # IOManager settings + dataframe_save_format: "parquet_if_large" # "csv", "parquet", "parquet_if_large" + dataframe_max_csv_mb: 100 # Threshold (MB) for using Parquet if format is parquet_if_large + +# ... existing code ... \ No newline at end of file diff --git a/training.log b/training.log new file mode 100644 index 00000000..854d3e7b --- /dev/null +++ b/training.log @@ -0,0 +1,2 @@ +nohup: ignoring input +python: can't open file '/home/yasha/develop/gru_sac_predictor/gru_sac_predictor/src/main.py': [Errno 2] No such file or directory