iundentation fix pipeline

This commit is contained in:
yasha 2025-04-18 18:18:52 +00:00
parent a02dd4342e
commit a86cdb2c8f
3 changed files with 233 additions and 219 deletions

View File

@ -227,7 +227,7 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location. 1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation). 2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
3. **Run Pipeline:** 3. **Run Pipeline:**
```bash ```bash
# From project root (develop/gru_sac_predictor/) # From project root (develop/gru_sac_predictor/)
python gru_sac_predictor/run.py --config path/to/your_config.yaml python gru_sac_predictor/run.py --config path/to/your_config.yaml
``` ```

View File

@ -77,7 +77,7 @@ class Backtester:
self.coverage_alarm_threshold_drop = self.cal_cfg.get('coverage_alarm_threshold_drop', 0.03) self.coverage_alarm_threshold_drop = self.cal_cfg.get('coverage_alarm_threshold_drop', 0.03)
self.coverage_alarm_window = self.cal_cfg.get('coverage_alarm_window', 1000) self.coverage_alarm_window = self.cal_cfg.get('coverage_alarm_window', 1000)
# --- End --- # # --- End --- #
self.results_df: Optional[pd.DataFrame] = None self.results_df: Optional[pd.DataFrame] = None
self.metrics: Optional[Dict[str, Any]] = None self.metrics: Optional[Dict[str, Any]] = None
@ -98,11 +98,11 @@ class Backtester:
self.ece_recalibration_threshold = self.cal_cfg.get('ece_recalibration_threshold', 0.03) self.ece_recalibration_threshold = self.cal_cfg.get('ece_recalibration_threshold', 0.03)
def run_backtest( def run_backtest(
self, self,
sac_agent_load_path: Optional[str], sac_agent_load_path: Optional[str],
X_test_seq: np.ndarray, X_test_seq: np.ndarray,
y_test_seq_dict: Dict[str, np.ndarray], y_test_seq_dict: Dict[str, np.ndarray],
test_indices: pd.Index, test_indices: pd.Index,
gru_handler: GRUModelHandler, gru_handler: GRUModelHandler,
# --- Pass Calibrator Instances & Initial Params/Threshold --- # # --- Pass Calibrator Instances & Initial Params/Threshold --- #
calibrator: Optional[Calibrator], # For Temp Scaling calibrator: Optional[Calibrator], # For Temp Scaling
@ -163,7 +163,7 @@ class Backtester:
logger.error(f"Fold {fold_num}: Rolling calibration enabled for ternary case, but logits_test not provided.") logger.error(f"Fold {fold_num}: Rolling calibration enabled for ternary case, but logits_test not provided.")
return None, None, None return None, None, None
# --- End Validation --- # # --- End Validation --- #
# 1. Initialize SAC Agent & Load Weights # 1. Initialize SAC Agent & Load Weights
# Ensure agent state dim matches the state construction below # Ensure agent state dim matches the state construction below
agent_state_dim = 5 # mu, sigma, edge, |mu|/sigma, position agent_state_dim = 5 # mu, sigma, edge, |mu|/sigma, position
@ -223,7 +223,7 @@ class Backtester:
log_sigma_test_state = preds_test_state[1][:, 1].flatten() if preds_test_state[1].shape[-1] == 2 else np.log(preds_test_state[1].flatten() + 1e-9) log_sigma_test_state = preds_test_state[1][:, 1].flatten() if preds_test_state[1].shape[-1] == 2 else np.log(preds_test_state[1].flatten() + 1e-9)
sigma_test_state = np.exp(log_sigma_test_state) sigma_test_state = np.exp(log_sigma_test_state)
p_raw_fallback = p_raw_test # Use the passed raw probs if needed outside rolling p_raw_fallback = p_raw_test # Use the passed raw probs if needed outside rolling
# Extract actual returns and directions # Extract actual returns and directions
actual_ret_test = y_test_seq_dict.get('ret') actual_ret_test = y_test_seq_dict.get('ret')
dir_key = 'dir3' if is_ternary else 'dir' dir_key = 'dir3' if is_ternary else 'dir'
@ -262,7 +262,7 @@ class Backtester:
metrics_log = [] # Store periodic metrics metrics_log = [] # Store periodic metrics
step_correct_nonzero, step_count_nonzero = 0, 0 step_correct_nonzero, step_count_nonzero = 0, 0
step_abs_actions = [] step_abs_actions = []
logger.info(f"Fold {fold_num}: Starting backtest simulation loop ({n_test} steps)...") logger.info(f"Fold {fold_num}: Starting backtest simulation loop ({n_test} steps)...")
for i in range(n_test): for i in range(n_test):
# --- Step i: Calibration --- # --- Step i: Calibration ---
@ -332,7 +332,7 @@ class Backtester:
# --- Step i: Update Position --- # --- Step i: Update Position ---
current_position = target_position current_position = target_position
# --- Step i: Update Metrics Log Data --- # --- Step i: Update Metrics Log Data ---
if abs(current_position) > 1e-6: if abs(current_position) > 1e-6:
step_count_nonzero += 1 step_count_nonzero += 1
@ -455,7 +455,7 @@ class Backtester:
logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})") logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})")
current_optimal_T = new_T current_optimal_T = new_T
calibrator.optimal_T = new_T # Update instance calibrator.optimal_T = new_T # Update instance
else: else:
logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).") logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).")
last_recalib_step = i last_recalib_step = i

View File

@ -769,7 +769,7 @@ class TradingPipeline:
except Exception as e: except Exception as e:
logging.error(f"Manual save of final feature whitelist failed: {e}", exc_info=True) logging.error(f"Manual save of final feature whitelist failed: {e}", exc_info=True)
# --- End Save Update --- # # --- End Save Update --- #
# --- MODIFIED: Prune the SCALED data splits using the determined whitelist --- # # --- MODIFIED: Prune the SCALED data splits using the determined whitelist --- #
if self.X_train_scaled is None or self.X_val_scaled is None or self.X_test_scaled is None: if self.X_train_scaled is None or self.X_val_scaled is None or self.X_test_scaled is None:
logging.error("Scaled data splits not available for pruning.") logging.error("Scaled data splits not available for pruning.")
@ -780,7 +780,7 @@ class TradingPipeline:
self.X_val_pruned = self.feature_engineer.prune_features(self.X_val_scaled, self.final_whitelist) self.X_val_pruned = self.feature_engineer.prune_features(self.X_val_scaled, self.final_whitelist)
self.X_test_pruned = self.feature_engineer.prune_features(self.X_test_scaled, self.final_whitelist) self.X_test_pruned = self.feature_engineer.prune_features(self.X_test_scaled, self.final_whitelist)
# --- End Modification --- # # --- End Modification --- #
logging.info(f"Feature shapes after pruning scaled data: Train={self.X_train_pruned.shape}, Val={self.X_val_pruned.shape}, Test={self.X_test_pruned.shape}") logging.info(f"Feature shapes after pruning scaled data: Train={self.X_train_pruned.shape}, Val={self.X_val_pruned.shape}, Test={self.X_test_pruned.shape}")
# Verification and empty checks remain the same, using X_*_pruned # Verification and empty checks remain the same, using X_*_pruned
@ -1094,116 +1094,116 @@ class TradingPipeline:
patience=patience patience=patience
) )
if self.gru_model is None: if self.gru_model is None:
logging.error("GRU model training failed. Exiting.") logging.error("GRU model training failed. Exiting.")
sys.exit(1) sys.exit(1)
else:
# Save the newly trained model
saved_path = self.gru_handler.save() # Uses run_id from handler
if saved_path:
logging.info(f"Newly trained GRU model saved to {saved_path}")
else: else:
logging.warning("Failed to save the newly trained GRU model.") # Save the newly trained model
# Set the loaded ID to the current run ID saved_path = self.gru_handler.save() # Uses run_id from handler
self.gru_model_run_id_loaded_from = self.run_id if saved_path:
logging.info(f"Using GRU model trained in current run: {self.run_id}") logging.info(f"Newly trained GRU model saved to {saved_path}")
# --- V3 Output Contract: Plot Learning Curve --- #
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
# Infer log dir path based on current models dir
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
if os.path.exists(csv_log_path):
logging.info(f"Plotting learning curve from {csv_log_path}...")
try:
history_df = pd.read_csv(csv_log_path)
# Determine metric keys (handle v2 vs v3 differences if necessary)
loss_key = 'loss'
val_loss_key = 'val_loss'
acc_key = None
val_acc_key = None
if 'dir3_accuracy' in history_df.columns: # V3 specific?
acc_key = 'dir3_accuracy'
val_acc_key = 'val_dir3_accuracy'
elif 'accuracy' in history_df.columns: # V2 or other?
acc_key = 'accuracy'
val_acc_key = 'val_accuracy'
if acc_key is None:
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
n_panes = 1 # Only plot loss
else:
n_panes = 2 # Plot loss and accuracy
# Get figure settings
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
footer_text = "© GRU-SAC v3"
plt.style.use('seaborn-v0_8-darkgrid')
# Adjust figsize height based on panes
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
if n_panes == 1:
ax_loss = axes # Single axis
else:
ax_loss, ax_acc = axes # Multiple axes
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
# Pane 1: Loss (Log Scale)
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
ax_loss.set_yscale('log')
ax_loss.set_ylabel('Loss (Log Scale)')
ax_loss.legend()
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
# Pane 2: Accuracy (if available)
if n_panes == 2:
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
ax_acc.set_ylabel('Accuracy')
ax_acc.set_xlabel('Epoch')
ax_acc.legend()
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
else:
# If only loss pane, set xlabel there
ax_loss.set_xlabel('Epoch')
# Add vertical line for early stopping epoch if available
if hasattr(history, 'epoch') and len(history.epoch) > 0:
# Early stopping epoch is the number of epochs run
early_stop_epoch = len(history.epoch)
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
for ax in fig.axes:
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
# Add legend entry to the last plot
fig.axes[-1].legend()
# Add footer
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
verticalalignment='bottom', fontsize=8, color='gray')
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
# Save figure using IOManager
self.io.save_figure(fig, "gru_learning_curve", section='results')
logging.info("GRU learning curve plot saved.")
plt.close(fig)
except FileNotFoundError:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
except Exception as e:
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
else: else:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.") logging.warning("Failed to save the newly trained GRU model.")
elif not self.io: # Set the loaded ID to the current run ID
logging.warning("IOManager not available, skipping GRU learning curve plot.") self.gru_model_run_id_loaded_from = self.run_id
# --- End Plot Learning Curve --- # logging.info(f"Using GRU model trained in current run: {self.run_id}")
# --- V3 Output Contract: Plot Learning Curve --- #
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
# Infer log dir path based on current models dir
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
if os.path.exists(csv_log_path):
logging.info(f"Plotting learning curve from {csv_log_path}...")
try:
history_df = pd.read_csv(csv_log_path)
# Determine metric keys (handle v2 vs v3 differences if necessary)
loss_key = 'loss'
val_loss_key = 'val_loss'
acc_key = None
val_acc_key = None
if 'dir3_accuracy' in history_df.columns: # V3 specific?
acc_key = 'dir3_accuracy'
val_acc_key = 'val_dir3_accuracy'
elif 'accuracy' in history_df.columns: # V2 or other?
acc_key = 'accuracy'
val_acc_key = 'val_accuracy'
if acc_key is None:
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
n_panes = 1 # Only plot loss
else:
n_panes = 2 # Plot loss and accuracy
# Get figure settings
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
footer_text = "© GRU-SAC v3"
plt.style.use('seaborn-v0_8-darkgrid')
# Adjust figsize height based on panes
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
if n_panes == 1:
ax_loss = axes # Single axis
else:
ax_loss, ax_acc = axes # Multiple axes
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
# Pane 1: Loss (Log Scale)
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
ax_loss.set_yscale('log')
ax_loss.set_ylabel('Loss (Log Scale)')
ax_loss.legend()
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
# Pane 2: Accuracy (if available)
if n_panes == 2:
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
ax_acc.set_ylabel('Accuracy')
ax_acc.set_xlabel('Epoch')
ax_acc.legend()
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
else:
# If only loss pane, set xlabel there
ax_loss.set_xlabel('Epoch')
# Add vertical line for early stopping epoch if available
if hasattr(history, 'epoch') and len(history.epoch) > 0:
# Early stopping epoch is the number of epochs run
early_stop_epoch = len(history.epoch)
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
for ax in fig.axes:
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
# Add legend entry to the last plot
fig.axes[-1].legend()
# Add footer
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
verticalalignment='bottom', fontsize=8, color='gray')
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
# Save figure using IOManager
self.io.save_figure(fig, "gru_learning_curve", section='results')
logging.info("GRU learning curve plot saved.")
plt.close(fig)
except FileNotFoundError:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
except Exception as e:
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
else:
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
elif not self.io:
logging.warning("IOManager not available, skipping GRU learning curve plot.")
# --- End Plot Learning Curve --- #
else: # Load pre-trained GRU model else: # Load pre-trained GRU model
load_run_id = gru_cfg.get('model_load_run_id', None) load_run_id = gru_cfg.get('model_load_run_id', None)
@ -1357,8 +1357,8 @@ class TradingPipeline:
except Exception as e: except Exception as e:
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True) logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
self.optimal_T = None self.optimal_T = None
else: else: # Covers cases where calibration method is wrong or mismatch with ternary state
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.") logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
self.optimal_T = None self.optimal_T = None
self.vector_cal_params = None self.vector_cal_params = None
@ -1368,60 +1368,74 @@ class TradingPipeline:
self.optimized_edge_threshold = None # Reset for the fold self.optimized_edge_threshold = None # Reset for the fold
if optimize_edge: if optimize_edge:
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...") logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
try: try:
# Prepare y_true for optimization (needs binary 0/1) # Prepare y_true for optimization (needs binary 0/1)
if y_dir_val_to_check is None: if y_dir_val_to_check is None:
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.") raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
y_true_for_opt = None y_true_for_opt = None
p_cal_for_opt = None p_cal_for_opt = None
if is_ternary_check: if is_ternary_check:
if p_cal_val_to_check is not None: if p_cal_val_to_check is not None:
# Convert ternary to binary: P(up) vs others # Convert ternary to binary: P(up) vs others
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up) p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int) y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
else: else:
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.") raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
else: else:
# Binary case # Binary case
if p_cal_val_to_check is not None: if p_cal_val_to_check is not None:
p_cal_for_opt = p_cal_val_to_check p_cal_for_opt = p_cal_val_to_check
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int) y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
else: else:
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.") raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
# Perform optimization using the dedicated function from metrics
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
# Ensure _calculate_optimal_edge_threshold is imported
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
if self.optimized_edge_threshold is not None:
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
# Save optimized threshold
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
try:
# Ensure fold_results_dir is defined (should be available from context)
if self.io and fold_results_dir:
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''), # Use filename as key for io
base_dir=fold_results_dir, use_txt=True)
logger.info(f"Saved optimized edge threshold to {thresh_file}")
elif self.io:
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''),
section='results', use_txt=True) # Fallback save
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}")
else:
logger.warning("IOManager not available, cannot save optimized threshold.")
except Exception as e:
logger.error(f"Failed to save optimized edge threshold: {e}")
else:
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
self.optimized_edge_threshold = edge_thr_config # Fallback
except Exception as e: # Perform optimization using the dedicated function from metrics
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True) # Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
self.optimized_edge_threshold = edge_thr_config # Fallback # Ensure _calculate_optimal_edge_threshold is imported
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
if self.optimized_edge_threshold is not None:
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
# Save optimized threshold
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
try:
# Ensure fold_results_dir is defined (should be available from context)
# Assuming fold_results_dir is defined in the outer scope
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''), # Use filename as key for io
base_dir=fold_results_dir, use_txt=True)
logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
elif self.io:
# Fallback: Save to the main results directory if fold_results_dir isn't available
# Construct the path for logging clarity
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''),
section='results', use_txt=True) # Fallback save
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else:
logger.warning("IOManager not available, cannot save optimized threshold.")
except NameError: # Specifically catch if fold_results_dir is not defined
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
if self.io:
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''),
section='results', use_txt=True)
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else:
logger.warning("IOManager not available, cannot save optimized threshold.")
except Exception as e:
logger.error(f"Failed to save optimized edge threshold: {e}")
else:
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
self.optimized_edge_threshold = edge_thr_config # Fallback
except Exception as e:
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
self.optimized_edge_threshold = edge_thr_config # Fallback
else: else:
# If optimization is disabled, store the config threshold for consistent use # If optimization is disabled, store the config threshold for consistent use
self.optimized_edge_threshold = edge_thr_config self.optimized_edge_threshold = edge_thr_config
@ -1430,21 +1444,21 @@ class TradingPipeline:
# --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- # # --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- #
if p_cal_val_to_check is not None and y_dir_val_to_check is not None: if p_cal_val_to_check is not None and y_dir_val_to_check is not None:
self._perform_gru_validation_checks( self._perform_gru_validation_checks(
p_cal_val=p_cal_val_to_check, p_cal_val=p_cal_val_to_check,
y_dir_val=y_dir_val_to_check, y_dir_val=y_dir_val_to_check,
is_ternary=is_ternary_check is_ternary=is_ternary_check
) )
# Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold # Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold
else: else:
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.") logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
# --- Helper for GRU Validation Checks (Replaces edge check) --- # # --- Helper for GRU Validation Checks (Replaces edge check) --- #
def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary): def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary):
""" """
Performs GRU validation checks: Edge-Filtered Accuracy and Brier Score. Performs GRU validation checks: Edge-Filtered Accuracy and Brier Score.
Logs results and raises SystemExit if checks fail. Logs results and raises SystemExit if checks fail.
Args: Args:
p_cal_val: Calibrated probabilities on validation set. p_cal_val: Calibrated probabilities on validation set.
For binary: (N,) shape, P(up). For binary: (N,) shape, P(up).
@ -1455,19 +1469,18 @@ class TradingPipeline:
is_ternary (bool): Flag indicating if ternary classification is used. is_ternary (bool): Flag indicating if ternary classification is used.
""" """
logger.info(f"--- Fold {self.current_fold}: Performing GRU Validation Checks --- ") logger.info(f"--- Fold {self.current_fold}: Performing GRU Validation Checks --- ")
# --- Define thresholds (Consider moving to config) --- # # --- Define thresholds (Consider moving to config) --- #
validation_criteria = self.config.get('validation_gates', {}).get('gru', {}) validation_criteria = self.config.get('validation_gates', {}).get('gru', {})
edge_check_thr = validation_criteria.get('edge_filtered_acc_ci_lower_threshold', 0.55) edge_check_thr = validation_criteria.get('edge_filtered_acc_ci_lower_threshold', 0.55)
brier_check_thr = validation_criteria.get('brier_score_threshold', 0.19) brier_check_thr = validation_criteria.get('brier_score_threshold', 0.19)
min_edge_samples = validation_criteria.get('edge_filtered_min_samples', 30) min_edge_samples = validation_criteria.get('edge_filtered_min_samples', 30)
# --- End Thresholds --- # # --- End Thresholds --- #
# --- Determine Edge Threshold --- # # --- Determine Edge Threshold --- #
calib_config = self.config.get('calibration', {}) calib_config = self.config.get('calibration', {})
optimize_edge = calib_config.get('optimize_edge_threshold', False) optimize_edge = calib_config.get('optimize_edge_threshold', False)
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
self.fold_edge_threshold = edge_thr_config # Initialize with config value self.fold_edge_threshold = edge_thr_config # Initialize with config value
if optimize_edge and not is_ternary: if optimize_edge and not is_ternary:
@ -1483,7 +1496,7 @@ class TradingPipeline:
self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold}, self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold},
f'optimized_edge_threshold_fold_{self.current_fold}', f'optimized_edge_threshold_fold_{self.current_fold}',
base_dir=fold_results_dir, use_txt=True) base_dir=fold_results_dir, use_txt=True)
else: else:
logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.") logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.")
elif optimize_edge and is_ternary: elif optimize_edge and is_ternary:
logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}") logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}")
@ -1527,13 +1540,13 @@ class TradingPipeline:
passed_edge_acc = ci_lower >= edge_check_thr passed_edge_acc = ci_lower >= edge_check_thr
logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}") logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}")
except ValueError as binom_err: except ValueError as binom_err:
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.") logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
passed_edge_acc = False # Consider error as failure passed_edge_acc = False # Consider error as failure
else: else:
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.") logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
if n_filtered == 0: if n_filtered == 0:
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}") logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
passed_edge_acc = False # Consider NaN or 0 samples as failure passed_edge_acc = False # Consider NaN or 0 samples as failure
except Exception as e: except Exception as e:
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True) logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True)
passed_edge_acc = False # Consider error as failure passed_edge_acc = False # Consider error as failure
@ -1566,9 +1579,9 @@ class TradingPipeline:
error_msg = f"FOLD {self.current_fold} GRU VALIDATION FAILED: Edge Acc Pass={passed_edge_acc} (Req CI>={edge_check_thr}), Brier Pass={passed_brier} (Req Score<={brier_check_thr}). Aborting fold." error_msg = f"FOLD {self.current_fold} GRU VALIDATION FAILED: Edge Acc Pass={passed_edge_acc} (Req CI>={edge_check_thr}), Brier Pass={passed_brier} (Req Score<={brier_check_thr}). Aborting fold."
logger.error(error_msg) logger.error(error_msg)
# Use sys.exit with a specific message for clarity # Use sys.exit with a specific message for clarity
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).") sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
else: else: # Corrected indentation
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
# --- End Validation Helper --- # # --- End Validation Helper --- #
def train_or_load_sac(self): def train_or_load_sac(self):
@ -1628,7 +1641,7 @@ class TradingPipeline:
# # --- Revision 1: Handle Rolling Calibrator Conflict --- # # # --- Revision 1: Handle Rolling Calibrator Conflict --- #
# ... (block removed) ... # ... (block removed) ...
# # --- End Revision 1 --- # # # --- End Revision 1 --- #
# Start the training process # Start the training process
final_agent_path = self.sac_trainer.train(gru_run_id_for_sac=self.gru_model_run_id_loaded_from) final_agent_path = self.sac_trainer.train(gru_run_id_for_sac=self.gru_model_run_id_loaded_from)
@ -1791,6 +1804,7 @@ class TradingPipeline:
# Get raw predictions needed for rolling calibration # Get raw predictions needed for rolling calibration
p_raw_test_for_bt = None p_raw_test_for_bt = None
logits_test_for_bt = None logits_test_for_bt = None
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
if self.config.get('calibration', {}).get('rolling_enabled', False): if self.config.get('calibration', {}).get('rolling_enabled', False):
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...") logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
if is_ternary: if is_ternary:
@ -1798,43 +1812,43 @@ class TradingPipeline:
if logits_test_for_bt is None: if logits_test_for_bt is None:
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.") logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
else: else: # Corrected indentation
preds_test_raw = self.gru_handler.predict(self.X_test_seq) preds_test_raw = self.gru_handler.predict(self.X_test_seq)
if preds_test_raw is None or len(preds_test_raw) < 3: if preds_test_raw is None or len(preds_test_raw) < 3:
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.") logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
p_raw_test_for_bt = preds_test_raw[2].flatten() p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
# Get the edge threshold determined during validation (optimized or fixed) # Get the edge threshold determined during validation (optimized or fixed)
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1)) edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.") logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
try: try: # Corrected indentation
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest( self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
sac_agent_load_path=self.sac_agent_load_path, sac_agent_load_path=self.sac_agent_load_path,
X_test_seq=self.X_test_seq, X_test_seq=self.X_test_seq,
y_test_seq_dict=self.y_test_seq_dict, y_test_seq_dict=self.y_test_seq_dict,
test_indices=self.test_indices, test_indices=self.test_indices,
gru_handler=self.gru_handler, gru_handler=self.gru_handler,
# --- Pass Calibrator instances and initial state --- # # --- Pass Calibrator instances and initial state --- #
calibrator=calibrator_instance, calibrator=calibrator_instance,
vector_calibrator=vector_calibrator_instance, vector_calibrator=vector_calibrator_instance,
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
fold_edge_threshold=edge_threshold_for_bt, fold_edge_threshold=edge_threshold_for_bt,
# --- Pass raw predictions if needed for rolling cal --- # # --- Pass raw predictions if needed for rolling cal --- #
p_raw_test=p_raw_test_for_bt, p_raw_test=p_raw_test_for_bt,
logits_test=logits_test_for_bt, logits_test=logits_test_for_bt,
# --- Pass original prices --- # # --- Pass original prices --- #
original_prices=self.df_test_original, # Pass the DataFrame original_prices=self.df_test_original, # Pass the DataFrame
is_ternary=self.use_ternary, is_ternary=self.use_ternary,
fold_num=self.current_fold fold_num=self.current_fold
) )
except SystemExit as e: except SystemExit as e: # Corrected indentation
# Catch exits from backtester validation/execution # Catch exits from backtester validation/execution
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}") logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
raise # Re-raise to stop the fold raise # Re-raise to stop the fold
except Exception as e: except Exception as e: # Corrected indentation
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True) logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
# Treat as failure, ensure metrics are None # Treat as failure, ensure metrics are None
self.backtest_results_df = None self.backtest_results_df = None
@ -2027,7 +2041,7 @@ class TradingPipeline:
def execute(self): def execute(self):
"""Runs the full trading pipeline end-to-end.""" """Runs the full trading pipeline end-to-end."""
logger.info(f"--- Starting Trading Pipeline: Run ID {self.run_id} ---") logger.info(f"--- Starting Trading Pipeline: Run ID {self.run_id} ---")
# 1. Load and Preprocess Data # 1. Load and Preprocess Data
self.load_and_preprocess_data() self.load_and_preprocess_data()
if self.data_processed is None: # Check if data loading failed if self.data_processed is None: # Check if data loading failed
@ -2073,7 +2087,7 @@ class TradingPipeline:
else: else:
# Perform edge accuracy check only if calibration happened and model exists # Perform edge accuracy check only if calibration happened and model exists
self._perform_gru_validation_checks( self._perform_gru_validation_checks(
p_cal_val=self.p_cal_val, p_cal_val=self.p_cal_val,
y_dir_val=self.y_dir_val, y_dir_val=self.y_dir_val,
is_ternary=self.use_ternary is_ternary=self.use_ternary
) )
@ -2092,7 +2106,7 @@ class TradingPipeline:
# --- Walk-Forward Fold Generation --- # # --- Walk-Forward Fold Generation --- #
def _generate_walk_forward_folds(self) -> Iterator[Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp]]: def _generate_walk_forward_folds(self) -> Iterator[Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp]]:
""" """
Generates start and end timestamps for train, validation, and test sets Generates start and end timestamps for train, validation, and test sets
for each walk-forward fold based on config settings. for each walk-forward fold based on config settings.
Requires self.df_raw to be loaded first to determine the full date range. Requires self.df_raw to be loaded first to determine the full date range.
""" """
@ -2741,8 +2755,8 @@ class TradingPipeline:
except Exception as e: except Exception as e:
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True) logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
# --- Entry Point --- # # --- Entry Point --- #
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.") parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")