iundentation fix pipeline

This commit is contained in:
yasha 2025-04-18 18:18:52 +00:00
parent a02dd4342e
commit a86cdb2c8f
3 changed files with 233 additions and 219 deletions

View File

@ -227,7 +227,7 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location. 1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation). 2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
3. **Run Pipeline:** 3. **Run Pipeline:**
```bash ```bash
# From project root (develop/gru_sac_predictor/) # From project root (develop/gru_sac_predictor/)
python gru_sac_predictor/run.py --config path/to/your_config.yaml python gru_sac_predictor/run.py --config path/to/your_config.yaml
``` ```

View File

@ -455,7 +455,7 @@ class Backtester:
logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})") logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})")
current_optimal_T = new_T current_optimal_T = new_T
calibrator.optimal_T = new_T # Update instance calibrator.optimal_T = new_T # Update instance
else: else:
logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).") logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).")
last_recalib_step = i last_recalib_step = i

View File

@ -1094,116 +1094,116 @@ class TradingPipeline:
patience=patience patience=patience
) )
if self.gru_model is None: if self.gru_model is None:
logging.error("GRU model training failed. Exiting.") logging.error("GRU model training failed. Exiting.")
sys.exit(1) sys.exit(1)
else:
# Save the newly trained model
saved_path = self.gru_handler.save() # Uses run_id from handler
if saved_path:
logging.info(f"Newly trained GRU model saved to {saved_path}")
else: else:
logging.warning("Failed to save the newly trained GRU model.") # Save the newly trained model
# Set the loaded ID to the current run ID saved_path = self.gru_handler.save() # Uses run_id from handler
self.gru_model_run_id_loaded_from = self.run_id if saved_path:
logging.info(f"Using GRU model trained in current run: {self.run_id}") logging.info(f"Newly trained GRU model saved to {saved_path}")
# --- V3 Output Contract: Plot Learning Curve --- #
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
# Infer log dir path based on current models dir
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
if os.path.exists(csv_log_path):
logging.info(f"Plotting learning curve from {csv_log_path}...")
try:
history_df = pd.read_csv(csv_log_path)
# Determine metric keys (handle v2 vs v3 differences if necessary)
loss_key = 'loss'
val_loss_key = 'val_loss'
acc_key = None
val_acc_key = None
if 'dir3_accuracy' in history_df.columns: # V3 specific?
acc_key = 'dir3_accuracy'
val_acc_key = 'val_dir3_accuracy'
elif 'accuracy' in history_df.columns: # V2 or other?
acc_key = 'accuracy'
val_acc_key = 'val_accuracy'
if acc_key is None:
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
n_panes = 1 # Only plot loss
else:
n_panes = 2 # Plot loss and accuracy
# Get figure settings
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
footer_text = "© GRU-SAC v3"
plt.style.use('seaborn-v0_8-darkgrid')
# Adjust figsize height based on panes
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
if n_panes == 1:
ax_loss = axes # Single axis
else:
ax_loss, ax_acc = axes # Multiple axes
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
# Pane 1: Loss (Log Scale)
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
ax_loss.set_yscale('log')
ax_loss.set_ylabel('Loss (Log Scale)')
ax_loss.legend()
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
# Pane 2: Accuracy (if available)
if n_panes == 2:
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
ax_acc.set_ylabel('Accuracy')
ax_acc.set_xlabel('Epoch')
ax_acc.legend()
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
else:
# If only loss pane, set xlabel there
ax_loss.set_xlabel('Epoch')
# Add vertical line for early stopping epoch if available
if hasattr(history, 'epoch') and len(history.epoch) > 0:
# Early stopping epoch is the number of epochs run
early_stop_epoch = len(history.epoch)
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
for ax in fig.axes:
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
# Add legend entry to the last plot
fig.axes[-1].legend()
# Add footer
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
verticalalignment='bottom', fontsize=8, color='gray')
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
# Save figure using IOManager
self.io.save_figure(fig, "gru_learning_curve", section='results')
logging.info("GRU learning curve plot saved.")
plt.close(fig)
except FileNotFoundError:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
except Exception as e:
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
else: else:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.") logging.warning("Failed to save the newly trained GRU model.")
elif not self.io: # Set the loaded ID to the current run ID
logging.warning("IOManager not available, skipping GRU learning curve plot.") self.gru_model_run_id_loaded_from = self.run_id
# --- End Plot Learning Curve --- # logging.info(f"Using GRU model trained in current run: {self.run_id}")
# --- V3 Output Contract: Plot Learning Curve --- #
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
# Infer log dir path based on current models dir
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
if os.path.exists(csv_log_path):
logging.info(f"Plotting learning curve from {csv_log_path}...")
try:
history_df = pd.read_csv(csv_log_path)
# Determine metric keys (handle v2 vs v3 differences if necessary)
loss_key = 'loss'
val_loss_key = 'val_loss'
acc_key = None
val_acc_key = None
if 'dir3_accuracy' in history_df.columns: # V3 specific?
acc_key = 'dir3_accuracy'
val_acc_key = 'val_dir3_accuracy'
elif 'accuracy' in history_df.columns: # V2 or other?
acc_key = 'accuracy'
val_acc_key = 'val_accuracy'
if acc_key is None:
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
n_panes = 1 # Only plot loss
else:
n_panes = 2 # Plot loss and accuracy
# Get figure settings
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
footer_text = "© GRU-SAC v3"
plt.style.use('seaborn-v0_8-darkgrid')
# Adjust figsize height based on panes
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
if n_panes == 1:
ax_loss = axes # Single axis
else:
ax_loss, ax_acc = axes # Multiple axes
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
# Pane 1: Loss (Log Scale)
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
ax_loss.set_yscale('log')
ax_loss.set_ylabel('Loss (Log Scale)')
ax_loss.legend()
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
# Pane 2: Accuracy (if available)
if n_panes == 2:
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
ax_acc.set_ylabel('Accuracy')
ax_acc.set_xlabel('Epoch')
ax_acc.legend()
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
else:
# If only loss pane, set xlabel there
ax_loss.set_xlabel('Epoch')
# Add vertical line for early stopping epoch if available
if hasattr(history, 'epoch') and len(history.epoch) > 0:
# Early stopping epoch is the number of epochs run
early_stop_epoch = len(history.epoch)
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
for ax in fig.axes:
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
# Add legend entry to the last plot
fig.axes[-1].legend()
# Add footer
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
verticalalignment='bottom', fontsize=8, color='gray')
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
# Save figure using IOManager
self.io.save_figure(fig, "gru_learning_curve", section='results')
logging.info("GRU learning curve plot saved.")
plt.close(fig)
except FileNotFoundError:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
except Exception as e:
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
else:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
elif not self.io:
logging.warning("IOManager not available, skipping GRU learning curve plot.")
# --- End Plot Learning Curve --- #
else: # Load pre-trained GRU model else: # Load pre-trained GRU model
load_run_id = gru_cfg.get('model_load_run_id', None) load_run_id = gru_cfg.get('model_load_run_id', None)
@ -1357,8 +1357,8 @@ class TradingPipeline:
except Exception as e: except Exception as e:
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True) logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
self.optimal_T = None self.optimal_T = None
else: else: # Covers cases where calibration method is wrong or mismatch with ternary state
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.") logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
self.optimal_T = None self.optimal_T = None
self.vector_cal_params = None self.vector_cal_params = None
@ -1368,60 +1368,74 @@ class TradingPipeline:
self.optimized_edge_threshold = None # Reset for the fold self.optimized_edge_threshold = None # Reset for the fold
if optimize_edge: if optimize_edge:
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...") logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
try: try:
# Prepare y_true for optimization (needs binary 0/1) # Prepare y_true for optimization (needs binary 0/1)
if y_dir_val_to_check is None: if y_dir_val_to_check is None:
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.") raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
y_true_for_opt = None y_true_for_opt = None
p_cal_for_opt = None p_cal_for_opt = None
if is_ternary_check: if is_ternary_check:
if p_cal_val_to_check is not None: if p_cal_val_to_check is not None:
# Convert ternary to binary: P(up) vs others # Convert ternary to binary: P(up) vs others
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up) p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int) y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
else: else:
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.") raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
else: else:
# Binary case # Binary case
if p_cal_val_to_check is not None: if p_cal_val_to_check is not None:
p_cal_for_opt = p_cal_val_to_check p_cal_for_opt = p_cal_val_to_check
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int) y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
else: else:
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.") raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
# Perform optimization using the dedicated function from metrics # Perform optimization using the dedicated function from metrics
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here # Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
# Ensure _calculate_optimal_edge_threshold is imported # Ensure _calculate_optimal_edge_threshold is imported
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt) self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
if self.optimized_edge_threshold is not None: if self.optimized_edge_threshold is not None:
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}") logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
# Save optimized threshold # Save optimized threshold
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt" thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
try: try:
# Ensure fold_results_dir is defined (should be available from context) # Ensure fold_results_dir is defined (should be available from context)
if self.io and fold_results_dir: # Assuming fold_results_dir is defined in the outer scope
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, if self.io and 'fold_results_dir' in locals() and fold_results_dir:
thresh_file.replace('.txt',''), # Use filename as key for io self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
base_dir=fold_results_dir, use_txt=True) thresh_file.replace('.txt',''), # Use filename as key for io
logger.info(f"Saved optimized edge threshold to {thresh_file}") base_dir=fold_results_dir, use_txt=True)
elif self.io: logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, elif self.io:
thresh_file.replace('.txt',''), # Fallback: Save to the main results directory if fold_results_dir isn't available
section='results', use_txt=True) # Fallback save # Construct the path for logging clarity
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}") fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
else: self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
logger.warning("IOManager not available, cannot save optimized threshold.") thresh_file.replace('.txt',''),
except Exception as e: section='results', use_txt=True) # Fallback save
logger.error(f"Failed to save optimized edge threshold: {e}") logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else: else:
logger.warning("Edge threshold optimization failed or was skipped. Using config default.") logger.warning("IOManager not available, cannot save optimized threshold.")
self.optimized_edge_threshold = edge_thr_config # Fallback except NameError: # Specifically catch if fold_results_dir is not defined
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
if self.io:
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''),
section='results', use_txt=True)
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else:
logger.warning("IOManager not available, cannot save optimized threshold.")
except Exception as e:
logger.error(f"Failed to save optimized edge threshold: {e}")
else:
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
self.optimized_edge_threshold = edge_thr_config # Fallback
except Exception as e: except Exception as e:
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True) logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
self.optimized_edge_threshold = edge_thr_config # Fallback self.optimized_edge_threshold = edge_thr_config # Fallback
else: else:
# If optimization is disabled, store the config threshold for consistent use # If optimization is disabled, store the config threshold for consistent use
self.optimized_edge_threshold = edge_thr_config self.optimized_edge_threshold = edge_thr_config
@ -1430,14 +1444,14 @@ class TradingPipeline:
# --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- # # --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- #
if p_cal_val_to_check is not None and y_dir_val_to_check is not None: if p_cal_val_to_check is not None and y_dir_val_to_check is not None:
self._perform_gru_validation_checks( self._perform_gru_validation_checks(
p_cal_val=p_cal_val_to_check, p_cal_val=p_cal_val_to_check,
y_dir_val=y_dir_val_to_check, y_dir_val=y_dir_val_to_check,
is_ternary=is_ternary_check is_ternary=is_ternary_check
) )
# Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold # Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold
else: else:
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.") logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
# --- Helper for GRU Validation Checks (Replaces edge check) --- # # --- Helper for GRU Validation Checks (Replaces edge check) --- #
def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary): def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary):
@ -1467,7 +1481,6 @@ class TradingPipeline:
calib_config = self.config.get('calibration', {}) calib_config = self.config.get('calibration', {})
optimize_edge = calib_config.get('optimize_edge_threshold', False) optimize_edge = calib_config.get('optimize_edge_threshold', False)
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
self.fold_edge_threshold = edge_thr_config # Initialize with config value self.fold_edge_threshold = edge_thr_config # Initialize with config value
if optimize_edge and not is_ternary: if optimize_edge and not is_ternary:
@ -1483,7 +1496,7 @@ class TradingPipeline:
self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold}, self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold},
f'optimized_edge_threshold_fold_{self.current_fold}', f'optimized_edge_threshold_fold_{self.current_fold}',
base_dir=fold_results_dir, use_txt=True) base_dir=fold_results_dir, use_txt=True)
else: else:
logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.") logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.")
elif optimize_edge and is_ternary: elif optimize_edge and is_ternary:
logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}") logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}")
@ -1527,13 +1540,13 @@ class TradingPipeline:
passed_edge_acc = ci_lower >= edge_check_thr passed_edge_acc = ci_lower >= edge_check_thr
logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}") logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}")
except ValueError as binom_err: except ValueError as binom_err:
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.") logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
passed_edge_acc = False # Consider error as failure passed_edge_acc = False # Consider error as failure
else: else:
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.") logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
if n_filtered == 0: if n_filtered == 0:
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}") logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
passed_edge_acc = False # Consider NaN or 0 samples as failure passed_edge_acc = False # Consider NaN or 0 samples as failure
except Exception as e: except Exception as e:
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True) logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True)
passed_edge_acc = False # Consider error as failure passed_edge_acc = False # Consider error as failure
@ -1567,8 +1580,8 @@ class TradingPipeline:
logger.error(error_msg) logger.error(error_msg)
# Use sys.exit with a specific message for clarity # Use sys.exit with a specific message for clarity
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).") sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
else: else: # Corrected indentation
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
# --- End Validation Helper --- # # --- End Validation Helper --- #
def train_or_load_sac(self): def train_or_load_sac(self):
@ -1791,6 +1804,7 @@ class TradingPipeline:
# Get raw predictions needed for rolling calibration # Get raw predictions needed for rolling calibration
p_raw_test_for_bt = None p_raw_test_for_bt = None
logits_test_for_bt = None logits_test_for_bt = None
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
if self.config.get('calibration', {}).get('rolling_enabled', False): if self.config.get('calibration', {}).get('rolling_enabled', False):
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...") logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
if is_ternary: if is_ternary:
@ -1798,43 +1812,43 @@ class TradingPipeline:
if logits_test_for_bt is None: if logits_test_for_bt is None:
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.") logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
else: else: # Corrected indentation
preds_test_raw = self.gru_handler.predict(self.X_test_seq) preds_test_raw = self.gru_handler.predict(self.X_test_seq)
if preds_test_raw is None or len(preds_test_raw) < 3: if preds_test_raw is None or len(preds_test_raw) < 3:
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.") logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
p_raw_test_for_bt = preds_test_raw[2].flatten() p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
# Get the edge threshold determined during validation (optimized or fixed) # Get the edge threshold determined during validation (optimized or fixed)
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1)) edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.") logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
try: try: # Corrected indentation
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest( self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
sac_agent_load_path=self.sac_agent_load_path, sac_agent_load_path=self.sac_agent_load_path,
X_test_seq=self.X_test_seq, X_test_seq=self.X_test_seq,
y_test_seq_dict=self.y_test_seq_dict, y_test_seq_dict=self.y_test_seq_dict,
test_indices=self.test_indices, test_indices=self.test_indices,
gru_handler=self.gru_handler, gru_handler=self.gru_handler,
# --- Pass Calibrator instances and initial state --- # # --- Pass Calibrator instances and initial state --- #
calibrator=calibrator_instance, calibrator=calibrator_instance,
vector_calibrator=vector_calibrator_instance, vector_calibrator=vector_calibrator_instance,
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
fold_edge_threshold=edge_threshold_for_bt, fold_edge_threshold=edge_threshold_for_bt,
# --- Pass raw predictions if needed for rolling cal --- # # --- Pass raw predictions if needed for rolling cal --- #
p_raw_test=p_raw_test_for_bt, p_raw_test=p_raw_test_for_bt,
logits_test=logits_test_for_bt, logits_test=logits_test_for_bt,
# --- Pass original prices --- # # --- Pass original prices --- #
original_prices=self.df_test_original, # Pass the DataFrame original_prices=self.df_test_original, # Pass the DataFrame
is_ternary=self.use_ternary, is_ternary=self.use_ternary,
fold_num=self.current_fold fold_num=self.current_fold
) )
except SystemExit as e: except SystemExit as e: # Corrected indentation
# Catch exits from backtester validation/execution # Catch exits from backtester validation/execution
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}") logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
raise # Re-raise to stop the fold raise # Re-raise to stop the fold
except Exception as e: except Exception as e: # Corrected indentation
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True) logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
# Treat as failure, ensure metrics are None # Treat as failure, ensure metrics are None
self.backtest_results_df = None self.backtest_results_df = None
@ -2742,7 +2756,7 @@ class TradingPipeline:
except Exception as e: except Exception as e:
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True) logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
# --- Entry Point --- # # --- Entry Point --- #
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.") parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")