2025-04-20 17:52:49 +00:00

129 lines
5.1 KiB
Python

import argparse
import logging
import os
import sys
# Ensure the src directory is in the Python path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# Import necessary components AFTER setting up path
from src.logger_setup import setup_logger # Correct function import
from src.io_manager import IOManager
from src.utils.run_id import make_run_id, get_git_sha # Import Git SHA function
from src.trading_pipeline import TradingPipeline # Keep pipeline import
# --- Define Version --- #
__version__ = "3.0.0-dev"
# --- Config Loading Helper --- #
def load_config(config_path: str) -> dict:
"""Helper to load YAML config."""
import yaml
# Logic similar to TradingPipeline._load_config, but simplified for entry point
if not os.path.isabs(config_path):
# Try relative to current dir first, then project root
potential_path = os.path.abspath(config_path)
if not os.path.exists(potential_path):
potential_path = os.path.join(project_root, config_path)
if os.path.exists(potential_path):
config_path = potential_path
else:
print(f"ERROR: Config file not found at '{config_path}' (tried CWD and project root).", file=sys.stderr)
sys.exit(1)
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
if not isinstance(config, dict):
raise TypeError("Config file did not parse as a dictionary.")
print(f"Config loaded ✓ ({config_path})") # Log before full logger setup
return config
except Exception as e:
print(f"ERROR: Failed to load or parse config file '{config_path}': {e}", file=sys.stderr)
sys.exit(1)
# --- Main Execution Block --- #
def main():
"""Main execution function: parses args, sets up, runs pipeline."""
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
# Default config path seeking strategy
default_config_rel_root = os.path.join(project_root, 'config.yaml')
default_config_pkg = os.path.join(project_root, 'gru_sac_predictor', 'config.yaml')
default_config_cwd = os.path.abspath('config.yaml')
if os.path.exists(default_config_rel_root):
default_config = default_config_rel_root
elif os.path.exists(default_config_pkg):
default_config = default_config_pkg
else:
default_config = default_config_cwd
parser.add_argument(
'--config', type=str, default=default_config,
help=f"Path to the configuration YAML file (default attempts relative to project root, package dir, or CWD)"
)
parser.add_argument(
'--use-ternary',
action='store_true',
help="Enable ternary (up/flat/down) direction labels instead of binary."
)
args = parser.parse_args()
# 1. Generate Run ID and Get Git SHA
run_id = make_run_id()
git_sha = get_git_sha(short=False) or "unknown"
# 2. Load Config first
try:
config = load_config(args.config) # Load config dictionary
except Exception as e:
# Error message handled within load_config
sys.exit(1) # Exit if config loading fails
# 3. Setup IOManager (passing loaded config dict)
try:
io = IOManager(cfg=config, run_id=run_id) # Pass config dict, not path
# Add git_sha as an attribute AFTER initialization
io.git_sha = git_sha
except Exception as e:
print(f"ERROR: Failed to initialize IOManager: {e}")
sys.exit(1)
# 4. Setup Logger (using path from IOManager)
logger = setup_logger(cfg=config, run_id=run_id, io=io) # Pass config dict here too (use cfg=)
# Log Banner
logger.info("="*80)
logger.info(f" GRU-SAC Predictor {__version__} | Commit: {git_sha[:8]} | Run: {run_id}")
logger.info(f" Config File: {os.path.basename(args.config)}")
logger.info("="*80)
# 5. Modify config based on CLI args (if any)
if args.use_ternary:
if 'gru' not in config: config['gru'] = {} # Ensure 'gru' section exists
config['gru']['label_type'] = 'ternary' # Override label type
logger.warning("CLI override: Using ternary labels (--use-ternary).")
# 6. Initialize and Run Pipeline
logger.info("Initializing TradingPipeline...")
try:
# Pass the loaded (and potentially modified) config dictionary directly
pipeline = TradingPipeline(config=config, io_manager=io)
logger.info("TradingPipeline initialized. Starting execution...")
pipeline.execute()
logger.info("--- Pipeline Execution Finished ---")
except SystemExit as e:
logger.critical(f"Pipeline halted prematurely by SystemExit: {e}")
sys.exit(1)
except ValueError as e:
logger.critical(f"Pipeline initialization failed with ValueError: {e}")
sys.exit(1)
except Exception as e:
logger.critical(f"An unexpected error occurred during pipeline execution: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()