From c3526bb9f6cac83fd07b7c13e70672b1bd066935 Mon Sep 17 00:00:00 2001 From: yasha Date: Sat, 19 Apr 2025 02:16:24 +0000 Subject: [PATCH] feat: Implement imputed bar handling in TradingEnv (skip, hold, penalty) and tests --- gru_sac_predictor/src/trading_env.py | 77 ++++++++- gru_sac_predictor/tests/test_trading_env.py | 166 ++++++++++++++++++++ 2 files changed, 236 insertions(+), 7 deletions(-) create mode 100644 gru_sac_predictor/tests/test_trading_env.py diff --git a/gru_sac_predictor/src/trading_env.py b/gru_sac_predictor/src/trading_env.py index a0c0aade..4582baf4 100644 --- a/gru_sac_predictor/src/trading_env.py +++ b/gru_sac_predictor/src/trading_env.py @@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns. import numpy as np import pandas as pd import logging +import gymnasium as gym +from omegaconf import DictConfig # Added for config typing env_logger = logging.getLogger(__name__) @@ -15,6 +17,8 @@ class TradingEnv: sigma_predictions: np.ndarray, p_cal_predictions: np.ndarray, actual_returns: np.ndarray, + bar_imputed_flags: np.ndarray, # Added imputed flags + config: DictConfig, # Added config initial_capital: float = 10000.0, transaction_cost: float = 0.0005, reward_scale: float = 100.0, @@ -27,18 +31,22 @@ class TradingEnv: sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)). p_cal_predictions: Calibrated probability of price increase (p_cal). actual_returns: Actual log returns (y_ret). + bar_imputed_flags: Boolean array indicating if a bar was imputed. + config: OmegaConf configuration object. initial_capital: Starting capital for simulation (used notionally in reward). transaction_cost: Fractional cost per trade. reward_scale: Multiplier for the reward signal. action_penalty_lambda: Coefficient for the action magnitude penalty (λ). """ - assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \ - "All input arrays must have the same length" + assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \ + "All input arrays (predictions, returns, imputed_flags) must have the same length" self.mu = mu_predictions self.sigma = sigma_predictions self.p_cal = p_cal_predictions self.actual_returns = actual_returns + self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags + self.config = config # Store config self.initial_capital = initial_capital self.transaction_cost = transaction_cost @@ -65,20 +73,36 @@ class TradingEnv: self.state_dim = 5 self.action_dim = 1 + # --- Define Gym Spaces --- + self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32) + self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32) + # --- End Define Gym Spaces --- + env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.") def _get_state(self) -> np.ndarray: """Construct the state vector for the current step.""" if self.current_step >= self.n_steps: - # Handle episode end - return a dummy state or zeros return np.zeros(self.state_dim, dtype=np.float32) mu_t = self.mu[self.current_step] sigma_t = self.sigma[self.current_step] p_cal_t = self.p_cal[self.current_step] - edge_t = 2 * p_cal_t - 1 - z_score_t = np.abs(mu_t) / (sigma_t + 1e-9) + # Calculate edge based on p_cal shape (binary vs ternary) + if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3: + # Ternary: edge = max(P(up), P(down)) - P(flat) + # Assuming order [Down, Flat, Up] for p_cal_t + edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1] + elif isinstance(p_cal_t, (float, np.number)): + # Binary: edge = 2 * P(up) - 1 + edge_t = 2 * p_cal_t - 1 + else: + env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.") + edge_t = 0.0 + + _EPS = 1e-9 # Define epsilon locally + z_score_t = np.abs(mu_t) / (sigma_t + _EPS) # State uses position *before* the action for this step is taken state = np.array([ @@ -108,11 +132,48 @@ class TradingEnv: Returns: tuple: (next_state, reward, done, info_dict) """ + info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False} + if self.current_step >= self.n_steps: # Should not happen if 'done' is handled correctly, but as safeguard env_logger.warning("Step called after environment finished.") - return self._get_state(), 0.0, True, {} + return self._get_state(), 0.0, True, info + # --- Handle Imputed Bar --- # + imputed = self.bar_imputed[self.current_step] + if imputed: + mode = self.config.sac.imputed_handling + env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}") + if mode == "skip": + self.current_step += 1 + next_state = self._get_state() # Get state for the *next* actual step + # Return 0 reward, not done, but indicate skip for buffer handling + info['is_imputed_step_skipped'] = True + return next_state, 0.0, False, info + elif mode == "hold": + # Action is forced to maintain current position + action = self.current_position + elif mode == "penalty": + # Calculate reward penalty based on config + target_position_penalty = np.clip(action, -1.0, 1.0) + reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2 + # Update position based on agent's intended action (clipped) + self.current_position = target_position_penalty + # Update capital notionally (no actual return, only cost if implemented) + # Cost is implicitly 0 here as there's no trade size if pos doesn't change + # If penalty mode allowed position change, cost would apply. + # For simplicity, we don't add cost here for the penalty step. + self.current_step += 1 + next_state = self._get_state() + scaled_reward = reward * self.reward_scale # Scale the penalty + done = self.current_step >= self.n_steps + info['capital'] = self.current_capital + info['position'] = self.current_position + return next_state, scaled_reward, done, info + # else: default behavior (treat as normal bar) - implicitly handled by falling through + # --- End Handle Imputed Bar --- # + + # --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- # # Action is the TARGET position for the *end* of this step target_position = np.clip(action, -1.0, 1.0) trade_size = target_position - self.current_position @@ -150,7 +211,9 @@ class TradingEnv: done = self.current_step >= self.n_steps or self.current_capital <= 0 next_state = self._get_state() - info = {'capital': self.current_capital, 'position': self.current_position} + # Update info dict (capital/position might have changed in normal step) + info['capital'] = self.current_capital + info['position'] = self.current_position # Log step details periodically # if self.current_step % 1000 == 0: diff --git a/gru_sac_predictor/tests/test_trading_env.py b/gru_sac_predictor/tests/test_trading_env.py new file mode 100644 index 00000000..8af07e6b --- /dev/null +++ b/gru_sac_predictor/tests/test_trading_env.py @@ -0,0 +1,166 @@ +import pytest +import numpy as np +from omegaconf import OmegaConf + +# Adjust import path based on structure +from gru_sac_predictor.src.trading_env import TradingEnv + +# --- Test Fixtures --- + +@pytest.fixture +def sample_env_data(): + """Provides sample data for initializing the TradingEnv.""" + n_steps = 10 + data = { + 'mu_predictions': np.random.randn(n_steps) * 0.001, + 'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005), + 'p_cal_predictions': np.random.rand(n_steps), + 'actual_returns': np.random.randn(n_steps) * 0.0015, + 'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool) + } + return data + +@pytest.fixture +def base_env_config(): + """Base configuration for the environment.""" + return OmegaConf.create({ + 'sac': { + 'imputed_handling': 'skip', # Default test mode + 'action_penalty': 0.05 + }, + 'environment': { + 'initial_capital': 10000.0, + 'transaction_cost': 0.0005, + 'reward_scale': 100.0, + 'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc + } + }) + +@pytest.fixture +def trading_env_instance(sample_env_data, base_env_config): + """Creates a TradingEnv instance with default 'skip' mode.""" + return TradingEnv(**sample_env_data, config=base_env_config) + +# --- Test Functions --- + +def test_env_initialization(trading_env_instance, sample_env_data): + assert trading_env_instance.n_steps == len(sample_env_data['actual_returns']) + assert trading_env_instance.current_step == 0 + assert trading_env_instance.current_position == 0.0 + assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags']) + +def test_env_reset(trading_env_instance): + # Take a few steps + trading_env_instance.step(0.5) + trading_env_instance.step(-0.2) + assert trading_env_instance.current_step > 0 + # Reset + initial_state = trading_env_instance.reset() + assert trading_env_instance.current_step == 0 + assert trading_env_instance.current_position == 0.0 + assert initial_state is not None + assert initial_state.shape == (trading_env_instance.state_dim,) + +def test_env_step_normal(trading_env_instance): + # Test a normal step (step 0 is not imputed) + initial_pos = trading_env_instance.current_position + action = 0.7 + next_state, reward, done, info = trading_env_instance.step(action) + + assert trading_env_instance.current_step == 1 + assert trading_env_instance.current_position == action # Position updates to action + assert not info['is_imputed_step_skipped'] + assert not done + assert next_state is not None + # Reward calculation is complex, just check type/sign if needed + assert isinstance(reward, float) + +def test_env_step_imputed_skip(trading_env_instance, sample_env_data): + # Step 2 is imputed in sample_env_data + trading_env_instance.step(0.5) # Step 0 + trading_env_instance.step(0.6) # Step 1 + assert trading_env_instance.current_step == 2 + initial_pos_before_imputed = trading_env_instance.current_position + + # Action for the imputed step (should be ignored by 'skip') + action_imputed = 0.9 + next_state, reward, done, info = trading_env_instance.step(action_imputed) + + # Should skip step 2 and now be at step 3 + assert trading_env_instance.current_step == 3 + # Position should NOT have changed from step 1 + assert trading_env_instance.current_position == initial_pos_before_imputed + assert reward == 0.0 # Skip gives 0 reward + assert not done + assert info['is_imputed_step_skipped'] == True # Crucial check for buffer + # Check that the returned state is for step 3 + expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3 + np.testing.assert_array_almost_equal(next_state, expected_state_step_3) + +def test_env_step_imputed_hold(sample_env_data, base_env_config): + cfg = base_env_config.copy() + cfg.sac.imputed_handling = 'hold' + env = TradingEnv(**sample_env_data, config=cfg) + + # Step 2 is imputed + env.step(0.5) # Step 0 + env.step(0.6) # Step 1 + assert env.current_step == 2 + position_before_imputed = env.current_position + + # Action for the imputed step (should be overridden by 'hold') + action_imputed = -0.5 + next_state, reward, done, info = env.step(action_imputed) + + # Should process step 2 and move to step 3 + assert env.current_step == 3 + # Position should be the same as before the step + assert env.current_position == position_before_imputed + assert not info['is_imputed_step_skipped'] + assert not done + # Reward should be calculated based on holding the position + expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1) + expected_cost = 0 # No trade size if holding + expected_penalty = 0 # No penalty in hold mode + expected_raw_reward = expected_pnl - expected_cost - expected_penalty + expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale + assert np.isclose(reward, expected_scaled_reward) + +def test_env_step_imputed_penalty(sample_env_data, base_env_config): + cfg = base_env_config.copy() + cfg.sac.imputed_handling = 'penalty' + cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing + env = TradingEnv(**sample_env_data, config=cfg) + + # Step 2 is imputed + env.step(0.5) # Step 0 + env.step(0.6) # Step 1 + assert env.current_step == 2 + position_before_imputed = env.current_position # Should be 0.6 + + # Action for the imputed step + action_imputed = -0.2 + next_state, reward, done, info = env.step(action_imputed) + + # Should process step 2 and move to step 3 + assert env.current_step == 3 + # Position should update to the *agent's* action + assert env.current_position == np.clip(action_imputed, -1.0, 1.0) + assert not info['is_imputed_step_skipped'] + assert not done + + # Reward calculation is ONLY the penalty + expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2 + expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale + assert np.isclose(reward, expected_scaled_reward) + +def test_env_done_condition(trading_env_instance, sample_env_data): + n_steps = len(sample_env_data['actual_returns']) + # Step through the environment + done = False + for i in range(n_steps): + _, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1)) + if i < n_steps - 1: + assert not done + else: + assert done # Should be done on the last step \ No newline at end of file