iundentation fix pipeline
This commit is contained in:
parent
a02dd4342e
commit
a86cdb2c8f
@ -227,7 +227,7 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections
|
||||
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
|
||||
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
|
||||
3. **Run Pipeline:**
|
||||
```bash
|
||||
```bash
|
||||
# From project root (develop/gru_sac_predictor/)
|
||||
python gru_sac_predictor/run.py --config path/to/your_config.yaml
|
||||
```
|
||||
|
||||
@ -455,7 +455,7 @@ class Backtester:
|
||||
logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})")
|
||||
current_optimal_T = new_T
|
||||
calibrator.optimal_T = new_T # Update instance
|
||||
else:
|
||||
else:
|
||||
logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).")
|
||||
|
||||
last_recalib_step = i
|
||||
|
||||
@ -1094,116 +1094,116 @@ class TradingPipeline:
|
||||
patience=patience
|
||||
)
|
||||
|
||||
if self.gru_model is None:
|
||||
logging.error("GRU model training failed. Exiting.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Save the newly trained model
|
||||
saved_path = self.gru_handler.save() # Uses run_id from handler
|
||||
if saved_path:
|
||||
logging.info(f"Newly trained GRU model saved to {saved_path}")
|
||||
if self.gru_model is None:
|
||||
logging.error("GRU model training failed. Exiting.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
logging.warning("Failed to save the newly trained GRU model.")
|
||||
# Set the loaded ID to the current run ID
|
||||
self.gru_model_run_id_loaded_from = self.run_id
|
||||
logging.info(f"Using GRU model trained in current run: {self.run_id}")
|
||||
|
||||
# --- V3 Output Contract: Plot Learning Curve --- #
|
||||
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
|
||||
# Infer log dir path based on current models dir
|
||||
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
|
||||
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
|
||||
if os.path.exists(csv_log_path):
|
||||
logging.info(f"Plotting learning curve from {csv_log_path}...")
|
||||
try:
|
||||
history_df = pd.read_csv(csv_log_path)
|
||||
|
||||
# Determine metric keys (handle v2 vs v3 differences if necessary)
|
||||
loss_key = 'loss'
|
||||
val_loss_key = 'val_loss'
|
||||
acc_key = None
|
||||
val_acc_key = None
|
||||
if 'dir3_accuracy' in history_df.columns: # V3 specific?
|
||||
acc_key = 'dir3_accuracy'
|
||||
val_acc_key = 'val_dir3_accuracy'
|
||||
elif 'accuracy' in history_df.columns: # V2 or other?
|
||||
acc_key = 'accuracy'
|
||||
val_acc_key = 'val_accuracy'
|
||||
|
||||
if acc_key is None:
|
||||
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
|
||||
n_panes = 1 # Only plot loss
|
||||
else:
|
||||
n_panes = 2 # Plot loss and accuracy
|
||||
|
||||
# Get figure settings
|
||||
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
|
||||
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
|
||||
footer_text = "© GRU-SAC v3"
|
||||
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
# Adjust figsize height based on panes
|
||||
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
|
||||
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
|
||||
|
||||
if n_panes == 1:
|
||||
ax_loss = axes # Single axis
|
||||
else:
|
||||
ax_loss, ax_acc = axes # Multiple axes
|
||||
|
||||
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
|
||||
|
||||
# Pane 1: Loss (Log Scale)
|
||||
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
|
||||
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
|
||||
ax_loss.set_yscale('log')
|
||||
ax_loss.set_ylabel('Loss (Log Scale)')
|
||||
ax_loss.legend()
|
||||
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
|
||||
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
|
||||
|
||||
# Pane 2: Accuracy (if available)
|
||||
if n_panes == 2:
|
||||
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
|
||||
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
|
||||
ax_acc.set_ylabel('Accuracy')
|
||||
ax_acc.set_xlabel('Epoch')
|
||||
ax_acc.legend()
|
||||
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
|
||||
else:
|
||||
# If only loss pane, set xlabel there
|
||||
ax_loss.set_xlabel('Epoch')
|
||||
|
||||
# Add vertical line for early stopping epoch if available
|
||||
if hasattr(history, 'epoch') and len(history.epoch) > 0:
|
||||
# Early stopping epoch is the number of epochs run
|
||||
early_stop_epoch = len(history.epoch)
|
||||
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
|
||||
for ax in fig.axes:
|
||||
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
|
||||
# Add legend entry to the last plot
|
||||
fig.axes[-1].legend()
|
||||
|
||||
# Add footer
|
||||
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
|
||||
verticalalignment='bottom', fontsize=8, color='gray')
|
||||
|
||||
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
|
||||
|
||||
# Save figure using IOManager
|
||||
self.io.save_figure(fig, "gru_learning_curve", section='results')
|
||||
logging.info("GRU learning curve plot saved.")
|
||||
plt.close(fig)
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
|
||||
# Save the newly trained model
|
||||
saved_path = self.gru_handler.save() # Uses run_id from handler
|
||||
if saved_path:
|
||||
logging.info(f"Newly trained GRU model saved to {saved_path}")
|
||||
else:
|
||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||
elif not self.io:
|
||||
logging.warning("IOManager not available, skipping GRU learning curve plot.")
|
||||
# --- End Plot Learning Curve --- #
|
||||
logging.warning("Failed to save the newly trained GRU model.")
|
||||
# Set the loaded ID to the current run ID
|
||||
self.gru_model_run_id_loaded_from = self.run_id
|
||||
logging.info(f"Using GRU model trained in current run: {self.run_id}")
|
||||
|
||||
# --- V3 Output Contract: Plot Learning Curve --- #
|
||||
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
|
||||
# Infer log dir path based on current models dir
|
||||
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
|
||||
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
|
||||
if os.path.exists(csv_log_path):
|
||||
logging.info(f"Plotting learning curve from {csv_log_path}...")
|
||||
try:
|
||||
history_df = pd.read_csv(csv_log_path)
|
||||
|
||||
# Determine metric keys (handle v2 vs v3 differences if necessary)
|
||||
loss_key = 'loss'
|
||||
val_loss_key = 'val_loss'
|
||||
acc_key = None
|
||||
val_acc_key = None
|
||||
if 'dir3_accuracy' in history_df.columns: # V3 specific?
|
||||
acc_key = 'dir3_accuracy'
|
||||
val_acc_key = 'val_dir3_accuracy'
|
||||
elif 'accuracy' in history_df.columns: # V2 or other?
|
||||
acc_key = 'accuracy'
|
||||
val_acc_key = 'val_accuracy'
|
||||
|
||||
if acc_key is None:
|
||||
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
|
||||
n_panes = 1 # Only plot loss
|
||||
else:
|
||||
n_panes = 2 # Plot loss and accuracy
|
||||
|
||||
# Get figure settings
|
||||
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
|
||||
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
|
||||
footer_text = "© GRU-SAC v3"
|
||||
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
# Adjust figsize height based on panes
|
||||
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
|
||||
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
|
||||
|
||||
if n_panes == 1:
|
||||
ax_loss = axes # Single axis
|
||||
else:
|
||||
ax_loss, ax_acc = axes # Multiple axes
|
||||
|
||||
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
|
||||
|
||||
# Pane 1: Loss (Log Scale)
|
||||
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
|
||||
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
|
||||
ax_loss.set_yscale('log')
|
||||
ax_loss.set_ylabel('Loss (Log Scale)')
|
||||
ax_loss.legend()
|
||||
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
|
||||
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
|
||||
|
||||
# Pane 2: Accuracy (if available)
|
||||
if n_panes == 2:
|
||||
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
|
||||
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
|
||||
ax_acc.set_ylabel('Accuracy')
|
||||
ax_acc.set_xlabel('Epoch')
|
||||
ax_acc.legend()
|
||||
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
|
||||
else:
|
||||
# If only loss pane, set xlabel there
|
||||
ax_loss.set_xlabel('Epoch')
|
||||
|
||||
# Add vertical line for early stopping epoch if available
|
||||
if hasattr(history, 'epoch') and len(history.epoch) > 0:
|
||||
# Early stopping epoch is the number of epochs run
|
||||
early_stop_epoch = len(history.epoch)
|
||||
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
|
||||
for ax in fig.axes:
|
||||
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
|
||||
# Add legend entry to the last plot
|
||||
fig.axes[-1].legend()
|
||||
|
||||
# Add footer
|
||||
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
|
||||
verticalalignment='bottom', fontsize=8, color='gray')
|
||||
|
||||
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
|
||||
|
||||
# Save figure using IOManager
|
||||
self.io.save_figure(fig, "gru_learning_curve", section='results')
|
||||
logging.info("GRU learning curve plot saved.")
|
||||
plt.close(fig)
|
||||
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
|
||||
else:
|
||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||
elif not self.io:
|
||||
logging.warning("IOManager not available, skipping GRU learning curve plot.")
|
||||
# --- End Plot Learning Curve --- #
|
||||
|
||||
else: # Load pre-trained GRU model
|
||||
load_run_id = gru_cfg.get('model_load_run_id', None)
|
||||
@ -1357,8 +1357,8 @@ class TradingPipeline:
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
||||
self.optimal_T = None
|
||||
else:
|
||||
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.")
|
||||
else: # Covers cases where calibration method is wrong or mismatch with ternary state
|
||||
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
|
||||
self.optimal_T = None
|
||||
self.vector_cal_params = None
|
||||
|
||||
@ -1368,60 +1368,74 @@ class TradingPipeline:
|
||||
self.optimized_edge_threshold = None # Reset for the fold
|
||||
|
||||
if optimize_edge:
|
||||
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
|
||||
try:
|
||||
# Prepare y_true for optimization (needs binary 0/1)
|
||||
if y_dir_val_to_check is None:
|
||||
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
|
||||
y_true_for_opt = None
|
||||
p_cal_for_opt = None
|
||||
if is_ternary_check:
|
||||
if p_cal_val_to_check is not None:
|
||||
# Convert ternary to binary: P(up) vs others
|
||||
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
|
||||
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
|
||||
else:
|
||||
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
|
||||
else:
|
||||
# Binary case
|
||||
if p_cal_val_to_check is not None:
|
||||
p_cal_for_opt = p_cal_val_to_check
|
||||
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
|
||||
else:
|
||||
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
|
||||
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
|
||||
try:
|
||||
# Prepare y_true for optimization (needs binary 0/1)
|
||||
if y_dir_val_to_check is None:
|
||||
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
|
||||
y_true_for_opt = None
|
||||
p_cal_for_opt = None
|
||||
if is_ternary_check:
|
||||
if p_cal_val_to_check is not None:
|
||||
# Convert ternary to binary: P(up) vs others
|
||||
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
|
||||
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
|
||||
else:
|
||||
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
|
||||
else:
|
||||
# Binary case
|
||||
if p_cal_val_to_check is not None:
|
||||
p_cal_for_opt = p_cal_val_to_check
|
||||
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
|
||||
else:
|
||||
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
|
||||
|
||||
# Perform optimization using the dedicated function from metrics
|
||||
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
|
||||
# Ensure _calculate_optimal_edge_threshold is imported
|
||||
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
|
||||
# Perform optimization using the dedicated function from metrics
|
||||
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
|
||||
# Ensure _calculate_optimal_edge_threshold is imported
|
||||
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
|
||||
|
||||
if self.optimized_edge_threshold is not None:
|
||||
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
|
||||
# Save optimized threshold
|
||||
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
||||
try:
|
||||
# Ensure fold_results_dir is defined (should be available from context)
|
||||
if self.io and fold_results_dir:
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''), # Use filename as key for io
|
||||
base_dir=fold_results_dir, use_txt=True)
|
||||
logger.info(f"Saved optimized edge threshold to {thresh_file}")
|
||||
elif self.io:
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''),
|
||||
section='results', use_txt=True) # Fallback save
|
||||
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}")
|
||||
else:
|
||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save optimized edge threshold: {e}")
|
||||
else:
|
||||
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
|
||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||
if self.optimized_edge_threshold is not None:
|
||||
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
|
||||
# Save optimized threshold
|
||||
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
||||
try:
|
||||
# Ensure fold_results_dir is defined (should be available from context)
|
||||
# Assuming fold_results_dir is defined in the outer scope
|
||||
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''), # Use filename as key for io
|
||||
base_dir=fold_results_dir, use_txt=True)
|
||||
logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
|
||||
elif self.io:
|
||||
# Fallback: Save to the main results directory if fold_results_dir isn't available
|
||||
# Construct the path for logging clarity
|
||||
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''),
|
||||
section='results', use_txt=True) # Fallback save
|
||||
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||
else:
|
||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||
except NameError: # Specifically catch if fold_results_dir is not defined
|
||||
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
|
||||
if self.io:
|
||||
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''),
|
||||
section='results', use_txt=True)
|
||||
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||
else:
|
||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save optimized edge threshold: {e}")
|
||||
else:
|
||||
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
|
||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
|
||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
|
||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||
else:
|
||||
# If optimization is disabled, store the config threshold for consistent use
|
||||
self.optimized_edge_threshold = edge_thr_config
|
||||
@ -1430,14 +1444,14 @@ class TradingPipeline:
|
||||
|
||||
# --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- #
|
||||
if p_cal_val_to_check is not None and y_dir_val_to_check is not None:
|
||||
self._perform_gru_validation_checks(
|
||||
p_cal_val=p_cal_val_to_check,
|
||||
y_dir_val=y_dir_val_to_check,
|
||||
is_ternary=is_ternary_check
|
||||
)
|
||||
self._perform_gru_validation_checks(
|
||||
p_cal_val=p_cal_val_to_check,
|
||||
y_dir_val=y_dir_val_to_check,
|
||||
is_ternary=is_ternary_check
|
||||
)
|
||||
# Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold
|
||||
else:
|
||||
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
|
||||
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
|
||||
|
||||
# --- Helper for GRU Validation Checks (Replaces edge check) --- #
|
||||
def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary):
|
||||
@ -1467,7 +1481,6 @@ class TradingPipeline:
|
||||
calib_config = self.config.get('calibration', {})
|
||||
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
||||
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
||||
|
||||
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
||||
|
||||
if optimize_edge and not is_ternary:
|
||||
@ -1483,7 +1496,7 @@ class TradingPipeline:
|
||||
self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold},
|
||||
f'optimized_edge_threshold_fold_{self.current_fold}',
|
||||
base_dir=fold_results_dir, use_txt=True)
|
||||
else:
|
||||
else:
|
||||
logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.")
|
||||
elif optimize_edge and is_ternary:
|
||||
logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}")
|
||||
@ -1527,13 +1540,13 @@ class TradingPipeline:
|
||||
passed_edge_acc = ci_lower >= edge_check_thr
|
||||
logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}")
|
||||
except ValueError as binom_err:
|
||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
|
||||
passed_edge_acc = False # Consider error as failure
|
||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
|
||||
passed_edge_acc = False # Consider error as failure
|
||||
else:
|
||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
|
||||
if n_filtered == 0:
|
||||
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
|
||||
passed_edge_acc = False # Consider NaN or 0 samples as failure
|
||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
|
||||
if n_filtered == 0:
|
||||
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
|
||||
passed_edge_acc = False # Consider NaN or 0 samples as failure
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True)
|
||||
passed_edge_acc = False # Consider error as failure
|
||||
@ -1567,8 +1580,8 @@ class TradingPipeline:
|
||||
logger.error(error_msg)
|
||||
# Use sys.exit with a specific message for clarity
|
||||
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
||||
else:
|
||||
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).")
|
||||
else: # Corrected indentation
|
||||
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
|
||||
# --- End Validation Helper --- #
|
||||
|
||||
def train_or_load_sac(self):
|
||||
@ -1791,6 +1804,7 @@ class TradingPipeline:
|
||||
# Get raw predictions needed for rolling calibration
|
||||
p_raw_test_for_bt = None
|
||||
logits_test_for_bt = None
|
||||
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
|
||||
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
||||
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
||||
if is_ternary:
|
||||
@ -1798,43 +1812,43 @@ class TradingPipeline:
|
||||
if logits_test_for_bt is None:
|
||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
||||
else:
|
||||
else: # Corrected indentation
|
||||
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
||||
if preds_test_raw is None or len(preds_test_raw) < 3:
|
||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
||||
p_raw_test_for_bt = preds_test_raw[2].flatten()
|
||||
p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
|
||||
|
||||
# Get the edge threshold determined during validation (optimized or fixed)
|
||||
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
||||
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
||||
|
||||
try:
|
||||
try: # Corrected indentation
|
||||
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
||||
sac_agent_load_path=self.sac_agent_load_path,
|
||||
X_test_seq=self.X_test_seq,
|
||||
y_test_seq_dict=self.y_test_seq_dict,
|
||||
test_indices=self.test_indices,
|
||||
gru_handler=self.gru_handler,
|
||||
# --- Pass Calibrator instances and initial state --- #
|
||||
calibrator=calibrator_instance,
|
||||
vector_calibrator=vector_calibrator_instance,
|
||||
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
|
||||
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
|
||||
fold_edge_threshold=edge_threshold_for_bt,
|
||||
# --- Pass raw predictions if needed for rolling cal --- #
|
||||
p_raw_test=p_raw_test_for_bt,
|
||||
logits_test=logits_test_for_bt,
|
||||
# --- Pass original prices --- #
|
||||
original_prices=self.df_test_original, # Pass the DataFrame
|
||||
is_ternary=self.use_ternary,
|
||||
fold_num=self.current_fold
|
||||
)
|
||||
except SystemExit as e:
|
||||
# --- Pass Calibrator instances and initial state --- #
|
||||
calibrator=calibrator_instance,
|
||||
vector_calibrator=vector_calibrator_instance,
|
||||
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
|
||||
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
|
||||
fold_edge_threshold=edge_threshold_for_bt,
|
||||
# --- Pass raw predictions if needed for rolling cal --- #
|
||||
p_raw_test=p_raw_test_for_bt,
|
||||
logits_test=logits_test_for_bt,
|
||||
# --- Pass original prices --- #
|
||||
original_prices=self.df_test_original, # Pass the DataFrame
|
||||
is_ternary=self.use_ternary,
|
||||
fold_num=self.current_fold
|
||||
)
|
||||
except SystemExit as e: # Corrected indentation
|
||||
# Catch exits from backtester validation/execution
|
||||
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
||||
raise # Re-raise to stop the fold
|
||||
except Exception as e:
|
||||
except Exception as e: # Corrected indentation
|
||||
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
||||
# Treat as failure, ensure metrics are None
|
||||
self.backtest_results_df = None
|
||||
@ -2742,7 +2756,7 @@ class TradingPipeline:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
|
||||
|
||||
# --- Entry Point --- #
|
||||
# --- Entry Point --- #
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user