iundentation fix pipeline

This commit is contained in:
yasha 2025-04-18 18:18:52 +00:00
parent a02dd4342e
commit a86cdb2c8f
3 changed files with 233 additions and 219 deletions

View File

@ -1357,8 +1357,8 @@ class TradingPipeline:
except Exception as e: except Exception as e:
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True) logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
self.optimal_T = None self.optimal_T = None
else: else: # Covers cases where calibration method is wrong or mismatch with ternary state
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.") logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
self.optimal_T = None self.optimal_T = None
self.vector_cal_params = None self.vector_cal_params = None
@ -1401,16 +1401,30 @@ class TradingPipeline:
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt" thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
try: try:
# Ensure fold_results_dir is defined (should be available from context) # Ensure fold_results_dir is defined (should be available from context)
if self.io and fold_results_dir: # Assuming fold_results_dir is defined in the outer scope
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''), # Use filename as key for io thresh_file.replace('.txt',''), # Use filename as key for io
base_dir=fold_results_dir, use_txt=True) base_dir=fold_results_dir, use_txt=True)
logger.info(f"Saved optimized edge threshold to {thresh_file}") logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
elif self.io: elif self.io:
# Fallback: Save to the main results directory if fold_results_dir isn't available
# Construct the path for logging clarity
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold}, self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''), thresh_file.replace('.txt',''),
section='results', use_txt=True) # Fallback save section='results', use_txt=True) # Fallback save
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}") logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else:
logger.warning("IOManager not available, cannot save optimized threshold.")
except NameError: # Specifically catch if fold_results_dir is not defined
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
if self.io:
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''),
section='results', use_txt=True)
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else: else:
logger.warning("IOManager not available, cannot save optimized threshold.") logger.warning("IOManager not available, cannot save optimized threshold.")
except Exception as e: except Exception as e:
@ -1467,7 +1481,6 @@ class TradingPipeline:
calib_config = self.config.get('calibration', {}) calib_config = self.config.get('calibration', {})
optimize_edge = calib_config.get('optimize_edge_threshold', False) optimize_edge = calib_config.get('optimize_edge_threshold', False)
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
self.fold_edge_threshold = edge_thr_config # Initialize with config value self.fold_edge_threshold = edge_thr_config # Initialize with config value
if optimize_edge and not is_ternary: if optimize_edge and not is_ternary:
@ -1567,8 +1580,8 @@ class TradingPipeline:
logger.error(error_msg) logger.error(error_msg)
# Use sys.exit with a specific message for clarity # Use sys.exit with a specific message for clarity
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).") sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
else: else: # Corrected indentation
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
# --- End Validation Helper --- # # --- End Validation Helper --- #
def train_or_load_sac(self): def train_or_load_sac(self):
@ -1791,6 +1804,7 @@ class TradingPipeline:
# Get raw predictions needed for rolling calibration # Get raw predictions needed for rolling calibration
p_raw_test_for_bt = None p_raw_test_for_bt = None
logits_test_for_bt = None logits_test_for_bt = None
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
if self.config.get('calibration', {}).get('rolling_enabled', False): if self.config.get('calibration', {}).get('rolling_enabled', False):
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...") logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
if is_ternary: if is_ternary:
@ -1798,18 +1812,18 @@ class TradingPipeline:
if logits_test_for_bt is None: if logits_test_for_bt is None:
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.") logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
else: else: # Corrected indentation
preds_test_raw = self.gru_handler.predict(self.X_test_seq) preds_test_raw = self.gru_handler.predict(self.X_test_seq)
if preds_test_raw is None or len(preds_test_raw) < 3: if preds_test_raw is None or len(preds_test_raw) < 3:
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.") logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.") raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
p_raw_test_for_bt = preds_test_raw[2].flatten() p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
# Get the edge threshold determined during validation (optimized or fixed) # Get the edge threshold determined during validation (optimized or fixed)
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1)) edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.") logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
try: try: # Corrected indentation
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest( self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
sac_agent_load_path=self.sac_agent_load_path, sac_agent_load_path=self.sac_agent_load_path,
X_test_seq=self.X_test_seq, X_test_seq=self.X_test_seq,
@ -1830,11 +1844,11 @@ class TradingPipeline:
is_ternary=self.use_ternary, is_ternary=self.use_ternary,
fold_num=self.current_fold fold_num=self.current_fold
) )
except SystemExit as e: except SystemExit as e: # Corrected indentation
# Catch exits from backtester validation/execution # Catch exits from backtester validation/execution
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}") logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
raise # Re-raise to stop the fold raise # Re-raise to stop the fold
except Exception as e: except Exception as e: # Corrected indentation
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True) logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
# Treat as failure, ensure metrics are None # Treat as failure, ensure metrics are None
self.backtest_results_df = None self.backtest_results_df = None
@ -2742,7 +2756,7 @@ class TradingPipeline:
except Exception as e: except Exception as e:
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True) logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
# --- Entry Point --- # # --- Entry Point --- #
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.") parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")