feat: Add imputed transition logging to SACTrainer and update docs

This commit is contained in:
yasha 2025-04-19 02:18:27 +00:00
parent c3526bb9f6
commit 35c60ae848
3 changed files with 741 additions and 602 deletions

View File

@ -1,6 +1,6 @@
# GRU + SAC Crypto Trading System (v3 - Consolidated & Enhanced) # GRU + SAC Crypto Trading System (v3 - Refactored & Enhanced)
This project implements a cryptocurrency trading system using a GRU model for market prediction and a Soft Actor-Critic (SAC) agent for position sizing. This version reflects significant refactoring, feature additions, and the consolidation of model logic. This project implements a cryptocurrency trading system using a GRU model for market prediction and a Soft Actor-Critic (SAC) agent for position sizing. This version reflects significant refactoring for modularity, feature additions, and the consolidation of model logic.
The core idea is to decouple prediction and action: The core idea is to decouple prediction and action:
1. A **GRU model** (v2 or v3 architecture, selected via config) forecasts future log-returns (μ̂) and class probabilities (binary p(up) or ternary p(down, flat, up)). 1. A **GRU model** (v2 or v3 architecture, selected via config) forecasts future log-returns (μ̂) and class probabilities (binary p(up) or ternary p(down, flat, up)).
@ -11,32 +11,71 @@ This approach aims for a robust system where the RL agent focuses solely on risk
## Key Features & Enhancements ## Key Features & Enhancements
* **Consolidated GRU Logic:** Both v2 and v3 GRU model architectures are now implemented and managed within `src/gru_model_handler.py`. * **Modular Pipeline Structure:** Pipeline logic is refactored into stage-specific functions within `src/pipeline_stages/` for improved readability, maintainability, and testability (see Project Structure).
* **Walk-Forward Validation:** Replaces static train/val/test splits with a robust walk-forward validation framework (`TradingPipeline.execute`, `_generate_walk_forward_folds`) for more realistic performance estimation. * **Consolidated GRU Logic:** Both v2 and v3 GRU model architectures are implemented and managed within `src/gru_model_handler.py`.
* **Hyperparameter Optimization (Optuna):** Integrated Optuna sweep for GRU hyperparameters (`src/gru_hyper_tuner.py`) with restricted search space, configurable objective (`edge_acc - brier`), and Keras callback for efficient pruning based on `val_loss`. * **Walk-Forward Validation:** Robust walk-forward validation framework (`TradingPipeline.execute`, `_generate_walk_forward_folds`) for realistic performance estimation.
* **Hyperparameter Optimization (Optuna):** Integrated Optuna sweep for GRU hyperparameters (`src/gru_hyper_tuner.py`) with configurable search space, objective, and pruning.
* **Advanced Calibration:** * **Advanced Calibration:**
* Supports Temperature and Vector Scaling (`calibration.method`) with optional L2 regularization (`calibration.l2_lambda`). * Supports Temperature and Vector Scaling (`calibration.method`) with optional L2 regularization (`calibration.l2_lambda`).
* Optimizes edge threshold via Youden's J on validation data (`calibration.optimize_edge_threshold`). Saved per fold (`optimized_edge_threshold_fold_N.txt`) and used consistently. * Optimizes edge threshold via Youden's J on validation data (`calibration.optimize_edge_threshold`).
* **Rolling Calibration (Experimental):** Implemented within `Backtester` to refit the calibrator periodically during the backtest (`calibration.rolling_enabled`, `recalibrate_every_n`, `recalibration_window`). Uses the **static** calibration from training time if SAC training is active to prevent lookahead. * **Rolling Calibration (Experimental):** Implemented within `Backtester` to refit the calibrator periodically during the backtest (`calibration.rolling_enabled`, `recalibrate_every_n`, `recalibration_window`).
* **Coverage Alarm (ECE-based):** Optional alarm triggers early recalibration if **Expected Calibration Error (ECE)** exceeds a threshold (`calibration.coverage_alarm_enabled`, `ece_recalibration_threshold`). * **Coverage Alarm (ECE-based):** Optional alarm triggers early recalibration if **Expected Calibration Error (ECE)** exceeds a threshold (`calibration.coverage_alarm_enabled`, `ece_recalibration_threshold`).
* **Prioritized Experience Replay (PER):** Implemented for SAC training (`sac.use_per`) with **TD-error clipping** and **alpha annealing** (linear decay). Logs TD error distribution statistics. * **Prioritized Experience Replay (PER):** Implemented for SAC training (`sac.use_per`) with TD-error clipping and alpha annealing.
* **SAC Enhancements:** Reward scaling, state normalization (`MeanStdFilter`), configurable **action penalty** (default: `0.01 / transaction_cost`), oracle seeding with **Importance Sampling (IS) weight decay** (`per_seed_decay_steps`). * **SAC Enhancements:** Reward scaling, state normalization (`MeanStdFilter`), configurable action penalty, oracle seeding with Importance Sampling (IS) weight decay.
* **Refined Validation Gates:** Configurable thresholds (`validation_gates`) for: * **Refined Validation Gates:** Configurable thresholds (`validation_gates`) for baseline checks, GRU validation, fold backtest performance, and final release decisions.
* **Baseline Gate:** Checks Logistic Regression CI on **raw/engineered** training features before scaling. * **Micro-structure Features:** Added bar-level features (`FeatureEngineer._add_microstructure_features`) with NaN guards.
* **GRU Gate:** Checks Edge Acc CI and Brier score on validation set after calibration, using the fold's **determined edge threshold**. * **Leakage Guard:** Feature calculations use `shift(1)`. Selection includes correlation check against future returns. Minimal whitelist applied before VIF.
* **Final Release Decision:** Checks aggregated metrics (e.g., **median Sharpe ≥ 1.3**, **≥ 75% successful folds**) across all successful folds. Backtest gate failures *per fold* are logged but do not halt the entire pipeline. * **Configuration:** Centralized and expanded `config.yaml`.
* **Micro-structure Features:** Added bar-level features (`FeatureEngineer._add_microstructure_features`) with **NaN guards** for robustness.
* **Leakage Guard:** Feature calculations use `shift(1)`. Selection occurs on **raw/engineered features** and includes correlation check (`corr(ret+h, feat_t-1)`) against future returns. Minimal whitelist applied before VIF.
* **Configuration:** Centralized and expanded `config.yaml` with annotations for mutually exclusive options (e.g., walk-forward vs static split).
* **Output Management:** Standardized output structure via `IOManager` and `LoggerSetup`. * **Output Management:** Standardized output structure via `IOManager` and `LoggerSetup`.
* **SAC Agent Aggregation:** Optional post-processing step to average weights from agents trained across successful folds (`TradingPipeline.aggregate_sac_agents`, `sac_aggregation.enabled`). * **SAC Agent Aggregation:** Optional post-processing step to average weights from agents trained across successful folds.
## Data Quality
This system includes mechanisms to handle potential missing data points (bars) in the input time series. This is crucial for maintaining data integrity and preventing errors during feature engineering and model training.
**Handling Missing Bars:**
* **Detection:** The pipeline automatically detects missing bars based on the expected `data.bar_frequency` (e.g., "1T" for 1 minute) after initial data loading.
* **Reporting:** A warning is logged detailing the total number of missing bars found and the length of the longest consecutive gap. A summary report (`missing_bars_summary.json`) is saved in the run's results directory.
* **Filling Strategies:** Several strategies are available, configured via `data.missing.strategy`:
* `"drop"`: No filling is performed. Missing bars remain gaps or NaNs. (Use with caution).
* `"neutral"`: Forward-fills the 'close' price, sets 'open', 'high', 'low' equal to the filled 'close', and sets 'volume' to 0 for imputed bars.
* `"ffill"`: Forward-fills all OHLCV columns, then back-fills any remaining NaNs at the beginning.
* `"interpolate"`: Interpolates missing values using the method specified in `data.missing.interpolate.method` (e.g., 'linear') up to a `limit` defined in `data.missing.interpolate.limit`.
* **Imputed Flag:** After filling, a boolean column `bar_imputed` is added to the DataFrame, marking rows that were originally missing.
* **Max Gap Check:** The pipeline will raise an error if the longest detected consecutive gap exceeds `data.missing.max_gap`.
**Impact on Downstream Components:**
* **Feature Engineering:** Features are calculated on the potentially gap-filled data.
* **Sequence Creation:** Sequences containing imputed bars can be optionally dropped before GRU training, controlled by `gru.drop_imputed_sequences`.
* **GRU Model:** The `bar_imputed` flag is included as a feature input to the GRU model, allowing it to potentially learn patterns related to imputed data.
* **SAC Environment:** The `TradingEnv` is aware of imputed bars. The behavior during an imputed step is controlled by `sac.imputed_handling`:
* `"skip"`: The environment skips the step, no action is taken, no reward is given, and the transition is not added to the replay buffer.
* `"hold"`: The agent's action is overridden to maintain its current position. The step proceeds normally otherwise (reward calculated based on held position).
* `"penalty"`: The agent's chosen action is taken, but a penalty reward (based on `sac.action_penalty`) is applied instead of the normal PnL reward.
**Recommended Defaults:**
Using `"neutral"` or `"ffill"` for `strategy` is generally recommended for continuous time series. `max_gap` should be set to a reasonably small number (e.g., 5-10) to avoid filling excessively long gaps with potentially inaccurate data. For the SAC environment, `"hold"` or `"skip"` are common choices, depending on whether you want the agent to explicitly learn from imputed steps (or lack thereof).
## System Design & Workflow ## System Design & Workflow
The system is orchestrated by `run.py`, which sets up logging and I/O via `LoggerSetup` and `IOManager`, then instantiates and executes the `TradingPipeline` class (`src/trading_pipeline.py`). The pipeline follows a sequence of steps, potentially looped for Walk-Forward validation. The system is orchestrated by `run.py`, which sets up logging and I/O via `LoggerSetup` and `IOManager`, then instantiates and executes the `TradingPipeline` class (`src/trading_pipeline.py`). The pipeline follows a sequence of steps, potentially looped for Walk-Forward validation.
### Pipeline Stages (Refactored)
The core logic for each step in the pipeline has been moved into dedicated functions within the `src/pipeline_stages/` directory. The `TradingPipeline` class now acts primarily as an orchestrator, calling these stage functions in sequence and managing the overall state and data flow.
* **`src/pipeline_stages/data_processing.py`**: Handles loading, initial preprocessing, feature engineering, labeling, and data splitting logic for each fold.
* **`src/pipeline_stages/feature_processing.py`**: Manages feature scaling, selection (L1+VIF), and pruning based on the selected whitelist.
* **`src/pipeline_stages/sequence_creation.py`**: Creates input sequences suitable for the GRU model from the processed feature data.
* **`src/pipeline_stages/modelling.py`**: Contains functions for training/loading the GRU model (including hyperparameter tuning), calibrating probabilities (Temperature/Vector Scaling, edge threshold optimization), training/loading the SAC agent, and aggregating SAC agents.
* **`src/pipeline_stages/evaluation.py`**: Includes functions for running baseline checks (Logistic Regression), performing GRU validation checks (Edge Accuracy, Brier Score), and executing the main backtest simulation (instantiating and running the `Backtester`).
### Workflow Diagram
```mermaid ```mermaid
%%{init: {'themeVariables': { 'fontSize': '26px' }}}%%
graph TD graph TD
A[run.py: Init Logger/IOManager/Config] --> B(TradingPipeline); A[run.py: Init Logger/IOManager/Config] --> B(TradingPipeline);
@ -47,21 +86,21 @@ graph TD
D --> F[Select Fold Data]; D --> F[Select Fold Data];
E --> F; E --> F;
subgraph Fold Processing [Fold Processing] subgraph Fold Processing [Fold Processing - Calls Stage Functions]
direction TB direction TB
F --> G[Engineer Features]; F --> G[data_processing: Engineer Features];
G --> H[Define Labels & Align]; G --> H[data_processing: Define Labels & Align];
H --> I[Split Fold Data]; H --> I[data_processing: Split Fold Data];
I --> J1[Baseline Check]; I --> J1[evaluation: Baseline Check];
J1 -- Pass --> L[Select Features]; J1 -- Pass --> L[feature_processing: Select Features];
L --> K[Scale Features]; L --> K[feature_processing: Scale Features];
K --> M[Prune Scaled Features]; K --> M[feature_processing: Prune Scaled Features];
M --> N[Create Sequences]; M --> N[sequence_creation: Create Sequences];
N --> O[Train/Load GRU]; N --> O[modelling: Train/Load GRU];
O --> P[Calibrate Probabilities]; O --> P[modelling: Calibrate Probabilities];
P --> R[GRU Validation Gate]; P --> R[evaluation: GRU Validation Gate];
R -- Pass --> S[Train/Load SAC Agent]; R -- Pass --> S[modelling: Train/Load SAC Agent];
S --> T[Run Backtest]; S --> T[evaluation: Run Backtest];
T --> U[Record Fold Results]; T --> U[Record Fold Results];
end end
@ -73,172 +112,144 @@ graph TD
B --> Walk-ForwardLoop; B --> Walk-ForwardLoop;
X1 --> Y[Aggregate Fold Metrics]; X1 --> Y[Aggregate Fold Metrics];
Y --> Z[Aggregate SAC Agents]; Y --> Z[modelling: Aggregate SAC Agents];
Z --> Z1[Final Release Decision]; Z --> Z1[Final Release Decision];
Z1 --> Z_End([End Pipeline Run]); Z1 --> Z_End([End Pipeline Run]);
``` ```
*Diagram outlines the consolidated v3 pipeline flow after `revisions.txt` modifications.* *Diagram outlines the consolidated v3 pipeline flow, highlighting calls to stage functions.*
### Detailed Steps (Walk-Forward Enabled) ### Detailed Steps (Walk-Forward Enabled)
1. **Initialization (`run.py`):** 1. **Initialization (`run.py`):** Sets up infrastructure (config, logging, IO). Instantiates `TradingPipeline`.
* Parses args (`--config`, etc.). 2. **Data Loading (`TradingPipeline` calls `data_processing.load_and_preprocess`):** Loads the *entire* raw dataset.
* Loads `config.yaml`. 3. **Fold Generation (`TradingPipeline._generate_walk_forward_folds`):** Yields date ranges for each fold.
* Initializes `IOManager`, `LoggerSetup`. 4. **Fold Loop (`TradingPipeline.execute`):** Iterates through folds.
* Instantiates `TradingPipeline`. * **Select Fold Data:** Extracts raw data for the current fold range.
2. **Data Loading (`TradingPipeline.load_and_preprocess_data`):** * **Feature Engineering (`TradingPipeline` calls `data_processing.engineer_features_for_fold`):** Computes features on fold data.
* Loads the *entire* raw dataset specified in the config. * **Labeling (`TradingPipeline` calls `data_processing.define_labels_and_align_fold`):** Calculates labels.
3. **Fold Generation (`TradingPipeline._generate_walk_forward_folds`):** * **Split Fold Data (`TradingPipeline` calls `data_processing.split_data_fold`):** Splits into `train`, `val`, `test`.
* Based on `walk_forward` config (train/val/test/step days), yields date ranges for each fold. * **Baseline Gate (`TradingPipeline` calls `evaluation.run_baseline_checks_fold`):** Runs Logistic Regression check. Halts fold on failure.
4. **Fold Loop (`TradingPipeline.execute`):** Iterates through generated folds. * **Select Features (`TradingPipeline` calls `feature_processing.select_features_fold`):** Performs selection. Saves whitelist.
* **Select Fold Data:** Extracts raw data corresponding to the current fold's (Train+Val+Test) date range. * **Scale Features (`TradingPipeline` calls `feature_processing.scale_features_fold`):** Fits/applies scaler. Saves scaler.
* **Feature Engineering (`engineer_features`):** Computes base, TA, and micro-structure features (with NaN guards) on the fold's raw data (using `shift(1)` for time-dependent features). * **Prune Features (`TradingPipeline` calls `feature_processing.prune_features_fold`):** Prunes scaled data using whitelist.
* **Labeling (`define_labels_and_align`):** Calculates forward returns and target labels (binary/ternary) for the fold's engineered data. * **Create Sequences (`TradingPipeline` calls `sequence_creation.create_sequences_fold`):** Creates GRU input sequences.
* **Split Fold Data (`split_data`):** Splits the fold's labeled data into `train`, `val`, and `test` sets based on the fold's date ranges. Stores results like `self.X_train_raw`, `self.y_val`, `self.df_test_original`. * **Train/Load GRU (`TradingPipeline` calls `modelling.train_or_load_gru_fold`):** Handles training, Optuna sweep, or loading. Handles re-processing (scale/prune/sequence) internally if loaded scaler differs. Saves model/params.
* **Baseline Gate (`run_baseline_checks`):** Trains/validates Logistic Regression on fold's ***raw/engineered*** training features. Exits fold if CI lower bound < config threshold (`validation_gates.baseline`). Saves `baseline_report_fold_N.txt`. * **Calibrate Probabilities (`TradingPipeline` calls `modelling.calibrate_probabilities_fold`):** Fits calibrator, optimizes edge threshold. Saves parameters.
* **Select Features (`select_and_prune_features` - Selection Part):** Performs leakage check (`corr(ret+h, feat_t-1)`) and L1 + VIF selection (applying minimal whitelist *before* VIF) on fold's *raw/engineered* training features. Saves `final_whitelist_fold_N.json`. * **GRU Validation Gate (`TradingPipeline` calls `evaluation.run_gru_validation_checks_fold`):** Checks calibrated validation predictions. Halts fold on failure.
* **Scale Features (`scale_features`):** Fits `StandardScaler` on fold's *raw* training features (numeric only). Scales train, val, test features (`X_train_scaled`, etc.). Saves `feature_scaler_fold_N.joblib`. * **Train/Load SAC (`TradingPipeline` calls `modelling.train_or_load_sac_fold`):** Handles SAC training (calling `SACTrainer`) or determines load path.
* **Prune Features (`select_and_prune_features` - Pruning Part):** Prunes the *scaled* data splits (`X_train_scaled` -> `X_train_pruned`) using the `final_whitelist` determined earlier. * **Run Backtest (`TradingPipeline` calls `evaluation.run_backtest_fold`):** Instantiates `Backtester`, runs simulation, handles rolling calibration, performs backtest validation checks. Halts fold on failure.
* **Create Sequences (`create_sequences`):** Converts the fold's *pruned, scaled* train/val/test sets into sequences (`X_train_seq`, etc.). * **Store Fold Results:** Appends metrics and SAC agent path (if trained) for aggregation.
* **Train/Load GRU (`train_or_load_gru`):** 5. **Aggregate Metrics (`TradingPipeline.aggregate_fold_metrics`):** Calculates summary statistics across successful folds.
* If `sweep_enabled`, runs `GRUHyperTuner` (with updated objective, restricted search space, Keras callback for pruning on `val_loss`, logging objective components) to find best hyperparameters using fold's train/val sequences. Trains final model with best params. Saves best params JSON and Optuna plots. 6. **Aggregate SAC Agents (`TradingPipeline` calls `modelling.aggregate_sac_agents`):** (Optional) Averages weights of successful fold agents.
* If not sweeping, trains/loads GRU using config defaults. 7. **Final Release Decision (`TradingPipeline.final_release_decision`):** Evaluates aggregated metrics against final criteria.
* Saves the final fold GRU model (`gru_model_fold_N.keras`), history, and learning curve plot. 8. **Log Final Status:** Reports overall pipeline success/failure.
* **Calibrate Probabilities (`calibrate_probabilities`):**
* Fits the selected calibrator (Temp/Vector with optional L2 reg) on the fold's validation sequences.
* If `optimize_edge_threshold`, calculates optimal threshold using Youden's J, stores it internally (`self.optimized_edge_threshold`), and saves it (`optimized_edge_threshold_fold_N.txt`).
* Saves fold calibration parameters (`calibration_{temp/vector}_fold_N.npy`).
* **GRU Validation Gate (`_perform_gru_validation_checks`):** Checks edge-filtered accuracy CI and Brier score against updated config thresholds (`validation_gates.gru`) using the fold's determined `optimized_edge_threshold`. Exits fold if failed.
* **Train/Load SAC (`train_or_load_sac`):**
* If `train_sac`, initializes `SACTrainer` using a config copy (passing the fold's `optimized_edge_threshold` and disabling rolling calibration if active). Trains agent handling PER (with clipping, alpha annealing), Oracle Seeding (with IS weight decay), State Normalization, Action Penalty (`0.01/cost`). Saves agent, filter state, logs, plots in a fold-specific `sac_train_...` dir.
* If loading, determines path from config.
* **Run Backtest (`run_backtest`):**
* Initializes `Backtester`.
* Passes the fold's SAC agent path, test sequences, GRU handler, *initial* calibration state, the fold's `optimized_edge_threshold`, and original test prices.
* If `rolling_enabled`, the backtester uses raw predictions to refit calibrator, potentially triggered early by **ECE Coverage Alarm** (`ECE > config threshold`).
* Logs backtest performance (Sharpe, MDD, Win Rate) and gate pass/fail status. Fold failure here does *not* halt the pipeline immediately but is recorded.
* **Store Fold Results:** Appends the `backtest_metrics` dict (including status) to `all_fold_metrics`. Stores the `sac_agent_load_path` if SAC was trained successfully.
5. **Aggregate Metrics (`aggregate_fold_metrics`):** Calculates summary statistics (mean, std, min, max, median) across metrics from all *successful* folds. Saves `aggregated_wf_metrics.json`.
6. **Aggregate SAC Agents (`aggregate_sac_agents`):** (Optional: if `sac_aggregation.enabled`)
* Loads SAC agents from the stored paths of successful folds.
* Averages the weights of the loaded agents.
* Saves the aggregated agent to the main run's model dir (`models/run_.../sac_agent_aggregated/`). Saves `sac_aggregation_info.txt`.
7. **Final Release Decision (`final_release_decision`):** Evaluates aggregated metrics against overall release criteria defined in `validation_gates.final_release` (e.g., min % successful folds, median Sharpe).
8. **Log Final Status:** Logs whether the pipeline passed or failed the final release criteria.
*(Note: If Walk-Forward is disabled (`walk_forward.enabled=false`), the pipeline runs steps D-T once using static splits based on `split_ratios` config).* *(Note: If Walk-Forward is disabled, the pipeline runs the fold processing steps once using static splits).*\
## Project Structure
```
gru_sac_predictor/
├── config/
│ └── config.yaml # Main configuration file
├── data/ # Data storage (e.g., parquet files)
├── logs/ # Log output directory (run-specific subdirs)
├── models/ # Saved models/scalers etc. (run-specific subdirs)
├── results/ # Output results, metrics, plots (run-specific subdirs)
├── src/
│ ├── pipeline_stages/ # **Refactored stage-specific logic**
│ │ ├── data_processing.py
│ │ ├── evaluation.py
│ │ ├── feature_processing.py
│ │ ├── modelling.py
│ │ └── sequence_creation.py
│ ├── __init__.py
│ ├── backtester.py # Backtesting simulation engine
│ ├── baseline_checker.py # Baseline logistic regression check
│ ├── calibrator.py # Temperature scaling
│ ├── calibrator_vector.py # Vector scaling
│ ├── data_loader.py # Loads raw data
│ ├── feature_engineer.py # Feature creation logic
│ ├── features.py # Feature lists/definitions (optional)
│ ├── gru_hyper_tuner.py # Optuna hyperparameter tuner for GRU
│ ├── gru_model_handler.py # GRU model building, training, loading (v2/v3)
│ ├── io_manager.py # Handles file I/O structure
│ ├── logger_setup.py # Configures logging
│ ├── metrics.py # Performance metrics calculation
│ ├── sac_agent.py # SAC agent network definitions
│ ├── sac_trainer.py # Offline SAC training orchestration
│ ├── trading_env.py # Gym-like environment for SAC training
│ ├── trading_pipeline.py # Main pipeline orchestrator class
│ └── utils/ # Utility functions (e.g., run_id generation)
├── tests/ # Unit/integration tests (optional)
├── run.py # Main execution script
├── requirements.txt # Python dependencies
└── README.md # This file
```
## Core Components Architecture ## Core Components Architecture
This section details the architecture and purpose of key modules. * **`run.py`**: Entry point, sets up IO/logging, runs the pipeline.
* **`TradingPipeline` (`src/trading_pipeline.py`):** Orchestrates the workflow by calling stage functions, manages overall state, handles walk-forward loop and validation gates.
* **`src/pipeline_stages/*.py`**: Contain the core implementation logic for each distinct step of the pipeline (data processing, feature processing, sequencing, modelling, evaluation).
* **`IOManager`, `LoggerSetup`**: Utilities for managing outputs and logging.
* **`DataLoader`, `FeatureEngineer`**: Data loading and feature generation.
* **`GRUModelHandler`, `GRUHyperTuner`**: GRU model implementation (v2/v3), training, loading, and Optuna tuning.
* **`Calibrator`, `VectorCalibrator`**: Probability calibration logic.
* **`SACTradingAgent`, `TradingEnv`, `SACTrainer`**: SAC agent definition, training environment, and offline training orchestration (including PER, normalization, seeding).
* **`Backtester`, `BaselineChecker`, `metrics.py`**: Backtesting simulation, baseline checks, and performance metric calculations.
### 1. Orchestration & Utilities (`TradingPipeline`, `IOManager`, `LoggerSetup`)
* **`TradingPipeline` (`src/trading_pipeline.py`):** The main class coordinating the workflow (data prep, feature eng, model training, calibration, backtesting, aggregation) potentially within a walk-forward loop. Reads config, manages data flow, implements validation gates, and calls other components.
* **`IOManager` (`src/io_manager.py`):** Handles standardized file I/O (saving/loading models, dataframes, scalers, configs, plots, reports) within a run-specific directory structure.
* **`LoggerSetup` (`src/logger_setup.py`):** Configures Python's `logging` for console and file output with standardized formats and levels.
### 2. Data Handling (`DataLoader`, `FeatureEngineer`)
* **`DataLoader` (`src/data_loader.py`):** Loads raw OHLCV data from specified database files/sources.
* **`FeatureEngineer` (`src/feature_engineer.py`):** Generates features:
* **TA:** Returns, ATR, EMA, RSI (MACD removed).
* **Cyclical:** Hour, Week (sin/cos).
* **Imbalance:** Chaikin AD, SVI, Gap Imbalance.
* **Micro-structure:** Spread Proxy, Vol-Norm Volume Spike, Return Asymmetry, Close-Location Value, Keltner Band Position (with NaN guards).
* **Leakage Guard:** Uses `shift(1)` on inputs for time-dependent features.
* **Target Definition:** Calculates forward returns and binary/ternary labels.
* **Selection/Pruning:** Performs **leakage check** (`corr(ret+h, feat_t-1)`), then L1+VIF selection on *raw/engineered* data (applying minimal whitelist before VIF), then prunes *scaled* data based on the selection.
### 3. GRU Predictor (`gru_model_handler.py`, `gru_hyper_tuner.py`)
* **`gru_model_handler.py`:** **Consolidated GRU implementation.**
* Contains builders for both v2 (`build_gru_model`) and v3 (`build_gru_model_v3`) architectures.
* **v3 Architecture:** GRU -> LayerNorm -> Optional MultiHeadAttention -> GlobalAvgPool -> Dense Heads (`mu`, `dir3_logits`). Includes L2 regularization. Uses Huber loss for `mu`, Focal loss for `dir3`.
* Manages training (with early stopping, CSV logging), saving (.keras format), loading (handles custom loss `gaussian_nll`), and prediction (`predict`, `predict_logits`).
* Selects v2/v3 based on `control.use_v3` flag.
* **`gru_hyper_tuner.py`:** Implements Optuna hyperparameter sweep for GRU.
* Called by `TradingPipeline` if sweep is enabled.
* Uses **restricted search space** based on config (`hyperparameter_tuning.gru`).
* Uses combined objective based on config (`objective_metric`, `objective_edge_acc_weight`, `objective_brier_weight`). Logs components to trial attributes.
* Supports pruning via **Keras callback** reporting `val_loss` each epoch.
* Trains final fold model using best found parameters.
* Saves best parameters JSON and Optuna plots per fold.
### 4. Probability Calibration (`Calibrator`, `VectorCalibrator`, `metrics.py`)
* **`Calibrator` (`src/calibrator.py`):** Implements Temperature Scaling (learns scalar `T`) with optional L2 regularization.
* **`VectorCalibrator` (`src/calibrator_vector.py`):** Implements Vector Scaling (learns matrix `W`, bias `b`) with optional L2 regularization. Preferred for ternary.
* **Integration:** `TradingPipeline.calibrate_probabilities` fits chosen calibrator per fold (with L2 reg). If `optimize_edge_threshold=true`, calculates and stores optimal edge for the fold. `Backtester` applies calibration step-by-step, handles rolling recalibration if enabled (with **ECE-based coverage alarm**), using the **static fold calibration** if SAC training was active for the fold.
* **Edge Threshold Optimization:** `metrics._calculate_optimal_edge_threshold` finds best threshold using Youden's J. Called by `TradingPipeline` if enabled.
### 5. SAC Agent (`sac_agent.py`, `sac_trainer.py`, `trading_env.py`)
* **`sac_agent.py`:** Defines the SAC agent networks (Actor, Critic) and update logic.
* Actor outputs squashed Gaussian distribution parameters.
* Uses twin Q-critics.
* Handles automatic entropy tuning (alpha).
* `train` method accepts a batch and returns losses + TD errors (for PER).
* **`trading_env.py`:** Gym-style environment using GRU predictions.
* **State:** `[mu, sigma, edge, |mu|/sigma, position]`.
* Takes action (-1 to +1).
* Calculates reward based on PnL, potentially scaled (`reward_scale`) and penalized (**action penalty**, default: `0.01 / transaction_cost`).
* **`sac_trainer.py`:** Orchestrates offline SAC training.
* Loads GRU dependencies for a specific run.
* Prepares validation data for the `TradingEnv` (uses fold's static calibration).
* Initializes `TradingEnv` and `SACTradingAgent`.
* Manages the Replay Buffer:
* Implements `PrioritizedReplayBuffer` if `sac.use_per` is true (with **TD error clipping** and **alpha annealing**). Logs TD error distribution stats.
* Uses `collections.deque` for standard uniform replay.
* Handles **Oracle Seeding** into the buffer (with **IS weight decay** for seeded samples).
* Runs the training loop: interacts with env, stores transitions, samples batches (uniform or PER), updates agent, updates PER priorities.
* Handles **State Normalization** using `MeanStdFilter` if `sac.use_state_filter` is true (saves/loads filter state).
* Saves agent checkpoints, final agent, state filter, and logs (rewards, TensorBoard).
### 6. Evaluation (`Backtester`, `BaselineChecker`, `metrics.py`)
* **`Backtester` (`src/backtester.py`):** Evaluates the full system on test data.
* Takes trained GRU, SAC agent, initial calibration state, and the fold's edge threshold.
* Simulates step-by-step trading.
* Applies **rolling calibration** logic if enabled (with ECE check).
* Calculates PnL, equity curve, standard performance metrics.
* Saves detailed results dataframe, metrics summary, and plots per fold.
* Logs performance and whether fold backtest gates passed/failed, but does **not** halt the fold on failure (decision made during final aggregation).
* **`BaselineChecker` (`src/baseline_checker.py`):** Performs initial logistic regression check on **raw/engineered** training features.
* **`metrics.py`:** Contains calculation functions for Sharpe, Brier, Edge-Filtered Accuracy, ECE, and the Youden's J optimization helper.
## Configuration (`config.yaml`) ## Configuration (`config.yaml`)
The `config.yaml` file centrally controls the pipeline's behavior. Key sections include: The `config.yaml` file centrally controls the pipeline's behavior. See comments within the default `config.yaml` for detailed explanations of each parameter. Key sections include:
* `base_dirs`: Output directories. * `base_dirs`, `output`: Directory and output settings.
* `output`: Figure DPI, size, logging level. * `data`, `features`: Data sources, labeling, feature selection controls.
* `data`: Data source details, label smoothing. * `walk_forward`, `split_ratios`: Controls walk-forward vs. static splits.
* `features`: Minimal feature whitelist, leakage threshold. * `gru`, `gru_v3`: GRU architecture, training parameters.
* `walk_forward`: Settings for WF validation (enable, days, step). **Note:** `walk_forward.enabled=true` overrides `split_ratios`. * `hyperparameter_tuning`: Optuna sweep settings for GRU.
* `split_ratios`: Used only if `walk_forward.enabled=false`. * `calibration`: Calibration method, parameters, rolling calibration, ECE alarm.
* `gru`: General GRU settings (horizon, lookback, ternary flag, flat sigma mult). * `validation_gates`: Thresholds for baseline, GRU, backtest, and final release checks.
* `gru_v3`: Specific hyperparameters for the v3 architecture (units, attention, losses, reg). * `sac`, `environment`: SAC agent hyperparameters, PER, seeding, environment settings.
* `hyperparameter_tuning`: Controls Optuna sweep for GRU (enable, trials, timeout, pruning, objective metric/weights). * `sac_aggregation`: Agent averaging settings.
* `calibration`: Method (temp/vector), L2 lambda, optimize edge threshold flag, rolling calibration settings (enable, freq, window, ECE alarm). * `control`: High-level flags (train/load models, enable plots, use v3 GRU).
* `validation_gates`: Thresholds for baseline, GRU gates, and final release decision (median Sharpe, % success).
* `sac`: SAC hyperparameters (gamma, tau, LR, alpha, PER settings, oracle seeding, IS weight decay steps, state filter). ## Installation
* `sac_aggregation`: Controls post-run agent averaging (enable, method).
* `environment`: Trading env parameters (capital, costs, reward scale, action penalty lambda). 1. Clone the repository.
* `control`: Flags to enable/disable major stages (train GRU, train SAC, run backtest, use v3, plots), model loading/resuming IDs. 2. Ensure you have Python 3.8+ installed.
3. Set up a virtual environment (recommended):
```bash
python -m venv .venv
source .venv/bin/activate # On Windows use `.venv\\Scripts\\activate`
```
4. Install dependencies:
```bash
pip install -r requirements.txt
```
*Note: This installs necessary libraries like TensorFlow, PyTorch, Optuna, scikit-learn, pandas, etc.*
5. Prepare your data according to the expected format and update paths in `config.yaml`.
## Usage ## Usage
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location. 1. **Configure:** Edit `config.yaml` to set data paths, feature lists, model parameters, control flags (e.g., `train_gru`, `train_sac`), walk-forward settings, validation thresholds, etc.
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation). 2. **Run Pipeline:** Execute `run.py` from the project root directory (`gru_sac_predictor/`), specifying the configuration file:
3. **Run Pipeline:**
```bash ```bash
# From project root (develop/gru_sac_predictor/) # Example execution from the parent directory 'develop/gru_sac_predictor/'
python gru_sac_predictor/run.py --config path/to/your_config.yaml python gru_sac_predictor/run.py --config gru_sac_predictor/config/config.yaml
``` ```
4. **Outputs:** Check `logs/`, `models/`, `results/` directories for run-specific outputs, including fold-specific artifacts if WF is enabled. * You can use other command-line arguments like `--use-ternary` if implemented in `run.py`.
3. **Outputs:** Check the directories specified in `config.yaml` (typically subdirectories within `logs/`, `models/`, `results/`) for run-specific outputs. If Walk-Forward is enabled, you will find fold-specific subdirectories containing models, scalers, plots, and results for each fold.
## Output Artifacts (Walk-Forward Enabled Example) ## Output Artifacts (Walk-Forward Enabled Example)
* **Main Run Dirs:** `logs/run_<id>/`, `models/run_<id>/`, `results/run_<id>/` * **Main Run Dirs:** `logs/run_<id>/`, `models/run_<id>/`, `results/run_<id>/`
* `run_config.yaml`, `pipeline_<id>.log` * `run_config.yaml`, `pipeline_<id>.log`
* (Post-run) `aggregated_wf_metrics.json`, `sac_aggregation_info.txt` * (Post-run) `aggregated_wf_metrics.json`, `sac_aggregation_info.txt` (if enabled)
* (Post-run) `models/.../sac_agent_aggregated/` * (Post-run) `models/.../sac_agent_aggregated/` (if enabled)
* **Fold Dirs (within main run dirs):** e.g., `models/run_<id>/fold_1/` * **Fold Dirs (within main run dirs):** e.g., `models/run_<id>/fold_1/`
* `models/run_<id>/fold_N/models/`: `gru_model_fold_N.keras`, `calibration_{...}_fold_N.npy`, `feature_scaler_fold_N.joblib`, `final_whitelist_fold_N.json` * `models/run_<id>/fold_N/models/`: `gru_model_fold_N.keras`, `calibration_{...}_fold_N.npy`, `feature_scaler_fold_N.joblib`, `final_whitelist_fold_N.json`
* `models/run_<id>/fold_N/hypertuning/`: (If sweep enabled) `best_gru_params.json`, Optuna plots. * `models/run_<id>/fold_N/hypertuning/`: (If sweep enabled) `best_gru_params.json`, Optuna plots.
@ -248,4 +259,17 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections
## Dependencies ## Dependencies
See `requirements.txt`. Key libraries include: TensorFlow, NumPy, Pandas, PyYAML, Scikit-learn, Statsmodels, TA-Lib (via `ta` wrapper), Matplotlib, Seaborn, Optuna, PyTorch (for SAC aggregation). Note `tensorflow-addons` is required for optimal focal loss / attention layers. All major Python dependencies are listed in `requirements.txt`. Key libraries include:
* TensorFlow (for GRU)
* PyTorch (for SAC)
* Optuna (for hyperparameter tuning)
* scikit-learn (for scaling, metrics, baseline)
* pandas, numpy
* pyyaml
* matplotlib, seaborn
Install them using:
```bash
pip install -r requirements.txt
```

File diff suppressed because one or more lines are too long

View File

@ -19,6 +19,8 @@ from datetime import datetime
from tqdm import tqdm from tqdm import tqdm
from tensorflow.keras.callbacks import TensorBoard from tensorflow.keras.callbacks import TensorBoard
import collections import collections
import csv
import time
# Import necessary components from the pipeline # Import necessary components from the pipeline
# Use absolute imports assuming the package structure is correct # Use absolute imports assuming the package structure is correct
@ -208,6 +210,8 @@ class SACTrainer:
self.control_cfg = config.get('control', {}) self.control_cfg = config.get('control', {})
self.data_cfg = config['data'] self.data_cfg = config['data']
# --- Cache useful config values ---
self.use_ternary = self.config.get('gru', {}).get('use_ternary', False)
# --- Store PER config params --- # # --- Store PER config params --- #
self.use_per = self.sac_cfg.get('use_per', False) self.use_per = self.sac_cfg.get('use_per', False)
self.per_alpha = self.sac_cfg.get('per_alpha', 0.6) self.per_alpha = self.sac_cfg.get('per_alpha', 0.6)
@ -301,23 +305,85 @@ class SACTrainer:
# 3. Load GRU Model # 3. Load GRU Model
model_path = os.path.join(gru_run_models_dir, f"gru_model_{gru_run_id}.keras") model_path = os.path.join(gru_run_models_dir, f"gru_model_{gru_run_id}.keras")
# Need a temporary GRU handler instance to use its load method # Need a temporary GRU handler instance to use its load method
temp_gru_handler = GRUModelHandler(run_id="temp_load", models_dir="temp_load") # Pass self.config (the trainer's config) to the GRUModelHandler
dependencies['gru_model'] = temp_gru_handler.load(model_path) temp_gru_handler = GRUModelHandler(run_id="temp_load", models_dir="temp_load", config=self.config)
# Attempt to load the model using the handler
# The load method likely needs the full path, not just the directory and ID
loaded_model_info = temp_gru_handler.load(model_path=model_path) # Pass full path
# Adjust based on what gru_handler.load returns
# Assuming it returns (model, info_dict) or None
if loaded_model_info and isinstance(loaded_model_info, tuple) and len(loaded_model_info) > 0:
dependencies['gru_model'] = loaded_model_info[0] # Get the actual model
if dependencies['gru_model'] is None: if dependencies['gru_model'] is None:
logger.error(f"Failed to load GRU model from {model_path}") logger.error(f"GRU handler load method returned None for model from {model_path}")
return None return None
elif loaded_model_info and not isinstance(loaded_model_info, tuple): # If it just returns the model
dependencies['gru_model'] = loaded_model_info
else: # If it returned None or unexpected tuple
logger.error(f"Failed to load GRU model using handler from {model_path}")
return None
logger.info(f"Loaded GRU model from {model_path}") logger.info(f"Loaded GRU model from {model_path}")
# 4. Load Optimal Temperature # 4. Load Calibration Info and Parameters
temp_path = os.path.join(gru_run_models_dir, f"calibration_temp_{gru_run_id}.npy") calib_info_path = os.path.join(gru_run_models_dir, f"calibration_info_{gru_run_id}.json")
calib_params = None
calib_method = None
try: try:
dependencies['optimal_T'] = float(np.load(temp_path)) with open(calib_info_path, 'r') as f:
logger.info(f"Loaded optimal temperature T={dependencies['optimal_T']:.4f} from {temp_path}") calib_info = json.load(f)
logger.info(f"Loaded calibration info: {calib_info}")
calib_method = calib_info.get("method")
params_filename = calib_info.get("params_filename")
if calib_method and params_filename:
params_path = os.path.join(gru_run_models_dir, params_filename)
if os.path.exists(params_path):
if calib_method == "temperature":
calib_params = float(np.load(params_path))
logger.info(f"Loaded optimal temperature T={calib_params:.4f} from {params_path}")
elif calib_method == "vector":
# Load vector params (assuming saved as .npy)
calib_params = np.load(params_path, allow_pickle=True)
logger.info(f"Loaded vector calibration params from {params_path} (type: {type(calib_params)}, shape: {getattr(calib_params, 'shape', 'N/A')})")
else:
logger.warning(f"Unknown calibration method '{calib_method}' in info file. Cannot load params.")
else:
logger.error(f"Calibration parameter file specified in info ({params_filename}) not found at {params_path}")
return None # Fail if params file missing
elif calib_method in ["temperature_failed", "vector_failed", "skipped_mismatch", "temperature_error", "vector_error", "vector_unavailable"]:
logger.warning(f"Calibration was not successful or was skipped during GRU run (method: {calib_method}). SAC training may be suboptimal.")
# Decide if this is fatal. For now, let's allow it but store None.
calib_params = None
else:
logger.error(f"Invalid calibration info content: {calib_info}")
return None # Fail if info is invalid
except FileNotFoundError:
logger.error(f"Calibration info file not found at {calib_info_path}. Cannot determine calibration parameters.")
# Check for legacy temperature file for backward compatibility?
legacy_temp_path = os.path.join(gru_run_models_dir, f"calibration_temp_{gru_run_id}.npy")
if os.path.exists(legacy_temp_path):
logger.warning("Calibration info file missing, but found legacy temperature file. Attempting to load it.")
try:
calib_params = float(np.load(legacy_temp_path))
calib_method = "temperature" # Assume temperature
logger.info(f"Loaded legacy optimal temperature T={calib_params:.4f} from {legacy_temp_path}")
except Exception as legacy_e:
logger.error(f"Failed to load legacy temperature file {legacy_temp_path}: {legacy_e}")
return None # Fail if legacy load fails
else:
logger.error("Neither calibration info file nor legacy temperature file found.")
return None # Fail if no calibration info found
except Exception as e: except Exception as e:
logger.error(f"Failed to load optimal temperature from {temp_path}: {e}", exc_info=True) logger.error(f"Failed to load calibration info/parameters: {e}", exc_info=True)
# Allow continuation without T? Or require it? Let's require it for now.
return None return None
# Store method and params in dependencies
dependencies['calibration_method'] = calib_method
dependencies['calibration_params'] = calib_params
logger.info("--- Successfully loaded all GRU dependencies ---") logger.info("--- Successfully loaded all GRU dependencies ---")
return dependencies return dependencies
@ -349,107 +415,284 @@ class SACTrainer:
df_raw.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True) df_raw.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True)
logger.info("Loaded raw data.") logger.info("Loaded raw data.")
# 2. Engineer Base Features (using a temporary FeatureEngineer) # 2. Engineer Base Features (using config)
# Pass the *minimal* whitelist as a fallback if the loaded one causes issues # FIX: Instantiate FeatureEngineer correctly using self.config
temp_feature_engineer = FeatureEngineer(minimal_whitelist=minimal_whitelist) temp_feature_engineer = FeatureEngineer(config=self.config)
df_engineered = temp_feature_engineer.add_base_features(df_raw) df_engineered = temp_feature_engineer.add_base_features(df_raw)
df_engineered.dropna(inplace=True) # Drop NaNs after feature eng df_engineered.dropna(inplace=True) # Drop NaNs after feature eng
if df_engineered.empty: raise ValueError("Dataframe empty after feature engineering.") if df_engineered.empty: raise ValueError("Dataframe empty after feature engineering.")
logger.info("Engineered base features.") logger.info("Engineered base features.")
# 3. Prune Features using *loaded* whitelist # 3. Define Labels (on the full engineered data)
loaded_whitelist = gru_dependencies['whitelist']
missing_in_eng = [f for f in loaded_whitelist if f not in df_engineered.columns]
if missing_in_eng:
raise ValueError(f"Features from loaded whitelist missing in engineered data: {missing_in_eng}")
df_features = df_engineered[loaded_whitelist]
logger.info("Pruned features using loaded whitelist.")
# 4. Define Labels
horizon = self.config['gru'].get('prediction_horizon', 5) horizon = self.config['gru'].get('prediction_horizon', 5)
target_ret_col = f'fwd_log_ret_{horizon}' target_ret_col = f'fwd_log_ret_{horizon}'
target_dir_col = f'direction_label_{horizon}' target_dir_col = f'direction_label3_{horizon}' if self.use_ternary else f'direction_label_{horizon}'
df_engineered[target_ret_col] = np.log(df_engineered['close'].shift(-horizon) / df_engineered['close']) _EPS = 1e-9
if 'future_close' not in df_engineered.columns:
df_engineered['future_close'] = df_engineered['close'].shift(-horizon)
if 'future_close' in df_engineered.columns:
df_engineered[target_ret_col] = np.log(df_engineered['future_close'] / (df_engineered['close'] + _EPS))
else:
df_engineered[target_ret_col] = np.log(df_engineered['close'].shift(-horizon) / (df_engineered['close'] + _EPS))
if self.use_ternary:
flat_sigma = self.config.get('gru', {}).get('flat_sigma_multiplier', 0.3)
if 'ATR_14' in df_engineered.columns:
atr_aligned = df_engineered['ATR_14'].reindex(df_engineered.index).bfill().ffill()
threshold = flat_sigma * atr_aligned / (df_engineered['close'] + _EPS)
df_engineered[target_dir_col] = np.select(
[df_engineered[target_ret_col] > threshold, df_engineered[target_ret_col] < -threshold],
[2, 0], default=1
)
else:
raise ValueError("ATR_14 needed for ternary labels not found.")
else: # Binary case
df_engineered[target_dir_col] = (df_engineered[target_ret_col] > 0).astype(int) df_engineered[target_dir_col] = (df_engineered[target_ret_col] > 0).astype(int)
# Align by dropping NaNs in targets AND ensuring indices match features
# 4. Align Features and Targets
df_engineered.dropna(subset=[target_ret_col, target_dir_col], inplace=True) df_engineered.dropna(subset=[target_ret_col, target_dir_col], inplace=True)
common_index = df_features.index.intersection(df_engineered.index) # Identify *all* potential feature columns (excluding targets)
if common_index.empty: potential_feature_cols = [col for col in df_engineered.columns if col not in [target_ret_col, target_dir_col, 'future_close']]
raise ValueError("No common index between features and targets after label definition.") # Ensure whitelist features are present in potential features
df_features = df_features.loc[common_index] loaded_whitelist = gru_dependencies['whitelist']
df_targets = df_engineered.loc[common_index, [target_ret_col, target_dir_col]] missing_whitelist_check = set(loaded_whitelist) - set(potential_feature_cols)
if missing_whitelist_check:
logger.warning(f"Whitelist features missing from engineered columns: {missing_whitelist_check}. Adding with 0 fill.")
for col in missing_whitelist_check:
df_engineered[col] = 0.0
potential_feature_cols.append(col)
# Now df_engineered contains all potential features and targets, aligned
logger.info("Defined labels and aligned features/targets.") logger.info("Defined labels and aligned features/targets.")
# 5. Split Data (to get validation set indices) # 5. Split Data (using aligned engineered data)
split_cfg = self.config['split_ratios'] wf_enabled = self.config.get('walk_forward', {}).get('enabled', False)
train_ratio, val_ratio = split_cfg['train'], split_cfg['validation'] if wf_enabled: logger.warning("Walk-forward enabled, but SAC uses split_ratios.")
total_len = len(df_features) split_ratios = self.config.get('walk_forward', {}).get('split_ratios', {})
train_ratio = split_ratios.get('train', 0.6); val_ratio = split_ratios.get('validation', 0.2)
if not (0 < train_ratio < 1 and 0 < val_ratio < 1 and (train_ratio + val_ratio) <= 1.0):
raise ValueError(f"Invalid split ratios: train={train_ratio}, validation={val_ratio}")
total_len = len(df_engineered)
train_end_idx = int(total_len * train_ratio) train_end_idx = int(total_len * train_ratio)
val_end_idx = int(total_len * (train_ratio + val_ratio)) val_end_idx = int(total_len * (train_ratio + val_ratio))
val_indices = df_features.index[train_end_idx:val_end_idx] val_indices = df_engineered.index[train_end_idx:val_end_idx]
if val_indices.empty: raise ValueError("Validation split resulted in empty indices.") if val_indices.empty: raise ValueError("Validation split resulted in empty indices.")
X_val_pruned = df_features.loc[val_indices] df_val_aligned = df_engineered.loc[val_indices]
y_val = df_targets.loc[val_indices]
logger.info(f"Isolated validation set data (Features: {X_val_pruned.shape}, Targets: {y_val.shape}).")
# 6. Scale Validation Features using *loaded* scaler # -- Determine columns expected by scaler --
scaler = gru_dependencies['scaler'] scaler = gru_dependencies['scaler']
numeric_cols = X_val_pruned.select_dtypes(include=np.number).columns expected_scaler_features = []
X_val_scaled = X_val_pruned.copy() if hasattr(scaler, 'feature_names_in_'):
if not numeric_cols.empty: expected_scaler_features = scaler.feature_names_in_.tolist()
X_val_scaled[numeric_cols] = scaler.transform(X_val_pruned[numeric_cols]) logger.debug(f"Scaler expects features: {expected_scaler_features}")
logger.info("Scaled validation features using loaded scaler.") else:
# Fallback: Use all numeric columns if scaler has no names
logger.warning("Scaler has no feature names saved. Assuming it was fit on all numeric columns present at that time.")
# This is risky; need to ensure columns match implicitly
expected_scaler_features = df_val_aligned.select_dtypes(include=np.number).columns.tolist()
# Manually exclude known target columns if needed as a safeguard
expected_scaler_features = [f for f in expected_scaler_features if f not in [target_ret_col, target_dir_col]]
logger.debug(f"Falling back to using numeric columns for scaler: {expected_scaler_features}")
# 7. Create Validation Sequences # Ensure all expected scaler features exist in the validation slice
missing_for_scaler = set(expected_scaler_features) - set(df_val_aligned.columns)
if missing_for_scaler:
# Attempt to add missing with 0 fill (e.g., if a feature wasn't calculable for val split)
logger.warning(f"Columns expected by scaler missing from validation data: {missing_for_scaler}. Adding with 0 fill.")
for col in missing_for_scaler:
df_val_aligned[col] = 0.0
# Re-verify
missing_for_scaler = set(expected_scaler_features) - set(df_val_aligned.columns)
if missing_for_scaler:
raise ValueError(f"Could not prepare all features expected by scaler, still missing: {missing_for_scaler}")
# Isolate the exact features needed for the scaler transform
X_val_engineered_for_scaler = df_val_aligned[expected_scaler_features]
y_val = df_val_aligned[[target_ret_col, target_dir_col]]
logger.info(f"Isolated validation set data for scaler (Features: {X_val_engineered_for_scaler.shape}, Targets: {y_val.shape}).")
# 6. Scale the features expected by the scaler
X_val_scaled_full = X_val_engineered_for_scaler.copy()
try:
X_val_scaled_full[expected_scaler_features] = scaler.transform(X_val_engineered_for_scaler)
logger.info("Scaled validation features using loaded scaler.")
except ValueError as e:
logger.error(f"Error applying scaler transform even after aligning columns: {e}")
raise
# 7. WORKAROUND: Drop non-feature columns (like future_close) *after* scaling
# Identify actual features intended for the model (whitelist)
loaded_whitelist = gru_dependencies['whitelist']
cols_to_keep_after_scale = [f for f in loaded_whitelist if f in X_val_scaled_full.columns]
missing_whitelist_post_scale = set(loaded_whitelist) - set(cols_to_keep_after_scale)
if missing_whitelist_post_scale:
# This shouldn't happen if whitelist features were in expected_scaler_features
logger.warning(f"Whitelist features missing AFTER scaling: {missing_whitelist_post_scale}. This might indicate issues.")
if not cols_to_keep_after_scale:
raise ValueError("Whitelist resulted in no features remaining after scaling step.")
X_val_scaled_eng = X_val_scaled_full[cols_to_keep_after_scale].copy()
logger.info(f"Selected whitelisted features after scaling. Shape: {X_val_scaled_eng.shape}")
# 8. Pruning step is now effectively done by selecting whitelist columns above.
X_val_pruned_scaled = X_val_scaled_eng # Rename for clarity in subsequent steps
logger.info(f"Using whitelisted scaled features for sequencing. Final shape: {X_val_pruned_scaled.shape}")
# 9. Create Validation Sequences (using the pruned & scaled data)
lookback = self.config['gru']['lookback'] lookback = self.config['gru']['lookback']
X_val_seq = [] X_val_seq_list, y_val_seq_targets_list, val_seq_indices = [], [], []
y_val_seq_targets = [] # Store corresponding targets features_np_arr = X_val_pruned_scaled.values
val_seq_indices = [] # Store corresponding indices targets_np_arr = y_val.values
features_np = X_val_scaled.values for i in range(lookback, len(features_np_arr)):
targets_np = y_val.values # Contains both ret and dir X_val_seq_list.append(features_np_arr[i-lookback : i])
for i in range(lookback, len(features_np)): y_val_seq_targets_list.append(targets_np_arr[i])
X_val_seq.append(features_np[i-lookback : i])
y_val_seq_targets.append(targets_np[i]) # Target corresponds to end of sequence
val_seq_indices.append(y_val.index[i]) val_seq_indices.append(y_val.index[i])
if not X_val_seq: if not X_val_seq_list: raise ValueError("Validation sequence creation resulted in empty list.")
raise ValueError("Validation sequence creation resulted in empty list.") X_val_seq = np.array(X_val_seq_list)
y_val_seq_targets = np.array(y_val_seq_targets_list)
X_val_seq = np.array(X_val_seq) actual_ret_val_seq = y_val_seq_targets[:, 0]
y_val_seq_targets = np.array(y_val_seq_targets) y_dir_val_seq = y_val_seq_targets[:, 1]
actual_ret_val_seq = y_val_seq_targets[:, 0] # First column is return
y_dir_val_seq = y_val_seq_targets[:, 1] # Second column is direction
logger.info(f"Created validation sequences (X shape: {X_val_seq.shape}).") logger.info(f"Created validation sequences (X shape: {X_val_seq.shape}).")
# 8. Get GRU Predictions on Validation Sequences using *loaded* GRU model # 10. Get GRU Predictions using loaded GRU model directly
gru_model = gru_dependencies['gru_model'] gru_model = gru_dependencies['gru_model']
# Use a temporary handler instance with the loaded model logger.info(f"Generating GRU predictions using loaded model (type: {type(gru_model)}).")
temp_gru_handler = GRUModelHandler(run_id="temp_predict", models_dir="temp")
temp_gru_handler.model = gru_model # Assign the loaded model # Check model type and predict accordingly
predictions_val = temp_gru_handler.predict(X_val_seq) if not hasattr(gru_model, 'predict'):
if predictions_val is None or len(predictions_val) < 3: raise TypeError("Loaded GRU model object does not have a 'predict' method.")
raise ValueError("GRU prediction on validation sequences failed.")
mu_val_pred = predictions_val[0].flatten() predictions_val = gru_model.predict(X_val_seq)
log_sigma_val_pred = predictions_val[1][:, 1].flatten()
p_raw_val_pred = predictions_val[2].flatten() # --- Explicit Shape Logging --- #
sigma_val_pred = np.exp(log_sigma_val_pred) if isinstance(predictions_val, list):
logger.info("Generated GRU predictions on validation sequences.") logger.info(f"DEBUG: GRU prediction output list length: {len(predictions_val)}")
if len(predictions_val) > 0: logger.info(f"DEBUG: Shape of predictions_val[0]: {predictions_val[0].shape}")
if len(predictions_val) > 1: logger.info(f"DEBUG: Shape of predictions_val[1]: {predictions_val[1].shape}")
if len(predictions_val) > 2: logger.info(f"DEBUG: Shape of predictions_val[2]: {predictions_val[2].shape}")
else:
logger.info(f"DEBUG: GRU prediction output type: {type(predictions_val)}")
# --- End Explicit Shape Logging --- #
# --- Infer output structure (Corrected Indexing) ---
mu_val_pred, sigma_val_pred, p_raw_val_pred, logits_val_pred = None, None, None, None
if isinstance(predictions_val, list) and len(predictions_val) == 2:
logger.debug(f"GRU model returned list of {len(predictions_val)} outputs. Correcting index assumption: [mu, dir].")
# Corrected Indexing based on logs:
mu_output = predictions_val[0] # Index 0 has shape (N, 1)
dir_output = predictions_val[1] # Index 1 has shape (N, 3)
# Extract Mu
if mu_output.ndim == 2 and mu_output.shape[-1] == 1:
mu_val_pred = mu_output.flatten()
else:
# Corrected error message to reflect index 0 check
raise ValueError(f"Unexpected shape for mu output (index 0): {mu_output.shape}. Expected (N, 1).")
# Handle missing Sigma - Use a default/fallback
logger.warning("Log_sigma_sq output not found in model prediction. Using default sigma=0.1.")
sigma_val_pred = np.ones_like(mu_val_pred) * 0.1
# Determine direction output type from dir_output (index 1)
if self.use_ternary:
if dir_output.ndim == 2 and dir_output.shape[-1] == 3:
# Assume dir_output (index 1) is logits if ternary
logits_val_pred = dir_output
from scipy.special import softmax # Local import ok here
p_raw_val_pred = softmax(logits_val_pred, axis=-1)
logger.info("Inferred ternary output (logits assumed at index 1). Calculated raw probabilities.")
else:
raise ValueError(f"Expected ternary output shape (N, 3) at index 1, got {dir_output.shape}")
else: # Binary
# Binary case is tricky if dir_output (index 1) is still (N, 3)
# This shouldn't happen if the model structure changes for binary
# Let's assume the mu_output (index 0) might be P(up) in binary?
logger.warning(f"Binary mode, dir_output shape is {dir_output.shape}. Trying to use mu output (index 0) as P(up) - THIS IS RISKY.")
if mu_output.ndim == 2 and mu_output.shape[-1] == 1:
p_raw_val_pred = mu_output.flatten()
epsilon = 1e-7
p_clipped = np.clip(p_raw_val_pred, epsilon, 1 - epsilon)
logits_val_pred = np.log(p_clipped / (1 - p_clipped))
logger.info("Using mu output (index 0) as P(up) for binary mode. Inferred raw logits.")
else:
raise ValueError(f"Cannot determine binary P(up) output. mu_output shape: {mu_output.shape}")
else:
raise ValueError(f"GRU prediction failed or returned unexpected type/structure: {type(predictions_val)}, length: {len(predictions_val) if isinstance(predictions_val, list) else 'N/A'}. Expected list of 2 arrays.")
# Final check for necessary predictions
if mu_val_pred is None or sigma_val_pred is None or p_raw_val_pred is None:
raise ValueError("Failed to extract mu, sigma, or raw probabilities from GRU model output.")
# Verify lengths # Verify lengths
n_seq = len(X_val_seq) n_seq = len(X_val_seq)
if not (len(mu_val_pred) == n_seq and len(sigma_val_pred) == n_seq and \ if not (len(mu_val_pred) == n_seq and len(sigma_val_pred) == n_seq and \
len(p_raw_val_pred) == n_seq and len(actual_ret_val_seq) == n_seq): p_raw_val_pred.shape[0] == n_seq and len(actual_ret_val_seq) == n_seq):
raise ValueError(f"Length mismatch after validation predictions: Expected {n_seq}, got mu={len(mu_val_pred)}, sigma={len(sigma_val_pred)}, p_raw={len(p_raw_val_pred)}, ret={len(actual_ret_val_seq)}") raise ValueError(f"Length mismatch after validation predictions: Expected {n_seq}, "
f"got mu={len(mu_val_pred)}, sigma={len(sigma_val_pred)}, "
f"p_raw={p_raw_val_pred.shape[0]}, ret={len(actual_ret_val_seq)}")
# 9. Calibrate Predictions using *loaded* optimal_T # 11. Calibrate Predictions using loaded parameters
optimal_T = gru_dependencies['optimal_T'] calib_method = gru_dependencies.get('calibration_method')
# Use a temporary calibrator instance calib_params = gru_dependencies.get('calibration_params')
temp_calibrator = Calibrator(edge_threshold=0.5) # Edge threshold doesn't matter here p_cal_val_pred = None
temp_calibrator.optimal_T = optimal_T
p_cal_val_pred = temp_calibrator.calibrate(p_raw_val_pred)
logger.info(f"Calibrated validation predictions using loaded T={optimal_T:.4f}.")
# 10. Return the necessary components for the TradingEnv if calib_method == "temperature" and calib_params is not None:
optimal_T = calib_params
# Temperature scaling needs logits
if logits_val_pred is not None:
scaled_logits = logits_val_pred / optimal_T
if self.use_ternary:
p_cal_val_pred = softmax(scaled_logits, axis=-1)
logger.info(f"Applied temperature scaling (T={optimal_T:.4f}) to ternary logits.")
else: # Binary - apply sigmoid
p_cal_val_pred = 1 / (1 + np.exp(-scaled_logits)) # Sigmoid
logger.info(f"Applied temperature scaling (T={optimal_T:.4f}) to binary logits.")
else:
logger.error(f"Cannot apply temperature scaling (method={calib_method}): Raw logits not available.")
p_cal_val_pred = p_raw_val_pred # Fallback to raw
elif calib_method == "vector" and calib_params is not None:
# Vector calibration needs logits
if logits_val_pred is not None:
try:
from gru_sac_predictor.src.calibrator_vector import VectorCalibrator # Ensure import
temp_vector_calibrator = VectorCalibrator()
temp_vector_calibrator.optimal_params = calib_params # Set loaded params
# Calibrate expects logits, returns probabilities
p_cal_val_pred = temp_vector_calibrator.calibrate(logits_val_pred)
logger.info(f"Calibrated validation predictions using loaded vector parameters.")
except ImportError:
logger.error("Cannot import VectorCalibrator. Vector calibration step skipped.")
p_cal_val_pred = p_raw_val_pred # Fallback to raw
except Exception as vec_e:
logger.error(f"Error during vector calibration application: {vec_e}", exc_info=True)
p_cal_val_pred = p_raw_val_pred # Fallback to raw
else:
logger.error(f"Cannot apply vector calibration (method={calib_method}): Raw logits not available.")
p_cal_val_pred = p_raw_val_pred # Fallback to raw
else: # No calibration or failed
logger.warning(f"Calibration method was '{calib_method}' or params were None or previous step failed. Using RAW predictions for SAC environment.")
p_cal_val_pred = p_raw_val_pred # Use the raw probs (softmaxed for ternary, P(up) for binary)
# Final check for calibrated probabilities
if p_cal_val_pred is None:
logger.error("Failed to obtain final (calibrated or raw) predictions. Cannot proceed.")
return None
# Verify shape of final probabilities
expected_final_dim = 3 if self.use_ternary else 1
if p_cal_val_pred.ndim == 2 and p_cal_val_pred.shape[-1] == expected_final_dim:
if not self.use_ternary: p_cal_val_pred = p_cal_val_pred.flatten() # Flatten binary case
elif p_cal_val_pred.ndim == 1 and not self.use_ternary and expected_final_dim == 1:
pass # Already flat for binary
else:
logger.error(f"Final probability array has unexpected shape: {p_cal_val_pred.shape}. Expected dim {expected_final_dim}.")
return None
# 12. Return the necessary components for the TradingEnv
logger.info("--- Successfully prepared validation data for SAC Environment ---") logger.info("--- Successfully prepared validation data for SAC Environment ---")
return mu_val_pred, sigma_val_pred, p_cal_val_pred, actual_ret_val_seq return mu_val_pred, sigma_val_pred, p_cal_val_pred, actual_ret_val_seq
@ -512,358 +755,206 @@ class SACTrainer:
logger.warning(f"SAC agent path not found for resume: {load_path}. Starting fresh.") logger.warning(f"SAC agent path not found for resume: {load_path}. Starting fresh.")
def _training_loop(self, agent: SACTradingAgent, env: TradingEnv) -> str | None: def _training_loop(self, agent: SACTradingAgent, env: TradingEnv) -> str | None:
"""Runs the main SAC training loop.""" """The main SAC training loop."""
buffer_max_size = self.sac_cfg.get('buffer_max_size', 100000) total_steps = self.sac_cfg.get('total_training_steps', 100000)
min_buffer_size = self.sac_cfg.get('min_buffer_size', 10000) start_steps = self.sac_cfg.get('start_steps', 10000)
update_after = self.sac_cfg.get('update_after', 1000)
update_every = self.sac_cfg.get('update_every', 50)
save_freq = self.sac_cfg.get('save_freq', 5000)
log_freq = self.sac_cfg.get('log_freq', 100)
buffer_capacity = self.sac_cfg.get('buffer_capacity', 1000000)
batch_size = self.sac_cfg.get('batch_size', 256) batch_size = self.sac_cfg.get('batch_size', 256)
total_training_steps = self.sac_cfg.get('total_training_steps', 100000)
# --- Initialize Replay Buffer (Potentially PER) --- # # Initialize Replay Buffer (Standard or Prioritized)
if self.use_per: if self.use_per:
logger.info(f"Initializing Prioritized Replay Buffer (Capacity={buffer_max_size}, alpha={self.per_alpha}, beta_start={self.per_beta_start}, beta_frames={self.per_beta_frames})") logger.info(f"Using Prioritized Replay Buffer (Capacity: {buffer_capacity})")
replay_buffer = PrioritizedReplayBuffer( replay_buffer = PrioritizedReplayBuffer(
buffer_max_size, capacity=buffer_capacity,
alpha=self.per_alpha, alpha=self.per_alpha, # Initial alpha
beta_start=self.per_beta_start, beta_start=self.per_beta_start,
beta_frames=self.per_beta_frames beta_frames=self.per_beta_frames
) )
else: else:
logger.info(f"Initializing Standard Replay Buffer (Deque, Capacity={buffer_max_size})") logger.info(f"Using Standard Replay Buffer (Capacity: {buffer_capacity})")
replay_buffer = collections.deque(maxlen=buffer_max_size) replay_buffer = collections.deque(maxlen=buffer_capacity)
replay_buffer.counter = 0 # Add counter for uniform sampling logic
# --- End Buffer Init --- #
# --- Oracle Seeding (Revision 4-B) --- # # TensorBoard setup
oracle_seeding_pct = self.sac_cfg.get('oracle_seeding_pct', 0.0) tb_callback = TensorBoard(log_dir=self.sac_tb_log_dir)
num_existing_samples = len(replay_buffer) # Count samples potentially loaded during resume # --- Revision 4: Set model for TensorBoard --- #
target_seed_steps = int(buffer_max_size * oracle_seeding_pct) # Check if agent has actor/critic models accessible
actual_seed_steps = max(0, target_seed_steps - num_existing_samples) # This depends heavily on SACTradingAgent implementation
# Assuming agent.actor and agent.critic1/2 are the models
if actual_seed_steps > 0: if hasattr(agent, 'actor') and hasattr(agent, 'critic1') and hasattr(agent, 'critic2'):
logger.info(f"Performing Oracle Seeding: Adding ~{actual_seed_steps} steps ({oracle_seeding_pct * 100:.1f}% target) with heuristic policy...") tb_callback.set_model(agent.actor) # Link to one model is often sufficient
logger.info("TensorBoard callback linked to SAC agent model (actor).")
# Need edge threshold from config (used in the heuristic action)
edge_threshold_heuristic = self.config.get('calibration', {}).get('edge_threshold')
if edge_threshold_heuristic is None:
logger.error("Cannot perform oracle seeding: 'calibration.edge_threshold' not found in config.")
elif edge_threshold_heuristic <= 0:
logger.warning(f"Edge threshold for heuristic is {edge_threshold_heuristic:.3f}. Oracle seeding action may be ill-defined. Using 0.01 instead.")
edge_threshold_heuristic = 0.01
else: else:
state = env.reset() # Start seeding from the beginning of the env data logger.warning("Could not link TensorBoard callback to agent models (actor/critic not found).")
n_seeded = 0 # --- End Revision 4 ---
# --- Initialize optional imputed transition logger --- #
imputed_log_path = os.path.join(self.sac_run_results_dir, 'sac_imputed_transitions.csv')
imputed_log_file = None
imputed_csv_writer = None
try: try:
for _ in tqdm(range(actual_seed_steps), desc="Oracle Seeding", file=sys.stdout, leave=False): imputed_log_file = open(imputed_log_path, 'w', newline='')
if len(replay_buffer) >= buffer_max_size: imputed_csv_writer = csv.writer(imputed_log_file)
logger.warning("Buffer full during oracle seeding.") imputed_csv_writer.writerow(['step', 'imputed_handling_mode', 'action', 'reward', 'position_before', 'position_after'])
break logger.info(f"Logging imputed transitions details to: {imputed_log_path}")
except IOError as e:
logger.error(f"Failed to open imputed transition log file {imputed_log_path}: {e}. Logging disabled.")
if imputed_log_file: imputed_log_file.close()
imputed_log_file = None
imputed_csv_writer = None
# --- End init imputed logger --- #
current_state_for_seeding = state # Keep original state before potential filtering
if self.state_filter:
state = self.state_filter(state, update=False) # Apply filter, don't update filter stats during seeding
# Heuristic action based on edge
# State: [mu, sigma, edge, |mu|/sigma, position] (assuming this order)
try:
# Ensure state has enough elements before indexing
if len(state) < 3:
logger.error(f"State vector too short ({len(state)} elements) for oracle seeding. Stopping seeding.")
break
edge = state[2] # Assuming edge is the 3rd element (index 2)
except IndexError:
logger.error(f"IndexError accessing edge (state[2]) during oracle seeding. State shape: {state.shape if hasattr(state, 'shape') else type(state)}. Stopping seeding.")
break
oracle_action = np.clip(edge / edge_threshold_heuristic, -1.0, 1.0)
next_state, reward, done, _ = env.step(oracle_action)
# Store experience with the state *before* filtering (as agent expects raw state)
experience = (current_state_for_seeding, oracle_action, reward, next_state, done)
if self.use_per:
# Mark as seeded
replay_buffer.add(replay_buffer.max_priority, experience, is_seeded=True)
elif hasattr(replay_buffer, 'counter'):
replay_buffer.append(experience)
replay_buffer.counter += 1
n_seeded += 1
state = next_state
if done:
# Reset env if it finishes during seeding, continue seeding if needed
logger.info("Environment finished during oracle seeding. Resetting.")
state = env.reset() state = env.reset()
logger.info(f"Oracle seeding completed. Added {n_seeded} experiences.") # Normalize initial state if filter is active
except Exception as seed_err:
logger.error(f"Error during oracle seeding loop: {seed_err}. Proceeding with {n_seeded} seeded samples.", exc_info=True)
else:
logger.info("Oracle seeding skipped (percentage=0 or buffer already seeded enough). Existing samples: {num_existing_samples}")
# --- End Oracle Seeding --- #
# --- Training Loop Setup --- #
start_learning_after_steps = self.sac_cfg.get('start_learning_after_steps', 1000)
save_checkpoint_freq = self.sac_cfg.get('save_checkpoint_freq_steps', 50000)
total_steps = self.sac_cfg.get('total_training_steps', 100000)
# Revision 5: Alpha annealing params
alpha_start = self.sac_cfg.get('per_alpha_start', 0.6)
alpha_end = self.sac_cfg.get('per_alpha_end', 0.4)
current_alpha = alpha_start # Initialize alpha
summary_writer = tf.summary.create_file_writer(self.sac_tb_log_dir)
episode_reward = 0
episode_steps = 0
episode_rewards_log = [] # For saving reward history
best_eval_score = -np.inf # Placeholder for potential periodic evaluation
final_agent_path = None
logger.info(f"Starting SAC training loop for {total_training_steps} steps...")
logger.info(f" Min buffer size: {min_buffer_size}, Batch size: {batch_size}, Updates/step: {updates_per_step}")
logger.info(f" Log interval: {log_interval}, Checkpoint interval: {checkpoint_interval}")
summary_writer = tf.summary.create_file_writer(self.sac_tb_log_dir)
pbar = tqdm(range(1, total_training_steps + 1), desc="SAC Training", file=sys.stdout)
last_saved_path = None # Track last successful save
# --- Main Training Loop --- #
for step in pbar:
# --- Revision 5: Update Annealed Alpha --- #
if self.use_per:
alpha_fraction = min(1.0, step / total_steps)
current_alpha = alpha_start + alpha_fraction * (alpha_end - alpha_start)
# Calculate Seed Decay factor for IS weights
seed_decay_factor = max(0.0, 1.0 - step / self.per_seed_decay_steps)
# --- End Revision 5 --- #
original_state = state # Keep original state for buffer
if self.state_filter: if self.state_filter:
state = self.state_filter(state, update=True) # Apply filter and update running stats state = self.state_filter(state, update=True) # Update filter with initial state
if step < start_learning_after_steps: total_reward = 0.0
# Use random actions during warmup to explore start_time = time.time()
action = np.random.uniform(env.action_space.low, env.action_space.high, size=env.action_space.shape)
logger.info(f"Starting SAC training loop for {total_steps} steps...")
for step in tqdm(range(total_steps), desc="SAC Training Steps"):
# Store state before taking action (for replay buffer)
state_before_action = state
position_before_action = env.current_position # Store for imputed log
# Select action
if step < start_steps:
action = env.action_space.sample()[0] # Sample random action
else: else:
# Get action from agent action = agent.select_action(state)
action, _ = agent.select_action(state)
# Environment step # Step the environment
next_state, reward, done, info = env.step(action[0]) # Env expects single float action next_state_raw, reward, done, info = env.step(action)
# Store experience in buffer (use original state) # Check if the step was skipped due to imputed bar handling
experience = (original_state, action, reward, next_state, done) is_skipped = info.get('is_imputed_step_skipped', False)
# Normalize next state if filter is active
if self.state_filter:
next_state = self.state_filter(next_state_raw, update=True) # Update filter
else:
next_state = next_state_raw
# Store transition ONLY if the step was NOT skipped
if not is_skipped:
if self.use_per: if self.use_per:
# Mark as seeded # Add with initial error=1 (or max priority) and is_seeded=False
replay_buffer.add(replay_buffer.max_priority, experience, is_seeded=True) # The error will be updated after the first training step on this sample.
elif hasattr(replay_buffer, 'counter'): replay_buffer.add(error=1.0, sample=(state_before_action, action, reward, next_state, done))
replay_buffer.append(experience) else:
replay_buffer.counter += 1 # Increment counter for deque replay_buffer.append((state_before_action, action, reward, next_state, done))
else:
logger.debug(f"Step {step}: Skipped adding transition to buffer due to imputed bar handling (mode=skip).")
# --- Log imputed transition details to CSV --- #
# Check if the *previous* step was imputed and handled by hold/penalty
# env.current_step was incremented *inside* env.step()
was_imputed_idx = env.current_step - 1
if 0 <= was_imputed_idx < env.n_steps and env.bar_imputed[was_imputed_idx] and not is_skipped and imputed_csv_writer:
# Don't log if skipped, log if hold/penalty was applied
imputed_handling_mode = self.config.sac.get('imputed_handling', 'unknown')
action_taken = env.current_position if imputed_handling_mode == 'hold' else action # Action applied or agent's intended action
position_after_action = env.current_position
try:
imputed_csv_writer.writerow([
was_imputed_idx, # Log the actual step number where imputation occurred
imputed_handling_mode,
f"{action_taken:.4f}",
f"{reward:.6f}", # Log the reward received for this step
f"{position_before_action:.4f}",
f"{position_after_action:.4f}"
])
except Exception as log_e:
logger.warning(f"Failed to write imputed transition to CSV: {log_e}")
# --- End Log imputed transition --- #
state = next_state state = next_state
current_episode_reward += reward total_reward += reward
current_episode_steps += 1
# Perform SAC updates # Perform SAC agent updates
if len(replay_buffer) >= min_buffer_size: if step >= update_after and step % update_every == 0:
for _ in range(updates_per_step): for update_i in range(update_every): # Perform multiple updates per interval
sample_indices, batch_with_seed_flags, importance_weights = None, None, None # Initialize
# --- Sample from Trainer's Buffer --- #
if self.use_per: if self.use_per:
sample_indices, batch_with_seed_flags, importance_weights = replay_buffer.sample(batch_size) if len(replay_buffer) > batch_size:
# --- Revision 5: Apply Seed Decay to IS Weights --- # idxs, batch_data, is_weights = replay_buffer.sample(batch_size, beta=replay_buffer.beta) # Use annealed beta
batch = [] # Store only the samples # Unpack batch_data which contains (sample, is_seeded) tuples
seed_mask = np.zeros_like(importance_weights, dtype=bool) batch = [item[0] for item in batch_data]
for i, (sample, is_seeded) in enumerate(batch_with_seed_flags): update_info = agent.update(batch, is_weights=is_weights, per_beta=replay_buffer.beta)
batch.append(sample) if update_info and 'td_errors' in update_info:
if is_seeded: # Update priorities using current annealed alpha
seed_mask[i] = True current_alpha = agent.get_current_per_alpha(step)
# Apply decay factor to IS weights of seeded samples replay_buffer.update_priorities(idxs, update_info['td_errors'], alpha=current_alpha)
# --- Revision 6: Log IS weight decay --- #
num_seeded_in_batch = np.sum(seed_mask)
if num_seeded_in_batch > 0:
original_weights_seeded = importance_weights[seed_mask].copy()
importance_weights[seed_mask] *= seed_decay_factor
weights_after_decay = importance_weights[seed_mask]
# Log periodically or if decay factor changes significantly
if step % self.log_interval == 0: # Align with other logging
logger.info(f"Step {step}: Applied IS decay factor {seed_decay_factor:.4f} to {num_seeded_in_batch} seeded samples in batch.")
# Optional: Log weight changes
# logger.debug(f" Weights before: {np.round(original_weights_seeded, 3)}")
# logger.debug(f" Weights after: {np.round(weights_after_decay, 3)}")
# --- End Revision 6 --- #
importance_weights_tensor = tf.convert_to_tensor(importance_weights, dtype=tf.float32)
else: # Uniform sampling from deque
current_buffer_size = min(replay_buffer.counter, replay_buffer.maxlen)
if current_buffer_size < batch_size:
sample_indices = np.random.choice(current_buffer_size, batch_size, replace=True)
else: else:
sample_indices = np.random.choice(current_buffer_size, batch_size, replace=False) continue # Not enough samples yet for PER
batch = [replay_buffer[i] for i in sample_indices] else: # Standard buffer
importance_weights_tensor = None # No IS weights for uniform if len(replay_buffer) > batch_size:
# --- End Sampling --- # indices = np.random.choice(len(replay_buffer), size=batch_size, replace=False)
batch = [replay_buffer[i] for i in indices]
state_batch, action_batch, reward_batch, next_state_batch, done_batch = map(np.stack, zip(*batch)) update_info = agent.update(batch)
# Apply state filter to sampled states if enabled
if self.state_filter:
state_batch_filtered = self.state_filter(state_batch, update=False)
next_state_batch_filtered = self.state_filter(next_state_batch, update=False)
else: else:
state_batch_filtered = state_batch continue # Not enough samples yet
next_state_batch_filtered = next_state_batch
# Convert batch to tensors # Log training metrics (losses, Q-values, alpha) to TensorBoard
state_tensor = tf.convert_to_tensor(state_batch_filtered, dtype=tf.float32) if update_info and step % log_freq == 0 and update_i == 0: # Log once per interval
action_tensor = tf.convert_to_tensor(action_batch, dtype=tf.float32) with tb_callback.writer.as_default():
reward_tensor = tf.convert_to_tensor(reward_batch, dtype=tf.float32) for key, value in update_info.items():
next_state_tensor = tf.convert_to_tensor(next_state_batch_filtered, dtype=tf.float32) if key != 'td_errors': # Don't log TD errors directly
done_tensor = tf.convert_to_tensor(done_batch, dtype=tf.float32) tf.summary.scalar(f'sac/{key}', value, step=step)
# logger.debug(f"Step {step}: Logged SAC metrics to TensorBoard.")
# --- Call agent's train method with the batch --- # # Check environment done state
loss_info = agent.train(
state_tensor,
action_tensor,
reward_tensor,
next_state_tensor,
done_tensor,
importance_weights=importance_weights_tensor # Pass potentially decayed weights for PER
)
# --- End Agent Update Call --- #
# --- Revision 5: Update PER priorities with current alpha --- #
if self.use_per and loss_info is not None:
td_errors = loss_info.get('td_errors') # Get TD errors from loss_info dict
if td_errors is not None:
replay_buffer.update_priorities(sample_indices, td_errors, current_alpha)
# --- Revision 6: Log TD Error Distribution --- #
# Convert to numpy if it's a tensor
if tf.is_tensor(td_errors):
td_errors = td_errors.numpy()
td_error_abs = np.abs(td_errors)
log_hist_interval = self.sac_cfg.get('log_hist_interval', 5000)
# Log percentiles every step where update happens
step_info_for_logging = {
'td_error_abs_p25': np.percentile(td_error_abs, 25),
'td_error_abs_p50': np.percentile(td_error_abs, 50),
'td_error_abs_p75': np.percentile(td_error_abs, 75),
'td_error_abs_p95': np.percentile(td_error_abs, 95),
'td_error_abs_max': np.max(td_error_abs),
'td_error_abs_mean': np.mean(td_error_abs)
}
# Log histogram to TensorBoard periodically
if log_hist_interval > 0 and step % log_hist_interval == 0:
try:
with summary_writer.as_default(step=step):
tf.summary.histogram('td_error_abs_distribution', td_error_abs, description='Absolute TD Errors')
summary_writer.flush()
logger.debug(f"Step {step}: Logged TD error histogram to TensorBoard.")
except Exception as hist_err:
logger.warning(f"Failed to log TD error histogram to TensorBoard at step {step}: {hist_err}")
# --- End Revision 6 --- #
else:
logger.warning(f"Step {step}: TD errors not found in loss_info from agent.train(). Cannot update PER priorities or log distribution.")
# --- End Revision 5 --- #
# Logging and Checkpointing
if step % log_interval == 0 and len(replay_buffer) >= min_buffer_size:
with summary_writer.as_default(): # Use context manager
# Log loss_info if available
if 'loss_info' in locals() and loss_info:
for k, v in loss_info.items():
# Skip logging raw TD errors tensor here
if k != 'td_errors':
tf.summary.scalar(f'loss/{k}', v, step=step)
tf.summary.scalar('alpha', agent.log_alpha.numpy().item(), step=step)
tf.summary.scalar('buffer/beta', replay_buffer.beta, step=step)
tf.summary.scalar('buffer/per_alpha', current_alpha, step=step)
# Log TD error percentiles calculated in Revision 6
if 'step_info_for_logging' in locals():
for k, v in step_info_for_logging.items():
tf.summary.scalar(f'td_error/{k}', v, step=step)
logger.debug(f"Step {step}: Logged losses, alpha, PER params, TD error stats.")
if step % checkpoint_interval == 0:
# Save checkpoint
chkpt_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step}')
meta_data = {'step': step, 'edge_threshold': edge_threshold_heuristic if 'edge_threshold_heuristic' in locals() else self.config.get('calibration',{}).get('edge_threshold')}
agent.save(chkpt_path, meta_data=meta_data)
# --- Save State Filter (Task 5.2) --- #
if self.state_filter:
filter_state = self.state_filter.get_state()
np.savez(os.path.join(chkpt_path, 'state_filter.npz'), **filter_state)
# --- End Save State Filter --- #
logger.info(f"Saved SAC checkpoint at step {step} to {chkpt_path}")
# Episode end handling
if done: if done:
logger.debug(f"Episode {episode_count} finished after {current_episode_steps} steps. Reward: {current_episode_reward:.2f}. Buffer size: {len(replay_buffer)}.")
episode_rewards_log.append({'episode': episode_count, 'total_step': step, 'episode_reward': current_episode_reward, 'episode_steps': current_episode_steps})
with summary_writer.as_default(): # Use context manager
tf.summary.scalar('reward/episode', current_episode_reward, step=step)
tf.summary.scalar('steps/episode', current_episode_steps, step=step)
# Reset for next episode
state = env.reset() state = env.reset()
current_episode_reward = 0.0 # Normalize reset state
current_episode_steps = 0
else:
state = next_state
# Update progress bar description
if step % 100 == 0 and len(replay_buffer) >= min_buffer_size:
last_reward = episode_rewards_log[-1]['episode_reward'] if episode_rewards_log else np.nan
pbar.set_description(f"SAC Training | Ep Reward (Last): {last_reward:.2f} | Buffer: {len(replay_buffer)}")
# Save agent checkpoint
if (step + 1) % checkpoint_interval == 0 or step == total_training_steps - 1:
save_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step + 1}')
agent.save_weights(save_path)
logger.info(f"SAC agent weights saved at step {step + 1} to {save_path}")
# --- Save State Filter (Task 5.2) --- #
if self.state_filter: if self.state_filter:
state_filter_path = os.path.join(self.sac_run_models_dir, f'state_filter_step_{step + 1}.npz') state = self.state_filter(state, update=True)
total_reward = 0.0 # Reset episodic reward
# Save agent checkpoint periodically
if step % save_freq == 0 and step > 0:
save_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step}')
try: try:
self.state_filter.save_npz(state_filter_path) agent.save(save_path)
logger.info(f"State filter saved to {state_filter_path}") # Also save state filter if used
if self.state_filter:
filter_path = os.path.join(save_path, 'state_filter.pkl')
joblib.dump(self.state_filter, filter_path)
logger.info(f"Saved state filter to {filter_path}")
except Exception as e: except Exception as e:
logger.error(f"Failed to save state filter: {e}") logger.error(f"Failed to save agent/filter checkpoint at step {step} to {save_path}: {e}", exc_info=True)
# --- End Save State Filter --- #
last_saved_path = save_path # Update last saved path
# Also save the reward log periodically
if episode_rewards_log:
rewards_df = pd.DataFrame(episode_rewards_log)
rewards_log_path = os.path.join(self.sac_run_logs_dir, 'episode_rewards.csv')
try:
rewards_df.to_csv(rewards_log_path, index=False)
except Exception as e:
logger.error(f"Failed to save episode rewards log: {e}")
# --- Final Save --- # # --- Final Save --- #
pbar.close()
summary_writer.close()
final_save_path = os.path.join(self.sac_run_models_dir, 'sac_agent_final') final_save_path = os.path.join(self.sac_run_models_dir, 'sac_agent_final')
agent.save_weights(final_save_path) try:
logger.info(f"Final SAC agent weights saved to {final_save_path}") agent.save(final_save_path)
# Save final state filter
if self.state_filter: if self.state_filter:
final_state_filter_path = os.path.join(self.sac_run_models_dir, 'state_filter_final.npz') filter_path = os.path.join(final_save_path, 'state_filter.pkl')
try: joblib.dump(self.state_filter, filter_path)
self.state_filter.save_npz(final_state_filter_path) logger.info(f"Saved final state filter to {filter_path}")
logger.info(f"Final state filter saved to {final_state_filter_path}") self.last_saved_agent_path = final_save_path # Store path for potential return
except Exception as e: except Exception as e:
logger.error(f"Failed to save final state filter: {e}") logger.error(f"Failed to save final agent/filter checkpoint to {final_save_path}: {e}", exc_info=True)
self.last_saved_agent_path = None
# Save final rewards log end_time = time.time()
if episode_rewards_log: training_duration = end_time - start_time
rewards_df = pd.DataFrame(episode_rewards_log) logger.info(f"SAC training loop finished in {training_duration:.2f} seconds.")
rewards_log_path = os.path.join(self.sac_run_logs_dir, 'episode_rewards.csv') logger.info(f"Final agent checkpoint saved to: {self.last_saved_agent_path}")
# --- Close imputed transition log file --- #
if imputed_log_file:
try: try:
rewards_df.to_csv(rewards_log_path, index=False) imputed_log_file.close()
logger.info(f"Final episode rewards log saved to {rewards_log_path}") logger.info(f"Closed imputed transition log file: {imputed_log_path}")
except Exception as e: except Exception as e:
logger.error(f"Failed to save final episode rewards log: {e}") logger.error(f"Error closing imputed transition log file: {e}")
# --- End close log file --- #
return final_save_path if os.path.exists(final_save_path) else last_saved_path return self.last_saved_agent_path
def train(self, gru_run_id_for_sac: str) -> str | None: def train(self, gru_run_id_for_sac: str) -> str | None:
""" """
@ -918,12 +1009,10 @@ class SACTrainer:
initial_lr=self.sac_cfg.get('actor_lr', 3e-4), initial_lr=self.sac_cfg.get('actor_lr', 3e-4),
lr_decay_rate=self.sac_cfg.get('lr_decay_rate', 0.96), lr_decay_rate=self.sac_cfg.get('lr_decay_rate', 0.96),
decay_steps=self.sac_cfg.get('decay_steps', 100000), decay_steps=self.sac_cfg.get('decay_steps', 100000),
buffer_capacity=self.sac_cfg.get('buffer_max_size', 100000),
ou_noise_stddev=self.sac_cfg.get('ou_noise_stddev', 0.2), ou_noise_stddev=self.sac_cfg.get('ou_noise_stddev', 0.2),
alpha=self.sac_cfg.get('alpha', 0.2), alpha=self.sac_cfg.get('alpha', 0.2),
alpha_auto_tune=self.sac_cfg.get('alpha_auto_tune', True), alpha_auto_tune=self.sac_cfg.get('alpha_auto_tune', True),
target_entropy=self.sac_cfg.get('target_entropy', -1.0 * env.action_dim), target_entropy=self.sac_cfg.get('target_entropy', -1.0 * env.action_dim),
min_buffer_size=self.sac_cfg.get('min_buffer_size', 1000),
edge_threshold_config=current_edge_threshold, # Pass edge threshold edge_threshold_config=current_edge_threshold, # Pass edge threshold
# --- Pass Env Params (Task 5.6) --- # # --- Pass Env Params (Task 5.6) --- #
reward_scale_config=reward_scale, reward_scale_config=reward_scale,