feat: Implement imputed bar handling in TradingEnv (skip, hold, penalty) and tests
This commit is contained in:
parent
4b1b542430
commit
c3526bb9f6
@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns.
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
import gymnasium as gym
|
||||
from omegaconf import DictConfig # Added for config typing
|
||||
|
||||
env_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,6 +17,8 @@ class TradingEnv:
|
||||
sigma_predictions: np.ndarray,
|
||||
p_cal_predictions: np.ndarray,
|
||||
actual_returns: np.ndarray,
|
||||
bar_imputed_flags: np.ndarray, # Added imputed flags
|
||||
config: DictConfig, # Added config
|
||||
initial_capital: float = 10000.0,
|
||||
transaction_cost: float = 0.0005,
|
||||
reward_scale: float = 100.0,
|
||||
@ -27,18 +31,22 @@ class TradingEnv:
|
||||
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
|
||||
p_cal_predictions: Calibrated probability of price increase (p_cal).
|
||||
actual_returns: Actual log returns (y_ret).
|
||||
bar_imputed_flags: Boolean array indicating if a bar was imputed.
|
||||
config: OmegaConf configuration object.
|
||||
initial_capital: Starting capital for simulation (used notionally in reward).
|
||||
transaction_cost: Fractional cost per trade.
|
||||
reward_scale: Multiplier for the reward signal.
|
||||
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
|
||||
"""
|
||||
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \
|
||||
"All input arrays must have the same length"
|
||||
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \
|
||||
"All input arrays (predictions, returns, imputed_flags) must have the same length"
|
||||
|
||||
self.mu = mu_predictions
|
||||
self.sigma = sigma_predictions
|
||||
self.p_cal = p_cal_predictions
|
||||
self.actual_returns = actual_returns
|
||||
self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags
|
||||
self.config = config # Store config
|
||||
|
||||
self.initial_capital = initial_capital
|
||||
self.transaction_cost = transaction_cost
|
||||
@ -65,20 +73,36 @@ class TradingEnv:
|
||||
self.state_dim = 5
|
||||
self.action_dim = 1
|
||||
|
||||
# --- Define Gym Spaces ---
|
||||
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32)
|
||||
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
|
||||
# --- End Define Gym Spaces ---
|
||||
|
||||
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
|
||||
|
||||
def _get_state(self) -> np.ndarray:
|
||||
"""Construct the state vector for the current step."""
|
||||
if self.current_step >= self.n_steps:
|
||||
# Handle episode end - return a dummy state or zeros
|
||||
return np.zeros(self.state_dim, dtype=np.float32)
|
||||
|
||||
mu_t = self.mu[self.current_step]
|
||||
sigma_t = self.sigma[self.current_step]
|
||||
p_cal_t = self.p_cal[self.current_step]
|
||||
|
||||
edge_t = 2 * p_cal_t - 1
|
||||
z_score_t = np.abs(mu_t) / (sigma_t + 1e-9)
|
||||
# Calculate edge based on p_cal shape (binary vs ternary)
|
||||
if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3:
|
||||
# Ternary: edge = max(P(up), P(down)) - P(flat)
|
||||
# Assuming order [Down, Flat, Up] for p_cal_t
|
||||
edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1]
|
||||
elif isinstance(p_cal_t, (float, np.number)):
|
||||
# Binary: edge = 2 * P(up) - 1
|
||||
edge_t = 2 * p_cal_t - 1
|
||||
else:
|
||||
env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.")
|
||||
edge_t = 0.0
|
||||
|
||||
_EPS = 1e-9 # Define epsilon locally
|
||||
z_score_t = np.abs(mu_t) / (sigma_t + _EPS)
|
||||
|
||||
# State uses position *before* the action for this step is taken
|
||||
state = np.array([
|
||||
@ -108,11 +132,48 @@ class TradingEnv:
|
||||
Returns:
|
||||
tuple: (next_state, reward, done, info_dict)
|
||||
"""
|
||||
info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False}
|
||||
|
||||
if self.current_step >= self.n_steps:
|
||||
# Should not happen if 'done' is handled correctly, but as safeguard
|
||||
env_logger.warning("Step called after environment finished.")
|
||||
return self._get_state(), 0.0, True, {}
|
||||
return self._get_state(), 0.0, True, info
|
||||
|
||||
# --- Handle Imputed Bar --- #
|
||||
imputed = self.bar_imputed[self.current_step]
|
||||
if imputed:
|
||||
mode = self.config.sac.imputed_handling
|
||||
env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}")
|
||||
if mode == "skip":
|
||||
self.current_step += 1
|
||||
next_state = self._get_state() # Get state for the *next* actual step
|
||||
# Return 0 reward, not done, but indicate skip for buffer handling
|
||||
info['is_imputed_step_skipped'] = True
|
||||
return next_state, 0.0, False, info
|
||||
elif mode == "hold":
|
||||
# Action is forced to maintain current position
|
||||
action = self.current_position
|
||||
elif mode == "penalty":
|
||||
# Calculate reward penalty based on config
|
||||
target_position_penalty = np.clip(action, -1.0, 1.0)
|
||||
reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2
|
||||
# Update position based on agent's intended action (clipped)
|
||||
self.current_position = target_position_penalty
|
||||
# Update capital notionally (no actual return, only cost if implemented)
|
||||
# Cost is implicitly 0 here as there's no trade size if pos doesn't change
|
||||
# If penalty mode allowed position change, cost would apply.
|
||||
# For simplicity, we don't add cost here for the penalty step.
|
||||
self.current_step += 1
|
||||
next_state = self._get_state()
|
||||
scaled_reward = reward * self.reward_scale # Scale the penalty
|
||||
done = self.current_step >= self.n_steps
|
||||
info['capital'] = self.current_capital
|
||||
info['position'] = self.current_position
|
||||
return next_state, scaled_reward, done, info
|
||||
# else: default behavior (treat as normal bar) - implicitly handled by falling through
|
||||
# --- End Handle Imputed Bar --- #
|
||||
|
||||
# --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- #
|
||||
# Action is the TARGET position for the *end* of this step
|
||||
target_position = np.clip(action, -1.0, 1.0)
|
||||
trade_size = target_position - self.current_position
|
||||
@ -150,7 +211,9 @@ class TradingEnv:
|
||||
done = self.current_step >= self.n_steps or self.current_capital <= 0
|
||||
|
||||
next_state = self._get_state()
|
||||
info = {'capital': self.current_capital, 'position': self.current_position}
|
||||
# Update info dict (capital/position might have changed in normal step)
|
||||
info['capital'] = self.current_capital
|
||||
info['position'] = self.current_position
|
||||
|
||||
# Log step details periodically
|
||||
# if self.current_step % 1000 == 0:
|
||||
|
||||
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
@ -0,0 +1,166 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Adjust import path based on structure
|
||||
from gru_sac_predictor.src.trading_env import TradingEnv
|
||||
|
||||
# --- Test Fixtures ---
|
||||
|
||||
@pytest.fixture
|
||||
def sample_env_data():
|
||||
"""Provides sample data for initializing the TradingEnv."""
|
||||
n_steps = 10
|
||||
data = {
|
||||
'mu_predictions': np.random.randn(n_steps) * 0.001,
|
||||
'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005),
|
||||
'p_cal_predictions': np.random.rand(n_steps),
|
||||
'actual_returns': np.random.randn(n_steps) * 0.0015,
|
||||
'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool)
|
||||
}
|
||||
return data
|
||||
|
||||
@pytest.fixture
|
||||
def base_env_config():
|
||||
"""Base configuration for the environment."""
|
||||
return OmegaConf.create({
|
||||
'sac': {
|
||||
'imputed_handling': 'skip', # Default test mode
|
||||
'action_penalty': 0.05
|
||||
},
|
||||
'environment': {
|
||||
'initial_capital': 10000.0,
|
||||
'transaction_cost': 0.0005,
|
||||
'reward_scale': 100.0,
|
||||
'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc
|
||||
}
|
||||
})
|
||||
|
||||
@pytest.fixture
|
||||
def trading_env_instance(sample_env_data, base_env_config):
|
||||
"""Creates a TradingEnv instance with default 'skip' mode."""
|
||||
return TradingEnv(**sample_env_data, config=base_env_config)
|
||||
|
||||
# --- Test Functions ---
|
||||
|
||||
def test_env_initialization(trading_env_instance, sample_env_data):
|
||||
assert trading_env_instance.n_steps == len(sample_env_data['actual_returns'])
|
||||
assert trading_env_instance.current_step == 0
|
||||
assert trading_env_instance.current_position == 0.0
|
||||
assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags'])
|
||||
|
||||
def test_env_reset(trading_env_instance):
|
||||
# Take a few steps
|
||||
trading_env_instance.step(0.5)
|
||||
trading_env_instance.step(-0.2)
|
||||
assert trading_env_instance.current_step > 0
|
||||
# Reset
|
||||
initial_state = trading_env_instance.reset()
|
||||
assert trading_env_instance.current_step == 0
|
||||
assert trading_env_instance.current_position == 0.0
|
||||
assert initial_state is not None
|
||||
assert initial_state.shape == (trading_env_instance.state_dim,)
|
||||
|
||||
def test_env_step_normal(trading_env_instance):
|
||||
# Test a normal step (step 0 is not imputed)
|
||||
initial_pos = trading_env_instance.current_position
|
||||
action = 0.7
|
||||
next_state, reward, done, info = trading_env_instance.step(action)
|
||||
|
||||
assert trading_env_instance.current_step == 1
|
||||
assert trading_env_instance.current_position == action # Position updates to action
|
||||
assert not info['is_imputed_step_skipped']
|
||||
assert not done
|
||||
assert next_state is not None
|
||||
# Reward calculation is complex, just check type/sign if needed
|
||||
assert isinstance(reward, float)
|
||||
|
||||
def test_env_step_imputed_skip(trading_env_instance, sample_env_data):
|
||||
# Step 2 is imputed in sample_env_data
|
||||
trading_env_instance.step(0.5) # Step 0
|
||||
trading_env_instance.step(0.6) # Step 1
|
||||
assert trading_env_instance.current_step == 2
|
||||
initial_pos_before_imputed = trading_env_instance.current_position
|
||||
|
||||
# Action for the imputed step (should be ignored by 'skip')
|
||||
action_imputed = 0.9
|
||||
next_state, reward, done, info = trading_env_instance.step(action_imputed)
|
||||
|
||||
# Should skip step 2 and now be at step 3
|
||||
assert trading_env_instance.current_step == 3
|
||||
# Position should NOT have changed from step 1
|
||||
assert trading_env_instance.current_position == initial_pos_before_imputed
|
||||
assert reward == 0.0 # Skip gives 0 reward
|
||||
assert not done
|
||||
assert info['is_imputed_step_skipped'] == True # Crucial check for buffer
|
||||
# Check that the returned state is for step 3
|
||||
expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3
|
||||
np.testing.assert_array_almost_equal(next_state, expected_state_step_3)
|
||||
|
||||
def test_env_step_imputed_hold(sample_env_data, base_env_config):
|
||||
cfg = base_env_config.copy()
|
||||
cfg.sac.imputed_handling = 'hold'
|
||||
env = TradingEnv(**sample_env_data, config=cfg)
|
||||
|
||||
# Step 2 is imputed
|
||||
env.step(0.5) # Step 0
|
||||
env.step(0.6) # Step 1
|
||||
assert env.current_step == 2
|
||||
position_before_imputed = env.current_position
|
||||
|
||||
# Action for the imputed step (should be overridden by 'hold')
|
||||
action_imputed = -0.5
|
||||
next_state, reward, done, info = env.step(action_imputed)
|
||||
|
||||
# Should process step 2 and move to step 3
|
||||
assert env.current_step == 3
|
||||
# Position should be the same as before the step
|
||||
assert env.current_position == position_before_imputed
|
||||
assert not info['is_imputed_step_skipped']
|
||||
assert not done
|
||||
# Reward should be calculated based on holding the position
|
||||
expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1)
|
||||
expected_cost = 0 # No trade size if holding
|
||||
expected_penalty = 0 # No penalty in hold mode
|
||||
expected_raw_reward = expected_pnl - expected_cost - expected_penalty
|
||||
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||
assert np.isclose(reward, expected_scaled_reward)
|
||||
|
||||
def test_env_step_imputed_penalty(sample_env_data, base_env_config):
|
||||
cfg = base_env_config.copy()
|
||||
cfg.sac.imputed_handling = 'penalty'
|
||||
cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing
|
||||
env = TradingEnv(**sample_env_data, config=cfg)
|
||||
|
||||
# Step 2 is imputed
|
||||
env.step(0.5) # Step 0
|
||||
env.step(0.6) # Step 1
|
||||
assert env.current_step == 2
|
||||
position_before_imputed = env.current_position # Should be 0.6
|
||||
|
||||
# Action for the imputed step
|
||||
action_imputed = -0.2
|
||||
next_state, reward, done, info = env.step(action_imputed)
|
||||
|
||||
# Should process step 2 and move to step 3
|
||||
assert env.current_step == 3
|
||||
# Position should update to the *agent's* action
|
||||
assert env.current_position == np.clip(action_imputed, -1.0, 1.0)
|
||||
assert not info['is_imputed_step_skipped']
|
||||
assert not done
|
||||
|
||||
# Reward calculation is ONLY the penalty
|
||||
expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2
|
||||
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||
assert np.isclose(reward, expected_scaled_reward)
|
||||
|
||||
def test_env_done_condition(trading_env_instance, sample_env_data):
|
||||
n_steps = len(sample_env_data['actual_returns'])
|
||||
# Step through the environment
|
||||
done = False
|
||||
for i in range(n_steps):
|
||||
_, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1))
|
||||
if i < n_steps - 1:
|
||||
assert not done
|
||||
else:
|
||||
assert done # Should be done on the last step
|
||||
Loading…
x
Reference in New Issue
Block a user