feat: Add imputed transition logging to SACTrainer and update docs
This commit is contained in:
parent
c3526bb9f6
commit
35c60ae848
@ -1,6 +1,6 @@
|
|||||||
# GRU + SAC Crypto Trading System (v3 - Consolidated & Enhanced)
|
# GRU + SAC Crypto Trading System (v3 - Refactored & Enhanced)
|
||||||
|
|
||||||
This project implements a cryptocurrency trading system using a GRU model for market prediction and a Soft Actor-Critic (SAC) agent for position sizing. This version reflects significant refactoring, feature additions, and the consolidation of model logic.
|
This project implements a cryptocurrency trading system using a GRU model for market prediction and a Soft Actor-Critic (SAC) agent for position sizing. This version reflects significant refactoring for modularity, feature additions, and the consolidation of model logic.
|
||||||
|
|
||||||
The core idea is to decouple prediction and action:
|
The core idea is to decouple prediction and action:
|
||||||
1. A **GRU model** (v2 or v3 architecture, selected via config) forecasts future log-returns (μ̂) and class probabilities (binary p(up) or ternary p(down, flat, up)).
|
1. A **GRU model** (v2 or v3 architecture, selected via config) forecasts future log-returns (μ̂) and class probabilities (binary p(up) or ternary p(down, flat, up)).
|
||||||
@ -11,32 +11,71 @@ This approach aims for a robust system where the RL agent focuses solely on risk
|
|||||||
|
|
||||||
## Key Features & Enhancements
|
## Key Features & Enhancements
|
||||||
|
|
||||||
* **Consolidated GRU Logic:** Both v2 and v3 GRU model architectures are now implemented and managed within `src/gru_model_handler.py`.
|
* **Modular Pipeline Structure:** Pipeline logic is refactored into stage-specific functions within `src/pipeline_stages/` for improved readability, maintainability, and testability (see Project Structure).
|
||||||
* **Walk-Forward Validation:** Replaces static train/val/test splits with a robust walk-forward validation framework (`TradingPipeline.execute`, `_generate_walk_forward_folds`) for more realistic performance estimation.
|
* **Consolidated GRU Logic:** Both v2 and v3 GRU model architectures are implemented and managed within `src/gru_model_handler.py`.
|
||||||
* **Hyperparameter Optimization (Optuna):** Integrated Optuna sweep for GRU hyperparameters (`src/gru_hyper_tuner.py`) with restricted search space, configurable objective (`edge_acc - brier`), and Keras callback for efficient pruning based on `val_loss`.
|
* **Walk-Forward Validation:** Robust walk-forward validation framework (`TradingPipeline.execute`, `_generate_walk_forward_folds`) for realistic performance estimation.
|
||||||
|
* **Hyperparameter Optimization (Optuna):** Integrated Optuna sweep for GRU hyperparameters (`src/gru_hyper_tuner.py`) with configurable search space, objective, and pruning.
|
||||||
* **Advanced Calibration:**
|
* **Advanced Calibration:**
|
||||||
* Supports Temperature and Vector Scaling (`calibration.method`) with optional L2 regularization (`calibration.l2_lambda`).
|
* Supports Temperature and Vector Scaling (`calibration.method`) with optional L2 regularization (`calibration.l2_lambda`).
|
||||||
* Optimizes edge threshold via Youden's J on validation data (`calibration.optimize_edge_threshold`). Saved per fold (`optimized_edge_threshold_fold_N.txt`) and used consistently.
|
* Optimizes edge threshold via Youden's J on validation data (`calibration.optimize_edge_threshold`).
|
||||||
* **Rolling Calibration (Experimental):** Implemented within `Backtester` to refit the calibrator periodically during the backtest (`calibration.rolling_enabled`, `recalibrate_every_n`, `recalibration_window`). Uses the **static** calibration from training time if SAC training is active to prevent lookahead.
|
* **Rolling Calibration (Experimental):** Implemented within `Backtester` to refit the calibrator periodically during the backtest (`calibration.rolling_enabled`, `recalibrate_every_n`, `recalibration_window`).
|
||||||
* **Coverage Alarm (ECE-based):** Optional alarm triggers early recalibration if **Expected Calibration Error (ECE)** exceeds a threshold (`calibration.coverage_alarm_enabled`, `ece_recalibration_threshold`).
|
* **Coverage Alarm (ECE-based):** Optional alarm triggers early recalibration if **Expected Calibration Error (ECE)** exceeds a threshold (`calibration.coverage_alarm_enabled`, `ece_recalibration_threshold`).
|
||||||
* **Prioritized Experience Replay (PER):** Implemented for SAC training (`sac.use_per`) with **TD-error clipping** and **alpha annealing** (linear decay). Logs TD error distribution statistics.
|
* **Prioritized Experience Replay (PER):** Implemented for SAC training (`sac.use_per`) with TD-error clipping and alpha annealing.
|
||||||
* **SAC Enhancements:** Reward scaling, state normalization (`MeanStdFilter`), configurable **action penalty** (default: `0.01 / transaction_cost`), oracle seeding with **Importance Sampling (IS) weight decay** (`per_seed_decay_steps`).
|
* **SAC Enhancements:** Reward scaling, state normalization (`MeanStdFilter`), configurable action penalty, oracle seeding with Importance Sampling (IS) weight decay.
|
||||||
* **Refined Validation Gates:** Configurable thresholds (`validation_gates`) for:
|
* **Refined Validation Gates:** Configurable thresholds (`validation_gates`) for baseline checks, GRU validation, fold backtest performance, and final release decisions.
|
||||||
* **Baseline Gate:** Checks Logistic Regression CI on **raw/engineered** training features before scaling.
|
* **Micro-structure Features:** Added bar-level features (`FeatureEngineer._add_microstructure_features`) with NaN guards.
|
||||||
* **GRU Gate:** Checks Edge Acc CI and Brier score on validation set after calibration, using the fold's **determined edge threshold**.
|
* **Leakage Guard:** Feature calculations use `shift(1)`. Selection includes correlation check against future returns. Minimal whitelist applied before VIF.
|
||||||
* **Final Release Decision:** Checks aggregated metrics (e.g., **median Sharpe ≥ 1.3**, **≥ 75% successful folds**) across all successful folds. Backtest gate failures *per fold* are logged but do not halt the entire pipeline.
|
* **Configuration:** Centralized and expanded `config.yaml`.
|
||||||
* **Micro-structure Features:** Added bar-level features (`FeatureEngineer._add_microstructure_features`) with **NaN guards** for robustness.
|
|
||||||
* **Leakage Guard:** Feature calculations use `shift(1)`. Selection occurs on **raw/engineered features** and includes correlation check (`corr(ret+h, feat_t-1)`) against future returns. Minimal whitelist applied before VIF.
|
|
||||||
* **Configuration:** Centralized and expanded `config.yaml` with annotations for mutually exclusive options (e.g., walk-forward vs static split).
|
|
||||||
* **Output Management:** Standardized output structure via `IOManager` and `LoggerSetup`.
|
* **Output Management:** Standardized output structure via `IOManager` and `LoggerSetup`.
|
||||||
* **SAC Agent Aggregation:** Optional post-processing step to average weights from agents trained across successful folds (`TradingPipeline.aggregate_sac_agents`, `sac_aggregation.enabled`).
|
* **SAC Agent Aggregation:** Optional post-processing step to average weights from agents trained across successful folds.
|
||||||
|
|
||||||
|
## Data Quality
|
||||||
|
|
||||||
|
This system includes mechanisms to handle potential missing data points (bars) in the input time series. This is crucial for maintaining data integrity and preventing errors during feature engineering and model training.
|
||||||
|
|
||||||
|
**Handling Missing Bars:**
|
||||||
|
|
||||||
|
* **Detection:** The pipeline automatically detects missing bars based on the expected `data.bar_frequency` (e.g., "1T" for 1 minute) after initial data loading.
|
||||||
|
* **Reporting:** A warning is logged detailing the total number of missing bars found and the length of the longest consecutive gap. A summary report (`missing_bars_summary.json`) is saved in the run's results directory.
|
||||||
|
* **Filling Strategies:** Several strategies are available, configured via `data.missing.strategy`:
|
||||||
|
* `"drop"`: No filling is performed. Missing bars remain gaps or NaNs. (Use with caution).
|
||||||
|
* `"neutral"`: Forward-fills the 'close' price, sets 'open', 'high', 'low' equal to the filled 'close', and sets 'volume' to 0 for imputed bars.
|
||||||
|
* `"ffill"`: Forward-fills all OHLCV columns, then back-fills any remaining NaNs at the beginning.
|
||||||
|
* `"interpolate"`: Interpolates missing values using the method specified in `data.missing.interpolate.method` (e.g., 'linear') up to a `limit` defined in `data.missing.interpolate.limit`.
|
||||||
|
* **Imputed Flag:** After filling, a boolean column `bar_imputed` is added to the DataFrame, marking rows that were originally missing.
|
||||||
|
* **Max Gap Check:** The pipeline will raise an error if the longest detected consecutive gap exceeds `data.missing.max_gap`.
|
||||||
|
|
||||||
|
**Impact on Downstream Components:**
|
||||||
|
|
||||||
|
* **Feature Engineering:** Features are calculated on the potentially gap-filled data.
|
||||||
|
* **Sequence Creation:** Sequences containing imputed bars can be optionally dropped before GRU training, controlled by `gru.drop_imputed_sequences`.
|
||||||
|
* **GRU Model:** The `bar_imputed` flag is included as a feature input to the GRU model, allowing it to potentially learn patterns related to imputed data.
|
||||||
|
* **SAC Environment:** The `TradingEnv` is aware of imputed bars. The behavior during an imputed step is controlled by `sac.imputed_handling`:
|
||||||
|
* `"skip"`: The environment skips the step, no action is taken, no reward is given, and the transition is not added to the replay buffer.
|
||||||
|
* `"hold"`: The agent's action is overridden to maintain its current position. The step proceeds normally otherwise (reward calculated based on held position).
|
||||||
|
* `"penalty"`: The agent's chosen action is taken, but a penalty reward (based on `sac.action_penalty`) is applied instead of the normal PnL reward.
|
||||||
|
|
||||||
|
**Recommended Defaults:**
|
||||||
|
|
||||||
|
Using `"neutral"` or `"ffill"` for `strategy` is generally recommended for continuous time series. `max_gap` should be set to a reasonably small number (e.g., 5-10) to avoid filling excessively long gaps with potentially inaccurate data. For the SAC environment, `"hold"` or `"skip"` are common choices, depending on whether you want the agent to explicitly learn from imputed steps (or lack thereof).
|
||||||
|
|
||||||
## System Design & Workflow
|
## System Design & Workflow
|
||||||
|
|
||||||
The system is orchestrated by `run.py`, which sets up logging and I/O via `LoggerSetup` and `IOManager`, then instantiates and executes the `TradingPipeline` class (`src/trading_pipeline.py`). The pipeline follows a sequence of steps, potentially looped for Walk-Forward validation.
|
The system is orchestrated by `run.py`, which sets up logging and I/O via `LoggerSetup` and `IOManager`, then instantiates and executes the `TradingPipeline` class (`src/trading_pipeline.py`). The pipeline follows a sequence of steps, potentially looped for Walk-Forward validation.
|
||||||
|
|
||||||
|
### Pipeline Stages (Refactored)
|
||||||
|
|
||||||
|
The core logic for each step in the pipeline has been moved into dedicated functions within the `src/pipeline_stages/` directory. The `TradingPipeline` class now acts primarily as an orchestrator, calling these stage functions in sequence and managing the overall state and data flow.
|
||||||
|
|
||||||
|
* **`src/pipeline_stages/data_processing.py`**: Handles loading, initial preprocessing, feature engineering, labeling, and data splitting logic for each fold.
|
||||||
|
* **`src/pipeline_stages/feature_processing.py`**: Manages feature scaling, selection (L1+VIF), and pruning based on the selected whitelist.
|
||||||
|
* **`src/pipeline_stages/sequence_creation.py`**: Creates input sequences suitable for the GRU model from the processed feature data.
|
||||||
|
* **`src/pipeline_stages/modelling.py`**: Contains functions for training/loading the GRU model (including hyperparameter tuning), calibrating probabilities (Temperature/Vector Scaling, edge threshold optimization), training/loading the SAC agent, and aggregating SAC agents.
|
||||||
|
* **`src/pipeline_stages/evaluation.py`**: Includes functions for running baseline checks (Logistic Regression), performing GRU validation checks (Edge Accuracy, Brier Score), and executing the main backtest simulation (instantiating and running the `Backtester`).
|
||||||
|
|
||||||
|
### Workflow Diagram
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
%%{init: {'themeVariables': { 'fontSize': '26px' }}}%%
|
|
||||||
graph TD
|
graph TD
|
||||||
A[run.py: Init Logger/IOManager/Config] --> B(TradingPipeline);
|
A[run.py: Init Logger/IOManager/Config] --> B(TradingPipeline);
|
||||||
|
|
||||||
@ -47,21 +86,21 @@ graph TD
|
|||||||
D --> F[Select Fold Data];
|
D --> F[Select Fold Data];
|
||||||
E --> F;
|
E --> F;
|
||||||
|
|
||||||
subgraph Fold Processing [Fold Processing]
|
subgraph Fold Processing [Fold Processing - Calls Stage Functions]
|
||||||
direction TB
|
direction TB
|
||||||
F --> G[Engineer Features];
|
F --> G[data_processing: Engineer Features];
|
||||||
G --> H[Define Labels & Align];
|
G --> H[data_processing: Define Labels & Align];
|
||||||
H --> I[Split Fold Data];
|
H --> I[data_processing: Split Fold Data];
|
||||||
I --> J1[Baseline Check];
|
I --> J1[evaluation: Baseline Check];
|
||||||
J1 -- Pass --> L[Select Features];
|
J1 -- Pass --> L[feature_processing: Select Features];
|
||||||
L --> K[Scale Features];
|
L --> K[feature_processing: Scale Features];
|
||||||
K --> M[Prune Scaled Features];
|
K --> M[feature_processing: Prune Scaled Features];
|
||||||
M --> N[Create Sequences];
|
M --> N[sequence_creation: Create Sequences];
|
||||||
N --> O[Train/Load GRU];
|
N --> O[modelling: Train/Load GRU];
|
||||||
O --> P[Calibrate Probabilities];
|
O --> P[modelling: Calibrate Probabilities];
|
||||||
P --> R[GRU Validation Gate];
|
P --> R[evaluation: GRU Validation Gate];
|
||||||
R -- Pass --> S[Train/Load SAC Agent];
|
R -- Pass --> S[modelling: Train/Load SAC Agent];
|
||||||
S --> T[Run Backtest];
|
S --> T[evaluation: Run Backtest];
|
||||||
T --> U[Record Fold Results];
|
T --> U[Record Fold Results];
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -73,172 +112,144 @@ graph TD
|
|||||||
|
|
||||||
B --> Walk-ForwardLoop;
|
B --> Walk-ForwardLoop;
|
||||||
X1 --> Y[Aggregate Fold Metrics];
|
X1 --> Y[Aggregate Fold Metrics];
|
||||||
Y --> Z[Aggregate SAC Agents];
|
Y --> Z[modelling: Aggregate SAC Agents];
|
||||||
Z --> Z1[Final Release Decision];
|
Z --> Z1[Final Release Decision];
|
||||||
Z1 --> Z_End([End Pipeline Run]);
|
Z1 --> Z_End([End Pipeline Run]);
|
||||||
|
|
||||||
```
|
```
|
||||||
*Diagram outlines the consolidated v3 pipeline flow after `revisions.txt` modifications.*
|
*Diagram outlines the consolidated v3 pipeline flow, highlighting calls to stage functions.*
|
||||||
|
|
||||||
### Detailed Steps (Walk-Forward Enabled)
|
### Detailed Steps (Walk-Forward Enabled)
|
||||||
|
|
||||||
1. **Initialization (`run.py`):**
|
1. **Initialization (`run.py`):** Sets up infrastructure (config, logging, IO). Instantiates `TradingPipeline`.
|
||||||
* Parses args (`--config`, etc.).
|
2. **Data Loading (`TradingPipeline` calls `data_processing.load_and_preprocess`):** Loads the *entire* raw dataset.
|
||||||
* Loads `config.yaml`.
|
3. **Fold Generation (`TradingPipeline._generate_walk_forward_folds`):** Yields date ranges for each fold.
|
||||||
* Initializes `IOManager`, `LoggerSetup`.
|
4. **Fold Loop (`TradingPipeline.execute`):** Iterates through folds.
|
||||||
* Instantiates `TradingPipeline`.
|
* **Select Fold Data:** Extracts raw data for the current fold range.
|
||||||
2. **Data Loading (`TradingPipeline.load_and_preprocess_data`):**
|
* **Feature Engineering (`TradingPipeline` calls `data_processing.engineer_features_for_fold`):** Computes features on fold data.
|
||||||
* Loads the *entire* raw dataset specified in the config.
|
* **Labeling (`TradingPipeline` calls `data_processing.define_labels_and_align_fold`):** Calculates labels.
|
||||||
3. **Fold Generation (`TradingPipeline._generate_walk_forward_folds`):**
|
* **Split Fold Data (`TradingPipeline` calls `data_processing.split_data_fold`):** Splits into `train`, `val`, `test`.
|
||||||
* Based on `walk_forward` config (train/val/test/step days), yields date ranges for each fold.
|
* **Baseline Gate (`TradingPipeline` calls `evaluation.run_baseline_checks_fold`):** Runs Logistic Regression check. Halts fold on failure.
|
||||||
4. **Fold Loop (`TradingPipeline.execute`):** Iterates through generated folds.
|
* **Select Features (`TradingPipeline` calls `feature_processing.select_features_fold`):** Performs selection. Saves whitelist.
|
||||||
* **Select Fold Data:** Extracts raw data corresponding to the current fold's (Train+Val+Test) date range.
|
* **Scale Features (`TradingPipeline` calls `feature_processing.scale_features_fold`):** Fits/applies scaler. Saves scaler.
|
||||||
* **Feature Engineering (`engineer_features`):** Computes base, TA, and micro-structure features (with NaN guards) on the fold's raw data (using `shift(1)` for time-dependent features).
|
* **Prune Features (`TradingPipeline` calls `feature_processing.prune_features_fold`):** Prunes scaled data using whitelist.
|
||||||
* **Labeling (`define_labels_and_align`):** Calculates forward returns and target labels (binary/ternary) for the fold's engineered data.
|
* **Create Sequences (`TradingPipeline` calls `sequence_creation.create_sequences_fold`):** Creates GRU input sequences.
|
||||||
* **Split Fold Data (`split_data`):** Splits the fold's labeled data into `train`, `val`, and `test` sets based on the fold's date ranges. Stores results like `self.X_train_raw`, `self.y_val`, `self.df_test_original`.
|
* **Train/Load GRU (`TradingPipeline` calls `modelling.train_or_load_gru_fold`):** Handles training, Optuna sweep, or loading. Handles re-processing (scale/prune/sequence) internally if loaded scaler differs. Saves model/params.
|
||||||
* **Baseline Gate (`run_baseline_checks`):** Trains/validates Logistic Regression on fold's ***raw/engineered*** training features. Exits fold if CI lower bound < config threshold (`validation_gates.baseline`). Saves `baseline_report_fold_N.txt`.
|
* **Calibrate Probabilities (`TradingPipeline` calls `modelling.calibrate_probabilities_fold`):** Fits calibrator, optimizes edge threshold. Saves parameters.
|
||||||
* **Select Features (`select_and_prune_features` - Selection Part):** Performs leakage check (`corr(ret+h, feat_t-1)`) and L1 + VIF selection (applying minimal whitelist *before* VIF) on fold's *raw/engineered* training features. Saves `final_whitelist_fold_N.json`.
|
* **GRU Validation Gate (`TradingPipeline` calls `evaluation.run_gru_validation_checks_fold`):** Checks calibrated validation predictions. Halts fold on failure.
|
||||||
* **Scale Features (`scale_features`):** Fits `StandardScaler` on fold's *raw* training features (numeric only). Scales train, val, test features (`X_train_scaled`, etc.). Saves `feature_scaler_fold_N.joblib`.
|
* **Train/Load SAC (`TradingPipeline` calls `modelling.train_or_load_sac_fold`):** Handles SAC training (calling `SACTrainer`) or determines load path.
|
||||||
* **Prune Features (`select_and_prune_features` - Pruning Part):** Prunes the *scaled* data splits (`X_train_scaled` -> `X_train_pruned`) using the `final_whitelist` determined earlier.
|
* **Run Backtest (`TradingPipeline` calls `evaluation.run_backtest_fold`):** Instantiates `Backtester`, runs simulation, handles rolling calibration, performs backtest validation checks. Halts fold on failure.
|
||||||
* **Create Sequences (`create_sequences`):** Converts the fold's *pruned, scaled* train/val/test sets into sequences (`X_train_seq`, etc.).
|
* **Store Fold Results:** Appends metrics and SAC agent path (if trained) for aggregation.
|
||||||
* **Train/Load GRU (`train_or_load_gru`):**
|
5. **Aggregate Metrics (`TradingPipeline.aggregate_fold_metrics`):** Calculates summary statistics across successful folds.
|
||||||
* If `sweep_enabled`, runs `GRUHyperTuner` (with updated objective, restricted search space, Keras callback for pruning on `val_loss`, logging objective components) to find best hyperparameters using fold's train/val sequences. Trains final model with best params. Saves best params JSON and Optuna plots.
|
6. **Aggregate SAC Agents (`TradingPipeline` calls `modelling.aggregate_sac_agents`):** (Optional) Averages weights of successful fold agents.
|
||||||
* If not sweeping, trains/loads GRU using config defaults.
|
7. **Final Release Decision (`TradingPipeline.final_release_decision`):** Evaluates aggregated metrics against final criteria.
|
||||||
* Saves the final fold GRU model (`gru_model_fold_N.keras`), history, and learning curve plot.
|
8. **Log Final Status:** Reports overall pipeline success/failure.
|
||||||
* **Calibrate Probabilities (`calibrate_probabilities`):**
|
|
||||||
* Fits the selected calibrator (Temp/Vector with optional L2 reg) on the fold's validation sequences.
|
|
||||||
* If `optimize_edge_threshold`, calculates optimal threshold using Youden's J, stores it internally (`self.optimized_edge_threshold`), and saves it (`optimized_edge_threshold_fold_N.txt`).
|
|
||||||
* Saves fold calibration parameters (`calibration_{temp/vector}_fold_N.npy`).
|
|
||||||
* **GRU Validation Gate (`_perform_gru_validation_checks`):** Checks edge-filtered accuracy CI and Brier score against updated config thresholds (`validation_gates.gru`) using the fold's determined `optimized_edge_threshold`. Exits fold if failed.
|
|
||||||
* **Train/Load SAC (`train_or_load_sac`):**
|
|
||||||
* If `train_sac`, initializes `SACTrainer` using a config copy (passing the fold's `optimized_edge_threshold` and disabling rolling calibration if active). Trains agent handling PER (with clipping, alpha annealing), Oracle Seeding (with IS weight decay), State Normalization, Action Penalty (`0.01/cost`). Saves agent, filter state, logs, plots in a fold-specific `sac_train_...` dir.
|
|
||||||
* If loading, determines path from config.
|
|
||||||
* **Run Backtest (`run_backtest`):**
|
|
||||||
* Initializes `Backtester`.
|
|
||||||
* Passes the fold's SAC agent path, test sequences, GRU handler, *initial* calibration state, the fold's `optimized_edge_threshold`, and original test prices.
|
|
||||||
* If `rolling_enabled`, the backtester uses raw predictions to refit calibrator, potentially triggered early by **ECE Coverage Alarm** (`ECE > config threshold`).
|
|
||||||
* Logs backtest performance (Sharpe, MDD, Win Rate) and gate pass/fail status. Fold failure here does *not* halt the pipeline immediately but is recorded.
|
|
||||||
* **Store Fold Results:** Appends the `backtest_metrics` dict (including status) to `all_fold_metrics`. Stores the `sac_agent_load_path` if SAC was trained successfully.
|
|
||||||
5. **Aggregate Metrics (`aggregate_fold_metrics`):** Calculates summary statistics (mean, std, min, max, median) across metrics from all *successful* folds. Saves `aggregated_wf_metrics.json`.
|
|
||||||
6. **Aggregate SAC Agents (`aggregate_sac_agents`):** (Optional: if `sac_aggregation.enabled`)
|
|
||||||
* Loads SAC agents from the stored paths of successful folds.
|
|
||||||
* Averages the weights of the loaded agents.
|
|
||||||
* Saves the aggregated agent to the main run's model dir (`models/run_.../sac_agent_aggregated/`). Saves `sac_aggregation_info.txt`.
|
|
||||||
7. **Final Release Decision (`final_release_decision`):** Evaluates aggregated metrics against overall release criteria defined in `validation_gates.final_release` (e.g., min % successful folds, median Sharpe).
|
|
||||||
8. **Log Final Status:** Logs whether the pipeline passed or failed the final release criteria.
|
|
||||||
|
|
||||||
*(Note: If Walk-Forward is disabled (`walk_forward.enabled=false`), the pipeline runs steps D-T once using static splits based on `split_ratios` config).*
|
*(Note: If Walk-Forward is disabled, the pipeline runs the fold processing steps once using static splits).*\
|
||||||
|
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
gru_sac_predictor/
|
||||||
|
├── config/
|
||||||
|
│ └── config.yaml # Main configuration file
|
||||||
|
├── data/ # Data storage (e.g., parquet files)
|
||||||
|
├── logs/ # Log output directory (run-specific subdirs)
|
||||||
|
├── models/ # Saved models/scalers etc. (run-specific subdirs)
|
||||||
|
├── results/ # Output results, metrics, plots (run-specific subdirs)
|
||||||
|
├── src/
|
||||||
|
│ ├── pipeline_stages/ # **Refactored stage-specific logic**
|
||||||
|
│ │ ├── data_processing.py
|
||||||
|
│ │ ├── evaluation.py
|
||||||
|
│ │ ├── feature_processing.py
|
||||||
|
│ │ ├── modelling.py
|
||||||
|
│ │ └── sequence_creation.py
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── backtester.py # Backtesting simulation engine
|
||||||
|
│ ├── baseline_checker.py # Baseline logistic regression check
|
||||||
|
│ ├── calibrator.py # Temperature scaling
|
||||||
|
│ ├── calibrator_vector.py # Vector scaling
|
||||||
|
│ ├── data_loader.py # Loads raw data
|
||||||
|
│ ├── feature_engineer.py # Feature creation logic
|
||||||
|
│ ├── features.py # Feature lists/definitions (optional)
|
||||||
|
│ ├── gru_hyper_tuner.py # Optuna hyperparameter tuner for GRU
|
||||||
|
│ ├── gru_model_handler.py # GRU model building, training, loading (v2/v3)
|
||||||
|
│ ├── io_manager.py # Handles file I/O structure
|
||||||
|
│ ├── logger_setup.py # Configures logging
|
||||||
|
│ ├── metrics.py # Performance metrics calculation
|
||||||
|
│ ├── sac_agent.py # SAC agent network definitions
|
||||||
|
│ ├── sac_trainer.py # Offline SAC training orchestration
|
||||||
|
│ ├── trading_env.py # Gym-like environment for SAC training
|
||||||
|
│ ├── trading_pipeline.py # Main pipeline orchestrator class
|
||||||
|
│ └── utils/ # Utility functions (e.g., run_id generation)
|
||||||
|
├── tests/ # Unit/integration tests (optional)
|
||||||
|
├── run.py # Main execution script
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
└── README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
## Core Components Architecture
|
## Core Components Architecture
|
||||||
|
|
||||||
This section details the architecture and purpose of key modules.
|
* **`run.py`**: Entry point, sets up IO/logging, runs the pipeline.
|
||||||
|
* **`TradingPipeline` (`src/trading_pipeline.py`):** Orchestrates the workflow by calling stage functions, manages overall state, handles walk-forward loop and validation gates.
|
||||||
|
* **`src/pipeline_stages/*.py`**: Contain the core implementation logic for each distinct step of the pipeline (data processing, feature processing, sequencing, modelling, evaluation).
|
||||||
|
* **`IOManager`, `LoggerSetup`**: Utilities for managing outputs and logging.
|
||||||
|
* **`DataLoader`, `FeatureEngineer`**: Data loading and feature generation.
|
||||||
|
* **`GRUModelHandler`, `GRUHyperTuner`**: GRU model implementation (v2/v3), training, loading, and Optuna tuning.
|
||||||
|
* **`Calibrator`, `VectorCalibrator`**: Probability calibration logic.
|
||||||
|
* **`SACTradingAgent`, `TradingEnv`, `SACTrainer`**: SAC agent definition, training environment, and offline training orchestration (including PER, normalization, seeding).
|
||||||
|
* **`Backtester`, `BaselineChecker`, `metrics.py`**: Backtesting simulation, baseline checks, and performance metric calculations.
|
||||||
|
|
||||||
### 1. Orchestration & Utilities (`TradingPipeline`, `IOManager`, `LoggerSetup`)
|
|
||||||
* **`TradingPipeline` (`src/trading_pipeline.py`):** The main class coordinating the workflow (data prep, feature eng, model training, calibration, backtesting, aggregation) potentially within a walk-forward loop. Reads config, manages data flow, implements validation gates, and calls other components.
|
|
||||||
* **`IOManager` (`src/io_manager.py`):** Handles standardized file I/O (saving/loading models, dataframes, scalers, configs, plots, reports) within a run-specific directory structure.
|
|
||||||
* **`LoggerSetup` (`src/logger_setup.py`):** Configures Python's `logging` for console and file output with standardized formats and levels.
|
|
||||||
|
|
||||||
### 2. Data Handling (`DataLoader`, `FeatureEngineer`)
|
|
||||||
* **`DataLoader` (`src/data_loader.py`):** Loads raw OHLCV data from specified database files/sources.
|
|
||||||
* **`FeatureEngineer` (`src/feature_engineer.py`):** Generates features:
|
|
||||||
* **TA:** Returns, ATR, EMA, RSI (MACD removed).
|
|
||||||
* **Cyclical:** Hour, Week (sin/cos).
|
|
||||||
* **Imbalance:** Chaikin AD, SVI, Gap Imbalance.
|
|
||||||
* **Micro-structure:** Spread Proxy, Vol-Norm Volume Spike, Return Asymmetry, Close-Location Value, Keltner Band Position (with NaN guards).
|
|
||||||
* **Leakage Guard:** Uses `shift(1)` on inputs for time-dependent features.
|
|
||||||
* **Target Definition:** Calculates forward returns and binary/ternary labels.
|
|
||||||
* **Selection/Pruning:** Performs **leakage check** (`corr(ret+h, feat_t-1)`), then L1+VIF selection on *raw/engineered* data (applying minimal whitelist before VIF), then prunes *scaled* data based on the selection.
|
|
||||||
|
|
||||||
### 3. GRU Predictor (`gru_model_handler.py`, `gru_hyper_tuner.py`)
|
|
||||||
* **`gru_model_handler.py`:** **Consolidated GRU implementation.**
|
|
||||||
* Contains builders for both v2 (`build_gru_model`) and v3 (`build_gru_model_v3`) architectures.
|
|
||||||
* **v3 Architecture:** GRU -> LayerNorm -> Optional MultiHeadAttention -> GlobalAvgPool -> Dense Heads (`mu`, `dir3_logits`). Includes L2 regularization. Uses Huber loss for `mu`, Focal loss for `dir3`.
|
|
||||||
* Manages training (with early stopping, CSV logging), saving (.keras format), loading (handles custom loss `gaussian_nll`), and prediction (`predict`, `predict_logits`).
|
|
||||||
* Selects v2/v3 based on `control.use_v3` flag.
|
|
||||||
* **`gru_hyper_tuner.py`:** Implements Optuna hyperparameter sweep for GRU.
|
|
||||||
* Called by `TradingPipeline` if sweep is enabled.
|
|
||||||
* Uses **restricted search space** based on config (`hyperparameter_tuning.gru`).
|
|
||||||
* Uses combined objective based on config (`objective_metric`, `objective_edge_acc_weight`, `objective_brier_weight`). Logs components to trial attributes.
|
|
||||||
* Supports pruning via **Keras callback** reporting `val_loss` each epoch.
|
|
||||||
* Trains final fold model using best found parameters.
|
|
||||||
* Saves best parameters JSON and Optuna plots per fold.
|
|
||||||
|
|
||||||
### 4. Probability Calibration (`Calibrator`, `VectorCalibrator`, `metrics.py`)
|
|
||||||
* **`Calibrator` (`src/calibrator.py`):** Implements Temperature Scaling (learns scalar `T`) with optional L2 regularization.
|
|
||||||
* **`VectorCalibrator` (`src/calibrator_vector.py`):** Implements Vector Scaling (learns matrix `W`, bias `b`) with optional L2 regularization. Preferred for ternary.
|
|
||||||
* **Integration:** `TradingPipeline.calibrate_probabilities` fits chosen calibrator per fold (with L2 reg). If `optimize_edge_threshold=true`, calculates and stores optimal edge for the fold. `Backtester` applies calibration step-by-step, handles rolling recalibration if enabled (with **ECE-based coverage alarm**), using the **static fold calibration** if SAC training was active for the fold.
|
|
||||||
* **Edge Threshold Optimization:** `metrics._calculate_optimal_edge_threshold` finds best threshold using Youden's J. Called by `TradingPipeline` if enabled.
|
|
||||||
|
|
||||||
### 5. SAC Agent (`sac_agent.py`, `sac_trainer.py`, `trading_env.py`)
|
|
||||||
* **`sac_agent.py`:** Defines the SAC agent networks (Actor, Critic) and update logic.
|
|
||||||
* Actor outputs squashed Gaussian distribution parameters.
|
|
||||||
* Uses twin Q-critics.
|
|
||||||
* Handles automatic entropy tuning (alpha).
|
|
||||||
* `train` method accepts a batch and returns losses + TD errors (for PER).
|
|
||||||
* **`trading_env.py`:** Gym-style environment using GRU predictions.
|
|
||||||
* **State:** `[mu, sigma, edge, |mu|/sigma, position]`.
|
|
||||||
* Takes action (-1 to +1).
|
|
||||||
* Calculates reward based on PnL, potentially scaled (`reward_scale`) and penalized (**action penalty**, default: `0.01 / transaction_cost`).
|
|
||||||
* **`sac_trainer.py`:** Orchestrates offline SAC training.
|
|
||||||
* Loads GRU dependencies for a specific run.
|
|
||||||
* Prepares validation data for the `TradingEnv` (uses fold's static calibration).
|
|
||||||
* Initializes `TradingEnv` and `SACTradingAgent`.
|
|
||||||
* Manages the Replay Buffer:
|
|
||||||
* Implements `PrioritizedReplayBuffer` if `sac.use_per` is true (with **TD error clipping** and **alpha annealing**). Logs TD error distribution stats.
|
|
||||||
* Uses `collections.deque` for standard uniform replay.
|
|
||||||
* Handles **Oracle Seeding** into the buffer (with **IS weight decay** for seeded samples).
|
|
||||||
* Runs the training loop: interacts with env, stores transitions, samples batches (uniform or PER), updates agent, updates PER priorities.
|
|
||||||
* Handles **State Normalization** using `MeanStdFilter` if `sac.use_state_filter` is true (saves/loads filter state).
|
|
||||||
* Saves agent checkpoints, final agent, state filter, and logs (rewards, TensorBoard).
|
|
||||||
|
|
||||||
### 6. Evaluation (`Backtester`, `BaselineChecker`, `metrics.py`)
|
|
||||||
* **`Backtester` (`src/backtester.py`):** Evaluates the full system on test data.
|
|
||||||
* Takes trained GRU, SAC agent, initial calibration state, and the fold's edge threshold.
|
|
||||||
* Simulates step-by-step trading.
|
|
||||||
* Applies **rolling calibration** logic if enabled (with ECE check).
|
|
||||||
* Calculates PnL, equity curve, standard performance metrics.
|
|
||||||
* Saves detailed results dataframe, metrics summary, and plots per fold.
|
|
||||||
* Logs performance and whether fold backtest gates passed/failed, but does **not** halt the fold on failure (decision made during final aggregation).
|
|
||||||
* **`BaselineChecker` (`src/baseline_checker.py`):** Performs initial logistic regression check on **raw/engineered** training features.
|
|
||||||
* **`metrics.py`:** Contains calculation functions for Sharpe, Brier, Edge-Filtered Accuracy, ECE, and the Youden's J optimization helper.
|
|
||||||
|
|
||||||
## Configuration (`config.yaml`)
|
## Configuration (`config.yaml`)
|
||||||
|
|
||||||
The `config.yaml` file centrally controls the pipeline's behavior. Key sections include:
|
The `config.yaml` file centrally controls the pipeline's behavior. See comments within the default `config.yaml` for detailed explanations of each parameter. Key sections include:
|
||||||
|
|
||||||
* `base_dirs`: Output directories.
|
* `base_dirs`, `output`: Directory and output settings.
|
||||||
* `output`: Figure DPI, size, logging level.
|
* `data`, `features`: Data sources, labeling, feature selection controls.
|
||||||
* `data`: Data source details, label smoothing.
|
* `walk_forward`, `split_ratios`: Controls walk-forward vs. static splits.
|
||||||
* `features`: Minimal feature whitelist, leakage threshold.
|
* `gru`, `gru_v3`: GRU architecture, training parameters.
|
||||||
* `walk_forward`: Settings for WF validation (enable, days, step). **Note:** `walk_forward.enabled=true` overrides `split_ratios`.
|
* `hyperparameter_tuning`: Optuna sweep settings for GRU.
|
||||||
* `split_ratios`: Used only if `walk_forward.enabled=false`.
|
* `calibration`: Calibration method, parameters, rolling calibration, ECE alarm.
|
||||||
* `gru`: General GRU settings (horizon, lookback, ternary flag, flat sigma mult).
|
* `validation_gates`: Thresholds for baseline, GRU, backtest, and final release checks.
|
||||||
* `gru_v3`: Specific hyperparameters for the v3 architecture (units, attention, losses, reg).
|
* `sac`, `environment`: SAC agent hyperparameters, PER, seeding, environment settings.
|
||||||
* `hyperparameter_tuning`: Controls Optuna sweep for GRU (enable, trials, timeout, pruning, objective metric/weights).
|
* `sac_aggregation`: Agent averaging settings.
|
||||||
* `calibration`: Method (temp/vector), L2 lambda, optimize edge threshold flag, rolling calibration settings (enable, freq, window, ECE alarm).
|
* `control`: High-level flags (train/load models, enable plots, use v3 GRU).
|
||||||
* `validation_gates`: Thresholds for baseline, GRU gates, and final release decision (median Sharpe, % success).
|
|
||||||
* `sac`: SAC hyperparameters (gamma, tau, LR, alpha, PER settings, oracle seeding, IS weight decay steps, state filter).
|
## Installation
|
||||||
* `sac_aggregation`: Controls post-run agent averaging (enable, method).
|
|
||||||
* `environment`: Trading env parameters (capital, costs, reward scale, action penalty lambda).
|
1. Clone the repository.
|
||||||
* `control`: Flags to enable/disable major stages (train GRU, train SAC, run backtest, use v3, plots), model loading/resuming IDs.
|
2. Ensure you have Python 3.8+ installed.
|
||||||
|
3. Set up a virtual environment (recommended):
|
||||||
|
```bash
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate # On Windows use `.venv\\Scripts\\activate`
|
||||||
|
```
|
||||||
|
4. Install dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
*Note: This installs necessary libraries like TensorFlow, PyTorch, Optuna, scikit-learn, pandas, etc.*
|
||||||
|
5. Prepare your data according to the expected format and update paths in `config.yaml`.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
|
1. **Configure:** Edit `config.yaml` to set data paths, feature lists, model parameters, control flags (e.g., `train_gru`, `train_sac`), walk-forward settings, validation thresholds, etc.
|
||||||
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
|
2. **Run Pipeline:** Execute `run.py` from the project root directory (`gru_sac_predictor/`), specifying the configuration file:
|
||||||
3. **Run Pipeline:**
|
```bash
|
||||||
```bash
|
# Example execution from the parent directory 'develop/gru_sac_predictor/'
|
||||||
# From project root (develop/gru_sac_predictor/)
|
python gru_sac_predictor/run.py --config gru_sac_predictor/config/config.yaml
|
||||||
python gru_sac_predictor/run.py --config path/to/your_config.yaml
|
|
||||||
```
|
```
|
||||||
4. **Outputs:** Check `logs/`, `models/`, `results/` directories for run-specific outputs, including fold-specific artifacts if WF is enabled.
|
* You can use other command-line arguments like `--use-ternary` if implemented in `run.py`.
|
||||||
|
3. **Outputs:** Check the directories specified in `config.yaml` (typically subdirectories within `logs/`, `models/`, `results/`) for run-specific outputs. If Walk-Forward is enabled, you will find fold-specific subdirectories containing models, scalers, plots, and results for each fold.
|
||||||
|
|
||||||
## Output Artifacts (Walk-Forward Enabled Example)
|
## Output Artifacts (Walk-Forward Enabled Example)
|
||||||
|
|
||||||
* **Main Run Dirs:** `logs/run_<id>/`, `models/run_<id>/`, `results/run_<id>/`
|
* **Main Run Dirs:** `logs/run_<id>/`, `models/run_<id>/`, `results/run_<id>/`
|
||||||
* `run_config.yaml`, `pipeline_<id>.log`
|
* `run_config.yaml`, `pipeline_<id>.log`
|
||||||
* (Post-run) `aggregated_wf_metrics.json`, `sac_aggregation_info.txt`
|
* (Post-run) `aggregated_wf_metrics.json`, `sac_aggregation_info.txt` (if enabled)
|
||||||
* (Post-run) `models/.../sac_agent_aggregated/`
|
* (Post-run) `models/.../sac_agent_aggregated/` (if enabled)
|
||||||
* **Fold Dirs (within main run dirs):** e.g., `models/run_<id>/fold_1/`
|
* **Fold Dirs (within main run dirs):** e.g., `models/run_<id>/fold_1/`
|
||||||
* `models/run_<id>/fold_N/models/`: `gru_model_fold_N.keras`, `calibration_{...}_fold_N.npy`, `feature_scaler_fold_N.joblib`, `final_whitelist_fold_N.json`
|
* `models/run_<id>/fold_N/models/`: `gru_model_fold_N.keras`, `calibration_{...}_fold_N.npy`, `feature_scaler_fold_N.joblib`, `final_whitelist_fold_N.json`
|
||||||
* `models/run_<id>/fold_N/hypertuning/`: (If sweep enabled) `best_gru_params.json`, Optuna plots.
|
* `models/run_<id>/fold_N/hypertuning/`: (If sweep enabled) `best_gru_params.json`, Optuna plots.
|
||||||
@ -248,4 +259,17 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections
|
|||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
|
|
||||||
See `requirements.txt`. Key libraries include: TensorFlow, NumPy, Pandas, PyYAML, Scikit-learn, Statsmodels, TA-Lib (via `ta` wrapper), Matplotlib, Seaborn, Optuna, PyTorch (for SAC aggregation). Note `tensorflow-addons` is required for optimal focal loss / attention layers.
|
All major Python dependencies are listed in `requirements.txt`. Key libraries include:
|
||||||
|
|
||||||
|
* TensorFlow (for GRU)
|
||||||
|
* PyTorch (for SAC)
|
||||||
|
* Optuna (for hyperparameter tuning)
|
||||||
|
* scikit-learn (for scaling, metrics, baseline)
|
||||||
|
* pandas, numpy
|
||||||
|
* pyyaml
|
||||||
|
* matplotlib, seaborn
|
||||||
|
|
||||||
|
Install them using:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|||||||
26
gru_sac_predictor/docs/v3_changelog.md
Normal file
26
gru_sac_predictor/docs/v3_changelog.md
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user