feat: Implement imputed bar handling in TradingEnv (skip, hold, penalty) and tests

This commit is contained in:
yasha 2025-04-19 02:16:24 +00:00
parent 4b1b542430
commit c3526bb9f6
2 changed files with 236 additions and 7 deletions

View File

@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns.
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import logging import logging
import gymnasium as gym
from omegaconf import DictConfig # Added for config typing
env_logger = logging.getLogger(__name__) env_logger = logging.getLogger(__name__)
@ -15,6 +17,8 @@ class TradingEnv:
sigma_predictions: np.ndarray, sigma_predictions: np.ndarray,
p_cal_predictions: np.ndarray, p_cal_predictions: np.ndarray,
actual_returns: np.ndarray, actual_returns: np.ndarray,
bar_imputed_flags: np.ndarray, # Added imputed flags
config: DictConfig, # Added config
initial_capital: float = 10000.0, initial_capital: float = 10000.0,
transaction_cost: float = 0.0005, transaction_cost: float = 0.0005,
reward_scale: float = 100.0, reward_scale: float = 100.0,
@ -27,18 +31,22 @@ class TradingEnv:
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)). sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
p_cal_predictions: Calibrated probability of price increase (p_cal). p_cal_predictions: Calibrated probability of price increase (p_cal).
actual_returns: Actual log returns (y_ret). actual_returns: Actual log returns (y_ret).
bar_imputed_flags: Boolean array indicating if a bar was imputed.
config: OmegaConf configuration object.
initial_capital: Starting capital for simulation (used notionally in reward). initial_capital: Starting capital for simulation (used notionally in reward).
transaction_cost: Fractional cost per trade. transaction_cost: Fractional cost per trade.
reward_scale: Multiplier for the reward signal. reward_scale: Multiplier for the reward signal.
action_penalty_lambda: Coefficient for the action magnitude penalty (λ). action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
""" """
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \ assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \
"All input arrays must have the same length" "All input arrays (predictions, returns, imputed_flags) must have the same length"
self.mu = mu_predictions self.mu = mu_predictions
self.sigma = sigma_predictions self.sigma = sigma_predictions
self.p_cal = p_cal_predictions self.p_cal = p_cal_predictions
self.actual_returns = actual_returns self.actual_returns = actual_returns
self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags
self.config = config # Store config
self.initial_capital = initial_capital self.initial_capital = initial_capital
self.transaction_cost = transaction_cost self.transaction_cost = transaction_cost
@ -65,20 +73,36 @@ class TradingEnv:
self.state_dim = 5 self.state_dim = 5
self.action_dim = 1 self.action_dim = 1
# --- Define Gym Spaces ---
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32)
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
# --- End Define Gym Spaces ---
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.") env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
def _get_state(self) -> np.ndarray: def _get_state(self) -> np.ndarray:
"""Construct the state vector for the current step.""" """Construct the state vector for the current step."""
if self.current_step >= self.n_steps: if self.current_step >= self.n_steps:
# Handle episode end - return a dummy state or zeros
return np.zeros(self.state_dim, dtype=np.float32) return np.zeros(self.state_dim, dtype=np.float32)
mu_t = self.mu[self.current_step] mu_t = self.mu[self.current_step]
sigma_t = self.sigma[self.current_step] sigma_t = self.sigma[self.current_step]
p_cal_t = self.p_cal[self.current_step] p_cal_t = self.p_cal[self.current_step]
# Calculate edge based on p_cal shape (binary vs ternary)
if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3:
# Ternary: edge = max(P(up), P(down)) - P(flat)
# Assuming order [Down, Flat, Up] for p_cal_t
edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1]
elif isinstance(p_cal_t, (float, np.number)):
# Binary: edge = 2 * P(up) - 1
edge_t = 2 * p_cal_t - 1 edge_t = 2 * p_cal_t - 1
z_score_t = np.abs(mu_t) / (sigma_t + 1e-9) else:
env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.")
edge_t = 0.0
_EPS = 1e-9 # Define epsilon locally
z_score_t = np.abs(mu_t) / (sigma_t + _EPS)
# State uses position *before* the action for this step is taken # State uses position *before* the action for this step is taken
state = np.array([ state = np.array([
@ -108,11 +132,48 @@ class TradingEnv:
Returns: Returns:
tuple: (next_state, reward, done, info_dict) tuple: (next_state, reward, done, info_dict)
""" """
info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False}
if self.current_step >= self.n_steps: if self.current_step >= self.n_steps:
# Should not happen if 'done' is handled correctly, but as safeguard # Should not happen if 'done' is handled correctly, but as safeguard
env_logger.warning("Step called after environment finished.") env_logger.warning("Step called after environment finished.")
return self._get_state(), 0.0, True, {} return self._get_state(), 0.0, True, info
# --- Handle Imputed Bar --- #
imputed = self.bar_imputed[self.current_step]
if imputed:
mode = self.config.sac.imputed_handling
env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}")
if mode == "skip":
self.current_step += 1
next_state = self._get_state() # Get state for the *next* actual step
# Return 0 reward, not done, but indicate skip for buffer handling
info['is_imputed_step_skipped'] = True
return next_state, 0.0, False, info
elif mode == "hold":
# Action is forced to maintain current position
action = self.current_position
elif mode == "penalty":
# Calculate reward penalty based on config
target_position_penalty = np.clip(action, -1.0, 1.0)
reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2
# Update position based on agent's intended action (clipped)
self.current_position = target_position_penalty
# Update capital notionally (no actual return, only cost if implemented)
# Cost is implicitly 0 here as there's no trade size if pos doesn't change
# If penalty mode allowed position change, cost would apply.
# For simplicity, we don't add cost here for the penalty step.
self.current_step += 1
next_state = self._get_state()
scaled_reward = reward * self.reward_scale # Scale the penalty
done = self.current_step >= self.n_steps
info['capital'] = self.current_capital
info['position'] = self.current_position
return next_state, scaled_reward, done, info
# else: default behavior (treat as normal bar) - implicitly handled by falling through
# --- End Handle Imputed Bar --- #
# --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- #
# Action is the TARGET position for the *end* of this step # Action is the TARGET position for the *end* of this step
target_position = np.clip(action, -1.0, 1.0) target_position = np.clip(action, -1.0, 1.0)
trade_size = target_position - self.current_position trade_size = target_position - self.current_position
@ -150,7 +211,9 @@ class TradingEnv:
done = self.current_step >= self.n_steps or self.current_capital <= 0 done = self.current_step >= self.n_steps or self.current_capital <= 0
next_state = self._get_state() next_state = self._get_state()
info = {'capital': self.current_capital, 'position': self.current_position} # Update info dict (capital/position might have changed in normal step)
info['capital'] = self.current_capital
info['position'] = self.current_position
# Log step details periodically # Log step details periodically
# if self.current_step % 1000 == 0: # if self.current_step % 1000 == 0:

View File

@ -0,0 +1,166 @@
import pytest
import numpy as np
from omegaconf import OmegaConf
# Adjust import path based on structure
from gru_sac_predictor.src.trading_env import TradingEnv
# --- Test Fixtures ---
@pytest.fixture
def sample_env_data():
"""Provides sample data for initializing the TradingEnv."""
n_steps = 10
data = {
'mu_predictions': np.random.randn(n_steps) * 0.001,
'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005),
'p_cal_predictions': np.random.rand(n_steps),
'actual_returns': np.random.randn(n_steps) * 0.0015,
'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool)
}
return data
@pytest.fixture
def base_env_config():
"""Base configuration for the environment."""
return OmegaConf.create({
'sac': {
'imputed_handling': 'skip', # Default test mode
'action_penalty': 0.05
},
'environment': {
'initial_capital': 10000.0,
'transaction_cost': 0.0005,
'reward_scale': 100.0,
'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc
}
})
@pytest.fixture
def trading_env_instance(sample_env_data, base_env_config):
"""Creates a TradingEnv instance with default 'skip' mode."""
return TradingEnv(**sample_env_data, config=base_env_config)
# --- Test Functions ---
def test_env_initialization(trading_env_instance, sample_env_data):
assert trading_env_instance.n_steps == len(sample_env_data['actual_returns'])
assert trading_env_instance.current_step == 0
assert trading_env_instance.current_position == 0.0
assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags'])
def test_env_reset(trading_env_instance):
# Take a few steps
trading_env_instance.step(0.5)
trading_env_instance.step(-0.2)
assert trading_env_instance.current_step > 0
# Reset
initial_state = trading_env_instance.reset()
assert trading_env_instance.current_step == 0
assert trading_env_instance.current_position == 0.0
assert initial_state is not None
assert initial_state.shape == (trading_env_instance.state_dim,)
def test_env_step_normal(trading_env_instance):
# Test a normal step (step 0 is not imputed)
initial_pos = trading_env_instance.current_position
action = 0.7
next_state, reward, done, info = trading_env_instance.step(action)
assert trading_env_instance.current_step == 1
assert trading_env_instance.current_position == action # Position updates to action
assert not info['is_imputed_step_skipped']
assert not done
assert next_state is not None
# Reward calculation is complex, just check type/sign if needed
assert isinstance(reward, float)
def test_env_step_imputed_skip(trading_env_instance, sample_env_data):
# Step 2 is imputed in sample_env_data
trading_env_instance.step(0.5) # Step 0
trading_env_instance.step(0.6) # Step 1
assert trading_env_instance.current_step == 2
initial_pos_before_imputed = trading_env_instance.current_position
# Action for the imputed step (should be ignored by 'skip')
action_imputed = 0.9
next_state, reward, done, info = trading_env_instance.step(action_imputed)
# Should skip step 2 and now be at step 3
assert trading_env_instance.current_step == 3
# Position should NOT have changed from step 1
assert trading_env_instance.current_position == initial_pos_before_imputed
assert reward == 0.0 # Skip gives 0 reward
assert not done
assert info['is_imputed_step_skipped'] == True # Crucial check for buffer
# Check that the returned state is for step 3
expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3
np.testing.assert_array_almost_equal(next_state, expected_state_step_3)
def test_env_step_imputed_hold(sample_env_data, base_env_config):
cfg = base_env_config.copy()
cfg.sac.imputed_handling = 'hold'
env = TradingEnv(**sample_env_data, config=cfg)
# Step 2 is imputed
env.step(0.5) # Step 0
env.step(0.6) # Step 1
assert env.current_step == 2
position_before_imputed = env.current_position
# Action for the imputed step (should be overridden by 'hold')
action_imputed = -0.5
next_state, reward, done, info = env.step(action_imputed)
# Should process step 2 and move to step 3
assert env.current_step == 3
# Position should be the same as before the step
assert env.current_position == position_before_imputed
assert not info['is_imputed_step_skipped']
assert not done
# Reward should be calculated based on holding the position
expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1)
expected_cost = 0 # No trade size if holding
expected_penalty = 0 # No penalty in hold mode
expected_raw_reward = expected_pnl - expected_cost - expected_penalty
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
assert np.isclose(reward, expected_scaled_reward)
def test_env_step_imputed_penalty(sample_env_data, base_env_config):
cfg = base_env_config.copy()
cfg.sac.imputed_handling = 'penalty'
cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing
env = TradingEnv(**sample_env_data, config=cfg)
# Step 2 is imputed
env.step(0.5) # Step 0
env.step(0.6) # Step 1
assert env.current_step == 2
position_before_imputed = env.current_position # Should be 0.6
# Action for the imputed step
action_imputed = -0.2
next_state, reward, done, info = env.step(action_imputed)
# Should process step 2 and move to step 3
assert env.current_step == 3
# Position should update to the *agent's* action
assert env.current_position == np.clip(action_imputed, -1.0, 1.0)
assert not info['is_imputed_step_skipped']
assert not done
# Reward calculation is ONLY the penalty
expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
assert np.isclose(reward, expected_scaled_reward)
def test_env_done_condition(trading_env_instance, sample_env_data):
n_steps = len(sample_env_data['actual_returns'])
# Step through the environment
done = False
for i in range(n_steps):
_, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1))
if i < n_steps - 1:
assert not done
else:
assert done # Should be done on the last step