iundentation fix pipeline

This commit is contained in:
yasha 2025-04-18 18:18:52 +00:00
parent a02dd4342e
commit a86cdb2c8f
3 changed files with 233 additions and 219 deletions

View File

@ -1357,8 +1357,8 @@ class TradingPipeline:
except Exception as e:
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
self.optimal_T = None
else:
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.")
else: # Covers cases where calibration method is wrong or mismatch with ternary state
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
self.optimal_T = None
self.vector_cal_params = None
@ -1401,16 +1401,30 @@ class TradingPipeline:
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
try:
# Ensure fold_results_dir is defined (should be available from context)
if self.io and fold_results_dir:
# Assuming fold_results_dir is defined in the outer scope
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''), # Use filename as key for io
base_dir=fold_results_dir, use_txt=True)
logger.info(f"Saved optimized edge threshold to {thresh_file}")
logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
elif self.io:
# Fallback: Save to the main results directory if fold_results_dir isn't available
# Construct the path for logging clarity
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''),
section='results', use_txt=True) # Fallback save
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}")
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else:
logger.warning("IOManager not available, cannot save optimized threshold.")
except NameError: # Specifically catch if fold_results_dir is not defined
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
if self.io:
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
thresh_file.replace('.txt',''),
section='results', use_txt=True)
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
else:
logger.warning("IOManager not available, cannot save optimized threshold.")
except Exception as e:
@ -1467,7 +1481,6 @@ class TradingPipeline:
calib_config = self.config.get('calibration', {})
optimize_edge = calib_config.get('optimize_edge_threshold', False)
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
self.fold_edge_threshold = edge_thr_config # Initialize with config value
if optimize_edge and not is_ternary:
@ -1567,8 +1580,8 @@ class TradingPipeline:
logger.error(error_msg)
# Use sys.exit with a specific message for clarity
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
else:
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).")
else: # Corrected indentation
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
# --- End Validation Helper --- #
def train_or_load_sac(self):
@ -1791,6 +1804,7 @@ class TradingPipeline:
# Get raw predictions needed for rolling calibration
p_raw_test_for_bt = None
logits_test_for_bt = None
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
if self.config.get('calibration', {}).get('rolling_enabled', False):
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
if is_ternary:
@ -1798,18 +1812,18 @@ class TradingPipeline:
if logits_test_for_bt is None:
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
else:
else: # Corrected indentation
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
if preds_test_raw is None or len(preds_test_raw) < 3:
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
p_raw_test_for_bt = preds_test_raw[2].flatten()
p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
# Get the edge threshold determined during validation (optimized or fixed)
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
try:
try: # Corrected indentation
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
sac_agent_load_path=self.sac_agent_load_path,
X_test_seq=self.X_test_seq,
@ -1830,11 +1844,11 @@ class TradingPipeline:
is_ternary=self.use_ternary,
fold_num=self.current_fold
)
except SystemExit as e:
except SystemExit as e: # Corrected indentation
# Catch exits from backtester validation/execution
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
raise # Re-raise to stop the fold
except Exception as e:
except Exception as e: # Corrected indentation
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
# Treat as failure, ensure metrics are None
self.backtest_results_df = None