129 lines
5.1 KiB
Python
129 lines
5.1 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
# Ensure the src directory is in the Python path
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
project_root = os.path.dirname(script_dir)
|
|
if project_root not in sys.path:
|
|
sys.path.insert(0, project_root)
|
|
|
|
# Import necessary components AFTER setting up path
|
|
from src.logger_setup import setup_logger # Correct function import
|
|
from src.io_manager import IOManager
|
|
from src.utils.run_id import make_run_id, get_git_sha # Import Git SHA function
|
|
from src.trading_pipeline import TradingPipeline # Keep pipeline import
|
|
|
|
# --- Define Version --- #
|
|
__version__ = "3.0.0-dev"
|
|
|
|
# --- Config Loading Helper --- #
|
|
def load_config(config_path: str) -> dict:
|
|
"""Helper to load YAML config."""
|
|
import yaml
|
|
# Logic similar to TradingPipeline._load_config, but simplified for entry point
|
|
if not os.path.isabs(config_path):
|
|
# Try relative to current dir first, then project root
|
|
potential_path = os.path.abspath(config_path)
|
|
if not os.path.exists(potential_path):
|
|
potential_path = os.path.join(project_root, config_path)
|
|
if os.path.exists(potential_path):
|
|
config_path = potential_path
|
|
else:
|
|
print(f"ERROR: Config file not found at '{config_path}' (tried CWD and project root).", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
try:
|
|
with open(config_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
if not isinstance(config, dict):
|
|
raise TypeError("Config file did not parse as a dictionary.")
|
|
print(f"Config loaded ✓ ({config_path})") # Log before full logger setup
|
|
return config
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to load or parse config file '{config_path}': {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# --- Main Execution Block --- #
|
|
def main():
|
|
"""Main execution function: parses args, sets up, runs pipeline."""
|
|
parser = argparse.ArgumentParser(description="Run the GRU-SAC Trading Pipeline.")
|
|
# Default config path seeking strategy
|
|
default_config_rel_root = os.path.join(project_root, 'config.yaml')
|
|
default_config_pkg = os.path.join(project_root, 'gru_sac_predictor', 'config.yaml')
|
|
default_config_cwd = os.path.abspath('config.yaml')
|
|
|
|
if os.path.exists(default_config_rel_root):
|
|
default_config = default_config_rel_root
|
|
elif os.path.exists(default_config_pkg):
|
|
default_config = default_config_pkg
|
|
else:
|
|
default_config = default_config_cwd
|
|
|
|
parser.add_argument(
|
|
'--config', type=str, default=default_config,
|
|
help=f"Path to the configuration YAML file (default attempts relative to project root, package dir, or CWD)"
|
|
)
|
|
parser.add_argument(
|
|
'--use-ternary',
|
|
action='store_true',
|
|
help="Enable ternary (up/flat/down) direction labels instead of binary."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# 1. Generate Run ID and Get Git SHA
|
|
run_id = make_run_id()
|
|
git_sha = get_git_sha(short=False) or "unknown"
|
|
|
|
# 2. Load Config first
|
|
try:
|
|
config = load_config(args.config) # Load config dictionary
|
|
except Exception as e:
|
|
# Error message handled within load_config
|
|
sys.exit(1) # Exit if config loading fails
|
|
|
|
# 3. Setup IOManager (passing loaded config dict)
|
|
try:
|
|
io = IOManager(cfg=config, run_id=run_id) # Pass config dict, not path
|
|
# Add git_sha as an attribute AFTER initialization
|
|
io.git_sha = git_sha
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to initialize IOManager: {e}")
|
|
sys.exit(1)
|
|
|
|
# 4. Setup Logger (using path from IOManager)
|
|
logger = setup_logger(cfg=config, run_id=run_id, io=io) # Pass config dict here too (use cfg=)
|
|
|
|
# Log Banner
|
|
logger.info("="*80)
|
|
logger.info(f" GRU-SAC Predictor {__version__} | Commit: {git_sha[:8]} | Run: {run_id}")
|
|
logger.info(f" Config File: {os.path.basename(args.config)}")
|
|
logger.info("="*80)
|
|
|
|
# 5. Modify config based on CLI args (if any)
|
|
if args.use_ternary:
|
|
if 'gru' not in config: config['gru'] = {} # Ensure 'gru' section exists
|
|
config['gru']['label_type'] = 'ternary' # Override label type
|
|
logger.warning("CLI override: Using ternary labels (--use-ternary).")
|
|
|
|
# 6. Initialize and Run Pipeline
|
|
logger.info("Initializing TradingPipeline...")
|
|
try:
|
|
# Pass the loaded (and potentially modified) config dictionary directly
|
|
pipeline = TradingPipeline(config=config, io_manager=io)
|
|
logger.info("TradingPipeline initialized. Starting execution...")
|
|
pipeline.execute()
|
|
logger.info("--- Pipeline Execution Finished ---")
|
|
except SystemExit as e:
|
|
logger.critical(f"Pipeline halted prematurely by SystemExit: {e}")
|
|
sys.exit(1)
|
|
except ValueError as e:
|
|
logger.critical(f"Pipeline initialization failed with ValueError: {e}")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
logger.critical(f"An unexpected error occurred during pipeline execution: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main() |