157 lines
9.1 KiB
YAML
157 lines
9.1 KiB
YAML
# Configuration for GRU-SAC Predictor
|
|
|
|
# --- Run Identification & Output ---
|
|
run_id_template: '{timestamp}' # Template for generating unique run IDs. '{timestamp}' will be replaced by YYYYMMDD_HHMMSS. Allows grouping results, logs, and models.
|
|
|
|
# --- Base Directories (Task 0.1) --- #
|
|
base_dirs:
|
|
results: 'results' # Base directory relative to package root
|
|
logs: 'logs' # Base directory relative to package root
|
|
models: 'models' # Base directory relative to package root
|
|
# --- End Base Directories --- #
|
|
|
|
# --- Output Settings (Task 0.1) --- #
|
|
output:
|
|
figure_dpi: 150 # DPI for saved matplotlib figures
|
|
figure_size: [16, 9] # Default figure size (width, height in inches)
|
|
log_level: INFO # Logging level (DEBUG, INFO, WARNING, ERROR)
|
|
# --- End Output Settings --- #
|
|
|
|
# --- Data Parameters ---
|
|
data:
|
|
db_dir: '../data/crypto_market_data' # Path to the directory containing the market data database (relative to where main.py is run).
|
|
exchange: 'bnbspot' # Name of the exchange table/data source in the database.
|
|
ticker: 'SOL-USDT' # Instrument identifier (e.g., trading pair) within the exchange data.
|
|
start_date: '2025-03-01' # Start date for loading data (YYYY-MM-DD). Note: Ensure enough data for lookback + splits.
|
|
end_date: '2025-03-10' # End date for loading data (YYYY-MM-DD).
|
|
interval: '1min' # Data frequency/interval (e.g., '1min', '5min', '1h').
|
|
# --- New Data Loader Params (v3 Rev) ---
|
|
vol_sampling: false # Task 1.1: Enable volatility-based sampling in DataLoader
|
|
vol_window: 30 # Task 1.1: Window size for volatility calculation
|
|
vol_quantile: 0.5 # Task 1.1: Keep samples where vol > this quantile
|
|
label_smoothing: 0.0 # Task 1.2: Apply label smoothing to binary targets (0.0 = off, 0.1 = [0.05, 0.95])
|
|
|
|
# --- Feature Engineering Params ---
|
|
# (Placeholder for potential future config, like VIF skip - Task 2.5)
|
|
# features:
|
|
# skip_vif: false
|
|
|
|
# --- Data Split ---
|
|
split_ratios:
|
|
train: 0.6 # Proportion of the loaded data to use for training (0.0 to <1.0).
|
|
validation: 0.2 # Proportion of the loaded data to use for validation (0.0 to <1.0).
|
|
# Test ratio is calculated as 1.0 - train - validation. Ensure train + validation < 1.0.
|
|
|
|
# --- GRU Model Parameters ---
|
|
gru:
|
|
# General
|
|
prediction_horizon: 5 # How many steps ahead the model predicts.
|
|
lookback: 60 # Sequence length input to the GRU.
|
|
# --- New Label/Version Params (v3 Rev) ---
|
|
use_ternary: false # Task 1.3: Use ternary (up/flat/down) labels instead of binary.
|
|
flat_sigma_multiplier: 0.25 # Task 1.3: k for ternary flat threshold (eps = k * rolling_sigma_N).
|
|
# --- v2 Specific Params (Legacy) ---
|
|
epochs: 25 # Max training epochs (used if v2).
|
|
batch_size: 256 # Batch size (used if v2).
|
|
patience: 5 # Early stopping patience (used if v2).
|
|
model_load_run_id: null # '20250417_173635' # Run ID to load pre-trained v2 model from (if train_gru=false, use_v3=false).
|
|
# v2 Loss Weighting (Deprecated?)
|
|
recency_weighting:
|
|
enabled: false # true
|
|
linear_start: 0.2
|
|
linear_end: 1.0
|
|
signed_weighting_beta: 0.0
|
|
composite_loss_kappa: 0.0
|
|
|
|
# --- GRU v3 Model Specific Parameters (v3 Rev) --- #
|
|
gru_v3:
|
|
# Architecture (Task 3.1)
|
|
gru_units: 96
|
|
attention_units: 16
|
|
# Training (Task 3.4 - these replace v2 equivalents when use_v3=true)
|
|
epochs: 30
|
|
batch_size: 128
|
|
patience: 5
|
|
model_load_run_id: null # Run ID to load pre-trained v3 model from (if train_gru=false, use_v3=true).
|
|
# Compilation (Task 3.3 / 3.4)
|
|
learning_rate: 1e-4
|
|
focal_gamma: 2.0 # Gamma for Categorical Focal Crossentropy (dir3 head).
|
|
focal_label_smoothing: 0.1 # Label smoothing for Focal Loss (passed to loss func).
|
|
huber_delta: 1.0 # Delta for Huber loss (mu head).
|
|
loss_weight_mu: 0.3 # Weight for the mu head loss.
|
|
loss_weight_dir3: 1.0 # Weight for the dir3 head loss.
|
|
|
|
# --- Calibration Parameters ---
|
|
calibration:
|
|
method: 'temperature' # Task 4.2: Calibration method: 'temperature' or 'vector'.
|
|
edge_threshold: 0.1 # Edge threshold |2p-1| for edge_filtered_accuracy & binary action signal (e.g., 0.1 => p>0.55 or p<0.45)
|
|
recalibrate_every_n: 0 # Recalibrate Temperature every N steps during backtest (0=disable).
|
|
recalibration_window: 10000 # Window size for rolling recalibration.
|
|
|
|
# --- SAC Agent Parameters ---
|
|
sac:
|
|
state_dim: 5 # Env state dimension (should match TradingEnv).
|
|
hidden_size: 64 # Hidden layer size in actor/critic networks.
|
|
gamma: 0.97 # Discount factor.
|
|
tau: 0.005 # Target network update rate.
|
|
actor_lr: 3e-4 # Initial learning rate for actor/critic/alpha optimizers.
|
|
lr_decay_rate: 0.96 # Decay rate for LR scheduler.
|
|
decay_steps: 100000 # Decay steps for LR scheduler.
|
|
buffer_max_size: 100000 # Max size of the replay buffer.
|
|
ou_noise_stddev: 0.2 # OU Noise standard deviation.
|
|
ou_noise_theta: 0.15 # OU Noise theta parameter.
|
|
ou_noise_dt: 0.01 # OU Noise dt parameter.
|
|
alpha: 0.2 # Initial alpha (entropy coefficient).
|
|
alpha_auto_tune: true # Automatically tune alpha?
|
|
target_entropy: null # Task 5.3: Target entropy. If null/default (-action_dim) & auto_tune=true, calculates -0.5*log(4). Otherwise uses value.
|
|
use_batch_norm: true # Use Batch Normalization in actor/critic?
|
|
total_training_steps: 120000 # Total steps for SAC training.
|
|
min_buffer_size: 10000 # Minimum experiences in buffer before training starts.
|
|
batch_size: 256 # Batch size for sampling from replay buffer.
|
|
log_interval: 1000 # Log training metrics every N steps.
|
|
save_interval: 10000 # Save agent checkpoints every N steps.
|
|
# --- New SAC Params (v3 Rev) ---
|
|
use_state_filter: true # Task 5.2: Normalize environment states using MeanStdFilter.
|
|
oracle_seeding_pct: 0.2 # Task 5.5: Percentage of buffer to pre-fill with heuristic actions (0.0 to disable).
|
|
|
|
# --- Environment Parameters (Used by train_sac.py & backtester.py) ---
|
|
environment:
|
|
initial_capital: 10000.0 # Notional capital for env/backtest consistency.
|
|
transaction_cost: 0.0005 # Fractional cost per trade (e.g., 0.0005 = 0.05%).
|
|
# --- New Env Params (v3 Rev) ---
|
|
reward_scale: 100.0 # Task 5.1: Multiplier applied to the raw environment reward.
|
|
action_penalty_lambda: 0.0 # Task 5.4: Coefficient (lambda) for action magnitude penalty (reward -= lambda * action^2).
|
|
|
|
# --- Backtesting Parameters ---
|
|
# (initial_capital, transaction_cost now primarily controlled under 'environment')
|
|
# backtest:
|
|
# initial_capital: 10000.0 # Deprecated: Use environment.initial_capital.
|
|
# transaction_cost: 0.0005 # Deprecated: Use environment.transaction_cost.
|
|
|
|
# --- Experience Generation (Simplified for config) ---
|
|
# Configuration for how experiences are generated or sampled for SAC training.
|
|
# (Currently only 'generate_new_on_epoch' is directly used from here in main.py)
|
|
experience:
|
|
generate_new_on_epoch: False # If true, generate fresh experiences using validation data at the start of each SAC epoch. If false, generate experiences once initially.
|
|
|
|
# --- Control Flags ---
|
|
# Determine which parts of the pipeline to run.
|
|
control:
|
|
# --- Model Version Control (Task 3.5) ---
|
|
use_v3: true # Use GRU v3 model/logic? If false, uses v2.
|
|
# --- End Version Control --- #
|
|
train_gru: true # Train the selected GRU model? (v2 or v3 based on use_v3).
|
|
train_sac: true # Run the offline SAC training script before backtesting?
|
|
|
|
# --- SAC Loading/Resuming ---
|
|
# For resuming training in train_sac.py:
|
|
sac_resume_run_id: null # Run ID of SAC agent to load *before* starting training (e.g., "sac_train_..."). If null, starts fresh.
|
|
sac_resume_step: final # Checkpoint step to resume from: 'final' or step number.
|
|
# For loading agent for backtesting in run_pipeline.py:
|
|
sac_load_run_id: null # Run ID of the SAC training run to load weights from for *backtesting* (e.g., "sac_train_..."). If null, uses initial weights.
|
|
sac_load_step: final # Which SAC checkpoint to load for backtesting: 'final' or step number.
|
|
|
|
# --- Other Pipeline Controls ---
|
|
run_backtest: true # Run the backtest?
|
|
generate_plots: true # Generate output plots?
|
|
# generate_report: True # Deprecated: Metrics are saved to a .txt file. |