iundentation fix pipeline
This commit is contained in:
parent
a02dd4342e
commit
a86cdb2c8f
@ -1357,8 +1357,8 @@ class TradingPipeline:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Error during temperature calibration: {e}", exc_info=True)
|
||||||
self.optimal_T = None
|
self.optimal_T = None
|
||||||
else:
|
else: # Covers cases where calibration method is wrong or mismatch with ternary state
|
||||||
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch. Skipping GRU validation checks.")
|
logger.warning(f"Fold {self.current_fold}: Calibration method '{calibration_method}' or ternary state mismatch ({is_ternary_check}). Skipping GRU validation checks.")
|
||||||
self.optimal_T = None
|
self.optimal_T = None
|
||||||
self.vector_cal_params = None
|
self.vector_cal_params = None
|
||||||
|
|
||||||
@ -1401,16 +1401,30 @@ class TradingPipeline:
|
|||||||
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
thresh_file = f"optimized_edge_threshold_fold_{self.current_fold}.txt"
|
||||||
try:
|
try:
|
||||||
# Ensure fold_results_dir is defined (should be available from context)
|
# Ensure fold_results_dir is defined (should be available from context)
|
||||||
if self.io and fold_results_dir:
|
# Assuming fold_results_dir is defined in the outer scope
|
||||||
|
if self.io and 'fold_results_dir' in locals() and fold_results_dir:
|
||||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
thresh_file.replace('.txt',''), # Use filename as key for io
|
thresh_file.replace('.txt',''), # Use filename as key for io
|
||||||
base_dir=fold_results_dir, use_txt=True)
|
base_dir=fold_results_dir, use_txt=True)
|
||||||
logger.info(f"Saved optimized edge threshold to {thresh_file}")
|
logger.info(f"Saved optimized edge threshold to {os.path.join(fold_results_dir, thresh_file)}")
|
||||||
elif self.io:
|
elif self.io:
|
||||||
|
# Fallback: Save to the main results directory if fold_results_dir isn't available
|
||||||
|
# Construct the path for logging clarity
|
||||||
|
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||||
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
thresh_file.replace('.txt',''),
|
thresh_file.replace('.txt',''),
|
||||||
section='results', use_txt=True) # Fallback save
|
section='results', use_txt=True) # Fallback save
|
||||||
logger.info(f"Saved optimized edge threshold to main results dir: {thresh_file}")
|
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||||
|
except NameError: # Specifically catch if fold_results_dir is not defined
|
||||||
|
logger.warning("fold_results_dir not defined. Attempting fallback save to main results dir.")
|
||||||
|
if self.io:
|
||||||
|
fallback_path = os.path.join(self.io.get_section_path('results'), thresh_file)
|
||||||
|
self.io.save_json({'optimized_edge_threshold': self.optimized_edge_threshold},
|
||||||
|
thresh_file.replace('.txt',''),
|
||||||
|
section='results', use_txt=True)
|
||||||
|
logger.info(f"Saved optimized edge threshold to main results dir: {fallback_path}")
|
||||||
else:
|
else:
|
||||||
logger.warning("IOManager not available, cannot save optimized threshold.")
|
logger.warning("IOManager not available, cannot save optimized threshold.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1467,7 +1481,6 @@ class TradingPipeline:
|
|||||||
calib_config = self.config.get('calibration', {})
|
calib_config = self.config.get('calibration', {})
|
||||||
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
optimize_edge = calib_config.get('optimize_edge_threshold', False)
|
||||||
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
edge_thr_config = calib_config.get('edge_threshold', 0.1) # Default/fallback
|
||||||
|
|
||||||
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
self.fold_edge_threshold = edge_thr_config # Initialize with config value
|
||||||
|
|
||||||
if optimize_edge and not is_ternary:
|
if optimize_edge and not is_ternary:
|
||||||
@ -1567,8 +1580,8 @@ class TradingPipeline:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
# Use sys.exit with a specific message for clarity
|
# Use sys.exit with a specific message for clarity
|
||||||
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
sys.exit(f"Fold {self.current_fold}: GRU validation gates failed (Edge Acc / Brier Score).")
|
||||||
else:
|
else: # Corrected indentation
|
||||||
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).")
|
logger.info(f"Fold {self.current_fold}: GRU validation checks passed (Edge Acc & Brier Score).") # Corrected indentation
|
||||||
# --- End Validation Helper --- #
|
# --- End Validation Helper --- #
|
||||||
|
|
||||||
def train_or_load_sac(self):
|
def train_or_load_sac(self):
|
||||||
@ -1791,6 +1804,7 @@ class TradingPipeline:
|
|||||||
# Get raw predictions needed for rolling calibration
|
# Get raw predictions needed for rolling calibration
|
||||||
p_raw_test_for_bt = None
|
p_raw_test_for_bt = None
|
||||||
logits_test_for_bt = None
|
logits_test_for_bt = None
|
||||||
|
is_ternary = self.config.get('gru', {}).get('use_ternary_output', False) # Need to know if ternary
|
||||||
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
if self.config.get('calibration', {}).get('rolling_enabled', False):
|
||||||
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
logger.info(f"Fold {self.current_fold}: Getting raw GRU outputs for rolling calibration...")
|
||||||
if is_ternary:
|
if is_ternary:
|
||||||
@ -1798,18 +1812,18 @@ class TradingPipeline:
|
|||||||
if logits_test_for_bt is None:
|
if logits_test_for_bt is None:
|
||||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
logger.error(f"Fold {self.current_fold}: Failed to get GRU logits for rolling calibration.")
|
||||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
raise SystemExit(f"Fold {self.current_fold}: Failed GRU logit prediction.")
|
||||||
else:
|
else: # Corrected indentation
|
||||||
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
preds_test_raw = self.gru_handler.predict(self.X_test_seq)
|
||||||
if preds_test_raw is None or len(preds_test_raw) < 3:
|
if preds_test_raw is None or len(preds_test_raw) < 3:
|
||||||
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
logger.error(f"Fold {self.current_fold}: Failed to get GRU raw predictions for rolling calibration.")
|
||||||
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
raise SystemExit(f"Fold {self.current_fold}: Failed GRU raw prediction.")
|
||||||
p_raw_test_for_bt = preds_test_raw[2].flatten()
|
p_raw_test_for_bt = preds_test_raw[2].flatten() # Assuming index 2 is probabilities
|
||||||
|
|
||||||
# Get the edge threshold determined during validation (optimized or fixed)
|
# Get the edge threshold determined during validation (optimized or fixed)
|
||||||
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
edge_threshold_for_bt = getattr(self, 'fold_edge_threshold', self.config.get('calibration', {}).get('edge_threshold', 0.1))
|
||||||
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
logger.info(f"Fold {self.current_fold}: Using edge threshold {edge_threshold_for_bt:.4f} for backtest execution.")
|
||||||
|
|
||||||
try:
|
try: # Corrected indentation
|
||||||
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
self.backtest_results_df, self.backtest_metrics, self.metrics_log_df = self.backtester.run_backtest(
|
||||||
sac_agent_load_path=self.sac_agent_load_path,
|
sac_agent_load_path=self.sac_agent_load_path,
|
||||||
X_test_seq=self.X_test_seq,
|
X_test_seq=self.X_test_seq,
|
||||||
@ -1830,11 +1844,11 @@ class TradingPipeline:
|
|||||||
is_ternary=self.use_ternary,
|
is_ternary=self.use_ternary,
|
||||||
fold_num=self.current_fold
|
fold_num=self.current_fold
|
||||||
)
|
)
|
||||||
except SystemExit as e:
|
except SystemExit as e: # Corrected indentation
|
||||||
# Catch exits from backtester validation/execution
|
# Catch exits from backtester validation/execution
|
||||||
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
logger.error(f"Fold {self.current_fold}: Backtest aborted: {e}")
|
||||||
raise # Re-raise to stop the fold
|
raise # Re-raise to stop the fold
|
||||||
except Exception as e:
|
except Exception as e: # Corrected indentation
|
||||||
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
logger.error(f"Fold {self.current_fold}: Unhandled error during backtester.run_backtest: {e}", exc_info=True)
|
||||||
# Treat as failure, ensure metrics are None
|
# Treat as failure, ensure metrics are None
|
||||||
self.backtest_results_df = None
|
self.backtest_results_df = None
|
||||||
@ -2742,7 +2756,7 @@ class TradingPipeline:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
|
logger.error(f"Error during SAC agent weight averaging or saving: {e}", exc_info=True)
|
||||||
|
|
||||||
# --- Entry Point --- #
|
# --- Entry Point --- #
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user