diff --git a/gru_sac_predictor/config.yaml b/gru_sac_predictor/config.yaml new file mode 100644 index 00000000..dc464d13 --- /dev/null +++ b/gru_sac_predictor/config.yaml @@ -0,0 +1,100 @@ +# Configuration for GRU-SAC Predictor + +# --- Run Identification & Output --- +run_id_template: '{timestamp}' # Template for generating unique run IDs. '{timestamp}' will be replaced by YYYYMMDD_HHMMSS. Allows grouping results, logs, and models. + +base_dirs: + results: 'results' # Base directory relative to package root + logs: 'logs' # Base directory relative to package root + models: 'models' # Base directory relative to package root + +# --- Data Parameters --- +data: + db_dir: '../../data/crypto_market_data' # Path to the directory containing the market data database (relative to where main.py is run). + exchange: 'bnbspot' # Name of the exchange table/data source in the database. + ticker: 'SOL-USDT' # Instrument identifier (e.g., trading pair) within the exchange data. + start_date: '2024-06-01' # Start date for loading data (YYYY-MM-DD). Note: Ensure enough data for lookback + splits. + end_date: '2025-03-10' # End date for loading data (YYYY-MM-DD). + interval: '1min' # Data frequency/interval (e.g., '1min', '5min', '1h'). + +# --- Data Split --- +split_ratios: + train: 0.6 # Proportion of the loaded data to use for training (0.0 to <1.0). + validation: 0.2 # Proportion of the loaded data to use for validation (0.0 to <1.0). + # Test ratio is calculated as 1.0 - train - validation. Ensure train + validation < 1.0. + +# --- GRU Model Parameters --- +gru: + lookback: 60 + epochs: 25 + batch_size: 256 + prediction_horizon: 5 + patience: 5 + model_load_run_id: '20250417_173635' + recency_weighting: + enabled: true + linear_start: 0.2 + linear_end: 1.0 + signed_weighting_beta: 0.0 + composite_loss_kappa: 0.0 + +# --- Calibration Parameters --- +calibration: + edge_threshold: 0.55 + recalibrate_every_n: 0 + recalibration_window: 10000 + +# --- SAC Agent Parameters --- +sac: + state_dim: 5 + hidden_size: 64 + gamma: 0.97 + tau: 0.02 + actor_lr: 3e-4 + buffer_max_size: 100000 + ou_noise_stddev: 0.2 + ou_noise_theta: 0.15 + ou_noise_dt: 0.01 + alpha: 0.2 + alpha_auto_tune: true + use_batch_norm: true + total_training_steps: 100 + min_buffer_size: 2000 + batch_size: 256 + log_interval: 1000 + save_interval: 10000 + +# --- Environment Parameters (Used by train_sac.py) --- +environment: + initial_capital: 10000.0 # Notional capital for env/backtest consistency + transaction_cost: 0.0005 # Fractional cost per trade (e.g., 0.0005 = 0.05%) + +# --- Backtesting Parameters --- +backtest: + initial_capital: 10000.0 # Starting capital for run_pipeline backtest. + transaction_cost: 0.0005 # Transaction cost for run_pipeline backtest. + +# --- Experience Generation (Simplified for config) --- +# Configuration for how experiences are generated or sampled for SAC training. +# (Currently only 'generate_new_on_epoch' is directly used from here in main.py) +experience: + generate_new_on_epoch: False # If true, generate fresh experiences using validation data at the start of each SAC epoch. If false, generate experiences once initially. + +# --- Control Flags --- +# Determine which parts of the pipeline to run. +control: + train_gru: true # Train the GRU model? + train_sac: true # Run the offline SAC training script before backtesting? + + # --- SAC Loading/Resuming --- + # For resuming training in train_sac.py: + sac_resume_run_id: null # Run ID of SAC agent to load *before* starting training (e.g., "sac_train_..."). If null, starts fresh. + sac_resume_step: final # Checkpoint step to resume from: 'final' or step number. + # For loading agent for backtesting in run_pipeline.py: + sac_load_run_id: null # Run ID of the SAC training run to load weights from for *backtesting* (e.g., "sac_train_..."). If null, uses initial weights. + sac_load_step: final # Which SAC checkpoint to load for backtesting: 'final' or step number. + + # --- Other Pipeline Controls --- + run_backtest: true # Run the backtest? + generate_plots: true # Generate output plots? + # generate_report: True # Deprecated: Metrics are saved to a .txt file. \ No newline at end of file diff --git a/gru_sac_predictor/models/run_20250416_142744/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_142744/actor.weights.h5 deleted file mode 100644 index 0f80090e..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_142744/best_model_reg.keras b/gru_sac_predictor/models/run_20250416_142744/best_model_reg.keras deleted file mode 100644 index c9ba7dfa..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/best_model_reg.keras and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_142744/critic1.weights.h5 b/gru_sac_predictor/models/run_20250416_142744/critic1.weights.h5 deleted file mode 100644 index 70b5a1a4..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/critic1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_142744/critic2.weights.h5 b/gru_sac_predictor/models/run_20250416_142744/critic2.weights.h5 deleted file mode 100644 index f4661419..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/critic2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_142744/feature_scaler.joblib b/gru_sac_predictor/models/run_20250416_142744/feature_scaler.joblib deleted file mode 100644 index 65594f07..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/feature_scaler.joblib and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_142744/gru_training_history.png b/gru_sac_predictor/models/run_20250416_142744/gru_training_history.png deleted file mode 100644 index 69ac8723..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/gru_training_history.png and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_142744/log_alpha.npy b/gru_sac_predictor/models/run_20250416_142744/log_alpha.npy deleted file mode 100644 index f6dcc78a..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/log_alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_142744/y_scaler.joblib b/gru_sac_predictor/models/run_20250416_142744/y_scaler.joblib deleted file mode 100644 index 815292e7..00000000 Binary files a/gru_sac_predictor/models/run_20250416_142744/y_scaler.joblib and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_144757/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_144757/sac_agent/actor.weights.h5 deleted file mode 100644 index 821c86f5..00000000 Binary files a/gru_sac_predictor/models/run_20250416_144757/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_144757/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_144757/sac_agent/alpha.npy deleted file mode 100644 index d8cc008b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_144757/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_144757/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_144757/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 219924ae..00000000 Binary files a/gru_sac_predictor/models/run_20250416_144757/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_144757/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_144757/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 556227b0..00000000 Binary files a/gru_sac_predictor/models/run_20250416_144757/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_144757/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_144757/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 0d6ff386..00000000 Binary files a/gru_sac_predictor/models/run_20250416_144757/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_144757/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_144757/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 2d695ae2..00000000 Binary files a/gru_sac_predictor/models/run_20250416_144757/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_145128/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_145128/sac_agent/actor.weights.h5 deleted file mode 100644 index b474b03b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_145128/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_145128/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_145128/sac_agent/alpha.npy deleted file mode 100644 index d8cc008b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_145128/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_145128/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_145128/sac_agent/critic_1.weights.h5 deleted file mode 100644 index b8e977f8..00000000 Binary files a/gru_sac_predictor/models/run_20250416_145128/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_145128/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_145128/sac_agent/critic_2.weights.h5 deleted file mode 100644 index a6643e59..00000000 Binary files a/gru_sac_predictor/models/run_20250416_145128/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_145128/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_145128/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 5a5b2bd0..00000000 Binary files a/gru_sac_predictor/models/run_20250416_145128/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_145128/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_145128/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 5b674e70..00000000 Binary files a/gru_sac_predictor/models/run_20250416_145128/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150829/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_150829/sac_agent/actor.weights.h5 deleted file mode 100644 index 1d707323..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150829/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150829/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_150829/sac_agent/alpha.npy deleted file mode 100644 index d8cc008b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150829/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150829/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_150829/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 5306d00d..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150829/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150829/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_150829/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 0c91518c..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150829/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150829/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_150829/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 566cb8a0..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150829/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150829/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_150829/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 2a9289d3..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150829/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150924/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_150924/sac_agent/actor.weights.h5 deleted file mode 100644 index 48f87255..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150924/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150924/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_150924/sac_agent/alpha.npy deleted file mode 100644 index d8cc008b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150924/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150924/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_150924/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 2b49793b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150924/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150924/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_150924/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 8819325c..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150924/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150924/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_150924/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index d9d16fbe..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150924/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_150924/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_150924/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 2c5effde..00000000 Binary files a/gru_sac_predictor/models/run_20250416_150924/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151322/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_151322/sac_agent/actor.weights.h5 deleted file mode 100644 index 07956adf..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151322/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151322/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_151322/sac_agent/alpha.npy deleted file mode 100644 index ff5c2883..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151322/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151322/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_151322/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 459fdc33..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151322/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151322/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_151322/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 5bfe8264..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151322/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151322/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_151322/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 175e4146..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151322/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151322/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_151322/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index f2efe7f3..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151322/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151849/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_151849/sac_agent/actor.weights.h5 deleted file mode 100644 index d49f4f57..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151849/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151849/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_151849/sac_agent/alpha.npy deleted file mode 100644 index d8cc008b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151849/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151849/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_151849/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 08716159..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151849/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151849/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_151849/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 43b82f3f..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151849/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151849/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_151849/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 20ee67f2..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151849/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_151849/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_151849/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 1f58a4d9..00000000 Binary files a/gru_sac_predictor/models/run_20250416_151849/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_152415/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_152415/sac_agent/actor.weights.h5 deleted file mode 100644 index d5fc1f9c..00000000 Binary files a/gru_sac_predictor/models/run_20250416_152415/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_152415/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_152415/sac_agent/alpha.npy deleted file mode 100644 index fdf08736..00000000 Binary files a/gru_sac_predictor/models/run_20250416_152415/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_152415/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_152415/sac_agent/critic_1.weights.h5 deleted file mode 100644 index ec11f9e4..00000000 Binary files a/gru_sac_predictor/models/run_20250416_152415/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_152415/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_152415/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 78096da4..00000000 Binary files a/gru_sac_predictor/models/run_20250416_152415/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_152415/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_152415/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 91a0b1c6..00000000 Binary files a/gru_sac_predictor/models/run_20250416_152415/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_152415/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_152415/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index a41c4a39..00000000 Binary files a/gru_sac_predictor/models/run_20250416_152415/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153132/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_153132/sac_agent/actor.weights.h5 deleted file mode 100644 index 76cbdefc..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153132/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153132/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_153132/sac_agent/alpha.npy deleted file mode 100644 index fdf08736..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153132/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153132/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_153132/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 931046cf..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153132/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153132/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_153132/sac_agent/critic_2.weights.h5 deleted file mode 100644 index daaeb3ff..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153132/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153132/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_153132/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index c2bf2d44..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153132/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153132/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_153132/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 1eee4657..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153132/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153846/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_153846/sac_agent/actor.weights.h5 deleted file mode 100644 index 67ea1863..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153846/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153846/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_153846/sac_agent/alpha.npy deleted file mode 100644 index d8cc008b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153846/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153846/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_153846/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 717ed3d9..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153846/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153846/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_153846/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 6d3328d7..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153846/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153846/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_153846/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 5a67aa12..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153846/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_153846/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_153846/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 44f93484..00000000 Binary files a/gru_sac_predictor/models/run_20250416_153846/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_154636/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_154636/sac_agent/actor.weights.h5 deleted file mode 100644 index e0b1186f..00000000 Binary files a/gru_sac_predictor/models/run_20250416_154636/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_154636/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_154636/sac_agent/alpha.npy deleted file mode 100644 index d8cc008b..00000000 Binary files a/gru_sac_predictor/models/run_20250416_154636/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_154636/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_154636/sac_agent/critic_1.weights.h5 deleted file mode 100644 index eddc6ce5..00000000 Binary files a/gru_sac_predictor/models/run_20250416_154636/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_154636/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_154636/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 9e1a8b2e..00000000 Binary files a/gru_sac_predictor/models/run_20250416_154636/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_154636/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_154636/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index dfbd183e..00000000 Binary files a/gru_sac_predictor/models/run_20250416_154636/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_154636/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_154636/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 1af03e89..00000000 Binary files a/gru_sac_predictor/models/run_20250416_154636/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_164726/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_164726/sac_agent/actor.weights.h5 deleted file mode 100644 index 3fe4bd15..00000000 Binary files a/gru_sac_predictor/models/run_20250416_164726/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_164726/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_164726/sac_agent/alpha.npy deleted file mode 100644 index d628ca3f..00000000 Binary files a/gru_sac_predictor/models/run_20250416_164726/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_164726/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_164726/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 5328bccb..00000000 Binary files a/gru_sac_predictor/models/run_20250416_164726/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_164726/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_164726/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 21330c03..00000000 Binary files a/gru_sac_predictor/models/run_20250416_164726/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_164726/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_164726/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 0b740f5d..00000000 Binary files a/gru_sac_predictor/models/run_20250416_164726/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_164726/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_164726/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 3c7e3cf1..00000000 Binary files a/gru_sac_predictor/models/run_20250416_164726/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_170503/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_170503/sac_agent/actor.weights.h5 deleted file mode 100644 index e6752f34..00000000 Binary files a/gru_sac_predictor/models/run_20250416_170503/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_170503/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_170503/sac_agent/alpha.npy deleted file mode 100644 index 2b472e76..00000000 Binary files a/gru_sac_predictor/models/run_20250416_170503/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_170503/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_170503/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 45bbf9dc..00000000 Binary files a/gru_sac_predictor/models/run_20250416_170503/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_170503/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_170503/sac_agent/critic_2.weights.h5 deleted file mode 100644 index fb92e4ee..00000000 Binary files a/gru_sac_predictor/models/run_20250416_170503/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_170503/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_170503/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index afa6d5ad..00000000 Binary files a/gru_sac_predictor/models/run_20250416_170503/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_170503/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_170503/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 0692b875..00000000 Binary files a/gru_sac_predictor/models/run_20250416_170503/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_182038/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_182038/sac_agent/actor.weights.h5 deleted file mode 100644 index 82f6ff98..00000000 Binary files a/gru_sac_predictor/models/run_20250416_182038/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_182038/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_182038/sac_agent/alpha.npy deleted file mode 100644 index 5b9784ad..00000000 Binary files a/gru_sac_predictor/models/run_20250416_182038/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_182038/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_182038/sac_agent/critic_1.weights.h5 deleted file mode 100644 index b5625db8..00000000 Binary files a/gru_sac_predictor/models/run_20250416_182038/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_182038/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_182038/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 2367e829..00000000 Binary files a/gru_sac_predictor/models/run_20250416_182038/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_182038/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_182038/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 1d4ceebd..00000000 Binary files a/gru_sac_predictor/models/run_20250416_182038/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_182038/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_182038/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 984a526e..00000000 Binary files a/gru_sac_predictor/models/run_20250416_182038/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183051/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_183051/sac_agent/actor.weights.h5 deleted file mode 100644 index 37063e4f..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183051/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183051/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_183051/sac_agent/alpha.npy deleted file mode 100644 index 3f8055fd..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183051/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183051/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_183051/sac_agent/critic_1.weights.h5 deleted file mode 100644 index 7c96ab2d..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183051/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183051/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_183051/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 9ae51864..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183051/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183051/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_183051/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index f117f012..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183051/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183051/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_183051/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index 89952e05..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183051/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183508/sac_agent/actor.weights.h5 b/gru_sac_predictor/models/run_20250416_183508/sac_agent/actor.weights.h5 deleted file mode 100644 index d347bb7e..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183508/sac_agent/actor.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183508/sac_agent/alpha.npy b/gru_sac_predictor/models/run_20250416_183508/sac_agent/alpha.npy deleted file mode 100644 index 4f7d1b57..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183508/sac_agent/alpha.npy and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183508/sac_agent/critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_183508/sac_agent/critic_1.weights.h5 deleted file mode 100644 index b6909207..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183508/sac_agent/critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183508/sac_agent/critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_183508/sac_agent/critic_2.weights.h5 deleted file mode 100644 index 1729e426..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183508/sac_agent/critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183508/sac_agent/target_critic_1.weights.h5 b/gru_sac_predictor/models/run_20250416_183508/sac_agent/target_critic_1.weights.h5 deleted file mode 100644 index 00081b5d..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183508/sac_agent/target_critic_1.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250416_183508/sac_agent/target_critic_2.weights.h5 b/gru_sac_predictor/models/run_20250416_183508/sac_agent/target_critic_2.weights.h5 deleted file mode 100644 index d2f9f298..00000000 Binary files a/gru_sac_predictor/models/run_20250416_183508/sac_agent/target_critic_2.weights.h5 and /dev/null differ diff --git a/gru_sac_predictor/models/run_20250418_013239/feature_scaler_20250418_013239.joblib b/gru_sac_predictor/models/run_20250418_013239/feature_scaler_20250418_013239.joblib new file mode 100644 index 00000000..de8edf24 Binary files /dev/null and b/gru_sac_predictor/models/run_20250418_013239/feature_scaler_20250418_013239.joblib differ diff --git a/gru_sac_predictor/models/run_20250418_013239/final_whitelist_20250418_013239.json b/gru_sac_predictor/models/run_20250418_013239/final_whitelist_20250418_013239.json new file mode 100644 index 00000000..3be8c986 --- /dev/null +++ b/gru_sac_predictor/models/run_20250418_013239/final_whitelist_20250418_013239.json @@ -0,0 +1,13 @@ +[ + "ATR_14", + "EMA_50", + "MACD_signal", + "chaikin_AD_10", + "hour_cos", + "hour_sin", + "return_15m", + "return_1m", + "return_60m", + "svi_10", + "volatility_14d" +] \ No newline at end of file diff --git a/gru_sac_predictor/models/run_20250418_013350/feature_scaler_20250418_013350.joblib b/gru_sac_predictor/models/run_20250418_013350/feature_scaler_20250418_013350.joblib new file mode 100644 index 00000000..0931d7f7 Binary files /dev/null and b/gru_sac_predictor/models/run_20250418_013350/feature_scaler_20250418_013350.joblib differ diff --git a/gru_sac_predictor/models/run_20250418_013350/final_whitelist_20250418_013350.json b/gru_sac_predictor/models/run_20250418_013350/final_whitelist_20250418_013350.json new file mode 100644 index 00000000..3be8c986 --- /dev/null +++ b/gru_sac_predictor/models/run_20250418_013350/final_whitelist_20250418_013350.json @@ -0,0 +1,13 @@ +[ + "ATR_14", + "EMA_50", + "MACD_signal", + "chaikin_AD_10", + "hour_cos", + "hour_sin", + "return_15m", + "return_1m", + "return_60m", + "svi_10", + "volatility_14d" +] \ No newline at end of file diff --git a/gru_sac_predictor/models/run_20250418_013938/feature_scaler_20250418_013938.joblib b/gru_sac_predictor/models/run_20250418_013938/feature_scaler_20250418_013938.joblib new file mode 100644 index 00000000..c20d654b Binary files /dev/null and b/gru_sac_predictor/models/run_20250418_013938/feature_scaler_20250418_013938.joblib differ diff --git a/gru_sac_predictor/models/run_20250418_013938/final_whitelist_20250418_013938.json b/gru_sac_predictor/models/run_20250418_013938/final_whitelist_20250418_013938.json new file mode 100644 index 00000000..d2581747 --- /dev/null +++ b/gru_sac_predictor/models/run_20250418_013938/final_whitelist_20250418_013938.json @@ -0,0 +1,13 @@ +[ + "ATR_14", + "EMA_10", + "MACD_signal", + "chaikin_AD_10", + "hour_cos", + "hour_sin", + "return_15m", + "return_1m", + "return_60m", + "svi_10", + "volatility_14d" +] \ No newline at end of file diff --git a/gru_sac_predictor/notebooks/example_pipeline_run.ipynb b/gru_sac_predictor/notebooks/example_pipeline_run.ipynb new file mode 100644 index 00000000..ae79be2d --- /dev/null +++ b/gru_sac_predictor/notebooks/example_pipeline_run.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GRU-SAC Trading Pipeline: Example Usage\n", + "\n", + "This notebook demonstrates how to instantiate and run the refactored `TradingPipeline` class.\n", + "\n", + "**Goal:** Run the complete pipeline (data loading, feature engineering, GRU training/loading, calibration, optional SAC training, backtesting) using a configuration file and inspect the results." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Imports and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initial sys.path: ['/home/yasha/develop', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/yasha/develop/gru_sac_predictor/.venv/lib/python3.10/site-packages']\n", + "Notebook directory (notebook_dir): /home/yasha/develop/gru_sac_predictor/notebooks\n", + "Calculated path for imports (project_root_for_imports): /home/yasha/develop/gru_sac_predictor\n", + "Checking if /home/yasha/develop/gru_sac_predictor is in sys.path...\n", + "Path not found. Adding /home/yasha/develop/gru_sac_predictor to sys.path.\n", + "sys.path after insert: ['/home/yasha/develop/gru_sac_predictor', '/home/yasha/develop', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/yasha/develop/gru_sac_predictor/.venv/lib/python3.10/site-packages']\n", + "Package path (package_path): /home/yasha/develop/gru_sac_predictor/gru_sac_predictor\n", + "Src path (src_path): /home/yasha/develop/gru_sac_predictor/gru_sac_predictor/src\n", + "\n", + "Attempting to import TradingPipeline...\n", + "ERROR: Failed to import TradingPipeline: No module named 'gru_sac_predictor.src'\n", + "Final sys.path before error: ['/home/yasha/develop/gru_sac_predictor', '/home/yasha/develop', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/yasha/develop/gru_sac_predictor/.venv/lib/python3.10/site-packages']\n", + "Please verify the calculated paths above and ensure the directory containing 'gru_sac_predictor' is correctly added to sys.path.\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "import yaml\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "import logging\n", + "\n", + "print(f'Initial sys.path: {sys.path}')\n", + "\n", + "# --- Path Setup ---\n", + "# Initialize project_root to None\n", + "project_root = None\n", + "project_root_for_imports = None # Initialize separately for clarity\n", + "try:\n", + " notebook_dir = os.path.abspath('') # Get current directory (should be notebooks/)\n", + " print(f'Notebook directory (notebook_dir): {notebook_dir}')\n", + "\n", + " # *** CORRECTED LINE BELOW ***\n", + " # Go up ONE level to get the directory containing the gru_sac_predictor package\n", + " # Assuming notebook is in develop/gru_sac_predictor/notebooks/\n", + " # This should result in '/home/yasha/develop/gru_sac_predictor'\n", + " project_root_for_imports = os.path.dirname(notebook_dir)\n", + " print(f'Calculated path for imports (project_root_for_imports): {project_root_for_imports}')\n", + "\n", + " # Add the calculated path to sys.path to allow imports from gru_sac_predictor\n", + " print(f'Checking if {project_root_for_imports} is in sys.path...')\n", + " if project_root_for_imports not in sys.path:\n", + " print(f'Path not found. Adding {project_root_for_imports} to sys.path.')\n", + " sys.path.insert(0, project_root_for_imports)\n", + " print(f'sys.path after insert: {sys.path}')\n", + " else:\n", + " print(f'Path {project_root_for_imports} already in sys.path.')\n", + "\n", + " # Define project_root consistently, used later for finding config.yaml\n", + " project_root = project_root_for_imports\n", + " if project_root: # Check if project_root was set successfully\n", + " package_path = os.path.join(project_root, 'gru_sac_predictor')\n", + " src_path = os.path.join(package_path, 'src')\n", + " print(f'Package path (package_path): {package_path}')\n", + " print(f'Src path (src_path): {src_path}')\n", + " else:\n", + " print(\"Project root could not be determined.\")\n", + "\n", + "except Exception as e:\n", + " print(f'Error during path setup: {e}')\n", + "\n", + "# --- Import the main pipeline class ---\n", + "print(\"\\nAttempting to import TradingPipeline...\")\n", + "try:\n", + " # Now this import should work if the path setup is correct\n", + " from gru_sac_predictor.src.trading_pipeline import TradingPipeline\n", + " print('Successfully imported TradingPipeline.')\n", + "except ImportError as e:\n", + " print(f'ERROR: Failed to import TradingPipeline: {e}')\n", + " print(f'Final sys.path before error: {sys.path}')\n", + " print(\"Please verify the calculated paths above and ensure the directory containing 'gru_sac_predictor' is correctly added to sys.path.\")\n", + " # Handle error appropriately, maybe raise it\n", + "except Exception as e: # Catch other potential errors\n", + " print(f'An unexpected error occurred during import: {e}')\n", + " print(f'Final sys.path before error: {sys.path}')\n", + "\n", + "# Configure basic logging for the notebook\n", + "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Configuration\n", + "\n", + "Specify the path to the configuration file (`config.yaml`). This file defines all parameters for the data, models, training, and backtesting." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using config file: /home/../gru_sac_predictor/config.yaml\n", + "ERROR: Config file not found at /home/../gru_sac_predictor/config.yaml\n" + ] + } + ], + "source": [ + "# Path to the configuration file \n", + "# Assumes config.yaml is in the gru_sac_predictor package directory, one level above src\n", + "config_rel_path = '../config.yaml'\n", + "# Construct absolute path relative to the project root identified earlier\n", + "if 'project_root' in locals():\n", + " config_abs_path = os.path.join(project_root, config_rel_path)\n", + "else:\n", + " print('ERROR: project_root not defined. Cannot find config file.')\n", + " config_abs_path = None\n", + "\n", + "if config_abs_path:\n", + " print(f'Using config file: {config_abs_path}')\n", + " # Verify the config file exists\n", + " if not os.path.exists(config_abs_path):\n", + " print(f'ERROR: Config file not found at {config_abs_path}')\n", + " else:\n", + " print('Config file found.')\n", + " # Optionally load and display config for verification\n", + " try:\n", + " with open(config_abs_path, 'r') as f:\n", + " config_data = yaml.safe_load(f)\n", + " # print('\\nConfiguration:')\n", + " # print(yaml.dump(config_data, default_flow_style=False)) # Pretty print\n", + " except Exception as e:\n", + " print(f'Error reading config file: {e}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Instantiate and Run the Pipeline\n", + "\n", + "Create an instance of the `TradingPipeline` and run its `execute()` method. This will perform all the steps defined in the configuration.\n", + "\n", + "**Note:** Depending on the configuration (especially `train_gru` and `train_sac` flags) and the data size, this cell might take a significant amount of time to run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline_instance = None # Define outside try block\n", + "if 'TradingPipeline' in locals() and config_abs_path and os.path.exists(config_abs_path): \n", + " try:\n", + " # Instantiate the pipeline\n", + " pipeline_instance = TradingPipeline(config_path=config_abs_path)\n", + " \n", + " # Execute the full pipeline\n", + " print('\\n=== Starting Pipeline Execution ===')\n", + " pipeline_instance.execute()\n", + " print('=== Pipeline Execution Finished ===')\n", + " \n", + " except FileNotFoundError as e:\n", + " print(f'ERROR during pipeline instantiation (FileNotFound): {e}')\n", + " except Exception as e:\n", + " print(f'An error occurred during pipeline execution: {e}')\n", + " logging.error('Pipeline execution failed.', exc_info=True) # Log traceback\n", + "else:\n", + " print('TradingPipeline class not imported, config path invalid, or config file not found. Cannot run pipeline.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Inspect Results\n", + "\n", + "After the pipeline execution, we can inspect the results stored within the `pipeline_instance` object and the files saved to the run directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance is not None and pipeline_instance.backtest_metrics:\n", + " print('\\n--- Backtest Metrics --- ')\n", + " # Pretty print the metrics dictionary\n", + " metrics = pipeline_instance.backtest_metrics\n", + " # Update Run ID in metrics before printing\n", + " metrics['Run ID'] = pipeline_instance.run_id \n", + " \n", + " for key, value in metrics.items():\n", + " if key == \"Confusion Matrix (GRU Signal vs Actual Dir)\":\n", + " print(f'{key}:\\n{np.array(value)}') \n", + " elif key == \"Classification Report (GRU Signal)\":\n", + " print(f'{key}:\\n{value}')\n", + " elif isinstance(value, float):\n", + " print(f'{key}: {value:.4f}')\n", + " else:\n", + " print(f'{key}: {value}')\n", + "else:\n", + " print('\\nPipeline object not found or backtest did not produce metrics.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance is not None and pipeline_instance.backtest_results_df is not None:\n", + " print('\\n--- Backtest Results DataFrame (Head) --- ')\n", + " pd.set_option('display.max_columns', None) # Show all columns\n", + " pd.set_option('display.width', 1000) # Wider display\n", + " display(pipeline_instance.backtest_results_df.head())\n", + " print('\\n--- Backtest Results DataFrame (Tail) --- ')\n", + " display(pipeline_instance.backtest_results_df.tail())\n", + " \n", + " # Display basic stats\n", + " print('\\n--- Backtest Results DataFrame (Description) --- ')\n", + " display(pipeline_instance.backtest_results_df.describe())\n", + "else:\n", + " print('\\nPipeline object not found or backtest did not produce results DataFrame.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Display Saved Plots\n", + "\n", + "Load and display the plots generated during the backtest. These are saved in the `results/` directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if pipeline_instance is not None and pipeline_instance.dirs.get('results'):\n", + " results_dir = pipeline_instance.dirs['results']\n", + " run_id = pipeline_instance.run_id\n", + " print(f'Looking for plots in: {results_dir}\\n')\n", + " \n", + " plot_files = [\n", + " f'backtest_summary_{run_id}.png', \n", + " f'confusion_matrix_{run_id}.png', \n", + " f'reliability_curve_val_{run_id}.png' # Optional validation plot\n", + " ]\n", + " \n", + " for plot_file in plot_files:\n", + " plot_path = os.path.join(results_dir, plot_file)\n", + " if os.path.exists(plot_path):\n", + " print(f'--- Displaying: {plot_file} ---')\n", + " try:\n", + " img = mpimg.imread(plot_path)\n", + " # Determine appropriate figure size based on plot type\n", + " figsize = (15, 12) if 'summary' in plot_file else (7, 6)\n", + " plt.figure(figsize=figsize)\n", + " plt.imshow(img)\n", + " plt.axis('off') # Hide axes for image display\n", + " plt.title(plot_file)\n", + " plt.show()\n", + " except Exception as e:\n", + " print(f' Error loading/displaying plot {plot_file}: {e}')\n", + " else:\n", + " print(f'Plot not found: {plot_path}')\n", + " \n", + "else:\n", + " print('\\nPipeline object not found or results directory is not available. Cannot display plots.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Conclusion\n", + "\n", + "This notebook demonstrated the basic workflow of using the `TradingPipeline`. You can modify the `config.yaml` file to experiment with different parameters, data ranges, and control flags (e.g., enabling/disabling GRU or SAC training). The results (metrics, plots, detailed CSV) are saved in the run-specific directory under `results/`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/gru_sac_predictor/src/__pycache__/__init__.cpython-312.pyc b/gru_sac_predictor/src/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 0d1fbc17..00000000 Binary files a/gru_sac_predictor/src/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/gru_sac_predictor/src/__pycache__/crypto_db_fetcher.cpython-312.pyc b/gru_sac_predictor/src/__pycache__/crypto_db_fetcher.cpython-312.pyc deleted file mode 100644 index ceb82e15..00000000 Binary files a/gru_sac_predictor/src/__pycache__/crypto_db_fetcher.cpython-312.pyc and /dev/null differ diff --git a/gru_sac_predictor/src/__pycache__/data_pipeline.cpython-312.pyc b/gru_sac_predictor/src/__pycache__/data_pipeline.cpython-312.pyc deleted file mode 100644 index e2827e01..00000000 Binary files a/gru_sac_predictor/src/__pycache__/data_pipeline.cpython-312.pyc and /dev/null differ diff --git a/gru_sac_predictor/src/__pycache__/gru_predictor.cpython-312.pyc b/gru_sac_predictor/src/__pycache__/gru_predictor.cpython-312.pyc deleted file mode 100644 index ec868d2f..00000000 Binary files a/gru_sac_predictor/src/__pycache__/gru_predictor.cpython-312.pyc and /dev/null differ diff --git a/gru_sac_predictor/src/__pycache__/sac_agent.cpython-312.pyc b/gru_sac_predictor/src/__pycache__/sac_agent.cpython-312.pyc deleted file mode 100644 index eed5f91d..00000000 Binary files a/gru_sac_predictor/src/__pycache__/sac_agent.cpython-312.pyc and /dev/null differ diff --git a/gru_sac_predictor/src/__pycache__/sac_agent_simplified.cpython-312.pyc b/gru_sac_predictor/src/__pycache__/sac_agent_simplified.cpython-312.pyc deleted file mode 100644 index 25a163f6..00000000 Binary files a/gru_sac_predictor/src/__pycache__/sac_agent_simplified.cpython-312.pyc and /dev/null differ diff --git a/gru_sac_predictor/src/__pycache__/trading_system.cpython-312.pyc b/gru_sac_predictor/src/__pycache__/trading_system.cpython-312.pyc deleted file mode 100644 index a2dae5f8..00000000 Binary files a/gru_sac_predictor/src/__pycache__/trading_system.cpython-312.pyc and /dev/null differ diff --git a/gru_sac_predictor/src/backtester.py b/gru_sac_predictor/src/backtester.py new file mode 100644 index 00000000..66fb58c7 --- /dev/null +++ b/gru_sac_predictor/src/backtester.py @@ -0,0 +1,440 @@ +""" +Backtesting Engine. + +Simulates trading strategy execution on historical test data, calculates +performance metrics, and generates reports and plots. +""" + +import os +import logging +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sns +from typing import Dict, Any, Tuple, Optional + +# Import required components (use absolute paths) +from gru_sac_predictor.src.sac_agent import SACTradingAgent +from gru_sac_predictor.src.gru_model_handler import GRUModelHandler +from gru_sac_predictor.src.calibrator import Calibrator + +logger = logging.getLogger(__name__) + + +def calculate_sharpe_ratio(returns, periods_per_year=252*24*60): # Default for 1-min data + """Calculate annualized Sharpe ratio from a series of returns.""" + returns = pd.Series(returns) + if returns.std() == 0: + return 0.0 + # Assuming risk-free rate is 0 for simplicity + return np.sqrt(periods_per_year) * returns.mean() / returns.std() + +def calculate_max_drawdown(equity_curve): + """Calculate the maximum drawdown from an equity curve series.""" + equity_curve = pd.Series(equity_curve) + rolling_max = equity_curve.cummax() + drawdown = (equity_curve - rolling_max) / rolling_max + max_drawdown = drawdown.min() + return abs(max_drawdown) # Return positive value + +class Backtester: + """Runs the backtest simulation and generates results.""" + + def __init__(self, config: dict): + """ + Initialize the Backtester. + Args: + config (dict): Pipeline configuration dictionary, expected to contain + 'backtest' and potentially 'calibration', 'sac' sections. + """ + self.config = config + self.backtest_cfg = config.get('backtest', {}) + self.cal_cfg = config.get('calibration', {}) + self.sac_cfg = config.get('sac', {}) + + self.initial_capital = self.backtest_cfg.get('initial_capital', 10000.0) + self.transaction_cost = self.backtest_cfg.get('transaction_cost', 0.0005) + self.edge_threshold = self.cal_cfg.get('edge_threshold', 0.55) + + self.results_df: Optional[pd.DataFrame] = None + self.metrics: Optional[Dict[str, Any]] = None + + logger.info("Backtester initialized.") + logger.info(f" Initial Capital: {self.initial_capital:.2f}") + logger.info(f" Transaction Cost: {self.transaction_cost*100:.4f}%") + logger.info(f" Edge Threshold: {self.edge_threshold:.3f}") + + def run_backtest( + self, + sac_agent_load_path: Optional[str], + X_test_seq: np.ndarray, + y_test_seq_dict: Dict[str, np.ndarray], + test_indices: pd.Index, + gru_handler: GRUModelHandler, + calibrator: Calibrator, + original_prices: Optional[pd.Series] = None # Added for plotting + ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + """ + Executes the backtesting simulation loop. + + Args: + sac_agent_load_path (Optional[str]): Path to load the SAC agent from. If None, uses untrained agent. + X_test_seq (np.ndarray): Test set feature sequences. + y_test_seq_dict (Dict[str, np.ndarray]): Dict of test set targets (e.g., 'ret', 'dir'). + test_indices (pd.Index): Timestamps corresponding to the test sequences' targets. + gru_handler (GRUModelHandler): Instance to get GRU predictions. + calibrator (Calibrator): Instance to calibrate probabilities. + original_prices (Optional[pd.Series]): Original close prices aligned with test_indices (for plotting). + + Returns: + Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + - DataFrame with detailed backtest results per step. + - Dictionary containing calculated performance metrics. + Returns (None, None) if backtest cannot run. + """ + logger.info("--- Starting Backtest Simulation ---") + + if X_test_seq is None or y_test_seq_dict is None or test_indices is None: + logger.error("Test sequence data (X, y, indices) is missing. Cannot run backtest.") + return None, None + if gru_handler.model is None: + logger.error("GRU model is not loaded in the handler. Cannot run backtest.") + return None + + # 1. Initialize SAC Agent & Load Weights + agent_state_dim = self.sac_cfg.get('state_dim', 5) + agent = SACTradingAgent( + state_dim=agent_state_dim, # Ensure this matches the state construction below + action_dim=1, # Typically -1 to 1 for position + gamma=self.sac_cfg.get('gamma', 0.99), # These params might not matter much if just loading weights + tau=self.sac_cfg.get('tau', 0.005), + initial_lr=self.sac_cfg.get('actor_lr', 3e-4), + alpha=self.sac_cfg.get('alpha', 0.2), + alpha_auto_tune=self.sac_cfg.get('alpha_auto_tune', True), + # Other SAC params... + edge_threshold_config=self.edge_threshold # Store config threshold + ) + if sac_agent_load_path and os.path.exists(sac_agent_load_path): + logger.info(f"Loading SAC agent weights from: {sac_agent_load_path}") + try: + agent.load(sac_agent_load_path) + except Exception as e: + logger.error(f"Failed to load SAC agent weights from {sac_agent_load_path}: {e}. Proceeding with untrained agent.", exc_info=True) + else: + logger.warning(f"SAC agent load path not found or not specified ({sac_agent_load_path}). Proceeding with untrained agent.") + + # 2. Get GRU Predictions on Test Set + logger.info(f"Generating GRU predictions for {len(X_test_seq)} test sequences...") + predictions_test = gru_handler.predict(X_test_seq) + if predictions_test is None or len(predictions_test) < 3: + logger.error("Failed to get GRU predictions on test set. Cannot run backtest.") + return None, None + + mu_test = predictions_test[0].flatten() + log_sigma_test = predictions_test[1][:, 1].flatten() + p_raw_test = predictions_test[2].flatten() + sigma_test = np.exp(log_sigma_test) + + # Extract actual returns and directions + actual_ret_test = y_test_seq_dict.get('ret') + actual_dir_test = y_test_seq_dict.get('dir') + if actual_ret_test is None or actual_dir_test is None: + logger.error("Actual return ('ret') or direction ('dir') missing from y_test_seq_dict. Cannot run backtest.") + return None + + # Verify prediction lengths + n_test = len(X_test_seq) + if not (len(mu_test) == n_test and len(sigma_test) == n_test and \ + len(p_raw_test) == n_test and len(actual_ret_test) == n_test and \ + len(actual_dir_test) == n_test and len(test_indices) == n_test): + logger.error(f"Length mismatch in test predictions/targets/indices: Expected {n_test}, got mu={len(mu_test)}, sigma={len(sigma_test)}, p_raw={len(p_raw_test)}, ret={len(actual_ret_test)}, dir={len(actual_dir_test)}, indices={len(test_indices)}") + return None, None + + # 3. Calibrate Test Probabilities + logger.info(f"Calibrating test predictions using T={calibrator.optimal_T:.4f}") + p_cal_test = calibrator.calibrate(p_raw_test) + edge_test = 2 * p_cal_test - 1 # Edge = P(up) - P(down) + z_score_test = np.abs(mu_test) / (sigma_test + 1e-9) # abs(mu)/sigma + # Generate GRU-based signals for confusion matrix + gru_signal_test = calibrator.action_signal(p_cal_test) + + # 4. Simulation Loop + capital = self.initial_capital + current_position = 0.0 # Starts neutral (-1 to 1) + equity_curve = [capital] + positions = [current_position] + actions_taken = [0.0] # SAC agent's desired fractional position + pnl_steps = [] + trades_executed = [] # Store details of trades + + logger.info(f"Starting backtest simulation loop: {n_test} steps...") + for i in range(n_test): + # Construct current state for SAC agent + # state = [mu, sigma, edge, |mu|/sigma, position] + state = np.array([ + mu_test[i], + sigma_test[i], + edge_test[i], + z_score_test[i], + current_position # Position from start of the period + ], dtype=np.float32) + + # Get action from SAC agent (deterministic for backtest) + # Action output is assumed to be the target position [-1, 1] + sac_action = agent.get_action(state, deterministic=True)[0] # Get scalar action + target_position = np.clip(sac_action, -1.0, 1.0) # Ensure action is within bounds + + # Calculate PnL for the step based on position held *during* the step + # Return is for the period ending at the *current* index i + step_actual_return = actual_ret_test[i] + gross_pnl = current_position * capital * (np.exp(step_actual_return) - 1) + + # Calculate trade size and cost + # Trade happens *after* observing return for period i, to establish position for period i+1 + trade = target_position - current_position + cost = abs(trade) * capital * self.transaction_cost + + # Calculate net PnL for the step + net_pnl = gross_pnl - cost + + # Update capital + capital += net_pnl + + # Store results for this step + equity_curve.append(capital) + positions.append(target_position) # Store the position held for the *next* step + actions_taken.append(sac_action) + pnl_steps.append(net_pnl) + if abs(trade) > 1e-6: # Record non-zero trades + trades_executed.append({ + 'timestamp': test_indices[i], + 'trade_size': trade, + 'cost': cost, + 'position_before': current_position, + 'position_after': target_position + }) + + # Update position for the next iteration + current_position = target_position + + if capital <= 0: + logger.warning(f"Capital depleted at step {i+1}. Stopping backtest.") + n_test = i + 1 # Adjust length to current step + break + + logger.info("Backtest simulation loop finished.") + logger.info(f"Final Equity: {capital:.2f}") + + # 5. Prepare Results DataFrame + if n_test == 0: + logger.warning("Backtest executed 0 steps.") + return pd.DataFrame(), {} + + results_data = { + 'equity': equity_curve[1:], # Exclude initial capital + 'position': positions[1:], # Position held for the period starting at this index + 'action': actions_taken[1:], # Agent's target position output + 'pnl': pnl_steps, + 'actual_return': actual_ret_test[:n_test], + 'mu_pred': mu_test[:n_test], + 'sigma_pred': sigma_test[:n_test], + 'p_cal_pred': p_cal_test[:n_test], + 'edge_pred': edge_test[:n_test], + 'gru_signal': gru_signal_test[:n_test], + 'actual_dir': actual_dir_test[:n_test] + } + # Add original prices if available + if original_prices is not None: + # Ensure alignment with test_indices before adding + aligned_prices = original_prices.loc[test_indices[:n_test]] + results_data['close_price'] = aligned_prices.values + + self.results_df = pd.DataFrame(results_data, index=test_indices[:n_test]) + self.results_df['returns'] = self.results_df['equity'].pct_change().fillna(0.0) + self.results_df['cumulative_return'] = (1 + self.results_df['returns']).cumprod() - 1 + + # Calculate Buy & Hold Benchmark + if 'close_price' in self.results_df.columns and not self.results_df.empty: + bh_returns = self.results_df['close_price'].pct_change().fillna(0.0) + self.results_df['bh_cumulative_return'] = (1 + bh_returns).cumprod() - 1 + bh_sharpe = calculate_sharpe_ratio(bh_returns) + else: + bh_sharpe = 0.0 + self.results_df['bh_cumulative_return'] = 0.0 + logger.warning("Could not calculate Buy & Hold benchmark due to missing price data.") + + # 6. Calculate Performance Metrics + logger.info("Calculating performance metrics...") + total_pnl = self.results_df['pnl'].sum() + final_equity = self.results_df['equity'].iloc[-1] + total_return_pct = (final_equity / self.initial_capital - 1) * 100 + sharpe = calculate_sharpe_ratio(self.results_df['returns']) + max_dd = calculate_max_drawdown(self.results_df['equity']) + + wins = self.results_df[self.results_df['pnl'] > 0]['pnl'] + losses = self.results_df[self.results_df['pnl'] < 0]['pnl'] + profit_factor = wins.sum() / abs(losses.sum()) if losses.sum() != 0 else np.inf + + num_trades = len(trades_executed) + # Simple win rate based on pnl > 0 per step holding position + win_rate_steps = (self.results_df['pnl'] > 0).mean() + + # Confusion Matrix for GRU signals + conf_matrix = confusion_matrix(actual_dir_test[:n_test], gru_signal_test[:n_test], labels=[-1, 0, 1]) + class_report = classification_report(actual_dir_test[:n_test], gru_signal_test[:n_test], labels=[-1, 0, 1], target_names=['Short', 'Neutral', 'Long'], zero_division=0) + + self.metrics = { + "Run ID": self.config.get('run_id_template', 'N/A').format(timestamp="..."), # Use actual run ID later + "Test Period Start": test_indices[0].strftime('%Y-%m-%d %H:%M'), + "Test Period End": test_indices[n_test-1].strftime('%Y-%m-%d %H:%M'), + "Initial Capital": self.initial_capital, + "Final Equity": final_equity, + "Total Net PnL": total_pnl, + "Total Return (%)": total_return_pct, + "Annualized Sharpe Ratio": sharpe, + "Max Drawdown (%)": max_dd * 100, + "Profit Factor": profit_factor, + "Number of Trades": num_trades, + "Step Win Rate (%)": win_rate_steps * 100, + "Transaction Cost (% per trade)": self.transaction_cost * 100, + "Edge Threshold": self.edge_threshold, + "Calibration Temperature (Optimal T)": calibrator.optimal_T, + "Buy & Hold Sharpe Ratio": bh_sharpe, + "Confusion Matrix (GRU Signal vs Actual Dir)": conf_matrix.tolist(), # Convert to list for saving + "Classification Report (GRU Signal)": class_report + } + logger.info("--- Backtest Simulation Finished ---") + return self.results_df, self.metrics + + def save_results( + self, + results_df: pd.DataFrame, + metrics: Dict[str, Any], + results_dir: str, + run_id: str + ): + """ + Saves the backtest results, metrics report, and plots. + + Args: + results_df (pd.DataFrame): DataFrame from run_backtest. + metrics (Dict[str, Any]): Metrics dictionary from run_backtest. + results_dir (str): Directory to save the results. + run_id (str): The pipeline run ID for filenames. + """ + logger.info("--- Saving Backtest Results --- ") + if results_df is None or metrics is None: + logger.warning("No results DataFrame or metrics to save.") + return + + os.makedirs(results_dir, exist_ok=True) + + # 1. Save Metrics Report + metrics_path = os.path.join(results_dir, f'performance_metrics_{run_id}.txt') + try: + with open(metrics_path, 'w') as f: + f.write(f"--- Performance Metrics (Run ID: {run_id}) ---\n\n") + for key, value in metrics.items(): + if key == "Confusion Matrix (GRU Signal vs Actual Dir)": + f.write(f"{key}:\n{np.array(value)}\n\n") # Nicer print for matrix + elif key == "Classification Report (GRU Signal)": + f.write(f"{key}:\n{value}\n\n") + elif isinstance(value, float): + f.write(f"{key}: {value:.4f}\n") + else: + f.write(f"{key}: {value}\n") + logger.info(f"Performance metrics saved to {metrics_path}") + except Exception as e: + logger.error(f"Failed to save metrics report: {e}", exc_info=True) + + # 2. Save Results DataFrame + results_csv_path = os.path.join(results_dir, f'backtest_results_{run_id}.csv') + try: + results_df.to_csv(results_csv_path) + logger.info(f"Detailed backtest results saved to {results_csv_path}") + except Exception as e: + logger.error(f"Failed to save results DataFrame: {e}", exc_info=True) + + # 3. Generate and Save Plots + if self.config.get('control', {}).get('generate_plots', True): + logger.info("Generating backtest plots...") + try: + # Plot 1: Multi-subplot (Price/Pred, Action, Equity/BH) + fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True) + + # Subplot 1: Price vs Prediction + if 'close_price' in results_df.columns: + ax = axes[0] + ax.plot(results_df.index, results_df['close_price'], label='Actual Price', color='black', alpha=0.8) + # Reconstruct predicted price from mu (log return prediction) + # pred_price = results_df['close_price'].shift(1) * np.exp(results_df['mu_pred']) + # ax.plot(results_df.index, pred_price, label='Predicted Price (from mu)', color='blue', alpha=0.6) + # Plot mu +/- sigma directly (interpreting mu as predicted return) + ax.plot(results_df.index, results_df['mu_pred'], label='Predicted Log Return (mu)', color='blue', alpha=0.7) + ax.fill_between(results_df.index, + results_df['mu_pred'] - results_df['sigma_pred'], + results_df['mu_pred'] + results_df['sigma_pred'], + color='blue', alpha=0.2, label='Predicted Sigma') + ax.set_ylabel("Price / Log Return") + ax.set_title(f'Price, Predictions & Uncertainty (Run: {run_id})') + ax.legend() + ax.grid(True) + else: + axes[0].text(0.5, 0.5, 'Price data not available for plotting', ha='center', va='center') + axes[0].set_title('Price vs Prediction (Data Missing)') + + # Subplot 2: SAC Agent Action + ax = axes[1] + ax.plot(results_df.index, results_df['action'], label='SAC Agent Target Position', color='red') + ax.set_ylabel("Target Position (-1 to 1)") + ax.set_ylim(-1.1, 1.1) + ax.set_title('SAC Agent Action') + ax.legend() + ax.grid(True) + + # Subplot 3: Equity Curve vs Buy&Hold + ax = axes[2] + equity_norm = results_df['equity'] / self.initial_capital + ax.plot(results_df.index, equity_norm, label='SAC Strategy Equity', color='green') + if 'bh_cumulative_return' in results_df.columns: + bh_equity_norm = 1 + results_df['bh_cumulative_return'] + ax.plot(results_df.index, bh_equity_norm, label='Buy & Hold Equity', color='gray', linestyle='--') + ax.set_ylabel("Normalized Equity") + ax.set_xlabel("Time") + ax.set_title('Portfolio Equity vs Buy & Hold') + ax.legend() + ax.grid(True) + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M')) + plt.xticks(rotation=45) + + plt.tight_layout() + plot1_path = os.path.join(results_dir, f'backtest_summary_{run_id}.png') + plt.savefig(plot1_path) + plt.close(fig) + logger.info(f"Summary plot saved to {plot1_path}") + + # Plot 2: Confusion Matrix Heatmap + if "Confusion Matrix (GRU Signal vs Actual Dir)" in metrics: + cm = np.array(metrics["Confusion Matrix (GRU Signal vs Actual Dir)"]) + fig_cm, ax_cm = plt.subplots(figsize=(6, 5)) + sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", + xticklabels=['Pred Short', 'Pred Neutral', 'Pred Long'], + yticklabels=['Actual Short', 'Actual Neutral', 'Actual Long'], + ax=ax_cm) + ax_cm.set_xlabel("Predicted Signal") + ax_cm.set_ylabel("Actual Direction") + ax_cm.set_title(f"GRU Signal Confusion Matrix (Run: {run_id})") + plt.tight_layout() + cm_plot_path = os.path.join(results_dir, f'confusion_matrix_{run_id}.png') + plt.savefig(cm_plot_path) + plt.close(fig_cm) + logger.info(f"Confusion matrix plot saved to {cm_plot_path}") + + except Exception as e: + logger.error(f"Failed to generate plots: {e}", exc_info=True) + else: + logger.info("Skipping plot generation as per config.") + + logger.info("--- Finished Saving Backtest Results ---") \ No newline at end of file diff --git a/gru_sac_predictor/src/calibrate.py b/gru_sac_predictor/src/calibrate.py new file mode 100644 index 00000000..aa7ecdd2 --- /dev/null +++ b/gru_sac_predictor/src/calibrate.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import numpy as np +import matplotlib.pyplot as plt +from scipy.optimize import minimize_scalar +from scipy.special import expit, logit +from typing import Tuple + +__all__ = [ + "optimise_temperature", + "calibrate", + "reliability_curve", + "EDGE_THR", +] + +# ------------------------------------------------------------------ +# Hyper‑parameters +# ------------------------------------------------------------------ +# Minimum calibrated edge magnitude before we take a trade. +# EDGE_THR: float = 0.55 # Default value if not passed + + +# ------------------------------------------------------------------ +# Temperature scaling +# ------------------------------------------------------------------ + +def _nll_temperature(T: float, logit_p: np.ndarray, y_true: np.ndarray) -> float: + """Negative log‑likelihood of labels given scaled logits.""" + p_cal = expit(logit_p / T) + # Binary cross‑entropy (NLL) + eps = 1e-12 + p_cal = np.clip(p_cal, eps, 1 - eps) + nll = -(y_true * np.log(p_cal) + (1 - y_true) * np.log(1 - p_cal)) + return float(np.mean(nll)) + + +def optimise_temperature(p_raw: np.ndarray, y_true: np.ndarray) -> float: + """Return optimal temperature `T` that minimises NLL on `y_true`.""" + p_raw = p_raw.flatten().astype(float) + y_true = y_true.flatten().astype(int) + logit_p = logit(np.clip(p_raw, 1e-6, 1 - 1e-6)) + + res = minimize_scalar( + lambda T: _nll_temperature(T, logit_p, y_true), + bounds=(0.05, 10.0), + method="bounded", + ) + return float(res.x) + + +# ------------------------------------------------------------------ +# Public API +# ------------------------------------------------------------------ + +def calibrate(p_raw: np.ndarray, T: float) -> np.ndarray: + """Return temperature‑scaled probabilities.""" + return expit(logit(np.clip(p_raw, 1e-6, 1 - 1e-6)) / T) + + +def reliability_curve( + p_raw: np.ndarray, + y_true: np.ndarray, + n_bins: int = 10, + show: bool = False, +) -> Tuple[np.ndarray, np.ndarray]: + """Return (bin_centres, empirical_prob) and optionally plot reliability.""" + p_raw = p_raw.flatten() + y_true = y_true.flatten() + bins = np.linspace(0, 1, n_bins + 1) + bin_ids = np.digitize(p_raw, bins) - 1 # 0‑indexed + + bin_centres = 0.5 * (bins[:-1] + bins[1:]) + acc = np.zeros(n_bins) + for i in range(n_bins): + idx = bin_ids == i + if np.any(idx): + acc[i] = y_true[idx].mean() + else: + acc[i] = np.nan + + if show: + plt.figure(figsize=(4, 4)) + plt.plot([0, 1], [0, 1], "k--", label="perfect") + plt.plot(bin_centres, acc, "o-", label="empirical") + plt.xlabel("Predicted P(up)") + plt.ylabel("Actual frequency") + plt.title("Reliability curve") + plt.legend() + plt.grid(True) + plt.tight_layout() + return bin_centres, acc + + +# ------------------------------------------------------------------ +# Signal filter +# ------------------------------------------------------------------ + +def action_signal(p_cal: np.ndarray, edge_threshold: float = 0.55) -> np.ndarray: + """Return trading signal: 1, -1 or 0 based on calibrated edge threshold.""" + up = p_cal > edge_threshold + dn = p_cal < 1 - edge_threshold + return np.where(up, 1, np.where(dn, -1, 0)) \ No newline at end of file diff --git a/gru_sac_predictor/src/calibrator.py b/gru_sac_predictor/src/calibrator.py new file mode 100644 index 00000000..d0e45496 --- /dev/null +++ b/gru_sac_predictor/src/calibrator.py @@ -0,0 +1,203 @@ +""" +Calibration Component for GRU Model Probabilities. + +Provides methods for temperature scaling (Platt scaling) and generating +action signals based on calibrated probabilities. +""" + +import numpy as np +import matplotlib.pyplot as plt +from scipy.optimize import minimize_scalar +from scipy.special import expit, logit +from typing import Tuple, Optional +import logging +import os + +logger = logging.getLogger(__name__) + +class Calibrator: + """Handles probability calibration using temperature scaling.""" + + def __init__(self, edge_threshold: float): + """ + Initialize the Calibrator. + + Args: + edge_threshold (float): Minimum calibrated edge magnitude for taking a trade (e.g., 0.55 means P(up) > 0.55 or P(down) > 0.55). + """ + self.edge_threshold = edge_threshold + self.optimal_T: Optional[float] = None # Stores the calculated temperature + logger.info(f"Calibrator initialized with edge threshold: {self.edge_threshold}") + + def _nll_objective(self, T: float, logit_p: np.ndarray, y_true: np.ndarray) -> float: + """Negative log-likelihood objective function for temperature optimization.""" + if T <= 0: + return np.inf # Temperature must be positive + p_cal = expit(logit_p / T) + # Binary cross-entropy (NLL) + eps = 1e-12 # Epsilon for numerical stability + p_cal = np.clip(p_cal, eps, 1 - eps) + # Ensure y_true is broadcastable if necessary (should be 1D) + nll = -(y_true * np.log(p_cal) + (1 - y_true) * np.log(1 - p_cal)) + return float(np.mean(nll)) + + def optimise_temperature(self, p_raw: np.ndarray, y_true: np.ndarray, bounds=(0.1, 10.0)) -> float: + """ + Finds the optimal temperature `T` by minimizing NLL on validation data. + + Args: + p_raw (np.ndarray): Raw model probabilities (validation set). + y_true (np.ndarray): True binary labels (validation set). + bounds (tuple): Bounds for the temperature search. + + Returns: + float: The optimal temperature found. + """ + logger.info(f"Optimizing calibration temperature using {len(p_raw)} samples...") + # Ensure inputs are flat numpy arrays and correct type + p_raw = np.asarray(p_raw).flatten().astype(float) + y_true = np.asarray(y_true).flatten().astype(int) + + # Clip raw probabilities and compute logits for numerical stability + eps = 1e-7 + p_clipped = np.clip(p_raw, eps, 1 - eps) + logit_p = logit(p_clipped) + + # Handle cases where all predictions are the same (logit might be inf) + if np.isinf(logit_p).any(): + logger.warning("Infinite values encountered in logits during temperature scaling. Clipping may be too aggressive or predictions are uniform. Returning T=1.0") + self.optimal_T = 1.0 + return 1.0 + + try: + res = minimize_scalar( + lambda T: self._nll_objective(T, logit_p, y_true), + bounds=bounds, + method="bounded", + ) + + if res.success: + optimal_T_found = float(res.x) + logger.info(f"Optimal temperature found: T = {optimal_T_found:.4f}") + self.optimal_T = optimal_T_found + return optimal_T_found + else: + logger.warning(f"Temperature optimization failed: {res.message}. Returning T=1.0") + self.optimal_T = 1.0 + return 1.0 + except Exception as e: + logger.error(f"Error during temperature optimization: {e}", exc_info=True) + logger.warning("Returning T=1.0 due to optimization error.") + self.optimal_T = 1.0 + return 1.0 + + def calibrate(self, p_raw: np.ndarray, T: Optional[float] = None) -> np.ndarray: + """ + Applies temperature scaling to raw probabilities. + + Args: + p_raw (np.ndarray): Raw model probabilities. + T (Optional[float]): Temperature value. If None, uses the stored optimal_T. + Defaults to 1.0 if optimal_T is also None. + + Returns: + np.ndarray: Calibrated probabilities. + """ + temp = T if T is not None else self.optimal_T + if temp is None: + logger.warning("Temperature T not provided and not optimized yet. Using T=1.0 for calibration.") + temp = 1.0 + if temp <= 0: + logger.error(f"Invalid temperature T={temp}. Using T=1.0 instead.") + temp = 1.0 + + # Clip raw probabilities and compute logits + eps = 1e-7 + p_clipped = np.clip(np.asarray(p_raw).astype(float), eps, 1 - eps) + logit_p = logit(p_clipped) + + # Apply temperature scaling + p_cal = expit(logit_p / temp) + return p_cal + + def reliability_curve( + self, + p_pred: np.ndarray, # Expects raw OR calibrated probabilities + y_true: np.ndarray, + n_bins: int = 10, + plot_title: str = "Reliability Curve", + save_path: Optional[str] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Computes and optionally plots the reliability curve. + + Args: + p_pred (np.ndarray): Predicted probabilities (raw or calibrated). + y_true (np.ndarray): True binary labels. + n_bins (int): Number of bins for the curve. + plot_title (str): Title for the plot. + save_path (Optional[str]): If provided, saves the plot to this path. + + Returns: + Tuple[np.ndarray, np.ndarray]: (bin_centers, empirical_prob) + """ + p_pred = np.asarray(p_pred).flatten() + y_true = np.asarray(y_true).flatten() + + bins = np.linspace(0, 1, n_bins + 1) + # Handle potential edge cases with digitize for values exactly 1.0 + bin_ids = np.digitize(p_pred, bins[1:], right=True) # Bin index 0 to n_bins-1 + + bin_centres = 0.5 * (bins[:-1] + bins[1:]) + empirical_prob = np.zeros(n_bins) + bin_counts = np.zeros(n_bins) + + for i in range(n_bins): + idx = bin_ids == i + bin_counts[i] = np.sum(idx) + if bin_counts[i] > 0: + empirical_prob[i] = y_true[idx].mean() + else: + empirical_prob[i] = np.nan # Use NaN for empty bins + + if save_path: + try: + plt.figure(figsize=(6, 6)) + plt.plot([0, 1], [0, 1], "k--", label="Perfect Calibration") + # Only plot bins with counts + valid_bins = bin_counts > 0 + plt.plot(bin_centres[valid_bins], empirical_prob[valid_bins], "o-", label="Model") + plt.xlabel("Mean Predicted Probability (per bin)") + plt.ylabel("Fraction of Positives (per bin)") + plt.title(plot_title) + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig(save_path) + plt.close() + logger.info(f"Reliability curve saved to {save_path}") + except Exception as e: + logger.error(f"Failed to generate or save reliability plot: {e}", exc_info=True) + + return bin_centres, empirical_prob + + def action_signal(self, p_cal: np.ndarray) -> np.ndarray: + """ + Generates trading signal (1, -1, or 0) based on calibrated probability + and the instance's edge threshold. + + Args: + p_cal (np.ndarray): Calibrated probabilities P(up). + + Returns: + np.ndarray: Action signals (1 for long, -1 for short, 0 for neutral). + """ + p_cal = np.asarray(p_cal) + # Signal long if P(up) > threshold + go_long = p_cal > self.edge_threshold + # Signal short if P(down) > threshold, which is P(up) < 1 - threshold + go_short = p_cal < (1.0 - self.edge_threshold) + + # Assign signals: 1 for long, -1 for short, 0 otherwise + signal = np.where(go_long, 1, np.where(go_short, -1, 0)) + return signal \ No newline at end of file diff --git a/gru_sac_predictor/src/crypto_db_fetcher.py b/gru_sac_predictor/src/crypto_db_fetcher.py deleted file mode 100644 index 4346e099..00000000 --- a/gru_sac_predictor/src/crypto_db_fetcher.py +++ /dev/null @@ -1,471 +0,0 @@ -""" -Market data fetcher module for cryptocurrency market data from SQLite databases. - -This module provides classes for fetching historical cryptocurrency market data -from SQLite databases downloaded from the crypto_md service. The raw data is always -stored in 1-minute intervals, and can be resampled to other intervals as needed. -""" - -import os -import logging -import pandas as pd -import sqlite3 -import glob -from datetime import datetime, timedelta -from typing import Dict, List, Tuple, Optional, Union -import re -import sys # Added - -# V7: Update logger setup -# Configure logging with explicit console output -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - # logging.FileHandler("db_fetcher.log", mode='a'), # Optional file logging - logging.StreamHandler(sys.stdout) - ] -) -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) # Set level for this specific logger - -class CryptoDBFetcher: - """ - Fetches historical cryptocurrency market data from SQLite databases. - - The raw data in the SQLite databases is always stored in 1-minute intervals. - This class can resample the data to other intervals (e.g., 5min, 1h, 1d) as needed. - """ - - # V7 Update: Adjusted defaults slightly, cache dir relative to project - def __init__(self, db_dir: str = "downloaded_data", cache_dir: str = "data/cache", use_cache: bool = False): - """ - Initialize the crypto database fetcher. - - Args: - db_dir: Directory where SQLite database files are stored - cache_dir: Directory to store cached data (relative to project root) - use_cache: Whether to use cached data when available (default False for V7) - """ - # V7 Update: Make db_dir potentially relative to workspace - if not os.path.isabs(db_dir): - # This assumes the script is run from the project root (e.g., v7/) - # Adjust if running from elsewhere - base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Go up two levels from src - self.db_dir = os.path.join(base_path, db_dir) - else: - self.db_dir = db_dir - - self.cache_dir = cache_dir - self.use_cache = use_cache - - if self.use_cache: - os.makedirs(cache_dir, exist_ok=True) - - # Map of exchanges and pairs - discovered lazily now - self._available_exchanges = None - self._available_pairs = None - self._db_files = None # Cache discovered DB files - - logger.info(f"Initialized CryptoDBFetcher with db_dir={self.db_dir}") - - @property - def available_exchanges(self) -> List[str]: - if self._available_exchanges is None: - self._available_exchanges = self._discover_available_exchanges() - logger.info(f"Discovered exchanges: {', '.join(self._available_exchanges)}") - return self._available_exchanges - - @property - def available_pairs(self) -> List[str]: - if self._available_pairs is None: - self._available_pairs = self._discover_available_pairs() - logger.info(f"Discovered pairs: {', '.join(self._available_pairs)}") - return self._available_pairs - - def _discover_available_exchanges(self) -> List[str]: - """Discover available exchanges from database files.""" - exchanges = set() - try: - db_files = self._get_db_files() - if not db_files: return [] - - # Use the first DB file to discover exchanges - with sqlite3.connect(db_files[0]) as conn: - tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE '%_ohlcv_%'").fetchall() - for table in tables: - table_name = table[0] - try: - # Attempt 1: Extract from table name (e.g., coinbase_ohlcv_1min) - match = re.match(r'([a-z0-9]+)_ohlcv_', table_name, re.IGNORECASE) - if match: - exchanges.add(match.group(1).upper()) - # Attempt 2: Query distinct exchange_id if column exists - elif 'exchange_id' in [c[1].lower() for c in conn.execute(f"PRAGMA table_info({table_name})").fetchall()]: - query = f"SELECT DISTINCT exchange_id FROM {table_name}" - for row in conn.execute(query).fetchall(): exchanges.add(row[0].upper()) - except sqlite3.Error as e: logger.debug(f"Could not query table {table_name}: {e}") - - return sorted(list(exchanges)) - except Exception as e: logger.error(f"Error discovering exchanges: {e}"); return [] - - def _discover_available_pairs(self) -> List[str]: - """Discover available trading pairs from database files.""" - pairs = set() - try: - db_files = self._get_db_files() - if not db_files: return [] - - with sqlite3.connect(db_files[0]) as conn: - tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE '%_ohlcv_%'").fetchall() - for table in tables: - table_name = table[0] - try: - # Check if instrument_id column exists - if 'instrument_id' in [c[1].lower() for c in conn.execute(f"PRAGMA table_info({table_name})").fetchall()]: - query = f"SELECT DISTINCT instrument_id FROM {table_name}" - for row in conn.execute(query).fetchall(): - instr_id = row[0] - if instr_id.startswith("PAIR-"): instr_id = instr_id[5:] - pairs.add(instr_id) - else: - # Attempt to parse from table name if instrument_id col missing - # e.g., coinbase_ohlcv_1min_btc_usdt - match = re.search(r'_([A-Z0-9]+_[A-Z0-9]+)$|_([A-Z0-9]+-[A-Z0-9]+)$', table_name, re.IGNORECASE) - if match: - pair = match.group(1) or match.group(2) - if pair: pairs.add(pair.replace('_','-').upper()) - except sqlite3.Error as e: logger.debug(f"Could not query table {table_name}: {e}") - return sorted(list(pairs)) - except Exception as e: logger.error(f"Error discovering pairs: {e}"); return [] - - def _get_db_files(self) -> List[str]: - """Get available database files, sorted by date desc (cached).""" - if self._db_files is not None: - return self._db_files - - logger.info(f"Scanning for DB files in: {self.db_dir}") - if not os.path.exists(self.db_dir): - logger.error(f"Database directory {self.db_dir} does not exist") - self._db_files = [] - return [] - - patterns = ["*.mktdata.ohlcv.db", "*.db", "*.sqlite", "*.sqlite3"] - db_files = [] - for pattern in patterns: - files = glob.glob(os.path.join(self.db_dir, pattern)) - if files: logger.debug(f"Found {len(files)} files with pattern {pattern}") - db_files.extend(files) - - if not db_files: - logger.warning(f"No database files found in {self.db_dir} matching patterns: {patterns}") - self._db_files = [] - return [] - - db_files = list(set(db_files)) # Remove duplicates - - # Sort by date (newest first) if possible - date_pattern = re.compile(r'(\d{8})') - file_dates = [] - for file in db_files: - basename = os.path.basename(file) - match = date_pattern.search(basename) - date_obj = None - if match: - try: date_obj = pd.to_datetime(match.group(1), format='%Y%m%d') - except ValueError: pass - # Fallback: try modification time - if date_obj is None: - try: date_obj = pd.to_datetime(os.path.getmtime(file), unit='s') - except Exception: date_obj = pd.Timestamp.min # Default to oldest if error - file_dates.append((date_obj, file)) - - file_dates.sort(key=lambda x: x[0], reverse=True) # Sort by date object, newest first - self._db_files = [file for _, file in file_dates] - - logger.info(f"Found {len(self._db_files)} DB files. Using newest: {os.path.basename(self._db_files[0]) if self._db_files else 'None'}") - return self._db_files - - # V7 Update: Simplified date finding - use files covering the range - def _get_relevant_db_files(self, start_dt: pd.Timestamp, end_dt: pd.Timestamp) -> List[str]: - """Find DB files potentially containing data for the date range.""" - all_files = self._get_db_files() - relevant_files = set() - date_pattern = re.compile(r'(\d{8})') - - for file in all_files: - basename = os.path.basename(file) - match = date_pattern.search(basename) - if match: - try: - file_date = pd.to_datetime(match.group(1), format='%Y%m%d') - # Check if the file's date is within or overlaps the target range - # (Assume file contains data for that single day) - if start_dt.date() <= file_date.date() <= end_dt.date(): - relevant_files.add(file) - except ValueError: - pass # Ignore files with unparseable dates - else: - # If no date in filename, conservatively include recent files - # based on modification time (might be less accurate) - try: - mod_time = pd.to_datetime(os.path.getmtime(file), unit='s') - # Include if modified within or shortly after the requested range - if start_dt <= mod_time <= (end_dt + timedelta(days=1)): - relevant_files.add(file) - except Exception: - pass # Ignore files with mod time errors - - # If no files found based on date, return the most recent one as a fallback - if not relevant_files and all_files: - logger.warning(f"No DB files found matching date range {start_dt.date()} - {end_dt.date()}. Using most recent file.") - return [all_files[0]] - elif not relevant_files: - logger.error("No relevant DB files found and no fallback files available.") - return [] - - # Sort the relevant files chronologically (oldest first for processing) - return sorted(list(relevant_files), key=lambda f: os.path.basename(f)) - - def _convert_ticker_to_instrument_id(self, ticker: str) -> str: - if not ticker.startswith("PAIR-"): return f"PAIR-{ticker}" - return ticker - - def _convert_interval(self, interval: str) -> Optional[str]: - interval = interval.lower() - interval_map = {"1m": "1min", "5m": "5min", "15m": "15min", "30m": "30min", - "1h": "1hour", "4h": "4hour", "1d": "1day", "1D": "1day"} - if interval in interval_map: return interval_map[interval] - if interval.endswith(('min', 'hour', 'day')): return interval - logger.warning(f"Unsupported interval format: {interval}") - return None # Return None for unsupported intervals - - def _get_table_name(self, conn: sqlite3.Connection, exchange: str, interval: str = "1min") -> Optional[str]: - """Find the correct table name, trying variations.""" - base_table = f"{exchange.lower()}_ohlcv_{interval}" - cursor = conn.cursor() - cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") - tables = [row[0] for row in cursor.fetchall()] - - if base_table in tables: return base_table - - # Try variations (case, interval format) - variations = [ - f"{exchange.upper()}_ohlcv_{interval}", - f"{exchange.lower()}_ohlcv_1m", # Always check 1min as source - f"{exchange.upper()}_ohlcv_1m", - ] - for var in variations: - if var in tables: - logger.debug(f"Found table using variation: {var}") - return var - - # Check if any OHLCV table exists for the exchange - for t in tables: - if t.lower().startswith(f"{exchange.lower()}_ohlcv_"): - logger.warning(f"Using first available OHLCV table for exchange: {t}") - return t - - logger.warning(f"Table for {exchange} interval {interval} not found.") - return None - - def _query_data_from_db(self, db_file: str, ticker: str, start_timestamp: int, end_timestamp: int, - interval: str = "1min", exchange: str = "COINBASE") -> pd.DataFrame: - """ - Query market data from a database file. - """ - instrument_id = self._convert_ticker_to_instrument_id(ticker) - # Always query 1min interval from DB as it's the source resolution - query_interval = "1min" - - try: - # logger.debug(f"Querying DB {db_file} for {instrument_id}...") - with sqlite3.connect(db_file) as conn: - table_name = self._get_table_name(conn, exchange, query_interval) - if not table_name: - return pd.DataFrame() - - cursor = conn.cursor() - cursor.execute(f"PRAGMA table_info({table_name})") - columns_info = cursor.fetchall() - column_names = [col[1].lower() for col in columns_info] - # logger.debug(f"Columns in {table_name}: {column_names}") - - # Build query dynamically - select_cols = ["tstamp", "open", "high", "low", "close", "volume"] - select_str = ", ".join([c for c in select_cols if c in column_names]) - if not all(c in column_names for c in ["tstamp", "open", "high", "low", "close", "volume"]): - logger.warning(f"Table {table_name} in {db_file} missing standard OHLCV columns.") - # Attempt to map common variations if needed - skipped for simplicity - # For now, return empty if standard columns are missing - return pd.DataFrame() - - where_clauses = ["tstamp >= ?", "tstamp <= ?"] - params = [start_timestamp, end_timestamp] - - if 'instrument_id' in column_names: - where_clauses.append("instrument_id = ?") - params.append(instrument_id) - if 'exchange_id' in column_names: - # Normalize exchange name from DB if needed - cursor.execute(f"SELECT DISTINCT exchange_id FROM {table_name} WHERE exchange_id LIKE ? LIMIT 1", (f'%{exchange}%',)) - db_exchange = cursor.fetchone() - if db_exchange: - params.append(db_exchange[0]) - where_clauses.append("exchange_id = ?") - else: # If exchange not found, query might return empty - params.append(exchange) - where_clauses.append("exchange_id = ?") - - query = f"SELECT {select_str} FROM {table_name} WHERE {' AND '.join(where_clauses)} ORDER BY tstamp" - # logger.debug(f"Executing query: {query} with params {params[:2]}...{params[-1]}") - - df = pd.read_sql_query(query, conn, params=params) - if df.empty: return pd.DataFrame() - - # Convert timestamp and set index - df['date'] = pd.to_datetime(df['tstamp'], unit='ns', utc=True) - df = df.set_index('date').drop(columns=['tstamp']) # Drop original tstamp - # logger.debug(f"Query returned {len(df)} rows.") - return df - - except Exception as e: - logger.error(f"Error querying {db_file} table {table_name if 'table_name' in locals() else 'N/A'}: {e}", exc_info=False) - return pd.DataFrame() - - def _resample_data(self, df: pd.DataFrame, interval: str) -> pd.DataFrame: - """Resample 1-minute data to a different interval.""" - if df.empty or not isinstance(df.index, pd.DatetimeIndex): return df - - # V7 Update: More robust interval conversion - try: - freq = pd.tseries.frequencies.to_offset(interval) - if freq is None: raise ValueError("Invalid frequency") - except ValueError: - # Try manual mapping for common cases - interval_map = {"1min": "1min", "5min": "5min", "15min": "15min", "30min": "30min", - "1hour": "1h", "4hour": "4h", "1day": "1d"} - if interval in interval_map: freq = interval_map[interval] - else: logger.error(f"Unsupported interval for resampling: {interval}"); return df - - logger.info(f"Resampling data to {freq}...") - try: - agg_dict = {'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last'} - # Only include volume if present - if 'volume' in df.columns: agg_dict['volume'] = 'sum' - - # Check for required columns - missing_cols = [c for c in ['open','high','low','close'] if c not in df.columns] - if missing_cols: - logger.error(f"Cannot resample, missing required columns: {missing_cols}") - return pd.DataFrame() # Return empty if essential cols missing - - resampled = df.resample(freq).agg(agg_dict) - resampled = resampled.dropna(subset=['open', 'high', 'low', 'close']) # Drop rows where OHLC couldn't be computed - logger.info(f"Resampling complete. New shape: {resampled.shape}") - - except Exception as e: - logger.error(f"Error during resampling to {freq}: {e}", exc_info=True) - return pd.DataFrame() # Return empty on error - - return resampled - - def fetch_data(self, ticker: str, start_date: str = None, end_date: str = None, interval: str = "1min", - exchange: str = "COINBASE") -> pd.DataFrame: - """ - Fetch cryptocurrency market data for a given ticker and date range. - Always sources 1-minute data from DB and resamples if needed. - """ - logger.info(f"Fetching {ticker} data from {start_date} to {end_date} at {interval} (Exchange: {exchange})...") - - # V7 Update: Stricter date handling - try: - start_dt = pd.to_datetime(start_date, utc=True) if start_date else pd.Timestamp.now(tz='utc') - timedelta(days=30) - end_dt = pd.to_datetime(end_date, utc=True) if end_date else pd.Timestamp.now(tz='utc') - if start_dt >= end_dt: raise ValueError("Start date must be before end date") - except Exception as e: logger.error(f"Invalid date format/range: {e}"); return pd.DataFrame() - logger.info(f"Querying date range: {start_dt.date()} to {end_dt.date()}") - - # V7 Update: Check supported interval format - target_interval = interval # Store requested interval - resample_freq_pd = None - try: resample_freq_pd = pd.tseries.frequencies.to_offset(target_interval) - except ValueError: - interval_map = {"1min": "1T", "5min": "5T", "15min": "15T", "30min": "30T", - "1hour": "1H", "4hour": "4H", "1day": "1D"} - if target_interval in interval_map: resample_freq_pd = interval_map[target_interval] - else: logger.error(f"Unsupported interval: {target_interval}"); return pd.DataFrame() - - # V7 Update: Cache key includes exchange - cache_key = f"{exchange}_{ticker}_{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}_{target_interval}".replace("-","_") - cache_path = os.path.join(self.cache_dir, f"{cache_key}.parquet") # Use parquet for better type handling - - if self.use_cache and os.path.exists(cache_path): - logger.info(f"Loading data from cache: {cache_path}") - try: - data = pd.read_parquet(cache_path) - # Ensure index is datetime and UTC (Parquet often preserves this) - if not isinstance(data.index, pd.DatetimeIndex): data.index = pd.to_datetime(data.index) - if data.index.tz is None: data.index = data.index.tz_localize('utc') - elif data.index.tz != 'UTC': data.index = data.index.tz_convert('utc') - logger.info(f"Loaded {len(data)} rows from cache.") - return data - except Exception as e: logger.warning(f"Error loading cache: {e}. Fetching fresh data.") - - # Convert timestamps for DB query - start_timestamp_ns = int(start_dt.timestamp() * 1e9) - end_timestamp_ns = int(end_dt.timestamp() * 1e9) - - # V7 Update: Use _get_relevant_db_files - db_files_to_query = self._get_relevant_db_files(start_dt, end_dt) - if not db_files_to_query: - logger.error("No relevant database files found for the specified date range.") - return pd.DataFrame() - logger.info(f"Querying {len(db_files_to_query)} DB files: {[os.path.basename(f) for f in db_files_to_query]}") - - all_data = [] - for db_file in db_files_to_query: - # Always query 1min data - df = self._query_data_from_db(db_file, ticker, start_timestamp_ns, end_timestamp_ns, "1min", exchange) - if not df.empty: all_data.append(df) - - if not all_data: logger.warning(f"No data found in DBs for {ticker}..."); return pd.DataFrame() - - combined_df = pd.concat(all_data) - combined_df = combined_df[~combined_df.index.duplicated(keep='first')].sort_index() - # Filter exact date range AFTER combining, before resampling - combined_df = combined_df[(combined_df.index >= start_dt) & (combined_df.index <= end_dt)] - logger.info(f"Combined data shape before resampling: {combined_df.shape}") - - # Resample if target interval is not 1 minute - final_df = combined_df - if target_interval != "1min": - final_df = self._resample_data(combined_df, target_interval) - if final_df.empty: - logger.error("Resampling resulted in empty DataFrame.") - return pd.DataFrame() - - # Cache the result - if self.use_cache: - try: - os.makedirs(os.path.dirname(cache_path), exist_ok=True) - final_df.to_parquet(cache_path) - logger.info(f"Saved {len(final_df)} rows to cache: {cache_path}") - except Exception as e: logger.warning(f"Failed to save cache: {e}") - - logger.info(f"Fetch data complete. Returning {len(final_df)} rows.") - return final_df - - # ... (keep _extract_date_from_filename) ... - def _extract_date_from_filename(self, filename): - if not filename: return None - base_name = os.path.basename(filename) - match = re.match(r"(\d{8})\.mktdata", base_name) - if match: - try: return datetime.strptime(match.group(1), "%Y%m%d").date() - except ValueError: return None - return None - -# --- Removed fetch_intraday_data, fetch_daily_data, batch_fetch_data, DataFetcherFactory --- -# These methods are not directly used by the V7 workflow as defined so far. -# They can be added back if needed. \ No newline at end of file diff --git a/gru_sac_predictor/src/data_loader.py b/gru_sac_predictor/src/data_loader.py new file mode 100644 index 00000000..b4846a7d --- /dev/null +++ b/gru_sac_predictor/src/data_loader.py @@ -0,0 +1,394 @@ +""" +Data Loader for Cryptocurrency Market Data from SQLite Databases. +""" + +import os +import logging +import pandas as pd +import sqlite3 +import glob +import re +import sys +from datetime import datetime, timedelta +from typing import List, Optional + +logger = logging.getLogger(__name__) + +class DataLoader: + """ + Loads historical cryptocurrency market data from SQLite databases. + Combines functionality from the previous CryptoDBFetcher and data loading logic. + """ + def __init__(self, db_dir: str, cache_dir: str = "data/cache", use_cache: bool = False): + """ + Initialize the DataLoader. + + Args: + db_dir (str): Directory where SQLite database files are stored. Can be relative to project root or absolute. + cache_dir (str): Directory to store cached data (currently not implemented). + use_cache (bool): Whether to use cached data (currently not implemented). + """ + # Resolve potential relative db_dir path + if not os.path.isabs(db_dir): + # Assume db_dir is relative to the project root (two levels up from src/) + # This might need adjustment depending on where the main script is run + script_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(os.path.dirname(script_dir)) + self.db_dir = os.path.abspath(os.path.join(project_root, db_dir)) + logger.info(f"Resolved relative db_dir '{db_dir}' to absolute path: {self.db_dir}") + else: + self.db_dir = db_dir + + self.cache_dir = cache_dir # Placeholder for future cache implementation + self.use_cache = use_cache # Placeholder + + self._db_files = None # Cache discovered DB files + + logger.info(f"Initialized DataLoader with db_dir='{self.db_dir}'") + if not os.path.exists(self.db_dir): + logger.warning(f"Database directory does not exist: {self.db_dir}") + + def _get_db_files(self) -> List[str]: + """Get available database files, sorted by date desc (cached). Uses recursive glob.""" + if self._db_files is not None: + return self._db_files + + logger.info(f"Scanning for DB files recursively in: {self.db_dir}") + if not os.path.exists(self.db_dir): + logger.error(f"Database directory {self.db_dir} does not exist") + self._db_files = [] + return [] + + patterns = ["*.mktdata.ohlcv.db", "*.db", "*.sqlite", "*.sqlite3"] + db_files = [] + for pattern in patterns: + # Recursive search + recursive_pattern = os.path.join(self.db_dir, '**', pattern) + try: + files = glob.glob(recursive_pattern, recursive=True) + if files: + logger.debug(f"Found {len(files)} files recursively with pattern '{pattern}'") + db_files.extend(files) + except Exception as e: + logger.error(f"Error during glob pattern '{recursive_pattern}': {e}") + + if not db_files: + logger.warning(f"No database files found in '{self.db_dir}' matching patterns: {patterns}") + self._db_files = [] + return [] + + db_files = sorted(list(set(db_files))) # Remove duplicates and sort alphabetically for consistency before date sort + + # Sort by date (newest first) if possible + date_pattern = re.compile(r'(\d{8})') + file_dates = [] + for file in db_files: + basename = os.path.basename(file) + match = date_pattern.search(basename) + date_obj = None + if match: + try: + date_obj = pd.to_datetime(match.group(1), format='%Y%m%d') + except ValueError: + pass + # Fallback: try modification time + if date_obj is None: + try: + date_obj = pd.to_datetime(os.path.getmtime(file), unit='s') + except Exception: + date_obj = pd.Timestamp.min # Default to oldest if error + file_dates.append((date_obj, file)) + + # Sort by date object, newest first + file_dates.sort(key=lambda x: x[0], reverse=True) + self._db_files = [file for _, file in file_dates] + + logger.info(f"Found {len(self._db_files)} DB files. Using newest: {os.path.basename(self._db_files[0]) if self._db_files else 'None'}") + return self._db_files + + def _get_relevant_db_files(self, start_dt: pd.Timestamp, end_dt: pd.Timestamp) -> List[str]: + """Find DB files potentially containing data for the date range.""" + all_files = self._get_db_files() + relevant_files = set() + date_pattern = re.compile(r'(\d{8})') + + start_date_only = start_dt.date() + end_date_only = end_dt.date() + + for file in all_files: + basename = os.path.basename(file) + match = date_pattern.search(basename) + file_date = None + if match: + try: + file_date = pd.to_datetime(match.group(1), format='%Y%m%d').date() + except ValueError: + pass # Ignore files with unparseable dates in name + + # Strategy 1: Check if filename date is within the requested range + if file_date and start_date_only <= file_date <= end_date_only: + relevant_files.add(file) + continue # Found by filename date, no need to check mtime + + # Strategy 2: Check if file modification time falls within the range (less precise) + # Useful for files without dates in the name or if a single file spans multiple dates + try: + mod_time_dt = pd.to_datetime(os.path.getmtime(file), unit='s', utc=True) + # Check if the file's modification date is within the range or shortly after + # We add a buffer (e.g., 1 day) because file might contain data slightly past its mod time + if start_dt <= mod_time_dt <= (end_dt + timedelta(days=1)): + relevant_files.add(file) + except Exception as e: + logger.debug(f"Could not get or parse modification time for {file}: {e}") + + # If no files found based on date/mtime, use the most recent file as a fallback + # This is a safety measure, but might lead to incorrect data if the range is old + if not relevant_files and all_files: + logger.warning(f"No DB files found matching date range {start_date_only} - {end_date_only}. Using most recent file as fallback: {os.path.basename(all_files[0])}") + return [all_files[0]] + elif not relevant_files: + logger.error("No relevant DB files found and no fallback files available.") + return [] + + # Sort the relevant files chronologically (oldest first for processing) + # Sorting by basename which often includes date is a reasonable heuristic + return sorted(list(relevant_files), key=lambda f: os.path.basename(f)) + + def _get_table_name(self, conn: sqlite3.Connection, exchange: str, interval: str = "1min") -> Optional[str]: + """Find the correct table name, trying variations. Prioritizes 1min.""" + cursor = conn.cursor() + try: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + except sqlite3.Error as e: + logger.error(f"Failed to list tables in database: {e}") + return None + + # Standard format check (lowercase exchange, exact interval) + base_table = f"{exchange.lower()}_ohlcv_{interval}" + if base_table in tables: return base_table + + # Check for 1min source table specifically (common case) + one_min_table = f"{exchange.lower()}_ohlcv_1min" + if one_min_table in tables: return one_min_table + + # Try other variations (case, common interval formats) + variations = [ + f"{exchange.upper()}_ohlcv_{interval}", + f"{exchange.upper()}_ohlcv_1min", + f"{exchange.lower()}_ohlcv_1m", # Common abbreviation + f"{exchange.upper()}_ohlcv_1m", + ] + for var in variations: + if var in tables: + logger.debug(f"Found table using variation: {var}") + return var + + # Fallback: Check if *any* OHLCV table exists for the exchange + for t in tables: + if t.lower().startswith(f"{exchange.lower()}_ohlcv_"): + logger.warning(f"Using first available OHLCV table found for exchange '{exchange}': {t}. Interval might not match '{interval}'.") + return t + + logger.warning(f"No suitable OHLCV table found for exchange '{exchange}' with interval '{interval}' or '1min' in the database.") + return None + + def _query_data_from_db(self, db_file: str, ticker: str, exchange: str, start_timestamp_ns: int, end_timestamp_ns: int) -> pd.DataFrame: + """ + Query market data from a single database file for a specific ticker and time range (nanoseconds). + Always queries the 1-minute interval table. + """ + instrument_id = f"PAIR-{ticker}" if not ticker.startswith("PAIR-") else ticker + query_interval = "1min" # Always query base interval from DB + + try: + logger.debug(f"Querying DB '{os.path.basename(db_file)}' for {instrument_id} between {start_timestamp_ns} and {end_timestamp_ns}") + with sqlite3.connect(f'file:{db_file}?mode=ro', uri=True) as conn: # Read-only mode + table_name = self._get_table_name(conn, exchange, query_interval) + if not table_name: + logger.warning(f"No table found for {exchange}/1min in {os.path.basename(db_file)}") + return pd.DataFrame() + + cursor = conn.cursor() + try: + cursor.execute(f"PRAGMA table_info({table_name})") + columns_info = cursor.fetchall() + column_names = [col[1].lower() for col in columns_info] + except sqlite3.Error as e: + logger.error(f"Failed to get column info for table '{table_name}' in {db_file}: {e}") + return pd.DataFrame() + + # Check for essential columns + select_cols = ["tstamp", "open", "high", "low", "close", "volume"] + if not all(c in column_names for c in select_cols): + logger.warning(f"Table '{table_name}' in {db_file} missing one or more standard columns: {select_cols}. Found: {column_names}") + return pd.DataFrame() + + # Build query + select_str = ", ".join(select_cols) + where_clauses = ["tstamp >= ?", "tstamp <= ?"] + params: list = [start_timestamp_ns, end_timestamp_ns] + + if 'instrument_id' in column_names: + where_clauses.append("instrument_id = ?") + params.append(instrument_id) + # Note: exchange_id filtering is complex due to potential variations; rely on table name for now. + + query = f"SELECT {select_str} FROM {table_name} WHERE {' AND '.join(where_clauses)} ORDER BY tstamp" + + df = pd.read_sql_query(query, conn, params=params) + if df.empty: + logger.debug(f"Query returned no data for {instrument_id} in {os.path.basename(db_file)} for the time range.") + return pd.DataFrame() + + # Convert timestamp and set index + df['date'] = pd.to_datetime(df['tstamp'], unit='ns', utc=True) + df = df.set_index('date').drop(columns=['tstamp']) + # Ensure numeric types + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + + # Drop rows with NaNs that might result from coerce + df.dropna(subset=['open', 'high', 'low', 'close'], inplace=True) + + logger.debug(f"Query from {os.path.basename(db_file)} returned {len(df)} rows for {instrument_id}.") + return df + + except sqlite3.Error as e: + logger.error(f"SQLite error querying {db_file} table '{table_name if 'table_name' in locals() else 'N/A'}': {e}") + except Exception as e: + logger.error(f"Unexpected error querying {db_file}: {e}", exc_info=False) + return pd.DataFrame() + + def _resample_data(self, df: pd.DataFrame, interval: str) -> pd.DataFrame: + """Resample 1-minute data to a different interval.""" + if df.empty or not isinstance(df.index, pd.DatetimeIndex): + logger.warning("Input DataFrame for resampling is empty or has non-DatetimeIndex.") + return df + if interval == '1min': # No resampling needed + return df + + logger.info(f"Resampling data to {interval}...") + try: + # Define aggregation rules + agg_dict = {'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last'} + if 'volume' in df.columns: agg_dict['volume'] = 'sum' + + # Check for required columns before resampling + required_cols = ['open', 'high', 'low', 'close'] + missing_cols = [c for c in required_cols if c not in df.columns] + if missing_cols: + logger.error(f"Cannot resample, missing required columns: {missing_cols}") + return pd.DataFrame() # Return empty if essential cols missing + + # Perform resampling + resampled_df = df.resample(interval).agg(agg_dict) + + # Drop rows where essential OHLC data is missing after resampling + resampled_df = resampled_df.dropna(subset=['open', 'high', 'low', 'close']) + + if resampled_df.empty: + logger.warning(f"Resampling to {interval} resulted in an empty DataFrame.") + else: + logger.info(f"Resampling complete. New shape: {resampled_df.shape}") + return resampled_df + + except ValueError as e: + logger.error(f"Invalid interval string for resampling: '{interval}'. Error: {e}") + return pd.DataFrame() + except Exception as e: + logger.error(f"Error during resampling to {interval}: {e}", exc_info=True) + return pd.DataFrame() + + def load_data(self, ticker: str, exchange: str, start_date: str, end_date: str, interval: str) -> pd.DataFrame: + """ + Loads, combines, and optionally resamples data from relevant DB files. + + Args: + ticker (str): The trading pair symbol (e.g., 'SOL-USDT'). + exchange (str): The exchange name (e.g., 'bnbspot'). + start_date (str): Start date string (YYYY-MM-DD). + end_date (str): End date string (YYYY-MM-DD). + interval (str): The desired final data interval (e.g., '1min', '5min', '1h'). + + Returns: + pd.DataFrame: Combined and resampled OHLCV data, indexed by UTC timestamp. + Returns an empty DataFrame on failure. + """ + logger.info(f"Loading data for {ticker} ({exchange}) from {start_date} to {end_date}, interval {interval}") + + try: + # Parse dates - add time component to cover full days + start_dt = pd.to_datetime(start_date, utc=True).replace(hour=0, minute=0, second=0, microsecond=0) + end_dt = pd.to_datetime(end_date, utc=True).replace(hour=23, minute=59, second=59, microsecond=999999) + if start_dt >= end_dt: + raise ValueError("Start date must be before end date") + except Exception as e: + logger.error(f"Invalid date format or range: {e}") + return pd.DataFrame() + + # Timestamps for DB query (nanoseconds) + start_timestamp_ns = int(start_dt.timestamp() * 1e9) + end_timestamp_ns = int(end_dt.timestamp() * 1e9) + + # Find relevant database files based on the date range + db_files_to_query = self._get_relevant_db_files(start_dt, end_dt) + if not db_files_to_query: + logger.error("No relevant database files found for the specified date range.") + return pd.DataFrame() + + logger.info(f"Identified {len(db_files_to_query)} potential DB files: {[os.path.basename(f) for f in db_files_to_query]}") + + # Query each relevant file and collect data + all_data = [] + for db_file in db_files_to_query: + df_part = self._query_data_from_db(db_file, ticker, exchange, start_timestamp_ns, end_timestamp_ns) + if not df_part.empty: + all_data.append(df_part) + + if not all_data: + logger.warning(f"No data found in any identified DB files for {ticker} ({exchange}) in the specified range.") + return pd.DataFrame() + + # Combine data from all files + try: + combined_df = pd.concat(all_data) + # Remove duplicate indices (e.g., from overlapping file queries), keeping the first occurrence + combined_df = combined_df[~combined_df.index.duplicated(keep='first')] + # Sort chronologically + combined_df = combined_df.sort_index() + except Exception as e: + logger.error(f"Error concatenating or sorting dataframes: {e}") + return pd.DataFrame() + + logger.info(f"Combined data shape before final filtering/resampling: {combined_df.shape}") + + # Apply precise date range filtering *after* combining and sorting + final_df = combined_df[(combined_df.index >= start_dt) & (combined_df.index <= end_dt)] + + if final_df.empty: + logger.warning(f"Dataframe is empty after final date range filtering ({start_dt} to {end_dt}).") + return pd.DataFrame() + + logger.info(f"Shape after final date filtering: {final_df.shape}") + + # Resample if the requested interval is different from 1min + if interval != "1min": + final_df = self._resample_data(final_df, interval) + if final_df.empty: + logger.error(f"Resampling to {interval} resulted in an empty DataFrame. Check resampling logic or input data.") + return pd.DataFrame() + + # Final check for NaNs in essential columns + essential_cols = ['open', 'high', 'low', 'close'] + if final_df[essential_cols].isnull().any().any(): + rows_before = len(final_df) + final_df.dropna(subset=essential_cols, inplace=True) + logger.warning(f"Dropped {rows_before - len(final_df)} rows with NaN values in essential OHLC columns after potential resampling.") + + if final_df.empty: + logger.error(f"Final DataFrame is empty after NaN checks for {ticker}.") + return pd.DataFrame() + + logger.info(f"Successfully loaded and processed data for {ticker}. Final shape: {final_df.shape}") + return final_df \ No newline at end of file diff --git a/gru_sac_predictor/src/data_pipeline.py b/gru_sac_predictor/src/data_pipeline.py deleted file mode 100644 index cac5b638..00000000 --- a/gru_sac_predictor/src/data_pipeline.py +++ /dev/null @@ -1,290 +0,0 @@ -import pandas as pd -import numpy as np -import logging -import os -import sys - -# V7 Update: Import the fetcher -from .crypto_db_fetcher import CryptoDBFetcher - -data_pipeline_logger = logging.getLogger(__name__) - -# V7 Update: Add load_data_from_db function from V6 -def load_data_from_db( - db_dir: str, - ticker: str, - exchange: str, - start_date: str, - end_date: str, - interval: str = "1min" -) -> pd.DataFrame: - """ - Loads cryptocurrency OHLCV data from the local SQLite database using CryptoDBFetcher. - Adapted from V6. - - Args: - db_dir: Directory containing the SQLite database files. - ticker: The trading pair symbol (e.g., 'BTC-USDT'). - exchange: The exchange name (e.g., 'COINBASE'). - start_date: Start date string (YYYY-MM-DD). - end_date: End date string (YYYY-MM-DD). - interval: The desired data interval (e.g., '1min', '5min', '1h'). - - Returns: - A Pandas DataFrame containing the OHLCV data, indexed by timestamp. - Returns an empty DataFrame if data loading fails. - """ - data_pipeline_logger.info(f"Loading data via DB: {ticker} from {exchange} ({start_date} to {end_date}, interval: {interval})") - try: - # Initialize fetcher (db_dir path is handled within fetcher now) - fetcher = CryptoDBFetcher(db_dir=db_dir) - df = fetcher.fetch_data( - ticker=ticker, - start_date=start_date, - end_date=end_date, - interval=interval, - exchange=exchange - ) - if df.empty: - data_pipeline_logger.warning(f"No data found for {ticker} in the specified DB range.") - else: - # Ensure index is datetime and timezone-aware (UTC) - if not isinstance(df.index, pd.DatetimeIndex): - try: - df.index = pd.to_datetime(df.index, errors='coerce', utc=True) - if df.index.isnull().any(): - data_pipeline_logger.warning("Dropping rows with invalid datetime index after conversion.") - df = df.dropna(subset=[df.index.name]) - except Exception as idx_e: - data_pipeline_logger.error(f"Failed to convert index to DatetimeIndex: {idx_e}") - return pd.DataFrame() - elif df.index.tz is None: # Ensure timezone if index is already datetime - df.index = df.index.tz_localize('utc') - elif df.index.tz != 'UTC': # Convert to UTC if different timezone - df.index = df.index.tz_convert('utc') - - df = df.sort_index() - data_pipeline_logger.info(f"Successfully loaded {len(df)} rows from DB.") - - return df - except Exception as e: - data_pipeline_logger.error(f"Error loading data from database via fetcher: {e}", exc_info=True) - return pd.DataFrame() - -def create_data_pipeline(historical_data, split_ratios=[0.6, 0.2, 0.2]): - """ - Prepare data pipeline for training both GRU and SAC models using chronological split. - - Args: - historical_data: DataFrame with OHLCV data, sorted chronologically. - Must have a DatetimeIndex. - split_ratios: Train/validation/test split ratios based on time. - - Returns: - Tuple of (train_data, validation_data, test_data) DataFrames. - """ - if historical_data is None or historical_data.empty: - logging.error("Input data is empty, cannot create pipeline.") - return None, None, None - - # Ensure data is sorted by time - historical_data.sort_index(inplace=True) - - # V7 Change: Use index for duration calculation - try: - if not isinstance(historical_data.index, pd.DatetimeIndex): - raise TypeError("Data index must be a DatetimeIndex for splitting.") - total_duration = historical_data.index[-1] - historical_data.index[0] - except IndexError: - logging.error("Cannot calculate duration: Data has insufficient rows.") - return None, None, None - except TypeError as e: - logging.error(f"Error calculating duration: {e}") - return None, None, None - - train_ratio, val_ratio, test_ratio = split_ratios - - # Calculate split points based on time duration - train_end_time = historical_data.index[0] + total_duration * train_ratio - val_end_time = train_end_time + total_duration * val_ratio - - # Perform the split using time index - # Ensure the split points are valid timestamps within the index range - # Use searchsorted to find the index locations closest to the calculated times - train_end_idx = historical_data.index.searchsorted(train_end_time) - val_end_idx = historical_data.index.searchsorted(val_end_time) - - train_data = historical_data.iloc[:train_end_idx] - val_data = historical_data.iloc[train_end_idx:val_end_idx] - test_data = historical_data.iloc[val_end_idx:] - - logging.info(f"Data split complete:") - logging.info(f" Train: {len(train_data)} rows ({train_data.index.min()} to {train_data.index.max()})") - logging.info(f" Validation: {len(val_data)} rows ({val_data.index.min()} to {val_data.index.max()})") - logging.info(f" Test: {len(test_data)} rows ({test_data.index.min()} to {test_data.index.max()})") - - return train_data, val_data, test_data - -def create_sequences_v2(features_scaled, targets_scaled, start_price_unscaled, seq_length=60): - """ - Create sequences for GRU training, handling potential mismatches in indices. - - Args: - features_scaled: Scaled feature DataFrame - targets_scaled: Scaled target Series - start_price_unscaled: Unscaled starting price Series - seq_length: Sequence length for GRU - - Returns: - Tuple of (X sequences, y targets, starting prices) or (None, None, None) if creation fails - """ - data_pipeline_logger.info(f"Creating sequences (v2) with length {seq_length}...") - - try: - # Type checking and conversion for features - if features_scaled is None: - data_pipeline_logger.error("features_scaled is None") - return None, None, None - - if not isinstance(features_scaled, pd.DataFrame): - data_pipeline_logger.warning(f"features_scaled is not DataFrame but {type(features_scaled)}") - try: - if isinstance(features_scaled, pd.Series): - # Try to convert Series to DataFrame - features_scaled = pd.DataFrame(features_scaled) - else: - # Try to convert numpy array to DataFrame - features_scaled = pd.DataFrame(features_scaled) - except Exception as e: - data_pipeline_logger.error(f"Failed to convert features_scaled to DataFrame: {e}") - return None, None, None - - # Type checking and conversion for targets - if targets_scaled is None: - data_pipeline_logger.error("targets_scaled is None") - return None, None, None - - if not isinstance(targets_scaled, pd.Series): - data_pipeline_logger.warning(f"targets_scaled is not Series but {type(targets_scaled)}") - try: - if isinstance(targets_scaled, pd.DataFrame) and targets_scaled.shape[1] == 1: - # Convert single-column DataFrame to Series - targets_scaled = targets_scaled.iloc[:, 0] - elif isinstance(targets_scaled, np.ndarray) and targets_scaled.ndim == 1: - # Convert 1D array to Series - targets_scaled = pd.Series(targets_scaled) - else: - data_pipeline_logger.error(f"targets_scaled shape is not compatible: {getattr(targets_scaled, 'shape', 'unknown')}") - return None, None, None - except Exception as e: - data_pipeline_logger.error(f"Failed to convert targets_scaled to Series: {e}") - return None, None, None - - # Type checking and conversion for prices - if start_price_unscaled is None: - data_pipeline_logger.error("start_price_unscaled is None") - return None, None, None - - if not isinstance(start_price_unscaled, pd.Series): - data_pipeline_logger.warning(f"start_price_unscaled is not Series but {type(start_price_unscaled)}") - try: - if isinstance(start_price_unscaled, pd.DataFrame) and start_price_unscaled.shape[1] == 1: - # Convert single-column DataFrame to Series - start_price_unscaled = start_price_unscaled.iloc[:, 0] - elif isinstance(start_price_unscaled, np.ndarray) and start_price_unscaled.ndim == 1: - # Convert 1D array to Series - start_price_unscaled = pd.Series(start_price_unscaled) - else: - data_pipeline_logger.error(f"start_price_unscaled shape is not compatible: {getattr(start_price_unscaled, 'shape', 'unknown')}") - return None, None, None - except Exception as e: - data_pipeline_logger.error(f"Failed to convert start_price_unscaled to Series: {e}") - return None, None, None - - # Log input info - data_pipeline_logger.info(f"Features index type: {type(features_scaled.index)}, length: {len(features_scaled)}") - data_pipeline_logger.info(f"Targets index type: {type(targets_scaled.index)}, length: {len(targets_scaled)}") - data_pipeline_logger.info(f"Start price index type: {type(start_price_unscaled.index)}, length: {len(start_price_unscaled)}") - - # Check for index compatibility - if (len(features_scaled) != len(targets_scaled) or - len(features_scaled) != len(start_price_unscaled)): - data_pipeline_logger.warning(f"Input lengths don't match! Features: {len(features_scaled)}, Targets: {len(targets_scaled)}, Prices: {len(start_price_unscaled)}") - - # Try to align on common index if all have DatetimeIndex - if (isinstance(features_scaled.index, pd.DatetimeIndex) and - isinstance(targets_scaled.index, pd.DatetimeIndex) and - isinstance(start_price_unscaled.index, pd.DatetimeIndex)): - - # Find common dates - common_index = features_scaled.index.intersection( - targets_scaled.index.intersection(start_price_unscaled.index) - ) - - # Check if we have any overlap - if len(common_index) < seq_length: - data_pipeline_logger.error(f"Not enough common indices ({len(common_index)}) for sequence length ({seq_length})") - - # If we don't have enough common indices, create synthetic indices - # First find the shortest length - min_len = min(len(features_scaled), len(targets_scaled), len(start_price_unscaled)) - - # If we have enough data for at least one sequence - if min_len >= seq_length: - data_pipeline_logger.warning(f"Using the shortest dataset length ({min_len})") - - # Convert features to numpy - features_np = features_scaled.values[:min_len] - - # Try to keep targets a Series and preserve its index - if len(targets_scaled) > min_len: - targets_scaled = targets_scaled.iloc[:min_len] - - # Try to keep prices a Series and preserve its index - if len(start_price_unscaled) > min_len: - start_price_unscaled = start_price_unscaled.iloc[:min_len] - - # Create sequences from numpy arrays - X_sequences, y_targets, starting_prices = [], [], [] - for i in range(len(features_np) - seq_length): - X_sequences.append(features_np[i:i+seq_length]) - y_targets.append(targets_scaled.iloc[i+seq_length]) - starting_prices.append(start_price_unscaled.iloc[i+seq_length]) - - if len(X_sequences) == 0: - data_pipeline_logger.error("No sequences created") - return None, None, None - - return np.array(X_sequences), np.array(y_targets), np.array(starting_prices) - else: - data_pipeline_logger.error(f"Cannot create sequences: Not enough data for sequence length {seq_length}") - return None, None, None - - # Align on common index - data_pipeline_logger.info(f"Aligning on {len(common_index)} common indices") - features_scaled = features_scaled.loc[common_index] - targets_scaled = targets_scaled.loc[common_index] - start_price_unscaled = start_price_unscaled.loc[common_index] - else: - data_pipeline_logger.error("Cannot create sequences v2: No common index found.") - return None, None, None - - # Convert features to numpy array for sequence creation - features_np = features_scaled.values - - # Create sequences - X_sequences, y_targets, starting_prices = [], [], [] - for i in range(len(features_np) - seq_length): - X_sequences.append(features_np[i:i+seq_length]) - y_targets.append(targets_scaled.iloc[i+seq_length]) - starting_prices.append(start_price_unscaled.iloc[i+seq_length]) - - if len(X_sequences) == 0: - data_pipeline_logger.error("No sequences created, check sequence length vs data length") - return None, None, None - - data_pipeline_logger.info(f"Created {len(X_sequences)} sequences of length {seq_length}") - return np.array(X_sequences), np.array(y_targets), np.array(starting_prices) - except Exception as e: - data_pipeline_logger.error(f"Error creating sequences: {e}", exc_info=True) - return None, None, None \ No newline at end of file diff --git a/gru_sac_predictor/src/feature_engineer.py b/gru_sac_predictor/src/feature_engineer.py new file mode 100644 index 00000000..c03ecbca --- /dev/null +++ b/gru_sac_predictor/src/feature_engineer.py @@ -0,0 +1,331 @@ +""" +Feature Engineering Component. + +Handles adding base features (cyclical, imbalance, TA) and selecting features +using Logistic Regression (L1) and Variance Inflation Factor (VIF). +""" + +import pandas as pd +import numpy as np +import logging +import json + +from sklearn.linear_model import LogisticRegression +from sklearn.feature_selection import SelectFromModel +from statsmodels.stats.outliers_influence import variance_inflation_factor +import statsmodels.api as sm + +# Import TA library functions directly +from ta.volatility import AverageTrueRange +from ta.momentum import RSIIndicator +from ta.trend import EMAIndicator, MACD + +logger = logging.getLogger(__name__) + +_EPS = 1e-6 + +class FeatureEngineer: + """Encapsulates feature creation and selection logic.""" + + def __init__(self, minimal_whitelist: list): + """ + Initialize the FeatureEngineer. + + Args: + minimal_whitelist (list): A base list of features considered essential. + """ + self.minimal_whitelist = minimal_whitelist + logger.info(f"FeatureEngineer initialized with minimal whitelist: {self.minimal_whitelist}") + + def _add_cyclical_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds sine and cosine transformations of the hour.""" + if isinstance(df.index, pd.DatetimeIndex): + timestamp_source = df.index + logger.info("Adding cyclical hour features (sin/cos)...") + df['hour_sin'] = np.sin(2 * np.pi * timestamp_source.hour / 24) + df['hour_cos'] = np.cos(2 * np.pi * timestamp_source.hour / 24) + else: + logger.warning("Index is not DatetimeIndex. Skipping cyclical hour features.") + # Add placeholders if needed by downstream code, though it's better to ensure datetime index upstream + df['hour_sin'] = 0.0 + df['hour_cos'] = 1.0 + return df + + def _add_imbalance_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Add Chaikin AD line, signed volume imbalance, gap imbalance.""" + logger.info("Adding imbalance features...") + if not {"open", "high", "low", "close", "volume"}.issubset(df.columns): + logger.warning("Missing required columns for imbalance features. Skipping.") + return df + + try: + clv = ((df["close"] - df["low"]) - (df["high"] - df["close"])) / ( + df["high"] - df["low"] + _EPS + ) + df["chaikin_AD_10"] = (clv * df["volume"]).rolling(10).sum() + + signed_vol = np.where(df["close"] >= df["open"], df["volume"], -df["volume"]) + df["svi_10"] = pd.Series(signed_vol, index=df.index).rolling(10).sum() + + med_vol = df["volume"].rolling(50).median() + gap_up = (df["low"] > df["high"].shift(1)) & (df["volume"] > 2 * med_vol) + gap_dn = (df["high"] < df["low"].shift(1)) & (df["volume"] > 2 * med_vol) + df["gap_imbalance"] = gap_up.astype(int) - gap_dn.astype(int) + + # Fill NaNs introduced - use fillna(0) for simplicity here + # More robust filling (bfill/ffill) might be needed depending on feature use + df.fillna({"chaikin_AD_10": 0, "svi_10": 0, "gap_imbalance": 0}, inplace=True) + logger.info("Successfully added imbalance features.") + except Exception as e: + logger.error(f"Error calculating imbalance features: {e}", exc_info=True) + + return df + + def _add_ta_features(self, df: pd.DataFrame) -> pd.DataFrame: + """Adds TA features using the 'ta' library.""" + logger.info("Adding TA features...") + required_cols = {'open', 'high', 'low', 'close', 'volume'} + if not required_cols.issubset(df.columns): + logger.warning(f"Missing required columns for TA features ({required_cols - set(df.columns)}). Skipping TA.") + return df + + # Apply shift(1) to prevent lookahead bias in TA features based on close + # Features will be calculated based on data up to t-1 + df_shifted = df.shift(1) + df_ta = pd.DataFrame(index=df.index) # Create empty DF to store results aligned with original index + + try: + # Calculate returns first (use shifted close) + # Fill NaNs robustly before pct_change on the *shifted* data + close_filled = df_shifted["close"].bfill().ffill() + df_ta["return_1m"] = close_filled.pct_change() + df_ta["return_15m"] = close_filled.pct_change(15) + df_ta["return_60m"] = close_filled.pct_change(60) + + # Calculate TA features using ta library on *shifted* data + df_ta["ATR_14"] = AverageTrueRange(df_shifted['high'], df_shifted['low'], df_shifted['close'], window=14).average_true_range() + + # Daily volatility (use calculated 1m return) + df_ta["volatility_14d"] = ( + df_ta["return_1m"].rolling(60 * 24 * 14, min_periods=30).std() # rough 14d for 1‑min bars + ) + + # EMA 10 / 50 + MACD using ta library (on shifted close) + df_ta["EMA_10"] = EMAIndicator(df_shifted["close"], 10).ema_indicator() + df_ta["EMA_50"] = EMAIndicator(df_shifted["close"], 50).ema_indicator() + macd = MACD(df_shifted["close"], window_slow=26, window_fast=12, window_sign=9) + df_ta["MACD"] = macd.macd() + df_ta["MACD_signal"] = macd.macd_signal() + + # RSI 14 using ta library (on shifted close) + df_ta["RSI_14"] = RSIIndicator(df_shifted["close"], window=14).rsi() + + # Handle potential NaNs introduced by TA calculations + df_ta.bfill(inplace=True) + df_ta.ffill(inplace=True) + + # Add the calculated TA features back to the original df + ta_cols_to_add = [col for col in df_ta.columns if col not in df.columns] + for col in ta_cols_to_add: + df[col] = df_ta[col] + logger.info("Successfully added TA features.") + + except Exception as e: + logger.error(f"Error calculating TA features: {e}", exc_info=True) + + return df + + def add_base_features(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Adds a standard set of base features: cyclical, imbalance, and TA. + + Args: + df (pd.DataFrame): Input DataFrame with OHLCV data and DatetimeIndex. + + Returns: + pd.DataFrame: DataFrame with added features. + """ + logger.info("--- Adding Base Features --- ") + df_out = df.copy() + df_out = self._add_cyclical_features(df_out) + df_out = self._add_imbalance_features(df_out) + df_out = self._add_ta_features(df_out) + + # Ensure minimal whitelist columns exist, fill with 0 if missing after calculation errors + for col in self.minimal_whitelist: + if col not in df_out.columns: + logger.warning(f"Minimal whitelist feature '{col}' not found after calculations. Adding column filled with 0.") + df_out[col] = 0.0 + + logger.info(f"Base feature engineering complete. DataFrame shape: {df_out.shape}") + return df_out + + def select_features(self, X_train_raw: pd.DataFrame, y_dir_train: pd.Series, vif_threshold: float = 10.0, logreg_c: float = 0.1) -> list: + """ + Performs feature selection using Logistic Regression (L1) followed by VIF filtering. + + Args: + X_train_raw (pd.DataFrame): Raw training features (should include minimal whitelist + others). + y_dir_train (pd.Series): Binary direction target for training. + vif_threshold (float): VIF threshold for multicollinearity removal. + logreg_c (float): Inverse of regularization strength for Logistic Regression. + + Returns: + list: The final list of selected feature names. + """ + logger.info("--- Selecting Features (LogReg L1 + VIF) ---") + + if X_train_raw is None or X_train_raw.empty or y_dir_train is None or y_dir_train.empty: + logger.error("Input data for feature selection is missing or empty. Returning minimal whitelist.") + return self.minimal_whitelist + + initial_features = X_train_raw.columns.tolist() + logger.info(f"Starting selection from {len(initial_features)} raw features.") + + final_whitelist = self.minimal_whitelist # Default to minimal if errors occur + + try: + # --- LogReg L1 Selection --- + logger.info(f"Performing Logistic Regression (L1, C={logreg_c}) selection...") + + # Handle potential NaN/Inf values robustly before fitting + X_train_processed = X_train_raw.copy() + X_train_processed.replace([np.inf, -np.inf], np.nan, inplace=True) + # Check if imputation is needed + if X_train_processed.isnull().any().any(): + logger.warning("NaNs detected in training features for LogReg. Imputing with column median.") + # Impute with median, could use other strategies (mean, IterativeImputer) + X_train_processed = X_train_processed.fillna(X_train_processed.median()) + # Check if any NaNs remain (e.g., all values were NaN) + if X_train_processed.isnull().any().any(): + logger.warning("NaNs still present after median imputation. Filling remaining NaNs with 0.") + X_train_processed.fillna(0, inplace=True) + + logreg_selector = LogisticRegression( + penalty='l1', + solver='liblinear', + C=logreg_c, + random_state=42, + max_iter=1000, + class_weight='balanced' # Added for potentially imbalanced targets + ) + + feature_selector = SelectFromModel(estimator=logreg_selector, prefit=False, threshold=-np.inf) # Select all features with non-zero coefficients + feature_selector.fit(X_train_processed, y_dir_train) + + selected_features_mask = feature_selector.get_support() + selected_features_names = X_train_raw.columns[selected_features_mask].tolist() + logger.info(f"Features selected by LogReg L1: {selected_features_names}") + + # Combine with the minimal whitelist + candidate_whitelist_set = set(self.minimal_whitelist).union(set(selected_features_names)) + candidate_whitelist = sorted(list(candidate_whitelist_set)) + logger.info(f"Candidate whitelist after LogReg ({len(candidate_whitelist)} features): {candidate_whitelist}") + + # --- VIF Filtering --- + logger.info(f"Performing VIF filtering (threshold={vif_threshold}) on candidate features...") + + # Prepare data subset for VIF calculation + X_vif = X_train_raw[candidate_whitelist].copy() + + # Handle NaN/Inf robustly before VIF + X_vif.replace([np.inf, -np.inf], np.nan, inplace=True) + if X_vif.isnull().any().any(): + logger.warning("NaNs detected in features for VIF. Imputing with median.") + X_vif = X_vif.fillna(X_vif.median()) + if X_vif.isnull().any().any(): + logger.warning("NaNs still present after median imputation for VIF. Filling remaining NaNs with 0.") + X_vif.fillna(0, inplace=True) + + features_after_vif = list(candidate_whitelist) # Start with the candidate list + + # Iteratively remove features with high VIF + while len(features_after_vif) > 1: + # Add constant for VIF calculation + X_vif_subset = X_vif[features_after_vif] + X_vif_const = sm.add_constant(X_vif_subset, prepend=False, has_constant='raise') + + try: + vif_data = pd.DataFrame() + vif_data["feature"] = X_vif_const.columns[:-1] # Exclude the added constant + vif_values = [variance_inflation_factor(X_vif_const.values, i) + for i in range(X_vif_const.shape[1] - 1)] + vif_data["VIF"] = vif_values + except Exception as vif_e: + logger.error(f"Error calculating VIF: {vif_e}. Stopping VIF filtering.") + # Return the features identified so far before the error + final_whitelist = features_after_vif + logger.warning(f"Returning {len(final_whitelist)} features identified before VIF error: {final_whitelist}") + return final_whitelist + + max_vif = vif_data["VIF"].max() + # Check for infinite VIF separately, often indicates perfect multicollinearity + if np.isinf(max_vif): + feature_to_remove = vif_data.loc[vif_data["VIF"].idxmax(), "feature"] + logger.warning(f"Removing feature '{feature_to_remove}' due to infinite VIF.") + elif max_vif > vif_threshold: + feature_to_remove = vif_data.loc[vif_data["VIF"].idxmax(), "feature"] + logger.info(f"Removing feature '{feature_to_remove}' due to high VIF ({max_vif:.2f} > {vif_threshold})...") + else: + logger.info(f"VIF filtering complete. Max VIF = {max_vif:.2f} <= {vif_threshold}.") + break # All VIFs are below threshold + + features_after_vif.remove(feature_to_remove) + + final_whitelist = features_after_vif + if len(final_whitelist) == 0: + logger.warning("VIF filtering removed all features. Falling back to minimal whitelist.") + final_whitelist = self.minimal_whitelist + else: + logger.info(f"Final whitelist after VIF filtering ({len(final_whitelist)} features): {final_whitelist}") + + except Exception as e: + logger.error(f"Feature selection or VIF filtering failed: {e}", exc_info=True) + logger.warning("Falling back to the minimal_whitelist.") + final_whitelist = self.minimal_whitelist + + # Ensure final whitelist is a list of strings + if not isinstance(final_whitelist, list): + final_whitelist = list(final_whitelist) + + return final_whitelist + + def prune_features(self, df: pd.DataFrame, whitelist: list) -> pd.DataFrame: + """ + Selects columns from a DataFrame based on a provided whitelist. + + Args: + df (pd.DataFrame): The input DataFrame. + whitelist (list): List of column names to keep. + + Returns: + pd.DataFrame: DataFrame containing only the whitelisted columns. + """ + if df is None or df.empty: + logger.warning("Input DataFrame for pruning is None or empty. Returning empty DataFrame.") + return pd.DataFrame() + + logger.debug(f"Pruning DataFrame (shape {df.shape}) to whitelist: {whitelist}") + + # Find columns present in both DataFrame and whitelist + cols_to_keep = [c for c in whitelist if c in df.columns] + missing_cols = [c for c in whitelist if c not in df.columns] + + if missing_cols: + logger.warning(f"The following whitelisted columns were not found in the DataFrame: {missing_cols}") + + if not cols_to_keep: + logger.error("No columns from the whitelist found in the DataFrame. Returning empty DataFrame.") + return pd.DataFrame() + + # Select the columns + df_pruned = df[cols_to_keep].copy() + logger.debug(f"Pruned DataFrame shape: {df_pruned.shape}") + + # Verification (optional but good practice) + if set(df_pruned.columns) != set(cols_to_keep): + logger.error(f"Pruning verification failed: Output columns {set(df_pruned.columns)} != Expected {set(cols_to_keep)}") + # Potentially return original df or raise error depending on desired strictness + # For now, return the potentially incorrect df_pruned + + return df_pruned \ No newline at end of file diff --git a/gru_sac_predictor/src/features.py b/gru_sac_predictor/src/features.py new file mode 100644 index 00000000..0d4300c8 --- /dev/null +++ b/gru_sac_predictor/src/features.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import pandas as pd +import numpy as np +# Restore imports from 'ta' library +from ta.volatility import AverageTrueRange +from ta.momentum import RSIIndicator +from ta.trend import EMAIndicator, MACD +# import talib # Remove talib import + +__all__ = [ + "add_imbalance_features", + "add_ta_features", + "prune_features", + "minimal_whitelist", +] + +_EPS = 1e-6 + + +def add_imbalance_features(df: pd.DataFrame) -> pd.DataFrame: + """Add Chaikin AD line, signed volume imbalance, gap imbalance.""" + if not {"open", "high", "low", "close", "volume"}.issubset(df.columns): + return df + + clv = ((df["close"] - df["low"]) - (df["high"] - df["close"])) / ( + df["high"] - df["low"] + _EPS + ) + df["chaikin_AD_10"] = (clv * df["volume"]).rolling(10).sum() + + signed_vol = np.where(df["close"] >= df["open"], df["volume"], -df["volume"]) + df["svi_10"] = pd.Series(signed_vol, index=df.index).rolling(10).sum() + + med_vol = df["volume"].rolling(50).median() + gap_up = (df["low"] > df["high"].shift(1)) & (df["volume"] > 2 * med_vol) + gap_dn = (df["high"] < df["low"].shift(1)) & (df["volume"] > 2 * med_vol) + df["gap_imbalance"] = gap_up.astype(int) - gap_dn.astype(int) + + df.fillna(0, inplace=True) + return df + + +# ------------------------------------------------------------------ +# Technical analysis features +# ------------------------------------------------------------------ + + +def add_ta_features(df: pd.DataFrame) -> pd.DataFrame: + """Adds TA features to the dataframe using the ta library.""" + # Remove talib checks + # required_cols = {'open': 'open', 'high': 'high', 'low': 'low', 'close': 'close', 'volume': 'volume'} + # if not set(required_cols.keys()).issubset(df.columns): + # print(f"WARN: Missing required columns for TA-Lib in features.py. Need {required_cols.keys()}") + # return df + # Ensure correct dtype for talib (often float64) + # for col in required_cols.keys(): + # if df[col].dtype != np.float64: + # try: + # df[col] = df[col].astype(np.float64) + # except Exception as e: + # print(f"WARN: Could not convert column {col} to float64 for TA-Lib: {e}") + # return df # Cannot proceed if conversion fails + + df_copy = df.copy() + + # Calculate returns first (use bfill + ffill for pct_change compatibility) + # Fill NaNs robustly before pct_change + df_copy["close_filled"] = df_copy["close"].bfill().ffill() + df_copy["return_1m"] = df_copy["close_filled"].pct_change() + df_copy["return_15m"] = df_copy["close_filled"].pct_change(15) + df_copy["return_60m"] = df_copy["close_filled"].pct_change(60) + df_copy.drop(columns=["close_filled"], inplace=True) + + # Calculate TA features using ta library + # df_copy["ATR_14"] = talib.ATR(df_copy['high'], df_copy['low'], df_copy['close'], timeperiod=14) + df_copy["ATR_14"] = AverageTrueRange(df_copy['high'], df_copy['low'], df_copy['close'], window=14).average_true_range() + + # Daily volatility 14d of returns + df_copy["volatility_14d"] = ( + df_copy["return_1m"].rolling(60 * 24 * 14, min_periods=30).std() # rough 14d for 1‑min bars + ) + + # EMA 10 / 50 + MACD using ta library + # df_copy["EMA_10"] = talib.EMA(df_copy["close"], timeperiod=10) + # df_copy["EMA_50"] = talib.EMA(df_copy["close"], timeperiod=50) + df_copy["EMA_10"] = EMAIndicator(df_copy["close"], 10).ema_indicator() + df_copy["EMA_50"] = EMAIndicator(df_copy["close"], 50).ema_indicator() + # talib.MACD returns macd, macdsignal, macdhist + # macd, macdsignal, macdhist = talib.MACD(df_copy["close"], fastperiod=12, slowperiod=26, signalperiod=9) + macd = MACD(df_copy["close"], window_slow=26, window_fast=12, window_sign=9) + df_copy["MACD"] = macd.macd() + df_copy["MACD_signal"] = macd.macd_signal() + + # RSI 14 using ta library + # df_copy["RSI_14"] = talib.RSI(df_copy["close"], timeperiod=14) + df_copy["RSI_14"] = RSIIndicator(df_copy["close"], window=14).rsi() + + # Cyclical hour already recommended to add upstream (data_pipeline). + + # Handle potential NaNs introduced by TA calculations + # df.fillna(method="bfill", inplace=True) # Deprecated + df_copy.bfill(inplace=True) + df_copy.ffill(inplace=True) # Add ffill for any remaining NaNs at the beginning + + return df_copy + + +# ------------------------------------------------------------------ +# Pruning & whitelist +# ------------------------------------------------------------------ + +minimal_whitelist = [ + "return_1m", + "return_15m", + "return_60m", + "ATR_14", + "volatility_14d", + "chaikin_AD_10", + "svi_10", + "EMA_10", + "EMA_50", + "MACD", + "MACD_signal", + "hour_sin", + "hour_cos", +] + + +def prune_features(df: pd.DataFrame, whitelist: list[str] | None = None) -> pd.DataFrame: + """Return DataFrame containing only *whitelisted* columns.""" + if whitelist is None: + whitelist = minimal_whitelist + # Find columns present in both DataFrame and whitelist + cols_to_keep = [c for c in whitelist if c in df.columns] + # Ensure the set of kept columns exactly matches the intersection + df_pruned = df[cols_to_keep].copy() + assert set(df_pruned.columns) == set(cols_to_keep), \ + f"Pruning failed: Output columns {set(df_pruned.columns)} != Expected {set(cols_to_keep)}" + # Optional: Assert against the full whitelist if input is expected to always contain all + # assert set(df_pruned.columns) == set(whitelist), \ + # f"Pruning failed: Output columns {set(df_pruned.columns)} != Full whitelist {set(whitelist)}" + return df_pruned \ No newline at end of file diff --git a/gru_sac_predictor/src/gru_model_handler.py b/gru_sac_predictor/src/gru_model_handler.py new file mode 100644 index 00000000..06d4aa27 --- /dev/null +++ b/gru_sac_predictor/src/gru_model_handler.py @@ -0,0 +1,217 @@ +""" +Handles GRU Model Training, Loading, Saving, and Prediction. +""" + +import tensorflow as tf +from tensorflow.keras import Model, callbacks +from tensorflow.keras import saving # Use saving module for custom object registration +import numpy as np +import os +import logging +from tqdm.keras import TqdmCallback +from typing import Dict, Tuple, Any + +# Import necessary components from model_gru +# It's better if build_gru_model and gaussian_nll are defined here +# or imported reliably. Assuming they stay in model_gru for now. +try: + from .model_gru import build_gru_model, gaussian_nll +except ImportError: + # If run directly, might need adjustment or place build_gru_model here + logging.error("Failed to import build_gru_model/gaussian_nll from .model_gru. Ensure structure is correct.") + # Define gaussian_nll here as a fallback if import fails and it's critical + @saving.register_keras_serializable(package='GRU') + def gaussian_nll(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: + mu, log_sigma = tf.split(y_pred, 2, axis=-1) + y_true_shaped = tf.reshape(y_true, tf.shape(mu)) + inv_var = tf.exp(-2.0 * log_sigma) + nll = 0.5 * inv_var * tf.square(y_true_shaped - mu) + log_sigma + return tf.reduce_mean(nll) + # build_gru_model would also need to be defined here as a fallback + # This indicates a potential structure issue if the import fails + +logger = logging.getLogger(__name__) + +class GRUModelHandler: + """Manages the lifecycle of the GRU model.""" + + def __init__(self, run_id: str, models_dir: str): + """ + Initialize the handler. + + Args: + run_id (str): The current pipeline run ID. + models_dir (str): The base directory where models for this run are saved. + """ + self.run_id = run_id + self.models_dir = models_dir # Should be the specific directory for this run + self.model: Model | None = None + logger.info(f"GRUModelHandler initialized for run {run_id} in {models_dir}") + + def train( + self, + X_train: np.ndarray, + y_train_dict: Dict[str, np.ndarray], + X_val: np.ndarray, + y_val_dict: Dict[str, np.ndarray], + lookback: int, + n_features: int, + max_epochs: int = 25, + batch_size: int = 128, + patience: int = 3 + ) -> Tuple[Model | None, Any]: # Returns model and history + """ + Builds and trains the GRU model. + + Args: + X_train: Training feature sequences. + y_train_dict: Dictionary of training targets. + X_val: Validation feature sequences. + y_val_dict: Dictionary of validation targets. + lookback: Sequence length. + n_features: Number of features per timestep. + max_epochs: Maximum training epochs. + batch_size: Training batch size. + patience: Early stopping patience. + + Returns: + Tuple[Model | None, Any]: The trained Keras model (or None on failure) and the training history object. + """ + logger.info(f"Building GRU model: lookback={lookback}, n_features={n_features}") + try: + # Ensure build_gru_model is available + if 'build_gru_model' not in globals() and 'build_gru_model' not in locals(): + raise NameError("build_gru_model function is not defined or imported.") + self.model = build_gru_model(lookback, n_features) + logger.info("Model built successfully.") + self.model.summary(print_fn=logger.info) # Log model summary + except Exception as e: + logger.error(f"Failed to build GRU model: {e}", exc_info=True) + return None, None + + cb_early = callbacks.EarlyStopping( + monitor="val_loss", # Monitor overall validation loss + patience=patience, + mode='min', + restore_best_weights=True, + verbose=1, + ) + cb_tqdm = TqdmCallback(verbose=1) + + logger.info(f"Starting GRU training: epochs={max_epochs}, batch={batch_size}, patience={patience}") + logger.info(f" Train X shape: {X_train.shape}") + logger.info(f" Val X shape: {X_val.shape}") + logger.info(f" Train y keys: {list(y_train_dict.keys())}") + logger.info(f" Val y keys: {list(y_val_dict.keys())}") + + history = None + try: + history = self.model.fit( + X_train, + y_train_dict, + validation_data=(X_val, y_val_dict), + epochs=max_epochs, + batch_size=batch_size, + callbacks=[cb_early, cb_tqdm], + verbose=0, # Let tqdm handle progress + ) + logger.info("GRU training finished.") + logger.info(f"Best validation loss: {min(history.history['val_loss']):.4f}") + return self.model, history + except Exception as e: + logger.error(f"Error during GRU model training: {e}", exc_info=True) + return None, history # Return None for model but maybe history is useful + + def save(self, model_name: str = 'gru_model') -> str | None: + """ + Saves the current model to the run's model directory. + + Args: + model_name (str): The base name for the saved model file (e.g., 'gru_model'). + The run_id will be appended. + + Returns: + str | None: The full path to the saved model file, or None if saving failed. + """ + if self.model is None: + logger.error("No model available to save.") + return None + + # Use .keras format for modern saving + save_path = os.path.join(self.models_dir, f"{model_name}_{self.run_id}.keras") + try: + self.model.save(save_path) + logger.info(f"GRU model saved successfully to: {save_path}") + return save_path + except Exception as e: + logger.error(f"Failed to save GRU model to {save_path}: {e}", exc_info=True) + return None + + def load(self, model_path: str) -> Model | None: + """ + Loads a GRU model from the specified path. + Handles the custom gaussian_nll loss function. + + Args: + model_path (str): The full path to the saved Keras model file. + + Returns: + Model | None: The loaded Keras model, or None if loading failed. + """ + if not os.path.exists(model_path): + logger.error(f"Model file not found at: {model_path}") + return None + + logger.info(f"Loading GRU model from: {model_path}") + try: + # Ensure gaussian_nll is registered if it wasn't globally + custom_objects = {} + if 'gaussian_nll' in globals() or 'gaussian_nll' in locals(): + custom_objects['gaussian_nll'] = gaussian_nll + else: + # This case should ideally not happen if imports/fallbacks work + logger.warning("gaussian_nll custom object not found during load. Model might fail if it uses it.") + + # Load using custom_objects dictionary + self.model = tf.keras.models.load_model(model_path, custom_objects=custom_objects) + logger.info("GRU model loaded successfully.") + self.model.summary(print_fn=logger.info) # Log summary of loaded model + return self.model + except Exception as e: + logger.error(f"Failed to load GRU model from {model_path}: {e}", exc_info=True) + self.model = None + return None + + def predict(self, X_data: np.ndarray, batch_size: int = 1024) -> Any: + """ + Generates predictions using the loaded/trained model. + + Args: + X_data (np.ndarray): Input data sequences (shape: [n_samples, lookback, n_features]). + batch_size (int): Batch size for prediction. + + Returns: + Any: The model's predictions (typically a list of numpy arrays for multi-output models). + Returns None if no model is available or prediction fails. + """ + if self.model is None: + logger.error("No model available for prediction.") + return None + if X_data is None or len(X_data) == 0: + logger.warning("Input data for prediction is None or empty.") + return None + + logger.info(f"Generating predictions for {len(X_data)} samples...") + try: + predictions = self.model.predict(X_data, batch_size=batch_size) + logger.info("Predictions generated successfully.") + # Log shapes of predictions if it's a list (multi-output) + if isinstance(predictions, list): + pred_shapes = [p.shape for p in predictions] + logger.debug(f"Prediction output shapes: {pred_shapes}") + else: + logger.debug(f"Prediction output shape: {predictions.shape}") + return predictions + except Exception as e: + logger.error(f"Error during model prediction: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/gru_sac_predictor/src/gru_predictor.py b/gru_sac_predictor/src/gru_predictor.py deleted file mode 100644 index 09e3edd0..00000000 --- a/gru_sac_predictor/src/gru_predictor.py +++ /dev/null @@ -1,646 +0,0 @@ -""" -V7 - GRU Model for Cryptocurrency Price Prediction (Adapted from V6) - -This module implements a GRU-based neural network model for predicting -cryptocurrency prices directly (regression), calculating uncertainty via -Monte Carlo dropout, and deriving predicted returns. -""" -import numpy as np -import tensorflow as tf -import os -import joblib -import logging -import time -import sys -from sklearn.preprocessing import MinMaxScaler, StandardScaler # Keep both options -from sklearn.metrics import mean_absolute_error, mean_squared_error, accuracy_score, classification_report, confusion_matrix -from tensorflow.keras.models import Sequential, load_model, Model # Keep Model if needed -from tensorflow.keras.layers import GRU, Dropout, Dense, Input # Keep Input if needed -from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau -from tensorflow.keras.optimizers import Adam -import matplotlib.pyplot as plt - -# Configure logging (ensure it doesn't conflict with other loggers) -gru_logger = logging.getLogger(__name__) -# Avoid adding handlers if they already exist from a root config -if not gru_logger.hasHandlers(): - gru_logger.setLevel(logging.INFO) - # Check if root logger has handlers to avoid duplicate console output - root_logger = logging.getLogger() - if not root_logger.hasHandlers() or all(isinstance(h, logging.FileHandler) for h in root_logger.handlers): - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - console_handler.setFormatter(formatter) - gru_logger.addHandler(console_handler) - gru_logger.propagate = False # Prevent propagation to root if we added a handler - else: - gru_logger.propagate = True # Propagate to root if it has handlers - -# Test log message -gru_logger.info("GRU predictor V7 (V6 Adaptation) logger initialized.") - -# --- Helper: Predict in Batches --- -def predict_in_batches(model, X, batch_size=1024): - """Process predictions in batches to avoid OOM errors""" - n_samples = X.shape[0] - n_batches = (n_samples + batch_size - 1) // batch_size - predictions = [] - for i in range(n_batches): - start_idx = i * batch_size - end_idx = min((i + 1) * batch_size, n_samples) - # Use training=False for standard prediction - batch_predictions = model(X[start_idx:end_idx], training=False).numpy() - predictions.append(batch_predictions) - if n_batches > 10 and (i+1) % max(1, n_batches//10) == 0: - gru_logger.info(f"Prediction batch progress: {i+1}/{n_batches}") - return np.vstack(predictions) - - -class CryptoGRUModel: - """ - GRU-based model for cryptocurrency price prediction, adapted from V6. - Predicts price directly and calculates MC uncertainty. - """ - - def __init__(self, model_dir=None): - """ - Initialize the GRU regression model. - - Args: - model_dir (str, optional): Directory to load pre-trained model and scalers. - """ - gru_logger.info("Initializing V7 CryptoGRUModel (V6 Adaptation)...") - self.model: tf.keras.Model = None - self.feature_scaler: StandardScaler | MinMaxScaler = None # Allow either - self.y_scaler: MinMaxScaler = None # V6 used MinMaxScaler for target - self.model_dir = model_dir # Store model dir for saving/loading convenience - self.is_trained = False - self.is_loaded = False - - if model_dir: - self.load(model_dir) - - def _build_model(self, input_data): - """ - Build the GRU model architecture dynamically from input data dimensions, - exactly matching the V6 implementation. - - Args: - input_data (np.array): Training data, used to determine input shape. - - Returns: - tf.keras.Model: Compiled GRU model matching V6. - """ - gru_logger.info("Building V6 GRU model architecture...") - # Determine input shape from data - input_shape = (input_data.shape[1], input_data.shape[2]) - - # Create Sequential model to match V6 implementation - model = Sequential(name="V6_GRU_Regression") - - # Add explicit Input layer as recommended - model.add(Input(shape=input_shape, name="input_layer")) - - # Add single GRU layer (100 units) - remove input_shape argument - model.add(GRU(100, name="gru_100")) - - # Add Dropout layer (V6 used 0.2) - model.add(Dropout(0.2, name="dropout_0.2")) - - # Output layer with linear activation for price prediction - model.add(Dense(1, activation='linear', name="output_price")) # output_dim = 1 - - # Compile the model using MSE loss as in V6 - model.compile( - optimizer=Adam(learning_rate=0.001), # V6 used 0.001 LR - loss='mse', # V6 used MSE - metrics=['mae', 'mape', tf.keras.metrics.RootMeanSquaredError(name='rmse')] # V6 metrics - ) - - gru_logger.info(f"V6 Model built: 1 GRU layer (100 units), Dropout (0.2), Linear Output.") - gru_logger.info(f"Input shape: {(input_data.shape[1], input_data.shape[2])}") - model.summary(print_fn=gru_logger.info) - - return model - - def train(self, X_train, y_train_scaled, X_val, y_val_scaled, - feature_scaler, y_scaler, # Pass fitted scalers - batch_size=32, epochs=20, patience=10, - model_save_dir='models/gru_predictor_trained'): - """ - Train the GRU regression model using V6 parameters. - Assumes sequences (X) and scaled targets (y) are provided. - Fitted scalers must also be provided for saving. - - Args: - X_train (np.array): Training sequences. - y_train_scaled (np.array): Scaled training target prices (shape: [samples, 1]). - X_val (np.array): Validation sequences. - y_val_scaled (np.array): Scaled validation target prices (shape: [samples, 1]). - feature_scaler: Fitted feature scaler (StandardScaler or MinMaxScaler). - y_scaler: Fitted target scaler (MinMaxScaler). - batch_size (int): Batch size for training (default: 32 as per V6). - epochs (int): Maximum number of epochs (default: 20 as per V6). - patience (int): Patience for early stopping (default: 10). - model_save_dir (str): Directory to save the best model and scalers. - - Returns: - dict: Training history. - """ - gru_logger.info("--- Starting GRU Model Training (V6 Adaptation) ---") - - # Store scalers - if feature_scaler is None or y_scaler is None: - gru_logger.error("Fitted scalers must be provided for training.") - return None - self.feature_scaler = feature_scaler - self.y_scaler = y_scaler - - # Build model dynamically based on input data shape if not already loaded/built - if self.model is None: - gru_logger.info(f"Building model dynamically from input data with shape {X_train.shape}") - self.model = self._build_model(X_train) - - # Set up callbacks (monitoring val_rmse as in V6) - os.makedirs(model_save_dir, exist_ok=True) - best_model_path = os.path.join(model_save_dir, "best_model_reg.keras") - - gru_logger.info("Monitoring val_rmse for EarlyStopping and ModelCheckpoint.") - callbacks = [ - EarlyStopping( - monitor='val_rmse', patience=patience, restore_best_weights=True, - verbose=1, mode='min' - ), - ModelCheckpoint( - filepath=best_model_path, monitor='val_rmse', save_best_only=True, - verbose=1, mode='min' - ), - ReduceLROnPlateau( - monitor='val_rmse', factor=0.5, patience=patience // 2, # Reduce LR faster - min_lr=1e-6, verbose=1, mode='min' - ) - ] - - # Record start time - start_time = time.time() - gru_logger.info(f"Starting training: epochs={epochs}, batch_size={batch_size}, patience={patience}") - - # --- Log first input sequence and target price --- - if len(X_train) > 0 and len(y_train_scaled) > 0: - gru_logger.info(f"First training sequence (X_train[0]) shape: {X_train[0].shape}") - log_steps = min(5, X_train.shape[1]) - gru_logger.info(f"First {log_steps} steps of first training sequence (scaled):\n{X_train[0][:log_steps]}") - gru_logger.info(f"Target scaled price for first sequence (y_train_scaled[0]): {y_train_scaled[0]:.6f}") - else: - gru_logger.warning("X_train or y_train_scaled is empty, cannot log first sequence.") - - # Train the model - history = self.model.fit( - X_train, y_train_scaled, - validation_data=(X_val, y_val_scaled), - batch_size=batch_size, - epochs=epochs, - callbacks=callbacks, - verbose=1, - ) - - # Calculate training duration - training_duration = time.time() - start_time - hours, remainder = divmod(training_duration, 3600) - minutes, seconds = divmod(remainder, 60) - gru_logger.info(f"Training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s") - - # Load the best model saved by ModelCheckpoint - if os.path.exists(best_model_path): - gru_logger.info(f"Loading best regression model from {best_model_path}") - # No custom objects needed for standard MSE loss and metrics - self.model = load_model(best_model_path) - self.is_trained = True - self.is_loaded = False # Trained in this session - # Save scalers alongside the best model - self.save(model_save_dir) # Save model and scalers - else: - gru_logger.warning(f"Best model file not found at {best_model_path}. Using the final state.") - self.is_trained = True - self.is_loaded = False - # Save the final state model and scalers - self.save(model_save_dir) - - return history.history - - def predict_scaled_price(self, X): - """ - Make scaled price predictions with the trained regression model. - - Args: - X (np.array): Input sequences. - - Returns: - np.array: Predicted scaled prices (shape: [samples, 1]). - """ - if self.model is None: - gru_logger.error("Model not loaded or trained. Cannot predict.") - raise ValueError("Model is not available for prediction.") - - gru_logger.info(f"Predicting scaled prices on data with shape {X.shape}") - return self.model.predict(X) - # Consider using predict_in_batches for large X - - - def evaluate(self, X_test, y_test_scaled, y_start_price_test, n_mc_samples=30): - """ - Evaluate the regression model on test data, calculate uncertainty, - and derive predicted returns and confidence. Mirrors V6 evaluation logic. - - Args: - X_test (np.array): Test sequences. - y_test_scaled (np.array): Scaled true target prices (shape: [samples, 1]). - y_start_price_test (np.array): Unscaled price at the start of each test target window (shape: [samples] or [samples, 1]). - n_mc_samples (int): Number of Monte Carlo samples for uncertainty estimation. - - Returns: - dict: Dictionary containing evaluation results: - 'pred_percent_change': Predicted % change based on unscaled price predictions. - 'raw_confidence_score': Confidence score (1 - normalized MC std dev). - 'predicted_unscaled_prices': Unscaled price predictions. - 'mae': Mean Absolute Error (unscaled). - 'rmse': Root Mean Squared Error (unscaled). - 'mape': Mean Absolute Percentage Error (unscaled). - 'mc_unscaled_std_dev': Unscaled standard deviation from MC dropout. - 'misc_metrics': Dict with scaled metrics if needed. - """ - gru_logger.info("--- Starting Model Evaluation (V6 Adaptation) ---") - if self.model is None or self.y_scaler is None: - gru_logger.error("Model or y_scaler not available for evaluation.") - return None - if X_test is None or y_test_scaled is None or y_start_price_test is None: - gru_logger.error("Missing data for evaluation (X_test, y_test_scaled, or y_start_price_test).") - return None - - # --- 1. Standard Prediction --- - gru_logger.info(f"Predicting on Test set (Standard) with shape {X_test.shape}") - y_pred_test_standard_scaled = predict_in_batches(self.model, X_test) # Use batch predictor - - # --- 2. Monte Carlo Dropout Prediction --- - gru_logger.info(f"Running Monte Carlo dropout inference ({n_mc_samples} samples)...") - mc_preds_test_list = [] - - # Define the prediction step with training=True inside the loop - @tf.function - def mc_predict_step_test(batch): - return self.model(batch, training=True) # Enable dropout - - for i in range(n_mc_samples): - mc_preds = [] - # Batch processing within MC loop - for j in range(0, len(X_test), 1024): - batch = X_test[j:j+1024] - batch_preds = mc_predict_step_test(tf.constant(batch, dtype=tf.float32)).numpy() - mc_preds.append(batch_preds) - mc_preds_test_list.append(np.vstack(mc_preds)) - if n_mc_samples > 1 and (i+1) % max(1, n_mc_samples//5) == 0: - gru_logger.info(f" MC progress: {i+1}/{n_mc_samples}") - - mc_preds_stack = np.stack(mc_preds_test_list) - y_pred_test_mc_std_scaled = np.std(mc_preds_stack, axis=0) - gru_logger.info(f"MC dropout completed. Scaled Std Dev: Min={np.min(y_pred_test_mc_std_scaled):.6f}, Max={np.max(y_pred_test_mc_std_scaled):.6f}, Mean={np.mean(y_pred_test_mc_std_scaled):.6f}") - - # --- 3. Inverse Transform --- - gru_logger.info("Inverse transforming predictions and true values...") - try: - y_pred_test_standard_unscaled = self.y_scaler.inverse_transform(y_pred_test_standard_scaled).flatten() - # Reshape y_test_scaled to 2D before inverse_transform - if y_test_scaled.ndim == 1: - y_test_scaled_2d = y_test_scaled.reshape(-1, 1) - else: - y_test_scaled_2d = y_test_scaled # Assume it's already 2D if not 1D - y_test_true_unscaled = self.y_scaler.inverse_transform(y_test_scaled_2d).flatten() - except Exception as e: - gru_logger.error(f"Error during inverse transform: {e}", exc_info=True) - return None - - # --- 4. Unscale MC Standard Deviation --- - gru_logger.info("Unscaling MC standard deviation...") - y_pred_test_mc_std_unscaled_flat = np.zeros_like(y_pred_test_mc_std_scaled.flatten()) # Default - try: - # Use scaler's data range if available (more robust for MinMaxScaler) - if hasattr(self.y_scaler, 'data_min_') and hasattr(self.y_scaler, 'data_max_'): - data_range = self.y_scaler.data_max_[0] - self.y_scaler.data_min_[0] - if data_range > 1e-9: - y_pred_test_mc_std_unscaled_flat = y_pred_test_mc_std_scaled.flatten() * data_range - gru_logger.info(f"Unscaled MC std dev (Test): Mean={np.mean(y_pred_test_mc_std_unscaled_flat):.6f}") - else: - gru_logger.warning("Scaler data range is near zero. Cannot reliably unscale MC std dev.") - else: - gru_logger.warning("y_scaler missing data_min_/data_max_. Cannot unscale MC std dev accurately.") - except Exception as e: - gru_logger.error(f"Error unscaling MC std dev: {e}", exc_info=True) - # Keep the default of zeros or potentially use scaled std dev as fallback? - - # --- 5. Calculate Raw Confidence Score --- - gru_logger.info("Calculating raw confidence score (1 - normalized std dev)...") - test_raw_confidence_score = np.ones_like(y_pred_test_mc_std_unscaled_flat) * 0.5 # Default if std dev is constant - epsilon = 1e-9 - max_std_dev_test = np.max(y_pred_test_mc_std_unscaled_flat) + epsilon - min_std_dev_test = np.min(y_pred_test_mc_std_unscaled_flat) - - if max_std_dev_test > min_std_dev_test: - # Normalize the UNscaled std dev between 0 and 1 - normalized_std_dev = (y_pred_test_mc_std_unscaled_flat - min_std_dev_test) / (max_std_dev_test - min_std_dev_test) - # Confidence is inverse of normalized uncertainty - test_raw_confidence_score = 1.0 - normalized_std_dev - else: - gru_logger.warning("MC standard deviation is constant or near-constant. Setting raw confidence to 0.5.") - - test_raw_confidence_score = np.clip(test_raw_confidence_score, 0.0, 1.0) # Ensure bounds - gru_logger.info(f"Raw Confidence score: Min={np.min(test_raw_confidence_score):.4f}, Max={np.max(test_raw_confidence_score):.4f}, Mean={np.mean(test_raw_confidence_score):.4f}") - - # --- 6. Calculate Predicted Percentage Change --- - gru_logger.info("Calculating predicted percentage change...") - # Ensure y_start_price_test is flattened and aligned - y_start_price_flat = y_start_price_test.flatten() - if len(y_pred_test_standard_unscaled) != len(y_start_price_flat): - gru_logger.error(f"Length mismatch: Pred Price ({len(y_pred_test_standard_unscaled)}) vs Start Price ({len(y_start_price_flat)})") - # Attempt to align if possible (e.g., maybe y_start_price_test has extra initial values) - if len(y_start_price_flat) > len(y_pred_test_standard_unscaled): - diff = len(y_start_price_flat) - len(y_pred_test_standard_unscaled) - gru_logger.warning(f"Attempting alignment by trimming {diff} elements from start_price array.") - y_start_price_flat = y_start_price_flat[diff:] # Trim from the beginning? Or end? Needs context. Assume end. - # y_start_price_flat = y_start_price_flat[:-diff] # Trim from end - - if len(y_pred_test_standard_unscaled) != len(y_start_price_flat): - gru_logger.error("Alignment failed. Cannot calculate predicted change.") - pred_percent_change = np.zeros_like(y_pred_test_standard_unscaled) # Fallback - else: - gru_logger.info("Alignment successful.") - pred_percent_change = np.where( - np.abs(y_start_price_flat) > epsilon, - (y_pred_test_standard_unscaled / y_start_price_flat) - 1, - 0 # Assign 0 change if start price is near zero - ) - else: - pred_percent_change = np.where( - np.abs(y_start_price_flat) > epsilon, - (y_pred_test_standard_unscaled / y_start_price_flat) - 1, - 0 # Assign 0 change if start price is near zero - ) - gru_logger.info(f"Predicted Percent Change: Min={np.min(pred_percent_change):.4f}, Max={np.max(pred_percent_change):.4f}, Mean={np.mean(pred_percent_change):.4f}") - - # --- 7. Calculate Regression Metrics --- - gru_logger.info("Calculating regression metrics (on unscaled data)...") - eval_mae = mean_absolute_error(y_test_true_unscaled, y_pred_test_standard_unscaled) - eval_mse = mean_squared_error(y_test_true_unscaled, y_pred_test_standard_unscaled) - eval_rmse = np.sqrt(eval_mse) - mask = y_test_true_unscaled != 0 - eval_mape = np.mean(np.abs((y_test_true_unscaled[mask] - y_pred_test_standard_unscaled[mask]) / y_test_true_unscaled[mask])) * 100 if np.any(mask) else 0.0 - - gru_logger.info(f"Test MAE (Unscaled): {eval_mae:.4f}") - gru_logger.info(f"Test RMSE (Unscaled): {eval_rmse:.4f}") - gru_logger.info(f"Test MAPE (Unscaled): {eval_mape:.4f}%") - - # --- 8. Return Results --- - results = { - 'pred_percent_change': pred_percent_change, - 'raw_confidence_score': test_raw_confidence_score, - 'predicted_unscaled_prices': y_pred_test_standard_unscaled, - 'true_unscaled_prices': y_test_true_unscaled, # Include true values for plotting/analysis - 'mae': eval_mae, - 'rmse': eval_rmse, - 'mape': eval_mape, - 'mc_unscaled_std_dev': y_pred_test_mc_std_unscaled_flat, - # Add derived direction accuracy for comparison? - # 'derived_direction_accuracy': accuracy_score(np.sign(y_test_true_unscaled - y_start_price_flat), np.sign(y_pred_test_standard_unscaled - y_start_price_flat)) - } - gru_logger.info("--- Evaluation Completed ---") - return results - - - def save(self, model_dir): - """ - Save the trained regression model and scalers. - - Args: - model_dir (str): Directory to save artifacts. - """ - if not (self.is_trained or self.is_loaded): - gru_logger.error("Cannot save, model not trained/loaded.") - return - if self.model is None or self.feature_scaler is None or self.y_scaler is None: - gru_logger.error("Cannot save, model or scalers missing.") - return - - os.makedirs(model_dir, exist_ok=True) - model_path = os.path.join(model_dir, "best_model_reg.keras") # V6 name - feature_scaler_path = os.path.join(model_dir, "feature_scaler.joblib") - y_scaler_path = os.path.join(model_dir, "y_scaler.joblib") - - try: - self.model.save(model_path) - gru_logger.info(f"Keras model saved to {model_path}") - joblib.dump(self.feature_scaler, feature_scaler_path) - gru_logger.info(f"Feature scaler saved to {feature_scaler_path}") - joblib.dump(self.y_scaler, y_scaler_path) - gru_logger.info(f"Target scaler (MinMaxScaler) saved to {y_scaler_path}") - except Exception as e: - gru_logger.error(f"Error saving model/scalers to {model_dir}: {e}", exc_info=True) - - def load(self, model_dir): - """ - Load a previously trained regression model and its scalers. - - Args: - model_dir (str): Directory containing model (.keras) and scalers (.joblib). - """ - gru_logger.info(f"Attempting to load V6-style GRU regression model and scalers from: {model_dir}") - self.model_dir = model_dir - model_path = os.path.join(model_dir, 'best_model_reg.keras') # V7.20 Fix: Load correct filename - scaler_feature_path = os.path.join(model_dir, 'feature_scaler.joblib') - scaler_y_path = os.path.join(model_dir, 'y_scaler.joblib') - - # Check if all required files exist - files_exist = all(os.path.exists(p) for p in [model_path, scaler_feature_path, scaler_y_path]) - - if not files_exist: - gru_logger.warning(f"Cannot load model. Required files missing in {model_dir}.") - gru_logger.warning(f" Missing: {[p for p in [model_path, scaler_feature_path, scaler_y_path] if not os.path.exists(p)]}") - self.is_loaded = False - return False - - try: - # Load Keras model - gru_logger.info(f"Loading GRU model from: {model_path}") - # Load without compiling if optimizer state is not needed or causes issues - self.model = load_model(model_path, compile=False) - - # Load scalers - gru_logger.info(f"Loading feature scaler from: {scaler_feature_path}") - self.feature_scaler = joblib.load(scaler_feature_path) - gru_logger.info(f"Loading target scaler from: {scaler_y_path}") - self.y_scaler = joblib.load(scaler_y_path) - - self.is_loaded = True - self.is_trained = False # Loaded, not trained in this session - gru_logger.info(f"CryptoGRUModel loaded successfully from {model_dir}.") - - # V7.9 Explicitly build after loading to be safe - if self.model and hasattr(self.model, 'input_shape'): - gru_logger.info("Triggering model build after loading...") - dummy_input = tf.zeros((1,) + self.model.input_shape[1:], dtype=tf.float32) - _ = self.model(dummy_input) - gru_logger.info("Model built after loading.") - else: - gru_logger.warning("Could not get input shape to explicitly build model after loading.") - - return True - except Exception as e: - gru_logger.error(f"Error loading model/scalers from {model_dir}: {e}", exc_info=True) - self.model = None; self.feature_scaler = None; self.y_scaler = None - self.is_trained = False; self.is_loaded = False - return False - - def plot_training_history(self, history, save_path=None): - """ - Plot training history (Loss, MAE, RMSE). Adapted from V6. - - Args: - history (dict): Training history from model.fit(). - save_path (str, optional): Path to save the plot. - """ - if not history: - gru_logger.warning("No history data provided to plot.") - return - - plt.figure(figsize=(18, 5)) - - # Plot training & validation loss (MSE) - plt.subplot(1, 3, 1) - plt.plot(history.get('loss', []), label='Train Loss (MSE)') - plt.plot(history.get('val_loss', []), label='Val Loss (MSE)') - plt.title('Model Loss (MSE)') - plt.ylabel('Loss') - plt.xlabel('Epoch') - plt.legend(loc='upper right') - plt.grid(True) - - # Plot training & validation MAE - plt.subplot(1, 3, 2) - plt.plot(history.get('mae', []), label='Train MAE') - plt.plot(history.get('val_mae', []), label='Val MAE') - plt.title('Model Mean Absolute Error') - plt.ylabel('MAE') - plt.xlabel('Epoch') - plt.legend(loc='upper right') - plt.grid(True) - - # Plot training & validation RMSE - plt.subplot(1, 3, 3) - plt.plot(history.get('rmse', []), label='Train RMSE') - plt.plot(history.get('val_rmse', []), label='Val RMSE') - plt.title('Model Root Mean Squared Error') - plt.ylabel('RMSE') - plt.xlabel('Epoch') - plt.legend(loc='upper right') - plt.grid(True) - - plt.tight_layout() - - if save_path: - try: - plt.savefig(save_path) - gru_logger.info(f"Training history plot saved to {save_path}") - except Exception as e: - gru_logger.error(f"Error saving training history plot: {e}") - else: - plt.show() # Display plot if not saving - - plt.close() - - def plot_evaluation_results(self, eval_results, save_path=None): - """ - Plot evaluation results for the REGRESSION model using data from evaluate(). - Adapted from V6. - - Args: - eval_results (dict): Dictionary returned by the evaluate() method. - save_path (str, optional): Path to save the plots. - """ - if not eval_results: - gru_logger.warning("No evaluation results provided to plot.") - return - - y_true = eval_results.get('true_unscaled_prices') - y_pred = eval_results.get('predicted_unscaled_prices') - mae = eval_results.get('mae', -1) - mape = eval_results.get('mape', -1) - confidence = eval_results.get('raw_confidence_score') - uncertainty_std = eval_results.get('mc_unscaled_std_dev') - - if y_true is None or y_pred is None: - gru_logger.error("Missing true or predicted prices in evaluation results for plotting.") - return - - plt.figure(figsize=(15, 12)) # Adjusted size - - # Plot 1: True vs. Predicted Prices with Uncertainty Bands - plt.subplot(3, 1, 1) # Changed to 3 rows - plt.plot(y_true, label='True Prices', alpha=0.7) - plt.plot(y_pred, label=f'Predicted Prices (MAE: {mae:.2f})', alpha=0.7) - if uncertainty_std is not None: - plt.fill_between( - range(len(y_pred)), - y_pred - uncertainty_std, - y_pred + uncertainty_std, - color='orange', alpha=0.2, label='Uncertainty (MC Std Dev)' - ) - plt.title(f"True vs. Predicted Prices (MAPE: {mape:.2f}%)") - plt.ylabel('Price') - plt.legend() - plt.grid(True) - - # Plot 2: Prediction Errors (Residuals) - plt.subplot(3, 1, 2) - errors = y_true - y_pred - plt.hist(errors, bins=50, alpha=0.7) - plt.title(f"Prediction Errors (Residuals) - MAE: {mae:.4f}") - plt.xlabel('Error (True - Predicted)') - plt.ylabel('Frequency') - plt.grid(True) - - # Plot 3: Confidence Scores - plt.subplot(3, 1, 3) - if confidence is not None: - plt.plot(confidence, label='Raw Confidence Score', color='green') - plt.title(f"Confidence Score (Mean: {np.mean(confidence):.3f})") - plt.ylabel("Confidence (0-1)") - plt.ylim(0, 1.05) - plt.legend() - else: - plt.title("Confidence Score (Not Available)") - plt.xlabel('Time Step (Test Set)') - plt.grid(True) - - - plt.tight_layout() - - if save_path: - try: - plt.savefig(save_path) - gru_logger.info(f"Regression evaluation plots saved to {save_path}") - except Exception as e: - gru_logger.error(f"Error saving evaluation plot: {e}") - else: - plt.show() - - plt.close() - -# Example usage placeholder (won't run directly here) -if __name__ == "__main__": - gru_logger.info("CryptoGRUModel (V6 Adaptation) module loaded.") - # Example: - # model = CryptoGRUModel() - # # Need to load data, preprocess, scale, create sequences first... - # # history = model.train(X_train, y_train_scaled, X_val, y_val_scaled, feature_scaler, y_scaler) - # # eval_results = model.evaluate(X_test, y_test_scaled, y_start_price_test) - # # model.plot_training_history(history, save_path='training_hist.png') - # # model.plot_evaluation_results(eval_results, save_path='evaluation.png') diff --git a/gru_sac_predictor/src/model_gru.py b/gru_sac_predictor/src/model_gru.py new file mode 100644 index 00000000..fa9a495e --- /dev/null +++ b/gru_sac_predictor/src/model_gru.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import tensorflow as tf +#import tensorflow_addons as tfa +from tensorflow.keras import layers, Model, callbacks, saving +from typing import Tuple, Dict +import numpy as np +import tqdm +from tqdm.keras import TqdmCallback + +# =================================================================== +# UTILITIES +# =================================================================== + +@saving.register_keras_serializable(package='GRU') +def gaussian_nll(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: # noqa: D401 + """Gaussian negative‑log likelihood for *scalar* targets. + + The model is assumed to predict concatenated [mu, log_sigma]. + + Given targets :math:`y` and predictions :math:`\mu, \log\sigma`, + the NLL is + + .. math:: + + \mathcal{L}_{NLL} = \frac{1}{2} \exp(-2\log\sigma)(y-\mu)^2 + + \log\sigma. + + A small constant is added for numerical stability. + """ + mu, log_sigma = tf.split(y_pred, 2, axis=-1) + # Ensure y_true has the same shape as mu for the subtraction + y_true_shaped = tf.reshape(y_true, tf.shape(mu)) + inv_var = tf.exp(-2.0 * log_sigma) # = 1/sigma^2 + nll = 0.5 * inv_var * tf.square(y_true_shaped - mu) + log_sigma + return tf.reduce_mean(nll) + + +# =================================================================== +# MODEL FACTORY +# =================================================================== + +def build_gru_model(lookback: int, n_features: int) -> Model: + """Return the three‑head GRU model described in *revisions.txt*. + + Architecture: + GRU(64) → + • Dense(2, name="gauss_params") → [μ̂, log σ̂] + • Dense(1, name="ret") – μ̂ sliced from gauss_params + • Dense(1, activation="sigmoid", name="dir") – probability price rises + """ + inputs = layers.Input(shape=(lookback, n_features), name="input") + h = layers.GRU(64, name="gru")(inputs) + + # concatenated [mu, log_sigma] for Gaussian‑NLL + gauss_out = layers.Dense(2, name="gauss_params")(h) + # Slice mu so we can attach a separate Huber loss on μ̂ + mu_out = layers.Lambda(lambda x: x[:, :1], name="ret")(gauss_out) + dir_prob = layers.Dense(1, activation="sigmoid", name="dir")(h) + + model = Model(inputs, outputs=[mu_out, gauss_out, dir_prob], name="crypto_gru") + + losses: Dict[str, tf.keras.losses.Loss | str] = { + "ret": tf.keras.losses.Huber(delta=1e-3), + "gauss_params": gaussian_nll, + "dir": tf.keras.losses.BinaryFocalCrossentropy(alpha=0.5, gamma=2.0, from_logits=False), + } + + loss_weights = {"ret": 1.0, "gauss_params": 0.2, "dir": 0.4} + + # Note: The targets passed to model.fit must align with the outputs. + # y_train should be a dictionary like: + # { "ret": y_ret_train, "gauss_params": y_ret_train, "dir": y_dir_train } + model.compile( + optimizer="adam", + loss=losses, + loss_weights=loss_weights, + run_eagerly=False, + ) + return model + + +# =================================================================== +# HIGH‑LEVEL TRAINING WRAPPER (REMOVED) +# =================================================================== +# Standalone train_model function removed as logic is now in GRUModelHandler.train +# def train_model(...): +# ... \ No newline at end of file diff --git a/gru_sac_predictor/src/sac_agent.py b/gru_sac_predictor/src/sac_agent.py index bb5e0085..276b0817 100644 --- a/gru_sac_predictor/src/sac_agent.py +++ b/gru_sac_predictor/src/sac_agent.py @@ -5,6 +5,7 @@ import tensorflow_probability as tfp from tensorflow.keras.optimizers.schedules import ExponentialDecay import logging import os +import json sac_logger = logging.getLogger(__name__) sac_logger.setLevel(logging.INFO) @@ -35,15 +36,22 @@ class ReplayBuffer: def __init__(self, capacity=100000, state_dim=2, action_dim=1): self.capacity = capacity + self.state_dim = state_dim # Store dims for reset + self.action_dim = action_dim self.counter = 0 - # Initialize buffer arrays - self.states = np.zeros((capacity, state_dim), dtype=np.float32) - self.actions = np.zeros((capacity, action_dim), dtype=np.float32) - self.rewards = np.zeros((capacity, 1), dtype=np.float32) - self.next_states = np.zeros((capacity, state_dim), dtype=np.float32) - self.dones = np.zeros((capacity, 1), dtype=np.float32) - + self._initialize_buffers() + + def _initialize_buffers(self): + """Initializes or resets the buffer arrays.""" + self.states = np.zeros((self.capacity, self.state_dim), dtype=np.float32) + self.actions = np.zeros((self.capacity, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.capacity, 1), dtype=np.float32) + self.next_states = np.zeros((self.capacity, self.state_dim), dtype=np.float32) + self.dones = np.zeros((self.capacity, 1), dtype=np.float32) + self.counter = 0 + sac_logger.info(f"Replay buffer initialized/reset with capacity {self.capacity}.") + def add(self, state, action, reward, next_state, done): """Add experience to buffer""" idx = self.counter % self.capacity @@ -75,9 +83,12 @@ class ReplayBuffer: def sample(self, batch_size): """Sample batch of experiences from buffer""" + # TODO: Implement prioritized experience replay (PER) or other sampling strategies here. + # Currently uses uniform random sampling. max_idx = min(self.counter, self.capacity) if max_idx < batch_size: - print(f"Warning: Trying to sample {batch_size} elements, but buffer only has {max_idx}. Sampling with replacement.") + # logger.warning(...) # Use logger instead of print + sac_logger.warning(f"Sampling {batch_size}, but buffer only has {max_idx}. Sampling with replacement.") indices = np.random.choice(max_idx, batch_size, replace=True) else: indices = np.random.choice(max_idx, batch_size, replace=False) @@ -90,6 +101,11 @@ class ReplayBuffer: return states, actions, rewards, next_states, dones + def clear_buffer(self): + """Resets the buffer contents and counter.""" + sac_logger.warning("Clearing replay buffer contents.") + self._initialize_buffers() + def __len__(self): """Get current size of buffer""" return min(self.counter, self.capacity) @@ -98,7 +114,7 @@ class SACTradingAgent: """V7.3 Enhanced: SAC agent with updated params and architecture fixes.""" def __init__(self, - state_dim=2, # Standard [pred_ret, uncert] + state_dim=5, # [mu, sigma, edge, |mu|/sigma, position] action_dim=1, gamma=0.99, tau=0.005, @@ -113,9 +129,30 @@ class SACTradingAgent: alpha=0.2, alpha_auto_tune=True, target_entropy=-1.0, - min_buffer_size=1000): + min_buffer_size=10000, + edge_threshold_config: float | None = None): """ Initialize the SAC agent with enhancements. + Args: + state_dim (int): The dimension of the state space. + action_dim (int): The dimension of the action space. + gamma (float): The discount factor. + tau (float): The target network update rate. + initial_lr (float): The initial learning rate. + decay_steps (int): The number of steps over which to decay the learning rate. + end_lr (float): The final learning rate. + lr_decay_rate (float): The rate at which to decay the learning rate. + buffer_capacity (int): The capacity of the replay buffer. + ou_noise_stddev (float): The standard deviation of the Ornstein-Uhlenbeck noise. + ou_noise_theta (float): The theta parameter of the Ornstein-Uhlenbeck noise. + ou_noise_dt (float): The dt parameter of the Ornstein-Uhlenbeck noise. + alpha (float): The initial alpha value. + alpha_auto_tune (bool): Whether to automatically tune the alpha value. + target_entropy (float): The target entropy for the alpha auto-tuning. + min_buffer_size (int): The minimum size of the replay buffer before training. + edge_threshold_config (float | None): The edge threshold value from the config + used during this agent's training setup. + Stored for metadata purposes. """ self.state_dim = state_dim self.action_dim = action_dim @@ -124,11 +161,12 @@ class SACTradingAgent: self.min_buffer_size = min_buffer_size self.target_entropy = tf.constant(target_entropy, dtype=tf.float32) self.alpha_auto_tune = alpha_auto_tune + self.edge_threshold_config = edge_threshold_config if self.alpha_auto_tune: self.log_alpha = tf.Variable(tf.math.log(alpha), trainable=True, name='log_alpha') self.alpha = tfp.util.DeferredTensor(self.log_alpha, tf.exp) - self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=initial_lr) + self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=float(initial_lr)) else: self.alpha = tf.constant(alpha, dtype=tf.float32) @@ -137,9 +175,12 @@ class SACTradingAgent: std_deviation=float(ou_noise_stddev) * np.ones(action_dim), theta=ou_noise_theta, dt=ou_noise_dt) + # Ensure explicit types for ExponentialDecay arguments self.lr_schedule = ExponentialDecay( - initial_learning_rate=initial_lr, decay_steps=decay_steps, - decay_rate=lr_decay_rate, staircase=False) + initial_learning_rate=float(initial_lr), + decay_steps=int(decay_steps), + decay_rate=float(lr_decay_rate), + staircase=False) sac_logger.info(f"Using ExponentialDecay LR: init={initial_lr}, steps={decay_steps}, rate={lr_decay_rate}") self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr_schedule) self.critic1_optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr_schedule) @@ -334,24 +375,81 @@ class SACTradingAgent: return metrics def save(self, path): + """Saves agent weights and potentially metadata.""" try: - self.actor.save_weights(f"{path}/actor.weights.h5"); self.critic1.save_weights(f"{path}/critic1.weights.h5") - self.critic2.save_weights(f"{path}/critic2.weights.h5") - if self.alpha_auto_tune and hasattr(self, 'log_alpha'): np.save(f"{path}/log_alpha.npy", self.log_alpha.numpy()) - sac_logger.info(f"Enhanced SAC Agent weights saved to {path}/") - except Exception as e: sac_logger.error(f"Error saving SAC weights: {e}") - + os.makedirs(path, exist_ok=True) + self.actor.save_weights(os.path.join(path, "actor.weights.h5")) + self.critic1.save_weights(os.path.join(path, "critic1.weights.h5")) + self.critic2.save_weights(os.path.join(path, "critic2.weights.h5")) + + metadata = {} + if self.alpha_auto_tune and hasattr(self, 'log_alpha'): + # Save log_alpha directly + np.save(os.path.join(path, "log_alpha.npy"), self.log_alpha.numpy()) + metadata['log_alpha_saved'] = True + else: + metadata['log_alpha_saved'] = False + metadata['fixed_alpha'] = float(self.alpha) # Save fixed alpha value + + # Add other relevant metadata if needed (like state_dim, action_dim used during training) + metadata['state_dim'] = self.state_dim + metadata['action_dim'] = self.action_dim + # Add edge threshold to metadata if it was stored + if self.edge_threshold_config is not None: + metadata['edge_threshold_config'] = self.edge_threshold_config + + meta_path = os.path.join(path, 'agent_metadata.json') + with open(meta_path, 'w') as f: + json.dump(metadata, f, indent=4) + + sac_logger.info(f"SAC Agent weights and metadata saved to {path}/") + except Exception as e: + sac_logger.error(f"Error saving SAC weights/metadata: {e}", exc_info=True) + def load(self, path): + """Loads agent weights and potentially metadata.""" try: + # Load weights (existing logic seems ok, ensures models are built) if not self.actor.built: self.actor.build((None, self.state_dim)) if not self.critic1.built: self.critic1.build([(None, self.state_dim), (None, self.action_dim)]) if not self.critic2.built: self.critic2.build([(None, self.state_dim), (None, self.action_dim)]) if not self.target_critic1.built: self.target_critic1.build([(None, self.state_dim), (None, self.action_dim)]) if not self.target_critic2.built: self.target_critic2.build([(None, self.state_dim), (None, self.action_dim)]) - self.actor.load_weights(f"{path}/actor.weights.h5"); self.critic1.load_weights(f"{path}/critic1.weights.h5") - self.critic2.load_weights(f"{path}/critic2.weights.h5"); self.target_critic1.load_weights(f"{path}/critic1.weights.h5") - self.target_critic2.load_weights(f"{path}/critic2.weights.h5") - log_alpha_path = f"{path}/log_alpha.npy" - if self.alpha_auto_tune and os.path.exists(log_alpha_path): self.log_alpha.assign(np.load(log_alpha_path)); sac_logger.info(f"Loaded log_alpha value") - sac_logger.info(f"Enhanced SAC Agent weights loaded from {path}/") - except Exception as e: sac_logger.error(f"Error loading SAC weights from {path}: {e}. Ensure files exist/shapes match.") \ No newline at end of file + self.actor.load_weights(os.path.join(path, "actor.weights.h5")); self.critic1.load_weights(os.path.join(path, "critic1.weights.h5")) + self.critic2.load_weights(os.path.join(path, "critic2.weights.h5")); self.target_critic1.load_weights(os.path.join(path, "critic1.weights.h5")) + self.target_critic2.load_weights(os.path.join(path, "critic2.weights.h5")) + + # Load metadata + meta_path = os.path.join(path, 'agent_metadata.json') + metadata = {} + if os.path.exists(meta_path): + with open(meta_path, 'r') as f: + metadata = json.load(f) + sac_logger.info(f"Loaded agent metadata: {metadata}") + + # Load log_alpha if saved and auto-tuning + log_alpha_path = os.path.join(path, "log_alpha.npy") + if self.alpha_auto_tune and metadata.get('log_alpha_saved', False) and os.path.exists(log_alpha_path): + self.log_alpha.assign(np.load(log_alpha_path)) + sac_logger.info(f"Restored log_alpha value from saved state.") + elif not self.alpha_auto_tune and 'fixed_alpha' in metadata: + # Restore fixed alpha if not auto-tuning + self.alpha = tf.constant(metadata['fixed_alpha'], dtype=tf.float32) + sac_logger.info(f"Restored fixed alpha value: {self.alpha:.4f}") + + else: + sac_logger.warning(f"Agent metadata file not found at {meta_path}. Cannot verify parameters or load log_alpha.") + + sac_logger.info(f"SAC Agent weights loaded from {path}/") + return metadata # Return metadata for potential checks + + except Exception as e: + sac_logger.error(f"Error loading SAC weights/metadata from {path}: {e}. Ensure files exist/shapes match.", exc_info=True) + return {} # Return empty dict on failure + + def clear_buffer(self): + """Clears the agent's replay buffer.""" + if hasattr(self, 'buffer') and hasattr(self.buffer, 'clear_buffer'): + self.buffer.clear_buffer() + else: + sac_logger.error("Agent or buffer does not have a clear_buffer method.") \ No newline at end of file diff --git a/gru_sac_predictor/src/sac_agent_simplified.py b/gru_sac_predictor/src/sac_agent_simplified.py deleted file mode 100644 index 0c8863fd..00000000 --- a/gru_sac_predictor/src/sac_agent_simplified.py +++ /dev/null @@ -1,503 +0,0 @@ -import os -import numpy as np -import tensorflow as tf -from tensorflow.keras.models import Model -from tensorflow.keras.layers import Input, Dense, Concatenate, BatchNormalization, Add -from tensorflow.keras.optimizers import Adam -import tensorflow.keras.backend as K -import tensorflow_probability as tfp # V7.13 Import TFP - -tfd = tfp.distributions # V7.13 TFP distribution alias - -LOG_STD_MIN = -20 # V7.13 Min log std dev for numerical stability -LOG_STD_MAX = 2 # V7.13 Max log std dev for numerical stability - -class SimplifiedSACTradingAgent: - """ - Simplified SAC Trading Agent optimized for GRU-predicted returns and uncertainty - with guarantees for performance on M1 chips and smaller datasets. - V7.13: Updated for 5D state and automatic alpha tuning. - """ - def __init__( - self, - state_dim=5, # V7.13 Updated state: [pred_ret, unc, z, mom, vol] - action_dim=1, # Position size between -1 and 1 - hidden_size=64, # Reduced network size for faster training - gamma=0.97, # Discount factor for faster adaptation - tau=0.02, # Target network update rate - # alpha=0.1, # V7.13 Removed: Use automatic alpha tuning - actor_lr=3e-4, # V7.13 Default updated - critic_lr=5e-4, # V7.13 Default updated - alpha_lr=3e-4, # V7.13 Learning rate for alpha tuning - batch_size=64, # Smaller batch size for faster updates - buffer_max_size=20000, # Smaller buffer for recency bias - min_buffer_size=1000, # Start learning after this many experiences - update_interval=1, # Update actor every step - target_update_interval=2, # Update target networks every 2 steps - gradient_clip=1.0, # Clip gradients for stability - reward_scale=2.0, # V7.13 Default updated - use_batch_norm=True, # Use batch normalization - use_residual=True, # Use residual connections - target_entropy=None, # V7.13 Target entropy for alpha tuning - model_dir='models/simplified_sac', - ): - self.state_dim = state_dim - self.action_dim = action_dim - self.hidden_size = hidden_size - self.gamma = gamma - self.tau = tau - # self.alpha = alpha # V7.13 Removed - self.actor_lr = actor_lr - self.critic_lr = critic_lr - self.alpha_lr = alpha_lr # V7.13 Store alpha LR - self.batch_size = batch_size - self.buffer_max_size = buffer_max_size - self.min_buffer_size = min_buffer_size - self.update_interval = update_interval - self.target_update_interval = target_update_interval - self.gradient_clip = gradient_clip - self.reward_scale = reward_scale - self.use_batch_norm = use_batch_norm - self.use_residual = use_residual - self.model_dir = model_dir - - # V7.13 Alpha tuning setup - if target_entropy is None: - # Default target entropy heuristic: -dim(A)/2 suggested in instructions - self.target_entropy = -np.prod(self.action_dim) / 2.0 - else: - self.target_entropy = target_entropy - # Initialize log_alpha (trainable variable) - start near log(0.1) ~ -2.3 - self.log_alpha = tf.Variable(np.log(0.1), dtype=tf.float32, name='log_alpha') - self.alpha = tfp.util.DeferredTensor(self.log_alpha, tf.exp) # Exponentiated alpha - self.alpha_optimizer = Adam(learning_rate=self.alpha_lr, name='alpha_optimizer') - print(f"Initialized SAC with automatic alpha tuning. Target Entropy: {self.target_entropy:.2f}") - - # Experience replay buffer (simple numpy arrays instead of deque for performance) - self.buffer_counter = 0 - self.buffer_capacity = buffer_max_size - self.state_buffer = np.zeros((self.buffer_capacity, self.state_dim)) - self.action_buffer = np.zeros((self.buffer_capacity, self.action_dim)) - self.reward_buffer = np.zeros((self.buffer_capacity, 1)) - self.next_state_buffer = np.zeros((self.buffer_capacity, self.state_dim)) - self.done_buffer = np.zeros((self.buffer_capacity, 1)) - - # Step counter - self.train_step_counter = 0 - - # Create actor and critic networks - self.actor = self._build_actor() - self.critic_1 = self._build_critic() - self.critic_2 = self._build_critic() - self.target_critic_1 = self._build_critic() - self.target_critic_2 = self._build_critic() - - # Initialize target networks with actor and critic's weights - self.target_critic_1.set_weights(self.critic_1.get_weights()) - self.target_critic_2.set_weights(self.critic_2.get_weights()) - - # Create optimizers with gradient clipping - self.actor_optimizer = Adam(learning_rate=self.actor_lr, clipnorm=self.gradient_clip) - self.critic_optimizer = Adam(learning_rate=self.critic_lr, clipnorm=self.gradient_clip) - # V7.13 Alpha optimizer already created above - - # Loss tracking - self.actor_loss_history = [] - self.critic_loss_history = [] - - # Ensure model directory exists - os.makedirs(self.model_dir, exist_ok=True) - - # V7.9: Explicitly build models to prevent graph mode errors with optimizers - self._build_models_with_dummy_input() - - def _build_models_with_dummy_input(self): - """Builds actor and critic models with dummy inputs to initialize weights and optimizers.""" - try: - # Dummy state and action tensors - dummy_state = tf.zeros((1, self.state_dim), dtype=tf.float32) - dummy_action = tf.zeros((1, self.action_dim), dtype=tf.float32) - - # Build actor - self.actor(dummy_state) - # Build critics - self.critic_1([dummy_state, dummy_action]) - self.critic_2([dummy_state, dummy_action]) - self.target_critic_1([dummy_state, dummy_action]) - self.target_critic_2([dummy_state, dummy_action]) - - # Optionally, build optimizers (Adam usually builds on first apply_gradients) - # V7.10: Explicitly build optimizers too - if hasattr(self.actor_optimizer, 'build'): # Check if build method exists (newer TF) - self.actor_optimizer.build(self.actor.trainable_variables) - if hasattr(self.critic_optimizer, 'build'): - self.critic_optimizer.build(self.critic_1.trainable_variables + self.critic_2.trainable_variables) - # V7.13 Build alpha optimizer - if hasattr(self.alpha_optimizer, 'build'): - self.alpha_optimizer.build([self.log_alpha]) - print("Simplified SAC Agent models and optimizers built explicitly.") # Log success - except Exception as e: - print(f"Warning: Failed to explicitly build SAC models/optimizers: {e}") - - def _build_actor(self): - """Build a simplified actor network with optional batch norm and residual connections - V7.13: Outputs mean and log_std for a Gaussian policy. - """ - # Input layer - state_input = Input(shape=(self.state_dim,)) - - # First hidden layer - x = Dense(self.hidden_size, activation='relu')(state_input) - if self.use_batch_norm: - x = BatchNormalization()(x) - - # Second hidden layer - y = Dense(self.hidden_size, activation='relu')(x) - if self.use_batch_norm: - y = BatchNormalization()(y) - - # Optional residual connection - if self.use_residual and self.hidden_size == self.hidden_size: - z = Add()([x, y]) - else: - z = y - - # V7.13 Output layer(s) for mean and log_std - mu = Dense(self.action_dim, activation=None, name='mu')(z) - log_std = Dense(self.action_dim, activation=None, name='log_std')(z) - - # V7.13 Clip log_std for stability - log_std = tf.keras.ops.clip(log_std, LOG_STD_MIN, LOG_STD_MAX) - - # V7.13 Create model - outputs mean and log_std - model = Model(inputs=state_input, outputs=[mu, log_std]) - return model - - def _build_critic(self): - """Build a simplified critic network for Q-value estimation""" - # Input layers - state_input = Input(shape=(self.state_dim,)) - action_input = Input(shape=(self.action_dim,)) - - # Concatenate state and action - concat = Concatenate()([state_input, action_input]) - - # First hidden layer - x = Dense(self.hidden_size, activation='relu')(concat) - if self.use_batch_norm: - x = BatchNormalization()(x) - - # Second hidden layer - y = Dense(self.hidden_size, activation='relu')(x) - if self.use_batch_norm: - y = BatchNormalization()(y) - - # Optional residual connection - if self.use_residual and self.hidden_size == self.hidden_size: - z = Add()([x, y]) - else: - z = y - - # Output layer (linear activation for Q-value) - outputs = Dense( - 1, - activation=None, - kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.003, maxval=0.003) - )(z) - - # Create model - model = Model(inputs=[state_input, action_input], outputs=outputs) - return model - - def get_action(self, state, deterministic=False): - """Select trading action (-1 to 1) based on current state - V7.13: Samples from Gaussian policy, calculates log_prob. - """ - state_tensor = tf.convert_to_tensor([state], dtype=tf.float32) - # V7.13 Get mean and log_std from actor - mu, log_std = self.actor(state_tensor) - std = tf.exp(log_std) - - # Create policy distribution - policy_dist = tfd.Normal(mu, std) - - if deterministic: - # Use mean for deterministic action - action = mu - else: - # Sample using reparameterization trick - action = policy_dist.sample() - - # Calculate log probability of the sampled/deterministic action - log_prob = policy_dist.log_prob(action) - log_prob = tf.reduce_sum(log_prob, axis=1, keepdims=True) # Sum across action dim if > 1 - - # Apply tanh squashing - # Action is sampled from Normal, then squashed by tanh - squashed_action = tf.tanh(action) - - # Adjust log_prob for tanh squashing (important!) - # log π(a|s) = log ρ(u|s) - Σ log(1 - tanh(u)^2) - # where u is the pre-tanh action sampled from Normal - # See Appendix C in SAC paper: https://arxiv.org/abs/1801.01290 - log_prob -= tf.reduce_sum(tf.math.log(1.0 - squashed_action**2 + 1e-6), axis=1, keepdims=True) - - # Return the SQUASHED action and its log_prob - # Return numpy arrays for consistency with buffer etc. - return squashed_action[0].numpy(), log_prob[0].numpy() - - def store_transition(self, state, action, reward, next_state, done): - """Store experience in replay buffer with efficient circular indexing""" - # Scale reward - scaled_reward = reward * self.reward_scale - - # Get the index to store experience - index = self.buffer_counter % self.buffer_capacity - - # Store experience - self.state_buffer[index] = state - self.action_buffer[index] = action - self.reward_buffer[index] = scaled_reward - self.next_state_buffer[index] = next_state - self.done_buffer[index] = done - - # Increment counter - self.buffer_counter += 1 - - def sample_batch(self): - """Sample a batch of experiences from replay buffer""" - # Get the valid buffer size (min of counter and capacity) - buffer_size = min(self.buffer_counter, self.buffer_capacity) - - # Sample random indices - batch_indices = np.random.choice(buffer_size, self.batch_size) - - # Get batch - state_batch = self.state_buffer[batch_indices] - action_batch = self.action_buffer[batch_indices] - reward_batch = self.reward_buffer[batch_indices] - next_state_batch = self.next_state_buffer[batch_indices] - done_batch = self.done_buffer[batch_indices] - - return state_batch, action_batch, reward_batch, next_state_batch, done_batch - - def update_target_networks(self): - """Update target critic networks using Polyak averaging""" - # Update target critic 1 - target_weights = self.target_critic_1.get_weights() - critic_weights = self.critic_1.get_weights() - new_weights = [] - for i in range(len(target_weights)): - new_weights.append(self.tau * critic_weights[i] + (1 - self.tau) * target_weights[i]) - self.target_critic_1.set_weights(new_weights) - - # Update target critic 2 - target_weights = self.target_critic_2.get_weights() - critic_weights = self.critic_2.get_weights() - new_weights = [] - for i in range(len(target_weights)): - new_weights.append(self.tau * critic_weights[i] + (1 - self.tau) * target_weights[i]) - self.target_critic_2.set_weights(new_weights) - - @tf.function - def _update_critics(self, states, actions, rewards, next_states, dones): - """Update critic networks with TensorFlow graph execution""" - with tf.GradientTape(persistent=True) as tape: - # V7.18 START: Get next action and log_prob for target Q calculation - next_mu, next_log_std = self.actor(next_states) - next_std = tf.exp(next_log_std) - next_policy_dist = tfd.Normal(next_mu, next_std) - - # Sample action for policy evaluation - next_action_presquash = next_policy_dist.sample() - next_action = tf.tanh(next_action_presquash) # Apply squashing - - # Calculate log_prob, adjusting for tanh squashing - next_log_prob = next_policy_dist.log_prob(next_action_presquash) - next_log_prob = tf.reduce_sum(next_log_prob, axis=1, keepdims=True) - next_log_prob -= tf.reduce_sum(tf.math.log(1.0 - next_action**2 + 1e-6), axis=1, keepdims=True) - # V7.18 END: Get next action and log_prob - - # Get target Q values using the SQUASHED next_action - target_q1 = self.target_critic_1([next_states, next_action]) - target_q2 = self.target_critic_2([next_states, next_action]) - - # Use minimum Q value (Double Q learning) - min_target_q = tf.minimum(target_q1, target_q2) - - # V7.18 Calculate target Q including entropy term - target_q_entropy_adjusted = min_target_q - self.alpha * next_log_prob - - # Calculate target values: r + γ(1-done) * (min(Q′(s′,a′)) - α log π(a′|s′)) - target_values = rewards + self.gamma * (1 - dones) * target_q_entropy_adjusted - - # Stop gradient flow through target values - target_values = tf.stop_gradient(target_values) - - # Get current Q estimates - current_q1 = self.critic_1([states, actions]) - current_q2 = self.critic_2([states, actions]) - - # Calculate critic losses (Huber loss for robustness) - critic1_loss = tf.reduce_mean(tf.keras.losses.huber(target_values, current_q1, delta=1.0)) - critic2_loss = tf.reduce_mean(tf.keras.losses.huber(target_values, current_q2, delta=1.0)) - critic_loss = critic1_loss + critic2_loss - - # Get critic gradients - critic1_gradients = tape.gradient(critic1_loss, self.critic_1.trainable_variables) - critic2_gradients = tape.gradient(critic2_loss, self.critic_2.trainable_variables) - - # Apply critic gradients - self.critic_optimizer.apply_gradients(zip(critic1_gradients, self.critic_1.trainable_variables)) - self.critic_optimizer.apply_gradients(zip(critic2_gradients, self.critic_2.trainable_variables)) - - del tape - return critic_loss - - @tf.function - def _update_actor_and_alpha(self, states): - """Update actor network and alpha temperature with TensorFlow graph execution.""" - with tf.GradientTape(persistent=True) as tape: - # Get policy distribution parameters - mu, log_std = self.actor(states) - std = tf.exp(log_std) - policy_dist = tfd.Normal(mu, std) - - # Sample action using reparameterization trick - actions_presquash = policy_dist.sample() - log_prob_presquash = policy_dist.log_prob(actions_presquash) - log_prob_presquash = tf.reduce_sum(log_prob_presquash, axis=1, keepdims=True) - - # Apply squashing - squashed_actions = tf.tanh(actions_presquash) - - # Adjust log_prob for squashing - log_prob = log_prob_presquash - tf.reduce_sum(tf.math.log(1.0 - squashed_actions**2 + 1e-6), axis=1, keepdims=True) - - # Get Q values from critics using squashed actions - q1_values = self.critic_1([states, squashed_actions]) - q2_values = self.critic_2([states, squashed_actions]) - min_q_values = tf.minimum(q1_values, q2_values) - - # Calculate actor loss: E[alpha * log_prob - Q] - actor_loss = tf.reduce_mean(self.alpha * log_prob - min_q_values) - - # Calculate alpha loss: E[-alpha * (log_prob + target_entropy)] - # Note: We use log_alpha and optimize that variable - alpha_loss = -tf.reduce_mean(self.log_alpha * tf.stop_gradient(log_prob + self.target_entropy)) - - # --- Actor Gradients --- - actor_gradients = tape.gradient(actor_loss, self.actor.trainable_variables) - self.actor_optimizer.apply_gradients(zip(actor_gradients, self.actor.trainable_variables)) - - # --- Alpha Gradients --- - alpha_gradients = tape.gradient(alpha_loss, [self.log_alpha]) - self.alpha_optimizer.apply_gradients(zip(alpha_gradients, [self.log_alpha])) - - del tape # Release tape resources - - return actor_loss, alpha_loss # Return losses for tracking - - def train(self, num_iterations=1): - """Train the agent for a specified number of iterations""" - # Don't train if buffer doesn't have enough experiences - if self.buffer_counter < self.min_buffer_size: - return None # V7.17 Return None if not training - - total_actor_loss = 0 - total_critic_loss = 0 - total_alpha_loss = 0 # V7.19 Track alpha loss - - iterations_run = 0 - for _ in range(num_iterations): - # Sample a batch of experiences - states, actions, rewards, next_states, dones = self.sample_batch() - - # Convert to tensors - states = tf.convert_to_tensor(states, dtype=tf.float32) - actions = tf.convert_to_tensor(actions, dtype=tf.float32) - rewards = tf.convert_to_tensor(rewards, dtype=tf.float32) - next_states = tf.convert_to_tensor(next_states, dtype=tf.float32) - dones = tf.convert_to_tensor(dones, dtype=tf.float32) - - # Update critics - critic_loss = self._update_critics(states, actions, rewards, next_states, dones) - total_critic_loss += critic_loss - - # Update actor and alpha less frequently for stability? - # Original implementation updates actor every step, let's stick to that for now - # if self.train_step_counter % self.update_interval == 0: - actor_loss, alpha_loss = self._update_actor_and_alpha(states) - total_actor_loss += actor_loss - total_alpha_loss += alpha_loss # V7.19 Accumulate alpha loss - - # Track losses (consider appending outside the loop if num_iterations > 1) - self.actor_loss_history.append(actor_loss.numpy()) - self.critic_loss_history.append(critic_loss.numpy()) - - # Update target networks less frequently - if self.train_step_counter % self.target_update_interval == 0: - self.update_target_networks() - - # Increment step counter - self.train_step_counter += 1 - iterations_run += 1 - - # Return average losses for the iterations run in this call - avg_actor_loss = total_actor_loss / iterations_run if iterations_run > 0 else tf.constant(0.0) - avg_critic_loss = total_critic_loss / iterations_run if iterations_run > 0 else tf.constant(0.0) - avg_alpha_loss = total_alpha_loss / iterations_run if iterations_run > 0 else tf.constant(0.0) - - # V7.17 Return tuple (actor_loss, critic_loss) - Needs update for alpha - # V7.19 Return losses including alpha loss (or maybe just actor/critic for main history) - return avg_actor_loss, avg_critic_loss # Keep original return for compatibility with main loop plotting for now - - def save(self, checkpoint_dir=None): - """Save the model weights""" - if checkpoint_dir is None: - checkpoint_dir = self.model_dir - - os.makedirs(checkpoint_dir, exist_ok=True) - - # Save actor weights - self.actor.save_weights(os.path.join(checkpoint_dir, 'actor.weights.h5')) - - # Save critic weights - self.critic_1.save_weights(os.path.join(checkpoint_dir, 'critic_1.weights.h5')) - self.critic_2.save_weights(os.path.join(checkpoint_dir, 'critic_2.weights.h5')) - - # Save target critic weights - self.target_critic_1.save_weights(os.path.join(checkpoint_dir, 'target_critic_1.weights.h5')) - self.target_critic_2.save_weights(os.path.join(checkpoint_dir, 'target_critic_2.weights.h5')) - - # Save alpha - np.save(os.path.join(checkpoint_dir, 'alpha.npy'), self.alpha) - - print(f"Model saved to {checkpoint_dir}") - - def load(self, checkpoint_dir=None): - """Load the model weights""" - if checkpoint_dir is None: - checkpoint_dir = self.model_dir - - try: - # Load actor weights - self.actor.load_weights(os.path.join(checkpoint_dir, 'actor.weights.h5')) - - # Load critic weights - self.critic_1.load_weights(os.path.join(checkpoint_dir, 'critic_1.weights.h5')) - self.critic_2.load_weights(os.path.join(checkpoint_dir, 'critic_2.weights.h5')) - - # Load target critic weights - self.target_critic_1.load_weights(os.path.join(checkpoint_dir, 'target_critic_1.weights.h5')) - self.target_critic_2.load_weights(os.path.join(checkpoint_dir, 'target_critic_2.weights.h5')) - - # Load alpha - if os.path.exists(os.path.join(checkpoint_dir, 'alpha.npy')): - self.alpha = float(np.load(os.path.join(checkpoint_dir, 'alpha.npy'))) - - print(f"Model loaded from {checkpoint_dir}") - return True - except Exception as e: - print(f"Error loading model: {e}") - return False \ No newline at end of file diff --git a/gru_sac_predictor/src/sac_trainer.py b/gru_sac_predictor/src/sac_trainer.py new file mode 100644 index 00000000..8dd0a9f7 --- /dev/null +++ b/gru_sac_predictor/src/sac_trainer.py @@ -0,0 +1,470 @@ +""" +SAC Trainer Component. + +Handles the offline training process for the SAC agent, including loading +dependencies from a specified GRU run, preparing data for the environment, +and executing the training loop. +""" + +import tensorflow as tf +import json +import logging +import os +import sys +import yaml +import joblib +import pandas as pd +import numpy as np +from datetime import datetime +from tqdm import tqdm +from tensorflow.keras.callbacks import TensorBoard + +# Import necessary components from the pipeline +# Use absolute imports assuming the package structure is correct +from gru_sac_predictor.src.data_loader import DataLoader +from gru_sac_predictor.src.feature_engineer import FeatureEngineer +from gru_sac_predictor.src.gru_model_handler import GRUModelHandler +from gru_sac_predictor.src.calibrator import Calibrator +from gru_sac_predictor.src.sac_agent import SACTradingAgent +from gru_sac_predictor.src.trading_env import TradingEnv +try: + from gru_sac_predictor.src.features import minimal_whitelist # For FE fallback +except ImportError: + minimal_whitelist = [] # Define empty if import fails + +logger = logging.getLogger(__name__) + +class SACTrainer: + """Manages the offline SAC training workflow.""" + + def __init__(self, config: dict, base_models_dir: str, base_logs_dir: str, base_results_dir: str): + """ + Initialize the SACTrainer. + + Args: + config (dict): The main pipeline configuration dictionary. + base_models_dir (str): The base directory where all run models are stored (e.g., project_root/models). + base_logs_dir (str): Base directory for logs. + base_results_dir (str): Base directory for results. + """ + self.config = config + self.base_models_dir = base_models_dir + self.sac_cfg = config['sac'] + self.env_cfg = config.get('environment', {}) + self.control_cfg = config.get('control', {}) + self.data_cfg = config['data'] + + # Generate a specific run ID for this SAC training instance + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.sac_train_run_id = f"sac_train_{timestamp}" + logger.info(f"Initializing SACTrainer with Run ID: {self.sac_train_run_id}") + + # Setup directories for this specific SAC training run + self.sac_run_models_dir = os.path.join(self.base_models_dir, self.sac_train_run_id) + self.sac_run_logs_dir = os.path.join(base_logs_dir, self.sac_train_run_id) + self.sac_run_results_dir = os.path.join(base_results_dir, self.sac_train_run_id) + self.sac_tb_log_dir = os.path.join(self.sac_run_logs_dir, 'tensorboard') + + os.makedirs(self.sac_run_models_dir, exist_ok=True) + os.makedirs(self.sac_run_logs_dir, exist_ok=True) + os.makedirs(self.sac_run_results_dir, exist_ok=True) + os.makedirs(self.sac_tb_log_dir, exist_ok=True) + + # Configure logging specifically for this trainer instance if needed + # For now, relies on the pipeline's logger setup + logger.info(f" SAC Models Dir: {self.sac_run_models_dir}") + logger.info(f" SAC Logs Dir: {self.sac_run_logs_dir}") + logger.info(f" SAC Results Dir:{self.sac_run_results_dir}") + logger.info(f" SAC TB Dir: {self.sac_tb_log_dir}") + + # Save config subset relevant to SAC training? + # Or assume full config is saved by the main pipeline + + def _load_gru_dependencies(self, gru_run_id: str) -> dict | None: + """ + Loads artifacts (whitelist, scaler, GRU model, T) from a completed GRU pipeline run. + + Args: + gru_run_id (str): The run ID of the GRU pipeline run. + + Returns: + dict | None: A dictionary containing the loaded dependencies + ('whitelist', 'scaler', 'gru_model', 'optimal_T'), or None on failure. + """ + logger.info(f"--- Loading Dependencies from GRU Run ID: {gru_run_id} ---") + gru_run_models_dir = os.path.join(self.base_models_dir, f"run_{gru_run_id}") + if not os.path.exists(gru_run_models_dir): + logger.error(f"Models directory for GRU run {gru_run_id} not found at: {gru_run_models_dir}") + return None + + dependencies = {} + + # 1. Load Whitelist + whitelist_path = os.path.join(gru_run_models_dir, f"final_whitelist_{gru_run_id}.json") + try: + with open(whitelist_path, 'r') as f: + dependencies['whitelist'] = json.load(f) + logger.info(f"Loaded whitelist ({len(dependencies['whitelist'])} features) from {whitelist_path}") + if not dependencies['whitelist']: + raise ValueError("Loaded whitelist is empty.") + except Exception as e: + logger.error(f"Failed to load whitelist from {whitelist_path}: {e}", exc_info=True) + return None + + # 2. Load Scaler + scaler_path = os.path.join(gru_run_models_dir, f"feature_scaler_{gru_run_id}.joblib") + try: + dependencies['scaler'] = joblib.load(scaler_path) + logger.info(f"Loaded scaler from {scaler_path}") + except Exception as e: + logger.error(f"Failed to load scaler from {scaler_path}: {e}", exc_info=True) + return None + + # 3. Load GRU Model + model_path = os.path.join(gru_run_models_dir, f"gru_model_{gru_run_id}.keras") + # Need a temporary GRU handler instance to use its load method + temp_gru_handler = GRUModelHandler(run_id="temp_load", models_dir="temp_load") + dependencies['gru_model'] = temp_gru_handler.load(model_path) + if dependencies['gru_model'] is None: + logger.error(f"Failed to load GRU model from {model_path}") + return None + logger.info(f"Loaded GRU model from {model_path}") + + # 4. Load Optimal Temperature + temp_path = os.path.join(gru_run_models_dir, f"calibration_temp_{gru_run_id}.npy") + try: + dependencies['optimal_T'] = float(np.load(temp_path)) + logger.info(f"Loaded optimal temperature T={dependencies['optimal_T']:.4f} from {temp_path}") + except Exception as e: + logger.error(f"Failed to load optimal temperature from {temp_path}: {e}", exc_info=True) + # Allow continuation without T? Or require it? Let's require it for now. + return None + + logger.info("--- Successfully loaded all GRU dependencies ---") + return dependencies + + def _prepare_data_for_sac(self, gru_dependencies: dict) -> tuple | None: + """ + Replicates the necessary data loading and preparation steps using the loaded + GRU dependencies, specifically focusing on the VALIDATION dataset to + generate inputs for the TradingEnv. + + Args: + gru_dependencies (dict): The dictionary returned by _load_gru_dependencies. + + Returns: + tuple | None: A tuple containing (mu_val, sigma_val, p_cal_val, actual_ret_val) + for the validation set, or None on failure. + """ + logger.info("--- Preparing Validation Data for SAC Environment --- ") + try: + # 1. Load Raw Data (using a temporary DataLoader) + temp_data_loader = DataLoader(db_dir=self.data_cfg['db_dir']) + df_raw = temp_data_loader.load_data( + ticker=self.data_cfg['ticker'], + exchange=self.data_cfg['exchange'], + start_date=self.data_cfg['start_date'], + end_date=self.data_cfg['end_date'], + interval=self.data_cfg['interval'] + ) + if df_raw is None or df_raw.empty: raise ValueError("Raw data loading failed") + df_raw.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True) + logger.info("Loaded raw data.") + + # 2. Engineer Base Features (using a temporary FeatureEngineer) + # Pass the *minimal* whitelist as a fallback if the loaded one causes issues + temp_feature_engineer = FeatureEngineer(minimal_whitelist=minimal_whitelist) + df_engineered = temp_feature_engineer.add_base_features(df_raw) + df_engineered.dropna(inplace=True) # Drop NaNs after feature eng + if df_engineered.empty: raise ValueError("Dataframe empty after feature engineering.") + logger.info("Engineered base features.") + + # 3. Prune Features using *loaded* whitelist + loaded_whitelist = gru_dependencies['whitelist'] + missing_in_eng = [f for f in loaded_whitelist if f not in df_engineered.columns] + if missing_in_eng: + raise ValueError(f"Features from loaded whitelist missing in engineered data: {missing_in_eng}") + df_features = df_engineered[loaded_whitelist] + logger.info("Pruned features using loaded whitelist.") + + # 4. Define Labels + horizon = self.config['gru'].get('prediction_horizon', 5) + target_ret_col = f'fwd_log_ret_{horizon}' + target_dir_col = f'direction_label_{horizon}' + df_engineered[target_ret_col] = np.log(df_engineered['close'].shift(-horizon) / df_engineered['close']) + df_engineered[target_dir_col] = (df_engineered[target_ret_col] > 0).astype(int) + # Align by dropping NaNs in targets AND ensuring indices match features + df_engineered.dropna(subset=[target_ret_col, target_dir_col], inplace=True) + common_index = df_features.index.intersection(df_engineered.index) + if common_index.empty: + raise ValueError("No common index between features and targets after label definition.") + df_features = df_features.loc[common_index] + df_targets = df_engineered.loc[common_index, [target_ret_col, target_dir_col]] + logger.info("Defined labels and aligned features/targets.") + + # 5. Split Data (to get validation set indices) + split_cfg = self.config['split_ratios'] + train_ratio, val_ratio = split_cfg['train'], split_cfg['validation'] + total_len = len(df_features) + train_end_idx = int(total_len * train_ratio) + val_end_idx = int(total_len * (train_ratio + val_ratio)) + val_indices = df_features.index[train_end_idx:val_end_idx] + if val_indices.empty: raise ValueError("Validation split resulted in empty indices.") + X_val_pruned = df_features.loc[val_indices] + y_val = df_targets.loc[val_indices] + logger.info(f"Isolated validation set data (Features: {X_val_pruned.shape}, Targets: {y_val.shape}).") + + # 6. Scale Validation Features using *loaded* scaler + scaler = gru_dependencies['scaler'] + numeric_cols = X_val_pruned.select_dtypes(include=np.number).columns + X_val_scaled = X_val_pruned.copy() + if not numeric_cols.empty: + X_val_scaled[numeric_cols] = scaler.transform(X_val_pruned[numeric_cols]) + logger.info("Scaled validation features using loaded scaler.") + + # 7. Create Validation Sequences + lookback = self.config['gru']['lookback'] + X_val_seq = [] + y_val_seq_targets = [] # Store corresponding targets + val_seq_indices = [] # Store corresponding indices + features_np = X_val_scaled.values + targets_np = y_val.values # Contains both ret and dir + for i in range(lookback, len(features_np)): + X_val_seq.append(features_np[i-lookback : i]) + y_val_seq_targets.append(targets_np[i]) # Target corresponds to end of sequence + val_seq_indices.append(y_val.index[i]) + + if not X_val_seq: + raise ValueError("Validation sequence creation resulted in empty list.") + + X_val_seq = np.array(X_val_seq) + y_val_seq_targets = np.array(y_val_seq_targets) + actual_ret_val_seq = y_val_seq_targets[:, 0] # First column is return + y_dir_val_seq = y_val_seq_targets[:, 1] # Second column is direction + logger.info(f"Created validation sequences (X shape: {X_val_seq.shape}).") + + # 8. Get GRU Predictions on Validation Sequences using *loaded* GRU model + gru_model = gru_dependencies['gru_model'] + # Use a temporary handler instance with the loaded model + temp_gru_handler = GRUModelHandler(run_id="temp_predict", models_dir="temp") + temp_gru_handler.model = gru_model # Assign the loaded model + predictions_val = temp_gru_handler.predict(X_val_seq) + if predictions_val is None or len(predictions_val) < 3: + raise ValueError("GRU prediction on validation sequences failed.") + mu_val_pred = predictions_val[0].flatten() + log_sigma_val_pred = predictions_val[1][:, 1].flatten() + p_raw_val_pred = predictions_val[2].flatten() + sigma_val_pred = np.exp(log_sigma_val_pred) + logger.info("Generated GRU predictions on validation sequences.") + + # Verify lengths + n_seq = len(X_val_seq) + if not (len(mu_val_pred) == n_seq and len(sigma_val_pred) == n_seq and \ + len(p_raw_val_pred) == n_seq and len(actual_ret_val_seq) == n_seq): + raise ValueError(f"Length mismatch after validation predictions: Expected {n_seq}, got mu={len(mu_val_pred)}, sigma={len(sigma_val_pred)}, p_raw={len(p_raw_val_pred)}, ret={len(actual_ret_val_seq)}") + + # 9. Calibrate Predictions using *loaded* optimal_T + optimal_T = gru_dependencies['optimal_T'] + # Use a temporary calibrator instance + temp_calibrator = Calibrator(edge_threshold=0.5) # Edge threshold doesn't matter here + temp_calibrator.optimal_T = optimal_T + p_cal_val_pred = temp_calibrator.calibrate(p_raw_val_pred) + logger.info(f"Calibrated validation predictions using loaded T={optimal_T:.4f}.") + + # 10. Return the necessary components for the TradingEnv + logger.info("--- Successfully prepared validation data for SAC Environment ---") + return mu_val_pred, sigma_val_pred, p_cal_val_pred, actual_ret_val_seq + + except Exception as e: + logger.error(f"Error preparing data for SAC environment: {e}", exc_info=True) + return None + + def _load_agent_for_resume(self, agent: SACTradingAgent) -> None: + """Loads agent weights if resuming is specified in config.""" + load_run_id = self.control_cfg.get('sac_resume_run_id') + load_step = self.control_cfg.get('sac_resume_step', 'final') + current_edge_threshold = self.config.get('calibration', {}).get('edge_threshold', 0.55) + + if not load_run_id: + logger.info("No SAC resume run ID specified. Starting training from scratch.") + return + + # Construct path relative to base models dir + if load_step == 'final': + load_path = os.path.join(self.base_models_dir, f"{load_run_id}", 'sac_agent_final') + else: + try: + step_num = int(load_step) + load_path = os.path.join(self.base_models_dir, f"{load_run_id}", f'sac_agent_step_{step_num}') + except ValueError: + logger.error(f"Invalid sac_resume_step: {load_step}. Must be 'final' or an integer. Starting fresh.") + return + + logger.info(f"Attempting to load SAC agent from {load_path} to resume training...") + if os.path.exists(load_path): + try: + loaded_meta = agent.load(load_path) + # Check for Buffer Purge on Load + saved_edge_thr = loaded_meta.get('edge_threshold_config') + if saved_edge_thr is not None and abs(saved_edge_thr - current_edge_threshold) > 1e-6: + logger.warning(f'Edge threshold mismatch on load (Saved={saved_edge_thr:.3f}, Current={current_edge_threshold:.3f}). Clearing replay buffer before resuming.') + agent.clear_buffer() + elif saved_edge_thr is None: + logger.warning("Loaded SAC agent metadata did not contain 'edge_threshold_config'. Cannot verify consistency.") + else: + logger.info('Edge threshold consistent with loaded agent metadata.') + except Exception as e: + logger.error(f"Failed to load SAC agent for resume: {e}. Starting fresh.", exc_info=True) + else: + logger.warning(f"SAC agent path not found for resume: {load_path}. Starting fresh.") + + def _training_loop(self, agent: SACTradingAgent, env: TradingEnv) -> str | None: + """Runs the main SAC training loop.""" + total_steps = self.sac_cfg.get('total_training_steps', 100000) + batch_size = self.sac_cfg.get('batch_size', 256) + log_interval = self.sac_cfg.get('log_interval', 1000) + save_interval = self.sac_cfg.get('save_interval', 10000) + generate_new_on_epoch = self.config.get('experience', {}).get('generate_new_on_epoch', False) + epoch_len = env.n_steps + + # Initialize TensorBoard + summary_writer = None + try: + summary_writer = tf.summary.create_file_writer(self.sac_tb_log_dir) + log_to_tensorboard = True + logger.info(f"TensorBoard logging initialized to: {self.sac_tb_log_dir}") + except Exception as e: + logger.error(f"Failed to initialize TensorBoard: {e}") + log_to_tensorboard = False + + logger.info(f"Starting SAC training loop: {total_steps=}, {batch_size=}, {log_interval=}, {save_interval=}") + state = env.reset() + episode_reward = 0 + episode_steps = 0 + episode_num = 0 + all_rewards = [] + final_save_path = None # Track the last save path + + for step in tqdm(range(total_steps), desc="SAC Training Steps"): + if generate_new_on_epoch and step > 0 and step % epoch_len == 0: + logger.info(f"Start of epoch {step // epoch_len + 1}. Clearing replay buffer.") + agent.clear_buffer() + + action = agent.get_action(state, deterministic=False) + next_state, reward, done, info = env.step(action[0]) + agent.buffer.add(state, action, reward, next_state, float(done)) + state = next_state + episode_reward += reward + episode_steps += 1 + + if len(agent.buffer) >= agent.min_buffer_size: + metrics = agent.train(batch_size) + if metrics and log_to_tensorboard and summary_writer: + with summary_writer.as_default(step=step): + for key, value in metrics.items(): + tf.summary.scalar(f'SAC_Metrics/{key}', value) + + if done: + if log_to_tensorboard and summary_writer: + with summary_writer.as_default(step=step): + tf.summary.scalar('SAC_Episode/Reward', episode_reward) + tf.summary.scalar('SAC_Episode/Steps', episode_steps) + all_rewards.append(episode_reward) + state = env.reset() + episode_reward = 0 + episode_steps = 0 + episode_num += 1 + + if (step + 1) % save_interval == 0: + save_path = os.path.join(self.sac_run_models_dir, f'sac_agent_step_{step+1}') + os.makedirs(save_path, exist_ok=True) + agent.save(save_path) + logger.info(f"SAC agent weights saved at step {step+1} to {save_path}") + final_save_path = save_path # Update last saved path + + # Save final agent + logger.info("Saving final SAC agent...") + final_save_path = os.path.join(self.sac_run_models_dir, 'sac_agent_final') + os.makedirs(final_save_path, exist_ok=True) + agent.save(final_save_path) + logger.info(f"Final SAC agent weights saved to {final_save_path}") + + # Save rewards + rewards_df = pd.DataFrame({'episode_reward': all_rewards}) + rewards_df.to_csv(os.path.join(self.sac_run_results_dir, f'sac_episode_rewards_{self.sac_train_run_id}.csv')) + + if summary_writer: summary_writer.close() + env.close() + return final_save_path + + def train(self, gru_run_id_for_sac: str) -> str | None: + """ + Main entry point to start the SAC training process. + + Args: + gru_run_id_for_sac (str): The run ID of the GRU pipeline run whose artifacts should be used. + + Returns: + str | None: Path to the final saved SAC agent, or None if training failed. + """ + logger.info(f"=== Starting SAC Training Process (SAC Run ID: {self.sac_train_run_id}) ===") + logger.info(f"Using artifacts from GRU Run ID: {gru_run_id_for_sac}") + + # 1. Load GRU dependencies + gru_dependencies = self._load_gru_dependencies(gru_run_id_for_sac) + if gru_dependencies is None: + logger.error("Failed to load GRU dependencies. Aborting SAC training.") + return None + + # 2. Prepare data for SAC environment (using validation set) + env_data = self._prepare_data_for_sac(gru_dependencies) + if env_data is None: + logger.error("Failed to prepare data for SAC environment. Aborting SAC training.") + return None + mu_val, sigma_val, p_cal_val, actual_ret_val = env_data + + # 3. Initialize Environment + logger.info("Initializing Trading Environment...") + env = TradingEnv( + mu_predictions=mu_val, + sigma_predictions=sigma_val, + p_cal_predictions=p_cal_val, + actual_returns=actual_ret_val, + initial_capital=self.env_cfg.get('initial_capital', 10000.0), + transaction_cost=self.env_cfg.get('transaction_cost', 0.0005) + ) + logger.info(f"TradingEnv initialized with {env.n_steps} steps.") + + # 4. Initialize SAC Agent + logger.info("Initializing SAC Agent...") + current_edge_threshold = self.config.get('calibration', {}).get('edge_threshold', 0.55) + agent = SACTradingAgent( + state_dim=env.state_dim, + action_dim=env.action_dim, + gamma=self.sac_cfg.get('gamma', 0.99), + tau=self.sac_cfg.get('tau', 0.005), + initial_lr=self.sac_cfg.get('actor_lr', 3e-4), + lr_decay_rate=self.sac_cfg.get('lr_decay_rate', 0.96), + decay_steps=self.sac_cfg.get('decay_steps', 100000), + buffer_capacity=self.sac_cfg.get('buffer_max_size', 100000), + ou_noise_stddev=self.sac_cfg.get('ou_noise_stddev', 0.2), + alpha=self.sac_cfg.get('alpha', 0.2), + alpha_auto_tune=self.sac_cfg.get('alpha_auto_tune', True), + target_entropy=self.sac_cfg.get('target_entropy', -1.0 * env.action_dim), + min_buffer_size=self.sac_cfg.get('min_buffer_size', 1000), + edge_threshold_config=current_edge_threshold # Pass edge threshold + ) + logger.info("SAC Agent initialized.") + + # 5. Load agent weights if resuming + self._load_agent_for_resume(agent) + + # 6. Run training loop + final_agent_path = self._training_loop(agent, env) + + if final_agent_path: + logger.info(f"=== SAC Training Process Completed Successfully ===") + else: + logger.error("=== SAC Training Process Failed ===") + + return final_agent_path \ No newline at end of file diff --git a/gru_sac_predictor/src/trading_env.py b/gru_sac_predictor/src/trading_env.py new file mode 100644 index 00000000..36c6a4d9 --- /dev/null +++ b/gru_sac_predictor/src/trading_env.py @@ -0,0 +1,137 @@ +""" +Simplified Trading Environment for SAC Training. + +Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns. +""" +import numpy as np +import pandas as pd +import logging + +env_logger = logging.getLogger(__name__) + +class TradingEnv: + def __init__(self, + mu_predictions: np.ndarray, + sigma_predictions: np.ndarray, + p_cal_predictions: np.ndarray, + actual_returns: np.ndarray, + initial_capital: float = 10000.0, + transaction_cost: float = 0.0005): + """ + Initialize the environment. + + Args: + mu_predictions: Predicted log returns (μ̂). + sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)). + p_cal_predictions: Calibrated probability of price increase (p_cal). + actual_returns: Actual log returns (y_ret). + initial_capital: Starting capital for simulation (used notionally in reward). + transaction_cost: Fractional cost per trade. + """ + assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \ + "All input arrays must have the same length" + + self.mu = mu_predictions + self.sigma = sigma_predictions + self.p_cal = p_cal_predictions + self.actual_returns = actual_returns + + self.initial_capital = initial_capital + self.transaction_cost = transaction_cost + + self.n_steps = len(actual_returns) + self.current_step = 0 + self.current_position = 0.0 # Fraction of capital (-1 to 1) + self.current_capital = initial_capital # Track for info, not used in reward directly + + # State dimension: [mu, sigma, edge, |mu|/sigma, position] + self.state_dim = 5 + self.action_dim = 1 + + env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.") + + def _get_state(self) -> np.ndarray: + """Construct the state vector for the current step.""" + if self.current_step >= self.n_steps: + # Handle episode end - return a dummy state or zeros + return np.zeros(self.state_dim, dtype=np.float32) + + mu_t = self.mu[self.current_step] + sigma_t = self.sigma[self.current_step] + p_cal_t = self.p_cal[self.current_step] + + edge_t = 2 * p_cal_t - 1 + z_score_t = np.abs(mu_t) / (sigma_t + 1e-9) + + # State uses position *before* the action for this step is taken + state = np.array([ + mu_t, + sigma_t, + edge_t, + z_score_t, + self.current_position + ], dtype=np.float32) + return state + + def reset(self) -> np.ndarray: + """Reset the environment to the beginning.""" + self.current_step = 0 + self.current_position = 0.0 + self.current_capital = self.initial_capital + env_logger.debug("Environment reset.") + return self._get_state() + + def step(self, action: float) -> tuple[np.ndarray, float, bool, dict]: + """ + Execute one time step. + + Args: + action (float): Agent's desired position size (-1 to 1). + + Returns: + tuple: (next_state, reward, done, info_dict) + """ + if self.current_step >= self.n_steps: + # Should not happen if 'done' is handled correctly, but as safeguard + env_logger.warning("Step called after environment finished.") + return self._get_state(), 0.0, True, {} + + # Action is the TARGET position for the *end* of this step + target_position = np.clip(action, -1.0, 1.0) + trade_size = target_position - self.current_position + + # Calculate PnL based on position held *during* this step + step_actual_return = self.actual_returns[self.current_step] + # Use simple return for PnL calculation: exp(log_ret) - 1 + pnl_fraction = self.current_position * (np.exp(step_actual_return) - 1) + + # Calculate transaction costs for the trade executed now + cost_fraction = abs(trade_size) * self.transaction_cost + + # Reward is net PnL fraction (doesn't scale with capital directly) + reward = pnl_fraction - cost_fraction + + # Update internal state for the *next* step + self.current_position = target_position + self.current_capital *= (1 + pnl_fraction - cost_fraction) # Update tracked capital + self.current_step += 1 + + # Check if done + done = self.current_step >= self.n_steps or self.current_capital <= 0 + + next_state = self._get_state() + info = {'capital': self.current_capital, 'position': self.current_position} + + # Log step details periodically + # if self.current_step % 1000 == 0: + # env_logger.debug(f"Step {self.current_step}: Action={action:.2f}, Pos={self.current_position:.2f}, Ret={step_actual_return:.5f}, Rew={reward:.5f}, Cap={self.current_capital:.2f}") + + if done: + env_logger.info(f"Environment finished at step {self.current_step}. Final Capital: {self.current_capital:.2f}") + + return next_state, reward, done, info + + def close(self): + """Clean up any resources (if needed).""" + env_logger.info("TradingEnv closed.") + pass \ No newline at end of file diff --git a/gru_sac_predictor/src/trading_pipeline.py b/gru_sac_predictor/src/trading_pipeline.py new file mode 100644 index 00000000..36041172 --- /dev/null +++ b/gru_sac_predictor/src/trading_pipeline.py @@ -0,0 +1,1003 @@ +""" +Main Orchestrator for the Trading Pipeline. + +Coordinates data loading, feature engineering, model training, calibration, +SAC training, and backtesting. +""" + +import os +import sys +import logging +import yaml +import pandas as pd +import numpy as np +from datetime import datetime +import argparse +import joblib +import json + +# Determine the project root directory based on the script location +# This assumes the script is in src/ and the project root is two levels up +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(os.path.dirname(script_dir)) +# Add project root to sys.path to allow absolute imports from the package +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# Now use absolute imports based on the package structure +from gru_sac_predictor.src.data_loader import DataLoader +# Import other components as they are created +from gru_sac_predictor.src.feature_engineer import FeatureEngineer +# Try importing minimal_whitelist from features.py (assuming it exists there) +try: + from gru_sac_predictor.src.features import minimal_whitelist +except ImportError: + # Fallback: Define it here if features.py doesn't exist or doesn't have it + logging.warning("Could not import minimal_whitelist from .features, defining fallback.") + minimal_whitelist = [ + "return_1m", "return_15m", "return_60m", "ATR_14", "volatility_14d", + "chaikin_AD_10", "svi_10", "EMA_10", "EMA_50", "MACD", "MACD_signal", + "hour_sin", "hour_cos", + ] +from gru_sac_predictor.src.gru_model_handler import GRUModelHandler +from gru_sac_predictor.src.calibrator import Calibrator +from gru_sac_predictor.src.sac_trainer import SACTrainer +from gru_sac_predictor.src.backtester import Backtester + +# Removed redundant imports for feature selection +from sklearn.preprocessing import StandardScaler + + +logger = logging.getLogger(__name__) # Use module-level logger + +class TradingPipeline: + """Orchestrates the entire trading strategy pipeline.""" + + def __init__(self, config_path: str): + """Initialize the pipeline with configuration.""" + self.config_path = config_path + self.config = self._load_config() + self.run_id = self._generate_run_id() + self._setup_directories() + self._setup_logging() + logging.info(f"--- Starting Pipeline Run: {self.run_id} ---") + logging.info(f"Using config: {self.config_path}") + + # Instantiate components + db_dir_from_config = self.config['data']['db_dir'] + self.data_loader = DataLoader(db_dir=db_dir_from_config) + self.feature_engineer = FeatureEngineer(minimal_whitelist=minimal_whitelist) + self.gru_handler = GRUModelHandler(run_id=self.run_id, models_dir=self.current_run_models_dir) + cal_cfg = self.config.get('calibration', {}) + self.calibrator = Calibrator(edge_threshold=cal_cfg.get('edge_threshold', 0.55)) + self.sac_trainer = None + self.backtester = Backtester(config=self.config) + + # Initialize data/state variables + self.df_raw = None + self.df_engineered_full = None + self.df_features_minimal = None + self.df_targets = None + self.df_train = None + self.df_val = None + self.df_test = None + self.X_train_raw = None + self.X_val_raw = None + self.X_test_raw = None + self.y_train = None + self.y_val = None + self.y_test = None + self.y_dir_train = None + self.final_whitelist = None + self.scaler = None + self.X_train_pruned = None + self.X_val_pruned = None + self.X_test_pruned = None + self.X_train_scaled = None + self.X_val_scaled = None + self.X_test_scaled = None + self.X_train_seq, self.y_train_seq_dict = None, None + self.X_val_seq, self.y_val_seq_dict = None, None + self.X_test_seq, self.y_test_seq_dict = None, None + self.gru_model = None + self.gru_model_run_id_loaded_from = None + self.optimal_T = None + self.sac_agent_load_path = None + self.train_indices, self.val_indices, self.test_indices = None, None, None + self.backtest_results_df = None + self.backtest_metrics = None + + self._save_run_config() + + def _load_config(self) -> dict: + """Loads the YAML configuration file.""" + try: + # Try loading relative to the script first (if running from src) + if not os.path.isabs(self.config_path): + potential_path = os.path.join(script_dir, self.config_path) + if not os.path.exists(potential_path): + # If not found relative to script, try relative to project root + potential_path = os.path.join(project_root, self.config_path) + if not os.path.exists(potential_path): + # If still not found, try relative to CWD as last resort + potential_path = os.path.abspath(self.config_path) + + if os.path.exists(potential_path): + self.config_path = potential_path + else: + # Try one level up from project root (common structure) + potential_path = os.path.join(os.path.dirname(project_root), 'gru_sac_predictor', 'config.yaml') + if os.path.exists(potential_path): + self.config_path = potential_path + else: + raise FileNotFoundError(f"Config file not found at relative paths, CWD, or common location: {self.config_path}") + + with open(self.config_path, 'r') as f: + config = yaml.safe_load(f) + # Basic validation + if 'data' not in config or 'gru' not in config or 'sac' not in config: + raise ValueError("Config file missing essential sections: data, gru, sac") + # Validate calibration config if present + if 'calibration' in config and 'edge_threshold' not in config['calibration']: + logging.warning("'edge_threshold' not found in calibration config, using default 0.55") + config['calibration']['edge_threshold'] = 0.55 # Add default if missing + elif 'calibration' not in config: + logging.warning("'calibration' section not found in config, using default edge_threshold 0.55") + config['calibration'] = {'edge_threshold': 0.55} # Add default section + + return config + except FileNotFoundError: + print(f"ERROR: Configuration file not found at '{self.config_path}'") + sys.exit(1) + except yaml.YAMLError as e: + print(f"ERROR: Error parsing configuration file '{self.config_path}': {e}") + sys.exit(1) + except Exception as e: + print(f"ERROR: An unexpected error occurred while loading config: {e}") + sys.exit(1) + + def _generate_run_id(self) -> str: + """Generates a unique run ID based on the template in config.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + template = self.config.get('run_id_template', '{timestamp}') + return template.format(timestamp=timestamp) + + def _setup_directories(self): + """Creates directories for logs, results, and models for this run.""" + self.dirs = {} + base_dirs_config = self.config.get('base_dirs', {}) + # Calculate base models dir path (needed for loading previous models) + # Assume it's relative to project root + models_rel_path = base_dirs_config.get('models', 'models') + self.base_models_dir_path = os.path.join(project_root, models_rel_path) + + for dir_type, rel_path in base_dirs_config.items(): + # Paths are relative to the project root + abs_path = os.path.join(project_root, rel_path, self.run_id) + os.makedirs(abs_path, exist_ok=True) + self.dirs[dir_type] = abs_path + # No need to log here, happens in _setup_logging + + # Specific dir for current run models (if models base dir exists) + if 'models' in self.dirs: + self.current_run_models_dir = self.dirs['models'] + else: + # Fallback if models base dir not in config + self.current_run_models_dir = os.path.join(self.base_models_dir_path, self.run_id) # Use calculated base path + os.makedirs(self.current_run_models_dir, exist_ok=True) + # Log this warning after logging is set up + # logging.warning(f"'models' base dir not found in config, using default: {self.current_run_models_dir}") + + def _setup_logging(self): + """Configures logging to file and console.""" + log_dir = self.dirs.get('logs') + if not log_dir: + print(f"Warning: 'logs' directory not configured. Logging to console only.") + log_file_path = None + else: + log_file_path = os.path.join(log_dir, f'pipeline_{self.run_id}.log') + + # Remove existing handlers to avoid duplicate logs if re-initialized + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + log_level = logging.INFO # Consider making this configurable + + handlers = [logging.StreamHandler(sys.stdout)] + if log_file_path: + handlers.append(logging.FileHandler(log_file_path)) + + logging.basicConfig(level=log_level, format=log_format, handlers=handlers) + + # Configure TensorFlow logging (optional, reduces verbosity) + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # ERROR messages only + logging.getLogger('tensorflow').setLevel(logging.ERROR) + + # Now log the directory paths + logging.info(f"Using Base Models Directory: {self.base_models_dir_path}") + for dir_type, abs_path in self.dirs.items(): + logging.info(f"Using {dir_type} directory: {abs_path}") + if 'models' not in self.dirs: + logging.warning(f"'models' base dir not found in config, using default models dir: {self.current_run_models_dir}") + + logging.info(f"Logging setup complete. Log file: {log_file_path if log_file_path else 'Console only'}") + + def _save_run_config(self): + """Saves the configuration used for this run.""" + results_dir = self.dirs.get('results') + if results_dir: + config_save_path = os.path.join(results_dir, 'run_config.yaml') + try: + with open(config_save_path, 'w') as f: + yaml.dump(self.config, f, default_flow_style=False) + logging.info(f"Saved run configuration to {config_save_path}") + except Exception as e: + logging.error(f"Failed to save run configuration: {e}") + else: + logging.warning("'results' directory not configured. Skipping saving run config.") + + # --- Pipeline Stages --- + + def load_and_preprocess_data(self): + """Loads raw data using DataLoader and performs initial checks.""" + logging.info("--- Stage: Loading and Preprocessing Data ---") + data_cfg = self.config['data'] + self.df_raw = self.data_loader.load_data( + ticker=data_cfg['ticker'], + exchange=data_cfg['exchange'], + start_date=data_cfg['start_date'], + end_date=data_cfg['end_date'], + interval=data_cfg['interval'] + ) + if self.df_raw is None or self.df_raw.empty: + logging.error("Failed to load data. Exiting.") + sys.exit(1) + + # Basic checks + if not isinstance(self.df_raw.index, pd.DatetimeIndex): + logging.error("Data index is not a DatetimeIndex after loading. Exiting.") + sys.exit(1) + if self.df_raw.index.tz is None or self.df_raw.index.tz.zone.upper() != 'UTC': # Case-insensitive check + logging.warning(f"Data index timezone is not UTC ({self.df_raw.index.tz}). Attempting conversion.") + try: + if self.df_raw.index.tz is None: + self.df_raw = self.df_raw.tz_localize('UTC') + else: + self.df_raw = self.df_raw.tz_convert('UTC') + logging.info(f"Data index timezone converted to UTC.") + except Exception as e: + logging.error(f"Failed to convert index timezone to UTC: {e}. Exiting.") + sys.exit(1) + + # Drop rows with NaN in essential OHLCV columns early + initial_rows = len(self.df_raw) + self.df_raw.dropna(subset=['open', 'high', 'low', 'close', 'volume'], inplace=True) + if len(self.df_raw) < initial_rows: + logging.warning(f"Dropped {initial_rows - len(self.df_raw)} rows with NaN in OHLCV during loading.") + + logging.info(f"Raw data loaded successfully: {self.df_raw.shape[0]} rows from {self.df_raw.index.min()} to {self.df_raw.index.max()}") + + def engineer_features(self): + """Adds features using FeatureEngineer.""" + logging.info("--- Stage: Engineering Features ---") + if self.df_raw is None: + logging.error("Raw data not loaded. Cannot engineer features.") + sys.exit(1) + + # Add base features (cyclical, imbalance, TA) + self.df_engineered_full = self.feature_engineer.add_base_features(self.df_raw.copy()) + + # Prune to minimal whitelist immediately for comparison/logging if needed + # self.df_features_minimal = self.feature_engineer.prune_features(self.df_engineered_full, self.feature_engineer.minimal_whitelist) + # logger.info(f"Features pruned to minimal whitelist ({self.df_features_minimal.shape[1]}): {self.df_features_minimal.columns.tolist()}") + + # Drop rows with NaNs potentially introduced by feature engineering (especially rolling features at the start) + initial_rows = len(self.df_engineered_full) + self.df_engineered_full.dropna(inplace=True) + if len(self.df_engineered_full) < initial_rows: + logging.warning(f"Dropped {initial_rows - len(self.df_engineered_full)} rows with NaN values after feature engineering.") + + if self.df_engineered_full.empty: + logging.error("DataFrame is empty after feature engineering and NaN removal. Exiting.") + sys.exit(1) + + logging.info(f"Feature engineering complete. Shape: {self.df_engineered_full.shape}") + + def define_labels_and_align(self): + """Defines prediction labels (returns, direction) and aligns with features.""" + logging.info("--- Stage: Defining Labels and Aligning ---") + if self.df_engineered_full is None: + logging.error("Engineered data not available. Cannot define labels.") + sys.exit(1) + + # Calculate forward returns and direction based on 'close' price + # Note: Ensure 'close' is present from the original raw data load + if 'close' not in self.df_engineered_full.columns: + logging.error("'close' column missing in engineered data. Cannot define labels.") + sys.exit(1) + + horizon = self.config['gru'].get('prediction_horizon', 5) # Default horizon if not in config + target_ret_col = f'fwd_log_ret_{horizon}' + target_dir_col = f'direction_label_{horizon}' + + # Shift close price into the future by 'horizon' periods + shifted_close = self.df_engineered_full['close'].shift(-horizon) + + # Calculate log return + self.df_engineered_full[target_ret_col] = np.log(shifted_close / self.df_engineered_full['close']) + + # Calculate direction label (1 if future price > current price, 0 otherwise) + self.df_engineered_full[target_dir_col] = (self.df_engineered_full[target_ret_col] > 0).astype(int) + + # Drop rows where targets are NaN (due to the shift at the end of the DataFrame) + initial_rows = len(self.df_engineered_full) + self.df_engineered_full.dropna(subset=[target_ret_col, target_dir_col], inplace=True) + final_rows = len(self.df_engineered_full) + if final_rows < initial_rows: + logging.info(f"Dropped {initial_rows - final_rows} rows due to NaN targets (horizon={horizon}).") + + if self.df_engineered_full.empty: + logging.error("DataFrame is empty after defining labels and dropping NaNs. Exiting.") + sys.exit(1) + + # Separate features (X) and targets (y) - X contains all engineered features for now + self.X_raw_aligned = self.df_engineered_full.drop(columns=[target_ret_col, target_dir_col]) + self.y_aligned = self.df_engineered_full[[target_ret_col, target_dir_col]] + self.y_dir_aligned = self.df_engineered_full[target_dir_col] # Keep separate handle for feature selection + + logging.info(f"Labels (horizon={horizon}) defined and aligned. Features shape: {self.X_raw_aligned.shape}, Targets shape: {self.y_aligned.shape}") + + def split_data(self): + """Splits features and targets into train, validation, and test sets chronologically.""" + logging.info("--- Stage: Splitting Data ---") + if self.X_raw_aligned is None or self.y_aligned is None: + logging.error("Aligned features/targets not available for splitting.") + sys.exit(1) + if not isinstance(self.X_raw_aligned.index, pd.DatetimeIndex): + logging.error("Feature index must be DatetimeIndex for chronological split. Aborting.") + sys.exit(1) + + split_cfg = self.config['split_ratios'] + train_ratio = split_cfg['train'] + val_ratio = split_cfg['validation'] + test_ratio = round(1.0 - train_ratio - val_ratio, 2) + logger.info(f"Using split ratios: Train={train_ratio:.2f}, Val={val_ratio:.2f}, Test={test_ratio:.2f}") + + total_len = len(self.X_raw_aligned) + train_end_idx = int(total_len * train_ratio) + val_end_idx = int(total_len * (train_ratio + val_ratio)) + + # Split features + self.X_train_raw = self.X_raw_aligned.iloc[:train_end_idx] + self.X_val_raw = self.X_raw_aligned.iloc[train_end_idx:val_end_idx] + self.X_test_raw = self.X_raw_aligned.iloc[val_end_idx:] + + # Split targets + self.y_train = self.y_aligned.iloc[:train_end_idx] + self.y_val = self.y_aligned.iloc[train_end_idx:val_end_idx] + self.y_test = self.y_aligned.iloc[val_end_idx:] + + # Keep separate handle to direction target for training feature selector + self.y_dir_train = self.y_dir_aligned.iloc[:train_end_idx] + + logging.info(f"Data split complete:") + logging.info(f" Train: X={self.X_train_raw.shape}, y={self.y_train.shape} ({self.X_train_raw.index.min()} to {self.X_train_raw.index.max()})") + logging.info(f" Val: X={self.X_val_raw.shape}, y={self.y_val.shape} ({self.X_val_raw.index.min()} to {self.X_val_raw.index.max()})") + logging.info(f" Test: X={self.X_test_raw.shape}, y={self.y_test.shape} ({self.X_test_raw.index.min()} to {self.X_test_raw.index.max()})") + + if len(self.X_train_raw) == 0 or len(self.X_val_raw) == 0 or len(self.X_test_raw) == 0: + logging.error("One or more data splits are empty. Check data length and split ratios. Aborting.") + sys.exit(1) + + def select_and_prune_features(self): + """Performs feature selection (e.g., VIF, L1) and prunes data splits.""" + logging.info("--- Stage: Selecting and Pruning Features ---") + if self.X_train_raw is None or self.y_dir_train is None: + logging.error("Training data (X_train_raw, y_dir_train) not available for feature selection.") + sys.exit(1) + + # Perform feature selection using the training set + self.final_whitelist = self.feature_engineer.select_features( + self.X_train_raw, + self.y_dir_train, + # Optionally get VIF threshold from config? Defaulting for now. + # vif_threshold=self.config.get('feature_selection', {}).get('vif_threshold', 10.0) + ) + + # --- Save the final whitelist --- # Should this be done here or in FeatureEngineer? + # Let's keep it here for pipeline-level artifact saving. + whitelist_save_path = os.path.join(self.current_run_models_dir, f'final_whitelist_{self.run_id}.json') + try: + with open(whitelist_save_path, 'w') as f: + json.dump(self.final_whitelist, f, indent=4) + logging.info(f"Saved final feature whitelist ({len(self.final_whitelist)} features) to {whitelist_save_path}") + except Exception as e: + logging.error(f"Failed to save final feature whitelist: {e}", exc_info=True) + # Decide if this is critical - maybe abort if we can't save it? + pass + + # Prune all data splits using the final whitelist + logging.info(f"Pruning feature sets using final whitelist: {self.final_whitelist}") + self.X_train_pruned = self.feature_engineer.prune_features(self.X_train_raw, self.final_whitelist) + self.X_val_pruned = self.feature_engineer.prune_features(self.X_val_raw, self.final_whitelist) + self.X_test_pruned = self.feature_engineer.prune_features(self.X_test_raw, self.final_whitelist) + + logging.info(f"Feature shapes after pruning: Train={self.X_train_pruned.shape}, Val={self.X_val_pruned.shape}, Test={self.X_test_pruned.shape}") + + # Verify all splits have the same columns after pruning + if not (self.X_train_pruned.columns.equals(self.X_val_pruned.columns) and + self.X_train_pruned.columns.equals(self.X_test_pruned.columns)): + logging.error("Column mismatch between pruned data splits. Check pruning logic.") + # Log details for debugging + logging.error(f"Train cols: {self.X_train_pruned.columns.tolist()}") + logging.error(f"Val cols: {self.X_val_pruned.columns.tolist()}") + logging.error(f"Test cols: {self.X_test_pruned.columns.tolist()}") + sys.exit(1) + + # Check if feature sets are empty after pruning + if self.X_train_pruned.empty or self.X_val_pruned.empty or self.X_test_pruned.empty: + logging.error("One or more feature splits are empty after pruning. Exiting.") + sys.exit(1) + + def scale_features(self): + """Scales features using StandardScaler fitted on the training set.""" + logging.info("--- Stage: Scaling Features ---") + if self.X_train_pruned is None or self.X_val_pruned is None or self.X_test_pruned is None: + logging.error("Pruned feature sets not available for scaling.") + sys.exit(1) + + scaler_path = os.path.join(self.current_run_models_dir, f'feature_scaler_{self.run_id}.joblib') + + # Ensure we only scale numeric columns + numeric_cols = self.X_train_pruned.select_dtypes(include=np.number).columns + if len(numeric_cols) < self.X_train_pruned.shape[1]: + non_numeric_cols = self.X_train_pruned.select_dtypes(exclude=np.number).columns + logging.warning(f"Non-numeric columns detected in pruned features: {non_numeric_cols.tolist()}. These will not be scaled.") + # If non-numeric columns exist, they should ideally be handled earlier (e.g., encoding) or excluded. + # For now, we proceed by scaling only numeric ones, but this might indicate an issue. + + if not numeric_cols.empty: + # Check if scaler was loaded previously (when loading GRU) + if self.scaler is None: + logging.info("Fitting StandardScaler on training data (numeric columns only)...") + self.scaler = StandardScaler() + self.scaler.fit(self.X_train_pruned[numeric_cols]) + + # Save the fitted scaler + try: + joblib.dump(self.scaler, scaler_path) + logging.info(f"Feature scaler saved to {scaler_path}") + except Exception as e: + logging.error(f"Failed to save feature scaler: {e}") + else: + logging.info("Using pre-loaded scaler for feature scaling.") + + # Apply scaling to all splits (numeric columns only) + # Create copies to store scaled data, preserving original pruned dataframes + self.X_train_scaled = self.X_train_pruned.copy() + self.X_val_scaled = self.X_val_pruned.copy() + self.X_test_scaled = self.X_test_pruned.copy() + + self.X_train_scaled[numeric_cols] = self.scaler.transform(self.X_train_pruned[numeric_cols]) + self.X_val_scaled[numeric_cols] = self.scaler.transform(self.X_val_pruned[numeric_cols]) + self.X_test_scaled[numeric_cols] = self.scaler.transform(self.X_test_pruned[numeric_cols]) + logging.info("Features scaled successfully.") + else: + logging.warning("No numeric columns found to scale. Skipping scaling step.") + # Assign unscaled data to scaled variables to allow pipeline continuation + self.X_train_scaled = self.X_train_pruned + self.X_val_scaled = self.X_val_pruned + self.X_test_scaled = self.X_test_pruned + + def run_baseline_checks(self): + """(Optional) Runs baseline model checks.""" + logging.info("--- Stage: Baseline Checks (Placeholder) ---") + # Placeholder - Implement if needed, e.g., LogReg on minimal features + logging.warning("Baseline checks stage not implemented.") + + def create_sequences(self): + """Creates sequences for GRU input using scaled features and aligned targets.""" + logging.info("--- Stage: Creating Sequences ---") + if self.X_train_scaled is None or self.y_train is None or \ + self.X_val_scaled is None or self.y_val is None or \ + self.X_test_scaled is None or self.y_test is None: + logging.error("Scaled features or aligned targets not available for sequence creation.") + sys.exit(1) + + lookback = self.config['gru'].get('lookback', 60) + horizon = self.config['gru'].get('prediction_horizon', 5) # Needed to identify target columns + target_ret_col = f'fwd_log_ret_{horizon}' + target_dir_col = f'direction_label_{horizon}' + + logging.info(f"Creating sequences with lookback={lookback}") + + # Helper function adapted from run_pipeline.py + def _create_sequences_helper(features_scaled_df, targets_df, lookback, ret_col, dir_col): + # Convert DataFrames to numpy arrays for efficiency + features_np = features_scaled_df.values + # Select target columns and convert to numpy + y_ret_np = targets_df[ret_col].values + y_dir_np = targets_df[dir_col].values + + X, y_ret_seq, y_dir_seq = [], [], [] + # Store original indices corresponding to the *target* timestep + target_indices = [] + + # Iterate from lookback index up to the length of the features + for i in range(lookback, len(features_np)): + # Append the sequence of features [i-lookback : i] + X.append(features_np[i-lookback : i]) + # Append the target value at timestep i-1 (target corresponds to the period *ending* at i) + # The features X[i-lookback : i] predict the target at time i. + y_ret_seq.append(y_ret_np[i]) + y_dir_seq.append(y_dir_np[i]) + # Store the index of the target timestep 'i' + target_indices.append(targets_df.index[i]) + + if not X: # Check if any sequences were created + return None, None, None, None + + # Convert lists to numpy arrays + X_np = np.array(X) + y_ret_seq_np = np.array(y_ret_seq) + y_dir_seq_np = np.array(y_dir_seq) + target_indices_pd = pd.Index(target_indices) # Keep as pandas Index + + return X_np, y_ret_seq_np, y_dir_seq_np, target_indices_pd + + # Create sequences for train, validation, and test sets + self.X_train_seq, y_ret_train_seq, y_dir_train_seq, self.train_indices = _create_sequences_helper( + self.X_train_scaled, self.y_train, lookback, target_ret_col, target_dir_col + ) + self.X_val_seq, y_ret_val_seq, y_dir_val_seq, self.val_indices = _create_sequences_helper( + self.X_val_scaled, self.y_val, lookback, target_ret_col, target_dir_col + ) + self.X_test_seq, y_ret_test_seq, y_dir_test_seq, self.test_indices = _create_sequences_helper( + self.X_test_scaled, self.y_test, lookback, target_ret_col, target_dir_col + ) + + # Check if sequences were created successfully + if self.X_train_seq is None or self.X_val_seq is None: + logger.error(f"Sequence creation resulted in empty train or val arrays. Check lookback ({lookback}) vs split sizes. Aborting.") + sys.exit(1) + + logging.info(f"Sequence shapes created:") + logging.info(f" Train: X={self.X_train_seq.shape}, y_ret={y_ret_train_seq.shape}, y_dir={y_dir_train_seq.shape}") + logging.info(f" Val: X={self.X_val_seq.shape}, y_ret={y_ret_val_seq.shape}, y_dir={y_dir_val_seq.shape}") + logging.info(f" Test: X={self.X_test_seq.shape if self.X_test_seq is not None else 'None'}, y_ret={y_ret_test_seq.shape if y_ret_test_seq is not None else 'None'}, y_dir={y_dir_test_seq.shape if y_dir_test_seq is not None else 'None'}") + + # Prepare target dictionaries required by the GRU model's training and evaluation + self.y_train_seq_dict = { + "ret": y_ret_train_seq, + "gauss_params": y_ret_train_seq, # Use ret target for NLL too + "dir": y_dir_train_seq + } + self.y_val_seq_dict = { + "ret": y_ret_val_seq, + "gauss_params": y_ret_val_seq, + "dir": y_dir_val_seq + } + # Test targets dictionary (useful for later evaluation/backtesting) + if y_ret_test_seq is not None and y_dir_test_seq is not None: + self.y_test_seq_dict = { + "ret": y_ret_test_seq, + "gauss_params": y_ret_test_seq, + "dir": y_dir_test_seq + } + else: + self.y_test_seq_dict = None + logging.warning("Test sequences or targets could not be created. Backtesting might fail.") + + def train_or_load_gru(self): + """Trains a new GRU model or loads a pre-trained one using GRUModelHandler.""" + logging.info("--- Stage: Training or Loading GRU Model ---") + gru_cfg = self.config['gru'] + train_gru_flag = self.config['control'].get('train_gru', False) + + if train_gru_flag: + logging.info(f"Attempting to train a new GRU model for run {self.run_id}...") + if self.X_train_seq is None or self.y_train_seq_dict is None or \ + self.X_val_seq is None or self.y_val_seq_dict is None: + logging.error("Sequence data (train/val) not available for GRU training. Exiting.") + sys.exit(1) + + # Get parameters from config + lookback = gru_cfg.get('lookback', 60) + # Get feature count from scaled data (use shape[2] for sequences) + n_features = self.X_train_seq.shape[2] + epochs = gru_cfg.get('epochs', 25) + batch_size = gru_cfg.get('batch_size', 128) + patience = gru_cfg.get('patience', 5) # Use gru patience from config + + # Train the model + self.gru_model, history = self.gru_handler.train( + X_train=self.X_train_seq, + y_train_dict=self.y_train_seq_dict, + X_val=self.X_val_seq, + y_val_dict=self.y_val_seq_dict, + lookback=lookback, + n_features=n_features, + max_epochs=epochs, + batch_size=batch_size, + patience=patience + ) + + if self.gru_model is None: + logging.error("GRU model training failed. Exiting.") + sys.exit(1) + else: + # Save the newly trained model + saved_path = self.gru_handler.save() # Uses run_id from handler + if saved_path: + logging.info(f"Newly trained GRU model saved to {saved_path}") + else: + logging.warning("Failed to save the newly trained GRU model.") + # Set the loaded ID to the current run ID + self.gru_model_run_id_loaded_from = self.run_id + logging.info(f"Using GRU model trained in current run: {self.run_id}") + + else: # Load pre-trained GRU model + load_run_id = gru_cfg.get('model_load_run_id', None) + if not load_run_id: + logging.error("train_gru is False, but no gru.model_load_run_id specified in config. Exiting.") + sys.exit(1) + + logging.info(f"Attempting to load pre-trained GRU model from run ID: {load_run_id}") + # Construct the expected path using the base models directory + model_filename = f'gru_model_{load_run_id}.keras' # Assuming .keras extension + model_path = os.path.join(self.base_models_dir_path, f'run_{load_run_id}', model_filename) + + # Load the model using the handler + self.gru_model = self.gru_handler.load(model_path) + + if self.gru_model is None: + logging.error(f"Failed to load GRU model from path: {model_path}. Exiting.") + sys.exit(1) + else: + self.gru_model_run_id_loaded_from = load_run_id + logging.info(f"Successfully loaded GRU model from run: {load_run_id}") + + # --- Try loading associated scaler --- # + scaler_filename = f'feature_scaler_{load_run_id}.joblib' + scaler_load_path = os.path.join(self.base_models_dir_path, f'run_{load_run_id}', scaler_filename) + logging.info(f"Attempting to load associated scaler from: {scaler_load_path}") + if os.path.exists(scaler_load_path): + try: + self.scaler = joblib.load(scaler_load_path) + logging.info("Associated feature scaler loaded successfully.") + # --- Re-apply scaling using the loaded scaler --- # + numeric_cols = self.X_train_pruned.select_dtypes(include=np.number).columns + if not numeric_cols.empty: + # Important: Scale the _pruned data before sequence creation if scaler is loaded + logging.info("Re-scaling features using loaded scaler...") + self.X_train_scaled = self.X_train_pruned.copy() + self.X_val_scaled = self.X_val_pruned.copy() + self.X_test_scaled = self.X_test_pruned.copy() + self.X_train_scaled[numeric_cols] = self.scaler.transform(self.X_train_pruned[numeric_cols]) + self.X_val_scaled[numeric_cols] = self.scaler.transform(self.X_val_pruned[numeric_cols]) + self.X_test_scaled[numeric_cols] = self.scaler.transform(self.X_test_pruned[numeric_cols]) + logging.info("Features re-scaled successfully.") + # Need to recreate sequences after re-scaling! + self.create_sequences() + else: + logging.warning("Loaded scaler, but no numeric columns found in pruned data to re-scale.") + except Exception as e: + logging.error(f"Failed to load or apply associated scaler: {e}. Scaling might be inconsistent. Exiting.") + sys.exit(1) + else: + logging.error(f"Associated feature scaler not found at {scaler_load_path} for run {load_run_id}. Cannot proceed. Exiting.") + sys.exit(1) + # --- End Scaler Loading/Applying --- # + + # Final check: Ensure a GRU model is loaded/trained + if self.gru_model is None: + logging.error("No GRU model is available after train/load step. Exiting.") + sys.exit(1) + + def calibrate_probabilities(self): + """Calibrates GRU output probabilities using temperature scaling.""" + logging.info("--- Stage: Calibrating Probabilities ---") + if self.gru_model is None: + logging.error("GRU model not available for calibration. Exiting.") + sys.exit(1) + if self.X_val_seq is None or self.y_val_seq_dict is None: + logging.error("Validation sequence data not available for calibration. Exiting.") + sys.exit(1) + + # Check if a calibration temp file exists for the loaded GRU model run + loaded_T = None + if self.gru_model_run_id_loaded_from: + temp_filename = f'calibration_temp_{self.gru_model_run_id_loaded_from}.npy' + temp_load_path = os.path.join(self.base_models_dir_path, f'run_{self.gru_model_run_id_loaded_from}', temp_filename) + if os.path.exists(temp_load_path): + try: + loaded_T = np.load(temp_load_path) + logging.info(f"Loaded calibration temperature T={loaded_T:.4f} from GRU run {self.gru_model_run_id_loaded_from}.") + self.optimal_T = float(loaded_T) # Store the loaded value + except Exception as e: + logging.warning(f"Failed to load calibration temp from {temp_load_path}: {e}. Recalculating.") + loaded_T = None # Ensure recalculation happens + else: + logging.info(f"No existing calibration temperature found for run {self.gru_model_run_id_loaded_from} at {temp_load_path}.") + + # If temperature wasn't loaded, calculate it using the validation set + if loaded_T is None: + logging.info("Calculating optimal temperature on validation set...") + # Get predictions on validation set + predictions_val = self.gru_handler.predict(self.X_val_seq) + if predictions_val is None or len(predictions_val) < 3: + logging.error("Failed to get validation predictions for calibration. Exiting.") + sys.exit(1) + + # Extract raw probabilities P(dir=up) - assuming it's the 3rd output + p_raw_val = predictions_val[2].flatten() + y_dir_val = self.y_val_seq_dict['dir'] + + # Check for length mismatch + if len(p_raw_val) != len(y_dir_val): + logging.error(f"Mismatch between validation predictions ({len(p_raw_val)}) and targets ({len(y_dir_val)}) for calibration. Exiting.") + sys.exit(1) + + # Optimize temperature using the Calibrator instance + self.optimal_T = self.calibrator.optimise_temperature(p_raw_val, y_dir_val) + + # Save the newly calculated temperature for the *current* run ID + temp_save_path = os.path.join(self.current_run_models_dir, f'calibration_temp_{self.run_id}.npy') + try: + np.save(temp_save_path, self.optimal_T) + logging.info(f"Saved newly calculated calibration temperature T={self.optimal_T:.4f} to {temp_save_path}") + except Exception as e: + logging.error(f"Failed to save calibration temperature: {e}") + + # Store the final temperature in the calibrator instance as well + self.calibrator.optimal_T = self.optimal_T + + # Optional: Generate reliability curve plot for validation set + if self.config.get('control', {}).get('generate_plots', True) and 'p_raw_val' in locals(): + logging.info("Generating validation reliability curve plot...") + results_plot_dir = self.dirs.get('results', None) + if results_plot_dir: + rel_curve_path = os.path.join(results_plot_dir, f'reliability_curve_val_{self.run_id}.png') + try: + # Pass calibrated probabilities using the found optimal_T + p_cal_val = self.calibrator.calibrate(p_raw_val) + y_dir_val = self.y_val_seq_dict['dir'] # Ensure y_dir_val is available + # Pass save_path to the method + self.calibrator.reliability_curve( + p_pred=p_cal_val, + y_true=y_dir_val, + plot_title=f"Reliability Curve (Validation, T={self.optimal_T:.2f})", + save_path=rel_curve_path + ) + except Exception as e: + logging.error(f"Failed to generate validation reliability curve: {e}", exc_info=True) + else: + logging.warning("Results directory not found, cannot save reliability curve plot.") + + def train_or_load_sac(self): + """Trains a new SAC agent offline or loads a pre-trained one for backtesting.""" + logging.info("--- Stage: Training or Loading SAC Agent ---") + train_sac_flag = self.config['control'].get('train_sac', False) + + if train_sac_flag: + if self.gru_model_run_id_loaded_from is None: + logging.error("Cannot run SAC training: GRU model run ID is not set (no model trained or loaded). Aborting.") + sys.exit(1) + + logging.info(f"SAC training is enabled. Instantiating SACTrainer...") + # Instantiate SACTrainer, passing necessary base directories from the main pipeline + # Ensure logs/results dirs exist + base_logs = self.dirs.get('logs') + if not base_logs: + base_logs = os.path.join(project_root, 'logs') + os.makedirs(base_logs, exist_ok=True) + logging.warning(f"Using default base logs dir for SACTrainer: {base_logs}") + + base_results = self.dirs.get('results') + if not base_results: + base_results = os.path.join(project_root, 'results') + os.makedirs(base_results, exist_ok=True) + logging.warning(f"Using default base results dir for SACTrainer: {base_results}") + + self.sac_trainer = SACTrainer( + config=self.config, # Pass the full config + base_models_dir=self.base_models_dir_path, + base_logs_dir=base_logs, + base_results_dir=base_results + ) + + # Start the training process + final_agent_path = self.sac_trainer.train(gru_run_id_for_sac=self.gru_model_run_id_loaded_from) + + if final_agent_path: + logger.info(f"SAC training completed. Final agent saved at: {final_agent_path}") + # Set the agent path to the newly trained agent for subsequent backtesting + self.sac_agent_load_path = final_agent_path + else: + logger.error("SAC training failed. Proceeding without a newly trained agent.") + # Decide whether to fallback to loading or abort? Fallback for now. + self.sac_agent_load_path = self._determine_sac_load_path_from_config() + if self.sac_agent_load_path: + logger.warning(f"Falling back to loading SAC agent specified in config: {self.sac_agent_load_path}") + else: + logger.error("SAC training failed and no load path specified in config. Cannot proceed with backtesting.") + # Optionally exit: sys.exit(1) + # For now, allow pipeline to continue, backtester should handle None path + + else: # Load SAC agent based on config for backtesting + logging.info("SAC training is disabled (train_sac=False). Determining agent path to load for backtesting...") + self.sac_agent_load_path = self._determine_sac_load_path_from_config() + if self.sac_agent_load_path: + logger.info(f"SAC agent path for backtesting set to load from: {self.sac_agent_load_path}") + else: + logger.warning("No 'sac_load_run_id' specified in config. Backtester will need to handle using untrained/initial weights.") + + def _determine_sac_load_path_from_config(self) -> str | None: + """Helper to determine the SAC agent load path based on config control flags.""" + load_run_id = self.config['control'].get('sac_load_run_id') + load_step = self.config['control'].get('sac_load_step', 'final') + sac_agent_path = None + if load_run_id: + # Construct path assuming structure like: //agent_.pt + # The sac_train_run_id usually differs from the pipeline run_id + models_base = self.base_models_dir_path # Use the stored base models path + # Assume the SAC trainer saves checkpoints inside its own run folder (e.g., models/sac_train_.../sac_agent_final) + if load_step == 'final': + # SAC trainer saves final model in a folder named 'sac_agent_final' + sac_agent_path = os.path.join(models_base, load_run_id, 'sac_agent_final') + else: + # SAC trainer saves step checkpoints in folder 'sac_agent_step_N' + sac_agent_path = os.path.join(models_base, load_run_id, f'sac_agent_step_{load_step}') + + # Check if the determined path exists + if not os.path.exists(sac_agent_path): + logger.warning(f"Determined SAC load path does not exist: {sac_agent_path}. Will proceed without loading specified agent.") + sac_agent_path = None # Reset path if not found + + return sac_agent_path + + def run_backtest(self): + """Runs the backtest using the trained/loaded models and test data.""" + logging.info("--- Stage: Running Backtest ---") + if not self.config['control'].get('run_backtest', False): + logging.warning("Skipping backtest stage as per config.") + return + + # Check if necessary data is available + if self.X_test_seq is None or self.y_test_seq_dict is None or self.test_indices is None: + logger.error("Test sequence data (X, y, indices) not available for backtesting. Skipping.") + return + if self.gru_handler is None or self.calibrator is None: + logger.error("GRU Handler or Calibrator not initialized. Skipping backtest.") + return + if self.optimal_T is None: + logger.warning("Optimal calibration temperature not set. Calibration might be incorrect. Proceeding...") + # Calibrator will default to T=1.0 if its optimal_T is None + + # Extract original prices for plotting (need to align with test sequence indices) + original_prices = None + if self.y_test is not None: + # Assuming y_test index aligns with X_test_raw index before sequencing + # We need the price *at* the time of the target prediction + # test_indices maps to the target time step `i` in the sequence loop + # Need the price from the *original* test split DF aligned with these indices + if 'close' in self.df_test.columns: # df_test has original columns before feature selection + try: + original_prices = self.df_test.loc[self.test_indices, 'close'] + except KeyError: + logger.warning("Could not align original close prices with test indices. Price plot will be limited.") + except Exception as e: + logger.error(f"Error aligning original prices: {e}") + else: + logger.warning("'close' column not found in df_test. Cannot extract original prices for plotting.") + + # Run the backtest using the Backtester instance + self.backtest_results_df, self.backtest_metrics = self.backtester.run_backtest( + sac_agent_load_path=self.sac_agent_load_path, + X_test_seq=self.X_test_seq, + y_test_seq_dict=self.y_test_seq_dict, + test_indices=self.test_indices, + gru_handler=self.gru_handler, + calibrator=self.calibrator, + original_prices=original_prices + ) + + if self.backtest_results_df is None or self.backtest_metrics is None: + logger.error("Backtesting failed to produce results.") + else: + logger.info("Backtest completed successfully.") + + def save_results(self): + """Saves backtest results, metrics, and plots using the Backtester instance.""" + logging.info("--- Stage: Saving Results ---") + if self.backtest_results_df is None or self.backtest_metrics is None: + logging.warning("No backtest results available to save. Skipping.") + return + + results_dir = self.dirs.get('results') + if not results_dir: + logger.error("Results directory not configured. Cannot save backtest results.") + return + + # Pass results to the backtester's save method + self.backtester.save_results( + results_df=self.backtest_results_df, + metrics=self.backtest_metrics, + results_dir=results_dir, + run_id=self.run_id + ) + + def execute(self): + """Executes the pipeline stages sequentially.""" + logging.info("=== Starting Pipeline Execution ===") + try: + self.load_and_preprocess_data() + self.engineer_features() + self.define_labels_and_align() + self.split_data() + self.select_and_prune_features() + # Scaling is now handled conditionally within train_or_load_gru if loading + # If training, scaling happens afterwards based on fitted scaler. + # self.scale_features() # Moved + self.train_or_load_gru() # Handles loading/training GRU and associated scaler/rescaling + # If GRU was trained, scale features *now* using the fitted scaler + if self.config['control'].get('train_gru', False): + self.scale_features() + # Need to recreate sequences *after* scaling if GRU was trained + self.create_sequences() + self.run_baseline_checks() # Optional + # self.create_sequences() # Sequence creation now happens conditionally after scaling + self.calibrate_probabilities() + self.train_or_load_sac() + self.run_backtest() # Call the implemented backtest method + self.save_results() # Call the implemented save method + logging.info("=== Pipeline Execution Finished Successfully ===") + + except Exception as e: + logging.error(f"Pipeline execution failed: {e}", exc_info=True) + logging.error("=== Pipeline Execution Terminated Due to Error ===") + sys.exit(1) + +# --- Entry Point --- # + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.") + # Default config path relative to the project root/gru_sac_predictor level + # Adjusting default config path seeking strategy + # 1. Relative to project root (../config.yaml from src) + default_config_rel_root = os.path.abspath(os.path.join(script_dir, '..', 'config.yaml')) + # 2. Relative to package dir (../gru_sac_predictor/config.yaml from project root) + default_config_pkg = os.path.join(os.path.dirname(project_root), 'gru_sac_predictor', 'config.yaml') + # 3. Relative to CWD + default_config_cwd = os.path.abspath('config.yaml') + + # Determine the best default path + if os.path.exists(default_config_rel_root): + default_config = default_config_rel_root + elif os.path.exists(default_config_pkg): + default_config = default_config_pkg + else: + default_config = default_config_cwd # Fallback to CWD + if not os.path.exists(default_config): + # If none exist, use a placeholder path to show in help message + default_config = '../config.yaml' # Placeholder for help message + + parser.add_argument( + '--config', + type=str, + default=default_config, + help=f"Path to the configuration YAML file (default attempts: {default_config_rel_root}, {default_config_pkg}, {default_config_cwd})" + ) + args = parser.parse_args() + + config_to_use = args.config + # Verify the provided/default config path exists + if not os.path.exists(config_to_use): + print(f"Error: Config file not found at the specified path: {config_to_use}") + print("Please ensure the path is correct or place config.yaml in the expected location.") + sys.exit(1) + + # Instantiate and run the pipeline + pipeline = TradingPipeline(config_path=config_to_use) + pipeline.execute() \ No newline at end of file diff --git a/gru_sac_predictor/src/trading_system.py b/gru_sac_predictor/src/trading_system.py deleted file mode 100644 index b008e85f..00000000 --- a/gru_sac_predictor/src/trading_system.py +++ /dev/null @@ -1,1856 +0,0 @@ -import numpy as np -import pandas as pd -import os -from tqdm import tqdm # Added for progress tracking -import logging # Added for feature functions -from scipy import stats # Added for feature functions -from typing import Tuple -import tensorflow as tf -import warnings -import sys - -# Optional: Import TA-Lib if available (used by V6 features) -try: - import talib - TALIB_AVAILABLE = True - # logging.info("TA-Lib found. Using TA-Lib for V6 features.") # Reduce noise -except ImportError: - TALIB_AVAILABLE = False - logging.warning("TA-Lib not found. V6 feature calculation will be limited.") - -# Import other components from the src directory -from .gru_predictor import CryptoGRUModel -from .sac_agent_simplified import SimplifiedSACTradingAgent -# Remove direct import of prepare_features as logic is integrated here -# from .data_pipeline import prepare_features - -# Optional: Matplotlib for plotting -try: - import matplotlib.pyplot as plt -except ImportError: - plt = None - -# --- V7.12 START: Add SAC Training History Plotting Function --- -def plot_sac_training_history(history, save_path=None, run_id=None): - """ - Plots the actor and critic loss from the SAC training history. - - Args: - history (list[dict]): A list of dictionaries, where each dict contains - metrics like 'actor_loss' and 'critic_loss'. - save_path (str, optional): Path to save the plot. If None, tries to show plot. - run_id (str, optional): Run ID to include in the plot title. - """ - if plt is None: - print("Matplotlib not available, skipping SAC training plot.") - return - - if not history: - print("No SAC training history data provided, skipping plot.") - return - - try: - actor_losses = [item['actor_loss'] for item in history if 'actor_loss' in item and item['actor_loss'] is not None] - critic_losses = [item['critic_loss'] for item in history if 'critic_loss' in item and item['critic_loss'] is not None] - steps = range(1, len(history) + 1) # Assuming one entry per training step/batch - - if not actor_losses or not critic_losses: - print("Warning: Could not extract valid actor or critic losses from history. Skipping plot.") - return - - # Determine common length if lists differ (shouldn't usually happen if logged together) - min_len = min(len(actor_losses), len(critic_losses)) - if min_len < len(history): - print(f"Warning: Plotting only {min_len} steps due to missing data.") - steps = range(1, min_len + 1) - actor_losses = actor_losses[:min_len] - critic_losses = critic_losses[:min_len] - - fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True) - fig.suptitle(f'SAC Training History {"(Run: " + run_id + ")" if run_id else ""}', fontsize=16) - - # Actor Loss Plot - axs[0].plot(steps, actor_losses, label='Actor Loss', color='tab:blue') - axs[0].set_ylabel('Loss') - axs[0].set_title('Actor Loss over Training Steps') - axs[0].legend() - axs[0].grid(True) - - # Critic Loss Plot - axs[1].plot(steps, critic_losses, label='Critic Loss', color='tab:orange') - axs[1].set_xlabel('Agent Training Step') - axs[1].set_ylabel('Loss') - axs[1].set_title('Critic Loss over Training Steps') - axs[1].legend() - axs[1].grid(True) - - plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap - - if save_path: - try: - plt.savefig(save_path) - print(f"SAC training history plot saved to {save_path}") - except Exception as e: - print(f"Error saving SAC training plot: {e}") - else: - try: - print("Displaying SAC training plot...") - plt.show() - except Exception as e: - print(f"Error displaying SAC training plot: {e}") - - except KeyError as e: - print(f"KeyError accessing loss data in history: {e}. Skipping plot.") - except Exception as e: - print(f"An unexpected error occurred during SAC plot generation: {e}") - finally: - # Ensure plot is closed to free memory, especially if not shown interactively - if plt: - plt.close(fig) -# --- V7.12 END: Add SAC Training History Plotting Function --- - - -# --- V6 Feature Calculation Logic (Copied & Adapted) --- -# Configure logging for feature calculation functions -feature_logger = logging.getLogger('TradingSystemFeatures') -# Set level to WARNING to avoid excessive logs during backtest -feature_logger.setLevel(logging.WARNING) - -# V7 Update: Add Scalers and sequence creation -from sklearn.preprocessing import MinMaxScaler, StandardScaler -from .data_pipeline import create_sequences_v2 - -# Copied from v6/src/cryptofeatures.py -def calculate_vwap(df: pd.DataFrame, period: int = None) -> pd.Series: - """ - Calculate Volume Weighted Average Price (VWAP). - Adapted from V6. - """ - try: - typical_price = (df['high'] + df['low'] + df['close']) / 3 - tp_volume = typical_price * df['volume'] - volume_sum = df['volume'].rolling(window=period, min_periods=1).sum() - tp_volume_sum = tp_volume.rolling(window=period, min_periods=1).sum() - volume_sum_safe = volume_sum.replace(0, np.nan) - vwap = tp_volume_sum / volume_sum_safe - vwap = vwap.fillna(method='ffill').fillna(method='bfill').fillna(df['close']) # More robust fill - return vwap - except Exception as e: - feature_logger.error(f"Failed to calculate VWAP (period={period}): {e}") - return df['close'] # Fallback - -# Copied from v6/src/cryptofeatures.py and adapted -def add_crypto_features(df: pd.DataFrame) -> pd.DataFrame: - """Add cryptocurrency-specific features. Adapted from V6.""" - df_features = df.copy() - try: - # Parkinson Volatility - Add better safeguards - high_safe = df_features['high'].replace(0, np.nan) - low_safe = df_features['low'].replace(0, np.nan).fillna(high_safe * 0.999) # Avoid division by zero - high_to_low = np.clip(high_safe / low_safe, 1.0, 10.0) # Clip extreme ratios - log_hl_sq = np.log(high_to_low)**2 - df_features['parkinson_vol_14'] = np.sqrt((1 / (4 * np.log(2))) * log_hl_sq.rolling(window=14).mean()) - - # Garman-Klass Volatility - Add better safeguards - close_safe = df_features['close'].replace(0, np.nan) - open_safe = df_features['open'].replace(0, np.nan).fillna(close_safe * 0.999) # Avoid division by zero - close_to_open = np.clip(close_safe / open_safe, 0.5, 2.0) # Clip extreme ratios - log_co_sq = np.log(close_to_open)**2 - gk_vol = (0.5 * log_hl_sq - (2 * np.log(2) - 1) * log_co_sq).rolling(window=14).mean() - # Ensure non-negative values before sqrt - gk_vol = np.maximum(gk_vol, 0) - df_features['garman_klass_vol_14'] = np.sqrt(gk_vol) - - # VWAP Features - Improve robustness - for period in [30, 60, 120]: - df_features[f'vwap_{period}'] = calculate_vwap(df_features, period=period) - vwap_safe = df_features[f'vwap_{period}'].replace(0, np.nan).fillna(close_safe) - # Clip the ratio to a reasonable range to avoid extreme values - df_features[f'close_to_vwap_{period}'] = np.clip((close_safe / vwap_safe) - 1, -0.5, 0.5) - - # V7 Fix: Cyclical features are now calculated earlier in main.py - # # Cyclical Features (Simplified) - # if isinstance(df_features.index, pd.DatetimeIndex): - # df_features['hour_sin'] = np.sin(2 * np.pi * df_features.index.hour / 24) - # df_features['hour_cos'] = np.cos(2 * np.pi * df_features.index.hour / 24) - - # Volume Intensity - Improve stability - vol = df_features['volume'].replace(0, np.nan) - vol_mean_30 = vol.rolling(window=30).mean() - vol_mean_30_safe = vol_mean_30.replace(0, np.nan).fillna(vol.mean()) - df_features['vol_intensity'] = np.clip(vol / vol_mean_30_safe, 0, 10) # Clip to reasonable range - - # Price Pattern Features - Improved stability - body = abs(close_safe - open_safe) - # Use a meaningful minimum body size to avoid extreme ratios - min_body_size = df_features['close'].mean() * 0.0001 # Small percentage of avg price - body_safe = np.maximum(body, min_body_size) - - # Calculate wicks with better handling of extreme cases - upper_wick = df_features['high'] - df_features[['open', 'close']].max(axis=1) - lower_wick = df_features[['open', 'close']].min(axis=1) - df_features['low'] - # Clip the ratios to reasonable ranges - df_features['upper_wick_ratio'] = np.clip(upper_wick / body_safe, 0, 5) - df_features['lower_wick_ratio'] = np.clip(lower_wick / body_safe, 0, 5) - - # Fill NaNs introduced here - cols_to_fill = ['parkinson_vol_14', 'garman_klass_vol_14', 'vol_intensity', - 'upper_wick_ratio', 'lower_wick_ratio'] # Removed hour_sin/cos - cols_to_fill.extend([f'close_to_vwap_{p}' for p in [30, 60, 120]]) - for col in cols_to_fill: - if col in df_features.columns: - # Fill NaNs with median values rather than zeros to maintain distribution - median_val = df_features[col].median() - df_features[col].fillna(median_val if not np.isnan(median_val) else 0, inplace=True) - - except Exception as e: - feature_logger.error(f"Error calculating crypto-specific features: {e}", exc_info=True) - return df_features - -# Adapted from v6/src/data_preprocessing.py::calculate_technical_indicators -def calculate_v6_features(df: pd.DataFrame) -> pd.DataFrame: - """Calculates V6 technical indicators + basic return features.""" - df_features = df.copy() - required_cols = ['open', 'high', 'low', 'close', 'volume'] - if not all(col in df_features.columns for col in required_cols): - missing = [col for col in required_cols if col not in df_features.columns] - feature_logger.error(f"Missing required V6 columns: {missing}") - return df_features - for col in required_cols: - df_features[col] = pd.to_numeric(df_features[col], errors='coerce') - if df_features[col].isnull().any(): - df_features[col] = df_features[col].ffill().bfill() - if df_features[col].isnull().any(): df_features[col] = df_features[col].fillna(0) - - # --- V7.2 Add Past Return Features --- - for lag in [1, 5, 15, 60]: # 1m, 5m, 15m, 1h returns - # Use pct_change for robustness to price levels - df_features[f'return_{lag}m'] = df_features['close'].pct_change(periods=lag) - # --- End Return Features --- - - if TALIB_AVAILABLE: - try: - close = df_features['close'].values; open_price = df_features['open'].values - high = df_features['high'].values; low = df_features['low'].values; volume = df_features['volume'].values - for period in [5, 10, 20, 30, 50, 100]: - df_features[f'SMA_{period}'] = talib.SMA(close, timeperiod=period) - df_features[f'EMA_{period}'] = talib.EMA(close, timeperiod=period) - macd, macdsignal, macdhist = talib.MACD(close, fastperiod=12, slowperiod=26, signalperiod=9) - df_features['MACD'] = macd; df_features['MACD_signal'] = macdsignal; df_features['MACD_hist'] = macdhist - df_features['SAR'] = talib.SAR(high, low, acceleration=0.02, maximum=0.2) - df_features['ADX_14'] = talib.ADX(high, low, close, timeperiod=14) - for period in [9, 14, 21]: df_features[f'RSI_{period}'] = talib.RSI(close, timeperiod=period) - slowk, slowd = talib.STOCH(high, low, close, fastk_period=14, slowk_period=3, slowk_matype=0, slowd_period=3, slowd_matype=0) - df_features['STOCH_K'] = slowk; df_features['STOCH_D'] = slowd - df_features['WILLR_14'] = talib.WILLR(high, low, close, timeperiod=14) - for period in [5, 10, 20]: df_features[f'ROC_{period}'] = talib.ROC(close, timeperiod=period) - df_features['CCI_14'] = talib.CCI(high, low, close, timeperiod=14) - upper20, middle20, lower20 = talib.BBANDS(close, timeperiod=20, nbdevup=2, nbdevdn=2, matype=0) - df_features['BB_upper_20'] = upper20; df_features['BB_middle_20'] = middle20; df_features['BB_lower_20'] = lower20 - middle20_safe = np.where(middle20 == 0, np.nan, middle20) - bb_width = (upper20 - lower20) / middle20_safe - df_features['BB_width_20'] = np.nan_to_num(bb_width, nan=0.0) - for period in [7, 14, 21]: df_features[f'ATR_{period}'] = talib.ATR(high, low, close, timeperiod=period) - df_features['OBV'] = talib.OBV(close, volume) - df_features['CMF_20'] = talib.ADOSC(high, low, close, volume, fastperiod=3, slowperiod=10) - vol_sma_20 = talib.SMA(volume, timeperiod=20) - vol_sma_20_safe = np.where(vol_sma_20 == 0, np.nan, vol_sma_20) - df_features['volume_SMA_20'] = vol_sma_20 - df_features['volume_ratio'] = np.nan_to_num(volume / vol_sma_20_safe, nan=1.0) - df_features['DOJI'] = talib.CDLDOJI(open_price, high, low, close) / 100 - df_features['ENGULFING'] = talib.CDLENGULFING(open_price, high, low, close) / 100 - df_features['HAMMER'] = talib.CDLHAMMER(open_price, high, low, close) / 100 - df_features['SHOOTING_STAR'] = talib.CDLSHOOTINGSTAR(open_price, high, low, close) / 100 - except Exception as e: feature_logger.error(f"Error calculating TA-Lib indicators: {e}", exc_info=True) - - if not TALIB_AVAILABLE: - feature_logger.warning("Calculating subset of features manually (TA-Lib unavailable)") - for period in [5, 10, 20]: - df_features[f'SMA_{period}'] = df_features['close'].rolling(window=period).mean() - df_features[f'EMA_{period}'] = df_features['close'].ewm(span=period, adjust=False).mean() - # df_features[f'ROC_{period}'] = df_features['close'].pct_change(periods=period) * 100 # Replaced by return_xm - delta = df_features['close'].diff(); gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() - loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean().replace(0, 1e-10); rs = gain / loss - df_features['RSI_14'] = 100 - (100 / (1 + rs)) - - # Derived Features (always calculated) - for period in [14, 30]: - roll_max = df_features['close'].rolling(window=period).max() - roll_min = df_features['close'].rolling(window=period).min() - # Ensure the range isn't too small to avoid extreme ratios - min_range = df_features['close'].mean() * 0.001 # 0.1% of mean price as minimum range - denominator = np.maximum(roll_max - roll_min, min_range) - # Relative position is a ratio - apply Fisher transform to make more normal - rel_pos = np.clip((df_features['close'] - roll_min) / denominator, 0.01, 0.99) - # Apply logit transform (inverse of sigmoid) to normalize the distribution - df_features[f'rel_position_{period}'] = np.log(rel_pos / (1 - rel_pos)) - - for period in [20, 50]: - if f'SMA_{period}' in df_features.columns: - sma = df_features[f'SMA_{period}'] - close = df_features['close'] - # Apply log transform to price-to-SMA ratio to make distribution more normal - ratio = np.clip(close / np.maximum(sma, close.mean() * 0.01), 0.5, 2.0) - df_features[f'price_dist_SMA_{period}'] = np.log(ratio) - - # Intraday return - with log transform for better distribution - open_prices = df_features['open'].replace(0, np.nan) - close_prices = df_features['close'] - # Use a reasonable minimum price to avoid extreme ratios - min_open = close_prices.mean() * 0.01 # 1% of mean close - open_safe = np.maximum(open_prices.fillna(close_prices), min_open) - # Apply log to price ratio for better distribution - ratio = np.clip(close_prices / open_safe, 0.5, 2.0) - df_features['intraday_return'] = np.log(ratio) - - # Log return with improved handling - prev_close = df_features['close'].shift(1).replace(0, np.nan) - current_close = df_features['close'] - # Avoid division by very small values - ratio = np.clip(current_close / np.maximum(prev_close.fillna(current_close), current_close.mean() * 0.01), 0.5, 2.0) - df_features['log_return'] = np.log(ratio) - - # Volatility with robust handling - apply log transform to make more normal - log_return_clipped = np.clip(df_features['log_return'], -0.2, 0.2) - vol = log_return_clipped.rolling(window=14).std().fillna(0) - # Log transform volatility for more normal distribution - df_features['volatility_14d'] = np.log1p(vol * 100) # log1p avoids issues with zero values - - # Add Crypto Features - df_features = add_crypto_features(df_features) - - # Final NaN fill (important after all calculations) - # Fill with 0 as done in V6 for most features - cols_to_fill = df_features.columns.difference(required_cols) - for col in cols_to_fill: - if df_features[col].isnull().any(): - df_features[col].fillna(0, inplace=True) - - # Drop original OHLCV + volume columns? No, keep them for potential analysis/state - # df_features = df_features.drop(columns=required_cols) - - return df_features -# --- End V6 Feature Logic --- - - -class TradingSystem: - """ - V7 Trading System: Integrates GRU Predictor and SAC Agent. - Generates features, uses GRU for price/uncertainty prediction, - and SAC for trading decisions. - """ - - # V7.2 Add system logger - _logger = logging.getLogger(__name__) - if not _logger.hasHandlers(): - _logger.setLevel(logging.INFO) - handler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - handler.setFormatter(formatter) - _logger.addHandler(handler) - _logger.propagate = False - - def __init__(self, gru_model: CryptoGRUModel = None, sac_agent: SimplifiedSACTradingAgent = None, gru_lookback=60): - """ - Initialize the TradingSystem. - - Args: - gru_model (CryptoGRUModel, optional): Pre-initialized GRU model. Defaults to None. - sac_agent (SACTradingAgent, optional): Pre-initialized SAC agent. Defaults to None. - If None, a default agent with state_dim=2 is created. - gru_lookback (int): Number of time steps for GRU input sequence. - """ - self._logger.info("Initializing V7 Trading System...") - self.gru_model = gru_model or CryptoGRUModel() # Use provided or create new - self.sac_agent = sac_agent or SimplifiedSACTradingAgent(state_dim=2) # V7 state: [pred_return, uncertainty] - self.gru_lookback = gru_lookback - self.feature_scaler = None - self.y_scaler = None - self.last_prediction = None - self.experiences = [] # Store (state, action, reward, next_state, done) - # V7.23: Add storage for scaled feature indices - self._momentum_feature_idx = None - self._volatility_feature_idx = None - # Initialize scalers from GRU model if available and loaded - if self.gru_model and self.gru_model.is_loaded: - self.feature_scaler = self.gru_model.feature_scaler - self.y_scaler = self.gru_model.y_scaler - _logger.info("Scalers initialized from pre-loaded GRU model.") - # V7.23: Attempt to set indices if scaler loaded - self._set_feature_indices() - - # V7.23: Helper method to find and store feature indices - def _set_feature_indices(self): - """Finds and stores the column indices for momentum and volatility features.""" - if self.feature_scaler is None: - self._logger.warning("Feature scaler not available. Cannot set feature indices.") - self._momentum_feature_idx = None - self._volatility_feature_idx = None - return False - - if not hasattr(self.feature_scaler, 'feature_names_in_'): - # Fallback for older sklearn versions or different scaler types - if hasattr(self.feature_scaler, 'n_features_in_'): - self._logger.warning("Feature scaler lacks 'feature_names_in_'. Cannot reliably find named features. State scaling may fail.") - # Cannot set indices reliably without names - self._momentum_feature_idx = None - self._volatility_feature_idx = None - return False - else: - self._logger.error("Feature scaler lacks both 'feature_names_in_' and 'n_features_in_'. Cannot determine feature indices.") - self._momentum_feature_idx = None - self._volatility_feature_idx = None - return False - - feature_columns = list(self.feature_scaler.feature_names_in_) - try: - # Use the correct feature names: 'return_5m' and 'volatility_14d' - self._momentum_feature_idx = feature_columns.index('return_5m') - self._volatility_feature_idx = feature_columns.index('volatility_14d') - self._logger.info(f"Successfully set feature indices: momentum_5m={self._momentum_feature_idx}, volatility_14d={self._volatility_feature_idx}") - return True - except ValueError as e: - self._logger.error(f"Could not find 'return_5m' or 'volatility_14d' in scaler's feature columns: {feature_columns}. Error: {e}. State construction will likely fail.") - self._momentum_feature_idx = None - self._volatility_feature_idx = None - return False - - # V7-V6 Update: Extract features and predict return/uncertainty for SAC state - def _extract_features_and_predict(self, data_df_full, current_idx): - """ - Extracts feature sequence and uses the trained GRU model to predict - the next period's return and associated confidence/uncertainty. - - Args: - data_df_full: Full DataFrame with OHLCV and potentially pre-calculated features. - current_idx: The current index in the DataFrame. - - Returns: - Tuple(np.array, float, float) or (None, None, None): - - features_sequence_scaled: Scaled NumPy array (sequence_length, num_features) - - predicted_return: GRU predicted return for the next period. - - uncertainty_sigma: Uncertainty estimate for the next period. - """ - # --- Step 1: Extract Raw Feature Sequence --- - # Determine the slice needed for feature calculation (V6 style) - # Needs enough history for longest V6 feature calc + sequence length - # Assume V6 features require ~100 periods + sequence length - required_feature_history = 100 + self.gru_lookback # Conservative estimate - if current_idx < required_feature_history - 1: - # feature_logger.debug(f"Not enough history at idx {current_idx} for feature calc + lookback ({required_feature_history})") - return None, None, None - - start_idx_calc = max(0, current_idx - required_feature_history + 1) - end_idx_calc = current_idx + 1 # Include current index - data_slice_for_calc = data_df_full.iloc[start_idx_calc:end_idx_calc] - - # --- Step 2: Calculate All V6 Features --- - try: - df_with_features = calculate_v6_features(data_slice_for_calc) - if df_with_features.empty or len(df_with_features) < self.gru_lookback: - feature_logger.warning(f"Feature calculation failed or insufficient length at idx {current_idx}") - return None, None, None - except Exception as e: - feature_logger.error(f"Error calculating V6 features at idx {current_idx}: {e}", exc_info=False) - return None, None, None - - # --- Step 3: Select Final Sequence & Scale --- - # The sequence ends at the *current* index (current_idx) - feature_sequence_unscaled_df = df_with_features.iloc[-self.gru_lookback:] - - # Check if feature scaler exists (should after GRU training/loading) - if self.gru_model is None or self.feature_scaler is None: - feature_logger.error(f"GRU model or feature scaler not available at idx {current_idx}. Cannot scale features.") - return None, None, None - - try: - # Ensure columns match scaler's expectations - expected_features = self.feature_scaler.feature_names_in_ - # Select and reorder columns - feature_sequence_aligned_df = feature_sequence_unscaled_df[expected_features] - - # Scale the sequence - Pass the DataFrame directly - # V7 Warning Fix: Pass DataFrame to transform, not .values - features_sequence_scaled_array = self.feature_scaler.transform(feature_sequence_aligned_df) - - # Ensure output is numpy array for reshaping - if isinstance(features_sequence_scaled_array, pd.DataFrame): - features_sequence_scaled = features_sequence_scaled_array.values - else: - features_sequence_scaled = features_sequence_scaled_array # Assume it's already numpy - - if features_sequence_scaled.shape[0] != self.gru_lookback: - feature_logger.error(f"Scaled sequence shape incorrect at idx {current_idx}: {features_sequence_scaled.shape}") - return None, None, None - except KeyError as e: - feature_logger.error(f"Missing expected feature columns for scaler at idx {current_idx}: {e}") - feature_logger.error(f" Available columns: {feature_sequence_unscaled_df.columns.tolist()}") - feature_logger.error(f" Expected columns: {expected_features}") - return None, None, None - except Exception as e: - feature_logger.error(f"Error scaling feature sequence at idx {current_idx}: {e}", exc_info=False) - return None, None, None - - # --- Step 4: Predict Return and Uncertainty using GRU --- - # Perform efficient single-step prediction and MC dropout here - try: - # Reshape for model prediction (1, seq_len, n_features) - model_input = features_sequence_scaled.reshape(1, self.gru_lookback, features_sequence_scaled.shape[1]) - - # 4a. Get standard prediction (scaled price) - pred_scaled = self.gru_model.model.predict(model_input, verbose=0)[0, 0] - - # 4b. Perform MC Dropout for uncertainty - n_mc_samples = 30 # Use a reasonable number of samples - mc_preds_scaled_list = [] - - # Define the prediction step with training=True - @tf.function - def mc_predict_step_single(batch): - return self.gru_model.model(batch, training=True) # Enable dropout - - for _ in range(n_mc_samples): - mc_pred_scaled = mc_predict_step_single(tf.constant(model_input, dtype=tf.float32)).numpy() - mc_preds_scaled_list.append(mc_pred_scaled[0, 0]) # Get the scalar prediction - - mc_preds_scaled_array = np.array(mc_preds_scaled_list) - mc_std_scaled = np.std(mc_preds_scaled_array) - - # 4c. Unscale prediction and uncertainty - if self.y_scaler is None: - feature_logger.error(f"Y-scaler not available at idx {current_idx}, cannot unscale prediction/uncertainty.") - return None, None, None - - pred_unscaled = self.y_scaler.inverse_transform([[pred_scaled]])[0, 0] - mc_unscaled_std_dev = 0.0 # Default - try: - # Use scaler's data range - if hasattr(self.y_scaler, 'data_min_') and hasattr(self.y_scaler, 'data_max_'): - data_range = self.y_scaler.data_max_[0] - self.y_scaler.data_min_[0] - if data_range > 1e-9: - mc_unscaled_std_dev = mc_std_scaled * data_range - else: feature_logger.warning("Scaler data range is near zero for unscaling std dev.") - else: feature_logger.warning("y_scaler missing data_min_/data_max_ for unscaling std dev.") - except Exception as e: - feature_logger.error(f"Error unscaling MC std dev at idx {current_idx}: {e}") - - # 4d. Calculate predicted return - # Get the *actual* close price at the *current* index to calculate return - # Use the unscaled features dataframe for this - start_price = feature_sequence_unscaled_df.iloc[-1]['close'] - epsilon = 1e-9 - if abs(start_price) > epsilon: - predicted_return = (pred_unscaled / start_price) - 1.0 - else: - predicted_return = 0.0 - feature_logger.warning(f"Start price is near zero at idx {current_idx}, cannot calculate return.") - - # Return the SCALED features, predicted return, and UNCERTAINTY SIGMA - return features_sequence_scaled, predicted_return, mc_unscaled_std_dev - - except Exception as e: - feature_logger.error(f"Error during GRU prediction/evaluation step at idx {current_idx}: {e}", exc_info=False) - return None, None, None - - # V7-V6 Update: Rename and modify to reflect price target - def _preprocess_data_for_gru_training(self, df_data: pd.DataFrame, prediction_horizon: int) -> Tuple[pd.DataFrame, pd.Series, pd.Series]: - """ - Preprocess data specifically for GRU model training (V6 Style - Price Target). - Calculates V6 features, defines future price target, and aligns data. - Scaling and sequence creation will happen *within* train_gru. - - Args: - df_data: Raw OHLCV DataFrame - prediction_horizon: How many steps ahead to predict the price. - - Returns: - Tuple of (features DataFrame, future_price_target Series, start_price Series) - Returns (None, None, None) if processing fails. - """ - if df_data is None or df_data.empty: - logging.error("Empty/None dataframe provided to _preprocess_data_for_gru_training") - return None, None, None - - logging.info(f"Preprocessing data for GRU training (Target: Future Price, Horizon: {prediction_horizon})...") - - # --- Calculate V6 Features --- - logging.debug(f"Calculating V6 features for data shape {df_data.shape}") - try: - features_df = calculate_v6_features(df_data.copy()) # Use a copy - if features_df is None or features_df.empty: - logging.error("Feature calculation returned empty DataFrame.") - return None, None, None - except Exception as e: - logging.error(f"Error calculating V6 features: {e}", exc_info=True) - return None, None, None - - # --- Define Target (Future Price) & Start Price --- - try: - # Target is the closing price `prediction_horizon` steps into the future - future_price_target_ser = features_df['close'].shift(-prediction_horizon) - - # Start price is the current closing price (aligned with features) - start_price_ser = features_df['close'] - - # --- Align Features, Target, and Start Price --- - # Drop rows where the future target is NaN (typically at the end) - common_index = future_price_target_ser.dropna().index - - # Ensure indices exist in all structures before slicing - common_index = common_index.intersection(features_df.index).intersection(start_price_ser.index) - - if common_index.empty: - logging.error("No common index found after aligning features and future price target.") - return None, None, None - - features_aligned_df = features_df.loc[common_index] - target_aligned_ser = future_price_target_ser.loc[common_index] - start_price_aligned_ser = start_price_ser.loc[common_index] - - except KeyError as e: - logging.error(f"Missing 'close' column for target/start price calculation: {e}") - return None, None, None - except Exception as e: - logging.error(f"Error defining/aligning target and start price: {e}", exc_info=True) - return None, None, None - - # Final check for NaN values introduced during alignment or feature calc - if features_aligned_df.isna().any().any(): - nan_cols = features_aligned_df.columns[features_aligned_df.isna().any()].tolist() - logging.warning(f"NaN values detected in final aligned feature columns: {nan_cols}. Dropping rows.") - features_aligned_df = features_aligned_df.dropna() - # Realign target and start price - common_index = features_aligned_df.index - target_aligned_ser = target_aligned_ser.loc[common_index] - start_price_aligned_ser = start_price_aligned_ser.loc[common_index] - - # --- Final Shape Check --- - if not (len(features_aligned_df) == len(target_aligned_ser) == len(start_price_aligned_ser)): - logging.error(f"Final shape mismatch after NaN handling: Features={len(features_aligned_df)}, Target={len(target_aligned_ser)}, StartPrice={len(start_price_aligned_ser)}") - return None, None, None - if features_aligned_df.empty: - logging.error("Preprocessing resulted in empty features DataFrame after alignment.") - return None, None, None - - logging.info(f"Preprocessing complete. Features: {features_aligned_df.shape}, Target: {target_aligned_ser.shape}, Start Price: {start_price_aligned_ser.shape}") - return features_aligned_df, target_aligned_ser, start_price_aligned_ser - - def generate_trading_experiences(self, val_data, transaction_cost=0.00001, prediction_horizon=1): - """ - V7-V6 Update: Use predicted return and uncertainty sigma for SAC state. - V7 Efficiency Update: Pre-compute GRU predictions/uncertainty for the whole validation set. - """ - print(f"Generating trading experiences for SAC agent on {len(val_data)} validation data points...") - - # --- Pre-computation Step --- - if self.gru_model is None or not (self.gru_model.is_trained or self.gru_model.is_loaded): - logging.error("Cannot generate experiences: GRU model not ready.") - return [] - if self.feature_scaler is None or self.y_scaler is None: - logging.error("Cannot generate experiences: Scalers not loaded/trained.") - return [] - - logging.info("Preprocessing validation data for experience generation...") - val_features_df, val_target_price_ser, val_start_price_ser = self._preprocess_data_for_gru_training(val_data, prediction_horizon) - if val_features_df is None: - logging.error("Failed to preprocess validation data for experiences.") - return [] - - logging.info("Scaling validation features and targets...") - try: - val_features_scaled = self.feature_scaler.transform(val_features_df.select_dtypes(include=np.number).fillna(0)) - val_target_scaled = self.y_scaler.transform(val_target_price_ser.fillna(0).values.reshape(-1, 1)) - except Exception as e: - logging.error(f"Error scaling validation data for experiences: {e}", exc_info=True) - return [] - - logging.info("Creating validation sequences...") - try: - val_start_price_aligned_ser = val_start_price_ser.loc[val_features_df.index] # Align before sequencing - X_val, y_val_scaled_seq, y_start_price_val_seq = create_sequences_v2( - pd.DataFrame(val_features_scaled, index=val_features_df.index), - pd.Series(val_target_scaled.flatten(), index=val_features_df.index), - val_start_price_aligned_ser, - self.gru_lookback - ) - if X_val is None or y_val_scaled_seq is None or y_start_price_val_seq is None: - logging.error("Sequence creation failed for validation data.") - return [] - except Exception as e: - logging.error(f"Error creating validation sequences: {e}", exc_info=True) - return [] - - logging.info(f"Pre-computing GRU predictions/uncertainty for {len(X_val)} validation sequences...") - try: - eval_results = self.gru_model.evaluate(X_val, y_val_scaled_seq, y_start_price_val_seq, n_mc_samples=30) - if eval_results is None: - logging.error("GRU evaluate failed on validation sequences.") - return [] - all_pred_returns = eval_results['pred_percent_change'] - all_uncertainties = eval_results['mc_unscaled_std_dev'] - except Exception as e: - logging.error(f"Error during pre-computation evaluate call: {e}", exc_info=True) - return [] - - # --- V7.15 START: Extract Momentum and Volatility for Experience Generation --- - # Similar to backtest logic, extract features aligned with sequences - num_sequences = len(all_pred_returns) # Use num_sequences from evaluate results - all_momentum_5 = np.zeros(num_sequences) - all_volatility_20 = np.zeros(num_sequences) - try: - # Check required columns in the *unscaled* preprocessed features df - required_state_cols = ['return_5m', 'volatility_14d'] - if not all(col in val_features_df.columns for col in required_state_cols): - missing_cols = [col for col in required_state_cols if col not in val_features_df.columns] - logging.error(f"Missing required state columns in val_features_df: {missing_cols}") - return [] - - # Align features with the sequences - if len(val_features_df) >= self.gru_lookback - 1 + num_sequences: - aligned_feature_indices = val_features_df.index[self.gru_lookback - 1 : self.gru_lookback - 1 + num_sequences] - aligned_features_for_state = val_features_df.loc[aligned_feature_indices] - - all_momentum_5 = aligned_features_for_state['return_5m'].values - all_volatility_20 = aligned_features_for_state['volatility_14d'].values - - if len(all_momentum_5) != num_sequences or len(all_volatility_20) != num_sequences: - logging.error(f"Exp Gen: Length mismatch extracting state features: Mom5({len(all_momentum_5)}), Vol20({len(all_volatility_20)}) vs NumSeq({num_sequences})") - return [] - else: - logging.error("Exp Gen: Length mismatch: val_features_df too short to extract aligned state features.") - return [] - except KeyError as e: - logging.error(f"Exp Gen: KeyError extracting state features: {e}") - return [] - except Exception as e: - logging.error(f"Exp Gen: Error extracting state features: {e}", exc_info=True) - return [] - # --- V7.15 END: Extract Momentum and Volatility --- - - logging.info("GRU pre-computation finished. Generating experiences loop...") - # --- End Pre-computation --- - - experiences = [] - current_position = 0.0 # Position *before* taking action at step i - - # Loop through the *results* (length = num_sequences) - # num_sequences = len(all_pred_returns) # Already defined - if num_sequences <= 1: - logging.warning("Not enough sequences generated from validation data to create experiences.") - return [] - - # Need original close prices aligned with the sequences for reward calculation - # The i-th sequence corresponds to the original data ending at index i + gru_lookback - 1 - # The prediction is for the step *after* this sequence - # The relevant close prices for reward are at original indices [i + gru_lookback] and [i + gru_lookback + 1] - original_close_prices = val_data['close'] # Use original validation data - - for i in tqdm(range(num_sequences - 1), desc="Generating Experiences"): # Iterate up to second-to-last sequence result - # --- V7.15 START: Construct 5D state s_t --- - # V7.23 Get SCALED momentum/volatility for state t - pred_return_t = all_pred_returns[i] - uncertainty_t = all_uncertainties[i] - momentum_5_t_scaled = all_momentum_5[i] # Use scaled value - volatility_20_t_scaled = all_volatility_20[i] # Use scaled value - - # Calculate z_proxy using position *before* action (current_position) - # V7.23: Use SCALED volatility for consistency within state. - z_proxy_t = current_position * volatility_20_t_scaled - state = np.array([pred_return_t, uncertainty_t, z_proxy_t, momentum_5_t_scaled, volatility_20_t_scaled], dtype=np.float32) - - # Handle potential NaNs/Infs - if np.any(np.isnan(state)) or np.any(np.isinf(state)): - logging.warning(f"NaN/Inf in state at step {i}. Replacing with 0. State: {state}") - state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0) - # --- V7.15 END: Construct 5D state s_t --- - - # Get action a_t (unpack action and log_prob) - action_tuple = self.sac_agent.get_action(state, deterministic=False) - if isinstance(action_tuple, (tuple, list)) and len(action_tuple) > 0: - action = action_tuple[0] # Action is the first element - # Ensure action is scalar if action_dim is 1 - if self.sac_agent.action_dim == 1 and isinstance(action, (np.ndarray, list)): - action = action[0] - else: - logging.error(f"SAC agent get_action did not return expected tuple at step {i}. Got: {action_tuple}. Skipping experience.") - continue # Skip this experience if action format is wrong - - # --- V7.15 START: Construct 5D next_state s_{t+1} --- - # V7.23 Get SCALED momentum/volatility for state t+1 - pred_return_t1 = all_pred_returns[i+1] - uncertainty_t1 = all_uncertainties[i+1] - momentum_5_t1_scaled = all_momentum_5[i+1] # Use scaled value - volatility_20_t1_scaled = all_volatility_20[i+1] # Use scaled value - - # Calculate z_proxy using the *action* taken (action is position for next step) - # V7.23: Use SCALED volatility for consistency - z_proxy_t1 = action * volatility_20_t1_scaled # Use action, not current_position - next_state = np.array([pred_return_t1, uncertainty_t1, z_proxy_t1, momentum_5_t1_scaled, volatility_20_t1_scaled], dtype=np.float32) - - # Handle potential NaNs/Infs - if np.any(np.isnan(next_state)) or np.any(np.isinf(next_state)): - logging.warning(f"NaN/Inf in next_state at step {i}. Replacing with 0. State: {next_state}") - next_state = np.nan_to_num(next_state, nan=0.0, posinf=0.0, neginf=0.0) - # --- V7.15 END: Construct 5D next_state s_{t+1} --- - - # Calculate actual return for reward r_t - # Map sequence index 'i' back to the original dataframe index - original_idx_t = i + self.gru_lookback # Index of the *last* element of the sequence i - original_idx_t_plus_1 = original_idx_t + 1 # Index for the next closing price - - try: - # Use original_close_prices Series which should have the same index as val_data - close_t = original_close_prices.iloc[original_idx_t] - close_t1 = original_close_prices.iloc[original_idx_t_plus_1] - - if close_t != 0: - actual_return = (close_t1 / close_t) - 1.0 - else: - actual_return = 0.0 # Assign 0 return if start price is zero - except IndexError: - logging.warning(f"IndexError accessing original close prices for reward at step {i} (original indices {original_idx_t}, {original_idx_t_plus_1}). Skipping experience.") - continue - except Exception as e: - logging.error(f"Error accessing original close prices for reward at step {i}: {e}") - continue - - # Reward calculation (unchanged, uses action) - reward = action * actual_return - transaction_cost * abs(action - current_position) - - # Done flag (only True for the very last possible transition) - done = (i == num_sequences - 2) - # Store the experience (state, action, reward, next_state, done) - experiences.append((state, [action], reward, next_state, float(done))) # Ensure action is in a list for consistency if needed by buffer - - # Update current_position for the *next* iteration's z_proxy_t calculation - current_position = action - - print(f"Generated {len(experiences)} experiences efficiently.") - return experiences - - def train_sac(self, val_data, epochs=100, batch_size=256, transaction_cost=0.00001, - generate_new_experiences_on_epoch=False, # Default is now False - prediction_horizon=1): # Need prediction horizon for experience gen - """ - Train SAC agent using experiences generated from historical data. - (Code largely unchanged, relies on updated generate_trading_experiences) - """ - print(f"Starting SAC training for {epochs} epochs...") - if not generate_new_experiences_on_epoch: - print("Generating initial experiences..."); experiences = self.generate_trading_experiences(val_data, transaction_cost) - print(f"Adding {len(experiences)} experiences to replay buffer...") - for state, action, reward, next_state, done in tqdm(experiences, desc="Filling Buffer"): - action_to_add = action[0] if isinstance(action, (list, np.ndarray)) and len(action) == 1 else action - # V7.16 Fix: Use agent's store_transition method - # self.sac_agent.buffer.add(state, action_to_add, reward, next_state, done) - self.sac_agent.store_transition(state, action_to_add, reward, next_state, done) - print("Initial buffer fill complete.") - metrics_history = [] - for epoch in tqdm(range(epochs), desc="SAC Training Epochs"): - if generate_new_experiences_on_epoch: - experiences = self.generate_trading_experiences(val_data, transaction_cost, prediction_horizon) # Pass prediction_horizon - for state, action, reward, next_state, done in experiences: - action_to_add = action[0] if isinstance(action, (list, np.ndarray)) and len(action) == 1 else action - self.sac_agent.buffer.add(state, action_to_add, reward, next_state, done) - # V7.2 Revert: Use SAC agent default min buffer size - # V7.17 Fix: Check buffer size using agent's buffer_counter - # if len(self.sac_agent.buffer) >= self.sac_agent.min_buffer_size if hasattr(self.sac_agent, 'min_buffer_size') else batch_size: - if self.sac_agent.buffer_counter >= (self.sac_agent.min_buffer_size if hasattr(self.sac_agent, 'min_buffer_size') else batch_size): - # V7.17 Fix: train returns tuple (actor_loss, critic_loss) or None - loss_tuple = self.sac_agent.train(batch_size) - if loss_tuple: - actor_loss, critic_loss = loss_tuple - metrics = {'actor_loss': actor_loss, 'critic_loss': critic_loss} - metrics_history.append(metrics) - # Print metrics every 10 epochs instead of 100 - if metrics and epoch % 10 == 0: print(f"Epoch {epoch}/{epochs}: {metrics}") - else: - # Handle case where train() returns None (e.g., not enough data yet) - pass # Or log a warning if needed - elif epoch == 0: - min_buff = self.sac_agent.min_buffer_size if hasattr(self.sac_agent, 'min_buffer_size') else batch_size - # V7.16 Fix: Use buffer_counter in log message - print(f"Epoch {epoch}: Buffer size ({self.sac_agent.buffer_counter}) < min ({min_buff}). Skipping.") - print("SAC training complete."); return metrics_history - - def backtest_simple(self, data, transaction_cost=0.00001): - """ - V7-V6 Update: Use predicted return and confidence for SAC state. - """ - print(f"Starting simple backtest on {len(data)} data points...") - portfolio_value = 1.0; current_position = 0.0 - portfolio_values = [portfolio_value]; positions_history = [current_position] - predictions = []; uncertainties = [] # Store pred_return and confidence - - for i in tqdm(range(len(data) - 1), desc="Simple Backtest"): - # V7-V6 Update: Use _extract_features_and_predict - _, pred_return, uncertainty_sigma = self._extract_features_and_predict(data, i) - # V7-V6 Sigma Update: Check uncertainty_sigma - if pred_return is None or uncertainty_sigma is None: - continue - - predictions.append(pred_return) - # V7-V6 Sigma Update: Store uncertainty_sigma - uncertainties.append(uncertainty_sigma) - - # State for SAC: [predicted_return, mc_unscaled_std_dev] - # V7-V6 Sigma Update: Use uncertainty_sigma - state = np.array([pred_return, uncertainty_sigma], dtype=np.float32) - - action = self.sac_agent.get_action(state, deterministic=True)[0] - - tx_cost_fraction = transaction_cost * abs(action - current_position) - - # Use actual close prices for backtest return calc - if data.iloc[i]['close'] != 0: - price_return = (data.iloc[i+1]['close'] / data.iloc[i]['close']) - 1.0 - else: - price_return = 0 - - position_return = current_position * price_return - portfolio_value *= (1.0 + position_return) * (1.0 - tx_cost_fraction) - positions_history.append(action); portfolio_values.append(portfolio_value) - current_position = action - - # Calculate metrics - portfolio_values = np.array(portfolio_values) - returns = np.diff(portfolio_values) / portfolio_values[:-1] # Calculate portfolio returns - returns = np.nan_to_num(returns) - - sharpe_ratio = 0.0 - if np.std(returns) > 1e-9: - # Assume daily data for annualization? Needs correct frequency. - # Let's assume minutes, T = minutes in a year - T = 365 * 24 * 60 - sharpe_ratio = (np.mean(returns) / np.std(returns)) * np.sqrt(T) - - max_drawdown = self._calculate_max_drawdown(portfolio_values) - results = { - 'final_value': portfolio_values[-1], - 'total_return': portfolio_values[-1] - 1.0, - 'sharpe_ratio': sharpe_ratio, - 'max_drawdown': max_drawdown, - 'positions': positions_history, - 'portfolio_values': portfolio_values, - 'predictions': predictions, # Stored predicted returns - 'uncertainties': uncertainties # Stored confidence scores - } - print("Simple backtest complete.") - print(f" Final Value: {results['final_value']:.4f}, Sharpe: {results['sharpe_ratio']:.4f}, Max Drawdown: {results['max_drawdown']:.4f}") - return results - - def _calculate_max_drawdown(self, portfolio_values): - """Calculate maximum drawdown from portfolio values""" - if len(portfolio_values) < 2: - return 0.0 - values = np.array(portfolio_values) - running_max = np.maximum.accumulate(values) - running_max[running_max == 0] = 1.0 # Avoid division by zero - drawdown = (running_max - values) / running_max - return np.max(drawdown) - - def save(self, path): - """Save the integrated trading system (GRU model/scalers and SAC agent).""" - print(f"Saving trading system to {path}...") - os.makedirs(path, exist_ok=True) - gru_path = os.path.join(path, "gru_model") - sac_path = os.path.join(path, "sac_agent") - os.makedirs(gru_path, exist_ok=True) - os.makedirs(sac_path, exist_ok=True) - - if self.gru_model: - # V7-V6 Update: Use CryptoGRUModel save - self.gru_model.save(gru_path) # CryptoGRUModel.save handles model+scalers - if self.sac_agent: - self.sac_agent.save(sac_path) - print("Trading system saved.") - - def load(self, path): - """Load the integrated trading system (GRU model/scalers and SAC agent).""" - print(f"Loading trading system from {path}...") - gru_path = os.path.join(path, "gru_model") - sac_path = os.path.join(path, "sac_agent") - models_loaded = True - try: - if os.path.isdir(gru_path): - # V7-V6 Update: Instantiate and load CryptoGRUModel - self.gru_model = CryptoGRUModel() - if self.gru_model.load(gru_path): - self.feature_scaler = self.gru_model.feature_scaler - self.y_scaler = self.gru_model.y_scaler - # V7.23: Set feature indices after loading scaler - if self._set_feature_indices(): - gru_ok = True - else: - print("Warning: Failed to set feature indices after loading GRU scaler.") - gru_ok = False # Mark as failed if indices can't be set - else: - print("Warning: Failed to load GRU model/scalers.") - else: print(f"Warning: GRU model directory not found: {gru_path}") - if os.path.isdir(sac_path): - # V7.23 Use correct state dim when loading - self.sac_agent = SimplifiedSACTradingAgent(state_dim=5) - if self.sac_agent.load(sac_path): - sac_ok = True - else: - print("Warning: Failed to load SAC agent.") - else: print(f"Warning: SAC agent directory not found: {sac_path}") - # V7.23 Check if both models loaded successfully - models_loaded = gru_ok and sac_ok - status = 'successful' if models_loaded else ('partially successful' if gru_ok or sac_ok else 'failed') - print(f"Trading system loading {status}.") - except Exception as e: - print(f"An error occurred during loading: {e}") - # Ensure flags reflect failure - models_loaded = False - - return models_loaded # Return overall success status - - # V7-V6 Update: Adapt GRU training call for price regression - def train_gru(self, train_data: pd.DataFrame, val_data: pd.DataFrame, - prediction_horizon: int, epochs=20, batch_size=32, - patience=10, # Revert to V6 defaults - model_save_dir='models/gru_predictor_trained'): - """V7 Adaptation: Train GRU model (V6 style - price prediction). """ - print(f"--- Starting GRU Training Pipeline (V6 Adaptation) ---") - print(f"Preprocessing training data (Target: Future Price, Horizon: {prediction_horizon})...") - # V7-V6 Update: Use renamed preprocessing function - train_features_df, train_target_price_ser, train_start_price_ser = self._preprocess_data_for_gru_training(train_data, prediction_horizon) - if train_features_df is None: - print("ERROR: Failed to preprocess training data for GRU.") - return None - - print(f"Preprocessing validation data (Target: Future Price, Horizon: {prediction_horizon})...") - # V7-V6 Update: Use renamed preprocessing function - val_features_df, val_target_price_ser, val_start_price_ser = self._preprocess_data_for_gru_training(val_data, prediction_horizon) - if val_features_df is None: - print("ERROR: Failed to preprocess validation data for GRU.") - return None - - # --- V7-V6 Update: Add Scaling Step --- - logging.info("Scaling features and target price...") - try: - # Feature Scaling (Use StandardScaler as in V6 preproc example, but make flexible) - # Consider making scaler type a parameter if needed - feature_scaler = StandardScaler() - # Ensure features are numeric and handle potential NaNs before scaling - train_features_numeric = train_features_df.select_dtypes(include=np.number).fillna(0) - val_features_numeric = val_features_df.select_dtypes(include=np.number).fillna(0) - - # Log columns used for scaling - logging.info(f"Feature columns for scaling: {train_features_numeric.columns.tolist()}") - - train_features_scaled = feature_scaler.fit_transform(train_features_numeric) - val_features_scaled = feature_scaler.transform(val_features_numeric) - self.feature_scaler = feature_scaler # Store fitted scaler - logging.info(f"Feature scaling complete. Scaler type: {type(feature_scaler)}") - - # Target Scaling (Use MinMaxScaler as in V6) - y_scaler = MinMaxScaler() # Default range (0, 1) is fine for prices - # Handle potential NaNs in target before scaling - train_target_numeric = train_target_price_ser.fillna(0).values.reshape(-1, 1) - val_target_numeric = val_target_price_ser.fillna(0).values.reshape(-1, 1) - - train_target_scaled = y_scaler.fit_transform(train_target_numeric) - val_target_scaled = y_scaler.transform(val_target_numeric) - self.y_scaler = y_scaler # Store fitted scaler - logging.info(f"Target price scaling complete. Scaler type: {type(y_scaler)}") - - # V7.23: Set feature indices now that scaler is fitted - if not self._set_feature_indices(): - logging.error("Failed to set feature indices after scaling. Cannot proceed with GRU training.") - return None # Exit if indices aren't found - - except Exception as e: - logging.error(f"Error during scaling: {e}", exc_info=True) - return None # Corrected indentation - # --- End Scaling Step --- - - # --- V7-V6 Update: Add Sequence Creation Step --- - logging.info("Creating sequences...") - try: # Added missing except block below - # Pass scaled features as DataFrame to preserve column info if needed by create_sequences_v2? - # Or just pass the numpy array? Let's assume numpy is fine for V6 logic. - # Also need to pass the UNscaled start price series correctly aligned. - - # Align start price series with the scaled data length before sequencing - train_start_price_aligned_ser = train_start_price_ser.loc[train_features_df.index] # Index from before scaling - val_start_price_aligned_ser = val_start_price_ser.loc[val_features_df.index] - - X_train, y_train_scaled, y_start_price_train = create_sequences_v2( - pd.DataFrame(train_features_scaled, index=train_features_df.index), # Pass features with index - pd.Series(train_target_scaled.flatten(), index=train_features_df.index), # Pass target with index - train_start_price_aligned_ser, # Pass aligned start price Series - self.gru_lookback - ) - X_val, y_val_scaled, y_start_price_val = create_sequences_v2( - pd.DataFrame(val_features_scaled, index=val_features_df.index), - pd.Series(val_target_scaled.flatten(), index=val_features_df.index), - val_start_price_aligned_ser, - self.gru_lookback - ) - - if X_train is None or X_val is None or y_train_scaled is None or y_val_scaled is None: - logging.error("Sequence creation failed. Returned None.") - return None - - except Exception as e: # Added missing except block - logging.error(f"Error creating sequences: {e}", exc_info=True) # Corrected indentation - return None # Corrected indentation - # --- End Sequence Creation --- - - # Debug info about the final data shapes for training - print(f"Data ready for GRU training:") - print(f" X_train shape: {X_train.shape}") - print(f" y_train_scaled shape: {y_train_scaled.shape}") - print(f" X_val shape: {X_val.shape}") - print(f" y_val_scaled shape: {y_val_scaled.shape}") - - # Initialize GRU model if needed - if self.gru_model is None: - # V7-V6 Update: Use CryptoGRUModel - self.gru_model = CryptoGRUModel() - - # --- V7-V6 Update: Call CryptoGRUModel.train --- - print(f"Training GRU model (V6 Adaptation) for {epochs} epochs...") - history = self.gru_model.train( - X_train, y_train_scaled, # Pass sequences and scaled targets - X_val, y_val_scaled, - feature_scaler=self.feature_scaler, # Pass fitted scalers - y_scaler=self.y_scaler, - epochs=epochs, - batch_size=batch_size, - patience=patience, - model_save_dir=model_save_dir # Pass save dir - # Removed LR args, handled within model train - ) - # --- End Updated Call --- - - if history is not None: - print(f"GRU model training complete. Model and scalers saved to {model_save_dir}.") - # Store scalers in TradingSystem instance after successful training - self.feature_scaler = self.gru_model.feature_scaler - self.y_scaler = self.gru_model.y_scaler - - # Optional: Plot training history - try: - history_plot_path = os.path.join(model_save_dir, 'gru_training_history.png') - self.gru_model.plot_training_history(history, save_path=history_plot_path) - except Exception as plot_e: - logging.warning(f"Could not plot GRU training history: {plot_e}") - - return history - else: - print("GRU model training failed.") - return None - - -class ExtendedBacktester: - """ - Enhanced backtesting framework for GRU+SAC integration. - V7 Efficiency Update: Pre-computes GRU outputs for faster backtesting. - """ - def __init__(self, trading_system: TradingSystem, initial_capital=10000.0, transaction_cost=0.0001, instrument_label="Unknown Instrument"): - print("Initializing ExtendedBacktester...") - self.trading_system = trading_system - self.initial_capital = initial_capital - self.transaction_cost = transaction_cost - self.instrument_label = instrument_label # Store instrument label - self.portfolio_values = [] - self.positions = [] - self.predictions = [] # Will store predicted returns used for state - self.uncertainties_used_in_state = [] # Will store uncertainty sigmas used for state - self.all_precomputed_uncertainties = None # Stores the full array from evaluate() - self.actions = [] # Actions taken by SAC - self.actual_returns = [] # Actual price returns observed - self.trade_counts = 0 - self.trade_history = [] - self.timestamps = [] # Timestamps for portfolio/action logging - self.gru_lookback = self.trading_system.gru_lookback # Get from TradingSystem - self.buy_hold_values = None # Initialize B&H value storage - # Add storage for GRU prediction plot data - self.prediction_timestamps = None - self.predicted_prices = None - self.true_prices_for_pred = None - # self.uncertainties already exists for uncertainty sigma - print("ExtendedBacktester initialized.") - - def backtest(self, test_data, verbose=True, prediction_horizon=1): - """V7 Efficiency Update: Pre-compute GRU outputs.""" - if not isinstance(test_data, pd.DataFrame) or 'close' not in test_data.columns: - raise ValueError("test_data must be a pandas DataFrame with a 'close' column.") - if self.trading_system.gru_model is None or not (self.trading_system.gru_model.is_trained or self.trading_system.gru_model.is_loaded): - raise ValueError("Cannot backtest: GRU model not ready.") - if self.trading_system.feature_scaler is None or self.trading_system.y_scaler is None: - raise ValueError("Cannot backtest: Scalers not loaded/trained.") - if not self.trading_system.sac_agent: - raise ValueError("Cannot backtest: SAC Agent not initialized.") - - print("--- Pre-computing GRU outputs for Test Set ---") - - # 1. Preprocess test data (Features, Target Price, Start Price) - feature_logger.info("Preprocessing test data...") # Use feature_logger - test_features_df, test_target_price_ser, test_start_price_ser = self.trading_system._preprocess_data_for_gru_training(test_data, prediction_horizon) - if test_features_df is None: - feature_logger.error("Failed to preprocess test data.") # Use feature_logger - return None - - # 2. Scale Features and Target - feature_logger.info("Scaling test features and targets...") # Use feature_logger - try: - test_features_scaled = self.trading_system.feature_scaler.transform(test_features_df.select_dtypes(include=np.number).fillna(0)) - test_target_scaled = self.trading_system.y_scaler.transform(test_target_price_ser.fillna(0).values.reshape(-1, 1)) - except Exception as e: - feature_logger.error(f"Error scaling test data: {e}", exc_info=True) # Use feature_logger - return None - - # 3. Create Sequences - feature_logger.info("Creating test sequences...") # Use feature_logger - try: - test_start_price_aligned_ser = test_start_price_ser.loc[test_features_df.index] # Align before sequencing - X_test, y_test_scaled_seq, y_start_price_test_seq = create_sequences_v2( - pd.DataFrame(test_features_scaled, index=test_features_df.index), - pd.Series(test_target_scaled.flatten(), index=test_features_df.index), - test_start_price_aligned_ser, - self.gru_lookback - ) - if X_test is None or y_test_scaled_seq is None or y_start_price_test_seq is None: - feature_logger.error("Sequence creation failed for test data.") # Use feature_logger - return None - except Exception as e: - feature_logger.error(f"Error creating test sequences: {e}", exc_info=True) # Use feature_logger - return None - - # 4. Run GRU Evaluate Once - feature_logger.info(f"Pre-computing GRU predictions/uncertainty for {len(X_test)} test sequences...") # Use feature_logger - try: - eval_results = self.trading_system.gru_model.evaluate(X_test, y_test_scaled_seq, y_start_price_test_seq, n_mc_samples=30) - if eval_results is None: - feature_logger.error("GRU evaluate failed on test sequences.") # Use feature_logger - return None - all_pred_returns = eval_results['pred_percent_change'] - all_uncertainties = eval_results['mc_unscaled_std_dev'] - # Store data needed for GRU prediction plot - self.predicted_prices = eval_results.get('predicted_unscaled_prices') - self.true_prices_for_pred = eval_results.get('true_unscaled_prices') - # V7 Fix: Store full uncertainty array separately - self.all_precomputed_uncertainties = all_uncertainties # Store full array for plotting - # self.uncertainties will store the ones used in the loop state below - self.uncertainties_used_in_state = [] # Initialize loop list - - # Store the corresponding timestamps - # The prediction results align with the *end* of the sequences - # The index of test_start_price_aligned_ser aligns with the start of the window - # So, the timestamp for the k-th prediction is at index k + gru_lookback - 1 - num_sequences_eval = len(eval_results['predicted_unscaled_prices']) # Get length from results - if len(test_start_price_aligned_ser) >= self.gru_lookback - 1 + num_sequences_eval: - prediction_indices = test_start_price_aligned_ser.index[self.gru_lookback - 1 : self.gru_lookback - 1 + num_sequences_eval] - self.prediction_timestamps = prediction_indices.to_list() - else: - feature_logger.error("Cannot extract prediction timestamps due to index length mismatch.") - self.prediction_timestamps = [] # Assign empty list on error - - # Validate stored prediction data - if self.predicted_prices is None or self.true_prices_for_pred is None or self.uncertainties_used_in_state is None: - feature_logger.error("Missing price/uncertainty data from GRU evaluate results.") - return None - if len(self.prediction_timestamps) != len(self.predicted_prices): - feature_logger.error("Prediction timestamp length mismatch.") - return None - - except Exception as e: - feature_logger.error(f"Error during pre-computation evaluate call on test data: {e}", exc_info=True) # Use feature_logger - return None - - # --- Calculate num_sequences FIRST --- - # Moved from below to be available for feature extraction - num_sequences = len(all_pred_returns) - if num_sequences <= 1: - feature_logger.warning("Not enough sequences generated from test data to perform backtest.") # Use feature_logger - return None - - # --- V7.13 START: Extract Momentum and Volatility Features --- - # These features should be aligned with the *sequences* used for GRU prediction - # Use test_features_df which is aligned with the sequences' start times - try: - # Check if required columns exist in the preprocessed features - required_state_cols = ['return_5m', 'volatility_14d'] # Example columns used in state - if not all(col in test_features_df.columns for col in required_state_cols): - missing_cols = [col for col in required_state_cols if col not in test_features_df.columns] - feature_logger.error(f"Missing required state columns in test_features_df: {missing_cols}") - return None - - # Select the features corresponding to the *end* of each sequence - # The evaluate results (all_pred_returns) align with the *output* of the sequences - # We need features from the *last timestep* of the input sequence (index k + gru_lookback - 1) - # Slice test_features_df to match the number of sequences evaluated - if len(test_features_df) >= self.gru_lookback - 1 + num_sequences: - aligned_feature_indices = test_features_df.index[self.gru_lookback - 1 : self.gru_lookback - 1 + num_sequences] - aligned_features_for_state = test_features_df.loc[aligned_feature_indices] - - # V7.14 Use return_5m for momentum_5, volatility_14d for volatility_20 (adapt names if needed) - all_momentum_5 = aligned_features_for_state['return_5m'].values - all_volatility_20 = aligned_features_for_state['volatility_14d'].values # Assuming 'volatility_14d' is the correct column - - if len(all_momentum_5) != num_sequences or len(all_volatility_20) != num_sequences: - feature_logger.error(f"Length mismatch extracting state features: Mom5({len(all_momentum_5)}), Vol20({len(all_volatility_20)}) vs NumSeq({num_sequences})") - return None - else: - feature_logger.error("Length mismatch: test_features_df too short to extract aligned state features.") - return None - - except KeyError as e: - feature_logger.error(f"KeyError extracting state features (momentum/volatility): {e}") - return None - except Exception as e: - feature_logger.error(f"Error extracting state features (momentum/volatility): {e}", exc_info=True) - return None - # --- V7.13 END: Extract Momentum and Volatility Features --- - - feature_logger.info("--- GRU Pre-computation Complete. Starting Backtest Simulation ---") # Use feature_logger - - # --- Refined Timestamp Handling START (from original test_data before preprocessing) --- - time_col_name = None - initial_timestamp_for_loop = None - - # Use a fresh copy for timestamp handling to avoid modifying the original input df - test_data_copy = test_data.copy() - - if 'timestamp' in test_data_copy.columns: - time_col_name = 'timestamp' - test_data_copy[time_col_name] = pd.to_datetime(test_data_copy[time_col_name]) - # Get timestamp corresponding to the *end* of the first sequence - if len(test_data_copy) > self.gru_lookback: - initial_timestamp_for_loop = test_data_copy.iloc[self.gru_lookback][time_col_name] - else: - raise ValueError("Test data too short for lookback after timestamp check.") - print(f"Using '{time_col_name}' column for time.") - - elif isinstance(test_data_copy.index, pd.DatetimeIndex): - feature_logger.info("Using DatetimeIndex for time.") # Changed from print - original_index_name = test_data_copy.index.name - if len(test_data_copy) > self.gru_lookback: - initial_timestamp_for_loop = test_data_copy.index[self.gru_lookback] - else: - raise ValueError("Test data index too short for lookback after timestamp check.") - - test_data_copy = test_data_copy.reset_index() - - # Find the name of the column created from the index - if original_index_name and original_index_name in test_data_copy.columns: - time_col_name = original_index_name - elif 'index' in test_data_copy.columns and pd.api.types.is_datetime64_any_dtype(test_data_copy['index']): - time_col_name = 'index' - else: - found_col = None - for col in test_data_copy.columns: - if pd.api.types.is_datetime64_any_dtype(test_data_copy[col]): - found_col = col; break - if found_col: time_col_name = found_col - else: raise ValueError("Could not identify the reset index column containing timestamps.") - print(f"Identified reset index time column as: '{time_col_name}'") - else: - raise ValueError("Could not find 'timestamp' column or DatetimeIndex.") - # --- Refined Timestamp Handling END --- - - # --- Backtest Loop Initialization --- - self.portfolio_values = [self.initial_capital] - self.positions = [0.0] # Position held *before* action at step i - self.predictions = [] - self.uncertainties_used_in_state = [] - self.actions = [] - self.actual_returns = [] - self.trade_history = [] - self.trade_counts = 0 - # Initialize timestamp list with the timestamp for the END of the first sequence period - self.timestamps = [initial_timestamp_for_loop] - - portfolio_value = self.initial_capital - current_position = 0.0 # Position held *before* taking action at step i - - # --- Main Backtest Loop (Iterates over pre-computed results) --- - # Use original close prices from the time-handled copy - original_close_prices = test_data_copy['close'] - - for i in tqdm(range(num_sequences - 1), desc="Backtest Simulation", disable=not verbose): - # 1. Get state from pre-computed results - pred_return = all_pred_returns[i] - uncertainty_sigma = all_uncertainties[i] - # V7.23 Get SCALED momentum and volatility for the current step - momentum_5_scaled = all_momentum_5[i] - volatility_20_scaled = all_volatility_20[i] - - # Calculate z-proxy (Position as proxy for risk aversion) - # V7.23 Use SCALED volatility for consistency in state - z_proxy = current_position * volatility_20_scaled - - # Construct 5D state using SCALED momentum/volatility - state = np.array([ - pred_return, - uncertainty_sigma, - z_proxy, - momentum_5_scaled, - volatility_20_scaled - ], dtype=np.float32) - - # Handle NaNs/Infs in state - if np.any(np.isnan(state)) or np.any(np.isinf(state)): - feature_logger.warning(f"NaN/Inf detected in state at step {i}. Replacing with 0. State: {state}") - state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0) - - # 2. Get deterministic action from SAC agent - action_tuple = self.trading_system.sac_agent.get_action(state, deterministic=True) - # V7.14 Ensure unpacking handles potential tuple vs single value return if agent API changes - if isinstance(action_tuple, (tuple, list)) and len(action_tuple) > 0: - raw_action = action_tuple[0] # Take the first element (the action) - else: - # Fallback if the API returns just the action (for robustness) - raw_action = action_tuple - feature_logger.warning("SAC agent get_action did not return a tuple as expected. Assuming single value is action.") - - # V7.14 Ensure raw_action is a scalar if action_dim is 1 - if self.trading_system.sac_agent.action_dim == 1 and isinstance(raw_action, (np.ndarray, list)): - raw_action = raw_action[0] - - # 3. Calculate PnL and portfolio value - position_change = raw_action - current_position - - # Map sequence index 'i' back to original dataframe index for price lookup - # Use the index from the time-handled copy (test_data_copy) - original_idx_t = i + self.gru_lookback # End of sequence i / Start of action period - original_idx_t_plus_1 = original_idx_t + 1 # End of action period - - try: - # Use prices from the time-handled copy - current_price = original_close_prices.iloc[original_idx_t] - next_price = original_close_prices.iloc[original_idx_t_plus_1] - if current_price != 0: - price_return = (next_price / current_price) - 1.0 - else: - price_return = 0.0 - except IndexError: - feature_logger.warning(f"IndexError accessing original close prices for PnL at step {i} (original indices {original_idx_t}, {original_idx_t_plus_1}). Ending backtest early.") - break - except Exception as e: - feature_logger.error(f"Error accessing original close prices for PnL at step {i}: {e}") - break - - tx_cost = abs(position_change) * portfolio_value * self.transaction_cost - # PnL based on position held *during* the period (current_position) - position_pnl = current_position * price_return * portfolio_value - new_portfolio_value = portfolio_value + position_pnl - tx_cost - - # 4. Store results & update state for next iteration - self.predictions.append(pred_return) - # V7 Fix: Append uncertainty used in state to the correct list - self.uncertainties_used_in_state.append(uncertainty_sigma) - self.actions.append(raw_action) - self.actual_returns.append(price_return) - self.portfolio_values.append(new_portfolio_value) - self.positions.append(raw_action) # Store position *after* action - try: - # Use .loc with the identified time_col_name from the time-handled copy - self.timestamps.append(test_data_copy.loc[original_idx_t_plus_1, time_col_name]) - except KeyError: - feature_logger.warning(f"Could not find timestamp for index {original_idx_t_plus_1}. Appending last known timestamp.") - self.timestamps.append(self.timestamps[-1]) - - if abs(position_change) > 0.01: - trade_timestamp = test_data_copy.loc[original_idx_t, time_col_name] # Timestamp when action is decided - trade = {'timestamp': trade_timestamp, 'price': current_price, 'old_position': current_position, - 'new_position': raw_action, 'prediction': pred_return, 'uncertainty': uncertainty_sigma, - 'portfolio_value_before': portfolio_value, 'portfolio_value_after': new_portfolio_value, 'transaction_cost': tx_cost} - self.trade_history.append(trade) - self.trade_counts += 1 - - portfolio_value = new_portfolio_value - current_position = raw_action # Position for the *next* period - - if portfolio_value <= 0: - print(f"Portfolio zeroed at step {i}. Stopping.") - # Adjust fill length based on remaining sequences - fill_len = (num_sequences - 1) - (i + 1) - self.portfolio_values.extend([0.0] * fill_len) - self.positions.extend([current_position] * fill_len) # Fill with last position - self.predictions.extend([0.0] * fill_len) - # V7 Fix: Extend the correct list - self.uncertainties_used_in_state.extend([0.0] * fill_len) - self.actions.extend([current_position] * fill_len) - self.actual_returns.extend([0.0] * fill_len) - # Extend timestamps carefully using the time-handled copy - last_valid_orig_idx = original_idx_t_plus_1 - if last_valid_orig_idx + fill_len <= len(test_data_copy) - 1: - remaining_timestamps = test_data_copy.loc[last_valid_orig_idx + 1 : last_valid_orig_idx + fill_len, time_col_name].tolist() - self.timestamps.extend(remaining_timestamps) - else: - # If fill_len exceeds available data, add NaT or last timestamp - available_count = len(test_data_copy) - 1 - last_valid_orig_idx - if available_count > 0: - remaining_timestamps = test_data_copy.loc[last_valid_orig_idx + 1 : last_valid_orig_idx + available_count, time_col_name].tolist() - self.timestamps.extend(remaining_timestamps) - if fill_len - available_count > 0: - self.timestamps.extend([pd.NaT] * (fill_len - available_count)) - break # Corrected indentation again - # --- End Backtest Loop --- - - # --- Calculate Buy and Hold Benchmark --- - feature_logger.info("Calculating Buy and Hold benchmark...") - try: - # Use the original test_data before reset_index if applicable - start_price_bh = test_data['close'].iloc[0] - # Use the close price corresponding to the *last timestamp* used in the strategy backtest - # The last timestamp added was for original_idx_t_plus_1 where i = num_sequences - 2 - last_strategy_idx = (num_sequences - 2) + self.gru_lookback + 1 - end_price_bh = test_data_copy['close'].iloc[last_strategy_idx] - - if pd.isna(start_price_bh) or start_price_bh <= 1e-9: - feature_logger.warning("Could not calculate Buy & Hold: Invalid start price.") - self.buy_hold_values = None - else: - initial_assets_bh = self.initial_capital / start_price_bh - # Calculate B&H value for the same period as the strategy - # Use close prices from the time-handled copy aligned with strategy timestamps - relevant_close_prices = test_data_copy['close'].iloc[self.gru_lookback : last_strategy_idx + 1] - self.buy_hold_values = initial_assets_bh * relevant_close_prices - # Ensure the length matches the strategy portfolio values - if len(self.buy_hold_values) != len(self.portfolio_values): - feature_logger.warning(f"Buy & Hold length mismatch ({len(self.buy_hold_values)}) vs Strategy ({len(self.portfolio_values)}). Adjusting B&H series.") - # Pad or truncate B&H to match strategy length (more robust approach needed if indices differ significantly) - bh_series = pd.Series(self.buy_hold_values, index=self.timestamps) # Use strategy timestamps - self.buy_hold_values = bh_series.reindex(self.timestamps).fillna(method='ffill').values - else: - self.buy_hold_values = self.buy_hold_values.values # Convert to numpy - - bh_final_value = self.buy_hold_values[-1] - feature_logger.info(f"Buy & Hold Final Value: ${bh_final_value:.2f}") - except Exception as e: - feature_logger.error(f"Error calculating Buy & Hold benchmark: {e}", exc_info=True) - self.buy_hold_values = None - # --- End Buy and Hold Calculation --- - - # Calculate & print metrics - metrics = self._calculate_performance_metrics() # Now calculates B&H metrics internally - metrics['trade_history'] = self.trade_history - metrics['timestamps'] = self.timestamps # Already populated - if verbose: - print("\n--- Extended Backtest Complete ---") - print(f"Final portfolio value: {metrics.get('final_value', 0):.2f}") - # ... (print other metrics) ... - print("---------------------------------") - return metrics - - def _calculate_performance_metrics(self): - if len(self.portfolio_values) < 2: - return {'final_value': 0.0, 'initial_value': 0.0, 'total_return': 0.0, 'annual_return': 0.0, 'sharpe_ratio': 0.0, 'sortino_ratio': 0.0, 'volatility': 0.0, 'max_drawdown': 0.0, 'avg_position': 0.0, 'position_accuracy': 0.0, 'pred_accuracy': 0.0, 'prediction_rmse': 0.0, 'pred_return_corr': 0.0, 'pred_pos_corr': 0.0, 'unc_pos_corr': 0.0, 'total_trades': 0} - portfolio_values = np.array(self.portfolio_values) - positions = np.array(self.positions[1:]) - predictions = np.array(self.predictions) - actual_returns = np.array(self.actual_returns) - # V7 Fix: Use the uncertainties actually used in the state loop - uncertainties_used = np.array(self.uncertainties_used_in_state) - safe_portfolio_values = portfolio_values[:-1].copy() - safe_portfolio_values[safe_portfolio_values == 0] = 1e-9 - period_returns = np.diff(portfolio_values) / safe_portfolio_values - total_minutes = len(period_returns) - minutes_per_year = 365 * 24 * 60 - if total_minutes == 0: - # Return default metrics if no periods - return {'final_value': 0.0, 'initial_value': 0.0, 'total_return': 0.0, 'annual_return': 0.0, 'sharpe_ratio': 0.0, 'sortino_ratio': 0.0, 'volatility': 0.0, 'max_drawdown': 0.0, 'avg_position': 0.0, 'position_accuracy': 0.0, 'pred_accuracy': 0.0, 'prediction_rmse': 0.0, 'pred_return_corr': 0.0, 'pred_pos_corr': 0.0, 'unc_pos_corr': 0.0, 'total_trades': 0} - - final_value = portfolio_values[-1] - initial_value = portfolio_values[0] - total_return = (final_value / initial_value) - 1.0 if initial_value != 0 else 0.0 - annual_return = ((1.0 + total_return) ** (minutes_per_year / total_minutes)) - 1.0 if total_return > -1 else -1.0 - volatility = np.std(period_returns) * np.sqrt(minutes_per_year) - mean_return = np.mean(period_returns) - std_return = np.std(period_returns) - sharpe_ratio = (mean_return / std_return) * np.sqrt(minutes_per_year) if std_return > 1e-9 else 0.0 - negative_returns = period_returns[period_returns < 0] - downside_std = np.std(negative_returns) if len(negative_returns) > 0 else 0.0 - sortino_ratio = (mean_return / downside_std) * np.sqrt(minutes_per_year) if downside_std > 1e-9 else 0.0 - max_drawdown = self.trading_system._calculate_max_drawdown(portfolio_values) - # V7 Fix: Use uncertainties_used here - common_length = min(len(predictions), len(actual_returns), len(positions), len(uncertainties_used)) - if common_length > 0: - aligned_predictions = predictions[:common_length] - aligned_actual_returns = actual_returns[:common_length] - aligned_positions = positions[:common_length] - # V7 Fix: Use uncertainties_used here - aligned_uncertainties = uncertainties_used[:common_length] - avg_position = np.mean(np.abs(aligned_positions)) - non_zero_mask = aligned_positions != 0 - # Use nan_to_num on signs before comparison for robustness - sign_pos = np.sign(np.nan_to_num(aligned_positions[non_zero_mask])) - sign_ret = np.sign(np.nan_to_num(aligned_actual_returns[non_zero_mask])) - position_accuracy = np.mean(sign_pos == sign_ret) if np.any(non_zero_mask) else 0.0 - pred_accuracy = np.mean(np.sign(aligned_predictions) == np.sign(aligned_actual_returns)) - prediction_rmse = np.sqrt(np.mean((aligned_predictions - aligned_actual_returns)**2)) - - # Robust Correlation Calculation with warnings suppressed: - with warnings.catch_warnings(): - # Ignore the specific RuntimeWarning from np.corrcoef with zero variance - warnings.filterwarnings('ignore', r'invalid value encountered in divide', RuntimeWarning) - - pred_safe = np.nan_to_num(aligned_predictions) - ret_safe = np.nan_to_num(aligned_actual_returns) - pos_safe = np.nan_to_num(aligned_positions) - unc_safe = np.nan_to_num(aligned_uncertainties) - abs_pos_safe = np.nan_to_num(np.abs(aligned_positions)) - - pred_return_corr = np.corrcoef(pred_safe, ret_safe)[0, 1] if np.std(pred_safe) > 1e-9 and np.std(ret_safe) > 1e-9 else 0.0 - pred_pos_corr = np.corrcoef(pred_safe, pos_safe)[0, 1] if np.std(pred_safe) > 1e-9 and np.std(pos_safe) > 1e-9 else 0.0 - unc_pos_corr = np.corrcoef(unc_safe, abs_pos_safe)[0, 1] if np.std(unc_safe) > 1e-9 and np.std(abs_pos_safe) > 1e-9 else 0.0 - - # Explicitly handle potential NaNs from corrcoef if std dev check wasn't perfect - if np.isnan(pred_return_corr): pred_return_corr = 0.0 - if np.isnan(pred_pos_corr): pred_pos_corr = 0.0 - if np.isnan(unc_pos_corr): unc_pos_corr = 0.0 - else: - avg_position, position_accuracy, pred_accuracy, prediction_rmse, pred_return_corr, pred_pos_corr, unc_pos_corr = (0.0,) * 7 - - metrics = { - 'final_value': final_value, 'initial_value': initial_value, 'total_return': total_return, - 'annual_return': annual_return, 'sharpe_ratio': sharpe_ratio, 'sortino_ratio': sortino_ratio, - 'volatility': volatility, 'max_drawdown': max_drawdown, 'avg_position': avg_position, - 'position_accuracy': position_accuracy, 'pred_accuracy': pred_accuracy, - 'prediction_rmse': prediction_rmse, 'pred_return_corr': pred_return_corr, - 'pred_pos_corr': pred_pos_corr, 'unc_pos_corr': unc_pos_corr, - 'total_trades': self.trade_counts - } - - # Add Buy & Hold metrics if available - if self.buy_hold_values is not None and len(self.buy_hold_values) > 0: - bh_final = self.buy_hold_values[-1] - bh_initial = self.buy_hold_values[0] # Should correspond to initial capital approx. - metrics['buy_hold_final_value'] = bh_final - metrics['buy_hold_total_return'] = (bh_final / self.initial_capital) - 1.0 if self.initial_capital > 0 else 0.0 - else: - metrics['buy_hold_final_value'] = None - metrics['buy_hold_total_return'] = None - - for key, value in metrics.items(): - # Avoid rounding None - if isinstance(value, (float, np.float_)): - metrics[key] = round(value, 6) - return metrics - - # V7 Plot Update: Rewrite plot_results for 3 specific subplots - def plot_results(self, save_path='backtest_results_v7.png', n_std_dev=1.5): - """Generate and save a 3-panel plot: GRU Preds, SAC Actions, Portfolio Perf.""" - if plt is None: - print("Matplotlib not installed. Skipping plotting.") - return None - - # --- Log Timestamp Range Used for Plotting --- - if self.timestamps: - feature_logger.info(f"Plotting Results - Timestamp range: {min(self.timestamps)} to {max(self.timestamps)}") - else: - feature_logger.warning("Plotting Results - Timestamps list is empty.") - - # --- Construct Title Suffix --- - title_suffix = f" ({self.instrument_label})" - if self.timestamps: - try: - start_ts_str = pd.to_datetime(min(self.timestamps)).strftime('%Y-%m-%d %H:%M') - end_ts_str = pd.to_datetime(max(self.timestamps)).strftime('%Y-%m-%d %H:%M') - title_suffix = f" ({self.instrument_label} | {start_ts_str} to {end_ts_str})" - except Exception as e: - feature_logger.warning(f"Could not format timestamps for plot title: {e}") - else: - feature_logger.warning("Timestamps missing for plot title.") - - # --- Data Validation for Plotting --- - valid_portfolio = self.portfolio_values is not None and len(self.portfolio_values) > 1 - valid_bh = self.buy_hold_values is not None and len(self.buy_hold_values) == len(self.portfolio_values) - valid_actions = self.actions is not None and len(self.actions) == len(self.portfolio_values) -1 # Actions map to periods - # V7 Fix: Check all_precomputed_uncertainties for plotting - valid_gru_preds = (self.predicted_prices is not None and - self.true_prices_for_pred is not None and - self.all_precomputed_uncertainties is not None and # Check the correct variable - self.prediction_timestamps is not None and - len(self.prediction_timestamps) == len(self.predicted_prices)) - - # Timestamps for portfolio/action plots (start from the first action's effect) - portfolio_time_axis = self.timestamps # Should align with portfolio_values - - # Timestamps for GRU predictions (align with the prediction time) - gru_time_axis = pd.to_datetime(self.prediction_timestamps) if valid_gru_preds else None - - # Determine the number of plots needed - n_plots = sum([valid_gru_preds, valid_actions, valid_portfolio]) - if n_plots == 0: - print("Warning: No valid data available for any plots.") - return None - - fig, axs = plt.subplots(n_plots, 1, figsize=(15, 6 * n_plots), sharex=True) - # Ensure axs is always an array, even if n_plots is 1 - if n_plots == 1: - axs = [axs] - plot_idx = 0 - - # --- Subplot 1: GRU Price Prediction vs Actual --- - if valid_gru_preds: - ax = axs[plot_idx] - # Align true prices with prediction timestamps - true_prices = pd.Series(self.true_prices_for_pred, index=gru_time_axis) - pred_prices = pd.Series(self.predicted_prices, index=gru_time_axis) - # V7 Fix: Use all_precomputed_uncertainties for plotting - if len(self.all_precomputed_uncertainties) != len(gru_time_axis): - print(f"Warning: Uncertainty length ({len(self.all_precomputed_uncertainties)}) doesn't match time axis ({len(gru_time_axis)}) for plotting.") - # Attempt to align or skip plotting uncertainty - uncertainty = None # Skip fill_between if lengths mismatch - else: - uncertainty = pd.Series(self.all_precomputed_uncertainties, index=gru_time_axis) - - ax.plot(gru_time_axis, true_prices, label='Actual Price', color='#2ca02c', alpha=0.8, linewidth=1.5) - ax.plot(gru_time_axis, pred_prices, label='Predicted Price (GRU)', color='#1f77b4', alpha=0.8, linewidth=1.5) - - # Only plot uncertainty band if uncertainty Series was created successfully - if uncertainty is not None: - lower_bound = pred_prices - n_std_dev * uncertainty - upper_bound = pred_prices + n_std_dev * uncertainty - ax.fill_between(gru_time_axis, lower_bound, upper_bound, - color='#aec7e8', alpha=0.3, - label=f'Prediction +/- {n_std_dev} Std Dev') - - ax.set_title('GRU Price Prediction vs Actual' + title_suffix) - ax.set_ylabel('Price') - ax.legend() - ax.grid(True) - ax.ticklabel_format(style='plain', axis='y') - plot_idx += 1 - else: - print("Skipping GRU Prediction plot due to missing data.") - - # --- Subplot 2: SAC Agent Actions --- - if valid_actions: - ax = axs[plot_idx] - # Align actions with the correct timestamps (action taken at T determines position for T to T+1) - action_time_axis = portfolio_time_axis[:-1] # Actions align with the start of the period - if len(action_time_axis) != len(self.actions): - feature_logger.warning(f"Action timestamp length mismatch ({len(action_time_axis)}) vs actions ({len(self.actions)}). Skipping SAC plot.") - else: - ax.plot(action_time_axis, self.actions, label='SAC Position Size', drawstyle='steps-post', color='#ff7f0e') - ax.set_title('SAC Agent Position Size (-1 to 1)' + title_suffix) - ax.set_ylabel('Position') - ax.set_ylim(-1.1, 1.1) - ax.grid(True) - ax.legend() - plot_idx += 1 - else: - print("Skipping SAC Action plot due to missing data.") - - # --- Subplot 3: Portfolio Performance vs Buy & Hold --- - if valid_portfolio: - ax = axs[plot_idx] - ax.plot(portfolio_time_axis, self.portfolio_values, label=f'Strategy Value: ${self.portfolio_values[-1]:,.2f}', color='blue') - - if valid_bh: - ax.plot(portfolio_time_axis, self.buy_hold_values, label=f'Buy & Hold Value: ${self.buy_hold_values[-1]:,.2f}', color='orange', linestyle='--') - elif self.buy_hold_values is not None: - feature_logger.warning("Could not plot Buy & Hold line due to length mismatch.") - - # Add trade markers - if self.trade_history: - buys = [(pd.to_datetime(t['timestamp']), t['price']) for t in self.trade_history if t['new_position'] > t['old_position']] - sells = [(pd.to_datetime(t['timestamp']), t['price']) for t in self.trade_history if t['new_position'] < t['old_position']] - if buys: - buy_times, buy_prices = zip(*buys) - # Find corresponding portfolio values at buy times for marker placement - portfolio_series = pd.Series(self.portfolio_values, index=pd.to_datetime(portfolio_time_axis)) - buy_portfolio_values = portfolio_series.reindex(buy_times, method='ffill') - ax.plot(buy_times, buy_portfolio_values, '^', color='green', markersize=8, label='Buy Signal') - if sells: - sell_times, sell_prices = zip(*sells) - portfolio_series = pd.Series(self.portfolio_values, index=pd.to_datetime(portfolio_time_axis)) - sell_portfolio_values = portfolio_series.reindex(sell_times, method='ffill') - ax.plot(sell_times, sell_portfolio_values, 'v', color='red', markersize=8, label='Sell Signal') - - ax.set_title('Strategy Portfolio Value vs. Buy & Hold' + title_suffix) - ax.set_ylabel('Portfolio Value ($)') - ax.grid(True); ax.legend(); ax.ticklabel_format(style='plain', axis='y') - plot_idx += 1 - else: - print("Skipping Portfolio Performance plot due to missing data.") - - # Final adjustments - axs[-1].set_xlabel('Time') # Add xlabel only to the bottom-most plot - fig.autofmt_xdate() - plt.tight_layout() - - if save_path: - try: # Added missing except block below - plt.savefig(save_path); - print(f"Combined performance plot saved to {save_path}") - except Exception as e: # Added missing except block - print(f"Error saving combined plot: {e}") # Corrected indentation - # plt.show() # Optional: uncomment to display plot interactively - plt.close(); - return fig - - def generate_performance_report(self, report_path="backtest_performance_report.md"): - # ... (implementation remains the same, note update) - metrics = self._calculate_performance_metrics() - if not metrics: - feature_logger.error("Cannot generate report: No metrics calculated.") - return None - - report = f"# GRU+SAC Backtesting Performance Report\n\n" - report += f"Report generated on: {pd.Timestamp.now()}\n" - # Ensure timestamps list is not empty before accessing - if self.timestamps: - # --- Log Timestamp Range Used for Report Header --- - report_start_ts = self.timestamps[0] - report_end_ts = self.timestamps[-1] - feature_logger.info(f"Generating Report - Timestamp range: {report_start_ts} to {report_end_ts}") - report += f"Data range: {self.timestamps[0]} to {self.timestamps[-1]}\n" - report += f"Total duration: {self.timestamps[-1] - self.timestamps[0]}\n\n" - else: - feature_logger.warning("Generating Report - Timestamps list is empty.") - report += "Data range: N/A\n" - report += "Total duration: N/A\n\n" - - report += "## Strategy Performance Metrics\n\n" - report += f"* **Initial capital:** ${metrics.get('initial_value', 0):,.2f}\n" - report += f"* **Final portfolio value:** ${metrics.get('final_value', 0):,.2f}\n" - report += f"* **Total return:** {metrics.get('total_return', 0)*100:.2f}%\n" - report += f"* **Annualized return:** {metrics.get('annual_return', 0)*100:.2f}%\n" - report += f"* **Sharpe ratio (annualized):** {metrics.get('sharpe_ratio', 0):.4f}\n" - report += f"* **Sortino ratio (annualized):** {metrics.get('sortino_ratio', 0):.4f}\n" - report += f"* **Volatility (annualized):** {metrics.get('volatility', 0)*100:.2f}%\n" - report += f"* **Maximum drawdown:** {metrics.get('max_drawdown', 0)*100:.2f}%\n" - report += f"* **Total trades:** {metrics.get('total_trades', 0)}\n" - - # Add Buy and Hold Performance section - report += "\n## Buy and Hold Benchmark\n\n" - bh_final_value = metrics.get('buy_hold_final_value') - bh_total_return = metrics.get('buy_hold_total_return') - if bh_final_value is not None and bh_total_return is not None: - report += f"* **Final value (B&H):** ${bh_final_value:,.2f}\n" - report += f"* **Total return (B&H):** {bh_total_return*100:.2f}%\n" - else: - report += "* *Buy and Hold benchmark could not be calculated.*\n" - - report += "\n## Position & Prediction Analysis\n\n" - report += f"* **Average absolute position size:** {metrics.get('avg_position', 0):.4f}\n" - report += f"* **Position sign accuracy vs return:** {metrics.get('position_accuracy', 0)*100:.2f}%\n" - report += f"* **Prediction sign accuracy vs return:** {metrics.get('pred_accuracy', 0)*100:.2f}%\n" - report += f"* **Prediction RMSE (on returns):** {metrics.get('prediction_rmse', 0):.6f}\n" - report += "\n## Correlations\n\n" - report += f"* **Prediction-Return correlation:** {metrics.get('pred_return_corr', 0):.4f}\n" - report += f"* **Prediction-Position correlation:** {metrics.get('pred_pos_corr', 0):.4f}\n" - report += f"* **Uncertainty-Position Size correlation:** {metrics.get('unc_pos_corr', 0):.4f}\n" - report += "\n## Notes\n\n" - report += f"* Transaction cost used: {self.transaction_cost * 100:.4f}% per position change value.\n" - report += f"* GRU lookback period: {self.gru_lookback} minutes.\n" - report += "* V6 features + return features used.\n" - report += "* Uncertainty estimated via MC Dropout standard deviation.\n" - # report += "* Uncertainty estimated directly by GRU second head.\n" - if report_path: - try: - with open(report_path, "w") as f: - f.write(report) - print(f"Performance report saved to {report_path}") - except Exception as e: - print(f"Error saving performance report: {e}") - return report \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_calibration.py b/gru_sac_predictor/tests/test_calibration.py new file mode 100644 index 00000000..241d7b9a --- /dev/null +++ b/gru_sac_predictor/tests/test_calibration.py @@ -0,0 +1,89 @@ +""" +Tests for probability calibration (Sec 6 of revisions.txt). +""" +import pytest +import numpy as np +from scipy.stats import binomtest +from scipy.special import logit, expit +import os + +# Try to import the modules; skip tests if not found (e.g., path issues) +try: + from gru_sac_predictor.src import calibrate +except ImportError: + calibrate = None + +# --- Fixtures --- +@pytest.fixture(scope="module") +def calibration_data(): + """ + Generate sample raw probabilities and true outcomes. + Simulates an overconfident model (T_implied < 1) where true probability drifts. + """ + np.random.seed(42) + n_samples = 2500 + # Simulate drifting true probability centered around 0.5 + drift = 0.05 * np.sin(np.linspace(0, 3 * np.pi, n_samples)) + true_prob = np.clip(0.5 + drift + np.random.randn(n_samples) * 0.05, 0.05, 0.95) + # Simulate overconfidence (implied T ~ 0.7) + raw_logits = logit(true_prob) / 0.7 + p_raw = expit(raw_logits) + # Generate true outcomes + y_true = (np.random.rand(n_samples) < true_prob).astype(int) + return p_raw, y_true + +# --- Tests --- +@pytest.mark.skipif(calibrate is None, reason="Module gru_sac_predictor.src.calibrate not found") +def test_optimise_temperature(calibration_data): + """Check if optimise_temperature runs and returns a plausible value.""" + p_raw, y_true = calibration_data + optimal_T = calibrate.optimise_temperature(p_raw, y_true) + print(f"\nOptimised T: {optimal_T:.4f}") + # Expect T > 0. A T near 0.7 would undo the simulated effect. + assert optimal_T > 0.1 and optimal_T < 5.0, "Optimised temperature seems out of expected range." + +@pytest.mark.skipif(calibrate is None, reason="Module gru_sac_predictor.src.calibrate not found") +def test_calibration_hit_rate_threshold(calibration_data): + """ + Verify that the lower 95% CI of the hit-rate for non-zero calibrated + signals is >= 0.55 (using the module's EDGE_THR). + """ + p_raw, y_true = calibration_data + optimal_T = calibrate.optimise_temperature(p_raw, y_true) + p_cal = calibrate.calibrate(p_raw, optimal_T) + action_signals = calibrate.action_signal(p_cal) + + # Filter for non-zero signals + non_zero_idx = action_signals != 0 + if not np.any(non_zero_idx): + pytest.fail("No non-zero action signals generated for hit-rate test.") + + signals_taken = action_signals[non_zero_idx] + actual_direction = y_true[non_zero_idx] + + # Hit: signal matches actual direction (1 vs 1, -1 vs 0) + hits = np.sum((signals_taken == 1) & (actual_direction == 1)) + \ + np.sum((signals_taken == -1) & (actual_direction == 0)) + total_trades = len(signals_taken) + + if total_trades < 30: + pytest.skip(f"Insufficient non-zero signals ({total_trades}) for reliable CI.") + + # Calculate 95% lower CI using binomial test + try: + # Ensure hits is integer + hits = int(hits) + result = binomtest(hits, total_trades, p=0.5, alternative='greater') + lower_ci = result.proportion_ci(confidence_level=0.95).low + except Exception as e: + pytest.fail(f"Binomial test failed: {e}") + + hit_rate = hits / total_trades + required_threshold = calibrate.EDGE_THR # Use threshold from module + + print(f"\nCalibration Test: EDGE_THR={required_threshold:.3f}") + print(f" Trades={total_trades}, Hits={hits}, Hit Rate={hit_rate:.4f}") + print(f" 95% Lower CI: {lower_ci:.4f}") + + assert lower_ci >= required_threshold, \ + f"Hit rate lower CI ({lower_ci:.4f}) is below module threshold ({required_threshold:.3f})" \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_feature_pruning.py b/gru_sac_predictor/tests/test_feature_pruning.py new file mode 100644 index 00000000..49a7720b --- /dev/null +++ b/gru_sac_predictor/tests/test_feature_pruning.py @@ -0,0 +1,66 @@ +""" +Tests for feature pruning logic. +""" +import pytest +import pandas as pd +import numpy as np + +# Try to import the module; skip tests if not found +try: + from gru_sac_predictor.src import features +except ImportError: + features = None + +# --- Fixtures --- +@pytest.fixture +def sample_features_df(): + """Create a DataFrame with more columns than the whitelist.""" + data = { + # Whitelisted + "return_1m": np.random.randn(100), + "return_15m": np.random.randn(100), + "return_60m": np.random.randn(100), + "ATR_14": np.random.rand(100) * 0.01, + "volatility_14d": np.random.rand(100) * 0.02, + "chaikin_AD_10": np.random.randn(100) * 1000, + "svi_10": np.random.randn(100) * 500, + "EMA_10": 100 + np.random.randn(100), + "EMA_50": 100 + np.random.randn(100), + "MACD": np.random.randn(100) * 0.1, + "MACD_signal": np.random.randn(100) * 0.05, + "hour_sin": np.sin(np.linspace(0, 2*np.pi, 100)), + "hour_cos": np.cos(np.linspace(0, 2*np.pi, 100)), + # Non-whitelisted + "close": 100 + np.random.randn(100), + "open": 100 + np.random.randn(100), + "RSI_14": 50 + np.random.randn(100) * 10, # Assumed not in final whitelist + "some_other_feature": np.random.rand(100) + } + return pd.DataFrame(data) + +# --- Tests --- +@pytest.mark.skipif(features is None, reason="Module gru_sac_predictor.src.features not found") +def test_prune_features_uses_whitelist(sample_features_df): + """ + Verify prune_features returns only columns present in features.minimal_whitelist. + """ + df_in = sample_features_df + whitelist = features.minimal_whitelist + df_out = features.prune_features(df_in) + + print(f"\nWhitelist: {whitelist}") + print(f"Input columns: {df_in.columns.tolist()}") + print(f"Output columns: {df_out.columns.tolist()}") + + # Check that all output columns are in the whitelist + assert all(col in whitelist for col in df_out.columns), \ + "Output DataFrame contains columns not in the whitelist." + + # Check that all whitelist columns present in the input are also in the output + expected_cols = [col for col in whitelist if col in df_in.columns] + assert sorted(df_out.columns.tolist()) == sorted(expected_cols), \ + "Output columns do not match the expected intersection of input and whitelist." + + # Check that non-whitelisted columns are removed + assert "close" not in df_out.columns, "'close' column was not pruned." + assert "some_other_feature" not in df_out.columns, "'some_other_feature' was not pruned." \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_integration.py b/gru_sac_predictor/tests/test_integration.py new file mode 100644 index 00000000..3f95759f --- /dev/null +++ b/gru_sac_predictor/tests/test_integration.py @@ -0,0 +1,117 @@ +""" +Integration tests for cross-module interactions. +""" +import pytest +import os +import numpy as np +import tempfile +import json + +# Try to import the module; skip tests if not found +try: + from gru_sac_predictor.src import sac_agent + import tensorflow as tf # Needed for agent init/load +except ImportError: + sac_agent = None + tf = None + +@pytest.fixture +def sac_agent_for_integration(): + """Provides a basic SAC agent instance.""" + if sac_agent is None or tf is None: + pytest.skip("SAC Agent module or TF not found.") + # Use minimal params for saving/loading tests + agent = sac_agent.SACTradingAgent( + state_dim=5, action_dim=1, + buffer_capacity=100, min_buffer_size=10 + ) + # Build models + try: + agent.actor(tf.zeros((1, 5))) + agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))]) + agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))]) + agent.update_target_networks(tau=1.0) + except Exception as e: + pytest.fail(f"Failed to build agent models: {e}") + return agent + +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TF not found") +def test_save_load_metadata(sac_agent_for_integration): + """Test if metadata is saved and loaded correctly.""" + agent = sac_agent_for_integration + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, "sac_test_save") + agent.save(save_path) + + # Check if metadata file exists + meta_path = os.path.join(save_path, 'agent_metadata.json') + assert os.path.exists(meta_path), "Metadata file was not saved." + + # Create a new agent and load + new_agent = sac_agent.SACTradingAgent(state_dim=5, action_dim=1) + loaded_meta = new_agent.load(save_path) + + assert isinstance(loaded_meta, dict), "Load method did not return a dict." + assert loaded_meta.get('state_dim') == 5, "Loaded state_dim incorrect." + assert loaded_meta.get('action_dim') == 1, "Loaded action_dim incorrect." + # Check alpha status (default is auto_tune=True) + assert loaded_meta.get('log_alpha_saved') == True, "log_alpha status incorrect." + +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TF not found") +def test_replay_buffer_purge_on_change(sac_agent_for_integration): + """ + Simulate loading an agent where the edge_threshold has changed + and verify the buffer is cleared. + """ + agent_to_save = sac_agent_for_integration + original_edge_thr = 0.55 + agent_to_save.edge_threshold_config = original_edge_thr # Manually set for saving + + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, "sac_purge_test") + + # 1. Save agent with original threshold in metadata + agent_to_save.save(save_path) + meta_path = os.path.join(save_path, 'agent_metadata.json') + assert os.path.exists(meta_path) + with open(meta_path, 'r') as f: + saved_meta = json.load(f) + assert saved_meta.get('edge_threshold_config') == original_edge_thr + + # 2. Create a new agent instance to load into + new_agent = sac_agent.SACTradingAgent( + state_dim=5, action_dim=1, + buffer_capacity=100, min_buffer_size=10 + ) + # Build models for the new agent + try: + new_agent.actor(tf.zeros((1, 5))) + new_agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))]) + new_agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))]) + new_agent.update_target_networks(tau=1.0) + except Exception as e: + pytest.fail(f"Failed to build new agent models: {e}") + + # Add dummy data to the *new* agent's buffer *before* loading + for _ in range(20): + dummy_state = np.random.rand(5).astype(np.float32) + dummy_action = np.random.rand(1).astype(np.float32) + new_agent.buffer.add(dummy_state, dummy_action, 0.0, dummy_state, 0.0) + assert len(new_agent.buffer) == 20, "Buffer should have data before load." + + # 3. Simulate loading with a *different* current edge threshold config + current_config_edge_thr = 0.60 + assert abs(current_config_edge_thr - original_edge_thr) > 1e-6 + + loaded_meta = new_agent.load(save_path) + saved_edge_thr = loaded_meta.get('edge_threshold_config') + + # 4. Perform the check and clear if needed (simulating pipeline logic) + if saved_edge_thr is not None and abs(saved_edge_thr - current_config_edge_thr) > 1e-6: + print(f"\nEdge threshold mismatch detected (Saved={saved_edge_thr}, Current={current_config_edge_thr}). Clearing buffer.") + new_agent.clear_buffer() + else: + print(f"\nEdge threshold match or not saved. Buffer not cleared.") + + # 5. Assert buffer is now empty + assert len(new_agent.buffer) == 0, "Buffer was not cleared after edge threshold mismatch." \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_leakage.py b/gru_sac_predictor/tests/test_leakage.py new file mode 100644 index 00000000..f96d3860 --- /dev/null +++ b/gru_sac_predictor/tests/test_leakage.py @@ -0,0 +1,133 @@ +""" +Tests for data leakage (Sec 6 of revisions.txt). +""" +import pytest +import pandas as pd +import numpy as np + +# Assume test data is loaded via fixtures later +@pytest.fixture(scope="module") +def sample_data_for_leakage(): + """ + Provides sample features and target for leakage tests. + Includes correctly shifted features, a feature with direct leakage, + and a rolling feature calculated correctly vs incorrectly. + """ + np.random.seed(43) + dates = pd.date_range(start='2023-01-01', periods=500, freq='T') + n = len(dates) + df = pd.DataFrame(index=dates) + df['noise'] = np.random.randn(n) + df['close'] = 100 + np.cumsum(df['noise'] * 0.1) + df['y_ret'] = np.log(df['close'].shift(-1) / df['close']) + + # --- Features --- + # OK: Based on past noise + df['feature_ok_past_noise'] = df['noise'].shift(1) + # OK: Rolling mean on correctly shifted past data + df['feature_ok_rolling_shifted'] = df['noise'].shift(1).rolling(10).mean() + # LEAKY: Uses future return directly + df['feature_leaky_direct'] = df['y_ret'] + # LEAKY: Rolling mean calculated *before* shifting target relationship + df['feature_leaky_rolling_unaligned'] = df['close'].rolling(5).mean() + + # Drop rows with NaNs from shifts/rolls AND the last row where y_ret is NaN + df.dropna(inplace=True) + + # Define features and target for the test + y_target = df['y_ret'] + features_df = df.drop(columns=['close', 'y_ret', 'noise']) # Exclude raw data used for generation + + return features_df, y_target + +@pytest.mark.parametrize("leakage_threshold", [0.02]) +def test_feature_leakage_correlation(sample_data_for_leakage, leakage_threshold): + """ + Verify that no feature has correlation > threshold with the correctly shifted target. + """ + features_df, y_target = sample_data_for_leakage + + max_abs_corr = 0.0 + leaky_col = "None" + all_corrs = {} + + print(f"\nTesting {features_df.shape[1]} features for leakage (threshold={leakage_threshold})...") + for col in features_df.columns: + if pd.api.types.is_numeric_dtype(features_df[col]): + # Handle potential NaNs introduced by feature engineering (though fixture avoids it) + temp_df = pd.concat([features_df[col], y_target], axis=1).dropna() + if len(temp_df) < 0.5 * len(features_df): + print(f" Skipping {col} due to excessive NaNs after merging with target.") + continue + + correlation = temp_df[col].corr(temp_df['y_ret']) + all_corrs[col] = correlation + # print(f" Corr({col}, y_ret): {correlation:.4f}") + if abs(correlation) > max_abs_corr: + max_abs_corr = abs(correlation) + leaky_col = col + else: + print(f" Skipping non-numeric column: {col}") + + print(f"Correlations found: { {k: round(v, 4) for k, v in all_corrs.items()} }") + print(f"Maximum absolute correlation found: {max_abs_corr:.4f} (feature: {leaky_col})") + + assert max_abs_corr < leakage_threshold, \ + f"Feature '{leaky_col}' has correlation {max_abs_corr:.4f} > threshold {leakage_threshold}, suggesting leakage." + +@pytest.mark.skipif(features is None, reason="Module gru_sac_predictor.src.features not found") +def test_ta_feature_leakage(sample_data_for_leakage, leakage_threshold=0.02): + """ + Specifically test TA features (EMA, MACD etc.) for leakage. + Ensures they were calculated on shifted data. + """ + features_df, y_target = sample_data_for_leakage + # Add TA features using the helper (simulating pipeline) + # We need OHLC in the input df for add_ta_features + # Recreate a df with shifted OHLC + other features for TA calc + np.random.seed(43) # Ensure consistent data with primary fixture + dates = pd.date_range(start='2023-01-01', periods=500, freq='T') + n = len(dates) + df_ohlc = pd.DataFrame(index=dates) + df_ohlc['close'] = 100 + np.cumsum(np.random.randn(n) * 0.1) + df_ohlc['open'] = df_ohlc['close'].shift(1) * (1 + np.random.randn(n) * 0.001) + df_ohlc['high'] = df_ohlc[['open','close']].max(axis=1) * (1 + np.random.rand(n) * 0.001) + df_ohlc['low'] = df_ohlc[['open','close']].min(axis=1) * (1 - np.random.rand(n) * 0.001) + df_ohlc['volume'] = np.random.rand(n) * 1000 + + # IMPORTANT: Shift before calculating TA features + df_shifted_ohlc = df_ohlc.shift(1) + df_ta = features.add_ta_features(df_shifted_ohlc) + + # Align with the target (requires original non-shifted index) + df_ta = df_ta.loc[y_target.index] + + ta_features_to_test = [col for col in features.minimal_whitelist if col in df_ta.columns and col not in ["return_1m", "return_15m", "return_60m", "hour_sin", "hour_cos"]] + max_abs_corr = 0.0 + leaky_col = "None" + all_corrs = {} + + print(f"\nTesting {len(ta_features_to_test)} TA features for leakage (threshold={leakage_threshold})...") + print(f" Features: {ta_features_to_test}") + + for col in ta_features_to_test: + if pd.api.types.is_numeric_dtype(df_ta[col]): + temp_df = pd.concat([df_ta[col], y_target], axis=1).dropna() + if len(temp_df) < 0.5 * len(y_target): + print(f" Skipping {col} due to excessive NaNs after merging.") + continue + correlation = temp_df[col].corr(temp_df['y_ret']) + all_corrs[col] = correlation + if abs(correlation) > max_abs_corr: + max_abs_corr = abs(correlation) + leaky_col = col + else: + print(f" Skipping non-numeric TA column: {col}") + + print(f"TA Feature Correlations: { {k: round(v, 4) for k, v in all_corrs.items()} }") + print(f"Maximum absolute TA correlation found: {max_abs_corr:.4f} (feature: {leaky_col})") + + assert max_abs_corr < leakage_threshold, \ + f"TA Feature '{leaky_col}' has correlation {max_abs_corr:.4f} > threshold {leakage_threshold}, suggesting leakage from TA calculation." + +# test_label_timing is usually covered by the correlation test, so removed for brevity. \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_sac_sanity.py b/gru_sac_predictor/tests/test_sac_sanity.py new file mode 100644 index 00000000..8d44bf67 --- /dev/null +++ b/gru_sac_predictor/tests/test_sac_sanity.py @@ -0,0 +1,121 @@ +""" +Sanity checks for the SAC agent (Sec 6 of revisions.txt). +""" +import pytest +import numpy as np +import os + +# Try to import the agent; skip tests if not found +try: + from gru_sac_predictor.src import sac_agent + # Need TF for tensor conversion if testing agent directly + import tensorflow as tf +except ImportError: + sac_agent = None + tf = None + +# --- Fixtures --- +@pytest.fixture(scope="module") +def sac_agent_instance(): + """ + Provides a default SAC agent instance for testing. + Uses standard parameters suitable for basic checks. + """ + if sac_agent is None: + pytest.skip("SAC Agent module not found.") + # Use default params, state_dim=5 as per revisions + # Use fixed seeds for reproducibility in tests if needed inside agent + agent = sac_agent.SACTradingAgent( + state_dim=5, action_dim=1, + initial_lr=1e-4, # Use a common LR for test simplicity + buffer_capacity=1000, # Smaller buffer for testing + min_buffer_size=100, + target_entropy=-1.0 + ) + # Build the models eagerly + try: + agent.actor(tf.zeros((1, 5))) + agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))]) + agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))]) + # Copy weights to target networks + agent.update_target_networks(tau=1.0) + except Exception as e: + pytest.fail(f"Failed to build SAC agent models: {e}") + return agent + +@pytest.fixture(scope="module") +def sample_sac_inputs(): + """ + Generate sample states and corresponding directional signals. + Simulates states with varying edge and signal-to-noise. + """ + np.random.seed(44) + n_samples = 1500 + # Simulate GRU outputs and position + mu = np.random.randn(n_samples) * 0.0015 # Slightly higher variance + sigma = np.random.uniform(0.0005, 0.0025, n_samples) + # Simulate edge with clearer separation for testing signals + edge_base = np.random.choice([-0.15, -0.05, 0.0, 0.05, 0.15], n_samples, p=[0.2, 0.2, 0.2, 0.2, 0.2]) + edge = np.clip(edge_base + np.random.randn(n_samples) * 0.03, -1.0, 1.0) + z_score = np.abs(mu) / (sigma + 1e-9) + position = np.random.uniform(-1, 1, n_samples) + states = np.vstack([mu, sigma, edge, z_score, position]).T.astype(np.float32) + # Use a small positive/negative threshold for determining signal from edge + signals = np.where(edge > 0.02, 1, np.where(edge < -0.02, -1, 0)) + return states, signals + +# --- Tests --- +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TensorFlow not found") +def test_sac_agent_default_min_buffer(sac_agent_instance): + """Verify the default min_buffer_size is at least 10000.""" + agent = sac_agent_instance + # Note: Fixture currently initializes with specific values, overriding default. + # Re-initialize with defaults for this test. + default_agent = sac_agent.SACTradingAgent(state_dim=5, action_dim=1) + min_buffer = default_agent.min_buffer_size + print(f"\nAgent default min_buffer_size: {min_buffer}") + assert min_buffer >= 10000, f"Default min_buffer_size ({min_buffer}) is less than recommended 10000." + +@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TensorFlow not found") +def test_sac_action_variance(sac_agent_instance, sample_sac_inputs): + """ + Verify that the mean absolute action taken when the signal is non-zero + is >= 0.05. + """ + agent = sac_agent_instance + states, signals = sample_sac_inputs + + actions = [] + for state in states: + # Use deterministic action for this sanity check + action = agent.get_action(state, deterministic=True) + actions.append(action[0]) # get_action returns list/array + actions = np.array(actions) + + # Filter for non-zero signals based on the *simulated* edge + non_zero_signal_idx = signals != 0 + if not np.any(non_zero_signal_idx): + pytest.fail("No non-zero signals generated in fixture for SAC variance test.") + + actions_on_signal = actions[non_zero_signal_idx] + + if len(actions_on_signal) == 0: + # This case should ideally not happen if the above check passed + pytest.fail("Filtered actions array is empty despite non-zero signals.") + + mean_abs_action = np.mean(np.abs(actions_on_signal)) + + print(f"\nSAC Sanity Test: Mean Absolute Action (on signal != 0): {mean_abs_action:.4f}") + + # Check if the agent is outputting actions with sufficient magnitude + assert mean_abs_action >= 0.05, \ + f"Mean absolute action ({mean_abs_action:.4f}) is below threshold (0.05). Agent might be too timid or stuck near zero." + +@pytest.mark.skip(reason="Requires full backtest results which are not available in this unit test setup.") +def test_sac_reward_correlation(): + """ + Optional: Check if actions taken correlate positively with subsequent rewards. + NOTE: This test requires results from a full backtest run (actions vs rewards) + and cannot be reliably simulated or executed in this unit test. + """ + pass # Cannot implement without actual backtest results \ No newline at end of file diff --git a/gru_sac_predictor/tests/test_time_encoding.py b/gru_sac_predictor/tests/test_time_encoding.py new file mode 100644 index 00000000..728c3172 --- /dev/null +++ b/gru_sac_predictor/tests/test_time_encoding.py @@ -0,0 +1,94 @@ +""" +Tests for time encoding, specifically DST transitions. +""" +import pytest +import pandas as pd +import numpy as np +import pytz # For timezone handling + +@pytest.fixture(scope="module") +def generate_dst_timeseries(): + """ + Generate a minute-frequency timestamp series crossing DST transitions + for a specific timezone (e.g., US/Eastern). + """ + # Example: US/Eastern DST Start (e.g., March 10, 2024 2:00 AM -> 3:00 AM) + # Example: US/Eastern DST End (e.g., Nov 3, 2024 2:00 AM -> 1:00 AM) + tz = pytz.timezone('US/Eastern') + + # Create timestamps around DST start + dst_start_range = pd.date_range( + start='2024-03-10 01:00:00', end='2024-03-10 04:00:00', freq='T', tz=tz + ) + # Create timestamps around DST end + dst_end_range = pd.date_range( + start='2024-11-03 00:00:00', end='2024-11-03 03:00:00', freq='T', tz=tz + ) + + # Combine and ensure uniqueness/order (though disjoint here) + timestamps = dst_start_range.union(dst_end_range) + df = pd.DataFrame(index=timestamps) + df.index.name = 'timestamp' + return df + +def calculate_cyclical_features(df): + """Helper to calculate sin/cos features from a datetime index.""" + if not isinstance(df.index, pd.DatetimeIndex): + raise TypeError("Input DataFrame must have a DatetimeIndex.") + + # Ensure timezone is present (fixture provides it) + if df.index.tz is None: + print("Warning: Index timezone is None, assuming UTC for calculation.") + timestamp_source = df.index.tz_localize('utc') + else: + timestamp_source = df.index + + # Use UTC hour for consistent calculation if timezone handling upstream is complex + # Or use localized hour if pipeline guarantees consistent local TZ + # Here, let's use the localized hour provided by the fixture + hour_of_day = timestamp_source.hour + # minute_of_day = timestamp_source.hour * 60 + timestamp_source.minute # Alternative + + df['hour_sin'] = np.sin(2 * np.pi * hour_of_day / 24) + df['hour_cos'] = np.cos(2 * np.pi * hour_of_day / 24) + return df + + +def test_cyclical_features_continuity(generate_dst_timeseries): + """ + Check if hour_sin and hour_cos features are continuous (no large jumps) + across DST transitions, assuming calculation uses localized time. + If using UTC hour, continuity is guaranteed, but might not capture + local market patterns intended. + """ + df = generate_dst_timeseries + df = calculate_cyclical_features(df) + + # Check differences between consecutive values + sin_diff = df['hour_sin'].diff().abs() + cos_diff = df['hour_cos'].diff().abs() + + # Define a reasonable threshold for a jump (e.g., difference > value for 15 mins) + # Max change in sin(2*pi*h/24) over 1 minute is small. + # A jump of 1 hour means h changes by 1, argument changes by pi/12. + # Max diff sin(x+pi/12) - sin(x) is approx pi/12 ~ 0.26 + max_allowed_diff = 0.3 # Allow slightly more than 1 hour jump equivalent + + print(f"\nMax Sin Diff: {sin_diff.max():.4f}") + print(f"Max Cos Diff: {cos_diff.max():.4f}") + + assert sin_diff.max() < max_allowed_diff, \ + f"Large jump detected in hour_sin ({sin_diff.max():.4f}) around DST. Check time source/calculation." + assert cos_diff.max() < max_allowed_diff, \ + f"Large jump detected in hour_cos ({cos_diff.max():.4f}) around DST. Check time source/calculation." + + # Optional: Plot to visually inspect + # import matplotlib.pyplot as plt + # plt.figure() + # plt.plot(df.index, df['hour_sin'], '.-.', label='sin') + # plt.plot(df.index, df['hour_cos'], '.-.', label='cos') + # plt.title('Cyclical Features Across DST') + # plt.legend() + # plt.xticks(rotation=45) + # plt.tight_layout() + # plt.show() \ No newline at end of file