diff --git a/gru_sac_predictor/README.md b/gru_sac_predictor/README.md index c78d3897..98262582 100644 --- a/gru_sac_predictor/README.md +++ b/gru_sac_predictor/README.md @@ -1,6 +1,6 @@ -# GRU + SAC Crypto Trading System (v3 - Consolidated & Enhanced) +# GRU + SAC Crypto Trading System (v3 - Refactored & Enhanced) -This project implements a cryptocurrency trading system using a GRU model for market prediction and a Soft Actor-Critic (SAC) agent for position sizing. This version reflects significant refactoring, feature additions, and the consolidation of model logic. +This project implements a cryptocurrency trading system using a GRU model for market prediction and a Soft Actor-Critic (SAC) agent for position sizing. This version reflects significant refactoring for modularity, feature additions, and the consolidation of model logic. The core idea is to decouple prediction and action: 1. A **GRU model** (v2 or v3 architecture, selected via config) forecasts future log-returns (μ̂) and class probabilities (binary p(up) or ternary p(down, flat, up)). @@ -11,32 +11,71 @@ This approach aims for a robust system where the RL agent focuses solely on risk ## Key Features & Enhancements -* **Consolidated GRU Logic:** Both v2 and v3 GRU model architectures are now implemented and managed within `src/gru_model_handler.py`. -* **Walk-Forward Validation:** Replaces static train/val/test splits with a robust walk-forward validation framework (`TradingPipeline.execute`, `_generate_walk_forward_folds`) for more realistic performance estimation. -* **Hyperparameter Optimization (Optuna):** Integrated Optuna sweep for GRU hyperparameters (`src/gru_hyper_tuner.py`) with restricted search space, configurable objective (`edge_acc - brier`), and Keras callback for efficient pruning based on `val_loss`. +* **Modular Pipeline Structure:** Pipeline logic is refactored into stage-specific functions within `src/pipeline_stages/` for improved readability, maintainability, and testability (see Project Structure). +* **Consolidated GRU Logic:** Both v2 and v3 GRU model architectures are implemented and managed within `src/gru_model_handler.py`. +* **Walk-Forward Validation:** Robust walk-forward validation framework (`TradingPipeline.execute`, `_generate_walk_forward_folds`) for realistic performance estimation. +* **Hyperparameter Optimization (Optuna):** Integrated Optuna sweep for GRU hyperparameters (`src/gru_hyper_tuner.py`) with configurable search space, objective, and pruning. * **Advanced Calibration:** * Supports Temperature and Vector Scaling (`calibration.method`) with optional L2 regularization (`calibration.l2_lambda`). - * Optimizes edge threshold via Youden's J on validation data (`calibration.optimize_edge_threshold`). Saved per fold (`optimized_edge_threshold_fold_N.txt`) and used consistently. - * **Rolling Calibration (Experimental):** Implemented within `Backtester` to refit the calibrator periodically during the backtest (`calibration.rolling_enabled`, `recalibrate_every_n`, `recalibration_window`). Uses the **static** calibration from training time if SAC training is active to prevent lookahead. + * Optimizes edge threshold via Youden's J on validation data (`calibration.optimize_edge_threshold`). + * **Rolling Calibration (Experimental):** Implemented within `Backtester` to refit the calibrator periodically during the backtest (`calibration.rolling_enabled`, `recalibrate_every_n`, `recalibration_window`). * **Coverage Alarm (ECE-based):** Optional alarm triggers early recalibration if **Expected Calibration Error (ECE)** exceeds a threshold (`calibration.coverage_alarm_enabled`, `ece_recalibration_threshold`). -* **Prioritized Experience Replay (PER):** Implemented for SAC training (`sac.use_per`) with **TD-error clipping** and **alpha annealing** (linear decay). Logs TD error distribution statistics. -* **SAC Enhancements:** Reward scaling, state normalization (`MeanStdFilter`), configurable **action penalty** (default: `0.01 / transaction_cost`), oracle seeding with **Importance Sampling (IS) weight decay** (`per_seed_decay_steps`). -* **Refined Validation Gates:** Configurable thresholds (`validation_gates`) for: - * **Baseline Gate:** Checks Logistic Regression CI on **raw/engineered** training features before scaling. - * **GRU Gate:** Checks Edge Acc CI and Brier score on validation set after calibration, using the fold's **determined edge threshold**. - * **Final Release Decision:** Checks aggregated metrics (e.g., **median Sharpe ≥ 1.3**, **≥ 75% successful folds**) across all successful folds. Backtest gate failures *per fold* are logged but do not halt the entire pipeline. -* **Micro-structure Features:** Added bar-level features (`FeatureEngineer._add_microstructure_features`) with **NaN guards** for robustness. -* **Leakage Guard:** Feature calculations use `shift(1)`. Selection occurs on **raw/engineered features** and includes correlation check (`corr(ret+h, feat_t-1)`) against future returns. Minimal whitelist applied before VIF. -* **Configuration:** Centralized and expanded `config.yaml` with annotations for mutually exclusive options (e.g., walk-forward vs static split). +* **Prioritized Experience Replay (PER):** Implemented for SAC training (`sac.use_per`) with TD-error clipping and alpha annealing. +* **SAC Enhancements:** Reward scaling, state normalization (`MeanStdFilter`), configurable action penalty, oracle seeding with Importance Sampling (IS) weight decay. +* **Refined Validation Gates:** Configurable thresholds (`validation_gates`) for baseline checks, GRU validation, fold backtest performance, and final release decisions. +* **Micro-structure Features:** Added bar-level features (`FeatureEngineer._add_microstructure_features`) with NaN guards. +* **Leakage Guard:** Feature calculations use `shift(1)`. Selection includes correlation check against future returns. Minimal whitelist applied before VIF. +* **Configuration:** Centralized and expanded `config.yaml`. * **Output Management:** Standardized output structure via `IOManager` and `LoggerSetup`. -* **SAC Agent Aggregation:** Optional post-processing step to average weights from agents trained across successful folds (`TradingPipeline.aggregate_sac_agents`, `sac_aggregation.enabled`). +* **SAC Agent Aggregation:** Optional post-processing step to average weights from agents trained across successful folds. + +## Data Quality + +This system includes mechanisms to handle potential missing data points (bars) in the input time series. This is crucial for maintaining data integrity and preventing errors during feature engineering and model training. + +**Handling Missing Bars:** + +* **Detection:** The pipeline automatically detects missing bars based on the expected `data.bar_frequency` (e.g., "1T" for 1 minute) after initial data loading. +* **Reporting:** A warning is logged detailing the total number of missing bars found and the length of the longest consecutive gap. A summary report (`missing_bars_summary.json`) is saved in the run's results directory. +* **Filling Strategies:** Several strategies are available, configured via `data.missing.strategy`: + * `"drop"`: No filling is performed. Missing bars remain gaps or NaNs. (Use with caution). + * `"neutral"`: Forward-fills the 'close' price, sets 'open', 'high', 'low' equal to the filled 'close', and sets 'volume' to 0 for imputed bars. + * `"ffill"`: Forward-fills all OHLCV columns, then back-fills any remaining NaNs at the beginning. + * `"interpolate"`: Interpolates missing values using the method specified in `data.missing.interpolate.method` (e.g., 'linear') up to a `limit` defined in `data.missing.interpolate.limit`. +* **Imputed Flag:** After filling, a boolean column `bar_imputed` is added to the DataFrame, marking rows that were originally missing. +* **Max Gap Check:** The pipeline will raise an error if the longest detected consecutive gap exceeds `data.missing.max_gap`. + +**Impact on Downstream Components:** + +* **Feature Engineering:** Features are calculated on the potentially gap-filled data. +* **Sequence Creation:** Sequences containing imputed bars can be optionally dropped before GRU training, controlled by `gru.drop_imputed_sequences`. +* **GRU Model:** The `bar_imputed` flag is included as a feature input to the GRU model, allowing it to potentially learn patterns related to imputed data. +* **SAC Environment:** The `TradingEnv` is aware of imputed bars. The behavior during an imputed step is controlled by `sac.imputed_handling`: + * `"skip"`: The environment skips the step, no action is taken, no reward is given, and the transition is not added to the replay buffer. + * `"hold"`: The agent's action is overridden to maintain its current position. The step proceeds normally otherwise (reward calculated based on held position). + * `"penalty"`: The agent's chosen action is taken, but a penalty reward (based on `sac.action_penalty`) is applied instead of the normal PnL reward. + +**Recommended Defaults:** + +Using `"neutral"` or `"ffill"` for `strategy` is generally recommended for continuous time series. `max_gap` should be set to a reasonably small number (e.g., 5-10) to avoid filling excessively long gaps with potentially inaccurate data. For the SAC environment, `"hold"` or `"skip"` are common choices, depending on whether you want the agent to explicitly learn from imputed steps (or lack thereof). ## System Design & Workflow The system is orchestrated by `run.py`, which sets up logging and I/O via `LoggerSetup` and `IOManager`, then instantiates and executes the `TradingPipeline` class (`src/trading_pipeline.py`). The pipeline follows a sequence of steps, potentially looped for Walk-Forward validation. +### Pipeline Stages (Refactored) + +The core logic for each step in the pipeline has been moved into dedicated functions within the `src/pipeline_stages/` directory. The `TradingPipeline` class now acts primarily as an orchestrator, calling these stage functions in sequence and managing the overall state and data flow. + +* **`src/pipeline_stages/data_processing.py`**: Handles loading, initial preprocessing, feature engineering, labeling, and data splitting logic for each fold. +* **`src/pipeline_stages/feature_processing.py`**: Manages feature scaling, selection (L1+VIF), and pruning based on the selected whitelist. +* **`src/pipeline_stages/sequence_creation.py`**: Creates input sequences suitable for the GRU model from the processed feature data. +* **`src/pipeline_stages/modelling.py`**: Contains functions for training/loading the GRU model (including hyperparameter tuning), calibrating probabilities (Temperature/Vector Scaling, edge threshold optimization), training/loading the SAC agent, and aggregating SAC agents. +* **`src/pipeline_stages/evaluation.py`**: Includes functions for running baseline checks (Logistic Regression), performing GRU validation checks (Edge Accuracy, Brier Score), and executing the main backtest simulation (instantiating and running the `Backtester`). + +### Workflow Diagram + ```mermaid -%%{init: {'themeVariables': { 'fontSize': '26px' }}}%% graph TD A[run.py: Init Logger/IOManager/Config] --> B(TradingPipeline); @@ -47,21 +86,21 @@ graph TD D --> F[Select Fold Data]; E --> F; - subgraph Fold Processing [Fold Processing] + subgraph Fold Processing [Fold Processing - Calls Stage Functions] direction TB - F --> G[Engineer Features]; - G --> H[Define Labels & Align]; - H --> I[Split Fold Data]; - I --> J1[Baseline Check]; - J1 -- Pass --> L[Select Features]; - L --> K[Scale Features]; - K --> M[Prune Scaled Features]; - M --> N[Create Sequences]; - N --> O[Train/Load GRU]; - O --> P[Calibrate Probabilities]; - P --> R[GRU Validation Gate]; - R -- Pass --> S[Train/Load SAC Agent]; - S --> T[Run Backtest]; + F --> G[data_processing: Engineer Features]; + G --> H[data_processing: Define Labels & Align]; + H --> I[data_processing: Split Fold Data]; + I --> J1[evaluation: Baseline Check]; + J1 -- Pass --> L[feature_processing: Select Features]; + L --> K[feature_processing: Scale Features]; + K --> M[feature_processing: Prune Scaled Features]; + M --> N[sequence_creation: Create Sequences]; + N --> O[modelling: Train/Load GRU]; + O --> P[modelling: Calibrate Probabilities]; + P --> R[evaluation: GRU Validation Gate]; + R -- Pass --> S[modelling: Train/Load SAC Agent]; + S --> T[evaluation: Run Backtest]; T --> U[Record Fold Results]; end @@ -73,172 +112,144 @@ graph TD B --> Walk-ForwardLoop; X1 --> Y[Aggregate Fold Metrics]; - Y --> Z[Aggregate SAC Agents]; + Y --> Z[modelling: Aggregate SAC Agents]; Z --> Z1[Final Release Decision]; Z1 --> Z_End([End Pipeline Run]); ``` -*Diagram outlines the consolidated v3 pipeline flow after `revisions.txt` modifications.* +*Diagram outlines the consolidated v3 pipeline flow, highlighting calls to stage functions.* ### Detailed Steps (Walk-Forward Enabled) -1. **Initialization (`run.py`):** - * Parses args (`--config`, etc.). - * Loads `config.yaml`. - * Initializes `IOManager`, `LoggerSetup`. - * Instantiates `TradingPipeline`. -2. **Data Loading (`TradingPipeline.load_and_preprocess_data`):** - * Loads the *entire* raw dataset specified in the config. -3. **Fold Generation (`TradingPipeline._generate_walk_forward_folds`):** - * Based on `walk_forward` config (train/val/test/step days), yields date ranges for each fold. -4. **Fold Loop (`TradingPipeline.execute`):** Iterates through generated folds. - * **Select Fold Data:** Extracts raw data corresponding to the current fold's (Train+Val+Test) date range. - * **Feature Engineering (`engineer_features`):** Computes base, TA, and micro-structure features (with NaN guards) on the fold's raw data (using `shift(1)` for time-dependent features). - * **Labeling (`define_labels_and_align`):** Calculates forward returns and target labels (binary/ternary) for the fold's engineered data. - * **Split Fold Data (`split_data`):** Splits the fold's labeled data into `train`, `val`, and `test` sets based on the fold's date ranges. Stores results like `self.X_train_raw`, `self.y_val`, `self.df_test_original`. - * **Baseline Gate (`run_baseline_checks`):** Trains/validates Logistic Regression on fold's ***raw/engineered*** training features. Exits fold if CI lower bound < config threshold (`validation_gates.baseline`). Saves `baseline_report_fold_N.txt`. - * **Select Features (`select_and_prune_features` - Selection Part):** Performs leakage check (`corr(ret+h, feat_t-1)`) and L1 + VIF selection (applying minimal whitelist *before* VIF) on fold's *raw/engineered* training features. Saves `final_whitelist_fold_N.json`. - * **Scale Features (`scale_features`):** Fits `StandardScaler` on fold's *raw* training features (numeric only). Scales train, val, test features (`X_train_scaled`, etc.). Saves `feature_scaler_fold_N.joblib`. - * **Prune Features (`select_and_prune_features` - Pruning Part):** Prunes the *scaled* data splits (`X_train_scaled` -> `X_train_pruned`) using the `final_whitelist` determined earlier. - * **Create Sequences (`create_sequences`):** Converts the fold's *pruned, scaled* train/val/test sets into sequences (`X_train_seq`, etc.). - * **Train/Load GRU (`train_or_load_gru`):** - * If `sweep_enabled`, runs `GRUHyperTuner` (with updated objective, restricted search space, Keras callback for pruning on `val_loss`, logging objective components) to find best hyperparameters using fold's train/val sequences. Trains final model with best params. Saves best params JSON and Optuna plots. - * If not sweeping, trains/loads GRU using config defaults. - * Saves the final fold GRU model (`gru_model_fold_N.keras`), history, and learning curve plot. - * **Calibrate Probabilities (`calibrate_probabilities`):** - * Fits the selected calibrator (Temp/Vector with optional L2 reg) on the fold's validation sequences. - * If `optimize_edge_threshold`, calculates optimal threshold using Youden's J, stores it internally (`self.optimized_edge_threshold`), and saves it (`optimized_edge_threshold_fold_N.txt`). - * Saves fold calibration parameters (`calibration_{temp/vector}_fold_N.npy`). - * **GRU Validation Gate (`_perform_gru_validation_checks`):** Checks edge-filtered accuracy CI and Brier score against updated config thresholds (`validation_gates.gru`) using the fold's determined `optimized_edge_threshold`. Exits fold if failed. - * **Train/Load SAC (`train_or_load_sac`):** - * If `train_sac`, initializes `SACTrainer` using a config copy (passing the fold's `optimized_edge_threshold` and disabling rolling calibration if active). Trains agent handling PER (with clipping, alpha annealing), Oracle Seeding (with IS weight decay), State Normalization, Action Penalty (`0.01/cost`). Saves agent, filter state, logs, plots in a fold-specific `sac_train_...` dir. - * If loading, determines path from config. - * **Run Backtest (`run_backtest`):** - * Initializes `Backtester`. - * Passes the fold's SAC agent path, test sequences, GRU handler, *initial* calibration state, the fold's `optimized_edge_threshold`, and original test prices. - * If `rolling_enabled`, the backtester uses raw predictions to refit calibrator, potentially triggered early by **ECE Coverage Alarm** (`ECE > config threshold`). - * Logs backtest performance (Sharpe, MDD, Win Rate) and gate pass/fail status. Fold failure here does *not* halt the pipeline immediately but is recorded. - * **Store Fold Results:** Appends the `backtest_metrics` dict (including status) to `all_fold_metrics`. Stores the `sac_agent_load_path` if SAC was trained successfully. -5. **Aggregate Metrics (`aggregate_fold_metrics`):** Calculates summary statistics (mean, std, min, max, median) across metrics from all *successful* folds. Saves `aggregated_wf_metrics.json`. -6. **Aggregate SAC Agents (`aggregate_sac_agents`):** (Optional: if `sac_aggregation.enabled`) - * Loads SAC agents from the stored paths of successful folds. - * Averages the weights of the loaded agents. - * Saves the aggregated agent to the main run's model dir (`models/run_.../sac_agent_aggregated/`). Saves `sac_aggregation_info.txt`. -7. **Final Release Decision (`final_release_decision`):** Evaluates aggregated metrics against overall release criteria defined in `validation_gates.final_release` (e.g., min % successful folds, median Sharpe). -8. **Log Final Status:** Logs whether the pipeline passed or failed the final release criteria. +1. **Initialization (`run.py`):** Sets up infrastructure (config, logging, IO). Instantiates `TradingPipeline`. +2. **Data Loading (`TradingPipeline` calls `data_processing.load_and_preprocess`):** Loads the *entire* raw dataset. +3. **Fold Generation (`TradingPipeline._generate_walk_forward_folds`):** Yields date ranges for each fold. +4. **Fold Loop (`TradingPipeline.execute`):** Iterates through folds. + * **Select Fold Data:** Extracts raw data for the current fold range. + * **Feature Engineering (`TradingPipeline` calls `data_processing.engineer_features_for_fold`):** Computes features on fold data. + * **Labeling (`TradingPipeline` calls `data_processing.define_labels_and_align_fold`):** Calculates labels. + * **Split Fold Data (`TradingPipeline` calls `data_processing.split_data_fold`):** Splits into `train`, `val`, `test`. + * **Baseline Gate (`TradingPipeline` calls `evaluation.run_baseline_checks_fold`):** Runs Logistic Regression check. Halts fold on failure. + * **Select Features (`TradingPipeline` calls `feature_processing.select_features_fold`):** Performs selection. Saves whitelist. + * **Scale Features (`TradingPipeline` calls `feature_processing.scale_features_fold`):** Fits/applies scaler. Saves scaler. + * **Prune Features (`TradingPipeline` calls `feature_processing.prune_features_fold`):** Prunes scaled data using whitelist. + * **Create Sequences (`TradingPipeline` calls `sequence_creation.create_sequences_fold`):** Creates GRU input sequences. + * **Train/Load GRU (`TradingPipeline` calls `modelling.train_or_load_gru_fold`):** Handles training, Optuna sweep, or loading. Handles re-processing (scale/prune/sequence) internally if loaded scaler differs. Saves model/params. + * **Calibrate Probabilities (`TradingPipeline` calls `modelling.calibrate_probabilities_fold`):** Fits calibrator, optimizes edge threshold. Saves parameters. + * **GRU Validation Gate (`TradingPipeline` calls `evaluation.run_gru_validation_checks_fold`):** Checks calibrated validation predictions. Halts fold on failure. + * **Train/Load SAC (`TradingPipeline` calls `modelling.train_or_load_sac_fold`):** Handles SAC training (calling `SACTrainer`) or determines load path. + * **Run Backtest (`TradingPipeline` calls `evaluation.run_backtest_fold`):** Instantiates `Backtester`, runs simulation, handles rolling calibration, performs backtest validation checks. Halts fold on failure. + * **Store Fold Results:** Appends metrics and SAC agent path (if trained) for aggregation. +5. **Aggregate Metrics (`TradingPipeline.aggregate_fold_metrics`):** Calculates summary statistics across successful folds. +6. **Aggregate SAC Agents (`TradingPipeline` calls `modelling.aggregate_sac_agents`):** (Optional) Averages weights of successful fold agents. +7. **Final Release Decision (`TradingPipeline.final_release_decision`):** Evaluates aggregated metrics against final criteria. +8. **Log Final Status:** Reports overall pipeline success/failure. -*(Note: If Walk-Forward is disabled (`walk_forward.enabled=false`), the pipeline runs steps D-T once using static splits based on `split_ratios` config).* +*(Note: If Walk-Forward is disabled, the pipeline runs the fold processing steps once using static splits).*\ + + +## Project Structure + +``` +gru_sac_predictor/ +├── config/ +│ └── config.yaml # Main configuration file +├── data/ # Data storage (e.g., parquet files) +├── logs/ # Log output directory (run-specific subdirs) +├── models/ # Saved models/scalers etc. (run-specific subdirs) +├── results/ # Output results, metrics, plots (run-specific subdirs) +├── src/ +│ ├── pipeline_stages/ # **Refactored stage-specific logic** +│ │ ├── data_processing.py +│ │ ├── evaluation.py +│ │ ├── feature_processing.py +│ │ ├── modelling.py +│ │ └── sequence_creation.py +│ ├── __init__.py +│ ├── backtester.py # Backtesting simulation engine +│ ├── baseline_checker.py # Baseline logistic regression check +│ ├── calibrator.py # Temperature scaling +│ ├── calibrator_vector.py # Vector scaling +│ ├── data_loader.py # Loads raw data +│ ├── feature_engineer.py # Feature creation logic +│ ├── features.py # Feature lists/definitions (optional) +│ ├── gru_hyper_tuner.py # Optuna hyperparameter tuner for GRU +│ ├── gru_model_handler.py # GRU model building, training, loading (v2/v3) +│ ├── io_manager.py # Handles file I/O structure +│ ├── logger_setup.py # Configures logging +│ ├── metrics.py # Performance metrics calculation +│ ├── sac_agent.py # SAC agent network definitions +│ ├── sac_trainer.py # Offline SAC training orchestration +│ ├── trading_env.py # Gym-like environment for SAC training +│ ├── trading_pipeline.py # Main pipeline orchestrator class +│ └── utils/ # Utility functions (e.g., run_id generation) +├── tests/ # Unit/integration tests (optional) +├── run.py # Main execution script +├── requirements.txt # Python dependencies +└── README.md # This file +``` ## Core Components Architecture -This section details the architecture and purpose of key modules. +* **`run.py`**: Entry point, sets up IO/logging, runs the pipeline. +* **`TradingPipeline` (`src/trading_pipeline.py`):** Orchestrates the workflow by calling stage functions, manages overall state, handles walk-forward loop and validation gates. +* **`src/pipeline_stages/*.py`**: Contain the core implementation logic for each distinct step of the pipeline (data processing, feature processing, sequencing, modelling, evaluation). +* **`IOManager`, `LoggerSetup`**: Utilities for managing outputs and logging. +* **`DataLoader`, `FeatureEngineer`**: Data loading and feature generation. +* **`GRUModelHandler`, `GRUHyperTuner`**: GRU model implementation (v2/v3), training, loading, and Optuna tuning. +* **`Calibrator`, `VectorCalibrator`**: Probability calibration logic. +* **`SACTradingAgent`, `TradingEnv`, `SACTrainer`**: SAC agent definition, training environment, and offline training orchestration (including PER, normalization, seeding). +* **`Backtester`, `BaselineChecker`, `metrics.py`**: Backtesting simulation, baseline checks, and performance metric calculations. -### 1. Orchestration & Utilities (`TradingPipeline`, `IOManager`, `LoggerSetup`) -* **`TradingPipeline` (`src/trading_pipeline.py`):** The main class coordinating the workflow (data prep, feature eng, model training, calibration, backtesting, aggregation) potentially within a walk-forward loop. Reads config, manages data flow, implements validation gates, and calls other components. -* **`IOManager` (`src/io_manager.py`):** Handles standardized file I/O (saving/loading models, dataframes, scalers, configs, plots, reports) within a run-specific directory structure. -* **`LoggerSetup` (`src/logger_setup.py`):** Configures Python's `logging` for console and file output with standardized formats and levels. - -### 2. Data Handling (`DataLoader`, `FeatureEngineer`) -* **`DataLoader` (`src/data_loader.py`):** Loads raw OHLCV data from specified database files/sources. -* **`FeatureEngineer` (`src/feature_engineer.py`):** Generates features: - * **TA:** Returns, ATR, EMA, RSI (MACD removed). - * **Cyclical:** Hour, Week (sin/cos). - * **Imbalance:** Chaikin AD, SVI, Gap Imbalance. - * **Micro-structure:** Spread Proxy, Vol-Norm Volume Spike, Return Asymmetry, Close-Location Value, Keltner Band Position (with NaN guards). - * **Leakage Guard:** Uses `shift(1)` on inputs for time-dependent features. - * **Target Definition:** Calculates forward returns and binary/ternary labels. - * **Selection/Pruning:** Performs **leakage check** (`corr(ret+h, feat_t-1)`), then L1+VIF selection on *raw/engineered* data (applying minimal whitelist before VIF), then prunes *scaled* data based on the selection. - -### 3. GRU Predictor (`gru_model_handler.py`, `gru_hyper_tuner.py`) -* **`gru_model_handler.py`:** **Consolidated GRU implementation.** - * Contains builders for both v2 (`build_gru_model`) and v3 (`build_gru_model_v3`) architectures. - * **v3 Architecture:** GRU -> LayerNorm -> Optional MultiHeadAttention -> GlobalAvgPool -> Dense Heads (`mu`, `dir3_logits`). Includes L2 regularization. Uses Huber loss for `mu`, Focal loss for `dir3`. - * Manages training (with early stopping, CSV logging), saving (.keras format), loading (handles custom loss `gaussian_nll`), and prediction (`predict`, `predict_logits`). - * Selects v2/v3 based on `control.use_v3` flag. -* **`gru_hyper_tuner.py`:** Implements Optuna hyperparameter sweep for GRU. - * Called by `TradingPipeline` if sweep is enabled. - * Uses **restricted search space** based on config (`hyperparameter_tuning.gru`). - * Uses combined objective based on config (`objective_metric`, `objective_edge_acc_weight`, `objective_brier_weight`). Logs components to trial attributes. - * Supports pruning via **Keras callback** reporting `val_loss` each epoch. - * Trains final fold model using best found parameters. - * Saves best parameters JSON and Optuna plots per fold. - -### 4. Probability Calibration (`Calibrator`, `VectorCalibrator`, `metrics.py`) -* **`Calibrator` (`src/calibrator.py`):** Implements Temperature Scaling (learns scalar `T`) with optional L2 regularization. -* **`VectorCalibrator` (`src/calibrator_vector.py`):** Implements Vector Scaling (learns matrix `W`, bias `b`) with optional L2 regularization. Preferred for ternary. -* **Integration:** `TradingPipeline.calibrate_probabilities` fits chosen calibrator per fold (with L2 reg). If `optimize_edge_threshold=true`, calculates and stores optimal edge for the fold. `Backtester` applies calibration step-by-step, handles rolling recalibration if enabled (with **ECE-based coverage alarm**), using the **static fold calibration** if SAC training was active for the fold. -* **Edge Threshold Optimization:** `metrics._calculate_optimal_edge_threshold` finds best threshold using Youden's J. Called by `TradingPipeline` if enabled. - -### 5. SAC Agent (`sac_agent.py`, `sac_trainer.py`, `trading_env.py`) -* **`sac_agent.py`:** Defines the SAC agent networks (Actor, Critic) and update logic. - * Actor outputs squashed Gaussian distribution parameters. - * Uses twin Q-critics. - * Handles automatic entropy tuning (alpha). - * `train` method accepts a batch and returns losses + TD errors (for PER). -* **`trading_env.py`:** Gym-style environment using GRU predictions. - * **State:** `[mu, sigma, edge, |mu|/sigma, position]`. - * Takes action (-1 to +1). - * Calculates reward based on PnL, potentially scaled (`reward_scale`) and penalized (**action penalty**, default: `0.01 / transaction_cost`). -* **`sac_trainer.py`:** Orchestrates offline SAC training. - * Loads GRU dependencies for a specific run. - * Prepares validation data for the `TradingEnv` (uses fold's static calibration). - * Initializes `TradingEnv` and `SACTradingAgent`. - * Manages the Replay Buffer: - * Implements `PrioritizedReplayBuffer` if `sac.use_per` is true (with **TD error clipping** and **alpha annealing**). Logs TD error distribution stats. - * Uses `collections.deque` for standard uniform replay. - * Handles **Oracle Seeding** into the buffer (with **IS weight decay** for seeded samples). - * Runs the training loop: interacts with env, stores transitions, samples batches (uniform or PER), updates agent, updates PER priorities. - * Handles **State Normalization** using `MeanStdFilter` if `sac.use_state_filter` is true (saves/loads filter state). - * Saves agent checkpoints, final agent, state filter, and logs (rewards, TensorBoard). - -### 6. Evaluation (`Backtester`, `BaselineChecker`, `metrics.py`) -* **`Backtester` (`src/backtester.py`):** Evaluates the full system on test data. - * Takes trained GRU, SAC agent, initial calibration state, and the fold's edge threshold. - * Simulates step-by-step trading. - * Applies **rolling calibration** logic if enabled (with ECE check). - * Calculates PnL, equity curve, standard performance metrics. - * Saves detailed results dataframe, metrics summary, and plots per fold. - * Logs performance and whether fold backtest gates passed/failed, but does **not** halt the fold on failure (decision made during final aggregation). -* **`BaselineChecker` (`src/baseline_checker.py`):** Performs initial logistic regression check on **raw/engineered** training features. -* **`metrics.py`:** Contains calculation functions for Sharpe, Brier, Edge-Filtered Accuracy, ECE, and the Youden's J optimization helper. ## Configuration (`config.yaml`) -The `config.yaml` file centrally controls the pipeline's behavior. Key sections include: +The `config.yaml` file centrally controls the pipeline's behavior. See comments within the default `config.yaml` for detailed explanations of each parameter. Key sections include: -* `base_dirs`: Output directories. -* `output`: Figure DPI, size, logging level. -* `data`: Data source details, label smoothing. -* `features`: Minimal feature whitelist, leakage threshold. -* `walk_forward`: Settings for WF validation (enable, days, step). **Note:** `walk_forward.enabled=true` overrides `split_ratios`. -* `split_ratios`: Used only if `walk_forward.enabled=false`. -* `gru`: General GRU settings (horizon, lookback, ternary flag, flat sigma mult). -* `gru_v3`: Specific hyperparameters for the v3 architecture (units, attention, losses, reg). -* `hyperparameter_tuning`: Controls Optuna sweep for GRU (enable, trials, timeout, pruning, objective metric/weights). -* `calibration`: Method (temp/vector), L2 lambda, optimize edge threshold flag, rolling calibration settings (enable, freq, window, ECE alarm). -* `validation_gates`: Thresholds for baseline, GRU gates, and final release decision (median Sharpe, % success). -* `sac`: SAC hyperparameters (gamma, tau, LR, alpha, PER settings, oracle seeding, IS weight decay steps, state filter). -* `sac_aggregation`: Controls post-run agent averaging (enable, method). -* `environment`: Trading env parameters (capital, costs, reward scale, action penalty lambda). -* `control`: Flags to enable/disable major stages (train GRU, train SAC, run backtest, use v3, plots), model loading/resuming IDs. +* `base_dirs`, `output`: Directory and output settings. +* `data`, `features`: Data sources, labeling, feature selection controls. +* `walk_forward`, `split_ratios`: Controls walk-forward vs. static splits. +* `gru`, `gru_v3`: GRU architecture, training parameters. +* `hyperparameter_tuning`: Optuna sweep settings for GRU. +* `calibration`: Calibration method, parameters, rolling calibration, ECE alarm. +* `validation_gates`: Thresholds for baseline, GRU, backtest, and final release checks. +* `sac`, `environment`: SAC agent hyperparameters, PER, seeding, environment settings. +* `sac_aggregation`: Agent averaging settings. +* `control`: High-level flags (train/load models, enable plots, use v3 GRU). + +## Installation + +1. Clone the repository. +2. Ensure you have Python 3.8+ installed. +3. Set up a virtual environment (recommended): + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows use `.venv\\Scripts\\activate` + ``` +4. Install dependencies: + ```bash + pip install -r requirements.txt + ``` + *Note: This installs necessary libraries like TensorFlow, PyTorch, Optuna, scikit-learn, pandas, etc.* +5. Prepare your data according to the expected format and update paths in `config.yaml`. ## Usage -1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location. -2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation). -3. **Run Pipeline:** - ```bash - # From project root (develop/gru_sac_predictor/) - python gru_sac_predictor/run.py --config path/to/your_config.yaml +1. **Configure:** Edit `config.yaml` to set data paths, feature lists, model parameters, control flags (e.g., `train_gru`, `train_sac`), walk-forward settings, validation thresholds, etc. +2. **Run Pipeline:** Execute `run.py` from the project root directory (`gru_sac_predictor/`), specifying the configuration file: + ```bash + # Example execution from the parent directory 'develop/gru_sac_predictor/' + python gru_sac_predictor/run.py --config gru_sac_predictor/config/config.yaml ``` -4. **Outputs:** Check `logs/`, `models/`, `results/` directories for run-specific outputs, including fold-specific artifacts if WF is enabled. + * You can use other command-line arguments like `--use-ternary` if implemented in `run.py`. +3. **Outputs:** Check the directories specified in `config.yaml` (typically subdirectories within `logs/`, `models/`, `results/`) for run-specific outputs. If Walk-Forward is enabled, you will find fold-specific subdirectories containing models, scalers, plots, and results for each fold. ## Output Artifacts (Walk-Forward Enabled Example) * **Main Run Dirs:** `logs/run_/`, `models/run_/`, `results/run_/` * `run_config.yaml`, `pipeline_.log` - * (Post-run) `aggregated_wf_metrics.json`, `sac_aggregation_info.txt` - * (Post-run) `models/.../sac_agent_aggregated/` + * (Post-run) `aggregated_wf_metrics.json`, `sac_aggregation_info.txt` (if enabled) + * (Post-run) `models/.../sac_agent_aggregated/` (if enabled) * **Fold Dirs (within main run dirs):** e.g., `models/run_/fold_1/` * `models/run_/fold_N/models/`: `gru_model_fold_N.keras`, `calibration_{...}_fold_N.npy`, `feature_scaler_fold_N.joblib`, `final_whitelist_fold_N.json` * `models/run_/fold_N/hypertuning/`: (If sweep enabled) `best_gru_params.json`, Optuna plots. @@ -248,4 +259,17 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections ## Dependencies -See `requirements.txt`. Key libraries include: TensorFlow, NumPy, Pandas, PyYAML, Scikit-learn, Statsmodels, TA-Lib (via `ta` wrapper), Matplotlib, Seaborn, Optuna, PyTorch (for SAC aggregation). Note `tensorflow-addons` is required for optimal focal loss / attention layers. +All major Python dependencies are listed in `requirements.txt`. Key libraries include: + +* TensorFlow (for GRU) +* PyTorch (for SAC) +* Optuna (for hyperparameter tuning) +* scikit-learn (for scaling, metrics, baseline) +* pandas, numpy +* pyyaml +* matplotlib, seaborn + +Install them using: +```bash +pip install -r requirements.txt +``` diff --git a/gru_sac_predictor/docs/v3_changelog.md b/gru_sac_predictor/docs/v3_changelog.md new file mode 100644 index 00000000..b00dee04 --- /dev/null +++ b/gru_sac_predictor/docs/v3_changelog.md @@ -0,0 +1,26 @@ +# GRU-SAC Predictor v3 Changelog\n\nThis document summarizes the major changes and new configuration options introduced in the v3 revisions (as outlined in `revisions.txt`).\n\n## Key Changes & New Features\n\n### 1. Data & Labeling (`config.data`, `config.gru`)\n\n* **Volatility-Aware Sampling (Task 1.1):**\n * Added optional sampling in `DataLoader` to focus on higher volatility periods.\n * Config: `data.vol_sampling` (bool), `data.vol_window` (int), `data.vol_quantile` (float).\n* **Soft Binary Labels (Task 1.2):**\n * Option to use smoothed labels (e.g., \[0.1, 0.9]) instead of hard {0, 1} for binary classification.\n * Config: `data.label_smoothing` (float, 0.0 to disable).\n* **Ternary Direction Labels (Task 1.3):**\n * Added option for \"up\" / \"flat\" / \"down\" classification.\n * \"Flat\" defined dynamically based on forward return volatility.\n * Config: `gru.use_ternary` (bool), `gru.flat_sigma_multiplier` (float).\n\n### 2. Feature Engineering (`config.features` - conceptual)\n\n* **Volatility-Normalized Return (Task 2.1):**\n * Added `vola_norm_return(df, k)` function.\n * Calculated for k=15, k=60 and added to default features (`vola_norm_return_15`, `vola_norm_return_60`).\n* **Weekly Fourier Features (Task 2.2):**\n * Added `week_sin`, `week_cos` to capture weekly seasonality.\n * Added to default features.\n* **MACD Removal (Task 2.3):**\n * Removed `MACD` and `MACD_signal` calculation and from `minimal_whitelist`.\n* **VIF Skip Logic (Task 2.5):**\n * Conceptual: Tests added assuming a `config.features.skip_vif` flag could be implemented in `FeatureEngineer.select_features`.\n\n### 3. GRU v3 Model (`config.gru_v3`, `config.control.use_v3`)\n\n* **New Architecture (Task 3.1):**\n * Implemented `model_gru_v3.py` with `GRU(units) -> Attention -> LayerNorm` structure.\n* **New Output Heads (Task 3.2):**\n * `dir3`: Dense(3, softmax) for ternary classification.\n * `mu`: Dense(1, linear) for return prediction.\n* **New Loss Configuration (Task 3.3):**\n * Uses `CategoricalFocalCrossentropy` for `dir3` and `Huber` for `mu`.\n * Loss weights configurable.\n* **Configurable Hyperparameters (Task 3.4):**\n * New `gru_v3` section in `config.yaml` exposes `gru_units`, `attention_units`, `learning_rate`, loss parameters (`focal_gamma`, `focal_label_smoothing`, `huber_delta`), and loss weights (`loss_weight_mu`, `loss_weight_dir3`).\n* **Model Selection (Task 3.5):**\n * Added `control.use_v3` (bool) flag to switch between GRU v2 and v3 logic within `GRUModelHandler`.\n\n### 4. Vector Scaling Calibration (`config.calibration`)\n\n* **New Calibrator (Task 4.1):**\n * Added `calibrator_vector.py` with `VectorCalibrator` class implementing vector scaling (optimizes diagonal matrix `W` and bias `b`).\n* **Method Selection (Task 4.2):**\n * Added `calibration.method` config option (`temperature` or `vector`). `TradingPipeline` routes to the appropriate calibrator.\n* **Parameter Handling (Task 4.3):**\n * `VectorCalibrator` saves/loads its parameters (`[W_diag, b]`) to `.npy` files.\n* **Logits Requirement:**\n * Vector scaling requires pre-softmax logits. Added `GRUModelHandler.predict_logits` method using an inference-only model view to retrieve these without altering the main model structure.\n\n### 5. SAC Stabilisation (`config.sac`, `config.environment`)\n\n* **Reward Scaling (Task 5.1):**\n * Environment reward is multiplied by a scaling factor.\n * Config: `environment.reward_scale` (float).\n* **State Normalization (Task 5.2):**\n * Added `utils.running_stats.MeanStdFilter`.\n * `SACTrainer` optionally normalizes environment states using this filter.\n * Config: `sac.use_state_filter` (bool).\n * Filter state is saved/loaded with agent checkpoints.\n* **Target Entropy Calculation (Task 5.3):**\n * `SACTradingAgent` automatically calculates target entropy as `-0.5 * log(4)` if `alpha_auto_tune` is true and the default `target_entropy` (`-action_dim`) is used.\n * Config: `sac.target_entropy` (float or null).\n* **Action Penalty (Task 5.4):**\n * Added quadratic penalty to the environment reward based on action magnitude.\n * Config: `environment.action_penalty_lambda` (float).\n* **Oracle Buffer Seeding (Task 5.5):**\n * `SACTrainer` can pre-populate a percentage of the replay buffer using a heuristic policy based on GRU predictions.\n * Config: `sac.oracle_seeding_pct` (float).\n* **Metadata Update (Task 5.6):**\n * `reward_scale` and `lambda` (action penalty) are now saved in `agent_metadata.json`.\n\n### 6. Metrics & Validation (`config.calibration`, `src/metrics.py`)\n\n* **Edge-Filtered Accuracy (Task 6.1):**\n * Added `metrics.edge_filtered_accuracy` function.\n* **Validation Check (Task 6.2):**\n * Added a check in `TradingPipeline` after calibration. Calculates edge-filtered accuracy on the validation set and computes the 95% CI lower bound.\n * Pipeline fails if CI lower bound < 0.60.\n* **Re-centred Sharpe Ratio (Task 6.3):**\n * Added `metrics.calculate_sharpe_ratio` function allowing custom benchmark return (defaults to 0).\n* **Backtester Reporting (Task 6.4):**\n * `Backtester` now calculates and saves edge-filtered accuracy and re-centred Sharpe ratio to the metrics file.\n\n### 7. Missing Data Handling (`config.data`, `config.gru`, `config.sac`) + +* **Detection & Filling (Tasks PR1):** + * Pipeline now detects missing bars based on `data.bar_frequency`. + * Configurable filling strategies (`data.missing.strategy`: `drop`, `neutral`, `ffill`, `interpolate`) implemented in `data_loader.py`. + * Warning logged with missing count and longest gap. + * `missing_bars_summary.json` artifact saved. + * Error raised if longest gap exceeds `data.missing.max_gap`. + * `bar_imputed` boolean column added to data after filling. +* **Sequence Handling (Task PR2):** + * `bar_imputed` added as a feature input to the GRU model (`minimal_whitelist`). + * Sequences containing any imputed bars can be dropped via `gru.drop_imputed_sequences` (boolean flag). + * Logs number of sequences dropped. + * `imputed_sequence_summary.json` artifact saved. +* **SAC Environment Handling (Task PR3):** + * `TradingEnv` now aware of imputed steps. + * Behavior configured by `sac.imputed_handling` (`skip`, `hold`, `penalty`). + * `skip`: Skips step, 0 reward, transition ignored by replay buffer. + * `hold`: Action forced to current position, normal reward calc. + * `penalty`: Agent action taken, penalty reward applied (config `sac.action_penalty`). + * Debug log message added for SAC imputed steps. + * Optional `sac_imputed_transitions.csv` artifact logs details of non-skipped imputed steps. + +## Configuration Summary + +See the updated `config.yaml` for details on the following new/modified sections and parameters:\n\n* `data`: `vol_sampling`, `vol_window`, `vol_quantile`, `label_smoothing`, `bar_frequency` (new), `missing` (new section: `strategy`, `max_gap`, `interpolate`)\n* `gru`: `use_ternary`, `flat_sigma_multiplier`, `drop_imputed_sequences` (new)\n* `gru_v3`: (New section with architecture, training, and compilation parameters)\n* `calibration`: `method`\n* `sac`: `use_state_filter`, `target_entropy` (updated behaviour), `oracle_seeding_pct`, `imputed_handling` (new), `action_penalty` (new)\n* `environment`: `reward_scale`, `action_penalty_lambda`\n* `control`: `use_v3`\n\n*(Note: Some parameters under `gru` like epochs/batch_size/patience primarily apply when `control.use_v3` is false)*.\n \ No newline at end of file diff --git a/gru_sac_predictor/src/sac_trainer.py b/gru_sac_predictor/src/sac_trainer.py index 68437e29..74f72840 100644 --- a/gru_sac_predictor/src/sac_trainer.py +++ b/gru_sac_predictor/src/sac_trainer.py @@ -19,6 +19,8 @@ from datetime import datetime from tqdm import tqdm from tensorflow.keras.callbacks import TensorBoard import collections +import csv +import time # Import necessary components from the pipeline # Use absolute imports assuming the package structure is correct @@ -208,6 +210,8 @@ class SACTrainer: self.control_cfg = config.get('control', {}) self.data_cfg = config['data'] + # --- Cache useful config values --- + self.use_ternary = self.config.get('gru', {}).get('use_ternary', False) # --- Store PER config params --- # self.use_per = self.sac_cfg.get('use_per', False) self.per_alpha = self.sac_cfg.get('per_alpha', 0.6) @@ -301,22 +305,84 @@ class SACTrainer: # 3. Load GRU Model model_path = os.path.join(gru_run_models_dir, f"gru_model_{gru_run_id}.keras") # Need a temporary GRU handler instance to use its load method - temp_gru_handler = GRUModelHandler(run_id="temp_load", models_dir="temp_load") - dependencies['gru_model'] = temp_gru_handler.load(model_path) - if dependencies['gru_model'] is None: - logger.error(f"Failed to load GRU model from {model_path}") + # Pass self.config (the trainer's config) to the GRUModelHandler + temp_gru_handler = GRUModelHandler(run_id="temp_load", models_dir="temp_load", config=self.config) + # Attempt to load the model using the handler + # The load method likely needs the full path, not just the directory and ID + loaded_model_info = temp_gru_handler.load(model_path=model_path) # Pass full path + + # Adjust based on what gru_handler.load returns + # Assuming it returns (model, info_dict) or None + if loaded_model_info and isinstance(loaded_model_info, tuple) and len(loaded_model_info) > 0: + dependencies['gru_model'] = loaded_model_info[0] # Get the actual model + if dependencies['gru_model'] is None: + logger.error(f"GRU handler load method returned None for model from {model_path}") + return None + elif loaded_model_info and not isinstance(loaded_model_info, tuple): # If it just returns the model + dependencies['gru_model'] = loaded_model_info + else: # If it returned None or unexpected tuple + logger.error(f"Failed to load GRU model using handler from {model_path}") return None + logger.info(f"Loaded GRU model from {model_path}") - # 4. Load Optimal Temperature - temp_path = os.path.join(gru_run_models_dir, f"calibration_temp_{gru_run_id}.npy") + # 4. Load Calibration Info and Parameters + calib_info_path = os.path.join(gru_run_models_dir, f"calibration_info_{gru_run_id}.json") + calib_params = None + calib_method = None try: - dependencies['optimal_T'] = float(np.load(temp_path)) - logger.info(f"Loaded optimal temperature T={dependencies['optimal_T']:.4f} from {temp_path}") + with open(calib_info_path, 'r') as f: + calib_info = json.load(f) + logger.info(f"Loaded calibration info: {calib_info}") + calib_method = calib_info.get("method") + params_filename = calib_info.get("params_filename") + + if calib_method and params_filename: + params_path = os.path.join(gru_run_models_dir, params_filename) + if os.path.exists(params_path): + if calib_method == "temperature": + calib_params = float(np.load(params_path)) + logger.info(f"Loaded optimal temperature T={calib_params:.4f} from {params_path}") + elif calib_method == "vector": + # Load vector params (assuming saved as .npy) + calib_params = np.load(params_path, allow_pickle=True) + logger.info(f"Loaded vector calibration params from {params_path} (type: {type(calib_params)}, shape: {getattr(calib_params, 'shape', 'N/A')})") + else: + logger.warning(f"Unknown calibration method '{calib_method}' in info file. Cannot load params.") + else: + logger.error(f"Calibration parameter file specified in info ({params_filename}) not found at {params_path}") + return None # Fail if params file missing + elif calib_method in ["temperature_failed", "vector_failed", "skipped_mismatch", "temperature_error", "vector_error", "vector_unavailable"]: + logger.warning(f"Calibration was not successful or was skipped during GRU run (method: {calib_method}). SAC training may be suboptimal.") + # Decide if this is fatal. For now, let's allow it but store None. + calib_params = None + else: + logger.error(f"Invalid calibration info content: {calib_info}") + return None # Fail if info is invalid + + except FileNotFoundError: + logger.error(f"Calibration info file not found at {calib_info_path}. Cannot determine calibration parameters.") + # Check for legacy temperature file for backward compatibility? + legacy_temp_path = os.path.join(gru_run_models_dir, f"calibration_temp_{gru_run_id}.npy") + if os.path.exists(legacy_temp_path): + logger.warning("Calibration info file missing, but found legacy temperature file. Attempting to load it.") + try: + calib_params = float(np.load(legacy_temp_path)) + calib_method = "temperature" # Assume temperature + logger.info(f"Loaded legacy optimal temperature T={calib_params:.4f} from {legacy_temp_path}") + except Exception as legacy_e: + logger.error(f"Failed to load legacy temperature file {legacy_temp_path}: {legacy_e}") + return None # Fail if legacy load fails + else: + logger.error("Neither calibration info file nor legacy temperature file found.") + return None # Fail if no calibration info found except Exception as e: - logger.error(f"Failed to load optimal temperature from {temp_path}: {e}", exc_info=True) - # Allow continuation without T? Or require it? Let's require it for now. + logger.error(f"Failed to load calibration info/parameters: {e}", exc_info=True) return None + + # Store method and params in dependencies + dependencies['calibration_method'] = calib_method + dependencies['calibration_params'] = calib_params logger.info("--- Successfully loaded all GRU dependencies ---") return dependencies @@ -349,107 +415,284 @@ class SACTrainer: df_raw.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True) logger.info("Loaded raw data.") - # 2. Engineer Base Features (using a temporary FeatureEngineer) - # Pass the *minimal* whitelist as a fallback if the loaded one causes issues - temp_feature_engineer = FeatureEngineer(minimal_whitelist=minimal_whitelist) + # 2. Engineer Base Features (using config) + # FIX: Instantiate FeatureEngineer correctly using self.config + temp_feature_engineer = FeatureEngineer(config=self.config) df_engineered = temp_feature_engineer.add_base_features(df_raw) df_engineered.dropna(inplace=True) # Drop NaNs after feature eng if df_engineered.empty: raise ValueError("Dataframe empty after feature engineering.") logger.info("Engineered base features.") - # 3. Prune Features using *loaded* whitelist - loaded_whitelist = gru_dependencies['whitelist'] - missing_in_eng = [f for f in loaded_whitelist if f not in df_engineered.columns] - if missing_in_eng: - raise ValueError(f"Features from loaded whitelist missing in engineered data: {missing_in_eng}") - df_features = df_engineered[loaded_whitelist] - logger.info("Pruned features using loaded whitelist.") - - # 4. Define Labels + # 3. Define Labels (on the full engineered data) horizon = self.config['gru'].get('prediction_horizon', 5) target_ret_col = f'fwd_log_ret_{horizon}' - target_dir_col = f'direction_label_{horizon}' - df_engineered[target_ret_col] = np.log(df_engineered['close'].shift(-horizon) / df_engineered['close']) - df_engineered[target_dir_col] = (df_engineered[target_ret_col] > 0).astype(int) - # Align by dropping NaNs in targets AND ensuring indices match features + target_dir_col = f'direction_label3_{horizon}' if self.use_ternary else f'direction_label_{horizon}' + _EPS = 1e-9 + if 'future_close' not in df_engineered.columns: + df_engineered['future_close'] = df_engineered['close'].shift(-horizon) + if 'future_close' in df_engineered.columns: + df_engineered[target_ret_col] = np.log(df_engineered['future_close'] / (df_engineered['close'] + _EPS)) + else: + df_engineered[target_ret_col] = np.log(df_engineered['close'].shift(-horizon) / (df_engineered['close'] + _EPS)) + + if self.use_ternary: + flat_sigma = self.config.get('gru', {}).get('flat_sigma_multiplier', 0.3) + if 'ATR_14' in df_engineered.columns: + atr_aligned = df_engineered['ATR_14'].reindex(df_engineered.index).bfill().ffill() + threshold = flat_sigma * atr_aligned / (df_engineered['close'] + _EPS) + df_engineered[target_dir_col] = np.select( + [df_engineered[target_ret_col] > threshold, df_engineered[target_ret_col] < -threshold], + [2, 0], default=1 + ) + else: + raise ValueError("ATR_14 needed for ternary labels not found.") + else: # Binary case + df_engineered[target_dir_col] = (df_engineered[target_ret_col] > 0).astype(int) + + # 4. Align Features and Targets df_engineered.dropna(subset=[target_ret_col, target_dir_col], inplace=True) - common_index = df_features.index.intersection(df_engineered.index) - if common_index.empty: - raise ValueError("No common index between features and targets after label definition.") - df_features = df_features.loc[common_index] - df_targets = df_engineered.loc[common_index, [target_ret_col, target_dir_col]] + # Identify *all* potential feature columns (excluding targets) + potential_feature_cols = [col for col in df_engineered.columns if col not in [target_ret_col, target_dir_col, 'future_close']] + # Ensure whitelist features are present in potential features + loaded_whitelist = gru_dependencies['whitelist'] + missing_whitelist_check = set(loaded_whitelist) - set(potential_feature_cols) + if missing_whitelist_check: + logger.warning(f"Whitelist features missing from engineered columns: {missing_whitelist_check}. Adding with 0 fill.") + for col in missing_whitelist_check: + df_engineered[col] = 0.0 + potential_feature_cols.append(col) + # Now df_engineered contains all potential features and targets, aligned logger.info("Defined labels and aligned features/targets.") - - # 5. Split Data (to get validation set indices) - split_cfg = self.config['split_ratios'] - train_ratio, val_ratio = split_cfg['train'], split_cfg['validation'] - total_len = len(df_features) + + # 5. Split Data (using aligned engineered data) + wf_enabled = self.config.get('walk_forward', {}).get('enabled', False) + if wf_enabled: logger.warning("Walk-forward enabled, but SAC uses split_ratios.") + split_ratios = self.config.get('walk_forward', {}).get('split_ratios', {}) + train_ratio = split_ratios.get('train', 0.6); val_ratio = split_ratios.get('validation', 0.2) + if not (0 < train_ratio < 1 and 0 < val_ratio < 1 and (train_ratio + val_ratio) <= 1.0): + raise ValueError(f"Invalid split ratios: train={train_ratio}, validation={val_ratio}") + total_len = len(df_engineered) train_end_idx = int(total_len * train_ratio) val_end_idx = int(total_len * (train_ratio + val_ratio)) - val_indices = df_features.index[train_end_idx:val_end_idx] + val_indices = df_engineered.index[train_end_idx:val_end_idx] if val_indices.empty: raise ValueError("Validation split resulted in empty indices.") - X_val_pruned = df_features.loc[val_indices] - y_val = df_targets.loc[val_indices] - logger.info(f"Isolated validation set data (Features: {X_val_pruned.shape}, Targets: {y_val.shape}).") + df_val_aligned = df_engineered.loc[val_indices] - # 6. Scale Validation Features using *loaded* scaler + # -- Determine columns expected by scaler -- scaler = gru_dependencies['scaler'] - numeric_cols = X_val_pruned.select_dtypes(include=np.number).columns - X_val_scaled = X_val_pruned.copy() - if not numeric_cols.empty: - X_val_scaled[numeric_cols] = scaler.transform(X_val_pruned[numeric_cols]) - logger.info("Scaled validation features using loaded scaler.") + expected_scaler_features = [] + if hasattr(scaler, 'feature_names_in_'): + expected_scaler_features = scaler.feature_names_in_.tolist() + logger.debug(f"Scaler expects features: {expected_scaler_features}") + else: + # Fallback: Use all numeric columns if scaler has no names + logger.warning("Scaler has no feature names saved. Assuming it was fit on all numeric columns present at that time.") + # This is risky; need to ensure columns match implicitly + expected_scaler_features = df_val_aligned.select_dtypes(include=np.number).columns.tolist() + # Manually exclude known target columns if needed as a safeguard + expected_scaler_features = [f for f in expected_scaler_features if f not in [target_ret_col, target_dir_col]] + logger.debug(f"Falling back to using numeric columns for scaler: {expected_scaler_features}") - # 7. Create Validation Sequences + # Ensure all expected scaler features exist in the validation slice + missing_for_scaler = set(expected_scaler_features) - set(df_val_aligned.columns) + if missing_for_scaler: + # Attempt to add missing with 0 fill (e.g., if a feature wasn't calculable for val split) + logger.warning(f"Columns expected by scaler missing from validation data: {missing_for_scaler}. Adding with 0 fill.") + for col in missing_for_scaler: + df_val_aligned[col] = 0.0 + # Re-verify + missing_for_scaler = set(expected_scaler_features) - set(df_val_aligned.columns) + if missing_for_scaler: + raise ValueError(f"Could not prepare all features expected by scaler, still missing: {missing_for_scaler}") + + # Isolate the exact features needed for the scaler transform + X_val_engineered_for_scaler = df_val_aligned[expected_scaler_features] + y_val = df_val_aligned[[target_ret_col, target_dir_col]] + logger.info(f"Isolated validation set data for scaler (Features: {X_val_engineered_for_scaler.shape}, Targets: {y_val.shape}).") + + # 6. Scale the features expected by the scaler + X_val_scaled_full = X_val_engineered_for_scaler.copy() + try: + X_val_scaled_full[expected_scaler_features] = scaler.transform(X_val_engineered_for_scaler) + logger.info("Scaled validation features using loaded scaler.") + except ValueError as e: + logger.error(f"Error applying scaler transform even after aligning columns: {e}") + raise + + # 7. WORKAROUND: Drop non-feature columns (like future_close) *after* scaling + # Identify actual features intended for the model (whitelist) + loaded_whitelist = gru_dependencies['whitelist'] + cols_to_keep_after_scale = [f for f in loaded_whitelist if f in X_val_scaled_full.columns] + missing_whitelist_post_scale = set(loaded_whitelist) - set(cols_to_keep_after_scale) + if missing_whitelist_post_scale: + # This shouldn't happen if whitelist features were in expected_scaler_features + logger.warning(f"Whitelist features missing AFTER scaling: {missing_whitelist_post_scale}. This might indicate issues.") + + if not cols_to_keep_after_scale: + raise ValueError("Whitelist resulted in no features remaining after scaling step.") + + X_val_scaled_eng = X_val_scaled_full[cols_to_keep_after_scale].copy() + logger.info(f"Selected whitelisted features after scaling. Shape: {X_val_scaled_eng.shape}") + + # 8. Pruning step is now effectively done by selecting whitelist columns above. + X_val_pruned_scaled = X_val_scaled_eng # Rename for clarity in subsequent steps + logger.info(f"Using whitelisted scaled features for sequencing. Final shape: {X_val_pruned_scaled.shape}") + + # 9. Create Validation Sequences (using the pruned & scaled data) lookback = self.config['gru']['lookback'] - X_val_seq = [] - y_val_seq_targets = [] # Store corresponding targets - val_seq_indices = [] # Store corresponding indices - features_np = X_val_scaled.values - targets_np = y_val.values # Contains both ret and dir - for i in range(lookback, len(features_np)): - X_val_seq.append(features_np[i-lookback : i]) - y_val_seq_targets.append(targets_np[i]) # Target corresponds to end of sequence + X_val_seq_list, y_val_seq_targets_list, val_seq_indices = [], [], [] + features_np_arr = X_val_pruned_scaled.values + targets_np_arr = y_val.values + for i in range(lookback, len(features_np_arr)): + X_val_seq_list.append(features_np_arr[i-lookback : i]) + y_val_seq_targets_list.append(targets_np_arr[i]) val_seq_indices.append(y_val.index[i]) - - if not X_val_seq: - raise ValueError("Validation sequence creation resulted in empty list.") - - X_val_seq = np.array(X_val_seq) - y_val_seq_targets = np.array(y_val_seq_targets) - actual_ret_val_seq = y_val_seq_targets[:, 0] # First column is return - y_dir_val_seq = y_val_seq_targets[:, 1] # Second column is direction + + if not X_val_seq_list: raise ValueError("Validation sequence creation resulted in empty list.") + X_val_seq = np.array(X_val_seq_list) + y_val_seq_targets = np.array(y_val_seq_targets_list) + actual_ret_val_seq = y_val_seq_targets[:, 0] + y_dir_val_seq = y_val_seq_targets[:, 1] logger.info(f"Created validation sequences (X shape: {X_val_seq.shape}).") - # 8. Get GRU Predictions on Validation Sequences using *loaded* GRU model + # 10. Get GRU Predictions using loaded GRU model directly gru_model = gru_dependencies['gru_model'] - # Use a temporary handler instance with the loaded model - temp_gru_handler = GRUModelHandler(run_id="temp_predict", models_dir="temp") - temp_gru_handler.model = gru_model # Assign the loaded model - predictions_val = temp_gru_handler.predict(X_val_seq) - if predictions_val is None or len(predictions_val) < 3: - raise ValueError("GRU prediction on validation sequences failed.") - mu_val_pred = predictions_val[0].flatten() - log_sigma_val_pred = predictions_val[1][:, 1].flatten() - p_raw_val_pred = predictions_val[2].flatten() - sigma_val_pred = np.exp(log_sigma_val_pred) - logger.info("Generated GRU predictions on validation sequences.") - + logger.info(f"Generating GRU predictions using loaded model (type: {type(gru_model)}).") + + # Check model type and predict accordingly + if not hasattr(gru_model, 'predict'): + raise TypeError("Loaded GRU model object does not have a 'predict' method.") + + predictions_val = gru_model.predict(X_val_seq) + + # --- Explicit Shape Logging --- # + if isinstance(predictions_val, list): + logger.info(f"DEBUG: GRU prediction output list length: {len(predictions_val)}") + if len(predictions_val) > 0: logger.info(f"DEBUG: Shape of predictions_val[0]: {predictions_val[0].shape}") + if len(predictions_val) > 1: logger.info(f"DEBUG: Shape of predictions_val[1]: {predictions_val[1].shape}") + if len(predictions_val) > 2: logger.info(f"DEBUG: Shape of predictions_val[2]: {predictions_val[2].shape}") + else: + logger.info(f"DEBUG: GRU prediction output type: {type(predictions_val)}") + # --- End Explicit Shape Logging --- # + + # --- Infer output structure (Corrected Indexing) --- + mu_val_pred, sigma_val_pred, p_raw_val_pred, logits_val_pred = None, None, None, None + + if isinstance(predictions_val, list) and len(predictions_val) == 2: + logger.debug(f"GRU model returned list of {len(predictions_val)} outputs. Correcting index assumption: [mu, dir].") + # Corrected Indexing based on logs: + mu_output = predictions_val[0] # Index 0 has shape (N, 1) + dir_output = predictions_val[1] # Index 1 has shape (N, 3) + + # Extract Mu + if mu_output.ndim == 2 and mu_output.shape[-1] == 1: + mu_val_pred = mu_output.flatten() + else: + # Corrected error message to reflect index 0 check + raise ValueError(f"Unexpected shape for mu output (index 0): {mu_output.shape}. Expected (N, 1).") + + # Handle missing Sigma - Use a default/fallback + logger.warning("Log_sigma_sq output not found in model prediction. Using default sigma=0.1.") + sigma_val_pred = np.ones_like(mu_val_pred) * 0.1 + + # Determine direction output type from dir_output (index 1) + if self.use_ternary: + if dir_output.ndim == 2 and dir_output.shape[-1] == 3: + # Assume dir_output (index 1) is logits if ternary + logits_val_pred = dir_output + from scipy.special import softmax # Local import ok here + p_raw_val_pred = softmax(logits_val_pred, axis=-1) + logger.info("Inferred ternary output (logits assumed at index 1). Calculated raw probabilities.") + else: + raise ValueError(f"Expected ternary output shape (N, 3) at index 1, got {dir_output.shape}") + else: # Binary + # Binary case is tricky if dir_output (index 1) is still (N, 3) + # This shouldn't happen if the model structure changes for binary + # Let's assume the mu_output (index 0) might be P(up) in binary? + logger.warning(f"Binary mode, dir_output shape is {dir_output.shape}. Trying to use mu output (index 0) as P(up) - THIS IS RISKY.") + if mu_output.ndim == 2 and mu_output.shape[-1] == 1: + p_raw_val_pred = mu_output.flatten() + epsilon = 1e-7 + p_clipped = np.clip(p_raw_val_pred, epsilon, 1 - epsilon) + logits_val_pred = np.log(p_clipped / (1 - p_clipped)) + logger.info("Using mu output (index 0) as P(up) for binary mode. Inferred raw logits.") + else: + raise ValueError(f"Cannot determine binary P(up) output. mu_output shape: {mu_output.shape}") + else: + raise ValueError(f"GRU prediction failed or returned unexpected type/structure: {type(predictions_val)}, length: {len(predictions_val) if isinstance(predictions_val, list) else 'N/A'}. Expected list of 2 arrays.") + + # Final check for necessary predictions + if mu_val_pred is None or sigma_val_pred is None or p_raw_val_pred is None: + raise ValueError("Failed to extract mu, sigma, or raw probabilities from GRU model output.") + # Verify lengths n_seq = len(X_val_seq) if not (len(mu_val_pred) == n_seq and len(sigma_val_pred) == n_seq and \ - len(p_raw_val_pred) == n_seq and len(actual_ret_val_seq) == n_seq): - raise ValueError(f"Length mismatch after validation predictions: Expected {n_seq}, got mu={len(mu_val_pred)}, sigma={len(sigma_val_pred)}, p_raw={len(p_raw_val_pred)}, ret={len(actual_ret_val_seq)}") + p_raw_val_pred.shape[0] == n_seq and len(actual_ret_val_seq) == n_seq): + raise ValueError(f"Length mismatch after validation predictions: Expected {n_seq}, " + f"got mu={len(mu_val_pred)}, sigma={len(sigma_val_pred)}, " + f"p_raw={p_raw_val_pred.shape[0]}, ret={len(actual_ret_val_seq)}") - # 9. Calibrate Predictions using *loaded* optimal_T - optimal_T = gru_dependencies['optimal_T'] - # Use a temporary calibrator instance - temp_calibrator = Calibrator(edge_threshold=0.5) # Edge threshold doesn't matter here - temp_calibrator.optimal_T = optimal_T - p_cal_val_pred = temp_calibrator.calibrate(p_raw_val_pred) - logger.info(f"Calibrated validation predictions using loaded T={optimal_T:.4f}.") + # 11. Calibrate Predictions using loaded parameters + calib_method = gru_dependencies.get('calibration_method') + calib_params = gru_dependencies.get('calibration_params') + p_cal_val_pred = None - # 10. Return the necessary components for the TradingEnv + if calib_method == "temperature" and calib_params is not None: + optimal_T = calib_params + # Temperature scaling needs logits + if logits_val_pred is not None: + scaled_logits = logits_val_pred / optimal_T + if self.use_ternary: + p_cal_val_pred = softmax(scaled_logits, axis=-1) + logger.info(f"Applied temperature scaling (T={optimal_T:.4f}) to ternary logits.") + else: # Binary - apply sigmoid + p_cal_val_pred = 1 / (1 + np.exp(-scaled_logits)) # Sigmoid + logger.info(f"Applied temperature scaling (T={optimal_T:.4f}) to binary logits.") + else: + logger.error(f"Cannot apply temperature scaling (method={calib_method}): Raw logits not available.") + p_cal_val_pred = p_raw_val_pred # Fallback to raw + + elif calib_method == "vector" and calib_params is not None: + # Vector calibration needs logits + if logits_val_pred is not None: + try: + from gru_sac_predictor.src.calibrator_vector import VectorCalibrator # Ensure import + temp_vector_calibrator = VectorCalibrator() + temp_vector_calibrator.optimal_params = calib_params # Set loaded params + # Calibrate expects logits, returns probabilities + p_cal_val_pred = temp_vector_calibrator.calibrate(logits_val_pred) + logger.info(f"Calibrated validation predictions using loaded vector parameters.") + except ImportError: + logger.error("Cannot import VectorCalibrator. Vector calibration step skipped.") + p_cal_val_pred = p_raw_val_pred # Fallback to raw + except Exception as vec_e: + logger.error(f"Error during vector calibration application: {vec_e}", exc_info=True) + p_cal_val_pred = p_raw_val_pred # Fallback to raw + else: + logger.error(f"Cannot apply vector calibration (method={calib_method}): Raw logits not available.") + p_cal_val_pred = p_raw_val_pred # Fallback to raw + + else: # No calibration or failed + logger.warning(f"Calibration method was '{calib_method}' or params were None or previous step failed. Using RAW predictions for SAC environment.") + p_cal_val_pred = p_raw_val_pred # Use the raw probs (softmaxed for ternary, P(up) for binary) + + # Final check for calibrated probabilities + if p_cal_val_pred is None: + logger.error("Failed to obtain final (calibrated or raw) predictions. Cannot proceed.") + return None + # Verify shape of final probabilities + expected_final_dim = 3 if self.use_ternary else 1 + if p_cal_val_pred.ndim == 2 and p_cal_val_pred.shape[-1] == expected_final_dim: + if not self.use_ternary: p_cal_val_pred = p_cal_val_pred.flatten() # Flatten binary case + elif p_cal_val_pred.ndim == 1 and not self.use_ternary and expected_final_dim == 1: + pass # Already flat for binary + else: + logger.error(f"Final probability array has unexpected shape: {p_cal_val_pred.shape}. Expected dim {expected_final_dim}.") + return None + + + # 12. Return the necessary components for the TradingEnv logger.info("--- Successfully prepared validation data for SAC Environment ---") return mu_val_pred, sigma_val_pred, p_cal_val_pred, actual_ret_val_seq @@ -512,358 +755,206 @@ class SACTrainer: logger.warning(f"SAC agent path not found for resume: {load_path}. Starting fresh.") def _training_loop(self, agent: SACTradingAgent, env: TradingEnv) -> str | None: - """Runs the main SAC training loop.""" - buffer_max_size = self.sac_cfg.get('buffer_max_size', 100000) - min_buffer_size = self.sac_cfg.get('min_buffer_size', 10000) + """The main SAC training loop.""" + total_steps = self.sac_cfg.get('total_training_steps', 100000) + start_steps = self.sac_cfg.get('start_steps', 10000) + update_after = self.sac_cfg.get('update_after', 1000) + update_every = self.sac_cfg.get('update_every', 50) + save_freq = self.sac_cfg.get('save_freq', 5000) + log_freq = self.sac_cfg.get('log_freq', 100) + buffer_capacity = self.sac_cfg.get('buffer_capacity', 1000000) batch_size = self.sac_cfg.get('batch_size', 256) - total_training_steps = self.sac_cfg.get('total_training_steps', 100000) - - # --- Initialize Replay Buffer (Potentially PER) --- # + + # Initialize Replay Buffer (Standard or Prioritized) if self.use_per: - logger.info(f"Initializing Prioritized Replay Buffer (Capacity={buffer_max_size}, alpha={self.per_alpha}, beta_start={self.per_beta_start}, beta_frames={self.per_beta_frames})") + logger.info(f"Using Prioritized Replay Buffer (Capacity: {buffer_capacity})") replay_buffer = PrioritizedReplayBuffer( - buffer_max_size, - alpha=self.per_alpha, - beta_start=self.per_beta_start, + capacity=buffer_capacity, + alpha=self.per_alpha, # Initial alpha + beta_start=self.per_beta_start, beta_frames=self.per_beta_frames ) else: - logger.info(f"Initializing Standard Replay Buffer (Deque, Capacity={buffer_max_size})") - replay_buffer = collections.deque(maxlen=buffer_max_size) - replay_buffer.counter = 0 # Add counter for uniform sampling logic - # --- End Buffer Init --- # + logger.info(f"Using Standard Replay Buffer (Capacity: {buffer_capacity})") + replay_buffer = collections.deque(maxlen=buffer_capacity) - # --- Oracle Seeding (Revision 4-B) --- # - oracle_seeding_pct = self.sac_cfg.get('oracle_seeding_pct', 0.0) - num_existing_samples = len(replay_buffer) # Count samples potentially loaded during resume - target_seed_steps = int(buffer_max_size * oracle_seeding_pct) - actual_seed_steps = max(0, target_seed_steps - num_existing_samples) - - if actual_seed_steps > 0: - logger.info(f"Performing Oracle Seeding: Adding ~{actual_seed_steps} steps ({oracle_seeding_pct * 100:.1f}% target) with heuristic policy...") - - # Need edge threshold from config (used in the heuristic action) - edge_threshold_heuristic = self.config.get('calibration', {}).get('edge_threshold') - if edge_threshold_heuristic is None: - logger.error("Cannot perform oracle seeding: 'calibration.edge_threshold' not found in config.") - elif edge_threshold_heuristic <= 0: - logger.warning(f"Edge threshold for heuristic is {edge_threshold_heuristic:.3f}. Oracle seeding action may be ill-defined. Using 0.01 instead.") - edge_threshold_heuristic = 0.01 - else: - state = env.reset() # Start seeding from the beginning of the env data - n_seeded = 0 - try: - for _ in tqdm(range(actual_seed_steps), desc="Oracle Seeding", file=sys.stdout, leave=False): - if len(replay_buffer) >= buffer_max_size: - logger.warning("Buffer full during oracle seeding.") - break - - current_state_for_seeding = state # Keep original state before potential filtering - if self.state_filter: - state = self.state_filter(state, update=False) # Apply filter, don't update filter stats during seeding - - # Heuristic action based on edge - # State: [mu, sigma, edge, |mu|/sigma, position] (assuming this order) - try: - # Ensure state has enough elements before indexing - if len(state) < 3: - logger.error(f"State vector too short ({len(state)} elements) for oracle seeding. Stopping seeding.") - break - edge = state[2] # Assuming edge is the 3rd element (index 2) - except IndexError: - logger.error(f"IndexError accessing edge (state[2]) during oracle seeding. State shape: {state.shape if hasattr(state, 'shape') else type(state)}. Stopping seeding.") - break - - oracle_action = np.clip(edge / edge_threshold_heuristic, -1.0, 1.0) - - next_state, reward, done, _ = env.step(oracle_action) - - # Store experience with the state *before* filtering (as agent expects raw state) - experience = (current_state_for_seeding, oracle_action, reward, next_state, done) - if self.use_per: - # Mark as seeded - replay_buffer.add(replay_buffer.max_priority, experience, is_seeded=True) - elif hasattr(replay_buffer, 'counter'): - replay_buffer.append(experience) - replay_buffer.counter += 1 - n_seeded += 1 - - state = next_state - if done: - # Reset env if it finishes during seeding, continue seeding if needed - logger.info("Environment finished during oracle seeding. Resetting.") - state = env.reset() - logger.info(f"Oracle seeding completed. Added {n_seeded} experiences.") - except Exception as seed_err: - logger.error(f"Error during oracle seeding loop: {seed_err}. Proceeding with {n_seeded} seeded samples.", exc_info=True) - + # TensorBoard setup + tb_callback = TensorBoard(log_dir=self.sac_tb_log_dir) + # --- Revision 4: Set model for TensorBoard --- # + # Check if agent has actor/critic models accessible + # This depends heavily on SACTradingAgent implementation + # Assuming agent.actor and agent.critic1/2 are the models + if hasattr(agent, 'actor') and hasattr(agent, 'critic1') and hasattr(agent, 'critic2'): + tb_callback.set_model(agent.actor) # Link to one model is often sufficient + logger.info("TensorBoard callback linked to SAC agent model (actor).") else: - logger.info("Oracle seeding skipped (percentage=0 or buffer already seeded enough). Existing samples: {num_existing_samples}") - # --- End Oracle Seeding --- # - - # --- Training Loop Setup --- # - start_learning_after_steps = self.sac_cfg.get('start_learning_after_steps', 1000) - save_checkpoint_freq = self.sac_cfg.get('save_checkpoint_freq_steps', 50000) - total_steps = self.sac_cfg.get('total_training_steps', 100000) - - # Revision 5: Alpha annealing params - alpha_start = self.sac_cfg.get('per_alpha_start', 0.6) - alpha_end = self.sac_cfg.get('per_alpha_end', 0.4) - current_alpha = alpha_start # Initialize alpha + logger.warning("Could not link TensorBoard callback to agent models (actor/critic not found).") + # --- End Revision 4 --- - summary_writer = tf.summary.create_file_writer(self.sac_tb_log_dir) - episode_reward = 0 - episode_steps = 0 - episode_rewards_log = [] # For saving reward history - best_eval_score = -np.inf # Placeholder for potential periodic evaluation - final_agent_path = None - - logger.info(f"Starting SAC training loop for {total_training_steps} steps...") - logger.info(f" Min buffer size: {min_buffer_size}, Batch size: {batch_size}, Updates/step: {updates_per_step}") - logger.info(f" Log interval: {log_interval}, Checkpoint interval: {checkpoint_interval}") - - summary_writer = tf.summary.create_file_writer(self.sac_tb_log_dir) - pbar = tqdm(range(1, total_training_steps + 1), desc="SAC Training", file=sys.stdout) - last_saved_path = None # Track last successful save + # --- Initialize optional imputed transition logger --- # + imputed_log_path = os.path.join(self.sac_run_results_dir, 'sac_imputed_transitions.csv') + imputed_log_file = None + imputed_csv_writer = None + try: + imputed_log_file = open(imputed_log_path, 'w', newline='') + imputed_csv_writer = csv.writer(imputed_log_file) + imputed_csv_writer.writerow(['step', 'imputed_handling_mode', 'action', 'reward', 'position_before', 'position_after']) + logger.info(f"Logging imputed transitions details to: {imputed_log_path}") + except IOError as e: + logger.error(f"Failed to open imputed transition log file {imputed_log_path}: {e}. Logging disabled.") + if imputed_log_file: imputed_log_file.close() + imputed_log_file = None + imputed_csv_writer = None + # --- End init imputed logger --- # - # --- Main Training Loop --- # - for step in pbar: - # --- Revision 5: Update Annealed Alpha --- # - if self.use_per: - alpha_fraction = min(1.0, step / total_steps) - current_alpha = alpha_start + alpha_fraction * (alpha_end - alpha_start) - # Calculate Seed Decay factor for IS weights - seed_decay_factor = max(0.0, 1.0 - step / self.per_seed_decay_steps) - # --- End Revision 5 --- # - - original_state = state # Keep original state for buffer - if self.state_filter: - state = self.state_filter(state, update=True) # Apply filter and update running stats - - if step < start_learning_after_steps: - # Use random actions during warmup to explore - action = np.random.uniform(env.action_space.low, env.action_space.high, size=env.action_space.shape) - else: - # Get action from agent - action, _ = agent.select_action(state) - - # Environment step - next_state, reward, done, info = env.step(action[0]) # Env expects single float action - - # Store experience in buffer (use original state) - experience = (original_state, action, reward, next_state, done) - if self.use_per: - # Mark as seeded - replay_buffer.add(replay_buffer.max_priority, experience, is_seeded=True) - elif hasattr(replay_buffer, 'counter'): - replay_buffer.append(experience) - replay_buffer.counter += 1 # Increment counter for deque - - state = next_state - current_episode_reward += reward - current_episode_steps += 1 - - # Perform SAC updates - if len(replay_buffer) >= min_buffer_size: - for _ in range(updates_per_step): - sample_indices, batch_with_seed_flags, importance_weights = None, None, None # Initialize - # --- Sample from Trainer's Buffer --- # - if self.use_per: - sample_indices, batch_with_seed_flags, importance_weights = replay_buffer.sample(batch_size) - # --- Revision 5: Apply Seed Decay to IS Weights --- # - batch = [] # Store only the samples - seed_mask = np.zeros_like(importance_weights, dtype=bool) - for i, (sample, is_seeded) in enumerate(batch_with_seed_flags): - batch.append(sample) - if is_seeded: - seed_mask[i] = True - # Apply decay factor to IS weights of seeded samples - # --- Revision 6: Log IS weight decay --- # - num_seeded_in_batch = np.sum(seed_mask) - if num_seeded_in_batch > 0: - original_weights_seeded = importance_weights[seed_mask].copy() - importance_weights[seed_mask] *= seed_decay_factor - weights_after_decay = importance_weights[seed_mask] - # Log periodically or if decay factor changes significantly - if step % self.log_interval == 0: # Align with other logging - logger.info(f"Step {step}: Applied IS decay factor {seed_decay_factor:.4f} to {num_seeded_in_batch} seeded samples in batch.") - # Optional: Log weight changes - # logger.debug(f" Weights before: {np.round(original_weights_seeded, 3)}") - # logger.debug(f" Weights after: {np.round(weights_after_decay, 3)}") - # --- End Revision 6 --- # - importance_weights_tensor = tf.convert_to_tensor(importance_weights, dtype=tf.float32) - else: # Uniform sampling from deque - current_buffer_size = min(replay_buffer.counter, replay_buffer.maxlen) - if current_buffer_size < batch_size: - sample_indices = np.random.choice(current_buffer_size, batch_size, replace=True) - else: - sample_indices = np.random.choice(current_buffer_size, batch_size, replace=False) - batch = [replay_buffer[i] for i in sample_indices] - importance_weights_tensor = None # No IS weights for uniform - # --- End Sampling --- # - - state_batch, action_batch, reward_batch, next_state_batch, done_batch = map(np.stack, zip(*batch)) - - # Apply state filter to sampled states if enabled - if self.state_filter: - state_batch_filtered = self.state_filter(state_batch, update=False) - next_state_batch_filtered = self.state_filter(next_state_batch, update=False) - else: - state_batch_filtered = state_batch - next_state_batch_filtered = next_state_batch - - # Convert batch to tensors - state_tensor = tf.convert_to_tensor(state_batch_filtered, dtype=tf.float32) - action_tensor = tf.convert_to_tensor(action_batch, dtype=tf.float32) - reward_tensor = tf.convert_to_tensor(reward_batch, dtype=tf.float32) - next_state_tensor = tf.convert_to_tensor(next_state_batch_filtered, dtype=tf.float32) - done_tensor = tf.convert_to_tensor(done_batch, dtype=tf.float32) - - # --- Call agent's train method with the batch --- # - loss_info = agent.train( - state_tensor, - action_tensor, - reward_tensor, - next_state_tensor, - done_tensor, - importance_weights=importance_weights_tensor # Pass potentially decayed weights for PER - ) - # --- End Agent Update Call --- # - - # --- Revision 5: Update PER priorities with current alpha --- # - if self.use_per and loss_info is not None: - td_errors = loss_info.get('td_errors') # Get TD errors from loss_info dict - if td_errors is not None: - replay_buffer.update_priorities(sample_indices, td_errors, current_alpha) - # --- Revision 6: Log TD Error Distribution --- # - # Convert to numpy if it's a tensor - if tf.is_tensor(td_errors): - td_errors = td_errors.numpy() - td_error_abs = np.abs(td_errors) - log_hist_interval = self.sac_cfg.get('log_hist_interval', 5000) - # Log percentiles every step where update happens - step_info_for_logging = { - 'td_error_abs_p25': np.percentile(td_error_abs, 25), - 'td_error_abs_p50': np.percentile(td_error_abs, 50), - 'td_error_abs_p75': np.percentile(td_error_abs, 75), - 'td_error_abs_p95': np.percentile(td_error_abs, 95), - 'td_error_abs_max': np.max(td_error_abs), - 'td_error_abs_mean': np.mean(td_error_abs) - } - # Log histogram to TensorBoard periodically - if log_hist_interval > 0 and step % log_hist_interval == 0: - try: - with summary_writer.as_default(step=step): - tf.summary.histogram('td_error_abs_distribution', td_error_abs, description='Absolute TD Errors') - summary_writer.flush() - logger.debug(f"Step {step}: Logged TD error histogram to TensorBoard.") - except Exception as hist_err: - logger.warning(f"Failed to log TD error histogram to TensorBoard at step {step}: {hist_err}") - # --- End Revision 6 --- # - else: - logger.warning(f"Step {step}: TD errors not found in loss_info from agent.train(). Cannot update PER priorities or log distribution.") - # --- End Revision 5 --- # - - # Logging and Checkpointing - if step % log_interval == 0 and len(replay_buffer) >= min_buffer_size: - with summary_writer.as_default(): # Use context manager - # Log loss_info if available - if 'loss_info' in locals() and loss_info: - for k, v in loss_info.items(): - # Skip logging raw TD errors tensor here - if k != 'td_errors': - tf.summary.scalar(f'loss/{k}', v, step=step) - tf.summary.scalar('alpha', agent.log_alpha.numpy().item(), step=step) - tf.summary.scalar('buffer/beta', replay_buffer.beta, step=step) - tf.summary.scalar('buffer/per_alpha', current_alpha, step=step) - # Log TD error percentiles calculated in Revision 6 - if 'step_info_for_logging' in locals(): - for k, v in step_info_for_logging.items(): - tf.summary.scalar(f'td_error/{k}', v, step=step) - logger.debug(f"Step {step}: Logged losses, alpha, PER params, TD error stats.") - - if step % checkpoint_interval == 0: - # Save checkpoint - chkpt_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step}') - meta_data = {'step': step, 'edge_threshold': edge_threshold_heuristic if 'edge_threshold_heuristic' in locals() else self.config.get('calibration',{}).get('edge_threshold')} - agent.save(chkpt_path, meta_data=meta_data) - # --- Save State Filter (Task 5.2) --- # - if self.state_filter: - filter_state = self.state_filter.get_state() - np.savez(os.path.join(chkpt_path, 'state_filter.npz'), **filter_state) - # --- End Save State Filter --- # - logger.info(f"Saved SAC checkpoint at step {step} to {chkpt_path}") - - # Episode end handling - if done: - logger.debug(f"Episode {episode_count} finished after {current_episode_steps} steps. Reward: {current_episode_reward:.2f}. Buffer size: {len(replay_buffer)}.") - episode_rewards_log.append({'episode': episode_count, 'total_step': step, 'episode_reward': current_episode_reward, 'episode_steps': current_episode_steps}) - with summary_writer.as_default(): # Use context manager - tf.summary.scalar('reward/episode', current_episode_reward, step=step) - tf.summary.scalar('steps/episode', current_episode_steps, step=step) - - # Reset for next episode - state = env.reset() - current_episode_reward = 0.0 - current_episode_steps = 0 - else: - state = next_state - - # Update progress bar description - if step % 100 == 0 and len(replay_buffer) >= min_buffer_size: - last_reward = episode_rewards_log[-1]['episode_reward'] if episode_rewards_log else np.nan - pbar.set_description(f"SAC Training | Ep Reward (Last): {last_reward:.2f} | Buffer: {len(replay_buffer)}") - - # Save agent checkpoint - if (step + 1) % checkpoint_interval == 0 or step == total_training_steps - 1: - save_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step + 1}') - agent.save_weights(save_path) - logger.info(f"SAC agent weights saved at step {step + 1} to {save_path}") - # --- Save State Filter (Task 5.2) --- # - if self.state_filter: - state_filter_path = os.path.join(self.sac_run_models_dir, f'state_filter_step_{step + 1}.npz') - try: - self.state_filter.save_npz(state_filter_path) - logger.info(f"State filter saved to {state_filter_path}") - except Exception as e: - logger.error(f"Failed to save state filter: {e}") - # --- End Save State Filter --- # - last_saved_path = save_path # Update last saved path - # Also save the reward log periodically - if episode_rewards_log: - rewards_df = pd.DataFrame(episode_rewards_log) - rewards_log_path = os.path.join(self.sac_run_logs_dir, 'episode_rewards.csv') - try: - rewards_df.to_csv(rewards_log_path, index=False) - except Exception as e: - logger.error(f"Failed to save episode rewards log: {e}") - - # --- Final Save --- # - pbar.close() - summary_writer.close() - final_save_path = os.path.join(self.sac_run_models_dir, 'sac_agent_final') - agent.save_weights(final_save_path) - logger.info(f"Final SAC agent weights saved to {final_save_path}") - # Save final state filter + state = env.reset() + # Normalize initial state if filter is active if self.state_filter: - final_state_filter_path = os.path.join(self.sac_run_models_dir, 'state_filter_final.npz') - try: - self.state_filter.save_npz(final_state_filter_path) - logger.info(f"Final state filter saved to {final_state_filter_path}") - except Exception as e: - logger.error(f"Failed to save final state filter: {e}") - - # Save final rewards log - if episode_rewards_log: - rewards_df = pd.DataFrame(episode_rewards_log) - rewards_log_path = os.path.join(self.sac_run_logs_dir, 'episode_rewards.csv') - try: - rewards_df.to_csv(rewards_log_path, index=False) - logger.info(f"Final episode rewards log saved to {rewards_log_path}") - except Exception as e: - logger.error(f"Failed to save final episode rewards log: {e}") + state = self.state_filter(state, update=True) # Update filter with initial state - return final_save_path if os.path.exists(final_save_path) else last_saved_path + total_reward = 0.0 + start_time = time.time() + + logger.info(f"Starting SAC training loop for {total_steps} steps...") + for step in tqdm(range(total_steps), desc="SAC Training Steps"): + # Store state before taking action (for replay buffer) + state_before_action = state + position_before_action = env.current_position # Store for imputed log + + # Select action + if step < start_steps: + action = env.action_space.sample()[0] # Sample random action + else: + action = agent.select_action(state) + + # Step the environment + next_state_raw, reward, done, info = env.step(action) + + # Check if the step was skipped due to imputed bar handling + is_skipped = info.get('is_imputed_step_skipped', False) + + # Normalize next state if filter is active + if self.state_filter: + next_state = self.state_filter(next_state_raw, update=True) # Update filter + else: + next_state = next_state_raw + + # Store transition ONLY if the step was NOT skipped + if not is_skipped: + if self.use_per: + # Add with initial error=1 (or max priority) and is_seeded=False + # The error will be updated after the first training step on this sample. + replay_buffer.add(error=1.0, sample=(state_before_action, action, reward, next_state, done)) + else: + replay_buffer.append((state_before_action, action, reward, next_state, done)) + else: + logger.debug(f"Step {step}: Skipped adding transition to buffer due to imputed bar handling (mode=skip).") + + # --- Log imputed transition details to CSV --- # + # Check if the *previous* step was imputed and handled by hold/penalty + # env.current_step was incremented *inside* env.step() + was_imputed_idx = env.current_step - 1 + if 0 <= was_imputed_idx < env.n_steps and env.bar_imputed[was_imputed_idx] and not is_skipped and imputed_csv_writer: + # Don't log if skipped, log if hold/penalty was applied + imputed_handling_mode = self.config.sac.get('imputed_handling', 'unknown') + action_taken = env.current_position if imputed_handling_mode == 'hold' else action # Action applied or agent's intended action + position_after_action = env.current_position + try: + imputed_csv_writer.writerow([ + was_imputed_idx, # Log the actual step number where imputation occurred + imputed_handling_mode, + f"{action_taken:.4f}", + f"{reward:.6f}", # Log the reward received for this step + f"{position_before_action:.4f}", + f"{position_after_action:.4f}" + ]) + except Exception as log_e: + logger.warning(f"Failed to write imputed transition to CSV: {log_e}") + # --- End Log imputed transition --- # + + state = next_state + total_reward += reward + + # Perform SAC agent updates + if step >= update_after and step % update_every == 0: + for update_i in range(update_every): # Perform multiple updates per interval + if self.use_per: + if len(replay_buffer) > batch_size: + idxs, batch_data, is_weights = replay_buffer.sample(batch_size, beta=replay_buffer.beta) # Use annealed beta + # Unpack batch_data which contains (sample, is_seeded) tuples + batch = [item[0] for item in batch_data] + update_info = agent.update(batch, is_weights=is_weights, per_beta=replay_buffer.beta) + if update_info and 'td_errors' in update_info: + # Update priorities using current annealed alpha + current_alpha = agent.get_current_per_alpha(step) + replay_buffer.update_priorities(idxs, update_info['td_errors'], alpha=current_alpha) + else: + continue # Not enough samples yet for PER + else: # Standard buffer + if len(replay_buffer) > batch_size: + indices = np.random.choice(len(replay_buffer), size=batch_size, replace=False) + batch = [replay_buffer[i] for i in indices] + update_info = agent.update(batch) + else: + continue # Not enough samples yet + + # Log training metrics (losses, Q-values, alpha) to TensorBoard + if update_info and step % log_freq == 0 and update_i == 0: # Log once per interval + with tb_callback.writer.as_default(): + for key, value in update_info.items(): + if key != 'td_errors': # Don't log TD errors directly + tf.summary.scalar(f'sac/{key}', value, step=step) + # logger.debug(f"Step {step}: Logged SAC metrics to TensorBoard.") + + # Check environment done state + if done: + state = env.reset() + # Normalize reset state + if self.state_filter: + state = self.state_filter(state, update=True) + total_reward = 0.0 # Reset episodic reward + + # Save agent checkpoint periodically + if step % save_freq == 0 and step > 0: + save_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step}') + try: + agent.save(save_path) + # Also save state filter if used + if self.state_filter: + filter_path = os.path.join(save_path, 'state_filter.pkl') + joblib.dump(self.state_filter, filter_path) + logger.info(f"Saved state filter to {filter_path}") + except Exception as e: + logger.error(f"Failed to save agent/filter checkpoint at step {step} to {save_path}: {e}", exc_info=True) + + # --- Final Save --- # + final_save_path = os.path.join(self.sac_run_models_dir, 'sac_agent_final') + try: + agent.save(final_save_path) + if self.state_filter: + filter_path = os.path.join(final_save_path, 'state_filter.pkl') + joblib.dump(self.state_filter, filter_path) + logger.info(f"Saved final state filter to {filter_path}") + self.last_saved_agent_path = final_save_path # Store path for potential return + except Exception as e: + logger.error(f"Failed to save final agent/filter checkpoint to {final_save_path}: {e}", exc_info=True) + self.last_saved_agent_path = None + + end_time = time.time() + training_duration = end_time - start_time + logger.info(f"SAC training loop finished in {training_duration:.2f} seconds.") + logger.info(f"Final agent checkpoint saved to: {self.last_saved_agent_path}") + + # --- Close imputed transition log file --- # + if imputed_log_file: + try: + imputed_log_file.close() + logger.info(f"Closed imputed transition log file: {imputed_log_path}") + except Exception as e: + logger.error(f"Error closing imputed transition log file: {e}") + # --- End close log file --- # + + return self.last_saved_agent_path def train(self, gru_run_id_for_sac: str) -> str | None: """ @@ -918,12 +1009,10 @@ class SACTrainer: initial_lr=self.sac_cfg.get('actor_lr', 3e-4), lr_decay_rate=self.sac_cfg.get('lr_decay_rate', 0.96), decay_steps=self.sac_cfg.get('decay_steps', 100000), - buffer_capacity=self.sac_cfg.get('buffer_max_size', 100000), ou_noise_stddev=self.sac_cfg.get('ou_noise_stddev', 0.2), alpha=self.sac_cfg.get('alpha', 0.2), alpha_auto_tune=self.sac_cfg.get('alpha_auto_tune', True), target_entropy=self.sac_cfg.get('target_entropy', -1.0 * env.action_dim), - min_buffer_size=self.sac_cfg.get('min_buffer_size', 1000), edge_threshold_config=current_edge_threshold, # Pass edge threshold # --- Pass Env Params (Task 5.6) --- # reward_scale_config=reward_scale,