iundentation fix pipeline
This commit is contained in:
parent
a02dd4342e
commit
a86cdb2c8f
@ -1357,8 +1357,8 @@ class TradingPipeline:
|
||||
except Exception as e:
|
||||
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
||||
self.optimal_T = None
|
||||
else:
|
||||
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.")
|
||||
else: # Covers cases where calibration method is wrong or mismatch with ternary state
|
||||
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
|
||||
self.optimal_T = None
|
||||
self.vector_cal_params = None
|
||||
|
||||
@ -1401,16 +1401,30 @@ class TradingPipeline:
|
||||
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
||||
try:
|
||||
# Ensure fold_results_dir is defined (should be available from context)
|
||||
if self.io and fold_results_dir:
|
||||
# Assuming fold_results_dir is defined in the outer scope
|
||||
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''), # Use filename as key for io
|
||||
base_dir=fold_results_dir, use_txt=True)
|
||||
logger.info(f"Saved optimized edge threshold to {thresh_file}")
|
||||
logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
|
||||
elif self.io:
|
||||
# Fallback: Save to the main results directory if fold_results_dir isn't available
|
||||
# Construct the path for logging clarity
|
||||
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''),
|
||||
section='results', use_txt=True) # Fallback save
|
||||
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}")
|
||||
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||
else:
|
||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||
except NameError: # Specifically catch if fold_results_dir is not defined
|
||||
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
|
||||
if self.io:
|
||||
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||
thresh_file.replace('.txt',''),
|
||||
section='results', use_txt=True)
|
||||
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||
else:
|
||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||
except Exception as e:
|
||||
@ -1467,7 +1481,6 @@ class TradingPipeline:
|
||||
calib_config = self.config.get('calibration', {})
|
||||
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
||||
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
||||
|
||||
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
||||
|
||||
if optimize_edge and not is_ternary:
|
||||
@ -1567,8 +1580,8 @@ class TradingPipeline:
|
||||
logger.error(error_msg)
|
||||
# Use sys.exit with a specific message for clarity
|
||||
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
||||
else:
|
||||
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).")
|
||||
else: # Corrected indentation
|
||||
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
|
||||
# --- End Validation Helper --- #
|
||||
|
||||
def train_or_load_sac(self):
|
||||
@ -1791,6 +1804,7 @@ class TradingPipeline:
|
||||
# Get raw predictions needed for rolling calibration
|
||||
p_raw_test_for_bt = None
|
||||
logits_test_for_bt = None
|
||||
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
|
||||
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
||||
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
||||
if is_ternary:
|
||||
@ -1798,18 +1812,18 @@ class TradingPipeline:
|
||||
if logits_test_for_bt is None:
|
||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
||||
else:
|
||||
else: # Corrected indentation
|
||||
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
||||
if preds_test_raw is None or len(preds_test_raw) < 3:
|
||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
||||
p_raw_test_for_bt = preds_test_raw[2].flatten()
|
||||
p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
|
||||
|
||||
# Get the edge threshold determined during validation (optimized or fixed)
|
||||
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
||||
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
||||
|
||||
try:
|
||||
try: # Corrected indentation
|
||||
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
||||
sac_agent_load_path=self.sac_agent_load_path,
|
||||
X_test_seq=self.X_test_seq,
|
||||
@ -1830,11 +1844,11 @@ class TradingPipeline:
|
||||
is_ternary=self.use_ternary,
|
||||
fold_num=self.current_fold
|
||||
)
|
||||
except SystemExit as e:
|
||||
except SystemExit as e: # Corrected indentation
|
||||
# Catch exits from backtester validation/execution
|
||||
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
||||
raise # Re-raise to stop the fold
|
||||
except Exception as e:
|
||||
except Exception as e: # Corrected indentation
|
||||
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
||||
# Treat as failure, ensure metrics are None
|
||||
self.backtest_results_df = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user