iundentation fix pipeline
This commit is contained in:
parent
a02dd4342e
commit
a86cdb2c8f
@ -227,7 +227,7 @@ The `config.yaml` file centrally controls the pipeline's behavior. Key sections
|
|||||||
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
|
1. **Setup:** Install requirements (`pip install -r requirements.txt`), prepare data in the specified format/location.
|
||||||
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
|
2. **Configure:** Edit `config.yaml` (data paths, feature lists, model params, control flags, walk-forward settings, calibration, validation thresholds, tuning, aggregation).
|
||||||
3. **Run Pipeline:**
|
3. **Run Pipeline:**
|
||||||
```bash
|
```bash
|
||||||
# From project root (develop/gru_sac_predictor/)
|
# From project root (develop/gru_sac_predictor/)
|
||||||
python gru_sac_predictor/run.py --config path/to/your_config.yaml
|
python gru_sac_predictor/run.py --config path/to/your_config.yaml
|
||||||
```
|
```
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class Backtester:
|
|||||||
self.coverage_alarm_threshold_drop = self.cal_cfg.get('coverage_alarm_threshold_drop', 0.03)
|
self.coverage_alarm_threshold_drop = self.cal_cfg.get('coverage_alarm_threshold_drop', 0.03)
|
||||||
self.coverage_alarm_window = self.cal_cfg.get('coverage_alarm_window', 1000)
|
self.coverage_alarm_window = self.cal_cfg.get('coverage_alarm_window', 1000)
|
||||||
# --- End --- #
|
# --- End --- #
|
||||||
|
|
||||||
self.results_df: Optional[pd.DataFrame] = None
|
self.results_df: Optional[pd.DataFrame] = None
|
||||||
self.metrics: Optional[Dict[str, Any]] = None
|
self.metrics: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@ -98,11 +98,11 @@ class Backtester:
|
|||||||
self.ece_recalibration_threshold = self.cal_cfg.get('ece_recalibration_threshold', 0.03)
|
self.ece_recalibration_threshold = self.cal_cfg.get('ece_recalibration_threshold', 0.03)
|
||||||
|
|
||||||
def run_backtest(
|
def run_backtest(
|
||||||
self,
|
self,
|
||||||
sac_agent_load_path: Optional[str],
|
sac_agent_load_path: Optional[str],
|
||||||
X_test_seq: np.ndarray,
|
X_test_seq: np.ndarray,
|
||||||
y_test_seq_dict: Dict[str, np.ndarray],
|
y_test_seq_dict: Dict[str, np.ndarray],
|
||||||
test_indices: pd.Index,
|
test_indices: pd.Index,
|
||||||
gru_handler: GRUModelHandler,
|
gru_handler: GRUModelHandler,
|
||||||
# --- Pass Calibrator Instances & Initial Params/Threshold --- #
|
# --- Pass Calibrator Instances & Initial Params/Threshold --- #
|
||||||
calibrator: Optional[Calibrator], # For Temp Scaling
|
calibrator: Optional[Calibrator], # For Temp Scaling
|
||||||
@ -163,7 +163,7 @@ class Backtester:
|
|||||||
logger.error(f"Fold {fold_num}: Rolling calibration enabled for ternary case, but logits_test not provided.")
|
logger.error(f"Fold {fold_num}: Rolling calibration enabled for ternary case, but logits_test not provided.")
|
||||||
return None, None, None
|
return None, None, None
|
||||||
# --- End Validation --- #
|
# --- End Validation --- #
|
||||||
|
|
||||||
# 1. Initialize SAC Agent & Load Weights
|
# 1. Initialize SAC Agent & Load Weights
|
||||||
# Ensure agent state dim matches the state construction below
|
# Ensure agent state dim matches the state construction below
|
||||||
agent_state_dim = 5 # mu, sigma, edge, |mu|/sigma, position
|
agent_state_dim = 5 # mu, sigma, edge, |mu|/sigma, position
|
||||||
@ -223,7 +223,7 @@ class Backtester:
|
|||||||
log_sigma_test_state = preds_test_state[1][:, 1].flatten() if preds_test_state[1].shape[-1] == 2 else np.log(preds_test_state[1].flatten() + 1e-9)
|
log_sigma_test_state = preds_test_state[1][:, 1].flatten() if preds_test_state[1].shape[-1] == 2 else np.log(preds_test_state[1].flatten() + 1e-9)
|
||||||
sigma_test_state = np.exp(log_sigma_test_state)
|
sigma_test_state = np.exp(log_sigma_test_state)
|
||||||
p_raw_fallback = p_raw_test # Use the passed raw probs if needed outside rolling
|
p_raw_fallback = p_raw_test # Use the passed raw probs if needed outside rolling
|
||||||
|
|
||||||
# Extract actual returns and directions
|
# Extract actual returns and directions
|
||||||
actual_ret_test = y_test_seq_dict.get('ret')
|
actual_ret_test = y_test_seq_dict.get('ret')
|
||||||
dir_key = 'dir3' if is_ternary else 'dir'
|
dir_key = 'dir3' if is_ternary else 'dir'
|
||||||
@ -262,7 +262,7 @@ class Backtester:
|
|||||||
metrics_log = [] # Store periodic metrics
|
metrics_log = [] # Store periodic metrics
|
||||||
step_correct_nonzero, step_count_nonzero = 0, 0
|
step_correct_nonzero, step_count_nonzero = 0, 0
|
||||||
step_abs_actions = []
|
step_abs_actions = []
|
||||||
|
|
||||||
logger.info(f"Fold {fold_num}: Starting backtest simulation loop ({n_test} steps)...")
|
logger.info(f"Fold {fold_num}: Starting backtest simulation loop ({n_test} steps)...")
|
||||||
for i in range(n_test):
|
for i in range(n_test):
|
||||||
# --- Step i: Calibration ---
|
# --- Step i: Calibration ---
|
||||||
@ -332,7 +332,7 @@ class Backtester:
|
|||||||
|
|
||||||
# --- Step i: Update Position ---
|
# --- Step i: Update Position ---
|
||||||
current_position = target_position
|
current_position = target_position
|
||||||
|
|
||||||
# --- Step i: Update Metrics Log Data ---
|
# --- Step i: Update Metrics Log Data ---
|
||||||
if abs(current_position) > 1e-6:
|
if abs(current_position) > 1e-6:
|
||||||
step_count_nonzero += 1
|
step_count_nonzero += 1
|
||||||
@ -455,7 +455,7 @@ class Backtester:
|
|||||||
logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})")
|
logger.info(f" New optimal temperature: {new_T:.4f} (Previous: {current_optimal_T:.4f})")
|
||||||
current_optimal_T = new_T
|
current_optimal_T = new_T
|
||||||
calibrator.optimal_T = new_T # Update instance
|
calibrator.optimal_T = new_T # Update instance
|
||||||
else:
|
else:
|
||||||
logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).")
|
logger.info(f" Temperature unchanged or optimization failed (T={current_optimal_T:.4f}).")
|
||||||
|
|
||||||
last_recalib_step = i
|
last_recalib_step = i
|
||||||
|
|||||||
@ -769,7 +769,7 @@ class TradingPipeline:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Manual save of final feature whitelist failed: {e}", exc_info=True)
|
logging.error(f"Manual save of final feature whitelist failed: {e}", exc_info=True)
|
||||||
# --- End Save Update --- #
|
# --- End Save Update --- #
|
||||||
|
|
||||||
# --- MODIFIED: Prune the SCALED data splits using the determined whitelist --- #
|
# --- MODIFIED: Prune the SCALED data splits using the determined whitelist --- #
|
||||||
if self.X_train_scaled is None or self.X_val_scaled is None or self.X_test_scaled is None:
|
if self.X_train_scaled is None or self.X_val_scaled is None or self.X_test_scaled is None:
|
||||||
logging.error("Scaled data splits not available for pruning.")
|
logging.error("Scaled data splits not available for pruning.")
|
||||||
@ -780,7 +780,7 @@ class TradingPipeline:
|
|||||||
self.X_val_pruned = self.feature_engineer.prune_features(self.X_val_scaled, self.final_whitelist)
|
self.X_val_pruned = self.feature_engineer.prune_features(self.X_val_scaled, self.final_whitelist)
|
||||||
self.X_test_pruned = self.feature_engineer.prune_features(self.X_test_scaled, self.final_whitelist)
|
self.X_test_pruned = self.feature_engineer.prune_features(self.X_test_scaled, self.final_whitelist)
|
||||||
# --- End Modification --- #
|
# --- End Modification --- #
|
||||||
|
|
||||||
logging.info(f"Feature shapes after pruning scaled data: Train={self.X_train_pruned.shape}, Val={self.X_val_pruned.shape}, Test={self.X_test_pruned.shape}")
|
logging.info(f"Feature shapes after pruning scaled data: Train={self.X_train_pruned.shape}, Val={self.X_val_pruned.shape}, Test={self.X_test_pruned.shape}")
|
||||||
|
|
||||||
# Verification and empty checks remain the same, using X_*_pruned
|
# Verification and empty checks remain the same, using X_*_pruned
|
||||||
@ -1094,116 +1094,116 @@ class TradingPipeline:
|
|||||||
patience=patience
|
patience=patience
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.gru_model is None:
|
if self.gru_model is None:
|
||||||
logging.error("GRU model training failed. Exiting.")
|
logging.error("GRU model training failed. Exiting.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
|
||||||
# Save the newly trained model
|
|
||||||
saved_path = self.gru_handler.save() # Uses run_id from handler
|
|
||||||
if saved_path:
|
|
||||||
logging.info(f"Newly trained GRU model saved to {saved_path}")
|
|
||||||
else:
|
else:
|
||||||
logging.warning("Failed to save the newly trained GRU model.")
|
# Save the newly trained model
|
||||||
# Set the loaded ID to the current run ID
|
saved_path = self.gru_handler.save() # Uses run_id from handler
|
||||||
self.gru_model_run_id_loaded_from = self.run_id
|
if saved_path:
|
||||||
logging.info(f"Using GRU model trained in current run: {self.run_id}")
|
logging.info(f"Newly trained GRU model saved to {saved_path}")
|
||||||
|
|
||||||
# --- V3 Output Contract: Plot Learning Curve --- #
|
|
||||||
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
|
|
||||||
# Infer log dir path based on current models dir
|
|
||||||
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
|
|
||||||
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
|
|
||||||
if os.path.exists(csv_log_path):
|
|
||||||
logging.info(f"Plotting learning curve from {csv_log_path}...")
|
|
||||||
try:
|
|
||||||
history_df = pd.read_csv(csv_log_path)
|
|
||||||
|
|
||||||
# Determine metric keys (handle v2 vs v3 differences if necessary)
|
|
||||||
loss_key = 'loss'
|
|
||||||
val_loss_key = 'val_loss'
|
|
||||||
acc_key = None
|
|
||||||
val_acc_key = None
|
|
||||||
if 'dir3_accuracy' in history_df.columns: # V3 specific?
|
|
||||||
acc_key = 'dir3_accuracy'
|
|
||||||
val_acc_key = 'val_dir3_accuracy'
|
|
||||||
elif 'accuracy' in history_df.columns: # V2 or other?
|
|
||||||
acc_key = 'accuracy'
|
|
||||||
val_acc_key = 'val_accuracy'
|
|
||||||
|
|
||||||
if acc_key is None:
|
|
||||||
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
|
|
||||||
n_panes = 1 # Only plot loss
|
|
||||||
else:
|
|
||||||
n_panes = 2 # Plot loss and accuracy
|
|
||||||
|
|
||||||
# Get figure settings
|
|
||||||
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
|
|
||||||
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
|
|
||||||
footer_text = "© GRU-SAC v3"
|
|
||||||
|
|
||||||
plt.style.use('seaborn-v0_8-darkgrid')
|
|
||||||
# Adjust figsize height based on panes
|
|
||||||
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
|
|
||||||
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
|
|
||||||
|
|
||||||
if n_panes == 1:
|
|
||||||
ax_loss = axes # Single axis
|
|
||||||
else:
|
|
||||||
ax_loss, ax_acc = axes # Multiple axes
|
|
||||||
|
|
||||||
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
|
|
||||||
|
|
||||||
# Pane 1: Loss (Log Scale)
|
|
||||||
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
|
|
||||||
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
|
|
||||||
ax_loss.set_yscale('log')
|
|
||||||
ax_loss.set_ylabel('Loss (Log Scale)')
|
|
||||||
ax_loss.legend()
|
|
||||||
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
|
|
||||||
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
|
|
||||||
|
|
||||||
# Pane 2: Accuracy (if available)
|
|
||||||
if n_panes == 2:
|
|
||||||
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
|
|
||||||
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
|
|
||||||
ax_acc.set_ylabel('Accuracy')
|
|
||||||
ax_acc.set_xlabel('Epoch')
|
|
||||||
ax_acc.legend()
|
|
||||||
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
|
|
||||||
else:
|
|
||||||
# If only loss pane, set xlabel there
|
|
||||||
ax_loss.set_xlabel('Epoch')
|
|
||||||
|
|
||||||
# Add vertical line for early stopping epoch if available
|
|
||||||
if hasattr(history, 'epoch') and len(history.epoch) > 0:
|
|
||||||
# Early stopping epoch is the number of epochs run
|
|
||||||
early_stop_epoch = len(history.epoch)
|
|
||||||
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
|
|
||||||
for ax in fig.axes:
|
|
||||||
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
|
|
||||||
# Add legend entry to the last plot
|
|
||||||
fig.axes[-1].legend()
|
|
||||||
|
|
||||||
# Add footer
|
|
||||||
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
|
|
||||||
verticalalignment='bottom', fontsize=8, color='gray')
|
|
||||||
|
|
||||||
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
|
|
||||||
|
|
||||||
# Save figure using IOManager
|
|
||||||
self.io.save_figure(fig, "gru_learning_curve", section='results')
|
|
||||||
logging.info("GRU learning curve plot saved.")
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
|
|
||||||
else:
|
else:
|
||||||
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
logging.warning("Failed to save the newly trained GRU model.")
|
||||||
elif not self.io:
|
# Set the loaded ID to the current run ID
|
||||||
logging.warning("IOManager not available, skipping GRU learning curve plot.")
|
self.gru_model_run_id_loaded_from = self.run_id
|
||||||
# --- End Plot Learning Curve --- #
|
logging.info(f"Using GRU model trained in current run: {self.run_id}")
|
||||||
|
|
||||||
|
# --- V3 Output Contract: Plot Learning Curve --- #
|
||||||
|
if self.io and history is not None and self.config.get('control', {}).get('generate_plots', True):
|
||||||
|
# Infer log dir path based on current models dir
|
||||||
|
log_dir = os.path.dirname(self.current_run_models_dir).replace('/models', '/logs')
|
||||||
|
csv_log_path = os.path.join(log_dir, 'gru_history.csv')
|
||||||
|
if os.path.exists(csv_log_path):
|
||||||
|
logging.info(f"Plotting learning curve from {csv_log_path}...")
|
||||||
|
try:
|
||||||
|
history_df = pd.read_csv(csv_log_path)
|
||||||
|
|
||||||
|
# Determine metric keys (handle v2 vs v3 differences if necessary)
|
||||||
|
loss_key = 'loss'
|
||||||
|
val_loss_key = 'val_loss'
|
||||||
|
acc_key = None
|
||||||
|
val_acc_key = None
|
||||||
|
if 'dir3_accuracy' in history_df.columns: # V3 specific?
|
||||||
|
acc_key = 'dir3_accuracy'
|
||||||
|
val_acc_key = 'val_dir3_accuracy'
|
||||||
|
elif 'accuracy' in history_df.columns: # V2 or other?
|
||||||
|
acc_key = 'accuracy'
|
||||||
|
val_acc_key = 'val_accuracy'
|
||||||
|
|
||||||
|
if acc_key is None:
|
||||||
|
logging.warning("Could not find a suitable accuracy metric in history CSV for plotting.")
|
||||||
|
n_panes = 1 # Only plot loss
|
||||||
|
else:
|
||||||
|
n_panes = 2 # Plot loss and accuracy
|
||||||
|
|
||||||
|
# Get figure settings
|
||||||
|
fig_dpi = self.config.get('output', {}).get('figure_dpi', 150)
|
||||||
|
fig_size = self.config.get('output', {}).get('figure_size', [16, 9])
|
||||||
|
footer_text = "© GRU-SAC v3"
|
||||||
|
|
||||||
|
plt.style.use('seaborn-v0_8-darkgrid')
|
||||||
|
# Adjust figsize height based on panes
|
||||||
|
adjusted_fig_height = fig_size[1] * (n_panes / 3.0) # Rough scaling
|
||||||
|
fig, axes = plt.subplots(n_panes, 1, figsize=(fig_size[0], adjusted_fig_height), sharex=True)
|
||||||
|
|
||||||
|
if n_panes == 1:
|
||||||
|
ax_loss = axes # Single axis
|
||||||
|
else:
|
||||||
|
ax_loss, ax_acc = axes # Multiple axes
|
||||||
|
|
||||||
|
epochs = history_df['epoch'] + 1 # epochs are 0-indexed in csv
|
||||||
|
|
||||||
|
# Pane 1: Loss (Log Scale)
|
||||||
|
ax_loss.plot(epochs, history_df[loss_key], label='Training Loss')
|
||||||
|
ax_loss.plot(epochs, history_df[val_loss_key], label='Validation Loss')
|
||||||
|
ax_loss.set_yscale('log')
|
||||||
|
ax_loss.set_ylabel('Loss (Log Scale)')
|
||||||
|
ax_loss.legend()
|
||||||
|
ax_loss.set_title('GRU Model Training Progress', fontsize=16)
|
||||||
|
ax_loss.grid(True, which="both", ls="--", linewidth=0.5)
|
||||||
|
|
||||||
|
# Pane 2: Accuracy (if available)
|
||||||
|
if n_panes == 2:
|
||||||
|
ax_acc.plot(epochs, history_df[acc_key], label=f'Training {acc_key}')
|
||||||
|
ax_acc.plot(epochs, history_df[val_acc_key], label=f'Validation {val_acc_key}')
|
||||||
|
ax_acc.set_ylabel('Accuracy')
|
||||||
|
ax_acc.set_xlabel('Epoch')
|
||||||
|
ax_acc.legend()
|
||||||
|
ax_acc.grid(True, which="both", ls="--", linewidth=0.5)
|
||||||
|
else:
|
||||||
|
# If only loss pane, set xlabel there
|
||||||
|
ax_loss.set_xlabel('Epoch')
|
||||||
|
|
||||||
|
# Add vertical line for early stopping epoch if available
|
||||||
|
if hasattr(history, 'epoch') and len(history.epoch) > 0:
|
||||||
|
# Early stopping epoch is the number of epochs run
|
||||||
|
early_stop_epoch = len(history.epoch)
|
||||||
|
if early_stop_epoch < max_epochs: # Only draw if early stopping occurred
|
||||||
|
for ax in fig.axes:
|
||||||
|
ax.axvline(x=early_stop_epoch, color='r', linestyle='--', linewidth=1, label=f'Early Stop @ {early_stop_epoch}')
|
||||||
|
# Add legend entry to the last plot
|
||||||
|
fig.axes[-1].legend()
|
||||||
|
|
||||||
|
# Add footer
|
||||||
|
plt.figtext(0.99, 0.01, footer_text, horizontalalignment='right',
|
||||||
|
verticalalignment='bottom', fontsize=8, color='gray')
|
||||||
|
|
||||||
|
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout
|
||||||
|
|
||||||
|
# Save figure using IOManager
|
||||||
|
self.io.save_figure(fig, "gru_learning_curve", section='results')
|
||||||
|
logging.info("GRU learning curve plot saved.")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to plot GRU learning curve: {e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logging.warning(f"GRU history file not found at {csv_log_path}. Cannot plot learning curve.")
|
||||||
|
elif not self.io:
|
||||||
|
logging.warning("IOManager not available, skipping GRU learning curve plot.")
|
||||||
|
# --- End Plot Learning Curve --- #
|
||||||
|
|
||||||
else: # Load pre-trained GRU model
|
else: # Load pre-trained GRU model
|
||||||
load_run_id = gru_cfg.get('model_load_run_id', None)
|
load_run_id = gru_cfg.get('model_load_run_id', None)
|
||||||
@ -1357,8 +1357,8 @@ class TradingPipeline:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
||||||
self.optimal_T = None
|
self.optimal_T = None
|
||||||
else:
|
else: # Covers cases where calibration method is wrong or mismatch with ternary state
|
||||||
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.")
|
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
|
||||||
self.optimal_T = None
|
self.optimal_T = None
|
||||||
self.vector_cal_params = None
|
self.vector_cal_params = None
|
||||||
|
|
||||||
@ -1368,60 +1368,74 @@ class TradingPipeline:
|
|||||||
self.optimized_edge_threshold = None # Reset for the fold
|
self.optimized_edge_threshold = None # Reset for the fold
|
||||||
|
|
||||||
if optimize_edge:
|
if optimize_edge:
|
||||||
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
|
logger.info(f"Optimizing edge threshold using Youden's J on validation predictions...")
|
||||||
try:
|
try:
|
||||||
# Prepare y_true for optimization (needs binary 0/1)
|
# Prepare y_true for optimization (needs binary 0/1)
|
||||||
if y_dir_val_to_check is None:
|
if y_dir_val_to_check is None:
|
||||||
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
|
raise ValueError("Cannot optimize edge threshold without valid y_dir_val.")
|
||||||
y_true_for_opt = None
|
y_true_for_opt = None
|
||||||
p_cal_for_opt = None
|
p_cal_for_opt = None
|
||||||
if is_ternary_check:
|
if is_ternary_check:
|
||||||
if p_cal_val_to_check is not None:
|
if p_cal_val_to_check is not None:
|
||||||
# Convert ternary to binary: P(up) vs others
|
# Convert ternary to binary: P(up) vs others
|
||||||
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
|
p_cal_for_opt = p_cal_val_to_check[:, -1] # P(up)
|
||||||
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
|
y_true_for_opt = (np.argmax(y_dir_val_to_check, axis=1) == 2).astype(int)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
|
raise ValueError("Cannot optimize ternary edge threshold without valid calibrated probabilities.")
|
||||||
else:
|
else:
|
||||||
# Binary case
|
# Binary case
|
||||||
if p_cal_val_to_check is not None:
|
if p_cal_val_to_check is not None:
|
||||||
p_cal_for_opt = p_cal_val_to_check
|
p_cal_for_opt = p_cal_val_to_check
|
||||||
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
|
y_true_for_opt = (np.asarray(y_dir_val_to_check) > 0.5).astype(int)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
|
raise ValueError("Cannot optimize binary edge threshold without valid calibrated probabilities.")
|
||||||
|
|
||||||
# Perform optimization using the dedicated function from metrics
|
|
||||||
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
|
|
||||||
# Ensure _calculate_optimal_edge_threshold is imported
|
|
||||||
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
|
|
||||||
|
|
||||||
if self.optimized_edge_threshold is not None:
|
|
||||||
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
|
|
||||||
# Save optimized threshold
|
|
||||||
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
|
||||||
try:
|
|
||||||
# Ensure fold_results_dir is defined (should be available from context)
|
|
||||||
if self.io and fold_results_dir:
|
|
||||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
|
||||||
thresh_file.replace('.txt',''), # Use filename as key for io
|
|
||||||
base_dir=fold_results_dir, use_txt=True)
|
|
||||||
logger.info(f"Saved optimized edge threshold to {thresh_file}")
|
|
||||||
elif self.io:
|
|
||||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
|
||||||
thresh_file.replace('.txt',''),
|
|
||||||
section='results', use_txt=True) # Fallback save
|
|
||||||
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}")
|
|
||||||
else:
|
|
||||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save optimized edge threshold: {e}")
|
|
||||||
else:
|
|
||||||
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
|
|
||||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
|
||||||
|
|
||||||
except Exception as e:
|
# Perform optimization using the dedicated function from metrics
|
||||||
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
|
# Note: This assumes Calibrator.optimize_edge_threshold was removed or is not used here
|
||||||
self.optimized_edge_threshold = edge_thr_config # Fallback
|
# Ensure _calculate_optimal_edge_threshold is imported
|
||||||
|
self.optimized_edge_threshold = _calculate_optimal_edge_threshold(y_true_for_opt, p_cal_for_opt)
|
||||||
|
|
||||||
|
if self.optimized_edge_threshold is not None:
|
||||||
|
logger.info(f"Optimized edge threshold: {self.optimized_edge_threshold:.4f}")
|
||||||
|
# Save optimized threshold
|
||||||
|
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
||||||
|
try:
|
||||||
|
# Ensure fold_results_dir is defined (should be available from context)
|
||||||
|
# Assuming fold_results_dir is defined in the outer scope
|
||||||
|
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
|
||||||
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
|
thresh_file.replace('.txt',''), # Use filename as key for io
|
||||||
|
base_dir=fold_results_dir, use_txt=True)
|
||||||
|
logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
|
||||||
|
elif self.io:
|
||||||
|
# Fallback: Save to the main results directory if fold_results_dir isn't available
|
||||||
|
# Construct the path for logging clarity
|
||||||
|
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||||
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
|
thresh_file.replace('.txt',''),
|
||||||
|
section='results', use_txt=True) # Fallback save
|
||||||
|
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||||
|
except NameError: # Specifically catch if fold_results_dir is not defined
|
||||||
|
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
|
||||||
|
if self.io:
|
||||||
|
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||||
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
|
thresh_file.replace('.txt',''),
|
||||||
|
section='results', use_txt=True)
|
||||||
|
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save optimized edge threshold: {e}")
|
||||||
|
else:
|
||||||
|
logger.warning("Edge threshold optimization failed or was skipped. Using config default.")
|
||||||
|
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during edge threshold optimization: {e}", exc_info=True)
|
||||||
|
self.optimized_edge_threshold = edge_thr_config # Fallback
|
||||||
else:
|
else:
|
||||||
# If optimization is disabled, store the config threshold for consistent use
|
# If optimization is disabled, store the config threshold for consistent use
|
||||||
self.optimized_edge_threshold = edge_thr_config
|
self.optimized_edge_threshold = edge_thr_config
|
||||||
@ -1430,21 +1444,21 @@ class TradingPipeline:
|
|||||||
|
|
||||||
# --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- #
|
# --- Perform GRU Validation Checks using the threshold stored in self.optimized_edge_threshold --- #
|
||||||
if p_cal_val_to_check is not None and y_dir_val_to_check is not None:
|
if p_cal_val_to_check is not None and y_dir_val_to_check is not None:
|
||||||
self._perform_gru_validation_checks(
|
self._perform_gru_validation_checks(
|
||||||
p_cal_val=p_cal_val_to_check,
|
p_cal_val=p_cal_val_to_check,
|
||||||
y_dir_val=y_dir_val_to_check,
|
y_dir_val=y_dir_val_to_check,
|
||||||
is_ternary=is_ternary_check
|
is_ternary=is_ternary_check
|
||||||
)
|
)
|
||||||
# Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold
|
# Note: _perform_gru_validation_checks was already modified to use self.optimized_edge_threshold
|
||||||
else:
|
else:
|
||||||
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
|
logger.warning("Could not perform GRU validation checks due to missing calibrated predictions or labels.")
|
||||||
|
|
||||||
# --- Helper for GRU Validation Checks (Replaces edge check) --- #
|
# --- Helper for GRU Validation Checks (Replaces edge check) --- #
|
||||||
def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary):
|
def _perform_gru_validation_checks(self, p_cal_val, y_dir_val, is_ternary):
|
||||||
"""
|
"""
|
||||||
Performs GRU validation checks: Edge-Filtered Accuracy and Brier Score.
|
Performs GRU validation checks: Edge-Filtered Accuracy and Brier Score.
|
||||||
Logs results and raises SystemExit if checks fail.
|
Logs results and raises SystemExit if checks fail.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p_cal_val: Calibrated probabilities on validation set.
|
p_cal_val: Calibrated probabilities on validation set.
|
||||||
For binary: (N,) shape, P(up).
|
For binary: (N,) shape, P(up).
|
||||||
@ -1455,19 +1469,18 @@ class TradingPipeline:
|
|||||||
is_ternary (bool): Flag indicating if ternary classification is used.
|
is_ternary (bool): Flag indicating if ternary classification is used.
|
||||||
"""
|
"""
|
||||||
logger.info(f"--- Fold {self.current_fold}: Performing GRU Validation Checks --- ")
|
logger.info(f"--- Fold {self.current_fold}: Performing GRU Validation Checks --- ")
|
||||||
|
|
||||||
# --- Define thresholds (Consider moving to config) --- #
|
# --- Define thresholds (Consider moving to config) --- #
|
||||||
validation_criteria = self.config.get('validation_gates', {}).get('gru', {})
|
validation_criteria = self.config.get('validation_gates', {}).get('gru', {})
|
||||||
edge_check_thr = validation_criteria.get('edge_filtered_acc_ci_lower_threshold', 0.55)
|
edge_check_thr = validation_criteria.get('edge_filtered_acc_ci_lower_threshold', 0.55)
|
||||||
brier_check_thr = validation_criteria.get('brier_score_threshold', 0.19)
|
brier_check_thr = validation_criteria.get('brier_score_threshold', 0.19)
|
||||||
min_edge_samples = validation_criteria.get('edge_filtered_min_samples', 30)
|
min_edge_samples = validation_criteria.get('edge_filtered_min_samples', 30)
|
||||||
# --- End Thresholds --- #
|
# --- End Thresholds --- #
|
||||||
|
|
||||||
# --- Determine Edge Threshold --- #
|
# --- Determine Edge Threshold --- #
|
||||||
calib_config = self.config.get('calibration', {})
|
calib_config = self.config.get('calibration', {})
|
||||||
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
||||||
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
||||||
|
|
||||||
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
||||||
|
|
||||||
if optimize_edge and not is_ternary:
|
if optimize_edge and not is_ternary:
|
||||||
@ -1483,7 +1496,7 @@ class TradingPipeline:
|
|||||||
self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold},
|
self.io.save_json({'optimized_edge_threshold': self.fold_edge_threshold},
|
||||||
f'optimized_edge_threshold_fold_{self.current_fold}',
|
f'optimized_edge_threshold_fold_{self.current_fold}',
|
||||||
base_dir=fold_results_dir, use_txt=True)
|
base_dir=fold_results_dir, use_txt=True)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.")
|
logger.warning(f"Fold {self.current_fold}: Could not save optimized edge threshold, results dir missing.")
|
||||||
elif optimize_edge and is_ternary:
|
elif optimize_edge and is_ternary:
|
||||||
logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}")
|
logger.warning(f"Fold {self.current_fold}: Edge threshold optimization requested but not supported for ternary. Using config value: {edge_thr_config:.4f}")
|
||||||
@ -1527,13 +1540,13 @@ class TradingPipeline:
|
|||||||
passed_edge_acc = ci_lower >= edge_check_thr
|
passed_edge_acc = ci_lower >= edge_check_thr
|
||||||
logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}")
|
logger.info(f"Fold {self.current_fold}: Edge Acc Check (edge >= {self.fold_edge_threshold:.2f}): Acc={edge_accuracy:.3f} ({k_correct}/{n_filtered}), 95% CI Lower={ci_lower:.3f} >= {edge_check_thr} -> {'Pass' if passed_edge_acc else 'FAIL'}")
|
||||||
except ValueError as binom_err:
|
except ValueError as binom_err:
|
||||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
|
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Error calculating binomial test (k={k_correct}, n={n_filtered}): {binom_err}. Check considered FAIL.")
|
||||||
passed_edge_acc = False # Consider error as failure
|
passed_edge_acc = False # Consider error as failure
|
||||||
else:
|
else:
|
||||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
|
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Calculation failed (NaN). Check considered FAIL.")
|
||||||
if n_filtered == 0:
|
if n_filtered == 0:
|
||||||
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
|
logger.error(f" Reason: No validation samples met the edge threshold >= {self.fold_edge_threshold:.2f}")
|
||||||
passed_edge_acc = False # Consider NaN or 0 samples as failure
|
passed_edge_acc = False # Consider NaN or 0 samples as failure
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Edge Acc Check: Unexpected error during calculation: {e}. Check considered FAIL.", exc_info=True)
|
||||||
passed_edge_acc = False # Consider error as failure
|
passed_edge_acc = False # Consider error as failure
|
||||||
@ -1566,9 +1579,9 @@ class TradingPipeline:
|
|||||||
error_msg = f"FOLD {self.current_fold} GRU VALIDATION FAILED: Edge Acc Pass={passed_edge_acc} (Req CI>={edge_check_thr}), Brier Pass={passed_brier} (Req Score<={brier_check_thr}). Aborting fold."
|
error_msg = f"FOLD {self.current_fold} GRU VALIDATION FAILED: Edge Acc Pass={passed_edge_acc} (Req CI>={edge_check_thr}), Brier Pass={passed_brier} (Req Score<={brier_check_thr}). Aborting fold."
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
# Use sys.exit with a specific message for clarity
|
# Use sys.exit with a specific message for clarity
|
||||||
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
||||||
else:
|
else: # Corrected indentation
|
||||||
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).")
|
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
|
||||||
# --- End Validation Helper --- #
|
# --- End Validation Helper --- #
|
||||||
|
|
||||||
def train_or_load_sac(self):
|
def train_or_load_sac(self):
|
||||||
@ -1628,7 +1641,7 @@ class TradingPipeline:
|
|||||||
# # --- Revision 1: Handle Rolling Calibrator Conflict --- #
|
# # --- Revision 1: Handle Rolling Calibrator Conflict --- #
|
||||||
# ... (block removed) ...
|
# ... (block removed) ...
|
||||||
# # --- End Revision 1 --- #
|
# # --- End Revision 1 --- #
|
||||||
|
|
||||||
# Start the training process
|
# Start the training process
|
||||||
final_agent_path = self.sac_trainer.train(gru_run_id_for_sac=self.gru_model_run_id_loaded_from)
|
final_agent_path = self.sac_trainer.train(gru_run_id_for_sac=self.gru_model_run_id_loaded_from)
|
||||||
|
|
||||||
@ -1791,6 +1804,7 @@ class TradingPipeline:
|
|||||||
# Get raw predictions needed for rolling calibration
|
# Get raw predictions needed for rolling calibration
|
||||||
p_raw_test_for_bt = None
|
p_raw_test_for_bt = None
|
||||||
logits_test_for_bt = None
|
logits_test_for_bt = None
|
||||||
|
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
|
||||||
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
||||||
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
||||||
if is_ternary:
|
if is_ternary:
|
||||||
@ -1798,43 +1812,43 @@ class TradingPipeline:
|
|||||||
if logits_test_for_bt is None:
|
if logits_test_for_bt is None:
|
||||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
||||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
||||||
else:
|
else: # Corrected indentation
|
||||||
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
||||||
if preds_test_raw is None or len(preds_test_raw) < 3:
|
if preds_test_raw is None or len(preds_test_raw) < 3:
|
||||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
||||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
||||||
p_raw_test_for_bt = preds_test_raw[2].flatten()
|
p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
|
||||||
|
|
||||||
# Get the edge threshold determined during validation (optimized or fixed)
|
# Get the edge threshold determined during validation (optimized or fixed)
|
||||||
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
||||||
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
||||||
|
|
||||||
try:
|
try: # Corrected indentation
|
||||||
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
||||||
sac_agent_load_path=self.sac_agent_load_path,
|
sac_agent_load_path=self.sac_agent_load_path,
|
||||||
X_test_seq=self.X_test_seq,
|
X_test_seq=self.X_test_seq,
|
||||||
y_test_seq_dict=self.y_test_seq_dict,
|
y_test_seq_dict=self.y_test_seq_dict,
|
||||||
test_indices=self.test_indices,
|
test_indices=self.test_indices,
|
||||||
gru_handler=self.gru_handler,
|
gru_handler=self.gru_handler,
|
||||||
# --- Pass Calibrator instances and initial state --- #
|
# --- Pass Calibrator instances and initial state --- #
|
||||||
calibrator=calibrator_instance,
|
calibrator=calibrator_instance,
|
||||||
vector_calibrator=vector_calibrator_instance,
|
vector_calibrator=vector_calibrator_instance,
|
||||||
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
|
initial_optimal_T=getattr(self, 'optimal_T', None), # Pass T if exists
|
||||||
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
|
initial_vector_params=getattr(self, 'vector_cal_params', None), # Pass params if exists
|
||||||
fold_edge_threshold=edge_threshold_for_bt,
|
fold_edge_threshold=edge_threshold_for_bt,
|
||||||
# --- Pass raw predictions if needed for rolling cal --- #
|
# --- Pass raw predictions if needed for rolling cal --- #
|
||||||
p_raw_test=p_raw_test_for_bt,
|
p_raw_test=p_raw_test_for_bt,
|
||||||
logits_test=logits_test_for_bt,
|
logits_test=logits_test_for_bt,
|
||||||
# --- Pass original prices --- #
|
# --- Pass original prices --- #
|
||||||
original_prices=self.df_test_original, # Pass the DataFrame
|
original_prices=self.df_test_original, # Pass the DataFrame
|
||||||
is_ternary=self.use_ternary,
|
is_ternary=self.use_ternary,
|
||||||
fold_num=self.current_fold
|
fold_num=self.current_fold
|
||||||
)
|
)
|
||||||
except SystemExit as e:
|
except SystemExit as e: # Corrected indentation
|
||||||
# Catch exits from backtester validation/execution
|
# Catch exits from backtester validation/execution
|
||||||
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
||||||
raise # Re-raise to stop the fold
|
raise # Re-raise to stop the fold
|
||||||
except Exception as e:
|
except Exception as e: # Corrected indentation
|
||||||
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
||||||
# Treat as failure, ensure metrics are None
|
# Treat as failure, ensure metrics are None
|
||||||
self.backtest_results_df = None
|
self.backtest_results_df = None
|
||||||
@ -2027,7 +2041,7 @@ class TradingPipeline:
|
|||||||
def execute(self):
|
def execute(self):
|
||||||
"""Runs the full trading pipeline end-to-end."""
|
"""Runs the full trading pipeline end-to-end."""
|
||||||
logger.info(f"--- Starting Trading Pipeline: Run ID {self.run_id} ---")
|
logger.info(f"--- Starting Trading Pipeline: Run ID {self.run_id} ---")
|
||||||
|
|
||||||
# 1. Load and Preprocess Data
|
# 1. Load and Preprocess Data
|
||||||
self.load_and_preprocess_data()
|
self.load_and_preprocess_data()
|
||||||
if self.data_processed is None: # Check if data loading failed
|
if self.data_processed is None: # Check if data loading failed
|
||||||
@ -2073,7 +2087,7 @@ class TradingPipeline:
|
|||||||
else:
|
else:
|
||||||
# Perform edge accuracy check only if calibration happened and model exists
|
# Perform edge accuracy check only if calibration happened and model exists
|
||||||
self._perform_gru_validation_checks(
|
self._perform_gru_validation_checks(
|
||||||
p_cal_val=self.p_cal_val,
|
p_cal_val=self.p_cal_val,
|
||||||
y_dir_val=self.y_dir_val,
|
y_dir_val=self.y_dir_val,
|
||||||
is_ternary=self.use_ternary
|
is_ternary=self.use_ternary
|
||||||
)
|
)
|
||||||
@ -2092,7 +2106,7 @@ class TradingPipeline:
|
|||||||
# --- Walk-Forward Fold Generation --- #
|
# --- Walk-Forward Fold Generation --- #
|
||||||
def _generate_walk_forward_folds(self) -> Iterator[Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp]]:
|
def _generate_walk_forward_folds(self) -> Iterator[Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp, pd.Timestamp]]:
|
||||||
"""
|
"""
|
||||||
Generates start and end timestamps for train, validation, and test sets
|
Generates start and end timestamps for train, validation, and test sets
|
||||||
for each walk-forward fold based on config settings.
|
for each walk-forward fold based on config settings.
|
||||||
Requires self.df_raw to be loaded first to determine the full date range.
|
Requires self.df_raw to be loaded first to determine the full date range.
|
||||||
"""
|
"""
|
||||||
@ -2741,8 +2755,8 @@ class TradingPipeline:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
|
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
|
||||||
|
|
||||||
# --- Entry Point --- #
|
# --- Entry Point --- #
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user