feat: Implement imputed bar handling in TradingEnv (skip, hold, penalty) and tests
This commit is contained in:
parent
4b1b542430
commit
c3526bb9f6
@ -6,6 +6,8 @@ Uses pre-calculated GRU predictions (mu, sigma, p_cal) and actual returns.
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import logging
|
import logging
|
||||||
|
import gymnasium as gym
|
||||||
|
from omegaconf import DictConfig # Added for config typing
|
||||||
|
|
||||||
env_logger = logging.getLogger(__name__)
|
env_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -15,6 +17,8 @@ class TradingEnv:
|
|||||||
sigma_predictions: np.ndarray,
|
sigma_predictions: np.ndarray,
|
||||||
p_cal_predictions: np.ndarray,
|
p_cal_predictions: np.ndarray,
|
||||||
actual_returns: np.ndarray,
|
actual_returns: np.ndarray,
|
||||||
|
bar_imputed_flags: np.ndarray, # Added imputed flags
|
||||||
|
config: DictConfig, # Added config
|
||||||
initial_capital: float = 10000.0,
|
initial_capital: float = 10000.0,
|
||||||
transaction_cost: float = 0.0005,
|
transaction_cost: float = 0.0005,
|
||||||
reward_scale: float = 100.0,
|
reward_scale: float = 100.0,
|
||||||
@ -27,18 +31,22 @@ class TradingEnv:
|
|||||||
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
|
sigma_predictions: Predicted volatility (σ̂ = exp(log σ̂)).
|
||||||
p_cal_predictions: Calibrated probability of price increase (p_cal).
|
p_cal_predictions: Calibrated probability of price increase (p_cal).
|
||||||
actual_returns: Actual log returns (y_ret).
|
actual_returns: Actual log returns (y_ret).
|
||||||
|
bar_imputed_flags: Boolean array indicating if a bar was imputed.
|
||||||
|
config: OmegaConf configuration object.
|
||||||
initial_capital: Starting capital for simulation (used notionally in reward).
|
initial_capital: Starting capital for simulation (used notionally in reward).
|
||||||
transaction_cost: Fractional cost per trade.
|
transaction_cost: Fractional cost per trade.
|
||||||
reward_scale: Multiplier for the reward signal.
|
reward_scale: Multiplier for the reward signal.
|
||||||
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
|
action_penalty_lambda: Coefficient for the action magnitude penalty (λ).
|
||||||
"""
|
"""
|
||||||
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns), \
|
assert len(mu_predictions) == len(sigma_predictions) == len(p_cal_predictions) == len(actual_returns) == len(bar_imputed_flags), \
|
||||||
"All input arrays must have the same length"
|
"All input arrays (predictions, returns, imputed_flags) must have the same length"
|
||||||
|
|
||||||
self.mu = mu_predictions
|
self.mu = mu_predictions
|
||||||
self.sigma = sigma_predictions
|
self.sigma = sigma_predictions
|
||||||
self.p_cal = p_cal_predictions
|
self.p_cal = p_cal_predictions
|
||||||
self.actual_returns = actual_returns
|
self.actual_returns = actual_returns
|
||||||
|
self.bar_imputed = bar_imputed_flags.astype(bool) # Store imputed flags
|
||||||
|
self.config = config # Store config
|
||||||
|
|
||||||
self.initial_capital = initial_capital
|
self.initial_capital = initial_capital
|
||||||
self.transaction_cost = transaction_cost
|
self.transaction_cost = transaction_cost
|
||||||
@ -65,20 +73,36 @@ class TradingEnv:
|
|||||||
self.state_dim = 5
|
self.state_dim = 5
|
||||||
self.action_dim = 1
|
self.action_dim = 1
|
||||||
|
|
||||||
|
# --- Define Gym Spaces ---
|
||||||
|
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.action_dim,), dtype=np.float32)
|
||||||
|
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
|
||||||
|
# --- End Define Gym Spaces ---
|
||||||
|
|
||||||
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
|
env_logger.info(f"TradingEnv initialized with {self.n_steps} steps.")
|
||||||
|
|
||||||
def _get_state(self) -> np.ndarray:
|
def _get_state(self) -> np.ndarray:
|
||||||
"""Construct the state vector for the current step."""
|
"""Construct the state vector for the current step."""
|
||||||
if self.current_step >= self.n_steps:
|
if self.current_step >= self.n_steps:
|
||||||
# Handle episode end - return a dummy state or zeros
|
|
||||||
return np.zeros(self.state_dim, dtype=np.float32)
|
return np.zeros(self.state_dim, dtype=np.float32)
|
||||||
|
|
||||||
mu_t = self.mu[self.current_step]
|
mu_t = self.mu[self.current_step]
|
||||||
sigma_t = self.sigma[self.current_step]
|
sigma_t = self.sigma[self.current_step]
|
||||||
p_cal_t = self.p_cal[self.current_step]
|
p_cal_t = self.p_cal[self.current_step]
|
||||||
|
|
||||||
edge_t = 2 * p_cal_t - 1
|
# Calculate edge based on p_cal shape (binary vs ternary)
|
||||||
z_score_t = np.abs(mu_t) / (sigma_t + 1e-9)
|
if isinstance(p_cal_t, (np.ndarray, list)) and len(p_cal_t) == 3:
|
||||||
|
# Ternary: edge = max(P(up), P(down)) - P(flat)
|
||||||
|
# Assuming order [Down, Flat, Up] for p_cal_t
|
||||||
|
edge_t = max(p_cal_t[2], p_cal_t[0]) - p_cal_t[1]
|
||||||
|
elif isinstance(p_cal_t, (float, np.number)):
|
||||||
|
# Binary: edge = 2 * P(up) - 1
|
||||||
|
edge_t = 2 * p_cal_t - 1
|
||||||
|
else:
|
||||||
|
env_logger.error(f"Unexpected type/shape for p_cal_t at step {self.current_step}: {p_cal_t}. Using edge=0.")
|
||||||
|
edge_t = 0.0
|
||||||
|
|
||||||
|
_EPS = 1e-9 # Define epsilon locally
|
||||||
|
z_score_t = np.abs(mu_t) / (sigma_t + _EPS)
|
||||||
|
|
||||||
# State uses position *before* the action for this step is taken
|
# State uses position *before* the action for this step is taken
|
||||||
state = np.array([
|
state = np.array([
|
||||||
@ -108,11 +132,48 @@ class TradingEnv:
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: (next_state, reward, done, info_dict)
|
tuple: (next_state, reward, done, info_dict)
|
||||||
"""
|
"""
|
||||||
|
info = {'capital': self.current_capital, 'position': self.current_position, 'is_imputed_step_skipped': False}
|
||||||
|
|
||||||
if self.current_step >= self.n_steps:
|
if self.current_step >= self.n_steps:
|
||||||
# Should not happen if 'done' is handled correctly, but as safeguard
|
# Should not happen if 'done' is handled correctly, but as safeguard
|
||||||
env_logger.warning("Step called after environment finished.")
|
env_logger.warning("Step called after environment finished.")
|
||||||
return self._get_state(), 0.0, True, {}
|
return self._get_state(), 0.0, True, info
|
||||||
|
|
||||||
|
# --- Handle Imputed Bar --- #
|
||||||
|
imputed = self.bar_imputed[self.current_step]
|
||||||
|
if imputed:
|
||||||
|
mode = self.config.sac.imputed_handling
|
||||||
|
env_logger.debug(f"SAC step {self.current_step} on imputed bar: handling={mode}")
|
||||||
|
if mode == "skip":
|
||||||
|
self.current_step += 1
|
||||||
|
next_state = self._get_state() # Get state for the *next* actual step
|
||||||
|
# Return 0 reward, not done, but indicate skip for buffer handling
|
||||||
|
info['is_imputed_step_skipped'] = True
|
||||||
|
return next_state, 0.0, False, info
|
||||||
|
elif mode == "hold":
|
||||||
|
# Action is forced to maintain current position
|
||||||
|
action = self.current_position
|
||||||
|
elif mode == "penalty":
|
||||||
|
# Calculate reward penalty based on config
|
||||||
|
target_position_penalty = np.clip(action, -1.0, 1.0)
|
||||||
|
reward = -self.config.sac.action_penalty * (target_position_penalty - self.current_position)**2
|
||||||
|
# Update position based on agent's intended action (clipped)
|
||||||
|
self.current_position = target_position_penalty
|
||||||
|
# Update capital notionally (no actual return, only cost if implemented)
|
||||||
|
# Cost is implicitly 0 here as there's no trade size if pos doesn't change
|
||||||
|
# If penalty mode allowed position change, cost would apply.
|
||||||
|
# For simplicity, we don't add cost here for the penalty step.
|
||||||
|
self.current_step += 1
|
||||||
|
next_state = self._get_state()
|
||||||
|
scaled_reward = reward * self.reward_scale # Scale the penalty
|
||||||
|
done = self.current_step >= self.n_steps
|
||||||
|
info['capital'] = self.current_capital
|
||||||
|
info['position'] = self.current_position
|
||||||
|
return next_state, scaled_reward, done, info
|
||||||
|
# else: default behavior (treat as normal bar) - implicitly handled by falling through
|
||||||
|
# --- End Handle Imputed Bar --- #
|
||||||
|
|
||||||
|
# --- Normal Step Logic (if not imputed or handling mode allows fallthrough like 'hold') --- #
|
||||||
# Action is the TARGET position for the *end* of this step
|
# Action is the TARGET position for the *end* of this step
|
||||||
target_position = np.clip(action, -1.0, 1.0)
|
target_position = np.clip(action, -1.0, 1.0)
|
||||||
trade_size = target_position - self.current_position
|
trade_size = target_position - self.current_position
|
||||||
@ -150,7 +211,9 @@ class TradingEnv:
|
|||||||
done = self.current_step >= self.n_steps or self.current_capital <= 0
|
done = self.current_step >= self.n_steps or self.current_capital <= 0
|
||||||
|
|
||||||
next_state = self._get_state()
|
next_state = self._get_state()
|
||||||
info = {'capital': self.current_capital, 'position': self.current_position}
|
# Update info dict (capital/position might have changed in normal step)
|
||||||
|
info['capital'] = self.current_capital
|
||||||
|
info['position'] = self.current_position
|
||||||
|
|
||||||
# Log step details periodically
|
# Log step details periodically
|
||||||
# if self.current_step % 1000 == 0:
|
# if self.current_step % 1000 == 0:
|
||||||
|
|||||||
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
166
gru_sac_predictor/tests/test_trading_env.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
# Adjust import path based on structure
|
||||||
|
from gru_sac_predictor.src.trading_env import TradingEnv
|
||||||
|
|
||||||
|
# --- Test Fixtures ---
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_env_data():
|
||||||
|
"""Provides sample data for initializing the TradingEnv."""
|
||||||
|
n_steps = 10
|
||||||
|
data = {
|
||||||
|
'mu_predictions': np.random.randn(n_steps) * 0.001,
|
||||||
|
'sigma_predictions': np.abs(np.random.randn(n_steps) * 0.002 + 0.005),
|
||||||
|
'p_cal_predictions': np.random.rand(n_steps),
|
||||||
|
'actual_returns': np.random.randn(n_steps) * 0.0015,
|
||||||
|
'bar_imputed_flags': np.array([False, False, True, False, True, True, False, False, True, False], dtype=bool)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_env_config():
|
||||||
|
"""Base configuration for the environment."""
|
||||||
|
return OmegaConf.create({
|
||||||
|
'sac': {
|
||||||
|
'imputed_handling': 'skip', # Default test mode
|
||||||
|
'action_penalty': 0.05
|
||||||
|
},
|
||||||
|
'environment': {
|
||||||
|
'initial_capital': 10000.0,
|
||||||
|
'transaction_cost': 0.0005,
|
||||||
|
'reward_scale': 100.0,
|
||||||
|
'action_penalty_lambda': 0.0 # Usually overridden by transaction_cost calc
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def trading_env_instance(sample_env_data, base_env_config):
|
||||||
|
"""Creates a TradingEnv instance with default 'skip' mode."""
|
||||||
|
return TradingEnv(**sample_env_data, config=base_env_config)
|
||||||
|
|
||||||
|
# --- Test Functions ---
|
||||||
|
|
||||||
|
def test_env_initialization(trading_env_instance, sample_env_data):
|
||||||
|
assert trading_env_instance.n_steps == len(sample_env_data['actual_returns'])
|
||||||
|
assert trading_env_instance.current_step == 0
|
||||||
|
assert trading_env_instance.current_position == 0.0
|
||||||
|
assert np.array_equal(trading_env_instance.bar_imputed, sample_env_data['bar_imputed_flags'])
|
||||||
|
|
||||||
|
def test_env_reset(trading_env_instance):
|
||||||
|
# Take a few steps
|
||||||
|
trading_env_instance.step(0.5)
|
||||||
|
trading_env_instance.step(-0.2)
|
||||||
|
assert trading_env_instance.current_step > 0
|
||||||
|
# Reset
|
||||||
|
initial_state = trading_env_instance.reset()
|
||||||
|
assert trading_env_instance.current_step == 0
|
||||||
|
assert trading_env_instance.current_position == 0.0
|
||||||
|
assert initial_state is not None
|
||||||
|
assert initial_state.shape == (trading_env_instance.state_dim,)
|
||||||
|
|
||||||
|
def test_env_step_normal(trading_env_instance):
|
||||||
|
# Test a normal step (step 0 is not imputed)
|
||||||
|
initial_pos = trading_env_instance.current_position
|
||||||
|
action = 0.7
|
||||||
|
next_state, reward, done, info = trading_env_instance.step(action)
|
||||||
|
|
||||||
|
assert trading_env_instance.current_step == 1
|
||||||
|
assert trading_env_instance.current_position == action # Position updates to action
|
||||||
|
assert not info['is_imputed_step_skipped']
|
||||||
|
assert not done
|
||||||
|
assert next_state is not None
|
||||||
|
# Reward calculation is complex, just check type/sign if needed
|
||||||
|
assert isinstance(reward, float)
|
||||||
|
|
||||||
|
def test_env_step_imputed_skip(trading_env_instance, sample_env_data):
|
||||||
|
# Step 2 is imputed in sample_env_data
|
||||||
|
trading_env_instance.step(0.5) # Step 0
|
||||||
|
trading_env_instance.step(0.6) # Step 1
|
||||||
|
assert trading_env_instance.current_step == 2
|
||||||
|
initial_pos_before_imputed = trading_env_instance.current_position
|
||||||
|
|
||||||
|
# Action for the imputed step (should be ignored by 'skip')
|
||||||
|
action_imputed = 0.9
|
||||||
|
next_state, reward, done, info = trading_env_instance.step(action_imputed)
|
||||||
|
|
||||||
|
# Should skip step 2 and now be at step 3
|
||||||
|
assert trading_env_instance.current_step == 3
|
||||||
|
# Position should NOT have changed from step 1
|
||||||
|
assert trading_env_instance.current_position == initial_pos_before_imputed
|
||||||
|
assert reward == 0.0 # Skip gives 0 reward
|
||||||
|
assert not done
|
||||||
|
assert info['is_imputed_step_skipped'] == True # Crucial check for buffer
|
||||||
|
# Check that the returned state is for step 3
|
||||||
|
expected_state_step_3 = trading_env_instance._get_state() # Get state now that we are at step 3
|
||||||
|
np.testing.assert_array_almost_equal(next_state, expected_state_step_3)
|
||||||
|
|
||||||
|
def test_env_step_imputed_hold(sample_env_data, base_env_config):
|
||||||
|
cfg = base_env_config.copy()
|
||||||
|
cfg.sac.imputed_handling = 'hold'
|
||||||
|
env = TradingEnv(**sample_env_data, config=cfg)
|
||||||
|
|
||||||
|
# Step 2 is imputed
|
||||||
|
env.step(0.5) # Step 0
|
||||||
|
env.step(0.6) # Step 1
|
||||||
|
assert env.current_step == 2
|
||||||
|
position_before_imputed = env.current_position
|
||||||
|
|
||||||
|
# Action for the imputed step (should be overridden by 'hold')
|
||||||
|
action_imputed = -0.5
|
||||||
|
next_state, reward, done, info = env.step(action_imputed)
|
||||||
|
|
||||||
|
# Should process step 2 and move to step 3
|
||||||
|
assert env.current_step == 3
|
||||||
|
# Position should be the same as before the step
|
||||||
|
assert env.current_position == position_before_imputed
|
||||||
|
assert not info['is_imputed_step_skipped']
|
||||||
|
assert not done
|
||||||
|
# Reward should be calculated based on holding the position
|
||||||
|
expected_pnl = position_before_imputed * (np.exp(sample_env_data['actual_returns'][2]) - 1)
|
||||||
|
expected_cost = 0 # No trade size if holding
|
||||||
|
expected_penalty = 0 # No penalty in hold mode
|
||||||
|
expected_raw_reward = expected_pnl - expected_cost - expected_penalty
|
||||||
|
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||||
|
assert np.isclose(reward, expected_scaled_reward)
|
||||||
|
|
||||||
|
def test_env_step_imputed_penalty(sample_env_data, base_env_config):
|
||||||
|
cfg = base_env_config.copy()
|
||||||
|
cfg.sac.imputed_handling = 'penalty'
|
||||||
|
cfg.sac.action_penalty = 0.1 # Use a specific penalty for testing
|
||||||
|
env = TradingEnv(**sample_env_data, config=cfg)
|
||||||
|
|
||||||
|
# Step 2 is imputed
|
||||||
|
env.step(0.5) # Step 0
|
||||||
|
env.step(0.6) # Step 1
|
||||||
|
assert env.current_step == 2
|
||||||
|
position_before_imputed = env.current_position # Should be 0.6
|
||||||
|
|
||||||
|
# Action for the imputed step
|
||||||
|
action_imputed = -0.2
|
||||||
|
next_state, reward, done, info = env.step(action_imputed)
|
||||||
|
|
||||||
|
# Should process step 2 and move to step 3
|
||||||
|
assert env.current_step == 3
|
||||||
|
# Position should update to the *agent's* action
|
||||||
|
assert env.current_position == np.clip(action_imputed, -1.0, 1.0)
|
||||||
|
assert not info['is_imputed_step_skipped']
|
||||||
|
assert not done
|
||||||
|
|
||||||
|
# Reward calculation is ONLY the penalty
|
||||||
|
expected_raw_reward = -cfg.sac.action_penalty * (action_imputed - position_before_imputed)**2
|
||||||
|
expected_scaled_reward = expected_raw_reward * cfg.environment.reward_scale
|
||||||
|
assert np.isclose(reward, expected_scaled_reward)
|
||||||
|
|
||||||
|
def test_env_done_condition(trading_env_instance, sample_env_data):
|
||||||
|
n_steps = len(sample_env_data['actual_returns'])
|
||||||
|
# Step through the environment
|
||||||
|
done = False
|
||||||
|
for i in range(n_steps):
|
||||||
|
_, _, done, _ = trading_env_instance.step(np.random.uniform(-1, 1))
|
||||||
|
if i < n_steps - 1:
|
||||||
|
assert not done
|
||||||
|
else:
|
||||||
|
assert done # Should be done on the last step
|
||||||
Loading…
x
Reference in New Issue
Block a user