initial commit for predictor module
This commit is contained in:
commit
f598141500
141
gru_sac_predictor/README.md
Normal file
141
gru_sac_predictor/README.md
Normal file
@ -0,0 +1,141 @@
|
||||
# v7 - GRU + Simplified SAC Trading Agent (V6 GRU Adaptation)
|
||||
|
||||
This project implements a cryptocurrency trading system using a GRU model for price prediction and a **Simplified SAC (Soft Actor-Critic)** agent for position sizing.
|
||||
|
||||
The system predicts future *price* using a GRU model adapted from the V6 architecture. It calculates the *predicted percentage return* from this price prediction and estimates prediction *uncertainty* based on the standard deviation of Monte Carlo dropout predictions. These two values (`predicted_return`, `mc_unscaled_std_dev`) form the state input to the SAC reinforcement learning agent, which determines optimal position sizing (-1 to +1).
|
||||
|
||||
The system incorporates efficiency improvements by pre-computing GRU predictions and uncertainties before generating SAC experiences or running the backtest. It includes detailed backtesting, performance reporting, and visualization capabilities.
|
||||
|
||||
## System Design
|
||||
|
||||
The system integrates a GRU predictor and a Simplified SAC agent within a backtesting framework.
|
||||
|
||||
### 1. Data Flow & Processing
|
||||
|
||||
1. **Loading:** Raw 1-minute OHLCV data is loaded from the SQLite database directory specified in `main.py` (e.g., `downloaded_data/`) using `src.data_pipeline.load_data_from_db` which utilizes `src.crypto_db_fetcher.CryptoDBFetcher`.
|
||||
2. **Splitting:** Data is chronologically split into training (60%), validation (20%), and test (20%) sets using `src.data_pipeline.create_data_pipeline`.
|
||||
3. **GRU Training / Loading (on Train/Validation Sets):**
|
||||
* If `TRAIN_GRU_MODEL` is `True`:
|
||||
* *Preprocessing*: `TradingSystem._preprocess_data_for_gru_training` calculates V6 features plus basic return features (`calculate_v6_features`) on the raw train/val data. It determines the future *price* target (`prediction_horizon` steps ahead) and aligns features, targets (prices), and the *unscaled* starting close prices needed for return calculation.
|
||||
* *Scaling*: Within `TradingSystem.train_gru`, a `StandardScaler` is fitted *only* on the training features. A `MinMaxScaler` is fitted *only* on the training future *price* targets. Train and validation features/targets are scaled using these fitted scalers.
|
||||
* *Sequence Creation*: `src.data_pipeline.create_sequences_v2` creates input sequences `(batch, sequence_length, num_features)` and corresponding scaled target prices using the scaled features/targets and the unscaled start prices.
|
||||
* *Model Training*: `CryptoGRUModel.train` builds the V6-style GRU model (if not already built) and trains it using Mean Squared Error (MSE) loss on the scaled sequences. Callbacks monitor `val_rmse` for early stopping and model checkpointing. The best model (`best_model_reg.keras`) and the fitted scalers (`feature_scaler.joblib`, `y_scaler.joblib`) are saved.
|
||||
* If `LOAD_EXISTING_SYSTEM` is `True` and `TRAIN_GRU_MODEL` is `False`: Attempts to load a pre-trained GRU model and scalers. If `GRU_MODEL_LOAD_RUN_ID` is set in `main.py`, it loads from that specific run ID's directory (`v7/models/run_<run_id>`); otherwise, it attempts to load from the default `MODEL_SAVE_PATH` (expecting a `gru_model` subdirectory).
|
||||
4. **SAC Training (on Validation Set):**
|
||||
* **Training Loop:** The training process runs for a fixed number of agent update steps (`TOTAL_TRAINING_STEPS`) instead of epochs.
|
||||
* **Experience Generation** (`TradingSystem.generate_trading_experiences`):
|
||||
* **Efficiency:** Pre-computes all required GRU outputs (predicted returns, uncertainties) for the entire validation set by calling `CryptoGRUModel.evaluate` *once*.
|
||||
* **Initial Fill:** Generates an initial set of experiences (`experience_config['initial_experiences']`). Uses the sampling strategy.
|
||||
* **Sampling (`_sample_experience_indices`):** When generating a specific number of experiences (initial fill or periodic updates), it applies **weighted sampling** (controlled by `recency_bias_strength`) and **stratified sampling** (ensuring minimum ratios `min_uncertainty_ratio`, `min_extreme_return_ratio` of high uncertainty/extreme return examples based on quantiles `high_uncertainty_quantile`, `extreme_return_quantile`) based on parameters in `experience_config`.
|
||||
* **Experience Format:** Iterates through the (potentially sampled) pre-computed results. Forms the state `s_t = [predicted_return_t, uncertainty_t]`. The SAC agent (`SimplifiedSACTradingAgent.get_action`) provides a *non-deterministic* action `a_t`. The next state `s_{t+1}` is retrieved. A reward `r_t = action * actual_return` is calculated (transaction costs are currently ignored in reward calculation during generation for simplicity). The transition `(s_t, a_t, r_t, s_{t+1}, done=False)` is created.
|
||||
* **Periodic Updates:** During the main training loop (controlled by `total_training_steps`), new batches of experiences (`experience_config['experiences_per_batch']`) are generated periodically (every `experience_config['batch_generation_interval']` loop steps) using the sampling strategy and added to the replay buffer.
|
||||
* **Agent Training** (`SimplifiedSACTradingAgent.train`): In each step of the main training loop, the agent performs `experience_config['training_iterations_per_step']` update(s). Batches are sampled from the replay buffer. Actor and Critic networks are updated using the SAC algorithm. The agent uses a standard FIFO circular buffer for experience storage.
|
||||
5. **Backtesting (on Test Set):**
|
||||
* *Pre-computation* (`ExtendedBacktester.backtest`): Similar to SAC training, preprocesses the test data, scales it, creates sequences, and calls `CryptoGRUModel.evaluate` *once* to get all predicted returns and uncertainties for the test set.
|
||||
* *Iteration*: Steps chronologically through the pre-computed results.
|
||||
* *State Generation*: Retrieves `predicted_return` and `uncertainty_sigma` from the pre-computed arrays to form the state `s_t`.
|
||||
* *Action Selection*: The trained `SimplifiedSACTradingAgent` selects a *deterministic* action `a_t`.
|
||||
* *Portfolio Simulation*: Calculates PnL based on the previous position held (`current_position`), the actual return over the step, and subtracts transaction costs based on the change in position (`abs(action - current_position)`).
|
||||
* *Logging*: Records detailed metrics, trade history, and timestamps.
|
||||
6. **Evaluation:**
|
||||
* *Performance Metrics*: `ExtendedBacktester._calculate_performance_metrics` computes overall portfolio metrics (Sharpe, Sortino, Drawdown, correlations, etc.) and Buy & Hold benchmark metrics.
|
||||
* *Visualization*: `ExtendedBacktester.plot_results` generates a 3-panel plot: GRU Predictions vs Actual Price (with uncertainty), SAC Actions (Position Size), and Portfolio Value vs Buy & Hold (with trade markers).
|
||||
* *Reporting*: `ExtendedBacktester.generate_performance_report` creates a detailed Markdown report.
|
||||
|
||||
### 2. Core Components & Inputs/Outputs
|
||||
|
||||
* **`src.crypto_db_fetcher.CryptoDBFetcher`**: Loads and resamples data from SQLite DBs.
|
||||
* **`src.data_pipeline`**: Functions for DB loading, data splitting, sequence creation.
|
||||
* **`src.trading_system.calculate_v6_features`**: Calculates features (TA-Lib based V6 set + past returns).
|
||||
* **`src.trading_system._preprocess_data_for_gru_training`**: Prepares features, future price targets, and start prices.
|
||||
* **`src.gru_predictor.CryptoGRUModel`**: (V6 Adaptation)
|
||||
* `train()`: Trains the GRU price prediction model. Saves model (`.keras`) and scalers (`.joblib`).
|
||||
* `evaluate()`: Performs standard prediction and MC dropout inference. Returns dict including `pred_percent_change`, `mc_unscaled_std_dev`, `predicted_unscaled_prices`, `true_unscaled_prices`.
|
||||
* **`src.sac_agent_simplified.SimplifiedSACTradingAgent`**: (V7 Simplified)
|
||||
* **Goal:** Learns a policy mapping state to optimal position size (-1.0 to +1.0). Optimized for faster training.
|
||||
* **State Input:** 2-element array `[predicted_return, mc_unscaled_std_dev]`.
|
||||
* **Action Output:** Float between -1.0 and +1.0.
|
||||
* `get_action()`: Selects action (stochastic or deterministic). Adds uncertainty-scaled noise during exploration.
|
||||
* `store_transition()`: Adds experience to internal NumPy buffer.
|
||||
* `train()`: Updates agent using buffer samples (internally handles batch size). Uses `@tf.function` for performance.
|
||||
* `save()` / `load()`: Handles Actor/Critic weights (`.weights.h5`), potentially `alpha.npy`.
|
||||
* **Note:** Models and optimizers are built explicitly during `__init__` to prevent TensorFlow graph mode issues.
|
||||
* **`src.trading_system.TradingSystem`**: Integrates GRU and SAC. Manages training pipelines, experience generation (including advanced sampling).
|
||||
* **`src.trading_system.ExtendedBacktester`**: Performs efficient backtesting using pre-computed GRU outputs, calculates metrics, plots results, generates reports.
|
||||
|
||||
### 3. Model Architectures
|
||||
|
||||
* **GRU (`src.gru_predictor.CryptoGRUModel._build_model`)**: V6 Architecture.
|
||||
* Input -> GRU(100) -> Dropout(0.2) -> Dense(1, linear).
|
||||
* Compiled with Adam (LR=0.001), MSE loss.
|
||||
* **Simplified SAC (`src.sac_agent_simplified.SimplifiedSACTradingAgent`)**:
|
||||
* **Actor Network**: MLP `(state_dim=2)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, tanh).
|
||||
* **Critic Network (x2)**: MLP `(state_dim=2 + action_dim=1)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, linear).
|
||||
* **Algorithm**: Implements SAC with Clipped Double-Q, fixed alpha (tunable via `SAC_ALPHA`), faster learning rates, smaller networks/buffer, optional Batch Normalization / Residual connections. Uses Huber loss for critics. No distributional critics. `@tf.function` used for update steps.
|
||||
|
||||
### 4. Features
|
||||
|
||||
The GRU model uses the V6 feature set plus basic past returns:
|
||||
* **TA-Lib Indicators & Derived Indicators:** SMA, EMA, MACD, SAR, ADX, RSI, Stochastics, WILLR, ROC, CCI, BBands, ATR, OBV, CMF, etc. (Requires TA-Lib installation). Fallback calculations for SMA, EMA, RSI if TA-Lib is unavailable.
|
||||
* **Custom Crypto Features:** Parkinson Volatility, Garman-Klass Volatility, VWAP ratios, Volume Intensity, Wick Ratios.
|
||||
* **Past Returns:** `return_1m`, `return_5m`, `return_15m`, `return_60m` (percentage change).
|
||||
* **Scaling:** Features scaled with `StandardScaler` (fitted on train). Target variable (future price) scaled with `MinMaxScaler` (fitted on train).
|
||||
|
||||
### 5. Evaluation
|
||||
|
||||
* **GRU Model:** Evaluated using RMSE loss on validation set. Callbacks monitor `val_rmse`. Plots compare predicted vs actual price.
|
||||
* **SAC Agent & Overall System:** Evaluated via the `ExtendedBacktester` metrics (Sharpe, Sortino, Max Drawdown, correlations, etc.), plots (Portfolio vs B&H, Actions), and a final Markdown report.
|
||||
|
||||
## File Structure
|
||||
|
||||
- `data/`: *Not used by default if loading from DB.*
|
||||
- `downloaded_data/`: **Place your V6 SQLite database files here.** (Or update `DB_DIR` in `main.py`).
|
||||
- `models/`: Trained models (GRU + scalers, SAC weights) saved here under `run_<run_id>/` directories by default.
|
||||
- `results/`: Backtest results (plots, reports, config) saved here under `<run_id>/` directories.
|
||||
- `logs/`: Log files saved here under `<run_id>/` directories.
|
||||
- `src/`: Core Python modules.
|
||||
- `crypto_db_fetcher.py`: Class for fetching data from SQLite DBs.
|
||||
- `data_pipeline.py`: DB loading function, data splitting, sequence creation.
|
||||
- `gru_predictor.py`: V6-style GRU model for price regression and MC uncertainty.
|
||||
- `sac_agent_simplified.py`: Simplified SAC agent implementation (V7.5+).
|
||||
- `sac_agent.py`: Original SAC agent implementation (potentially outdated).
|
||||
- `trading_system.py`: Integration class, feature calculation, scaling, experience generation, `ExtendedBacktester` class.
|
||||
- `main.py`: Main script using DB loading, orchestrates training and backtesting.
|
||||
- `requirements.txt`: Dependencies.
|
||||
- `v7_instructions.txt`: Design notes for Simplified SAC.
|
||||
- `experience_instructions.txt`: Design notes for experience generation.
|
||||
- `README.md`: This file.
|
||||
|
||||
## Setup
|
||||
|
||||
1. **Data:** Place your V6 `downloaded_data` directory containing the SQLite files relative to the `v7` project root, or update the `DB_DIR` variable in `main.py` to point to the correct location.
|
||||
2. **Dependencies:** Install required packages:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
*Strongly Recommended:* Install TA-Lib for the full feature set. See TA-Lib installation guides for your OS.
|
||||
3. **Configuration:** Review and adjust parameters in `main.py`. Key parameters include:
|
||||
* `DB_DIR`, `TICKER`, `EXCHANGE`, `START_DATE`, `END_DATE`, `INTERVAL`
|
||||
* Model hyperparameters (GRU and SAC sections)
|
||||
* Control Flags: `LOAD_EXISTING_SYSTEM`, `TRAIN_GRU_MODEL`, `TRAIN_SAC_AGENT`, `LOAD_SAC_AGENT`
|
||||
* Loading Specific Models: `GRU_MODEL_LOAD_RUN_ID` (set to a specific run ID string like `'YYYYMMDD_HHMMSS'` to load that GRU model from `v7/models/run_<run_id>/`). Note: This expects GRU and SAC files to be in the *same* directory if loading this way.
|
||||
* SAC Training: `TOTAL_TRAINING_STEPS` defines the length of SAC training (number of agent `train()` calls).
|
||||
* Experience Generation: `experience_config` dictionary controls initial fill, periodic updates, and sampling strategies (recency bias, stratification for uncertainty/extreme returns).
|
||||
* Backtesting: `INITIAL_CAPITAL`, `TRANSACTION_COST`.
|
||||
4. **Run:** Execute from the project root directory (containing the `v7` folder):
|
||||
```bash
|
||||
python -m v7.main
|
||||
```
|
||||
Output files (logs, models, plots, report) will be generated in `v7/logs/`, `v7/models/`, and `v7/results/` within run-specific subdirectories.
|
||||
|
||||
## Reporting
|
||||
|
||||
The report generated by the `ExtendedBacktester` includes performance metrics, correlation analyses, and configuration details. Key metrics include:
|
||||
|
||||
* Total/Annualized Return
|
||||
* Sharpe & Sortino Ratios
|
||||
* Volatility & Max Drawdown
|
||||
* Buy & Hold Comparison
|
||||
* Position/Prediction Accuracy
|
||||
* Prediction/Position/Uncertainty Correlations
|
||||
* Total Trades
|
||||
BIN
gru_sac_predictor/__pycache__/main.cpython-312.pyc
Normal file
BIN
gru_sac_predictor/__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
0
gru_sac_predictor/logs/main_v7.log
Normal file
0
gru_sac_predictor/logs/main_v7.log
Normal file
465
gru_sac_predictor/main.py
Normal file
465
gru_sac_predictor/main.py
Normal file
@ -0,0 +1,465 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from datetime import datetime
|
||||
import warnings
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
|
||||
# --- Generate Run ID ---
|
||||
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Import components
|
||||
# V7 Update: Import load_data_from_db
|
||||
from .src.data_pipeline import create_data_pipeline, load_data_from_db
|
||||
from .src.trading_system import TradingSystem, ExtendedBacktester, plot_sac_training_history
|
||||
# V7.3 Fix: Add missing imports
|
||||
# V7-V6 Final Update: Import CryptoGRUModel
|
||||
from .src.gru_predictor import CryptoGRUModel
|
||||
# V7.5 Import the simplified agent
|
||||
from .src.sac_agent_simplified import SimplifiedSACTradingAgent
|
||||
# GRU and SAC classes are implicitly imported via TradingSystem
|
||||
|
||||
# --- Base Output Directories ---
|
||||
BASE_RESULTS_DIR = "gru_sac_predictor/results"
|
||||
BASE_LOGS_DIR = "gru_sac_predictor/logs"
|
||||
BASE_MODELS_DIR = "gru_sac_predictor/models"
|
||||
|
||||
# --- Run Specific Directories ---
|
||||
RUN_RESULTS_DIR = os.path.join(BASE_RESULTS_DIR, run_id)
|
||||
RUN_LOGS_DIR = os.path.join(BASE_LOGS_DIR, run_id)
|
||||
RUN_MODELS_DIR = os.path.join(BASE_MODELS_DIR, f"run_{run_id}")
|
||||
|
||||
# --- Logging Setup ---
|
||||
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
# Ensure logs directory exists
|
||||
os.makedirs(RUN_LOGS_DIR, exist_ok=True)
|
||||
log_file_path = os.path.join(RUN_LOGS_DIR, f"main_{run_id}.log") # Removed _v7
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=log_format,
|
||||
handlers=[
|
||||
logging.FileHandler(log_file_path, mode='a'), # Use path variable
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Configuration ---
|
||||
# V7 Update: Add DB parameters
|
||||
DB_DIR = '../downloaded_data' # V7 Fix: Point to correct relative path for V6 data
|
||||
TICKER = 'BTC-USD' # Example ticker
|
||||
EXCHANGE = 'COINBASE' # Example exchange
|
||||
START_DATE = '2025-03-01' # Example start date - NOTE: VERY SHORT!
|
||||
END_DATE = '2025-03-10' # Example end date - NOTE: VERY SHORT!
|
||||
INTERVAL = '1min' # Data interval to fetch and use
|
||||
|
||||
MODEL_SAVE_PATH = RUN_MODELS_DIR # Use run-specific directory
|
||||
# Updated paths to use RUN_RESULTS_DIR and include run_id
|
||||
RESULTS_PLOT_PATH = os.path.join(RUN_RESULTS_DIR, f'backtest_results_{run_id}.png') # Removed _v7
|
||||
REPORT_SAVE_PATH = os.path.join(RUN_RESULTS_DIR, f'backtest_performance_report_{run_id}.md') # Removed _v7
|
||||
# GRU_PLOT_PATH = 'gru_performance_v7.png' # Not used directly in main
|
||||
|
||||
# V7.6 Add specific run ID for loading GRU model
|
||||
GRU_MODEL_LOAD_RUN_ID = '20250416_142744' # Set this to a specific 'YYYYMMDD_HHMMSS' string to load that GRU model
|
||||
|
||||
# Data split ratios
|
||||
TRAIN_RATIO = 0.6
|
||||
VALIDATION_RATIO = 0.2
|
||||
|
||||
# Model/Training Parameters (V7.3)
|
||||
GRU_LOOKBACK = 60
|
||||
GRU_PREDICTION_HORIZON = 1
|
||||
GRU_EPOCHS = 20
|
||||
GRU_BATCH_SIZE = 32 # Updated default
|
||||
GRU_PATIENCE = 10 # Updated default
|
||||
GRU_LR_PATIENCE = 10 # Updated default
|
||||
GRU_LR_FACTOR = 0.5 # Updated default
|
||||
GRU_RETURN_SCALE = 0.03 # Updated default
|
||||
|
||||
# SAC Parameters (V7.5 - Simplified Agent)
|
||||
SAC_STATE_DIM = 5 # [pred_return, uncertainty, z, momentum_5, volatility_20] - Updated from 2
|
||||
SAC_HIDDEN_SIZE = 64
|
||||
SAC_GAMMA = 0.97
|
||||
SAC_TAU = 0.02
|
||||
# SAC_ALPHA = 0.1 # Removed - Will use automatic tuning
|
||||
SAC_ACTOR_LR = 3e-4 # Lowered from 5e-4
|
||||
SAC_CRITIC_LR = 5e-4 # Lowered from 8e-4
|
||||
SAC_BATCH_SIZE = 64
|
||||
SAC_BUFFER_MAX_SIZE = 20000
|
||||
SAC_MIN_BUFFER_SIZE = 1000
|
||||
SAC_UPDATE_INTERVAL = 1
|
||||
SAC_TARGET_UPDATE_INTERVAL = 2
|
||||
SAC_GRADIENT_CLIP = 1.0
|
||||
SAC_REWARD_SCALE = 2.0 # Decreased from 10.0
|
||||
SAC_USE_BATCH_NORM = True
|
||||
SAC_USE_RESIDUAL = True
|
||||
SAC_MODEL_DIR = 'models/simplified_sac' # Default dir within the agent class
|
||||
SAC_EPOCHS = 50 # Keep this from previous config for training loop control
|
||||
|
||||
# V7.9 Experience Generation Config (Based on instructions.txt)
|
||||
# TOTAL_TRAINING_STEPS = 1000 # Removed - Not used in current training loop
|
||||
experience_config = {
|
||||
# Basic setup
|
||||
'initial_experiences': 3000, # Start with this many experiences
|
||||
'experiences_per_batch': 64, # Generate this many in each new batch
|
||||
'batch_generation_interval': 500, # Generate a new batch every N training steps
|
||||
|
||||
# Distribution control (Flags for future implementation in generate_trading_experiences)
|
||||
'balance_market_regimes': False, # Not implemented
|
||||
'recency_bias_strength': 0.5, # 0 = uniform, >0 weights recent data more
|
||||
'high_uncertainty_quantile': 0.75, # Threshold for high uncertainty
|
||||
'extreme_return_quantile': 0.1, # Threshold for extreme returns (upper/lower)
|
||||
'min_uncertainty_ratio': 0.2, # Min % of samples with high uncertainty
|
||||
'min_extreme_return_ratio': 0.1, # Min % of samples with extreme returns
|
||||
|
||||
# Efficient processing
|
||||
'use_parallel_generation': False, # Not implemented
|
||||
'precompute_all_gru_outputs': True, # Already implemented
|
||||
'buffer_update_strategy': 'fifo', # Agent currently uses FIFO
|
||||
|
||||
# Training optimization
|
||||
'training_iterations_per_step': 1, # Number of agent.train calls per main loop step
|
||||
# Max/Min buffer size are defined by the agent itself now
|
||||
}
|
||||
|
||||
# Backtesting Parameters
|
||||
INITIAL_CAPITAL = 10000.0
|
||||
TRANSACTION_COST = 0.0005
|
||||
# V7.12 Add Opportunity Cost Penalty Parameters
|
||||
OPPORTUNITY_COST_PENALTY_FACTOR = 0.0 # How much to penalize missed high returns - Disabled (was 1.0)
|
||||
HIGH_RETURN_THRESHOLD = 0.002 # Actual return magnitude threshold to trigger penalty check
|
||||
ACTION_TOLERANCE = 0.3 # Action magnitude below which penalty applies if return threshold met - Lowered from 0.5
|
||||
# RISK_PENALTY_FACTOR = 0.0 # Removed as state reverted
|
||||
|
||||
# Control Flags
|
||||
LOAD_EXISTING_SYSTEM = True
|
||||
TRAIN_GRU_MODEL = False
|
||||
TRAIN_SAC_AGENT = True # V7.8 Set to True to train SAC
|
||||
LOAD_SAC_AGENT = False # V7.8 Set to False to avoid loading SAC
|
||||
RUN_BACKTEST = True
|
||||
GENERATE_PLOTS = True
|
||||
GENERATE_REPORT = True
|
||||
# --- End Configuration ---
|
||||
|
||||
def main():
|
||||
# Access config variables defined at module level
|
||||
global LOAD_EXISTING_SYSTEM, TRAIN_GRU_MODEL, TRAIN_SAC_AGENT, LOAD_SAC_AGENT
|
||||
|
||||
logger.info(f"--- Starting GRU+SAC Trading System Pipeline (Run ID: {run_id}) ---") # Removed V7
|
||||
|
||||
# Ensure results directory exists
|
||||
os.makedirs(RUN_RESULTS_DIR, exist_ok=True)
|
||||
# Ensure base models directory exists (RUN_MODELS_DIR created later if training)
|
||||
os.makedirs(BASE_MODELS_DIR, exist_ok=True)
|
||||
|
||||
# LOAD_EXISTING_SYSTEM is now declared global before use here
|
||||
# --- Save Configuration ---
|
||||
config_to_save = {
|
||||
"run_id": run_id,
|
||||
"db_dir": DB_DIR,
|
||||
"ticker": TICKER,
|
||||
"exchange": EXCHANGE,
|
||||
"start_date": START_DATE,
|
||||
"end_date": END_DATE,
|
||||
"interval": INTERVAL,
|
||||
"model_save_path": MODEL_SAVE_PATH,
|
||||
"results_plot_path": RESULTS_PLOT_PATH,
|
||||
"report_save_path": REPORT_SAVE_PATH,
|
||||
"train_ratio": TRAIN_RATIO,
|
||||
"validation_ratio": VALIDATION_RATIO,
|
||||
"gru_lookback": GRU_LOOKBACK,
|
||||
"gru_prediction_horizon": GRU_PREDICTION_HORIZON,
|
||||
"gru_epochs": GRU_EPOCHS,
|
||||
"gru_batch_size": GRU_BATCH_SIZE,
|
||||
"gru_patience": GRU_PATIENCE,
|
||||
"gru_lr_factor": GRU_LR_FACTOR,
|
||||
"gru_return_scale": GRU_RETURN_SCALE,
|
||||
"gru_model_load_run_id": GRU_MODEL_LOAD_RUN_ID,
|
||||
"sac_state_dim": SAC_STATE_DIM,
|
||||
"sac_hidden_size": SAC_HIDDEN_SIZE,
|
||||
"sac_gamma": SAC_GAMMA,
|
||||
"sac_tau": SAC_TAU,
|
||||
"sac_actor_lr": SAC_ACTOR_LR,
|
||||
"sac_critic_lr": SAC_CRITIC_LR,
|
||||
"sac_batch_size": SAC_BATCH_SIZE,
|
||||
"sac_buffer_max_size": SAC_BUFFER_MAX_SIZE,
|
||||
"sac_min_buffer_size": SAC_MIN_BUFFER_SIZE,
|
||||
"sac_update_interval": SAC_UPDATE_INTERVAL,
|
||||
"sac_target_update_interval": SAC_TARGET_UPDATE_INTERVAL,
|
||||
"sac_gradient_clip": SAC_GRADIENT_CLIP,
|
||||
"sac_reward_scale": SAC_REWARD_SCALE,
|
||||
"sac_use_batch_norm": SAC_USE_BATCH_NORM,
|
||||
"sac_use_residual": SAC_USE_RESIDUAL,
|
||||
"sac_model_dir": SAC_MODEL_DIR,
|
||||
"sac_epochs": SAC_EPOCHS,
|
||||
"experience_config": experience_config,
|
||||
"initial_capital": INITIAL_CAPITAL,
|
||||
"transaction_cost": TRANSACTION_COST,
|
||||
# V7.12 Add new params to saved config
|
||||
"opportunity_cost_penalty_factor": OPPORTUNITY_COST_PENALTY_FACTOR,
|
||||
"high_return_threshold": HIGH_RETURN_THRESHOLD,
|
||||
"action_tolerance": ACTION_TOLERANCE,
|
||||
"load_existing_system": LOAD_EXISTING_SYSTEM,
|
||||
"train_gru_model": TRAIN_GRU_MODEL,
|
||||
"train_sac_agent": TRAIN_SAC_AGENT,
|
||||
"load_sac_agent": LOAD_SAC_AGENT,
|
||||
"run_backtest": RUN_BACKTEST,
|
||||
"generate_plots": GENERATE_PLOTS,
|
||||
"generate_report": GENERATE_REPORT
|
||||
}
|
||||
config_save_path = os.path.join(RUN_RESULTS_DIR, f'config_{run_id}.json')
|
||||
try:
|
||||
with open(config_save_path, 'w') as f:
|
||||
json.dump(config_to_save, f, indent=4)
|
||||
logger.info(f"Run configuration saved to {config_save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save run configuration: {e}")
|
||||
# --- End Save Configuration ---
|
||||
|
||||
# 1. Load Data from Database
|
||||
logger.info(f"Loading data from DB: {TICKER}/{EXCHANGE} ({START_DATE}-{END_DATE}) @ {INTERVAL}")
|
||||
data = load_data_from_db(
|
||||
db_dir=DB_DIR,
|
||||
ticker=TICKER,
|
||||
exchange=EXCHANGE,
|
||||
start_date=START_DATE,
|
||||
end_date=END_DATE,
|
||||
interval=INTERVAL
|
||||
)
|
||||
|
||||
if data.empty:
|
||||
logger.error("Failed to load data from database. Please check DB_DIR and parameters. Aborting.")
|
||||
return
|
||||
|
||||
# --- Re-inserted Steps Start ---
|
||||
# Basic Data Validation (Timestamp index assumed from load_data_from_db)
|
||||
if 'close' not in data.columns: # Check essential columns
|
||||
raise ValueError("Loaded data must contain 'close' column.")
|
||||
logger.info(f"Data loaded: {len(data)} rows, from {data.index.min()} to {data.index.max()}")
|
||||
initial_len = len(data); data.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True)
|
||||
if len(data) < initial_len: logger.info(f"Dropped {initial_len - len(data)} NaN rows.")
|
||||
if len(data) < GRU_LOOKBACK * 3: raise ValueError(f"Insufficient data ({len(data)} rows) for lookback/splits.")
|
||||
|
||||
# Add cyclical features immediately
|
||||
logger.info("Calculating cyclical time features (hour_sin, hour_cos)...")
|
||||
timestamp_source = None
|
||||
if isinstance(data.index, pd.DatetimeIndex):
|
||||
timestamp_source = data.index
|
||||
logger.debug("Using index for hour features.")
|
||||
elif 'timestamp' in data.columns and pd.api.types.is_datetime64_any_dtype(data['timestamp']):
|
||||
timestamp_source = pd.to_datetime(data['timestamp'])
|
||||
logger.debug("Using 'timestamp' column for hour features.")
|
||||
elif 'date' in data.columns and pd.api.types.is_datetime64_any_dtype(data['date']):
|
||||
timestamp_source = pd.to_datetime(data['date'])
|
||||
logger.debug("Using 'date' column for hour features.")
|
||||
|
||||
if timestamp_source is not None:
|
||||
data['hour_sin'] = np.sin(2 * np.pi * timestamp_source.hour / 24)
|
||||
data['hour_cos'] = np.cos(2 * np.pi * timestamp_source.hour / 24)
|
||||
logger.info("Added hour_sin/hour_cos to main dataframe.")
|
||||
else:
|
||||
logger.warning("Could not find suitable timestamp source. Setting hour_sin/cos defaults (0.0, 1.0).")
|
||||
data['hour_sin'] = 0.0
|
||||
data['hour_cos'] = 1.0 # Default to cos(0) = 1
|
||||
|
||||
# 2. Split Data Chronologically
|
||||
logger.info("Splitting data...")
|
||||
test_ratio = round(1.0 - TRAIN_RATIO - VALIDATION_RATIO, 2)
|
||||
if test_ratio <= 0: raise ValueError("Train+Validation ratios must sum to < 1.")
|
||||
train_data, val_data, test_data = create_data_pipeline(data, [TRAIN_RATIO, VALIDATION_RATIO, test_ratio])
|
||||
if len(train_data) < GRU_LOOKBACK or len(val_data) < GRU_LOOKBACK or len(test_data) < GRU_LOOKBACK:
|
||||
warnings.warn(f"Splits smaller than GRU lookback ({GRU_LOOKBACK}). Backtesting might fail.")
|
||||
|
||||
# 3. Initialize Trading System
|
||||
logger.info("Initializing Trading System...")
|
||||
trading_system = TradingSystem(
|
||||
gru_model=CryptoGRUModel(), # Instantiate the correct model
|
||||
sac_agent=SimplifiedSACTradingAgent(
|
||||
state_dim=SAC_STATE_DIM,
|
||||
hidden_size=SAC_HIDDEN_SIZE,
|
||||
gamma=SAC_GAMMA,
|
||||
tau=SAC_TAU,
|
||||
actor_lr=SAC_ACTOR_LR,
|
||||
critic_lr=SAC_CRITIC_LR,
|
||||
batch_size=SAC_BATCH_SIZE,
|
||||
buffer_max_size=SAC_BUFFER_MAX_SIZE,
|
||||
min_buffer_size=SAC_MIN_BUFFER_SIZE,
|
||||
update_interval=SAC_UPDATE_INTERVAL,
|
||||
target_update_interval=SAC_TARGET_UPDATE_INTERVAL,
|
||||
gradient_clip=SAC_GRADIENT_CLIP,
|
||||
reward_scale=SAC_REWARD_SCALE,
|
||||
use_batch_norm=SAC_USE_BATCH_NORM,
|
||||
use_residual=SAC_USE_RESIDUAL,
|
||||
model_dir=os.path.join(MODEL_SAVE_PATH, 'sac_agent') # Point to subfolder within run
|
||||
), # Pass the configured agent
|
||||
gru_lookback=GRU_LOOKBACK
|
||||
)
|
||||
|
||||
# --- Model Loading/Training ---
|
||||
gru_loaded = False; sac_loaded = False
|
||||
if LOAD_EXISTING_SYSTEM:
|
||||
load_base_path = MODEL_SAVE_PATH
|
||||
logger.info(f"Attempting to load existing system components...")
|
||||
logger.info(f"Base path for loading: {load_base_path}")
|
||||
|
||||
gru_model_load_dir = None
|
||||
sac_model_load_dir = None
|
||||
if GRU_MODEL_LOAD_RUN_ID:
|
||||
gru_model_load_dir = os.path.join(BASE_MODELS_DIR, f'run_{GRU_MODEL_LOAD_RUN_ID}')
|
||||
logger.info(f"Using specific GRU load path based on run ID: {gru_model_load_dir}")
|
||||
if LOAD_SAC_AGENT:
|
||||
sac_model_load_dir = os.path.join(BASE_MODELS_DIR, f'run_{GRU_MODEL_LOAD_RUN_ID}')
|
||||
logger.info(f"Using specific SAC load path based on GRU run ID (LOAD_SAC_AGENT=True): {sac_model_load_dir}")
|
||||
else:
|
||||
sac_model_load_dir = os.path.join(MODEL_SAVE_PATH, 'sac_agent')
|
||||
logger.info(f"Defaulting SAC path to current run (LOAD_SAC_AGENT=False): {sac_model_load_dir}")
|
||||
elif os.path.exists(load_base_path):
|
||||
gru_model_load_dir = os.path.join(load_base_path, 'gru_model')
|
||||
sac_model_load_dir = os.path.join(load_base_path, 'sac_agent')
|
||||
logger.info(f"Using GRU load path based on MODEL_SAVE_PATH: {gru_model_load_dir}")
|
||||
logger.info(f"Using SAC load path based on MODEL_SAVE_PATH: {sac_model_load_dir}")
|
||||
else:
|
||||
logger.warning(f"LOAD_EXISTING_SYSTEM is True, but MODEL_SAVE_PATH does not exist: {load_base_path}. Cannot determine model paths.")
|
||||
LOAD_EXISTING_SYSTEM = False
|
||||
|
||||
if LOAD_EXISTING_SYSTEM:
|
||||
try:
|
||||
if gru_model_load_dir and os.path.isdir(gru_model_load_dir):
|
||||
logger.info(f"Found GRU model directory: {gru_model_load_dir}. Loading...")
|
||||
if trading_system.gru_model is None: trading_system.gru_model = CryptoGRUModel()
|
||||
if trading_system.gru_model.load(gru_model_load_dir):
|
||||
logger.info("GRU model loaded successfully.")
|
||||
gru_loaded = True
|
||||
trading_system.feature_scaler = trading_system.gru_model.feature_scaler
|
||||
trading_system.y_scaler = trading_system.gru_model.y_scaler
|
||||
logger.info("Scalers propagated from loaded GRU model.")
|
||||
else: logger.warning(f"GRU model directory found, but loading failed.")
|
||||
elif gru_model_load_dir: logger.warning(f"GRU model directory specified or derived, but not found at {gru_model_load_dir}. GRU model cannot be loaded.")
|
||||
else: logger.warning("GRU model path could not be determined. GRU model cannot be loaded.")
|
||||
|
||||
if LOAD_SAC_AGENT:
|
||||
if sac_model_load_dir and os.path.isdir(sac_model_load_dir):
|
||||
logger.info(f"Found SAC model directory: {sac_model_load_dir}. Loading (LOAD_SAC_AGENT=True)...")
|
||||
if trading_system.sac_agent is None:
|
||||
trading_system.sac_agent = SimplifiedSACTradingAgent(state_dim=SAC_STATE_DIM, model_dir=sac_model_load_dir)
|
||||
if trading_system.sac_agent.load(sac_model_load_dir):
|
||||
logger.info("SAC agent loaded successfully.")
|
||||
sac_loaded = True
|
||||
else: logger.warning(f"SAC model directory found, but loading failed.")
|
||||
elif sac_model_load_dir: logger.warning(f"SAC agent model directory derived, but not found at {sac_model_load_dir}. SAC agent cannot be loaded (LOAD_SAC_AGENT=True).")
|
||||
else: logger.info("Skipping SAC agent loading (LOAD_SAC_AGENT=False).")
|
||||
|
||||
if gru_loaded: TRAIN_GRU_MODEL = False
|
||||
if sac_loaded: TRAIN_SAC_AGENT = False; LOAD_SAC_AGENT = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load existing system components: {e}. Proceeding based on training flags.")
|
||||
gru_loaded = False; sac_loaded = False
|
||||
TRAIN_GRU_MODEL = True; TRAIN_SAC_AGENT = True; LOAD_SAC_AGENT = False
|
||||
|
||||
elif LOAD_EXISTING_SYSTEM: pass
|
||||
else: logger.info("LOAD_EXISTING_SYSTEM=False. Proceeding with training flags.")
|
||||
|
||||
# --- Sanity Check After Loading ---
|
||||
if not gru_loaded and not TRAIN_GRU_MODEL:
|
||||
logger.error("Critical Error: GRU model was not loaded and TRAIN_GRU_MODEL is False. Cannot proceed.")
|
||||
return
|
||||
if not sac_loaded and not TRAIN_SAC_AGENT:
|
||||
if RUN_BACKTEST:
|
||||
logger.error("Critical Error: SAC agent was not loaded and TRAIN_SAC_AGENT is False. Aborting because RUN_BACKTEST is True.")
|
||||
return
|
||||
else: logger.warning("Proceeding without a functional SAC agent as RUN_BACKTEST is False.")
|
||||
|
||||
# Train GRU Model (if flag is set and not loaded)
|
||||
if TRAIN_GRU_MODEL:
|
||||
logger.info("--- Training GRU Model --- ")
|
||||
gru_save_dir = MODEL_SAVE_PATH
|
||||
history = trading_system.train_gru(
|
||||
train_data=train_data, val_data=val_data,
|
||||
prediction_horizon=GRU_PREDICTION_HORIZON,
|
||||
epochs=GRU_EPOCHS, batch_size=GRU_BATCH_SIZE,
|
||||
patience=GRU_PATIENCE,
|
||||
model_save_dir=gru_save_dir
|
||||
)
|
||||
if history is None: logger.error("GRU Training failed. Aborting."); return
|
||||
logger.info("--- GRU Model Training Finished --- ")
|
||||
elif not gru_loaded: logger.error("GRU Model must be trained or loaded."); return
|
||||
else: logger.info("Skipping GRU training (already loaded).")
|
||||
|
||||
# Train SAC Agent (if flag is set and not loaded)
|
||||
if TRAIN_SAC_AGENT:
|
||||
logger.info("--- Training SAC Agent --- ")
|
||||
if not trading_system.gru_model or not (trading_system.gru_model.is_trained or trading_system.gru_model.is_loaded):
|
||||
logger.error("Cannot train SAC: GRU model not ready."); return
|
||||
|
||||
if trading_system.sac_agent is None: logger.error("SAC Agent instance is missing in the trading system before training."); return
|
||||
trading_system.sac_agent.model_dir = os.path.join(MODEL_SAVE_PATH, 'sac_agent')
|
||||
logger.info(f"Ensured SAC agent model save dir is set to: {trading_system.sac_agent.model_dir}")
|
||||
|
||||
sac_history = trading_system.train_sac(
|
||||
val_data=val_data,
|
||||
epochs=SAC_EPOCHS,
|
||||
batch_size=SAC_BATCH_SIZE,
|
||||
transaction_cost=TRANSACTION_COST,
|
||||
prediction_horizon=GRU_PREDICTION_HORIZON
|
||||
)
|
||||
logger.info("Finished training SAC agent.")
|
||||
|
||||
if sac_history is not None:
|
||||
sac_save_dir = os.path.join(MODEL_SAVE_PATH, 'sac_agent')
|
||||
logger.info(f"Saving Simplified SAC agent to {sac_save_dir}")
|
||||
trading_system.sac_agent.save(sac_save_dir)
|
||||
|
||||
if sac_history:
|
||||
sac_plot_save_path = os.path.join(RUN_RESULTS_DIR, f'sac_training_history_{run_id}.png')
|
||||
logger.info(f"Plotting SAC training history to {sac_plot_save_path}...")
|
||||
try: plot_sac_training_history(sac_history, save_path=sac_plot_save_path)
|
||||
except Exception as plot_e: logger.error(f"Failed to plot SAC training history: {plot_e}", exc_info=True)
|
||||
else: logger.warning("SAC training finished, but no history data returned for plotting.")
|
||||
|
||||
elif not sac_loaded and LOAD_SAC_AGENT:
|
||||
# This block handles loading SAC if LOAD_EXISTING_SYSTEM was False but LOAD_SAC_AGENT was True (unlikely case)
|
||||
if trading_system.sac_agent is None: trading_system.sac_agent = SimplifiedSACTradingAgent(state_dim=SAC_STATE_DIM)
|
||||
sac_load_path = os.path.join(MODEL_SAVE_PATH, 'sac_agent') # Load from current run models
|
||||
if os.path.isdir(sac_load_path):
|
||||
logger.info(f"Attempting to load SAC weights from {sac_load_path} (LOAD_SAC_AGENT=True)...")
|
||||
try: trading_system.sac_agent.load(sac_load_path); logger.info("SAC weights loaded."); sac_loaded = True
|
||||
except Exception as e: logger.warning(f"Could not load SAC weights: {e}")
|
||||
else: logger.warning(f"LOAD_SAC_AGENT=True but no weights found at {sac_load_path}.")
|
||||
elif not sac_loaded: logger.warning("SAC Agent not trained or loaded.")
|
||||
else: logger.info("Skipping SAC training (already loaded).")
|
||||
|
||||
# 5. Backtest on Test Data
|
||||
if RUN_BACKTEST:
|
||||
logger.info("--- Running Extended Backtest --- ")
|
||||
if not trading_system.gru_model or not (trading_system.gru_model.is_trained or trading_system.gru_model.is_loaded):
|
||||
logger.error("Cannot backtest: GRU model not ready."); return
|
||||
if not trading_system.sac_agent: logger.error("Cannot backtest: SAC Agent not initialized."); return
|
||||
|
||||
instrument_label = f"{TICKER}/{EXCHANGE}"
|
||||
backtester = ExtendedBacktester(
|
||||
trading_system,
|
||||
initial_capital=INITIAL_CAPITAL,
|
||||
transaction_cost=TRANSACTION_COST,
|
||||
instrument_label=instrument_label
|
||||
)
|
||||
backtest_results = backtester.backtest(test_data, verbose=True)
|
||||
|
||||
# 6. Generate Plots and Report
|
||||
if GENERATE_PLOTS:
|
||||
logger.info(f"Generating overall performance plot: {RESULTS_PLOT_PATH}...")
|
||||
backtester.plot_results(save_path=RESULTS_PLOT_PATH)
|
||||
if GENERATE_REPORT:
|
||||
logger.info(f"Generating performance report: {REPORT_SAVE_PATH}...")
|
||||
backtester.generate_performance_report(report_path=REPORT_SAVE_PATH)
|
||||
else:
|
||||
logger.info("Skipping backtesting.")
|
||||
# --- Re-inserted Steps End ---
|
||||
|
||||
logger.info("--- GRU+SAC Pipeline Finished --- ")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
gru_sac_predictor/main_v7.log
Normal file
0
gru_sac_predictor/main_v7.log
Normal file
BIN
gru_sac_predictor/models/run_20250416_142744/actor.weights.h5
Normal file
BIN
gru_sac_predictor/models/run_20250416_142744/actor.weights.h5
Normal file
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_142744/critic1.weights.h5
Normal file
BIN
gru_sac_predictor/models/run_20250416_142744/critic1.weights.h5
Normal file
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_142744/critic2.weights.h5
Normal file
BIN
gru_sac_predictor/models/run_20250416_142744/critic2.weights.h5
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 76 KiB |
BIN
gru_sac_predictor/models/run_20250416_142744/log_alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_142744/log_alpha.npy
Normal file
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_142744/y_scaler.joblib
Normal file
BIN
gru_sac_predictor/models/run_20250416_142744/y_scaler.joblib
Normal file
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_144757/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_144757/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_145128/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_145128/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_150829/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_150829/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_150924/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_150924/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_151322/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_151322/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_151849/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_151849/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_152415/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_152415/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_153132/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_153132/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_153846/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_153846/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
gru_sac_predictor/models/run_20250416_154636/sac_agent/alpha.npy
Normal file
BIN
gru_sac_predictor/models/run_20250416_154636/sac_agent/alpha.npy
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user