feat: Implement imputed bar handling in TradingEnv (skip, hold, penalty) and tests

This commit is contained in:
yasha 2025-04-19 02:16:24 +00:00
parent 4b1b542430
commit c3526bb9f6
2 changed files with 236 additions and 7 deletions

View File

@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns.
import numpy as np
import pandas as pd
import logging
import gymnasium as gym
from omegaconf import DictConfig # Added for config typing
env_logger = logging.getLogger(__name__)
@ -15,6 +17,8 @@ class TradingEnv:
sigma_predictions: np.ndarray,
p_cal_predictions: np.ndarray,
actual_returns: np.ndarray,
bar_imputed_flags: np.ndarray, # Added imputed flags
config: DictConfig, # Added config
initial_capital: float = 10000.0,
transaction_cost: float = 0.0005,
reward_scale: float = 100.0,
@ -27,18 +31,22 @@ class TradingEnv:
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
p_cal_predictions: Calibrated probability of price increase (p_cal).
actual_returns: Actual log returns (y_ret).
bar_imputed_flags: Boolean array indicating if a bar was imputed.
config: OmegaConf configuration object.
initial_capital: Starting capital for simulation (used notionally in reward).
transaction_cost: Fractional cost per trade.
reward_scale: Multiplier for the reward signal.
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
"""
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \
"All input arrays must have the same length"
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \
"All input arrays (predictions, returns, imputed_flags) must have the same length"
self.mu = mu_predictions
self.sigma = sigma_predictions
self.p_cal = p_cal_predictions
self.actual_returns = actual_returns
self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags
self.config = config # Store config
self.initial_capital = initial_capital
self.transaction_cost = transaction_cost
@ -65,20 +73,36 @@ class TradingEnv:
self.state_dim = 5
self.action_dim = 1
# --- Define Gym Spaces ---
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32)
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
# --- End Define Gym Spaces ---
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
def _get_state(self) -> np.ndarray:
"""Construct the state vector for the current step."""
if self.current_step >= self.n_steps:
# Handle episode end - return a dummy state or zeros
return np.zeros(self.state_dim, dtype=np.float32)
mu_t = self.mu[self.current_step]
sigma_t = self.sigma[self.current_step]
p_cal_t = self.p_cal[self.current_step]
edge_t = 2 * p_cal_t - 1
z_score_t = np.abs(mu_t) / (sigma_t + 1e-9)
# Calculate edge based on p_cal shape (binary vs ternary)
if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3:
# Ternary: edge = max(P(up), P(down)) - P(flat)
# Assuming order [Down, Flat, Up] for p_cal_t
edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1]
elif isinstance(p_cal_t, (float, np.number)):
# Binary: edge = 2 * P(up) - 1
edge_t = 2 * p_cal_t - 1
else:
env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.")
edge_t = 0.0
_EPS = 1e-9 # Define epsilon locally
z_score_t = np.abs(mu_t) / (sigma_t + _EPS)
# State uses position *before* the action for this step is taken
state = np.array([
@ -108,11 +132,48 @@ class TradingEnv:
Returns:
tuple: (next_state, reward, done, info_dict)
"""
info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False}
if self.current_step >= self.n_steps:
# Should not happen if 'done' is handled correctly, but as safeguard
env_logger.warning("Step called after environment finished.")
return self._get_state(), 0.0, True, {}
return self._get_state(), 0.0, True, info
# --- Handle Imputed Bar --- #
imputed = self.bar_imputed[self.current_step]
if imputed:
mode = self.config.sac.imputed_handling
env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}")
if mode == "skip":
self.current_step += 1
next_state = self._get_state() # Get state for the *next* actual step
# Return 0 reward, not done, but indicate skip for buffer handling
info['is_imputed_step_skipped'] = True
return next_state, 0.0, False, info
elif mode == "hold":
# Action is forced to maintain current position
action = self.current_position
elif mode == "penalty":
# Calculate reward penalty based on config
target_position_penalty = np.clip(action, -1.0, 1.0)
reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2
# Update position based on agent's intended action (clipped)
self.current_position = target_position_penalty
# Update capital notionally (no actual return, only cost if implemented)
# Cost is implicitly 0 here as there's no trade size if pos doesn't change
# If penalty mode allowed position change, cost would apply.
# For simplicity, we don't add cost here for the penalty step.
self.current_step += 1
next_state = self._get_state()
scaled_reward = reward * self.reward_scale # Scale the penalty
done = self.current_step >= self.n_steps
info['capital'] = self.current_capital
info['position'] = self.current_position
return next_state, scaled_reward, done, info
# else: default behavior (treat as normal bar) - implicitly handled by falling through
# --- End Handle Imputed Bar --- #
# --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- #
# Action is the TARGET position for the *end* of this step
target_position = np.clip(action, -1.0, 1.0)
trade_size = target_position - self.current_position
@ -150,7 +211,9 @@ class TradingEnv:
done = self.current_step >= self.n_steps or self.current_capital <= 0
next_state = self._get_state()
info = {'capital': self.current_capital, 'position': self.current_position}
# Update info dict (capital/position might have changed in normal step)
info['capital'] = self.current_capital
info['position'] = self.current_position
# Log step details periodically
# if self.current_step % 1000 == 0:

View File

@ -0,0 +1,166 @@
import pytest
import numpy as np
from omegaconf import OmegaConf
# Adjust import path based on structure
from gru_sac_predictor.src.trading_env import TradingEnv
# --- Test Fixtures ---
@pytest.fixture
def sample_env_data():
"""Provides sample data for initializing the TradingEnv."""
n_steps = 10
data = {
'mu_predictions': np.random.randn(n_steps) * 0.001,
'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005),
'p_cal_predictions': np.random.rand(n_steps),
'actual_returns': np.random.randn(n_steps) * 0.0015,
'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool)
}
return data
@pytest.fixture
def base_env_config():
"""Base configuration for the environment."""
return OmegaConf.create({
'sac': {
'imputed_handling': 'skip', # Default test mode
'action_penalty': 0.05
},
'environment': {
'initial_capital': 10000.0,
'transaction_cost': 0.0005,
'reward_scale': 100.0,
'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc
}
})
@pytest.fixture
def trading_env_instance(sample_env_data, base_env_config):
"""Creates a TradingEnv instance with default 'skip' mode."""
return TradingEnv(**sample_env_data, config=base_env_config)
# --- Test Functions ---
def test_env_initialization(trading_env_instance, sample_env_data):
assert trading_env_instance.n_steps == len(sample_env_data['actual_returns'])
assert trading_env_instance.current_step == 0
assert trading_env_instance.current_position == 0.0
assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags'])
def test_env_reset(trading_env_instance):
# Take a few steps
trading_env_instance.step(0.5)
trading_env_instance.step(-0.2)
assert trading_env_instance.current_step > 0
# Reset
initial_state = trading_env_instance.reset()
assert trading_env_instance.current_step == 0
assert trading_env_instance.current_position == 0.0
assert initial_state is not None
assert initial_state.shape == (trading_env_instance.state_dim,)
def test_env_step_normal(trading_env_instance):
# Test a normal step (step 0 is not imputed)
initial_pos = trading_env_instance.current_position
action = 0.7
next_state, reward, done, info = trading_env_instance.step(action)
assert trading_env_instance.current_step == 1
assert trading_env_instance.current_position == action # Position updates to action
assert not info['is_imputed_step_skipped']
assert not done
assert next_state is not None
# Reward calculation is complex, just check type/sign if needed
assert isinstance(reward, float)
def test_env_step_imputed_skip(trading_env_instance, sample_env_data):
# Step 2 is imputed in sample_env_data
trading_env_instance.step(0.5) # Step 0
trading_env_instance.step(0.6) # Step 1
assert trading_env_instance.current_step == 2
initial_pos_before_imputed = trading_env_instance.current_position
# Action for the imputed step (should be ignored by 'skip')
action_imputed = 0.9
next_state, reward, done, info = trading_env_instance.step(action_imputed)
# Should skip step 2 and now be at step 3
assert trading_env_instance.current_step == 3
# Position should NOT have changed from step 1
assert trading_env_instance.current_position == initial_pos_before_imputed
assert reward == 0.0 # Skip gives 0 reward
assert not done
assert info['is_imputed_step_skipped'] == True # Crucial check for buffer
# Check that the returned state is for step 3
expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3
np.testing.assert_array_almost_equal(next_state, expected_state_step_3)
def test_env_step_imputed_hold(sample_env_data, base_env_config):
cfg = base_env_config.copy()
cfg.sac.imputed_handling = 'hold'
env = TradingEnv(**sample_env_data, config=cfg)
# Step 2 is imputed
env.step(0.5) # Step 0
env.step(0.6) # Step 1
assert env.current_step == 2
position_before_imputed = env.current_position
# Action for the imputed step (should be overridden by 'hold')
action_imputed = -0.5
next_state, reward, done, info = env.step(action_imputed)
# Should process step 2 and move to step 3
assert env.current_step == 3
# Position should be the same as before the step
assert env.current_position == position_before_imputed
assert not info['is_imputed_step_skipped']
assert not done
# Reward should be calculated based on holding the position
expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1)
expected_cost = 0 # No trade size if holding
expected_penalty = 0 # No penalty in hold mode
expected_raw_reward = expected_pnl - expected_cost - expected_penalty
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
assert np.isclose(reward, expected_scaled_reward)
def test_env_step_imputed_penalty(sample_env_data, base_env_config):
cfg = base_env_config.copy()
cfg.sac.imputed_handling = 'penalty'
cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing
env = TradingEnv(**sample_env_data, config=cfg)
# Step 2 is imputed
env.step(0.5) # Step 0
env.step(0.6) # Step 1
assert env.current_step == 2
position_before_imputed = env.current_position # Should be 0.6
# Action for the imputed step
action_imputed = -0.2
next_state, reward, done, info = env.step(action_imputed)
# Should process step 2 and move to step 3
assert env.current_step == 3
# Position should update to the *agent's* action
assert env.current_position == np.clip(action_imputed, -1.0, 1.0)
assert not info['is_imputed_step_skipped']
assert not done
# Reward calculation is ONLY the penalty
expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
assert np.isclose(reward, expected_scaled_reward)
def test_env_done_condition(trading_env_instance, sample_env_data):
n_steps = len(sample_env_data['actual_returns'])
# Step through the environment
done = False
for i in range(n_steps):
_, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1))
if i < n_steps - 1:
assert not done
else:
assert done # Should be done on the last step