diff --git a/gru_sac_predictor/README.md b/gru_sac_predictor/README.md index 907675af..c78d3897 100644 --- a/gru_sac_predictor/README.md +++ b/gru_sac_predictor/README.md @@ -227,7 +227,7 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections 1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location. 2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation). 3. **Run Pipeline:** - ```bash + ```bash # From project root (develop/gru_sac_predictor/) python gru_sac_predictor/run.py --config path/to/your_config.yaml ``` diff --git a/gru_sac_predictor/src/backtester.py b/gru_sac_predictor/src/backtester.py index 8283014d..f80eb5ff 100644 --- a/gru_sac_predictor/src/backtester.py +++ b/gru_sac_predictor/src/backtester.py @@ -77,7 +77,7 @@ class Backtester: self.coverage_alarm_threshold_drop = self.cal_cfg.get('coverage_alarm_threshold_drop', 0.03) self.coverage_alarm_window = self.cal_cfg.get('coverage_alarm_window', 1000) # --- End --- # - + self.results_df: Optional[pd.DataFrame] = None self.metrics: Optional[Dict[str, Any]] = None @@ -98,11 +98,11 @@ class Backtester: self.ece_recalibration_threshold = self.cal_cfg.get('ece_recalibration_threshold', 0.03) def run_backtest( - self, + self, sac_agent_load_path: Optional[str], X_test_seq: np.ndarray, y_test_seq_dict: Dict[str, np.ndarray], - test_indices: pd.Index, + test_indices: pd.Index, gru_handler: GRUModelHandler, # --- Pass Calibrator Instances & Initial Params/Threshold --- # calibrator: Optional[Calibrator], # For Temp Scaling @@ -163,7 +163,7 @@ class Backtester: logger.error(f"Fold {fold_num}: Rolling calibration enabled for ternary case, but logits_test not provided.") return None, None, None # --- End Validation --- # - + # 1. Initialize SAC Agent & Load Weights # Ensure agent state dim matches the state construction below agent_state_dim = 5 # mu, sigma, edge, |mu|/sigma, position @@ -223,7 +223,7 @@ class Backtester: log_sigma_test_state = preds_test_state[1][:, 1].flatten() if preds_test_state[1].shape[-1] == 2 else np.log(preds_test_state[1].flatten() + 1e-9) sigma_test_state = np.exp(log_sigma_test_state) p_raw_fallback = p_raw_test # Use the passed raw probs if needed outside rolling - + # Extract actual returns and directions actual_ret_test = y_test_seq_dict.get('ret') dir_key = 'dir3' if is_ternary else 'dir' @@ -262,7 +262,7 @@ class Backtester: metrics_log = [] # Store periodic metrics step_correct_nonzero, step_count_nonzero = 0, 0 step_abs_actions = [] - + logger.info(f"Fold {fold_num}: Starting backtest simulation loop ({n_test} steps)...") for i in range(n_test): # --- Step i: Calibration --- @@ -332,7 +332,7 @@ class Backtester: # --- Step i: Update Position --- current_position = target_position - + # --- Step i: Update Metrics Log Data --- if abs(current_position) > 1e-6: step_count_nonzero += 1 @@ -455,7 +455,7 @@ class Backtester: logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})") current_optimal_T = new_T calibrator.optimal_T = new_T # Update instance - else: + else: logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).") last_recalib_step = i diff --git a/gru_sac_predictor/src/trading_pipeline.py b/gru_sac_predictor/src/trading_pipeline.py index 6939a171..9bfd264b 100644 --- a/gru_sac_predictor/src/trading_pipeline.py +++ b/gru_sac_predictor/src/trading_pipeline.py @@ -769,7 +769,7 @@ class TradingPipeline: except Exception as e: logging.error(f"Manual save of final feature whitelist failed: {e}", exc_info=True) # --- End Save Update --- # - + # --- MODIFIED: Prune the SCALED data splits using the determined whitelist --- # if self.X_train_scaled is None or self.X_val_scaled is None or self.X_test_scaled is None: logging.error("Scaled data splits not available for pruning.") @@ -780,7 +780,7 @@ class TradingPipeline: self.X_val_pruned = self.feature_engineer.prune_features(self.X_val_scaled, self.final_whitelist) self.X_test_pruned = self.feature_engineer.prune_features(self.X_test_scaled, self.final_whitelist) # --- End Modification --- # - + logging.info(f"Feature shapes after pruning scaled data: Train={self.X_train_pruned.shape}, Val={self.X_val_pruned.shape}, Test={self.X_test_pruned.shape}") # Verification and empty checks remain the same, using X_*_pruned @@ -1094,116 +1094,116 @@ class TradingPipeline: patience=patience ) - if self.gru_model is None: - logging.error("GRU model training failed. Exiting.") - sys.exit(1) - else: - # Save the newly trained model - saved_path = self.gru_handler.save() # Uses run_id from handler - if saved_path: - logging.info(f"Newly trained GRU model saved to {saved_path}") + if self.gru_model is None: + logging.error("GRU model training failed. Exiting.") + sys.exit(1) else: - logging.warning("Failed to save the newly trained GRU model.") - # Set the loaded ID to the current run ID - self.gru_model_run_id_loaded_from = self.run_id - logging.info(f"Using GRU model trained in current run: {self.run_id}") - - # --- V3 Output Contract: Plot Learning Curve --- # - if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True): - # Infer log dir path based on current models dir - log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs') - csv_log_path = os.path.join(log_dir, 'gru_history.csv') - if os.path.exists(csv_log_path): - logging.info(f"Plotting learning curve from {csv_log_path}...") - try: - history_df = pd.read_csv(csv_log_path) - - # Determine metric keys (handle v2 vs v3 differences if necessary) - loss_key = 'loss' - val_loss_key = 'val_loss' - acc_key = None - val_acc_key = None - if 'dir3_accuracy' in history_df.columns: # V3 specific? - acc_key = 'dir3_accuracy' - val_acc_key = 'val_dir3_accuracy' - elif 'accuracy' in history_df.columns: # V2 or other? - acc_key = 'accuracy' - val_acc_key = 'val_accuracy' - - if acc_key is None: - logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.") - n_panes = 1 # Only plot loss - else: - n_panes = 2 # Plot loss and accuracy - - # Get figure settings - fig_dpi = self.config.get('output', {}).get('figure_dpi', 150) - fig_size = self.config.get('output', {}).get('figure_size', [16, 9]) - footer_text = "© GRU-SAC v3" - - plt.style.use('seaborn-v0_8-darkgrid') - # Adjust figsize height based on panes - adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling - fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True) - - if n_panes == 1: - ax_loss = axes # Single axis - else: - ax_loss, ax_acc = axes # Multiple axes - - epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv - - # Pane 1: Loss (Log Scale) - ax_loss.plot(epochs, history_df[loss_key], label='Training Loss') - ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss') - ax_loss.set_yscale('log') - ax_loss.set_ylabel('Loss (Log Scale)') - ax_loss.legend() - ax_loss.set_title('GRU Model Training Progress', fontsize=16) - ax_loss.grid(True, which="both", ls="--", linewidth=0.5) - - # Pane 2: Accuracy (if available) - if n_panes == 2: - ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}') - ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}') - ax_acc.set_ylabel('Accuracy') - ax_acc.set_xlabel('Epoch') - ax_acc.legend() - ax_acc.grid(True, which="both", ls="--", linewidth=0.5) - else: - # If only loss pane, set xlabel there - ax_loss.set_xlabel('Epoch') - - # Add vertical line for early stopping epoch if available - if hasattr(history, 'epoch') and len(history.epoch) > 0: - # Early stopping epoch is the number of epochs run - early_stop_epoch = len(history.epoch) - if early_stop_epoch < max_epochs: # Only draw if early stopping occurred - for ax in fig.axes: - ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}') - # Add legend entry to the last plot - fig.axes[-1].legend() - - # Add footer - plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right', - verticalalignment='bottom', fontsize=8, color='gray') - - plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout - - # Save figure using IOManager - self.io.save_figure(fig, "gru_learning_curve", section='results') - logging.info("GRU learning curve plot saved.") - plt.close(fig) - - except FileNotFoundError: - logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.") - except Exception as e: - logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True) + # Save the newly trained model + saved_path = self.gru_handler.save() # Uses run_id from handler + if saved_path: + logging.info(f"Newly trained GRU model saved to {saved_path}") else: - logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.") - elif not self.io: - logging.warning("IOManager not available, skipping GRU learning curve plot.") - # --- End Plot Learning Curve --- # + logging.warning("Failed to save the newly trained GRU model.") + # Set the loaded ID to the current run ID + self.gru_model_run_id_loaded_from = self.run_id + logging.info(f"Using GRU model trained in current run: {self.run_id}") + + # --- V3 Output Contract: Plot Learning Curve --- # + if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True): + # Infer log dir path based on current models dir + log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs') + csv_log_path = os.path.join(log_dir, 'gru_history.csv') + if os.path.exists(csv_log_path): + logging.info(f"Plotting learning curve from {csv_log_path}...") + try: + history_df = pd.read_csv(csv_log_path) + + # Determine metric keys (handle v2 vs v3 differences if necessary) + loss_key = 'loss' + val_loss_key = 'val_loss' + acc_key = None + val_acc_key = None + if 'dir3_accuracy' in history_df.columns: # V3 specific? + acc_key = 'dir3_accuracy' + val_acc_key = 'val_dir3_accuracy' + elif 'accuracy' in history_df.columns: # V2 or other? + acc_key = 'accuracy' + val_acc_key = 'val_accuracy' + + if acc_key is None: + logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.") + n_panes = 1 # Only plot loss + else: + n_panes = 2 # Plot loss and accuracy + + # Get figure settings + fig_dpi = self.config.get('output', {}).get('figure_dpi', 150) + fig_size = self.config.get('output', {}).get('figure_size', [16, 9]) + footer_text = "© GRU-SAC v3" + + plt.style.use('seaborn-v0_8-darkgrid') + # Adjust figsize height based on panes + adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling + fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True) + + if n_panes == 1: + ax_loss = axes # Single axis + else: + ax_loss, ax_acc = axes # Multiple axes + + epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv + + # Pane 1: Loss (Log Scale) + ax_loss.plot(epochs, history_df[loss_key], label='Training Loss') + ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss') + ax_loss.set_yscale('log') + ax_loss.set_ylabel('Loss (Log Scale)') + ax_loss.legend() + ax_loss.set_title('GRU Model Training Progress', fontsize=16) + ax_loss.grid(True, which="both", ls="--", linewidth=0.5) + + # Pane 2: Accuracy (if available) + if n_panes == 2: + ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}') + ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}') + ax_acc.set_ylabel('Accuracy') + ax_acc.set_xlabel('Epoch') + ax_acc.legend() + ax_acc.grid(True, which="both", ls="--", linewidth=0.5) + else: + # If only loss pane, set xlabel there + ax_loss.set_xlabel('Epoch') + + # Add vertical line for early stopping epoch if available + if hasattr(history, 'epoch') and len(history.epoch) > 0: + # Early stopping epoch is the number of epochs run + early_stop_epoch = len(history.epoch) + if early_stop_epoch < max_epochs: # Only draw if early stopping occurred + for ax in fig.axes: + ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}') + # Add legend entry to the last plot + fig.axes[-1].legend() + + # Add footer + plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right', + verticalalignment='bottom', fontsize=8, color='gray') + + plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout + + # Save figure using IOManager + self.io.save_figure(fig, "gru_learning_curve", section='results') + logging.info("GRU learning curve plot saved.") + plt.close(fig) + + except FileNotFoundError: + logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.") + except Exception as e: + logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True) + else: + logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.") + elif not self.io: + logging.warning("IOManager not available, skipping GRU learning curve plot.") + # --- End Plot Learning Curve --- # else: # Load pre-trained GRU model load_run_id = gru_cfg.get('model_load_run_id', None) @@ -1357,8 +1357,8 @@ class TradingPipeline: except Exception as e: logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True) self.optimal_T = None - else: - logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.") + else: # Covers cases where calibration method is wrong or mismatch with ternary state + logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.") self.optimal_T = None self.vector_cal_params = None @@ -1368,60 +1368,74 @@ class TradingPipeline: self.optimized_edge_threshold = None # Reset for the fold if optimize_edge: - logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...") - try: - # Prepare y_true for optimization (needs binary 0/1) - if y_dir_val_to_check is None: - raise ValueError("Cannot optimize edge threshold without valid y_dir_val.") - y_true_for_opt = None - p_cal_for_opt = None - if is_ternary_check: - if p_cal_val_to_check is not None: - # Convert ternary to binary: P(up) vs others - p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up) - y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int) - else: - raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.") - else: - # Binary case - if p_cal_val_to_check is not None: - p_cal_for_opt = p_cal_val_to_check - y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int) - else: - raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.") - - # Perform optimization using the dedicated function from metrics - # Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here - # Ensure _calculate_optimal_edge_threshold is imported - self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt) - - if self.optimized_edge_threshold is not None: - logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}") - # Save optimized threshold - thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt" - try: - # Ensure fold_results_dir is defined (should be available from context) - if self.io and fold_results_dir: - self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, - thresh_file.replace('.txt',''), # Use filename as key for io - base_dir=fold_results_dir, use_txt=True) - logger.info(f"Saved optimized edge threshold to {thresh_file}") - elif self.io: - self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, - thresh_file.replace('.txt',''), - section='results', use_txt=True) # Fallback save - logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}") - else: - logger.warning("IOManager not available, cannot save optimized threshold.") - except Exception as e: - logger.error(f"Failed to save optimized edge threshold: {e}") - else: - logger.warning("Edge threshold optimization failed or was skipped. Using config default.") - self.optimized_edge_threshold = edge_thr_config # Fallback + logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...") + try: + # Prepare y_true for optimization (needs binary 0/1) + if y_dir_val_to_check is None: + raise ValueError("Cannot optimize edge threshold without valid y_dir_val.") + y_true_for_opt = None + p_cal_for_opt = None + if is_ternary_check: + if p_cal_val_to_check is not None: + # Convert ternary to binary: P(up) vs others + p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up) + y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int) + else: + raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.") + else: + # Binary case + if p_cal_val_to_check is not None: + p_cal_for_opt = p_cal_val_to_check + y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int) + else: + raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.") - except Exception as e: - logger.error(f"Error during edge threshold optimization: {e}", exc_info=True) - self.optimized_edge_threshold = edge_thr_config # Fallback + # Perform optimization using the dedicated function from metrics + # Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here + # Ensure _calculate_optimal_edge_threshold is imported + self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt) + + if self.optimized_edge_threshold is not None: + logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}") + # Save optimized threshold + thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt" + try: + # Ensure fold_results_dir is defined (should be available from context) + # Assuming fold_results_dir is defined in the outer scope + if self.io and 'fold_results_dir' in locals() and fold_results_dir: + self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, + thresh_file.replace('.txt',''), # Use filename as key for io + base_dir=fold_results_dir, use_txt=True) + logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}") + elif self.io: + # Fallback: Save to the main results directory if fold_results_dir isn't available + # Construct the path for logging clarity + fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file) + self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, + thresh_file.replace('.txt',''), + section='results', use_txt=True) # Fallback save + logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}") + else: + logger.warning("IOManager not available, cannot save optimized threshold.") + except NameError: # Specifically catch if fold_results_dir is not defined + logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.") + if self.io: + fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file) + self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, + thresh_file.replace('.txt',''), + section='results', use_txt=True) + logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}") + else: + logger.warning("IOManager not available, cannot save optimized threshold.") + except Exception as e: + logger.error(f"Failed to save optimized edge threshold: {e}") + else: + logger.warning("Edge threshold optimization failed or was skipped. Using config default.") + self.optimized_edge_threshold = edge_thr_config # Fallback + + except Exception as e: + logger.error(f"Error during edge threshold optimization: {e}", exc_info=True) + self.optimized_edge_threshold = edge_thr_config # Fallback else: # If optimization is disabled, store the config threshold for consistent use self.optimized_edge_threshold = edge_thr_config @@ -1430,21 +1444,21 @@ class TradingPipeline: # --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- # if p_cal_val_to_check is not None and y_dir_val_to_check is not None: - self._perform_gru_validation_checks( - p_cal_val=p_cal_val_to_check, - y_dir_val=y_dir_val_to_check, - is_ternary=is_ternary_check - ) + self._perform_gru_validation_checks( + p_cal_val=p_cal_val_to_check, + y_dir_val=y_dir_val_to_check, + is_ternary=is_ternary_check + ) # Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold else: - logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.") + logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.") # --- Helper for GRU Validation Checks (Replaces edge check) --- # def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary): """ Performs GRU validation checks: Edge-Filtered Accuracy and Brier Score. Logs results and raises SystemExit if checks fail. - + Args: p_cal_val: Calibrated probabilities on validation set. For binary: (N,) shape, P(up). @@ -1455,19 +1469,18 @@ class TradingPipeline: is_ternary (bool): Flag indicating if ternary classification is used. """ logger.info(f"--- Fold {self.current_fold}: Performing GRU Validation Checks --- ") - + # --- Define thresholds (Consider moving to config) --- # validation_criteria = self.config.get('validation_gates', {}).get('gru', {}) edge_check_thr = validation_criteria.get('edge_filtered_acc_ci_lower_threshold', 0.55) brier_check_thr = validation_criteria.get('brier_score_threshold', 0.19) min_edge_samples = validation_criteria.get('edge_filtered_min_samples', 30) # --- End Thresholds --- # - + # --- Determine Edge Threshold --- # calib_config = self.config.get('calibration', {}) optimize_edge = calib_config.get('optimize_edge_threshold', False) edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback - self.fold_edge_threshold = edge_thr_config # Initialize with config value if optimize_edge and not is_ternary: @@ -1483,7 +1496,7 @@ class TradingPipeline: self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold}, f'optimized_edge_threshold_fold_{self.current_fold}', base_dir=fold_results_dir, use_txt=True) - else: + else: logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.") elif optimize_edge and is_ternary: logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}") @@ -1527,13 +1540,13 @@ class TradingPipeline: passed_edge_acc = ci_lower >= edge_check_thr logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}") except ValueError as binom_err: - logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.") - passed_edge_acc = False # Consider error as failure + logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.") + passed_edge_acc = False # Consider error as failure else: - logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.") - if n_filtered == 0: - logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}") - passed_edge_acc = False # Consider NaN or 0 samples as failure + logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.") + if n_filtered == 0: + logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}") + passed_edge_acc = False # Consider NaN or 0 samples as failure except Exception as e: logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True) passed_edge_acc = False # Consider error as failure @@ -1566,9 +1579,9 @@ class TradingPipeline: error_msg = f"FOLD {self.current_fold} GRU VALIDATION FAILED: Edge Acc Pass={passed_edge_acc} (Req CI>={edge_check_thr}), Brier Pass={passed_brier} (Req Score<={brier_check_thr}). Aborting fold." logger.error(error_msg) # Use sys.exit with a specific message for clarity - sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).") - else: - logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") + sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).") + else: # Corrected indentation + logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation # --- End Validation Helper --- # def train_or_load_sac(self): @@ -1628,7 +1641,7 @@ class TradingPipeline: # # --- Revision 1: Handle Rolling Calibrator Conflict --- # # ... (block removed) ... # # --- End Revision 1 --- # - + # Start the training process final_agent_path = self.sac_trainer.train(gru_run_id_for_sac=self.gru_model_run_id_loaded_from) @@ -1791,6 +1804,7 @@ class TradingPipeline: # Get raw predictions needed for rolling calibration p_raw_test_for_bt = None logits_test_for_bt = None + is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary if self.config.get('calibration', {}).get('rolling_enabled', False): logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...") if is_ternary: @@ -1798,43 +1812,43 @@ class TradingPipeline: if logits_test_for_bt is None: logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.") - else: + else: # Corrected indentation preds_test_raw = self.gru_handler.predict(self.X_test_seq) if preds_test_raw is None or len(preds_test_raw) < 3: logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.") - p_raw_test_for_bt = preds_test_raw[2].flatten() + p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities # Get the edge threshold determined during validation (optimized or fixed) edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1)) logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.") - try: + try: # Corrected indentation self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest( sac_agent_load_path=self.sac_agent_load_path, X_test_seq=self.X_test_seq, y_test_seq_dict=self.y_test_seq_dict, test_indices=self.test_indices, gru_handler=self.gru_handler, - # --- Pass Calibrator instances and initial state --- # - calibrator=calibrator_instance, - vector_calibrator=vector_calibrator_instance, - initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists - initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists - fold_edge_threshold=edge_threshold_for_bt, - # --- Pass raw predictions if needed for rolling cal --- # - p_raw_test=p_raw_test_for_bt, - logits_test=logits_test_for_bt, - # --- Pass original prices --- # - original_prices=self.df_test_original, # Pass the DataFrame - is_ternary=self.use_ternary, - fold_num=self.current_fold - ) - except SystemExit as e: + # --- Pass Calibrator instances and initial state --- # + calibrator=calibrator_instance, + vector_calibrator=vector_calibrator_instance, + initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists + initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists + fold_edge_threshold=edge_threshold_for_bt, + # --- Pass raw predictions if needed for rolling cal --- # + p_raw_test=p_raw_test_for_bt, + logits_test=logits_test_for_bt, + # --- Pass original prices --- # + original_prices=self.df_test_original, # Pass the DataFrame + is_ternary=self.use_ternary, + fold_num=self.current_fold + ) + except SystemExit as e: # Corrected indentation # Catch exits from backtester validation/execution logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}") raise # Re-raise to stop the fold - except Exception as e: + except Exception as e: # Corrected indentation logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True) # Treat as failure, ensure metrics are None self.backtest_results_df = None @@ -2027,7 +2041,7 @@ class TradingPipeline: def execute(self): """Runs the full trading pipeline end-to-end.""" logger.info(f"--- Starting Trading Pipeline: Run ID {self.run_id} ---") - + # 1. Load and Preprocess Data self.load_and_preprocess_data() if self.data_processed is None: # Check if data loading failed @@ -2073,7 +2087,7 @@ class TradingPipeline: else: # Perform edge accuracy check only if calibration happened and model exists self._perform_gru_validation_checks( - p_cal_val=self.p_cal_val, + p_cal_val=self.p_cal_val, y_dir_val=self.y_dir_val, is_ternary=self.use_ternary ) @@ -2092,7 +2106,7 @@ class TradingPipeline: # --- Walk-Forward Fold Generation --- # def _generate_walk_forward_folds(self) -> Iterator[Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp]]: """ - Generates start and end timestamps for train, validation, and test sets + Generates start and end timestamps for train, validation, and test sets for each walk-forward fold based on config settings. Requires self.df_raw to be loaded first to determine the full date range. """ @@ -2741,8 +2755,8 @@ class TradingPipeline: except Exception as e: logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True) - - # --- Entry Point --- # + +# --- Entry Point --- # if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")