initial commit for predictor module

This commit is contained in:
Yasha Sheynin 2025-04-16 16:50:59 -04:00
commit f598141500
166 changed files with 6672 additions and 0 deletions

141
gru_sac_predictor/README.md Normal file
View File

@ -0,0 +1,141 @@
# v7 - GRU + Simplified SAC Trading Agent (V6 GRU Adaptation)
This project implements a cryptocurrency trading system using a GRU model for price prediction and a **Simplified SAC (Soft Actor-Critic)** agent for position sizing.
The system predicts future *price* using a GRU model adapted from the V6 architecture. It calculates the *predicted percentage return* from this price prediction and estimates prediction *uncertainty* based on the standard deviation of Monte Carlo dropout predictions. These two values (`predicted_return`, `mc_unscaled_std_dev`) form the state input to the SAC reinforcement learning agent, which determines optimal position sizing (-1 to +1).
The system incorporates efficiency improvements by pre-computing GRU predictions and uncertainties before generating SAC experiences or running the backtest. It includes detailed backtesting, performance reporting, and visualization capabilities.
## System Design
The system integrates a GRU predictor and a Simplified SAC agent within a backtesting framework.
### 1. Data Flow & Processing
1. **Loading:** Raw 1-minute OHLCV data is loaded from the SQLite database directory specified in `main.py` (e.g., `downloaded_data/`) using `src.data_pipeline.load_data_from_db` which utilizes `src.crypto_db_fetcher.CryptoDBFetcher`.
2. **Splitting:** Data is chronologically split into training (60%), validation (20%), and test (20%) sets using `src.data_pipeline.create_data_pipeline`.
3. **GRU Training / Loading (on Train/Validation Sets):**
* If `TRAIN_GRU_MODEL` is `True`:
* *Preprocessing*: `TradingSystem._preprocess_data_for_gru_training` calculates V6 features plus basic return features (`calculate_v6_features`) on the raw train/val data. It determines the future *price* target (`prediction_horizon` steps ahead) and aligns features, targets (prices), and the *unscaled* starting close prices needed for return calculation.
* *Scaling*: Within `TradingSystem.train_gru`, a `StandardScaler` is fitted *only* on the training features. A `MinMaxScaler` is fitted *only* on the training future *price* targets. Train and validation features/targets are scaled using these fitted scalers.
* *Sequence Creation*: `src.data_pipeline.create_sequences_v2` creates input sequences `(batch, sequence_length, num_features)` and corresponding scaled target prices using the scaled features/targets and the unscaled start prices.
* *Model Training*: `CryptoGRUModel.train` builds the V6-style GRU model (if not already built) and trains it using Mean Squared Error (MSE) loss on the scaled sequences. Callbacks monitor `val_rmse` for early stopping and model checkpointing. The best model (`best_model_reg.keras`) and the fitted scalers (`feature_scaler.joblib`, `y_scaler.joblib`) are saved.
* If `LOAD_EXISTING_SYSTEM` is `True` and `TRAIN_GRU_MODEL` is `False`: Attempts to load a pre-trained GRU model and scalers. If `GRU_MODEL_LOAD_RUN_ID` is set in `main.py`, it loads from that specific run ID's directory (`v7/models/run_<run_id>`); otherwise, it attempts to load from the default `MODEL_SAVE_PATH` (expecting a `gru_model` subdirectory).
4. **SAC Training (on Validation Set):**
* **Training Loop:** The training process runs for a fixed number of agent update steps (`TOTAL_TRAINING_STEPS`) instead of epochs.
* **Experience Generation** (`TradingSystem.generate_trading_experiences`):
* **Efficiency:** Pre-computes all required GRU outputs (predicted returns, uncertainties) for the entire validation set by calling `CryptoGRUModel.evaluate` *once*.
* **Initial Fill:** Generates an initial set of experiences (`experience_config['initial_experiences']`). Uses the sampling strategy.
* **Sampling (`_sample_experience_indices`):** When generating a specific number of experiences (initial fill or periodic updates), it applies **weighted sampling** (controlled by `recency_bias_strength`) and **stratified sampling** (ensuring minimum ratios `min_uncertainty_ratio`, `min_extreme_return_ratio` of high uncertainty/extreme return examples based on quantiles `high_uncertainty_quantile`, `extreme_return_quantile`) based on parameters in `experience_config`.
* **Experience Format:** Iterates through the (potentially sampled) pre-computed results. Forms the state `s_t = [predicted_return_t, uncertainty_t]`. The SAC agent (`SimplifiedSACTradingAgent.get_action`) provides a *non-deterministic* action `a_t`. The next state `s_{t+1}` is retrieved. A reward `r_t = action * actual_return` is calculated (transaction costs are currently ignored in reward calculation during generation for simplicity). The transition `(s_t, a_t, r_t, s_{t+1}, done=False)` is created.
* **Periodic Updates:** During the main training loop (controlled by `total_training_steps`), new batches of experiences (`experience_config['experiences_per_batch']`) are generated periodically (every `experience_config['batch_generation_interval']` loop steps) using the sampling strategy and added to the replay buffer.
* **Agent Training** (`SimplifiedSACTradingAgent.train`): In each step of the main training loop, the agent performs `experience_config['training_iterations_per_step']` update(s). Batches are sampled from the replay buffer. Actor and Critic networks are updated using the SAC algorithm. The agent uses a standard FIFO circular buffer for experience storage.
5. **Backtesting (on Test Set):**
* *Pre-computation* (`ExtendedBacktester.backtest`): Similar to SAC training, preprocesses the test data, scales it, creates sequences, and calls `CryptoGRUModel.evaluate` *once* to get all predicted returns and uncertainties for the test set.
* *Iteration*: Steps chronologically through the pre-computed results.
* *State Generation*: Retrieves `predicted_return` and `uncertainty_sigma` from the pre-computed arrays to form the state `s_t`.
* *Action Selection*: The trained `SimplifiedSACTradingAgent` selects a *deterministic* action `a_t`.
* *Portfolio Simulation*: Calculates PnL based on the previous position held (`current_position`), the actual return over the step, and subtracts transaction costs based on the change in position (`abs(action - current_position)`).
* *Logging*: Records detailed metrics, trade history, and timestamps.
6. **Evaluation:**
* *Performance Metrics*: `ExtendedBacktester._calculate_performance_metrics` computes overall portfolio metrics (Sharpe, Sortino, Drawdown, correlations, etc.) and Buy & Hold benchmark metrics.
* *Visualization*: `ExtendedBacktester.plot_results` generates a 3-panel plot: GRU Predictions vs Actual Price (with uncertainty), SAC Actions (Position Size), and Portfolio Value vs Buy & Hold (with trade markers).
* *Reporting*: `ExtendedBacktester.generate_performance_report` creates a detailed Markdown report.
### 2. Core Components & Inputs/Outputs
* **`src.crypto_db_fetcher.CryptoDBFetcher`**: Loads and resamples data from SQLite DBs.
* **`src.data_pipeline`**: Functions for DB loading, data splitting, sequence creation.
* **`src.trading_system.calculate_v6_features`**: Calculates features (TA-Lib based V6 set + past returns).
* **`src.trading_system._preprocess_data_for_gru_training`**: Prepares features, future price targets, and start prices.
* **`src.gru_predictor.CryptoGRUModel`**: (V6 Adaptation)
* `train()`: Trains the GRU price prediction model. Saves model (`.keras`) and scalers (`.joblib`).
* `evaluate()`: Performs standard prediction and MC dropout inference. Returns dict including `pred_percent_change`, `mc_unscaled_std_dev`, `predicted_unscaled_prices`, `true_unscaled_prices`.
* **`src.sac_agent_simplified.SimplifiedSACTradingAgent`**: (V7 Simplified)
* **Goal:** Learns a policy mapping state to optimal position size (-1.0 to +1.0). Optimized for faster training.
* **State Input:** 2-element array `[predicted_return, mc_unscaled_std_dev]`.
* **Action Output:** Float between -1.0 and +1.0.
* `get_action()`: Selects action (stochastic or deterministic). Adds uncertainty-scaled noise during exploration.
* `store_transition()`: Adds experience to internal NumPy buffer.
* `train()`: Updates agent using buffer samples (internally handles batch size). Uses `@tf.function` for performance.
* `save()` / `load()`: Handles Actor/Critic weights (`.weights.h5`), potentially `alpha.npy`.
* **Note:** Models and optimizers are built explicitly during `__init__` to prevent TensorFlow graph mode issues.
* **`src.trading_system.TradingSystem`**: Integrates GRU and SAC. Manages training pipelines, experience generation (including advanced sampling).
* **`src.trading_system.ExtendedBacktester`**: Performs efficient backtesting using pre-computed GRU outputs, calculates metrics, plots results, generates reports.
### 3. Model Architectures
* **GRU (`src.gru_predictor.CryptoGRUModel._build_model`)**: V6 Architecture.
* Input -> GRU(100) -> Dropout(0.2) -> Dense(1, linear).
* Compiled with Adam (LR=0.001), MSE loss.
* **Simplified SAC (`src.sac_agent_simplified.SimplifiedSACTradingAgent`)**:
* **Actor Network**: MLP `(state_dim=2)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, tanh).
* **Critic Network (x2)**: MLP `(state_dim=2 + action_dim=1)` -> Dense(64, relu) -> [BN] -> Dense(64, relu) -> [BN] -> [Residual] -> Dense(1, linear).
* **Algorithm**: Implements SAC with Clipped Double-Q, fixed alpha (tunable via `SAC_ALPHA`), faster learning rates, smaller networks/buffer, optional Batch Normalization / Residual connections. Uses Huber loss for critics. No distributional critics. `@tf.function` used for update steps.
### 4. Features
The GRU model uses the V6 feature set plus basic past returns:
* **TA-Lib Indicators & Derived Indicators:** SMA, EMA, MACD, SAR, ADX, RSI, Stochastics, WILLR, ROC, CCI, BBands, ATR, OBV, CMF, etc. (Requires TA-Lib installation). Fallback calculations for SMA, EMA, RSI if TA-Lib is unavailable.
* **Custom Crypto Features:** Parkinson Volatility, Garman-Klass Volatility, VWAP ratios, Volume Intensity, Wick Ratios.
* **Past Returns:** `return_1m`, `return_5m`, `return_15m`, `return_60m` (percentage change).
* **Scaling:** Features scaled with `StandardScaler` (fitted on train). Target variable (future price) scaled with `MinMaxScaler` (fitted on train).
### 5. Evaluation
* **GRU Model:** Evaluated using RMSE loss on validation set. Callbacks monitor `val_rmse`. Plots compare predicted vs actual price.
* **SAC Agent & Overall System:** Evaluated via the `ExtendedBacktester` metrics (Sharpe, Sortino, Max Drawdown, correlations, etc.), plots (Portfolio vs B&H, Actions), and a final Markdown report.
## File Structure
- `data/`: *Not used by default if loading from DB.*
- `downloaded_data/`: **Place your V6 SQLite database files here.** (Or update `DB_DIR` in `main.py`).
- `models/`: Trained models (GRU + scalers, SAC weights) saved here under `run_<run_id>/` directories by default.
- `results/`: Backtest results (plots, reports, config) saved here under `<run_id>/` directories.
- `logs/`: Log files saved here under `<run_id>/` directories.
- `src/`: Core Python modules.
- `crypto_db_fetcher.py`: Class for fetching data from SQLite DBs.
- `data_pipeline.py`: DB loading function, data splitting, sequence creation.
- `gru_predictor.py`: V6-style GRU model for price regression and MC uncertainty.
- `sac_agent_simplified.py`: Simplified SAC agent implementation (V7.5+).
- `sac_agent.py`: Original SAC agent implementation (potentially outdated).
- `trading_system.py`: Integration class, feature calculation, scaling, experience generation, `ExtendedBacktester` class.
- `main.py`: Main script using DB loading, orchestrates training and backtesting.
- `requirements.txt`: Dependencies.
- `v7_instructions.txt`: Design notes for Simplified SAC.
- `experience_instructions.txt`: Design notes for experience generation.
- `README.md`: This file.
## Setup
1. **Data:** Place your V6 `downloaded_data` directory containing the SQLite files relative to the `v7` project root, or update the `DB_DIR` variable in `main.py` to point to the correct location.
2. **Dependencies:** Install required packages:
```bash
pip install -r requirements.txt
```
*Strongly Recommended:* Install TA-Lib for the full feature set. See TA-Lib installation guides for your OS.
3. **Configuration:** Review and adjust parameters in `main.py`. Key parameters include:
* `DB_DIR`, `TICKER`, `EXCHANGE`, `START_DATE`, `END_DATE`, `INTERVAL`
* Model hyperparameters (GRU and SAC sections)
* Control Flags: `LOAD_EXISTING_SYSTEM`, `TRAIN_GRU_MODEL`, `TRAIN_SAC_AGENT`, `LOAD_SAC_AGENT`
* Loading Specific Models: `GRU_MODEL_LOAD_RUN_ID` (set to a specific run ID string like `'YYYYMMDD_HHMMSS'` to load that GRU model from `v7/models/run_<run_id>/`). Note: This expects GRU and SAC files to be in the *same* directory if loading this way.
* SAC Training: `TOTAL_TRAINING_STEPS` defines the length of SAC training (number of agent `train()` calls).
* Experience Generation: `experience_config` dictionary controls initial fill, periodic updates, and sampling strategies (recency bias, stratification for uncertainty/extreme returns).
* Backtesting: `INITIAL_CAPITAL`, `TRANSACTION_COST`.
4. **Run:** Execute from the project root directory (containing the `v7` folder):
```bash
python -m v7.main
```
Output files (logs, models, plots, report) will be generated in `v7/logs/`, `v7/models/`, and `v7/results/` within run-specific subdirectories.
## Reporting
The report generated by the `ExtendedBacktester` includes performance metrics, correlation analyses, and configuration details. Key metrics include:
* Total/Annualized Return
* Sharpe & Sortino Ratios
* Volatility & Max Drawdown
* Buy & Hold Comparison
* Position/Prediction Accuracy
* Prediction/Position/Uncertainty Correlations
* Total Trades

Binary file not shown.

View File

465
gru_sac_predictor/main.py Normal file
View File

@ -0,0 +1,465 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
import warnings
import logging
import sys
import json
# --- Generate Run ID ---
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
# Import components
# V7 Update: Import load_data_from_db
from .src.data_pipeline import create_data_pipeline, load_data_from_db
from .src.trading_system import TradingSystem, ExtendedBacktester, plot_sac_training_history
# V7.3 Fix: Add missing imports
# V7-V6 Final Update: Import CryptoGRUModel
from .src.gru_predictor import CryptoGRUModel
# V7.5 Import the simplified agent
from .src.sac_agent_simplified import SimplifiedSACTradingAgent
# GRU and SAC classes are implicitly imported via TradingSystem
# --- Base Output Directories ---
BASE_RESULTS_DIR = "gru_sac_predictor/results"
BASE_LOGS_DIR = "gru_sac_predictor/logs"
BASE_MODELS_DIR = "gru_sac_predictor/models"
# --- Run Specific Directories ---
RUN_RESULTS_DIR = os.path.join(BASE_RESULTS_DIR, run_id)
RUN_LOGS_DIR = os.path.join(BASE_LOGS_DIR, run_id)
RUN_MODELS_DIR = os.path.join(BASE_MODELS_DIR, f"run_{run_id}")
# --- Logging Setup ---
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# Ensure logs directory exists
os.makedirs(RUN_LOGS_DIR, exist_ok=True)
log_file_path = os.path.join(RUN_LOGS_DIR, f"main_{run_id}.log") # Removed _v7
logging.basicConfig(
level=logging.INFO,
format=log_format,
handlers=[
logging.FileHandler(log_file_path, mode='a'), # Use path variable
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# --- Configuration ---
# V7 Update: Add DB parameters
DB_DIR = '../downloaded_data' # V7 Fix: Point to correct relative path for V6 data
TICKER = 'BTC-USD' # Example ticker
EXCHANGE = 'COINBASE' # Example exchange
START_DATE = '2025-03-01' # Example start date - NOTE: VERY SHORT!
END_DATE = '2025-03-10' # Example end date - NOTE: VERY SHORT!
INTERVAL = '1min' # Data interval to fetch and use
MODEL_SAVE_PATH = RUN_MODELS_DIR # Use run-specific directory
# Updated paths to use RUN_RESULTS_DIR and include run_id
RESULTS_PLOT_PATH = os.path.join(RUN_RESULTS_DIR, f'backtest_results_{run_id}.png') # Removed _v7
REPORT_SAVE_PATH = os.path.join(RUN_RESULTS_DIR, f'backtest_performance_report_{run_id}.md') # Removed _v7
# GRU_PLOT_PATH = 'gru_performance_v7.png' # Not used directly in main
# V7.6 Add specific run ID for loading GRU model
GRU_MODEL_LOAD_RUN_ID = '20250416_142744' # Set this to a specific 'YYYYMMDD_HHMMSS' string to load that GRU model
# Data split ratios
TRAIN_RATIO = 0.6
VALIDATION_RATIO = 0.2
# Model/Training Parameters (V7.3)
GRU_LOOKBACK = 60
GRU_PREDICTION_HORIZON = 1
GRU_EPOCHS = 20
GRU_BATCH_SIZE = 32 # Updated default
GRU_PATIENCE = 10 # Updated default
GRU_LR_PATIENCE = 10 # Updated default
GRU_LR_FACTOR = 0.5 # Updated default
GRU_RETURN_SCALE = 0.03 # Updated default
# SAC Parameters (V7.5 - Simplified Agent)
SAC_STATE_DIM = 5 # [pred_return, uncertainty, z, momentum_5, volatility_20] - Updated from 2
SAC_HIDDEN_SIZE = 64
SAC_GAMMA = 0.97
SAC_TAU = 0.02
# SAC_ALPHA = 0.1 # Removed - Will use automatic tuning
SAC_ACTOR_LR = 3e-4 # Lowered from 5e-4
SAC_CRITIC_LR = 5e-4 # Lowered from 8e-4
SAC_BATCH_SIZE = 64
SAC_BUFFER_MAX_SIZE = 20000
SAC_MIN_BUFFER_SIZE = 1000
SAC_UPDATE_INTERVAL = 1
SAC_TARGET_UPDATE_INTERVAL = 2
SAC_GRADIENT_CLIP = 1.0
SAC_REWARD_SCALE = 2.0 # Decreased from 10.0
SAC_USE_BATCH_NORM = True
SAC_USE_RESIDUAL = True
SAC_MODEL_DIR = 'models/simplified_sac' # Default dir within the agent class
SAC_EPOCHS = 50 # Keep this from previous config for training loop control
# V7.9 Experience Generation Config (Based on instructions.txt)
# TOTAL_TRAINING_STEPS = 1000 # Removed - Not used in current training loop
experience_config = {
# Basic setup
'initial_experiences': 3000, # Start with this many experiences
'experiences_per_batch': 64, # Generate this many in each new batch
'batch_generation_interval': 500, # Generate a new batch every N training steps
# Distribution control (Flags for future implementation in generate_trading_experiences)
'balance_market_regimes': False, # Not implemented
'recency_bias_strength': 0.5, # 0 = uniform, >0 weights recent data more
'high_uncertainty_quantile': 0.75, # Threshold for high uncertainty
'extreme_return_quantile': 0.1, # Threshold for extreme returns (upper/lower)
'min_uncertainty_ratio': 0.2, # Min % of samples with high uncertainty
'min_extreme_return_ratio': 0.1, # Min % of samples with extreme returns
# Efficient processing
'use_parallel_generation': False, # Not implemented
'precompute_all_gru_outputs': True, # Already implemented
'buffer_update_strategy': 'fifo', # Agent currently uses FIFO
# Training optimization
'training_iterations_per_step': 1, # Number of agent.train calls per main loop step
# Max/Min buffer size are defined by the agent itself now
}
# Backtesting Parameters
INITIAL_CAPITAL = 10000.0
TRANSACTION_COST = 0.0005
# V7.12 Add Opportunity Cost Penalty Parameters
OPPORTUNITY_COST_PENALTY_FACTOR = 0.0 # How much to penalize missed high returns - Disabled (was 1.0)
HIGH_RETURN_THRESHOLD = 0.002 # Actual return magnitude threshold to trigger penalty check
ACTION_TOLERANCE = 0.3 # Action magnitude below which penalty applies if return threshold met - Lowered from 0.5
# RISK_PENALTY_FACTOR = 0.0 # Removed as state reverted
# Control Flags
LOAD_EXISTING_SYSTEM = True
TRAIN_GRU_MODEL = False
TRAIN_SAC_AGENT = True # V7.8 Set to True to train SAC
LOAD_SAC_AGENT = False # V7.8 Set to False to avoid loading SAC
RUN_BACKTEST = True
GENERATE_PLOTS = True
GENERATE_REPORT = True
# --- End Configuration ---
def main():
# Access config variables defined at module level
global LOAD_EXISTING_SYSTEM, TRAIN_GRU_MODEL, TRAIN_SAC_AGENT, LOAD_SAC_AGENT
logger.info(f"--- Starting GRU+SAC Trading System Pipeline (Run ID: {run_id}) ---") # Removed V7
# Ensure results directory exists
os.makedirs(RUN_RESULTS_DIR, exist_ok=True)
# Ensure base models directory exists (RUN_MODELS_DIR created later if training)
os.makedirs(BASE_MODELS_DIR, exist_ok=True)
# LOAD_EXISTING_SYSTEM is now declared global before use here
# --- Save Configuration ---
config_to_save = {
"run_id": run_id,
"db_dir": DB_DIR,
"ticker": TICKER,
"exchange": EXCHANGE,
"start_date": START_DATE,
"end_date": END_DATE,
"interval": INTERVAL,
"model_save_path": MODEL_SAVE_PATH,
"results_plot_path": RESULTS_PLOT_PATH,
"report_save_path": REPORT_SAVE_PATH,
"train_ratio": TRAIN_RATIO,
"validation_ratio": VALIDATION_RATIO,
"gru_lookback": GRU_LOOKBACK,
"gru_prediction_horizon": GRU_PREDICTION_HORIZON,
"gru_epochs": GRU_EPOCHS,
"gru_batch_size": GRU_BATCH_SIZE,
"gru_patience": GRU_PATIENCE,
"gru_lr_factor": GRU_LR_FACTOR,
"gru_return_scale": GRU_RETURN_SCALE,
"gru_model_load_run_id": GRU_MODEL_LOAD_RUN_ID,
"sac_state_dim": SAC_STATE_DIM,
"sac_hidden_size": SAC_HIDDEN_SIZE,
"sac_gamma": SAC_GAMMA,
"sac_tau": SAC_TAU,
"sac_actor_lr": SAC_ACTOR_LR,
"sac_critic_lr": SAC_CRITIC_LR,
"sac_batch_size": SAC_BATCH_SIZE,
"sac_buffer_max_size": SAC_BUFFER_MAX_SIZE,
"sac_min_buffer_size": SAC_MIN_BUFFER_SIZE,
"sac_update_interval": SAC_UPDATE_INTERVAL,
"sac_target_update_interval": SAC_TARGET_UPDATE_INTERVAL,
"sac_gradient_clip": SAC_GRADIENT_CLIP,
"sac_reward_scale": SAC_REWARD_SCALE,
"sac_use_batch_norm": SAC_USE_BATCH_NORM,
"sac_use_residual": SAC_USE_RESIDUAL,
"sac_model_dir": SAC_MODEL_DIR,
"sac_epochs": SAC_EPOCHS,
"experience_config": experience_config,
"initial_capital": INITIAL_CAPITAL,
"transaction_cost": TRANSACTION_COST,
# V7.12 Add new params to saved config
"opportunity_cost_penalty_factor": OPPORTUNITY_COST_PENALTY_FACTOR,
"high_return_threshold": HIGH_RETURN_THRESHOLD,
"action_tolerance": ACTION_TOLERANCE,
"load_existing_system": LOAD_EXISTING_SYSTEM,
"train_gru_model": TRAIN_GRU_MODEL,
"train_sac_agent": TRAIN_SAC_AGENT,
"load_sac_agent": LOAD_SAC_AGENT,
"run_backtest": RUN_BACKTEST,
"generate_plots": GENERATE_PLOTS,
"generate_report": GENERATE_REPORT
}
config_save_path = os.path.join(RUN_RESULTS_DIR, f'config_{run_id}.json')
try:
with open(config_save_path, 'w') as f:
json.dump(config_to_save, f, indent=4)
logger.info(f"Run configuration saved to {config_save_path}")
except Exception as e:
logger.error(f"Failed to save run configuration: {e}")
# --- End Save Configuration ---
# 1. Load Data from Database
logger.info(f"Loading data from DB: {TICKER}/{EXCHANGE} ({START_DATE}-{END_DATE}) @ {INTERVAL}")
data = load_data_from_db(
db_dir=DB_DIR,
ticker=TICKER,
exchange=EXCHANGE,
start_date=START_DATE,
end_date=END_DATE,
interval=INTERVAL
)
if data.empty:
logger.error("Failed to load data from database. Please check DB_DIR and parameters. Aborting.")
return
# --- Re-inserted Steps Start ---
# Basic Data Validation (Timestamp index assumed from load_data_from_db)
if 'close' not in data.columns: # Check essential columns
raise ValueError("Loaded data must contain 'close' column.")
logger.info(f"Data loaded: {len(data)} rows, from {data.index.min()} to {data.index.max()}")
initial_len = len(data); data.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True)
if len(data) < initial_len: logger.info(f"Dropped {initial_len - len(data)} NaN rows.")
if len(data) < GRU_LOOKBACK * 3: raise ValueError(f"Insufficient data ({len(data)} rows) for lookback/splits.")
# Add cyclical features immediately
logger.info("Calculating cyclical time features (hour_sin, hour_cos)...")
timestamp_source = None
if isinstance(data.index, pd.DatetimeIndex):
timestamp_source = data.index
logger.debug("Using index for hour features.")
elif 'timestamp' in data.columns and pd.api.types.is_datetime64_any_dtype(data['timestamp']):
timestamp_source = pd.to_datetime(data['timestamp'])
logger.debug("Using 'timestamp' column for hour features.")
elif 'date' in data.columns and pd.api.types.is_datetime64_any_dtype(data['date']):
timestamp_source = pd.to_datetime(data['date'])
logger.debug("Using 'date' column for hour features.")
if timestamp_source is not None:
data['hour_sin'] = np.sin(2 * np.pi * timestamp_source.hour / 24)
data['hour_cos'] = np.cos(2 * np.pi * timestamp_source.hour / 24)
logger.info("Added hour_sin/hour_cos to main dataframe.")
else:
logger.warning("Could not find suitable timestamp source. Setting hour_sin/cos defaults (0.0, 1.0).")
data['hour_sin'] = 0.0
data['hour_cos'] = 1.0 # Default to cos(0) = 1
# 2. Split Data Chronologically
logger.info("Splitting data...")
test_ratio = round(1.0 - TRAIN_RATIO - VALIDATION_RATIO, 2)
if test_ratio <= 0: raise ValueError("Train+Validation ratios must sum to < 1.")
train_data, val_data, test_data = create_data_pipeline(data, [TRAIN_RATIO, VALIDATION_RATIO, test_ratio])
if len(train_data) < GRU_LOOKBACK or len(val_data) < GRU_LOOKBACK or len(test_data) < GRU_LOOKBACK:
warnings.warn(f"Splits smaller than GRU lookback ({GRU_LOOKBACK}). Backtesting might fail.")
# 3. Initialize Trading System
logger.info("Initializing Trading System...")
trading_system = TradingSystem(
gru_model=CryptoGRUModel(), # Instantiate the correct model
sac_agent=SimplifiedSACTradingAgent(
state_dim=SAC_STATE_DIM,
hidden_size=SAC_HIDDEN_SIZE,
gamma=SAC_GAMMA,
tau=SAC_TAU,
actor_lr=SAC_ACTOR_LR,
critic_lr=SAC_CRITIC_LR,
batch_size=SAC_BATCH_SIZE,
buffer_max_size=SAC_BUFFER_MAX_SIZE,
min_buffer_size=SAC_MIN_BUFFER_SIZE,
update_interval=SAC_UPDATE_INTERVAL,
target_update_interval=SAC_TARGET_UPDATE_INTERVAL,
gradient_clip=SAC_GRADIENT_CLIP,
reward_scale=SAC_REWARD_SCALE,
use_batch_norm=SAC_USE_BATCH_NORM,
use_residual=SAC_USE_RESIDUAL,
model_dir=os.path.join(MODEL_SAVE_PATH, 'sac_agent') # Point to subfolder within run
), # Pass the configured agent
gru_lookback=GRU_LOOKBACK
)
# --- Model Loading/Training ---
gru_loaded = False; sac_loaded = False
if LOAD_EXISTING_SYSTEM:
load_base_path = MODEL_SAVE_PATH
logger.info(f"Attempting to load existing system components...")
logger.info(f"Base path for loading: {load_base_path}")
gru_model_load_dir = None
sac_model_load_dir = None
if GRU_MODEL_LOAD_RUN_ID:
gru_model_load_dir = os.path.join(BASE_MODELS_DIR, f'run_{GRU_MODEL_LOAD_RUN_ID}')
logger.info(f"Using specific GRU load path based on run ID: {gru_model_load_dir}")
if LOAD_SAC_AGENT:
sac_model_load_dir = os.path.join(BASE_MODELS_DIR, f'run_{GRU_MODEL_LOAD_RUN_ID}')
logger.info(f"Using specific SAC load path based on GRU run ID (LOAD_SAC_AGENT=True): {sac_model_load_dir}")
else:
sac_model_load_dir = os.path.join(MODEL_SAVE_PATH, 'sac_agent')
logger.info(f"Defaulting SAC path to current run (LOAD_SAC_AGENT=False): {sac_model_load_dir}")
elif os.path.exists(load_base_path):
gru_model_load_dir = os.path.join(load_base_path, 'gru_model')
sac_model_load_dir = os.path.join(load_base_path, 'sac_agent')
logger.info(f"Using GRU load path based on MODEL_SAVE_PATH: {gru_model_load_dir}")
logger.info(f"Using SAC load path based on MODEL_SAVE_PATH: {sac_model_load_dir}")
else:
logger.warning(f"LOAD_EXISTING_SYSTEM is True, but MODEL_SAVE_PATH does not exist: {load_base_path}. Cannot determine model paths.")
LOAD_EXISTING_SYSTEM = False
if LOAD_EXISTING_SYSTEM:
try:
if gru_model_load_dir and os.path.isdir(gru_model_load_dir):
logger.info(f"Found GRU model directory: {gru_model_load_dir}. Loading...")
if trading_system.gru_model is None: trading_system.gru_model = CryptoGRUModel()
if trading_system.gru_model.load(gru_model_load_dir):
logger.info("GRU model loaded successfully.")
gru_loaded = True
trading_system.feature_scaler = trading_system.gru_model.feature_scaler
trading_system.y_scaler = trading_system.gru_model.y_scaler
logger.info("Scalers propagated from loaded GRU model.")
else: logger.warning(f"GRU model directory found, but loading failed.")
elif gru_model_load_dir: logger.warning(f"GRU model directory specified or derived, but not found at {gru_model_load_dir}. GRU model cannot be loaded.")
else: logger.warning("GRU model path could not be determined. GRU model cannot be loaded.")
if LOAD_SAC_AGENT:
if sac_model_load_dir and os.path.isdir(sac_model_load_dir):
logger.info(f"Found SAC model directory: {sac_model_load_dir}. Loading (LOAD_SAC_AGENT=True)...")
if trading_system.sac_agent is None:
trading_system.sac_agent = SimplifiedSACTradingAgent(state_dim=SAC_STATE_DIM, model_dir=sac_model_load_dir)
if trading_system.sac_agent.load(sac_model_load_dir):
logger.info("SAC agent loaded successfully.")
sac_loaded = True
else: logger.warning(f"SAC model directory found, but loading failed.")
elif sac_model_load_dir: logger.warning(f"SAC agent model directory derived, but not found at {sac_model_load_dir}. SAC agent cannot be loaded (LOAD_SAC_AGENT=True).")
else: logger.info("Skipping SAC agent loading (LOAD_SAC_AGENT=False).")
if gru_loaded: TRAIN_GRU_MODEL = False
if sac_loaded: TRAIN_SAC_AGENT = False; LOAD_SAC_AGENT = True
except Exception as e:
logger.warning(f"Could not load existing system components: {e}. Proceeding based on training flags.")
gru_loaded = False; sac_loaded = False
TRAIN_GRU_MODEL = True; TRAIN_SAC_AGENT = True; LOAD_SAC_AGENT = False
elif LOAD_EXISTING_SYSTEM: pass
else: logger.info("LOAD_EXISTING_SYSTEM=False. Proceeding with training flags.")
# --- Sanity Check After Loading ---
if not gru_loaded and not TRAIN_GRU_MODEL:
logger.error("Critical Error: GRU model was not loaded and TRAIN_GRU_MODEL is False. Cannot proceed.")
return
if not sac_loaded and not TRAIN_SAC_AGENT:
if RUN_BACKTEST:
logger.error("Critical Error: SAC agent was not loaded and TRAIN_SAC_AGENT is False. Aborting because RUN_BACKTEST is True.")
return
else: logger.warning("Proceeding without a functional SAC agent as RUN_BACKTEST is False.")
# Train GRU Model (if flag is set and not loaded)
if TRAIN_GRU_MODEL:
logger.info("--- Training GRU Model --- ")
gru_save_dir = MODEL_SAVE_PATH
history = trading_system.train_gru(
train_data=train_data, val_data=val_data,
prediction_horizon=GRU_PREDICTION_HORIZON,
epochs=GRU_EPOCHS, batch_size=GRU_BATCH_SIZE,
patience=GRU_PATIENCE,
model_save_dir=gru_save_dir
)
if history is None: logger.error("GRU Training failed. Aborting."); return
logger.info("--- GRU Model Training Finished --- ")
elif not gru_loaded: logger.error("GRU Model must be trained or loaded."); return
else: logger.info("Skipping GRU training (already loaded).")
# Train SAC Agent (if flag is set and not loaded)
if TRAIN_SAC_AGENT:
logger.info("--- Training SAC Agent --- ")
if not trading_system.gru_model or not (trading_system.gru_model.is_trained or trading_system.gru_model.is_loaded):
logger.error("Cannot train SAC: GRU model not ready."); return
if trading_system.sac_agent is None: logger.error("SAC Agent instance is missing in the trading system before training."); return
trading_system.sac_agent.model_dir = os.path.join(MODEL_SAVE_PATH, 'sac_agent')
logger.info(f"Ensured SAC agent model save dir is set to: {trading_system.sac_agent.model_dir}")
sac_history = trading_system.train_sac(
val_data=val_data,
epochs=SAC_EPOCHS,
batch_size=SAC_BATCH_SIZE,
transaction_cost=TRANSACTION_COST,
prediction_horizon=GRU_PREDICTION_HORIZON
)
logger.info("Finished training SAC agent.")
if sac_history is not None:
sac_save_dir = os.path.join(MODEL_SAVE_PATH, 'sac_agent')
logger.info(f"Saving Simplified SAC agent to {sac_save_dir}")
trading_system.sac_agent.save(sac_save_dir)
if sac_history:
sac_plot_save_path = os.path.join(RUN_RESULTS_DIR, f'sac_training_history_{run_id}.png')
logger.info(f"Plotting SAC training history to {sac_plot_save_path}...")
try: plot_sac_training_history(sac_history, save_path=sac_plot_save_path)
except Exception as plot_e: logger.error(f"Failed to plot SAC training history: {plot_e}", exc_info=True)
else: logger.warning("SAC training finished, but no history data returned for plotting.")
elif not sac_loaded and LOAD_SAC_AGENT:
# This block handles loading SAC if LOAD_EXISTING_SYSTEM was False but LOAD_SAC_AGENT was True (unlikely case)
if trading_system.sac_agent is None: trading_system.sac_agent = SimplifiedSACTradingAgent(state_dim=SAC_STATE_DIM)
sac_load_path = os.path.join(MODEL_SAVE_PATH, 'sac_agent') # Load from current run models
if os.path.isdir(sac_load_path):
logger.info(f"Attempting to load SAC weights from {sac_load_path} (LOAD_SAC_AGENT=True)...")
try: trading_system.sac_agent.load(sac_load_path); logger.info("SAC weights loaded."); sac_loaded = True
except Exception as e: logger.warning(f"Could not load SAC weights: {e}")
else: logger.warning(f"LOAD_SAC_AGENT=True but no weights found at {sac_load_path}.")
elif not sac_loaded: logger.warning("SAC Agent not trained or loaded.")
else: logger.info("Skipping SAC training (already loaded).")
# 5. Backtest on Test Data
if RUN_BACKTEST:
logger.info("--- Running Extended Backtest --- ")
if not trading_system.gru_model or not (trading_system.gru_model.is_trained or trading_system.gru_model.is_loaded):
logger.error("Cannot backtest: GRU model not ready."); return
if not trading_system.sac_agent: logger.error("Cannot backtest: SAC Agent not initialized."); return
instrument_label = f"{TICKER}/{EXCHANGE}"
backtester = ExtendedBacktester(
trading_system,
initial_capital=INITIAL_CAPITAL,
transaction_cost=TRANSACTION_COST,
instrument_label=instrument_label
)
backtest_results = backtester.backtest(test_data, verbose=True)
# 6. Generate Plots and Report
if GENERATE_PLOTS:
logger.info(f"Generating overall performance plot: {RESULTS_PLOT_PATH}...")
backtester.plot_results(save_path=RESULTS_PLOT_PATH)
if GENERATE_REPORT:
logger.info(f"Generating performance report: {REPORT_SAVE_PATH}...")
backtester.generate_performance_report(report_path=REPORT_SAVE_PATH)
else:
logger.info("Skipping backtesting.")
# --- Re-inserted Steps End ---
logger.info("--- GRU+SAC Pipeline Finished --- ")
if __name__ == "__main__":
main()

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

Some files were not shown because too many files have changed in this diff Show More