updated readme

This commit is contained in:
Yasha Sheynin 2025-04-16 16:54:05 -04:00
parent f598141500
commit 984a230bcd
10 changed files with 95 additions and 52 deletions

View File

@ -1,10 +1,10 @@
# v7 - GRU + Simplified SAC Trading Agent (V6 GRU Adaptation)
# GRU + Simplified SAC Trading Agent
This project implements a cryptocurrency trading system using a GRU model for price prediction and a **Simplified SAC (Soft Actor-Critic)** agent for position sizing.
The system predicts future *price* using a GRU model adapted from the V6 architecture. It calculates the *predicted percentage return* from this price prediction and estimates prediction *uncertainty* based on the standard deviation of Monte Carlo dropout predictions. These two values (`predicted_return`, `mc_unscaled_std_dev`) form the state input to the SAC reinforcement learning agent, which determines optimal position sizing (-1 to +1).
The system predicts future *price* using a GRU model adapted from the V6 architecture. It calculates the *predicted percentage return* from this price prediction and estimates prediction *uncertainty* based on the standard deviation of Monte Carlo dropout predictions. It also extracts recent *momentum* and *volatility* features. These values, along with a risk proxy (`z_proxy`), form the **5-dimensional state** input (`[predicted_return, mc_unscaled_std_dev, z_proxy, momentum_5, volatility_20]`) to the SAC reinforcement learning agent, which determines optimal position sizing (-1 to +1) using a **squashed Gaussian policy** and **automatic entropy tuning**.
The system incorporates efficiency improvements by pre-computing GRU predictions and uncertainties before generating SAC experiences or running the backtest. It includes detailed backtesting, performance reporting, and visualization capabilities.
The system incorporates efficiency improvements by pre-computing GRU predictions and uncertainties before generating SAC experiences or running the backtest. It includes detailed backtesting, performance reporting, and visualization capabilities, including **SAC training loss plots**.
## System Design
@ -20,22 +20,23 @@ The system integrates a GRU predictor and a Simplified SAC agent within a backte
* *Scaling*: Within `TradingSystem.train_gru`, a `StandardScaler` is fitted *only* on the training features. A `MinMaxScaler` is fitted *only* on the training future *price* targets. Train and validation features/targets are scaled using these fitted scalers.
* *Sequence Creation*: `src.data_pipeline.create_sequences_v2` creates input sequences `(batch, sequence_length, num_features)` and corresponding scaled target prices using the scaled features/targets and the unscaled start prices.
* *Model Training*: `CryptoGRUModel.train` builds the V6-style GRU model (if not already built) and trains it using Mean Squared Error (MSE) loss on the scaled sequences. Callbacks monitor `val_rmse` for early stopping and model checkpointing. The best model (`best_model_reg.keras`) and the fitted scalers (`feature_scaler.joblib`, `y_scaler.joblib`) are saved.
* If `LOAD_EXISTING_SYSTEM` is `True` and `TRAIN_GRU_MODEL` is `False`: Attempts to load a pre-trained GRU model and scalers. If `GRU_MODEL_LOAD_RUN_ID` is set in `main.py`, it loads from that specific run ID's directory (`v7/models/run_<run_id>`); otherwise, it attempts to load from the default `MODEL_SAVE_PATH` (expecting a `gru_model` subdirectory).
* If `LOAD_EXISTING_SYSTEM` is `True` and `TRAIN_GRU_MODEL` is `False`:
* Attempts to load a pre-trained GRU model and scalers. If `GRU_MODEL_LOAD_RUN_ID` is set in `main.py`, it loads the GRU from that specific run ID's directory (`gru_sac_predictor/models/run_<run_id>`); otherwise, it attempts to load from the default `MODEL_SAVE_PATH` (expecting the model and scalers to be directly in that path).
* **Note:** SAC model loading is handled *separately* based on the `LOAD_SAC_AGENT` flag and the `GRU_MODEL_LOAD_RUN_ID` setting (see Model Loading/Training section in `main.py` for details).
4. **SAC Training (on Validation Set):**
* **Training Loop:** The training process runs for a fixed number of agent update steps (`TOTAL_TRAINING_STEPS`) instead of epochs.
* **Training Loop:** The training process runs for a fixed number of epochs (`SAC_EPOCHS`).
* **Experience Generation** (`TradingSystem.generate_trading_experiences`):
* **Efficiency:** Pre-computes all required GRU outputs (predicted returns, uncertainties) for the entire validation set by calling `CryptoGRUModel.evaluate` *once*.
* **Initial Fill:** Generates an initial set of experiences (`experience_config['initial_experiences']`). Uses the sampling strategy.
* **Sampling (`_sample_experience_indices`):** When generating a specific number of experiences (initial fill or periodic updates), it applies **weighted sampling** (controlled by `recency_bias_strength`) and **stratified sampling** (ensuring minimum ratios `min_uncertainty_ratio`, `min_extreme_return_ratio` of high uncertainty/extreme return examples based on quantiles `high_uncertainty_quantile`, `extreme_return_quantile`) based on parameters in `experience_config`.
* **Experience Format:** Iterates through the (potentially sampled) pre-computed results. Forms the state `s_t = [predicted_return_t, uncertainty_t]`. The SAC agent (`SimplifiedSACTradingAgent.get_action`) provides a *non-deterministic* action `a_t`. The next state `s_{t+1}` is retrieved. A reward `r_t = action * actual_return` is calculated (transaction costs are currently ignored in reward calculation during generation for simplicity). The transition `(s_t, a_t, r_t, s_{t+1}, done=False)` is created.
* **Periodic Updates:** During the main training loop (controlled by `total_training_steps`), new batches of experiences (`experience_config['experiences_per_batch']`) are generated periodically (every `experience_config['batch_generation_interval']` loop steps) using the sampling strategy and added to the replay buffer.
* **Agent Training** (`SimplifiedSACTradingAgent.train`): In each step of the main training loop, the agent performs `experience_config['training_iterations_per_step']` update(s). Batches are sampled from the replay buffer. Actor and Critic networks are updated using the SAC algorithm. The agent uses a standard FIFO circular buffer for experience storage.
* **State Extraction:** Extracts pre-computed GRU outputs and relevant features (`momentum_5`, `volatility_20`) from the validation features dataframe.
* **Experience Format:** Iterates through the pre-computed results. Forms the 5D state `s_t = [pred_return_t, uncertainty_t, z_proxy_t, momentum_5_t, volatility_20_t]` (where `z_proxy` uses the position *before* the action). The SAC agent (`SimplifiedSACTradingAgent.get_action`) provides a *non-deterministic* action `a_t` and `log_prob`. The next state `s_{t+1}` is constructed similarly (using `action` for `z_proxy`). A reward `r_t = action * actual_return - cost` is calculated. The transition `(s_t, a_t, r_t, s_{t+1}, done)` is stored.
* **Note:** Experience sampling strategies (recency bias, stratification) defined in `experience_config` are currently *not* implemented in `generate_trading_experiences` but the configuration remains.
* **Agent Training** (`TradingSystem.train_sac` calls `SimplifiedSACTradingAgent.train`): Iterates for `SAC_EPOCHS`. In each epoch, the agent performs one training step. Batches are sampled from the replay buffer. Actor and Critic networks are updated using the SAC algorithm with automatic alpha tuning. Agent uses `store_transition` to add experiences to its internal NumPy buffer.
* **History Plotting:** After successful training, `plot_sac_training_history` is called to generate and save a plot of actor and critic losses.
5. **Backtesting (on Test Set):**
* *Pre-computation* (`ExtendedBacktester.backtest`): Similar to SAC training, preprocesses the test data, scales it, creates sequences, and calls `CryptoGRUModel.evaluate` *once* to get all predicted returns and uncertainties for the test set.
* *Iteration*: Steps chronologically through the pre-computed results.
* *State Generation*: Retrieves `predicted_return` and `uncertainty_sigma` from the pre-computed arrays to form the state `s_t`.
* *Action Selection*: The trained `SimplifiedSACTradingAgent` selects a *deterministic* action `a_t`.
* *Portfolio Simulation*: Calculates PnL based on the previous position held (`current_position`), the actual return over the step, and subtracts transaction costs based on the change in position (`abs(action - current_position)`).
* *Pre-computation* (`ExtendedBacktester.backtest`): Preprocesses test data, scales, creates sequences, calls `CryptoGRUModel.evaluate` once for GRU outputs, and extracts required features (`momentum_5`, `volatility_20`).
* *State Generation*: Constructs the 5D state `s_t = [pred_return, uncertainty, z_proxy, momentum_5, volatility_20]` using pre-computed results and the current position.
* *Action Selection*: The trained `SimplifiedSACTradingAgent` selects a *deterministic* action `a_t` (unpacking the tuple returned by `get_action`).
* *Portfolio Simulation*: Calculates PnL based on the previous position, actual return, and transaction costs.
* *Logging*: Records detailed metrics, trade history, and timestamps.
6. **Evaluation:**
* *Performance Metrics*: `ExtendedBacktester._calculate_performance_metrics` computes overall portfolio metrics (Sharpe, Sortino, Drawdown, correlations, etc.) and Buy & Hold benchmark metrics.
@ -53,15 +54,16 @@ The system integrates a GRU predictor and a Simplified SAC agent within a backte
* `evaluate()`: Performs standard prediction and MC dropout inference. Returns dict including `pred_percent_change`, `mc_unscaled_std_dev`, `predicted_unscaled_prices`, `true_unscaled_prices`.
* **`src.sac_agent_simplified.SimplifiedSACTradingAgent`**: (V7 Simplified)
* **Goal:** Learns a policy mapping state to optimal position size (-1.0 to +1.0). Optimized for faster training.
* **State Input:** 2-element array `[predicted_return, mc_unscaled_std_dev]`.
* **State Input:** 5-element array `[predicted_return, mc_unscaled_std_dev, z_proxy, momentum_5, volatility_20]`.
* **Action Output:** Float between -1.0 and +1.0.
* `get_action()`: Selects action (stochastic or deterministic). Adds uncertainty-scaled noise during exploration.
* `store_transition()`: Adds experience to internal NumPy buffer.
* `train()`: Updates agent using buffer samples (internally handles batch size). Uses `@tf.function` for performance.
* `save()` / `load()`: Handles Actor/Critic weights (`.weights.h5`), potentially `alpha.npy`.
* **Note:** Models and optimizers are built explicitly during `__init__` to prevent TensorFlow graph mode issues.
* **`src.trading_system.TradingSystem`**: Integrates GRU and SAC. Manages training pipelines, experience generation (including advanced sampling).
* **Note:** Models and optimizers are built explicitly during `__init__` using dummy inputs to prevent TensorFlow graph mode issues.
* **`src.trading_system.TradingSystem`**: Integrates GRU and SAC. Manages training pipelines, feature calculation, experience generation.
* **`src.trading_system.ExtendedBacktester`**: Performs efficient backtesting using pre-computed GRU outputs, calculates metrics, plots results, generates reports.
* **`src.trading_system.plot_sac_training_history`**: Generates plot for SAC actor/critic losses during training.
### 3. Model Architectures
@ -69,46 +71,46 @@ The system integrates a GRU predictor and a Simplified SAC agent within a backte
* Input -> GRU(100) -> Dropout(0.2) -> Dense(1, linear).
* Compiled with Adam (LR=0.001), MSE loss.
* **Simplified SAC (`src.sac_agent_simplified.SimplifiedSACTradingAgent`)**:
* **Actor Network**: MLP `(state_dim=2)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, tanh).
* **Critic Network (x2)**: MLP `(state_dim=2 + action_dim=1)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, linear).
* **Algorithm**: Implements SAC with Clipped Double-Q, fixed alpha (tunable via `SAC_ALPHA`), faster learning rates, smaller networks/buffer, optional Batch Normalization / Residual connections. Uses Huber loss for critics. No distributional critics. `@tf.function` used for update steps.
* **Actor Network**: MLP `(state_dim=5)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, name='mu'), Dense(1, name='log_std'). Output is `mu` and `log_std` for a **Gaussian policy**. `log_std` is clipped.
* **Critic Network (x2)**: MLP `(state_dim=5 + action_dim=1)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, linear).
* **Algorithm**: Implements SAC with Clipped Double-Q, **automatic entropy tuning** (optimizing `alpha` based on `target_entropy`), squashed actions (`tanh`), faster learning rates, smaller networks/buffer, optional Batch Normalization / Residual connections. Uses Huber loss for critics. `@tf.function` used for update steps (`_update_critics`, `_update_actor_and_alpha`).
### 4. Features
### 4. Features & State Representation
The GRU model uses the V6 feature set plus basic past returns:
* **TA-Lib Indicators & Derived Indicators:** SMA, EMA, MACD, SAR, ADX, RSI, Stochastics, WILLR, ROC, CCI, BBands, ATR, OBV, CMF, etc. (Requires TA-Lib installation). Fallback calculations for SMA, EMA, RSI if TA-Lib is unavailable.
* **Custom Crypto Features:** Parkinson Volatility, Garman-Klass Volatility, VWAP ratios, Volume Intensity, Wick Ratios.
* **Past Returns:** `return_1m`, `return_5m`, `return_15m`, `return_60m` (percentage change).
* **Scaling:** Features scaled with `StandardScaler` (fitted on train). Target variable (future price) scaled with `MinMaxScaler` (fitted on train).
* **GRU Features:** Uses the V6 feature set plus basic past returns (see `calculate_v6_features`). Cyclical time features (`hour_sin`, `hour_cos`) are added *before* data splitting.
* **SAC State (`state_dim=5`):**
1. `predicted_return`: GRU predicted percentage return for the next period.
2. `uncertainty`: GRU MC dropout standard deviation (unscaled).
3. `z_proxy`: Risk proxy, calculated as `current_position * volatility_20`.
4. `momentum_5`: 5-minute return (`return_5m` feature).
5. `volatility_20`: 20-day volatility (`volatility_14d` feature, name mismatch intended).
* **Scaling:** Features for GRU scaled with `StandardScaler`. Target price for GRU scaled with `MinMaxScaler`. SAC state components are used directly without separate scaling.
### 5. Evaluation
* **GRU Model:** Evaluated using RMSE loss on validation set. Callbacks monitor `val_rmse`. Plots compare predicted vs actual price.
* **SAC Agent & Overall System:** Evaluated via the `ExtendedBacktester` metrics (Sharpe, Sortino, Max Drawdown, correlations, etc.), plots (Portfolio vs B&H, Actions), and a final Markdown report.
* **SAC Agent & Overall System:** Evaluated via the `ExtendedBacktester` metrics (Sharpe, Sortino, Max Drawdown, correlations, etc.), plots (Portfolio vs B&H, Actions), and a final Markdown report. SAC training progress monitored via saved loss plots (`sac_training_history_<run_id>.png`).
## File Structure
- `data/`: *Not used by default if loading from DB.*
- `downloaded_data/`: **Place your V6 SQLite database files here.** (Or update `DB_DIR` in `main.py`).
- `models/`: Trained models (GRU + scalers, SAC weights) saved here under `run_<run_id>/` directories by default.
- `results/`: Backtest results (plots, reports, config) saved here under `<run_id>/` directories.
- `downloaded_data/`: **Place your SQLite database files here.** (Or update `DB_DIR` in `main.py`).
- `gru_sac_predictor/`: Project root directory.
- `models/`: Trained models saved here under `run_<run_id>/` directories.
- `results/`: Backtest results saved here under `<run_id>/` directories.
- `logs/`: Log files saved here under `<run_id>/` directories.
- `src/`: Core Python modules.
- `crypto_db_fetcher.py`: Class for fetching data from SQLite DBs.
- `data_pipeline.py`: DB loading function, data splitting, sequence creation.
- `gru_predictor.py`: V6-style GRU model for price regression and MC uncertainty.
- `sac_agent_simplified.py`: Simplified SAC agent implementation (V7.5+).
- `sac_agent.py`: Original SAC agent implementation (potentially outdated).
- `trading_system.py`: Integration class, feature calculation, scaling, experience generation, `ExtendedBacktester` class.
- `main.py`: Main script using DB loading, orchestrates training and backtesting.
- `requirements.txt`: Dependencies.
- `v7_instructions.txt`: Design notes for Simplified SAC.
- `experience_instructions.txt`: Design notes for experience generation.
- `README.md`: This file.
- `crypto_db_fetcher.py`
- `data_pipeline.py`
- `gru_predictor.py`
- `sac_agent_simplified.py`
- `trading_system.py`
- `main.py`: Main script.
- `requirements.txt`
- `README.md`
## Setup
1. **Data:** Place your V6 `downloaded_data` directory containing the SQLite files relative to the `v7` project root, or update the `DB_DIR` variable in `main.py` to point to the correct location.
1. **Data:** Place your V6 `downloaded_data` directory containing the SQLite files relative to the `gru_sac_predictor` project root, or update the `DB_DIR` variable in `main.py` to point to the correct location.
2. **Dependencies:** Install required packages:
```bash
pip install -r requirements.txt
@ -118,15 +120,15 @@ The GRU model uses the V6 feature set plus basic past returns:
* `DB_DIR`, `TICKER`, `EXCHANGE`, `START_DATE`, `END_DATE`, `INTERVAL`
* Model hyperparameters (GRU and SAC sections)
* Control Flags: `LOAD_EXISTING_SYSTEM`, `TRAIN_GRU_MODEL`, `TRAIN_SAC_AGENT`, `LOAD_SAC_AGENT`
* Loading Specific Models: `GRU_MODEL_LOAD_RUN_ID` (set to a specific run ID string like `'YYYYMMDD_HHMMSS'` to load that GRU model from `v7/models/run_<run_id>/`). Note: This expects GRU and SAC files to be in the *same* directory if loading this way.
* SAC Training: `TOTAL_TRAINING_STEPS` defines the length of SAC training (number of agent `train()` calls).
* Experience Generation: `experience_config` dictionary controls initial fill, periodic updates, and sampling strategies (recency bias, stratification for uncertainty/extreme returns).
* Loading Specific Models: `GRU_MODEL_LOAD_RUN_ID` (set to a specific run ID string like `'YYYYMMDD_HHMMSS'` to load *only* the GRU model from `gru_sac_predictor/models/run_<run_id>/`). SAC loading depends on `LOAD_SAC_AGENT` flag.
* SAC Training: `SAC_EPOCHS` defines the number of training epochs.
* Experience Generation: `experience_config` dictionary (sampling strategies currently not implemented).
* Backtesting: `INITIAL_CAPITAL`, `TRANSACTION_COST`.
4. **Run:** Execute from the project root directory (containing the `v7` folder):
4. **Run:** Execute from the project root directory (the one *containing* `gru_sac_predictor`):
```bash
python -m v7.main
python -m gru_sac_predictor.main
```
Output files (logs, models, plots, report) will be generated in `v7/logs/`, `v7/models/`, and `v7/results/` within run-specific subdirectories.
Output files (logs, models, plots, report) will be generated in `gru_sac_predictor/logs/`, `gru_sac_predictor/models/`, and `gru_sac_predictor/results/` within run-specific subdirectories.
## Reporting

View File

@ -0,0 +1,41 @@
# GRU+SAC Backtesting Performance Report
Report generated on: 2025-04-16 16:52:19.447276
Data range: N/A
Total duration: N/A
## Strategy Performance Metrics
* **Initial capital:** $0.00
* **Final portfolio value:** $0.00
* **Total return:** 0.00%
* **Annualized return:** 0.00%
* **Sharpe ratio (annualized):** 0.0000
* **Sortino ratio (annualized):** 0.0000
* **Volatility (annualized):** 0.00%
* **Maximum drawdown:** 0.00%
* **Total trades:** 0
## Buy and Hold Benchmark
* *Buy and Hold benchmark could not be calculated.*
## Position & Prediction Analysis
* **Average absolute position size:** 0.0000
* **Position sign accuracy vs return:** 0.00%
* **Prediction sign accuracy vs return:** 0.00%
* **Prediction RMSE (on returns):** 0.000000
## Correlations
* **Prediction-Return correlation:** 0.0000
* **Prediction-Position correlation:** 0.0000
* **Uncertainty-Position Size correlation:** 0.0000
## Notes
* Transaction cost used: 0.0500% per position change value.
* GRU lookback period: 60 minutes.
* V6 features + return features used.
* Uncertainty estimated via MC Dropout standard deviation.

Binary file not shown.

After

Width:  |  Height:  |  Size: 188 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB