iundentation fix pipeline
This commit is contained in:
parent
a02dd4342e
commit
a86cdb2c8f
@ -227,7 +227,7 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections
|
|||||||
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
|
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
|
||||||
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
|
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
|
||||||
3. **Run Pipeline:**
|
3. **Run Pipeline:**
|
||||||
```bash
|
```bash
|
||||||
# From project root (develop/gru_sac_predictor/)
|
# From project root (develop/gru_sac_predictor/)
|
||||||
python gru_sac_predictor/run.py --config path/to/your_config.yaml
|
python gru_sac_predictor/run.py --config path/to/your_config.yaml
|
||||||
```
|
```
|
||||||
|
|||||||
@ -455,7 +455,7 @@ class Backtester:
|
|||||||
logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})")
|
logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})")
|
||||||
current_optimal_T = new_T
|
current_optimal_T = new_T
|
||||||
calibrator.optimal_T = new_T # Update instance
|
calibrator.optimal_T = new_T # Update instance
|
||||||
else:
|
else:
|
||||||
logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).")
|
logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).")
|
||||||
|
|
||||||
last_recalib_step = i
|
last_recalib_step = i
|
||||||
|
|||||||
@ -1094,116 +1094,116 @@ class TradingPipeline:
|
|||||||
patience=patience
|
patience=patience
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.gru_model is None:
|
if self.gru_model is None:
|
||||||
logging.error("GRU model training failed. Exiting.")
|
logging.error("GRU model training failed. Exiting.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
|
||||||
# Save the newly trained model
|
|
||||||
saved_path = self.gru_handler.save() # Uses run_id from handler
|
|
||||||
if saved_path:
|
|
||||||
logging.info(f"Newly trained GRU model saved to {saved_path}")
|
|
||||||
else:
|
else:
|
||||||
logging.warning("Failed to save the newly trained GRU model.")
|
# Save the newly trained model
|
||||||
# Set the loaded ID to the current run ID
|
saved_path = self.gru_handler.save() # Uses run_id from handler
|
||||||
self.gru_model_run_id_loaded_from = self.run_id
|
if saved_path:
|
||||||
logging.info(f"Using GRU model trained in current run: {self.run_id}")
|
logging.info(f"Newly trained GRU model saved to {saved_path}")
|
||||||
|
|
||||||
# --- V3 Output Contract: Plot Learning Curve --- #
|
|
||||||
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
|
|
||||||
# Infer log dir path based on current models dir
|
|
||||||
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
|
|
||||||
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
|
|
||||||
if os.path.exists(csv_log_path):
|
|
||||||
logging.info(f"Plotting learning curve from {csv_log_path}...")
|
|
||||||
try:
|
|
||||||
history_df = pd.read_csv(csv_log_path)
|
|
||||||
|
|
||||||
# Determine metric keys (handle v2 vs v3 differences if necessary)
|
|
||||||
loss_key = 'loss'
|
|
||||||
val_loss_key = 'val_loss'
|
|
||||||
acc_key = None
|
|
||||||
val_acc_key = None
|
|
||||||
if 'dir3_accuracy' in history_df.columns: # V3 specific?
|
|
||||||
acc_key = 'dir3_accuracy'
|
|
||||||
val_acc_key = 'val_dir3_accuracy'
|
|
||||||
elif 'accuracy' in history_df.columns: # V2 or other?
|
|
||||||
acc_key = 'accuracy'
|
|
||||||
val_acc_key = 'val_accuracy'
|
|
||||||
|
|
||||||
if acc_key is None:
|
|
||||||
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
|
|
||||||
n_panes = 1 # Only plot loss
|
|
||||||
else:
|
|
||||||
n_panes = 2 # Plot loss and accuracy
|
|
||||||
|
|
||||||
# Get figure settings
|
|
||||||
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
|
|
||||||
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
|
|
||||||
footer_text = "© GRU-SAC v3"
|
|
||||||
|
|
||||||
plt.style.use('seaborn-v0_8-darkgrid')
|
|
||||||
# Adjust figsize height based on panes
|
|
||||||
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
|
|
||||||
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
|
|
||||||
|
|
||||||
if n_panes == 1:
|
|
||||||
ax_loss = axes # Single axis
|
|
||||||
else:
|
|
||||||
ax_loss, ax_acc = axes # Multiple axes
|
|
||||||
|
|
||||||
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
|
|
||||||
|
|
||||||
# Pane 1: Loss (Log Scale)
|
|
||||||
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
|
|
||||||
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
|
|
||||||
ax_loss.set_yscale('log')
|
|
||||||
ax_loss.set_ylabel('Loss (Log Scale)')
|
|
||||||
ax_loss.legend()
|
|
||||||
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
|
|
||||||
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
|
|
||||||
|
|
||||||
# Pane 2: Accuracy (if available)
|
|
||||||
if n_panes == 2:
|
|
||||||
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
|
|
||||||
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
|
|
||||||
ax_acc.set_ylabel('Accuracy')
|
|
||||||
ax_acc.set_xlabel('Epoch')
|
|
||||||
ax_acc.legend()
|
|
||||||
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
|
|
||||||
else:
|
|
||||||
# If only loss pane, set xlabel there
|
|
||||||
ax_loss.set_xlabel('Epoch')
|
|
||||||
|
|
||||||
# Add vertical line for early stopping epoch if available
|
|
||||||
if hasattr(history, 'epoch') and len(history.epoch) > 0:
|
|
||||||
# Early stopping epoch is the number of epochs run
|
|
||||||
early_stop_epoch = len(history.epoch)
|
|
||||||
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
|
|
||||||
for ax in fig.axes:
|
|
||||||
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
|
|
||||||
# Add legend entry to the last plot
|
|
||||||
fig.axes[-1].legend()
|
|
||||||
|
|
||||||
# Add footer
|
|
||||||
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
|
|
||||||
verticalalignment='bottom', fontsize=8, color='gray')
|
|
||||||
|
|
||||||
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
|
|
||||||
|
|
||||||
# Save figure using IOManager
|
|
||||||
self.io.save_figure(fig, "gru_learning_curve", section='results')
|
|
||||||
logging.info("GRU learning curve plot saved.")
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
|
|
||||||
else:
|
else:
|
||||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
logging.warning("Failed to save the newly trained GRU model.")
|
||||||
elif not self.io:
|
# Set the loaded ID to the current run ID
|
||||||
logging.warning("IOManager not available, skipping GRU learning curve plot.")
|
self.gru_model_run_id_loaded_from = self.run_id
|
||||||
# --- End Plot Learning Curve --- #
|
logging.info(f"Using GRU model trained in current run: {self.run_id}")
|
||||||
|
|
||||||
|
# --- V3 Output Contract: Plot Learning Curve --- #
|
||||||
|
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
|
||||||
|
# Infer log dir path based on current models dir
|
||||||
|
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
|
||||||
|
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
|
||||||
|
if os.path.exists(csv_log_path):
|
||||||
|
logging.info(f"Plotting learning curve from {csv_log_path}...")
|
||||||
|
try:
|
||||||
|
history_df = pd.read_csv(csv_log_path)
|
||||||
|
|
||||||
|
# Determine metric keys (handle v2 vs v3 differences if necessary)
|
||||||
|
loss_key = 'loss'
|
||||||
|
val_loss_key = 'val_loss'
|
||||||
|
acc_key = None
|
||||||
|
val_acc_key = None
|
||||||
|
if 'dir3_accuracy' in history_df.columns: # V3 specific?
|
||||||
|
acc_key = 'dir3_accuracy'
|
||||||
|
val_acc_key = 'val_dir3_accuracy'
|
||||||
|
elif 'accuracy' in history_df.columns: # V2 or other?
|
||||||
|
acc_key = 'accuracy'
|
||||||
|
val_acc_key = 'val_accuracy'
|
||||||
|
|
||||||
|
if acc_key is None:
|
||||||
|
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
|
||||||
|
n_panes = 1 # Only plot loss
|
||||||
|
else:
|
||||||
|
n_panes = 2 # Plot loss and accuracy
|
||||||
|
|
||||||
|
# Get figure settings
|
||||||
|
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
|
||||||
|
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
|
||||||
|
footer_text = "© GRU-SAC v3"
|
||||||
|
|
||||||
|
plt.style.use('seaborn-v0_8-darkgrid')
|
||||||
|
# Adjust figsize height based on panes
|
||||||
|
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
|
||||||
|
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
|
||||||
|
|
||||||
|
if n_panes == 1:
|
||||||
|
ax_loss = axes # Single axis
|
||||||
|
else:
|
||||||
|
ax_loss, ax_acc = axes # Multiple axes
|
||||||
|
|
||||||
|
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
|
||||||
|
|
||||||
|
# Pane 1: Loss (Log Scale)
|
||||||
|
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
|
||||||
|
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
|
||||||
|
ax_loss.set_yscale('log')
|
||||||
|
ax_loss.set_ylabel('Loss (Log Scale)')
|
||||||
|
ax_loss.legend()
|
||||||
|
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
|
||||||
|
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
|
||||||
|
|
||||||
|
# Pane 2: Accuracy (if available)
|
||||||
|
if n_panes == 2:
|
||||||
|
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
|
||||||
|
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
|
||||||
|
ax_acc.set_ylabel('Accuracy')
|
||||||
|
ax_acc.set_xlabel('Epoch')
|
||||||
|
ax_acc.legend()
|
||||||
|
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
|
||||||
|
else:
|
||||||
|
# If only loss pane, set xlabel there
|
||||||
|
ax_loss.set_xlabel('Epoch')
|
||||||
|
|
||||||
|
# Add vertical line for early stopping epoch if available
|
||||||
|
if hasattr(history, 'epoch') and len(history.epoch) > 0:
|
||||||
|
# Early stopping epoch is the number of epochs run
|
||||||
|
early_stop_epoch = len(history.epoch)
|
||||||
|
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
|
||||||
|
for ax in fig.axes:
|
||||||
|
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
|
||||||
|
# Add legend entry to the last plot
|
||||||
|
fig.axes[-1].legend()
|
||||||
|
|
||||||
|
# Add footer
|
||||||
|
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
|
||||||
|
verticalalignment='bottom', fontsize=8, color='gray')
|
||||||
|
|
||||||
|
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
|
||||||
|
|
||||||
|
# Save figure using IOManager
|
||||||
|
self.io.save_figure(fig, "gru_learning_curve", section='results')
|
||||||
|
logging.info("GRU learning curve plot saved.")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||||
|
elif not self.io:
|
||||||
|
logging.warning("IOManager not available, skipping GRU learning curve plot.")
|
||||||
|
# --- End Plot Learning Curve --- #
|
||||||
|
|
||||||
else: # Load pre-trained GRU model
|
else: # Load pre-trained GRU model
|
||||||
load_run_id = gru_cfg.get('model_load_run_id', None)
|
load_run_id = gru_cfg.get('model_load_run_id', None)
|
||||||
@ -1357,8 +1357,8 @@ class TradingPipeline:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
||||||
self.optimal_T = None
|
self.optimal_T = None
|
||||||
else:
|
else: # Covers cases where calibration method is wrong or mismatch with ternary state
|
||||||
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.")
|
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
|
||||||
self.optimal_T = None
|
self.optimal_T = None
|
||||||
self.vector_cal_params = None
|
self.vector_cal_params = None
|
||||||
|
|
||||||
@ -1368,60 +1368,74 @@ class TradingPipeline:
|
|||||||
self.optimized_edge_threshold = None # Reset for the fold
|
self.optimized_edge_threshold = None # Reset for the fold
|
||||||
|
|
||||||
if optimize_edge:
|
if optimize_edge:
|
||||||
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
|
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
|
||||||
try:
|
try:
|
||||||
# Prepare y_true for optimization (needs binary 0/1)
|
# Prepare y_true for optimization (needs binary 0/1)
|
||||||
if y_dir_val_to_check is None:
|
if y_dir_val_to_check is None:
|
||||||
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
|
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
|
||||||
y_true_for_opt = None
|
y_true_for_opt = None
|
||||||
p_cal_for_opt = None
|
p_cal_for_opt = None
|
||||||
if is_ternary_check:
|
if is_ternary_check:
|
||||||
if p_cal_val_to_check is not None:
|
if p_cal_val_to_check is not None:
|
||||||
# Convert ternary to binary: P(up) vs others
|
# Convert ternary to binary: P(up) vs others
|
||||||
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
|
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
|
||||||
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
|
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
|
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
|
||||||
else:
|
else:
|
||||||
# Binary case
|
# Binary case
|
||||||
if p_cal_val_to_check is not None:
|
if p_cal_val_to_check is not None:
|
||||||
p_cal_for_opt = p_cal_val_to_check
|
p_cal_for_opt = p_cal_val_to_check
|
||||||
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
|
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
|
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
|
||||||
|
|
||||||
# Perform optimization using the dedicated function from metrics
|
# Perform optimization using the dedicated function from metrics
|
||||||
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
|
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
|
||||||
# Ensure _calculate_optimal_edge_threshold is imported
|
# Ensure _calculate_optimal_edge_threshold is imported
|
||||||
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
|
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
|
||||||
|
|
||||||
if self.optimized_edge_threshold is not None:
|
if self.optimized_edge_threshold is not None:
|
||||||
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
|
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
|
||||||
# Save optimized threshold
|
# Save optimized threshold
|
||||||
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
||||||
try:
|
try:
|
||||||
# Ensure fold_results_dir is defined (should be available from context)
|
# Ensure fold_results_dir is defined (should be available from context)
|
||||||
if self.io and fold_results_dir:
|
# Assuming fold_results_dir is defined in the outer scope
|
||||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
|
||||||
thresh_file.replace('.txt',''), # Use filename as key for io
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
base_dir=fold_results_dir, use_txt=True)
|
thresh_file.replace('.txt',''), # Use filename as key for io
|
||||||
logger.info(f"Saved optimized edge threshold to {thresh_file}")
|
base_dir=fold_results_dir, use_txt=True)
|
||||||
elif self.io:
|
logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
|
||||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
elif self.io:
|
||||||
thresh_file.replace('.txt',''),
|
# Fallback: Save to the main results directory if fold_results_dir isn't available
|
||||||
section='results', use_txt=True) # Fallback save
|
# Construct the path for logging clarity
|
||||||
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}")
|
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||||
else:
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
thresh_file.replace('.txt',''),
|
||||||
except Exception as e:
|
section='results', use_txt=True) # Fallback save
|
||||||
logger.error(f"Failed to save optimized edge threshold: {e}")
|
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||||
else:
|
else:
|
||||||
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
|
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
except NameError: # Specifically catch if fold_results_dir is not defined
|
||||||
|
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
|
||||||
|
if self.io:
|
||||||
|
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||||
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
|
thresh_file.replace('.txt',''),
|
||||||
|
section='results', use_txt=True)
|
||||||
|
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save optimized edge threshold: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
|
||||||
|
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
|
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
|
||||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||||
else:
|
else:
|
||||||
# If optimization is disabled, store the config threshold for consistent use
|
# If optimization is disabled, store the config threshold for consistent use
|
||||||
self.optimized_edge_threshold = edge_thr_config
|
self.optimized_edge_threshold = edge_thr_config
|
||||||
@ -1430,14 +1444,14 @@ class TradingPipeline:
|
|||||||
|
|
||||||
# --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- #
|
# --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- #
|
||||||
if p_cal_val_to_check is not None and y_dir_val_to_check is not None:
|
if p_cal_val_to_check is not None and y_dir_val_to_check is not None:
|
||||||
self._perform_gru_validation_checks(
|
self._perform_gru_validation_checks(
|
||||||
p_cal_val=p_cal_val_to_check,
|
p_cal_val=p_cal_val_to_check,
|
||||||
y_dir_val=y_dir_val_to_check,
|
y_dir_val=y_dir_val_to_check,
|
||||||
is_ternary=is_ternary_check
|
is_ternary=is_ternary_check
|
||||||
)
|
)
|
||||||
# Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold
|
# Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold
|
||||||
else:
|
else:
|
||||||
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
|
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
|
||||||
|
|
||||||
# --- Helper for GRU Validation Checks (Replaces edge check) --- #
|
# --- Helper for GRU Validation Checks (Replaces edge check) --- #
|
||||||
def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary):
|
def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary):
|
||||||
@ -1467,7 +1481,6 @@ class TradingPipeline:
|
|||||||
calib_config = self.config.get('calibration', {})
|
calib_config = self.config.get('calibration', {})
|
||||||
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
||||||
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
||||||
|
|
||||||
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
||||||
|
|
||||||
if optimize_edge and not is_ternary:
|
if optimize_edge and not is_ternary:
|
||||||
@ -1483,7 +1496,7 @@ class TradingPipeline:
|
|||||||
self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold},
|
self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold},
|
||||||
f'optimized_edge_threshold_fold_{self.current_fold}',
|
f'optimized_edge_threshold_fold_{self.current_fold}',
|
||||||
base_dir=fold_results_dir, use_txt=True)
|
base_dir=fold_results_dir, use_txt=True)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.")
|
logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.")
|
||||||
elif optimize_edge and is_ternary:
|
elif optimize_edge and is_ternary:
|
||||||
logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}")
|
logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}")
|
||||||
@ -1527,13 +1540,13 @@ class TradingPipeline:
|
|||||||
passed_edge_acc = ci_lower >= edge_check_thr
|
passed_edge_acc = ci_lower >= edge_check_thr
|
||||||
logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}")
|
logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}")
|
||||||
except ValueError as binom_err:
|
except ValueError as binom_err:
|
||||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
|
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
|
||||||
passed_edge_acc = False # Consider error as failure
|
passed_edge_acc = False # Consider error as failure
|
||||||
else:
|
else:
|
||||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
|
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
|
||||||
if n_filtered == 0:
|
if n_filtered == 0:
|
||||||
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
|
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
|
||||||
passed_edge_acc = False # Consider NaN or 0 samples as failure
|
passed_edge_acc = False # Consider NaN or 0 samples as failure
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True)
|
||||||
passed_edge_acc = False # Consider error as failure
|
passed_edge_acc = False # Consider error as failure
|
||||||
@ -1567,8 +1580,8 @@ class TradingPipeline:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
# Use sys.exit with a specific message for clarity
|
# Use sys.exit with a specific message for clarity
|
||||||
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
||||||
else:
|
else: # Corrected indentation
|
||||||
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).")
|
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
|
||||||
# --- End Validation Helper --- #
|
# --- End Validation Helper --- #
|
||||||
|
|
||||||
def train_or_load_sac(self):
|
def train_or_load_sac(self):
|
||||||
@ -1791,6 +1804,7 @@ class TradingPipeline:
|
|||||||
# Get raw predictions needed for rolling calibration
|
# Get raw predictions needed for rolling calibration
|
||||||
p_raw_test_for_bt = None
|
p_raw_test_for_bt = None
|
||||||
logits_test_for_bt = None
|
logits_test_for_bt = None
|
||||||
|
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
|
||||||
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
||||||
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
||||||
if is_ternary:
|
if is_ternary:
|
||||||
@ -1798,43 +1812,43 @@ class TradingPipeline:
|
|||||||
if logits_test_for_bt is None:
|
if logits_test_for_bt is None:
|
||||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
||||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
||||||
else:
|
else: # Corrected indentation
|
||||||
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
||||||
if preds_test_raw is None or len(preds_test_raw) < 3:
|
if preds_test_raw is None or len(preds_test_raw) < 3:
|
||||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
||||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
||||||
p_raw_test_for_bt = preds_test_raw[2].flatten()
|
p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
|
||||||
|
|
||||||
# Get the edge threshold determined during validation (optimized or fixed)
|
# Get the edge threshold determined during validation (optimized or fixed)
|
||||||
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
||||||
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
||||||
|
|
||||||
try:
|
try: # Corrected indentation
|
||||||
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
||||||
sac_agent_load_path=self.sac_agent_load_path,
|
sac_agent_load_path=self.sac_agent_load_path,
|
||||||
X_test_seq=self.X_test_seq,
|
X_test_seq=self.X_test_seq,
|
||||||
y_test_seq_dict=self.y_test_seq_dict,
|
y_test_seq_dict=self.y_test_seq_dict,
|
||||||
test_indices=self.test_indices,
|
test_indices=self.test_indices,
|
||||||
gru_handler=self.gru_handler,
|
gru_handler=self.gru_handler,
|
||||||
# --- Pass Calibrator instances and initial state --- #
|
# --- Pass Calibrator instances and initial state --- #
|
||||||
calibrator=calibrator_instance,
|
calibrator=calibrator_instance,
|
||||||
vector_calibrator=vector_calibrator_instance,
|
vector_calibrator=vector_calibrator_instance,
|
||||||
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
|
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
|
||||||
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
|
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
|
||||||
fold_edge_threshold=edge_threshold_for_bt,
|
fold_edge_threshold=edge_threshold_for_bt,
|
||||||
# --- Pass raw predictions if needed for rolling cal --- #
|
# --- Pass raw predictions if needed for rolling cal --- #
|
||||||
p_raw_test=p_raw_test_for_bt,
|
p_raw_test=p_raw_test_for_bt,
|
||||||
logits_test=logits_test_for_bt,
|
logits_test=logits_test_for_bt,
|
||||||
# --- Pass original prices --- #
|
# --- Pass original prices --- #
|
||||||
original_prices=self.df_test_original, # Pass the DataFrame
|
original_prices=self.df_test_original, # Pass the DataFrame
|
||||||
is_ternary=self.use_ternary,
|
is_ternary=self.use_ternary,
|
||||||
fold_num=self.current_fold
|
fold_num=self.current_fold
|
||||||
)
|
)
|
||||||
except SystemExit as e:
|
except SystemExit as e: # Corrected indentation
|
||||||
# Catch exits from backtester validation/execution
|
# Catch exits from backtester validation/execution
|
||||||
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
||||||
raise # Re-raise to stop the fold
|
raise # Re-raise to stop the fold
|
||||||
except Exception as e:
|
except Exception as e: # Corrected indentation
|
||||||
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
||||||
# Treat as failure, ensure metrics are None
|
# Treat as failure, ensure metrics are None
|
||||||
self.backtest_results_df = None
|
self.backtest_results_df = None
|
||||||
@ -2742,7 +2756,7 @@ class TradingPipeline:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
|
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
|
||||||
|
|
||||||
# --- Entry Point --- #
|
# --- Entry Point --- #
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user