Add tracked files based on updated .gitignore
This commit is contained in:
parent
20705a5690
commit
c004e96369
183
gru_sac_predictor/tests/test_calibration.py
Normal file
183
gru_sac_predictor/tests/test_calibration.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
"""
|
||||||
|
Tests for probability calibration (Sec 6 of revisions.txt).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from scipy.stats import binomtest
|
||||||
|
from scipy.special import logit, expit
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Try to import the modules; skip tests if not found (e.g., path issues)
|
||||||
|
try:
|
||||||
|
from gru_sac_predictor.src import calibrate
|
||||||
|
except ImportError:
|
||||||
|
calibrate = None
|
||||||
|
|
||||||
|
# --- Import VectorCalibrator (Task 4) --- #
|
||||||
|
try:
|
||||||
|
from gru_sac_predictor.src.calibrator_vector import VectorCalibrator
|
||||||
|
except ImportError:
|
||||||
|
VectorCalibrator = None
|
||||||
|
# --- End Import --- #
|
||||||
|
|
||||||
|
# --- Helper Function for ECE --- #
|
||||||
|
def _calculate_ece(probs: np.ndarray, y_true: np.ndarray, n_bins: int = 10) -> float:
|
||||||
|
"""
|
||||||
|
Calculates the Expected Calibration Error (ECE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (np.ndarray): Predicted probabilities for the positive class (N,) or all classes (N, K).
|
||||||
|
y_true (np.ndarray): True labels (0 or 1 for binary, or class index for multi-class).
|
||||||
|
n_bins (int): Number of bins to divide probabilities into.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The calculated ECE score.
|
||||||
|
"""
|
||||||
|
if len(probs.shape) == 1: # Binary case
|
||||||
|
p_max = probs
|
||||||
|
y_pred_class = (probs > 0.5).astype(int)
|
||||||
|
y_true_class = y_true
|
||||||
|
elif len(probs.shape) == 2: # Multi-class case
|
||||||
|
p_max = np.max(probs, axis=1)
|
||||||
|
y_pred_class = np.argmax(probs, axis=1)
|
||||||
|
# If y_true is one-hot, convert to class index
|
||||||
|
if len(y_true.shape) == 2 and y_true.shape[1] > 1:
|
||||||
|
y_true_class = np.argmax(y_true, axis=1)
|
||||||
|
else:
|
||||||
|
y_true_class = y_true # Assume already class index
|
||||||
|
else:
|
||||||
|
raise ValueError("probs array must be 1D or 2D")
|
||||||
|
|
||||||
|
ece = 0.0
|
||||||
|
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
||||||
|
|
||||||
|
for i in range(n_bins):
|
||||||
|
in_bin = (p_max > bin_boundaries[i]) & (p_max <= bin_boundaries[i+1])
|
||||||
|
prop_in_bin = np.mean(in_bin)
|
||||||
|
|
||||||
|
if prop_in_bin > 0:
|
||||||
|
accuracy_in_bin = np.mean(y_pred_class[in_bin] == y_true_class[in_bin])
|
||||||
|
avg_confidence_in_bin = np.mean(p_max[in_bin])
|
||||||
|
ece += np.abs(accuracy_in_bin - avg_confidence_in_bin) * prop_in_bin
|
||||||
|
|
||||||
|
return ece
|
||||||
|
# --- End ECE Helper --- #
|
||||||
|
|
||||||
|
# --- Fixtures ---
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def calibration_data():
|
||||||
|
"""
|
||||||
|
Generate sample raw probabilities and true outcomes.
|
||||||
|
Simulates an overconfident model (T_implied < 1) where true probability drifts.
|
||||||
|
"""
|
||||||
|
np.random.seed(42)
|
||||||
|
n_samples = 2500
|
||||||
|
# Simulate drifting true probability centered around 0.5
|
||||||
|
drift = 0.05 * np.sin(np.linspace(0, 3 * np.pi, n_samples))
|
||||||
|
true_prob = np.clip(0.5 + drift + np.random.randn(n_samples) * 0.05, 0.05, 0.95)
|
||||||
|
# Simulate overconfidence (implied T ~ 0.7)
|
||||||
|
raw_logits = logit(true_prob) / 0.7
|
||||||
|
p_raw = expit(raw_logits)
|
||||||
|
# Generate true outcomes
|
||||||
|
y_true = (np.random.rand(n_samples) < true_prob).astype(int)
|
||||||
|
return p_raw, y_true
|
||||||
|
|
||||||
|
# --- Tests ---
|
||||||
|
@pytest.mark.skipif(calibrate is None, reason="Module gru_sac_predictor.src.calibrate not found")
|
||||||
|
def test_optimise_temperature(calibration_data):
|
||||||
|
"""Check if optimise_temperature runs and returns a plausible value."""
|
||||||
|
p_raw, y_true = calibration_data
|
||||||
|
optimal_T = calibrate.optimise_temperature(p_raw, y_true)
|
||||||
|
print(f"\nOptimised T: {optimal_T:.4f}")
|
||||||
|
# Expect T > 0. A T near 0.7 would undo the simulated effect.
|
||||||
|
assert optimal_T > 0.1 and optimal_T < 5.0, "Optimised temperature seems out of expected range."
|
||||||
|
|
||||||
|
@pytest.mark.skipif(calibrate is None, reason="Module gru_sac_predictor.src.calibrate not found")
|
||||||
|
def test_calibration_hit_rate_threshold(calibration_data):
|
||||||
|
"""
|
||||||
|
Verify that the lower 95% CI of the hit-rate for non-zero calibrated
|
||||||
|
signals is >= 0.55 (using the module's EDGE_THR).
|
||||||
|
"""
|
||||||
|
p_raw, y_true = calibration_data
|
||||||
|
optimal_T = calibrate.optimise_temperature(p_raw, y_true)
|
||||||
|
p_cal = calibrate.calibrate(p_raw, optimal_T)
|
||||||
|
action_signals = calibrate.action_signal(p_cal)
|
||||||
|
|
||||||
|
# Filter for non-zero signals
|
||||||
|
non_zero_idx = action_signals != 0
|
||||||
|
if not np.any(non_zero_idx):
|
||||||
|
pytest.fail("No non-zero action signals generated for hit-rate test.")
|
||||||
|
|
||||||
|
signals_taken = action_signals[non_zero_idx]
|
||||||
|
actual_direction = y_true[non_zero_idx]
|
||||||
|
|
||||||
|
# Hit: signal matches actual direction (1 vs 1, -1 vs 0)
|
||||||
|
hits = np.sum((signals_taken == 1) & (actual_direction == 1)) + \
|
||||||
|
np.sum((signals_taken == -1) & (actual_direction == 0))
|
||||||
|
total_trades = len(signals_taken)
|
||||||
|
|
||||||
|
if total_trades < 30:
|
||||||
|
pytest.skip(f"Insufficient non-zero signals ({total_trades}) for reliable CI.")
|
||||||
|
|
||||||
|
# Calculate 95% lower CI using binomial test
|
||||||
|
try:
|
||||||
|
# Ensure hits is integer
|
||||||
|
hits = int(hits)
|
||||||
|
result = binomtest(hits, total_trades, p=0.5, alternative='greater')
|
||||||
|
lower_ci = result.proportion_ci(confidence_level=0.95).low
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Binomial test failed: {e}")
|
||||||
|
|
||||||
|
hit_rate = hits / total_trades
|
||||||
|
required_threshold = calibrate.EDGE_THR # Use threshold from module
|
||||||
|
|
||||||
|
print(f"\nCalibration Test: EDGE_THR={required_threshold:.3f}")
|
||||||
|
print(f" Trades={total_trades}, Hits={hits}, Hit Rate={hit_rate:.4f}")
|
||||||
|
print(f" 95% Lower CI: {lower_ci:.4f}")
|
||||||
|
|
||||||
|
assert lower_ci >= required_threshold, \
|
||||||
|
f"Hit rate lower CI ({lower_ci:.4f}) is below module threshold ({required_threshold:.3f})"
|
||||||
|
|
||||||
|
# --- Vector Scaling Test (Task 4.4) --- #
|
||||||
|
@pytest.mark.skipif(VectorCalibrator is None, reason="VectorCalibrator not found")
|
||||||
|
def test_vector_scaling_calibration():
|
||||||
|
"""Check if Vector Scaling reduces ECE on sample multi-class data."""
|
||||||
|
np.random.seed(123)
|
||||||
|
n_samples = 5000
|
||||||
|
num_classes = 3
|
||||||
|
|
||||||
|
# Simulate slightly miscalibrated logits (e.g., too peaky or too flat)
|
||||||
|
# True distribution is uniform-ish
|
||||||
|
true_labels = np.random.randint(0, num_classes, n_samples)
|
||||||
|
y_onehot = tf.keras.utils.to_categorical(true_labels, num_classes=num_classes)
|
||||||
|
|
||||||
|
# Generate logits - make class 1 slightly more likely, and make logits "peaky"
|
||||||
|
logits_raw = np.random.randn(n_samples, num_classes) * 0.5 # Base noise
|
||||||
|
logits_raw[:, 1] += 0.5 # Bias towards class 1
|
||||||
|
# Add systematic miscalibration (e.g., scale up logits -> overconfidence)
|
||||||
|
logits_miscalibrated = logits_raw * 1.8
|
||||||
|
|
||||||
|
# Instantiate calibrator
|
||||||
|
vector_cal = VectorCalibrator()
|
||||||
|
|
||||||
|
# Calculate ECE before calibration
|
||||||
|
probs_uncal = vector_cal._softmax(logits_miscalibrated)
|
||||||
|
ece_before = _calculate_ece(probs_uncal, true_labels)
|
||||||
|
|
||||||
|
# Fit vector scaling
|
||||||
|
vector_cal.fit(logits_miscalibrated, y_onehot)
|
||||||
|
assert vector_cal.W is not None and vector_cal.b is not None, "Vector scaling fit failed"
|
||||||
|
|
||||||
|
# Calibrate probabilities
|
||||||
|
probs_cal = vector_cal.calibrate(logits_miscalibrated)
|
||||||
|
|
||||||
|
# Calculate ECE after calibration
|
||||||
|
ece_after = _calculate_ece(probs_cal, true_labels)
|
||||||
|
|
||||||
|
print(f"\nVector Scaling Test: ECE Before = {ece_before:.4f}, ECE After = {ece_after:.4f}")
|
||||||
|
|
||||||
|
# Assert that ECE improved (decreased)
|
||||||
|
# Allow for slight numerical noise, but expect significant improvement
|
||||||
|
assert ece_after < ece_before * 0.7, f"ECE did not improve significantly after Vector Scaling (Before: {ece_before:.4f}, After: {ece_after:.4f})"
|
||||||
|
# Assert ECE is reasonably low after calibration
|
||||||
|
assert ece_after < 0.05, f"ECE after Vector Scaling ({ece_after:.4f}) is higher than expected (< 0.05)"
|
||||||
125
gru_sac_predictor/tests/test_feature_engineer.py
Normal file
125
gru_sac_predictor/tests/test_feature_engineer.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
Tests for the FeatureEngineer class and its methods.
|
||||||
|
|
||||||
|
Ref: revisions.txt Task 2.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import sys, os
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# --- Add path for src imports --- #
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.dirname(script_dir)
|
||||||
|
src_path = os.path.join(project_root, 'src')
|
||||||
|
if src_path not in sys.path:
|
||||||
|
sys.path.insert(0, src_path)
|
||||||
|
# --- End Add path --- #
|
||||||
|
|
||||||
|
from feature_engineer import FeatureEngineer
|
||||||
|
# Import minimal_whitelist from features to pass to constructor
|
||||||
|
from features import minimal_whitelist as base_minimal_whitelist
|
||||||
|
|
||||||
|
# --- Fixtures --- #
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_engineer() -> FeatureEngineer:
|
||||||
|
"""Provides a FeatureEngineer instance with a basic whitelist."""
|
||||||
|
# Use a copy to avoid modifying the original during tests
|
||||||
|
test_whitelist = base_minimal_whitelist.copy()
|
||||||
|
return FeatureEngineer(minimal_whitelist=test_whitelist)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_feature_data() -> pd.DataFrame:
|
||||||
|
"""Creates sample features for testing selection."""
|
||||||
|
np.random.seed(42)
|
||||||
|
data = {
|
||||||
|
'return_1m': np.random.randn(100) * 0.01,
|
||||||
|
'EMA_50': 100 + np.random.randn(100).cumsum() * 0.1,
|
||||||
|
'ATR_14': np.random.rand(100) * 0.5,
|
||||||
|
'hour_sin': np.sin(np.linspace(0, 2 * np.pi, 100)),
|
||||||
|
'highly_correlated_1': 100 + np.random.randn(100).cumsum() * 0.1, # Copy EMA_50 roughly
|
||||||
|
'highly_correlated_2': 101 + np.random.randn(100).cumsum() * 0.1, # Copy EMA_50 roughly
|
||||||
|
'constant_feat': np.ones(100),
|
||||||
|
'nan_feat': np.full(100, np.nan),
|
||||||
|
'inf_feat': np.full(100, np.inf)
|
||||||
|
}
|
||||||
|
index = pd.date_range(start='2023-01-01', periods=100, freq='min', tz='UTC')
|
||||||
|
df = pd.DataFrame(data, index=index)
|
||||||
|
# Add the correlation
|
||||||
|
df['highly_correlated_1'] = df['EMA_50'] * (1 + np.random.randn(100) * 0.01)
|
||||||
|
df['highly_correlated_2'] = df['highly_correlated_1'] * (1 + np.random.randn(100) * 0.01)
|
||||||
|
return df
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_target_data() -> pd.Series:
|
||||||
|
"""Creates sample binary target variable."""
|
||||||
|
np.random.seed(123)
|
||||||
|
# Create somewhat predictable target based on EMA_50 trend
|
||||||
|
ema = 100 + np.random.randn(100).cumsum() * 0.1
|
||||||
|
target = (np.diff(ema, prepend=0) > 0).astype(int)
|
||||||
|
index = pd.date_range(start='2023-01-01', periods=100, freq='min', tz='UTC')
|
||||||
|
return pd.Series(target, index=index)
|
||||||
|
|
||||||
|
# --- Tests --- #
|
||||||
|
|
||||||
|
def test_select_features_vif_skip(sample_engineer, sample_feature_data, sample_target_data):
|
||||||
|
"""
|
||||||
|
Test 2.5: Assert VIF calculation is skipped if skip_vif=True in config.
|
||||||
|
We need to mock the config access within select_features.
|
||||||
|
"""
|
||||||
|
engineer = sample_engineer
|
||||||
|
X_train = sample_feature_data
|
||||||
|
y_train = sample_target_data
|
||||||
|
|
||||||
|
# Mock the config dictionary that would be passed or accessed
|
||||||
|
# For now, assume select_features might take an optional config or we patch where it reads it.
|
||||||
|
# Since it doesn't currently take config, we have to modify the method or mock dependencies.
|
||||||
|
# Let's *assume* for this test that select_features *will be* modified to check a config.
|
||||||
|
# We will patch the VIF function itself and assert it's not called.
|
||||||
|
|
||||||
|
# Add a feature that would definitely be removed by VIF to ensure the check matters
|
||||||
|
X_train['perfectly_correlated'] = X_train['EMA_50'] * 2
|
||||||
|
|
||||||
|
with patch('feature_engineer.variance_inflation_factor') as mock_vif:
|
||||||
|
# We also need to mock the SelectFromModel part to return *some* features initially
|
||||||
|
with patch('feature_engineer.SelectFromModel') as mock_select_from_model:
|
||||||
|
# Configure the mock selector to return a subset of features including correlated ones
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
initial_selection = [True] * 5 + [False] * 4 + [True] # Select first 5 + perfectly_correlated
|
||||||
|
mock_instance.get_support.return_value = np.array(initial_selection)
|
||||||
|
mock_select_from_model.return_value = mock_instance
|
||||||
|
|
||||||
|
# Call select_features - **modify it conceptually to accept skip_vif**
|
||||||
|
# Since we can't modify the source directly here, we test by asserting VIF wasn't called.
|
||||||
|
# This implicitly tests the skip logic.
|
||||||
|
|
||||||
|
# Simulate the call as if skip_vif=True was passed/checked internally
|
||||||
|
# Patch the VIF calculation call site directly
|
||||||
|
with patch('feature_engineer.sm.add_constant') as mock_add_constant: # VIF loop uses this
|
||||||
|
# Call the function normally - the patch on VIF itself is the key
|
||||||
|
selected_features = engineer.select_features(X_train, y_train)
|
||||||
|
|
||||||
|
# Assert that variance_inflation_factor was NOT called
|
||||||
|
mock_vif.assert_not_called()
|
||||||
|
# Assert that add_constant (used within VIF loop) was also NOT called
|
||||||
|
mock_add_constant.assert_not_called()
|
||||||
|
|
||||||
|
# Assert that the features returned are those from the mocked L1 selection
|
||||||
|
# (potentially plus minimal whitelist, depending on implementation)
|
||||||
|
# The exact output depends on how L1 + whitelist are combined *before* VIF step
|
||||||
|
# Let's just assert the correlated feature IS included, as VIF didn't remove it
|
||||||
|
assert 'perfectly_correlated' in selected_features
|
||||||
|
|
||||||
|
# We should also check that the log message indicating VIF skip was printed
|
||||||
|
# (This requires capturing logs, omitted here for brevity)
|
||||||
|
|
||||||
|
# TODO: Add more tests for FeatureEngineer
|
||||||
|
# - Test feature calculation methods (_add_cyclical_features, _add_imbalance_features, _add_ta_features)
|
||||||
|
# - Test add_base_features orchestration
|
||||||
|
# - Test select_features VIF logic *when enabled* (e.g., check correlated feature is removed)
|
||||||
|
# - Test select_features LogReg L1 logic (e.g., check constant feature is removed)
|
||||||
|
# - Test handling of NaNs/Infs in select_features
|
||||||
|
# - Test prune_features (although covered in test_feature_pruning.py)
|
||||||
87
gru_sac_predictor/tests/test_feature_pruning.py
Normal file
87
gru_sac_predictor/tests/test_feature_pruning.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
Tests for feature pruning logic.
|
||||||
|
|
||||||
|
Ref: revisions.txt Step 1-D
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# TODO: Import prune_features function and minimal_whitelist from src.features
|
||||||
|
# from gru_sac_predictor.src.features import prune_features, minimal_whitelist
|
||||||
|
|
||||||
|
# Mock minimal_whitelist for testing if import fails
|
||||||
|
minimal_whitelist = ['feat_a', 'feat_b', 'feat_c', 'hour_sin']
|
||||||
|
|
||||||
|
# Mock prune_features if import fails
|
||||||
|
def prune_features(df: pd.DataFrame, whitelist: list[str] | None = None) -> pd.DataFrame:
|
||||||
|
if whitelist is None:
|
||||||
|
whitelist = minimal_whitelist
|
||||||
|
cols_to_keep = [c for c in whitelist if c in df.columns]
|
||||||
|
df_pruned = df[cols_to_keep].copy()
|
||||||
|
assert set(df_pruned.columns) == set(cols_to_keep), \
|
||||||
|
f"Pruning failed: Output columns {set(df_pruned.columns)} != Expected intersection {set(cols_to_keep)}"
|
||||||
|
return df_pruned
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_dataframe() -> pd.DataFrame:
|
||||||
|
"""Create a sample DataFrame for testing."""
|
||||||
|
data = {
|
||||||
|
'feat_a': [1, 2, 3],
|
||||||
|
'feat_b': [4, 5, 6],
|
||||||
|
'feat_extra': [7, 8, 9],
|
||||||
|
'hour_sin': [0.1, 0.2, 0.3]
|
||||||
|
}
|
||||||
|
return pd.DataFrame(data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prune_to_minimal_whitelist(sample_dataframe):
|
||||||
|
"""Test pruning to the default minimal whitelist."""
|
||||||
|
df_pruned = prune_features(sample_dataframe, whitelist=minimal_whitelist)
|
||||||
|
|
||||||
|
expected_cols = {'feat_a', 'feat_b', 'hour_sin'}
|
||||||
|
assert set(df_pruned.columns) == expected_cols
|
||||||
|
assert 'feat_extra' not in df_pruned.columns
|
||||||
|
|
||||||
|
def test_prune_with_custom_whitelist(sample_dataframe):
|
||||||
|
"""Test pruning with a custom whitelist."""
|
||||||
|
custom_whitelist = ['feat_a', 'feat_extra']
|
||||||
|
df_pruned = prune_features(sample_dataframe, whitelist=custom_whitelist)
|
||||||
|
|
||||||
|
expected_cols = {'feat_a', 'feat_extra'}
|
||||||
|
assert set(df_pruned.columns) == expected_cols
|
||||||
|
assert 'feat_b' not in df_pruned.columns
|
||||||
|
assert 'hour_sin' not in df_pruned.columns
|
||||||
|
|
||||||
|
def test_prune_missing_whitelist_cols(sample_dataframe):
|
||||||
|
"""Test when whitelist contains columns not in the dataframe."""
|
||||||
|
custom_whitelist = ['feat_a', 'feat_c', 'hour_sin'] # feat_c is not in sample_dataframe
|
||||||
|
df_pruned = prune_features(sample_dataframe, whitelist=custom_whitelist)
|
||||||
|
|
||||||
|
expected_cols = {'feat_a', 'hour_sin'} # Only existing columns are kept
|
||||||
|
assert set(df_pruned.columns) == expected_cols
|
||||||
|
assert 'feat_c' not in df_pruned.columns
|
||||||
|
|
||||||
|
def test_prune_empty_whitelist():
|
||||||
|
"""Test pruning with an empty whitelist."""
|
||||||
|
df = pd.DataFrame({'a': [1], 'b': [2]})
|
||||||
|
df_pruned = prune_features(df, whitelist=[])
|
||||||
|
assert df_pruned.empty
|
||||||
|
assert df_pruned.columns.empty
|
||||||
|
|
||||||
|
def test_prune_empty_dataframe():
|
||||||
|
"""Test pruning an empty dataframe."""
|
||||||
|
df = pd.DataFrame()
|
||||||
|
df_pruned = prune_features(df, whitelist=minimal_whitelist)
|
||||||
|
assert df_pruned.empty
|
||||||
|
assert df_pruned.columns.empty
|
||||||
|
|
||||||
|
def test_prune_assertion(sample_dataframe):
|
||||||
|
"""Verify the assertion within prune_features catches mismatches (requires mocking or specific setup)."""
|
||||||
|
# This test might be tricky without modifying the function or using complex mocks.
|
||||||
|
# The assertion `assert set(df_pruned.columns) == set(cols_to_keep)` should generally hold
|
||||||
|
# if the logic `df_pruned = df[cols_to_keep].copy()` is correct.
|
||||||
|
# We rely on the other tests implicitly covering this assertion.
|
||||||
|
pytest.skip("Assertion test might require specific mocking setup.")
|
||||||
|
|
||||||
|
# Add tests for edge cases like DataFrames with duplicate column names if relevant.
|
||||||
117
gru_sac_predictor/tests/test_integration.py
Normal file
117
gru_sac_predictor/tests/test_integration.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for cross-module interactions.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import tempfile
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Try to import the module; skip tests if not found
|
||||||
|
try:
|
||||||
|
from gru_sac_predictor.src import sac_agent
|
||||||
|
import tensorflow as tf # Needed for agent init/load
|
||||||
|
except ImportError:
|
||||||
|
sac_agent = None
|
||||||
|
tf = None
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sac_agent_for_integration():
|
||||||
|
"""Provides a basic SAC agent instance."""
|
||||||
|
if sac_agent is None or tf is None:
|
||||||
|
pytest.skip("SAC Agent module or TF not found.")
|
||||||
|
# Use minimal params for saving/loading tests
|
||||||
|
agent = sac_agent.SACTradingAgent(
|
||||||
|
state_dim=5, action_dim=1,
|
||||||
|
buffer_capacity=100, min_buffer_size=10
|
||||||
|
)
|
||||||
|
# Build models
|
||||||
|
try:
|
||||||
|
agent.actor(tf.zeros((1, 5)))
|
||||||
|
agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))])
|
||||||
|
agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))])
|
||||||
|
agent.update_target_networks(tau=1.0)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Failed to build agent models: {e}")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TF not found")
|
||||||
|
def test_save_load_metadata(sac_agent_for_integration):
|
||||||
|
"""Test if metadata is saved and loaded correctly."""
|
||||||
|
agent = sac_agent_for_integration
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
save_path = os.path.join(tmpdir, "sac_test_save")
|
||||||
|
agent.save(save_path)
|
||||||
|
|
||||||
|
# Check if metadata file exists
|
||||||
|
meta_path = os.path.join(save_path, 'agent_metadata.json')
|
||||||
|
assert os.path.exists(meta_path), "Metadata file was not saved."
|
||||||
|
|
||||||
|
# Create a new agent and load
|
||||||
|
new_agent = sac_agent.SACTradingAgent(state_dim=5, action_dim=1)
|
||||||
|
loaded_meta = new_agent.load(save_path)
|
||||||
|
|
||||||
|
assert isinstance(loaded_meta, dict), "Load method did not return a dict."
|
||||||
|
assert loaded_meta.get('state_dim') == 5, "Loaded state_dim incorrect."
|
||||||
|
assert loaded_meta.get('action_dim') == 1, "Loaded action_dim incorrect."
|
||||||
|
# Check alpha status (default is auto_tune=True)
|
||||||
|
assert loaded_meta.get('log_alpha_saved') == True, "log_alpha status incorrect."
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TF not found")
|
||||||
|
def test_replay_buffer_purge_on_change(sac_agent_for_integration):
|
||||||
|
"""
|
||||||
|
Simulate loading an agent where the edge_threshold has changed
|
||||||
|
and verify the buffer is cleared.
|
||||||
|
"""
|
||||||
|
agent_to_save = sac_agent_for_integration
|
||||||
|
original_edge_thr = 0.55
|
||||||
|
agent_to_save.edge_threshold_config = original_edge_thr # Manually set for saving
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
save_path = os.path.join(tmpdir, "sac_purge_test")
|
||||||
|
|
||||||
|
# 1. Save agent with original threshold in metadata
|
||||||
|
agent_to_save.save(save_path)
|
||||||
|
meta_path = os.path.join(save_path, 'agent_metadata.json')
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, 'r') as f:
|
||||||
|
saved_meta = json.load(f)
|
||||||
|
assert saved_meta.get('edge_threshold_config') == original_edge_thr
|
||||||
|
|
||||||
|
# 2. Create a new agent instance to load into
|
||||||
|
new_agent = sac_agent.SACTradingAgent(
|
||||||
|
state_dim=5, action_dim=1,
|
||||||
|
buffer_capacity=100, min_buffer_size=10
|
||||||
|
)
|
||||||
|
# Build models for the new agent
|
||||||
|
try:
|
||||||
|
new_agent.actor(tf.zeros((1, 5)))
|
||||||
|
new_agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))])
|
||||||
|
new_agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))])
|
||||||
|
new_agent.update_target_networks(tau=1.0)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Failed to build new agent models: {e}")
|
||||||
|
|
||||||
|
# Add dummy data to the *new* agent's buffer *before* loading
|
||||||
|
for _ in range(20):
|
||||||
|
dummy_state = np.random.rand(5).astype(np.float32)
|
||||||
|
dummy_action = np.random.rand(1).astype(np.float32)
|
||||||
|
new_agent.buffer.add(dummy_state, dummy_action, 0.0, dummy_state, 0.0)
|
||||||
|
assert len(new_agent.buffer) == 20, "Buffer should have data before load."
|
||||||
|
|
||||||
|
# 3. Simulate loading with a *different* current edge threshold config
|
||||||
|
current_config_edge_thr = 0.60
|
||||||
|
assert abs(current_config_edge_thr - original_edge_thr) > 1e-6
|
||||||
|
|
||||||
|
loaded_meta = new_agent.load(save_path)
|
||||||
|
saved_edge_thr = loaded_meta.get('edge_threshold_config')
|
||||||
|
|
||||||
|
# 4. Perform the check and clear if needed (simulating pipeline logic)
|
||||||
|
if saved_edge_thr is not None and abs(saved_edge_thr - current_config_edge_thr) > 1e-6:
|
||||||
|
print(f"\nEdge threshold mismatch detected (Saved={saved_edge_thr}, Current={current_config_edge_thr}). Clearing buffer.")
|
||||||
|
new_agent.clear_buffer()
|
||||||
|
else:
|
||||||
|
print(f"\nEdge threshold match or not saved. Buffer not cleared.")
|
||||||
|
|
||||||
|
# 5. Assert buffer is now empty
|
||||||
|
assert len(new_agent.buffer) == 0, "Buffer was not cleared after edge threshold mismatch."
|
||||||
201
gru_sac_predictor/tests/test_labels.py
Normal file
201
gru_sac_predictor/tests/test_labels.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
Tests for label generation and potential leakage.
|
||||||
|
|
||||||
|
Ref: revisions.txt Step 1-A, 1.4
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
# --- Add path for src imports --- #
|
||||||
|
# Assuming tests is one level down from the package root
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.dirname(script_dir) # Go up one level
|
||||||
|
src_path = os.path.join(project_root, 'src')
|
||||||
|
if src_path not in sys.path:
|
||||||
|
sys.path.insert(0, src_path)
|
||||||
|
# --- End Add path --- #
|
||||||
|
|
||||||
|
# Import the function to test
|
||||||
|
from trading_pipeline import _generate_direction_labels
|
||||||
|
|
||||||
|
# --- Fixtures --- #
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_close_data() -> pd.DataFrame:
|
||||||
|
"""Creates a sample DataFrame with close prices and DatetimeIndex."""
|
||||||
|
# Generate data with some variation
|
||||||
|
np.random.seed(42)
|
||||||
|
prices = 100 + np.cumsum(np.random.randn(200) * 0.5)
|
||||||
|
data = {'close': prices}
|
||||||
|
index = pd.date_range(start='2023-01-01', periods=len(data['close']), freq='min', tz='UTC')
|
||||||
|
df = pd.DataFrame(data, index=index)
|
||||||
|
return df
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_config() -> dict:
|
||||||
|
"""Provides a basic config dictionary."""
|
||||||
|
return {
|
||||||
|
'gru': {
|
||||||
|
'prediction_horizon': 5,
|
||||||
|
'use_ternary': False,
|
||||||
|
'flat_sigma_multiplier': 0.25
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'label_smoothing': 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Tests --- #
|
||||||
|
|
||||||
|
def test_lookahead_bias(sample_close_data, sample_config):
|
||||||
|
"""
|
||||||
|
Test 1.4.a: Verify labels don't depend on information *beyond* the prediction horizon.
|
||||||
|
Strategy: Modify future close prices (beyond horizon) and check if labels change.
|
||||||
|
"""
|
||||||
|
df = sample_close_data
|
||||||
|
config = sample_config
|
||||||
|
horizon = config['gru']['prediction_horizon']
|
||||||
|
|
||||||
|
# Generate baseline labels (binary)
|
||||||
|
df_labeled_base, label_col_base = _generate_direction_labels(df.copy(), config)
|
||||||
|
|
||||||
|
# Modify close prices far into the future (beyond the horizon needed for any label)
|
||||||
|
df_modified = df.copy()
|
||||||
|
future_index = len(df) - 1 # Index of the last point
|
||||||
|
modify_point = future_index - horizon - 5 # Index well beyond the last needed future price
|
||||||
|
if modify_point > 0:
|
||||||
|
df_modified.iloc[modify_point:, df_modified.columns.get_loc('close')] *= 1.5 # Modify future prices
|
||||||
|
|
||||||
|
# Generate labels with modified future data
|
||||||
|
df_labeled_mod, label_col_mod = _generate_direction_labels(df_modified.copy(), config)
|
||||||
|
|
||||||
|
# Align based on index (label function drops NaNs at the end)
|
||||||
|
common_index = df_labeled_base.index.intersection(df_labeled_mod.index)
|
||||||
|
labels_base_aligned = df_labeled_base.loc[common_index, label_col_base]
|
||||||
|
labels_mod_aligned = df_labeled_mod.loc[common_index, label_col_mod]
|
||||||
|
|
||||||
|
# Assert: Labels should be identical, as modification was beyond the horizon
|
||||||
|
pd.testing.assert_series_equal(labels_base_aligned, labels_mod_aligned, check_names=False)
|
||||||
|
|
||||||
|
# --- Repeat for Ternary --- #
|
||||||
|
config['gru']['use_ternary'] = True
|
||||||
|
df_labeled_base_t, label_col_base_t = _generate_direction_labels(df.copy(), config)
|
||||||
|
df_labeled_mod_t, label_col_mod_t = _generate_direction_labels(df_modified.copy(), config)
|
||||||
|
|
||||||
|
common_index_t = df_labeled_base_t.index.intersection(df_labeled_mod_t.index)
|
||||||
|
labels_base_aligned_t = df_labeled_base_t.loc[common_index_t, label_col_base_t]
|
||||||
|
labels_mod_aligned_t = df_labeled_mod_t.loc[common_index_t, label_col_mod_t]
|
||||||
|
|
||||||
|
# Assert: Ternary labels should also be identical
|
||||||
|
# Need careful comparison for list/array column
|
||||||
|
assert labels_base_aligned_t.equals(labels_mod_aligned_t)
|
||||||
|
|
||||||
|
def test_binary_label_distribution(sample_close_data, sample_config):
|
||||||
|
"""
|
||||||
|
Test 1.4.b: Check binary label distribution has >= 5% in each class.
|
||||||
|
"""
|
||||||
|
df = sample_close_data
|
||||||
|
config = sample_config
|
||||||
|
config['gru']['use_ternary'] = False
|
||||||
|
config['data']['label_smoothing'] = 0.0 # Ensure hard binary for this test
|
||||||
|
|
||||||
|
df_labeled, label_col = _generate_direction_labels(df.copy(), config)
|
||||||
|
|
||||||
|
assert not df_labeled.empty, "Label generation resulted in empty DataFrame"
|
||||||
|
assert label_col in df_labeled.columns, f"Label column '{label_col}' not found"
|
||||||
|
|
||||||
|
labels = df_labeled[label_col]
|
||||||
|
counts = labels.value_counts(normalize=True)
|
||||||
|
|
||||||
|
assert len(counts) == 2, f"Expected 2 binary classes, found {len(counts)}"
|
||||||
|
assert counts.min() >= 0.05, f"Minimum binary class proportion ({counts.min():.2%}) is less than 5%"
|
||||||
|
print(f"\nBinary Dist: {counts.to_dict()}") # Print for info
|
||||||
|
|
||||||
|
def test_soft_binary_label_distribution(sample_close_data, sample_config):
|
||||||
|
"""
|
||||||
|
Test 1.4.b: Check soft binary label distribution has >= 5% in each effective class.
|
||||||
|
"""
|
||||||
|
df = sample_close_data
|
||||||
|
config = sample_config
|
||||||
|
config['gru']['use_ternary'] = False
|
||||||
|
config['data']['label_smoothing'] = 0.2 # Example smoothing
|
||||||
|
smoothing = config['data']['label_smoothing']
|
||||||
|
low_label = smoothing / 2.0
|
||||||
|
high_label = 1.0 - smoothing / 2.0
|
||||||
|
|
||||||
|
df_labeled, label_col = _generate_direction_labels(df.copy(), config)
|
||||||
|
|
||||||
|
assert not df_labeled.empty, "Label generation resulted in empty DataFrame"
|
||||||
|
assert label_col in df_labeled.columns, f"Label column '{label_col}' not found"
|
||||||
|
|
||||||
|
labels = df_labeled[label_col]
|
||||||
|
counts = labels.value_counts(normalize=True)
|
||||||
|
|
||||||
|
assert len(counts) == 2, f"Expected 2 soft binary classes, found {len(counts)}"
|
||||||
|
assert counts.min() >= 0.05, f"Minimum soft binary class proportion ({counts.min():.2%}) is less than 5%"
|
||||||
|
assert low_label in counts.index, f"Low label {low_label} not found in counts"
|
||||||
|
assert high_label in counts.index, f"High label {high_label} not found in counts"
|
||||||
|
print(f"\nSoft Binary Dist: {counts.to_dict()}")
|
||||||
|
|
||||||
|
def test_ternary_label_distribution(sample_close_data, sample_config):
|
||||||
|
"""
|
||||||
|
Test 1.4.b: Check ternary label distribution (flat=[0.15, 0.45], others >= 0.10).
|
||||||
|
Uses default k=0.25.
|
||||||
|
"""
|
||||||
|
df = sample_close_data
|
||||||
|
config = sample_config
|
||||||
|
config['gru']['use_ternary'] = True
|
||||||
|
k = config['gru']['flat_sigma_multiplier'] # Should be 0.25 from fixture
|
||||||
|
|
||||||
|
df_labeled, label_col = _generate_direction_labels(df.copy(), config)
|
||||||
|
|
||||||
|
assert not df_labeled.empty, "Label generation resulted in empty DataFrame"
|
||||||
|
assert label_col in df_labeled.columns, f"Label column '{label_col}' not found"
|
||||||
|
|
||||||
|
# Decode one-hot labels back to ordinal for distribution check
|
||||||
|
labels_one_hot = np.stack(df_labeled[label_col].values)
|
||||||
|
assert labels_one_hot.shape[1] == 3, "Ternary labels should have 3 columns"
|
||||||
|
ordinal_labels = np.argmax(labels_one_hot, axis=1)
|
||||||
|
|
||||||
|
counts = np.bincount(ordinal_labels, minlength=3)
|
||||||
|
total = len(ordinal_labels)
|
||||||
|
dist_pct = counts / total * 100
|
||||||
|
|
||||||
|
print(f"\nTernary Dist (k={k}): Down={dist_pct[0]:.1f}%, Flat={dist_pct[1]:.1f}%, Up={dist_pct[2]:.1f}%")
|
||||||
|
|
||||||
|
# Check constraints based on design doc / implementation
|
||||||
|
assert 15.0 <= dist_pct[1] <= 45.0, f"Flat class ({dist_pct[1]:.1f}%) out of expected range [15%, 45%] for k={k}"
|
||||||
|
assert dist_pct[0] >= 10.0, f"Down class ({dist_pct[0]:.1f}%) is less than 10% (check impl threshold)"
|
||||||
|
assert dist_pct[2] >= 10.0, f"Up class ({dist_pct[2]:.1f}%) is less than 10% (check impl threshold)"
|
||||||
|
|
||||||
|
# --- Old Tests (Keep or Remove?) ---
|
||||||
|
# The original tests checked 'future_close', which is related but not the final label.
|
||||||
|
# We can keep test_future_close_shift as it verifies the shift logic used internally.
|
||||||
|
# The NaN test is less relevant now as the main function handles NaN dropping.
|
||||||
|
|
||||||
|
def test_future_close_shift(sample_close_data):
|
||||||
|
"""Verify that 'future_close' is correctly shifted and has NaNs at the end."""
|
||||||
|
df = sample_close_data
|
||||||
|
horizon = 5 # Example horizon
|
||||||
|
|
||||||
|
# Apply the logic directly for testing the shift itself
|
||||||
|
df['future_close'] = df['close'].shift(-horizon)
|
||||||
|
df['fwd_log_ret'] = np.log(df['future_close'] / df['close'])
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
# 1. Check for correct shift in fwd_log_ret
|
||||||
|
# The first valid fwd_log_ret depends on close[0] and close[horizon]
|
||||||
|
assert pd.notna(df['fwd_log_ret'].iloc[0])
|
||||||
|
# The last valid fwd_log_ret depends on close[end-horizon-1] and close[end-1]
|
||||||
|
assert pd.notna(df['fwd_log_ret'].iloc[len(df) - horizon - 1])
|
||||||
|
|
||||||
|
# 2. Check for NaNs at the end due to shift
|
||||||
|
assert pd.isna(df['fwd_log_ret'].iloc[-horizon:]).all()
|
||||||
|
assert pd.notna(df['fwd_log_ret'].iloc[:-horizon]).all()
|
||||||
|
|
||||||
|
# def test_no_nan_in_future_close_output():
|
||||||
|
# """Unit test to ensure no unexpected NaNs in the output of label creation (specific to the function)."""
|
||||||
|
# # Setup similar to above, potentially call the actual DataLoader/label function
|
||||||
|
# # Assert pd.notna(output_df['future_close'][:-horizon]).all()
|
||||||
|
# pytest.skip("Test covered by NaN dropping in _generate_direction_labels and its tests.")
|
||||||
133
gru_sac_predictor/tests/test_leakage.py
Normal file
133
gru_sac_predictor/tests/test_leakage.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
Tests for data leakage (Sec 6 of revisions.txt).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Assume test data is loaded via fixtures later
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sample_data_for_leakage():
|
||||||
|
"""
|
||||||
|
Provides sample features and target for leakage tests.
|
||||||
|
Includes correctly shifted features, a feature with direct leakage,
|
||||||
|
and a rolling feature calculated correctly vs incorrectly.
|
||||||
|
"""
|
||||||
|
np.random.seed(43)
|
||||||
|
dates = pd.date_range(start='2023-01-01', periods=500, freq='T')
|
||||||
|
n = len(dates)
|
||||||
|
df = pd.DataFrame(index=dates)
|
||||||
|
df['noise'] = np.random.randn(n)
|
||||||
|
df['close'] = 100 + np.cumsum(df['noise'] * 0.1)
|
||||||
|
df['y_ret'] = np.log(df['close'].shift(-1) / df['close'])
|
||||||
|
|
||||||
|
# --- Features ---
|
||||||
|
# OK: Based on past noise
|
||||||
|
df['feature_ok_past_noise'] = df['noise'].shift(1)
|
||||||
|
# OK: Rolling mean on correctly shifted past data
|
||||||
|
df['feature_ok_rolling_shifted'] = df['noise'].shift(1).rolling(10).mean()
|
||||||
|
# LEAKY: Uses future return directly
|
||||||
|
df['feature_leaky_direct'] = df['y_ret']
|
||||||
|
# LEAKY: Rolling mean calculated *before* shifting target relationship
|
||||||
|
df['feature_leaky_rolling_unaligned'] = df['close'].rolling(5).mean()
|
||||||
|
|
||||||
|
# Drop rows with NaNs from shifts/rolls AND the last row where y_ret is NaN
|
||||||
|
df.dropna(inplace=True)
|
||||||
|
|
||||||
|
# Define features and target for the test
|
||||||
|
y_target = df['y_ret']
|
||||||
|
features_df = df.drop(columns=['close', 'y_ret', 'noise']) # Exclude raw data used for generation
|
||||||
|
|
||||||
|
return features_df, y_target
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("leakage_threshold", [0.02])
|
||||||
|
def test_feature_leakage_correlation(sample_data_for_leakage, leakage_threshold):
|
||||||
|
"""
|
||||||
|
Verify that no feature has correlation > threshold with the correctly shifted target.
|
||||||
|
"""
|
||||||
|
features_df, y_target = sample_data_for_leakage
|
||||||
|
|
||||||
|
max_abs_corr = 0.0
|
||||||
|
leaky_col = "None"
|
||||||
|
all_corrs = {}
|
||||||
|
|
||||||
|
print(f"\nTesting {features_df.shape[1]} features for leakage (threshold={leakage_threshold})...")
|
||||||
|
for col in features_df.columns:
|
||||||
|
if pd.api.types.is_numeric_dtype(features_df[col]):
|
||||||
|
# Handle potential NaNs introduced by feature engineering (though fixture avoids it)
|
||||||
|
temp_df = pd.concat([features_df[col], y_target], axis=1).dropna()
|
||||||
|
if len(temp_df) < 0.5 * len(features_df):
|
||||||
|
print(f" Skipping {col} due to excessive NaNs after merging with target.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
correlation = temp_df[col].corr(temp_df['y_ret'])
|
||||||
|
all_corrs[col] = correlation
|
||||||
|
# print(f" Corr({col}, y_ret): {correlation:.4f}")
|
||||||
|
if abs(correlation) > max_abs_corr:
|
||||||
|
max_abs_corr = abs(correlation)
|
||||||
|
leaky_col = col
|
||||||
|
else:
|
||||||
|
print(f" Skipping non-numeric column: {col}")
|
||||||
|
|
||||||
|
print(f"Correlations found: { {k: round(v, 4) for k, v in all_corrs.items()} }")
|
||||||
|
print(f"Maximum absolute correlation found: {max_abs_corr:.4f} (feature: {leaky_col})")
|
||||||
|
|
||||||
|
assert max_abs_corr < leakage_threshold, \
|
||||||
|
f"Feature '{leaky_col}' has correlation {max_abs_corr:.4f} > threshold {leakage_threshold}, suggesting leakage."
|
||||||
|
|
||||||
|
@pytest.mark.skipif(features is None, reason="Module gru_sac_predictor.src.features not found")
|
||||||
|
def test_ta_feature_leakage(sample_data_for_leakage, leakage_threshold=0.02):
|
||||||
|
"""
|
||||||
|
Specifically test TA features (EMA, MACD etc.) for leakage.
|
||||||
|
Ensures they were calculated on shifted data.
|
||||||
|
"""
|
||||||
|
features_df, y_target = sample_data_for_leakage
|
||||||
|
# Add TA features using the helper (simulating pipeline)
|
||||||
|
# We need OHLC in the input df for add_ta_features
|
||||||
|
# Recreate a df with shifted OHLC + other features for TA calc
|
||||||
|
np.random.seed(43) # Ensure consistent data with primary fixture
|
||||||
|
dates = pd.date_range(start='2023-01-01', periods=500, freq='T')
|
||||||
|
n = len(dates)
|
||||||
|
df_ohlc = pd.DataFrame(index=dates)
|
||||||
|
df_ohlc['close'] = 100 + np.cumsum(np.random.randn(n) * 0.1)
|
||||||
|
df_ohlc['open'] = df_ohlc['close'].shift(1) * (1 + np.random.randn(n) * 0.001)
|
||||||
|
df_ohlc['high'] = df_ohlc[['open','close']].max(axis=1) * (1 + np.random.rand(n) * 0.001)
|
||||||
|
df_ohlc['low'] = df_ohlc[['open','close']].min(axis=1) * (1 - np.random.rand(n) * 0.001)
|
||||||
|
df_ohlc['volume'] = np.random.rand(n) * 1000
|
||||||
|
|
||||||
|
# IMPORTANT: Shift before calculating TA features
|
||||||
|
df_shifted_ohlc = df_ohlc.shift(1)
|
||||||
|
df_ta = features.add_ta_features(df_shifted_ohlc)
|
||||||
|
|
||||||
|
# Align with the target (requires original non-shifted index)
|
||||||
|
df_ta = df_ta.loc[y_target.index]
|
||||||
|
|
||||||
|
ta_features_to_test = [col for col in features.minimal_whitelist if col in df_ta.columns and col not in ["return_1m", "return_15m", "return_60m", "hour_sin", "hour_cos"]]
|
||||||
|
max_abs_corr = 0.0
|
||||||
|
leaky_col = "None"
|
||||||
|
all_corrs = {}
|
||||||
|
|
||||||
|
print(f"\nTesting {len(ta_features_to_test)} TA features for leakage (threshold={leakage_threshold})...")
|
||||||
|
print(f" Features: {ta_features_to_test}")
|
||||||
|
|
||||||
|
for col in ta_features_to_test:
|
||||||
|
if pd.api.types.is_numeric_dtype(df_ta[col]):
|
||||||
|
temp_df = pd.concat([df_ta[col], y_target], axis=1).dropna()
|
||||||
|
if len(temp_df) < 0.5 * len(y_target):
|
||||||
|
print(f" Skipping {col} due to excessive NaNs after merging.")
|
||||||
|
continue
|
||||||
|
correlation = temp_df[col].corr(temp_df['y_ret'])
|
||||||
|
all_corrs[col] = correlation
|
||||||
|
if abs(correlation) > max_abs_corr:
|
||||||
|
max_abs_corr = abs(correlation)
|
||||||
|
leaky_col = col
|
||||||
|
else:
|
||||||
|
print(f" Skipping non-numeric TA column: {col}")
|
||||||
|
|
||||||
|
print(f"TA Feature Correlations: { {k: round(v, 4) for k, v in all_corrs.items()} }")
|
||||||
|
print(f"Maximum absolute TA correlation found: {max_abs_corr:.4f} (feature: {leaky_col})")
|
||||||
|
|
||||||
|
assert max_abs_corr < leakage_threshold, \
|
||||||
|
f"TA Feature '{leaky_col}' has correlation {max_abs_corr:.4f} > threshold {leakage_threshold}, suggesting leakage from TA calculation."
|
||||||
|
|
||||||
|
# test_label_timing is usually covered by the correlation test, so removed for brevity.
|
||||||
136
gru_sac_predictor/tests/test_metrics.py
Normal file
136
gru_sac_predictor/tests/test_metrics.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Tests for custom metric functions.
|
||||||
|
|
||||||
|
Ref: revisions.txt Task 6.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
# --- Add path for src imports --- #
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.dirname(script_dir)
|
||||||
|
src_path = os.path.join(project_root, 'src')
|
||||||
|
if src_path not in sys.path:
|
||||||
|
sys.path.insert(0, src_path)
|
||||||
|
# --- End Add path --- #
|
||||||
|
|
||||||
|
from metrics import edge_filtered_accuracy, calculate_sharpe_ratio
|
||||||
|
|
||||||
|
# --- Tests for edge_filtered_accuracy --- #
|
||||||
|
|
||||||
|
def test_edge_filtered_accuracy_basic():
|
||||||
|
"""Test basic functionality with hard labels and clear edge."""
|
||||||
|
y_true = np.array([1, 0, 1, 0, 1, 1, 0, 0])
|
||||||
|
p_cal = np.array([0.9, 0.1, 0.8, 0.2, 0.7, 0.6, 0.3, 0.4]) # Edge > 0.1 for all
|
||||||
|
thr = 0.1
|
||||||
|
|
||||||
|
accuracy, n_filtered = edge_filtered_accuracy(y_true, p_cal, thr=thr)
|
||||||
|
|
||||||
|
assert n_filtered == 8
|
||||||
|
# Predictions: 1, 0, 1, 0, 1, 1, 0, 0. All correct.
|
||||||
|
assert accuracy == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_edge_filtered_accuracy_thresholding():
|
||||||
|
"""Test that the threshold correctly filters samples."""
|
||||||
|
y_true = np.array([1, 0, 1, 0, 1, 1, 0, 0])
|
||||||
|
p_cal = np.array([0.9, 0.1, 0.8, 0.2, 0.51, 0.49, 0.55, 0.45]) # Edge: 0.8, 0.8, 0.6, 0.6, 0.02, 0.02, 0.1, 0.1
|
||||||
|
|
||||||
|
# Test with thr=0.15 (should exclude last 4 samples)
|
||||||
|
thr1 = 0.15
|
||||||
|
accuracy1, n_filtered1 = edge_filtered_accuracy(y_true, p_cal, thr=thr1)
|
||||||
|
assert n_filtered1 == 4
|
||||||
|
# Predictions on first 4: 1, 0, 1, 0. All correct.
|
||||||
|
assert accuracy1 == pytest.approx(1.0)
|
||||||
|
|
||||||
|
# Test with thr=0.05 (should include all but middle 2)
|
||||||
|
thr2 = 0.05
|
||||||
|
accuracy2, n_filtered2 = edge_filtered_accuracy(y_true, p_cal, thr=thr2)
|
||||||
|
assert n_filtered2 == 6
|
||||||
|
# Included: 1,0,1,0, 1, 0. Correct: 1,0,1,0, ?, ?. Preds: 1,0,1,0, 1, 0. 6/6 correct.
|
||||||
|
assert accuracy2 == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_edge_filtered_accuracy_soft_labels():
|
||||||
|
"""Test with soft labels."""
|
||||||
|
y_true_soft = np.array([0.9, 0.1, 0.8, 0.2, 0.7, 0.6]) # Soft labels
|
||||||
|
p_cal = np.array([0.8, 0.3, 0.9, 0.1, 0.6, 0.7]) # All edge > 0.1
|
||||||
|
thr = 0.1
|
||||||
|
|
||||||
|
accuracy, n_filtered = edge_filtered_accuracy(y_true_soft, p_cal, thr=thr)
|
||||||
|
|
||||||
|
assert n_filtered == 6
|
||||||
|
# y_true_hard: 1, 0, 1, 0, 1, 1
|
||||||
|
# y_pred : 1, 0, 1, 0, 1, 1. All correct.
|
||||||
|
assert accuracy == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_edge_filtered_accuracy_no_samples():
|
||||||
|
"""Test case where no samples meet the edge threshold."""
|
||||||
|
y_true = np.array([1, 0, 1, 0])
|
||||||
|
p_cal = np.array([0.51, 0.49, 0.52, 0.48]) # All edge < 0.1
|
||||||
|
thr = 0.1
|
||||||
|
|
||||||
|
accuracy, n_filtered = edge_filtered_accuracy(y_true, p_cal, thr=thr)
|
||||||
|
assert n_filtered == 0
|
||||||
|
assert np.isnan(accuracy)
|
||||||
|
|
||||||
|
def test_edge_filtered_accuracy_empty_input():
|
||||||
|
"""Test with empty input arrays."""
|
||||||
|
y_true = np.array([])
|
||||||
|
p_cal = np.array([])
|
||||||
|
thr = 0.1
|
||||||
|
|
||||||
|
accuracy, n_filtered = edge_filtered_accuracy(y_true, p_cal, thr=thr)
|
||||||
|
assert n_filtered == 0
|
||||||
|
assert np.isnan(accuracy)
|
||||||
|
|
||||||
|
# --- Tests for calculate_sharpe_ratio --- #
|
||||||
|
|
||||||
|
def test_calculate_sharpe_ratio_basic():
|
||||||
|
"""Test basic Sharpe calculation."""
|
||||||
|
returns = pd.Series([0.01, -0.005, 0.02, 0.005, -0.01])
|
||||||
|
# mean = 0.004, std = 0.01166, Sharpe_period = 0.343
|
||||||
|
# Annualized (252) = 0.343 * sqrt(252) = 5.44
|
||||||
|
expected_sharpe = 5.44441
|
||||||
|
sharpe = calculate_sharpe_ratio(returns, benchmark_return=0.0, annualization_factor=252)
|
||||||
|
assert sharpe == pytest.approx(expected_sharpe, abs=1e-4)
|
||||||
|
|
||||||
|
def test_calculate_sharpe_ratio_different_annualization():
|
||||||
|
"""Test Sharpe with different annualization factor."""
|
||||||
|
returns = pd.Series([0.01, -0.005, 0.02, 0.005, -0.01])
|
||||||
|
# Annualized (52) = 0.343 * sqrt(52) = 2.47
|
||||||
|
expected_sharpe = 2.4738
|
||||||
|
sharpe = calculate_sharpe_ratio(returns, benchmark_return=0.0, annualization_factor=52)
|
||||||
|
assert sharpe == pytest.approx(expected_sharpe, abs=1e-4)
|
||||||
|
|
||||||
|
def test_calculate_sharpe_ratio_with_benchmark():
|
||||||
|
"""Test Sharpe with a non-zero benchmark return."""
|
||||||
|
returns = pd.Series([0.01, -0.005, 0.02, 0.005, -0.01]) # mean=0.004
|
||||||
|
benchmark = 0.001 # Per period
|
||||||
|
# excess mean = 0.003, std = 0.01166, Sharpe_period = 0.257
|
||||||
|
# Annualized (252) = 0.257 * sqrt(252) = 4.08
|
||||||
|
expected_sharpe = 4.0833
|
||||||
|
sharpe = calculate_sharpe_ratio(returns, benchmark_return=benchmark, annualization_factor=252)
|
||||||
|
assert sharpe == pytest.approx(expected_sharpe, abs=1e-4)
|
||||||
|
|
||||||
|
def test_calculate_sharpe_ratio_zero_std():
|
||||||
|
"""Test Sharpe when returns have zero standard deviation."""
|
||||||
|
returns_positive = pd.Series([0.01, 0.01, 0.01])
|
||||||
|
returns_negative = pd.Series([-0.01, -0.01, -0.01])
|
||||||
|
returns_zero = pd.Series([0.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
assert calculate_sharpe_ratio(returns_positive) == 0.0 # Positive mean, zero std -> 0?
|
||||||
|
# assert calculate_sharpe_ratio(returns_negative) == -np.inf # Negative mean, zero std -> -inf?
|
||||||
|
assert calculate_sharpe_ratio(returns_zero) == 0.0
|
||||||
|
|
||||||
|
# Let's refine zero std handling based on function's logic
|
||||||
|
# Function returns 0 if mean>0, -inf if mean<0, 0 if mean=0
|
||||||
|
assert calculate_sharpe_ratio(returns_positive) == 0.0
|
||||||
|
assert calculate_sharpe_ratio(returns_negative) == -np.inf
|
||||||
|
assert calculate_sharpe_ratio(returns_zero) == 0.0
|
||||||
|
|
||||||
|
def test_calculate_sharpe_ratio_empty_or_nan():
|
||||||
|
"""Test Sharpe with empty or all-NaN input."""
|
||||||
|
assert np.isnan(calculate_sharpe_ratio(pd.Series([], dtype=float)))
|
||||||
|
assert np.isnan(calculate_sharpe_ratio(pd.Series([np.nan, np.nan], dtype=float)))
|
||||||
139
gru_sac_predictor/tests/test_model_shapes.py
Normal file
139
gru_sac_predictor/tests/test_model_shapes.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Tests for GRU model input/output shapes.
|
||||||
|
|
||||||
|
Ref: revisions.txt Task 3.6
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
# --- Add path for src imports --- #
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.dirname(script_dir)
|
||||||
|
src_path = os.path.join(project_root, 'src')
|
||||||
|
if src_path not in sys.path:
|
||||||
|
sys.path.insert(0, src_path)
|
||||||
|
# --- End Add path --- #
|
||||||
|
|
||||||
|
# Import the v3 model builder
|
||||||
|
from model_gru_v3 import build_gru_model_v3
|
||||||
|
# TODO: Import v2 model builder if needed for comparison tests
|
||||||
|
# from model_gru import build_gru_model
|
||||||
|
|
||||||
|
# --- Constants for Testing --- #
|
||||||
|
LOOKBACK = 60
|
||||||
|
N_FEATURES = 25
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
|
||||||
|
# --- Tests --- #
|
||||||
|
|
||||||
|
def test_gru_v3_output_shapes():
|
||||||
|
"""Verify the output shapes of the GRU v3 model heads."""
|
||||||
|
print(f"\nBuilding GRU v3 model for shape test...")
|
||||||
|
# Build the v3 model with default parameters
|
||||||
|
model = build_gru_model_v3(lookback=LOOKBACK, n_features=N_FEATURES)
|
||||||
|
assert model is not None, "Failed to build GRU v3 model"
|
||||||
|
|
||||||
|
# Check number of outputs
|
||||||
|
assert len(model.outputs) == 2, f"Expected 2 outputs, got {len(model.outputs)}"
|
||||||
|
|
||||||
|
# Check output names and shapes
|
||||||
|
# Output order in the model definition was [mu, dir3]
|
||||||
|
mu_output_shape = model.outputs[0].shape.as_list()
|
||||||
|
dir3_output_shape = model.outputs[1].shape.as_list()
|
||||||
|
|
||||||
|
# Assert shapes (ignoring batch size None)
|
||||||
|
# mu head should be (None, 1)
|
||||||
|
assert mu_output_shape == [None, 1], f"Expected mu shape [None, 1], got {mu_output_shape}"
|
||||||
|
# dir3 head should be (None, 3)
|
||||||
|
assert dir3_output_shape == [None, 3], f"Expected dir3 shape [None, 3], got {dir3_output_shape}"
|
||||||
|
|
||||||
|
print("GRU v3 output shapes test passed.")
|
||||||
|
|
||||||
|
def test_gru_v3_prediction_shapes():
|
||||||
|
"""Verify the prediction shapes match the output shapes for a sample batch."""
|
||||||
|
model = build_gru_model_v3(lookback=LOOKBACK, n_features=N_FEATURES)
|
||||||
|
assert model is not None, "Failed to build GRU v3 model"
|
||||||
|
|
||||||
|
# Create dummy input data
|
||||||
|
dummy_input = np.random.rand(BATCH_SIZE, LOOKBACK, N_FEATURES)
|
||||||
|
|
||||||
|
# Generate predictions
|
||||||
|
predictions = model.predict(dummy_input)
|
||||||
|
|
||||||
|
# Check prediction structure and shapes
|
||||||
|
assert isinstance(predictions, list), "Predictions should be a list for multi-output model"
|
||||||
|
assert len(predictions) == 2, f"Expected 2 prediction arrays, got {len(predictions)}"
|
||||||
|
|
||||||
|
# Predictions order should match model.outputs order [mu, dir3]
|
||||||
|
mu_preds = predictions[0]
|
||||||
|
dir3_preds = predictions[1]
|
||||||
|
|
||||||
|
# Assert prediction shapes match expected batch size
|
||||||
|
assert mu_preds.shape == (BATCH_SIZE, 1), f"Expected mu prediction shape ({BATCH_SIZE}, 1), got {mu_preds.shape}"
|
||||||
|
assert dir3_preds.shape == (BATCH_SIZE, 3), f"Expected dir3 prediction shape ({BATCH_SIZE}, 3), got {dir3_preds.shape}"
|
||||||
|
|
||||||
|
print("GRU v3 prediction shapes test passed.")
|
||||||
|
|
||||||
|
# TODO: Add tests for GRU v2 model shapes if it's still relevant.
|
||||||
|
|
||||||
|
def test_logits_view_shapes():
|
||||||
|
"""Test that softmax applied to predict_logits output matches predict output."""
|
||||||
|
print(f"\nBuilding GRU v3 model for logits view test...")
|
||||||
|
model = build_gru_model_v3(lookback=LOOKBACK, n_features=N_FEATURES)
|
||||||
|
assert model is not None, "Failed to build GRU v3 model"
|
||||||
|
|
||||||
|
# --- Requires GRUModelHandler to run predict_logits --- #
|
||||||
|
# We need to instantiate the handler to test its methods.
|
||||||
|
# Mock config and directories needed for handler init.
|
||||||
|
mock_config = {
|
||||||
|
'control': {'use_v3': True},
|
||||||
|
'gru_v3': {} # Use defaults for building
|
||||||
|
}
|
||||||
|
mock_run_id = "test_logits_run"
|
||||||
|
mock_models_dir = "./mock_models/test_logits_run"
|
||||||
|
os.makedirs(mock_models_dir, exist_ok=True) # Create mock dir
|
||||||
|
|
||||||
|
# Import handler locally for test setup
|
||||||
|
from gru_model_handler import GRUModelHandler
|
||||||
|
handler = GRUModelHandler(run_id=mock_run_id, models_dir=mock_models_dir, config=mock_config)
|
||||||
|
handler.model = model # Assign the already built model to the handler
|
||||||
|
handler.model_version_used = 'v3' # Set version manually
|
||||||
|
# --- End Handler Setup --- #
|
||||||
|
|
||||||
|
# Create dummy input data
|
||||||
|
dummy_input = np.random.rand(BATCH_SIZE, LOOKBACK, N_FEATURES).astype(np.float32)
|
||||||
|
|
||||||
|
# Generate predictions using both methods
|
||||||
|
logits = handler.predict_logits(dummy_input)
|
||||||
|
predictions = handler.predict(dummy_input)
|
||||||
|
|
||||||
|
assert logits is not None, "predict_logits returned None"
|
||||||
|
assert predictions is not None, "predict returned None"
|
||||||
|
assert isinstance(predictions, list) and len(predictions) == 2, "predict output structure incorrect"
|
||||||
|
|
||||||
|
probs_from_predict = predictions[1] # dir3 is the second output
|
||||||
|
|
||||||
|
# Apply softmax to logits
|
||||||
|
# Use tf.nn.softmax for consistency with Keras backend
|
||||||
|
import tensorflow as tf
|
||||||
|
probs_from_logits = tf.nn.softmax(logits).numpy()
|
||||||
|
|
||||||
|
# Assert shapes match first
|
||||||
|
assert probs_from_logits.shape == probs_from_predict.shape, \
|
||||||
|
f"Shape mismatch: softmax(logits)={probs_from_logits.shape}, predict_probs={probs_from_predict.shape}"
|
||||||
|
|
||||||
|
# Assert values are close
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
probs_from_logits,
|
||||||
|
probs_from_predict,
|
||||||
|
rtol=1e-6,
|
||||||
|
atol=1e-6, # Use tighter tolerance for numerical precision check
|
||||||
|
err_msg="Softmax applied to logits does not match probability output from model.predict()"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Logits view test passed.")
|
||||||
|
# Clean up mock directory
|
||||||
|
import shutil
|
||||||
|
if os.path.exists("./mock_models"):
|
||||||
|
shutil.rmtree("./mock_models")
|
||||||
110
gru_sac_predictor/tests/test_sac_agent.py
Normal file
110
gru_sac_predictor/tests/test_sac_agent.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
"""
|
||||||
|
Tests for the SACTradingAgent class.
|
||||||
|
|
||||||
|
Ref: revisions.txt Task 5.7
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
# --- Add path for src imports --- #
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.dirname(script_dir)
|
||||||
|
src_path = os.path.join(project_root, 'src')
|
||||||
|
if src_path not in sys.path:
|
||||||
|
sys.path.insert(0, src_path)
|
||||||
|
# --- End Add path --- #
|
||||||
|
|
||||||
|
from sac_agent import SACTradingAgent
|
||||||
|
|
||||||
|
# --- Constants --- #
|
||||||
|
STATE_DIM = 5
|
||||||
|
ACTION_DIM = 1
|
||||||
|
BUFFER_SIZE = 5000
|
||||||
|
MIN_BUFFER = 1000
|
||||||
|
TRAIN_STEPS = 1500 # Number of training steps for the test
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
|
# --- Fixtures --- #
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sac_agent_fixture() -> SACTradingAgent:
|
||||||
|
"""Provides a default SACTradingAgent instance for testing."""
|
||||||
|
agent = SACTradingAgent(
|
||||||
|
state_dim=STATE_DIM,
|
||||||
|
action_dim=ACTION_DIM,
|
||||||
|
buffer_capacity=BUFFER_SIZE,
|
||||||
|
min_buffer_size=MIN_BUFFER,
|
||||||
|
alpha_auto_tune=True, # Enable auto-tuning for realistic test
|
||||||
|
target_entropy=-1.0 * ACTION_DIM # Default target entropy
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def _populate_buffer(agent: SACTradingAgent, num_samples: int):
|
||||||
|
"""Helper to add random transitions to the agent's buffer."""
|
||||||
|
print(f"\nPopulating buffer with {num_samples} random samples...")
|
||||||
|
for _ in range(num_samples):
|
||||||
|
state = np.random.randn(STATE_DIM).astype(np.float32)
|
||||||
|
action = np.random.uniform(-1, 1, size=(ACTION_DIM,)).astype(np.float32)
|
||||||
|
reward = np.random.randn()
|
||||||
|
next_state = np.random.randn(STATE_DIM).astype(np.float32)
|
||||||
|
done = float(np.random.rand() < 0.05) # 5% chance of done
|
||||||
|
agent.buffer.add(state, action, reward, next_state, done)
|
||||||
|
print(f"Buffer populated. Size: {len(agent.buffer)}")
|
||||||
|
|
||||||
|
# --- Tests --- #
|
||||||
|
|
||||||
|
def test_sac_training_updates(sac_agent_fixture):
|
||||||
|
"""
|
||||||
|
Test 5.7: Run training steps and check for basic health:
|
||||||
|
a) Q-values are not NaN.
|
||||||
|
b) Action variance is reasonable (suggests exploration).
|
||||||
|
"""
|
||||||
|
agent = sac_agent_fixture
|
||||||
|
# Populate buffer sufficiently to start training
|
||||||
|
_populate_buffer(agent, MIN_BUFFER + BATCH_SIZE)
|
||||||
|
|
||||||
|
print(f"\nRunning {TRAIN_STEPS} training steps...")
|
||||||
|
metrics_history = []
|
||||||
|
for i in range(TRAIN_STEPS):
|
||||||
|
metrics = agent.train(batch_size=BATCH_SIZE)
|
||||||
|
if metrics: # Train only runs if buffer is full enough
|
||||||
|
metrics_history.append(metrics)
|
||||||
|
# Basic check within the loop to fail fast
|
||||||
|
if i % 100 == 0 and metrics:
|
||||||
|
assert not np.isnan(metrics['critic1_loss']), f"Critic1 loss is NaN at step {i}"
|
||||||
|
assert not np.isnan(metrics['critic2_loss']), f"Critic2 loss is NaN at step {i}"
|
||||||
|
assert not np.isnan(metrics['actor_loss']), f"Actor loss is NaN at step {i}"
|
||||||
|
if agent.alpha_auto_tune:
|
||||||
|
assert not np.isnan(metrics['alpha_loss']), f"Alpha loss is NaN at step {i}"
|
||||||
|
|
||||||
|
assert len(metrics_history) > 0, "Training loop did not execute (buffer size issue?)"
|
||||||
|
print(f"Training steps completed. Last metrics: {metrics_history[-1]}")
|
||||||
|
|
||||||
|
# a) Check final Q-values (indirectly via loss)
|
||||||
|
last_metrics = metrics_history[-1]
|
||||||
|
assert not np.isnan(last_metrics['critic1_loss']), "Final Critic1 loss is NaN"
|
||||||
|
assert not np.isnan(last_metrics['critic2_loss']), "Final Critic2 loss is NaN"
|
||||||
|
# We assume if losses are not NaN, Q-values involved are also not NaN
|
||||||
|
print("Check a) Passed: Q-value losses are not NaN.")
|
||||||
|
|
||||||
|
# b) Check action variance after training
|
||||||
|
num_samples_for_variance = 500
|
||||||
|
sampled_actions = []
|
||||||
|
dummy_state = np.random.randn(STATE_DIM).astype(np.float32)
|
||||||
|
for _ in range(num_samples_for_variance):
|
||||||
|
# Sample non-deterministically to check stochastic policy variance
|
||||||
|
action = agent.get_action(dummy_state, deterministic=False)
|
||||||
|
sampled_actions.append(action)
|
||||||
|
|
||||||
|
sampled_actions = np.array(sampled_actions)
|
||||||
|
action_variance = np.var(sampled_actions, axis=0)
|
||||||
|
print(f"Action variance after {TRAIN_STEPS} steps: {action_variance}")
|
||||||
|
|
||||||
|
# Check if variance is above a threshold (e.g., 0.2 from revisions.txt)
|
||||||
|
# This threshold might need tuning based on action space scaling (-1 to 1)
|
||||||
|
min_variance_threshold = 0.2
|
||||||
|
assert np.all(action_variance > min_variance_threshold), \
|
||||||
|
f"Action variance ({action_variance}) is below threshold ({min_variance_threshold}). Exploration might be too low."
|
||||||
|
print(f"Check b) Passed: Action variance ({action_variance.round(3)}) > {min_variance_threshold}.")
|
||||||
121
gru_sac_predictor/tests/test_sac_sanity.py
Normal file
121
gru_sac_predictor/tests/test_sac_sanity.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
Sanity checks for the SAC agent (Sec 6 of revisions.txt).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Try to import the agent; skip tests if not found
|
||||||
|
try:
|
||||||
|
from gru_sac_predictor.src import sac_agent
|
||||||
|
# Need TF for tensor conversion if testing agent directly
|
||||||
|
import tensorflow as tf
|
||||||
|
except ImportError:
|
||||||
|
sac_agent = None
|
||||||
|
tf = None
|
||||||
|
|
||||||
|
# --- Fixtures ---
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sac_agent_instance():
|
||||||
|
"""
|
||||||
|
Provides a default SAC agent instance for testing.
|
||||||
|
Uses standard parameters suitable for basic checks.
|
||||||
|
"""
|
||||||
|
if sac_agent is None:
|
||||||
|
pytest.skip("SAC Agent module not found.")
|
||||||
|
# Use default params, state_dim=5 as per revisions
|
||||||
|
# Use fixed seeds for reproducibility in tests if needed inside agent
|
||||||
|
agent = sac_agent.SACTradingAgent(
|
||||||
|
state_dim=5, action_dim=1,
|
||||||
|
initial_lr=1e-4, # Use a common LR for test simplicity
|
||||||
|
buffer_capacity=1000, # Smaller buffer for testing
|
||||||
|
min_buffer_size=100,
|
||||||
|
target_entropy=-1.0
|
||||||
|
)
|
||||||
|
# Build the models eagerly
|
||||||
|
try:
|
||||||
|
agent.actor(tf.zeros((1, 5)))
|
||||||
|
agent.critic1([tf.zeros((1, 5)), tf.zeros((1, 1))])
|
||||||
|
agent.critic2([tf.zeros((1, 5)), tf.zeros((1, 1))])
|
||||||
|
# Copy weights to target networks
|
||||||
|
agent.update_target_networks(tau=1.0)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Failed to build SAC agent models: {e}")
|
||||||
|
return agent
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sample_sac_inputs():
|
||||||
|
"""
|
||||||
|
Generate sample states and corresponding directional signals.
|
||||||
|
Simulates states with varying edge and signal-to-noise.
|
||||||
|
"""
|
||||||
|
np.random.seed(44)
|
||||||
|
n_samples = 1500
|
||||||
|
# Simulate GRU outputs and position
|
||||||
|
mu = np.random.randn(n_samples) * 0.0015 # Slightly higher variance
|
||||||
|
sigma = np.random.uniform(0.0005, 0.0025, n_samples)
|
||||||
|
# Simulate edge with clearer separation for testing signals
|
||||||
|
edge_base = np.random.choice([-0.15, -0.05, 0.0, 0.05, 0.15], n_samples, p=[0.2, 0.2, 0.2, 0.2, 0.2])
|
||||||
|
edge = np.clip(edge_base + np.random.randn(n_samples) * 0.03, -1.0, 1.0)
|
||||||
|
z_score = np.abs(mu) / (sigma + 1e-9)
|
||||||
|
position = np.random.uniform(-1, 1, n_samples)
|
||||||
|
states = np.vstack([mu, sigma, edge, z_score, position]).T.astype(np.float32)
|
||||||
|
# Use a small positive/negative threshold for determining signal from edge
|
||||||
|
signals = np.where(edge > 0.02, 1, np.where(edge < -0.02, -1, 0))
|
||||||
|
return states, signals
|
||||||
|
|
||||||
|
# --- Tests ---
|
||||||
|
@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TensorFlow not found")
|
||||||
|
def test_sac_agent_default_min_buffer(sac_agent_instance):
|
||||||
|
"""Verify the default min_buffer_size is at least 10000."""
|
||||||
|
agent = sac_agent_instance
|
||||||
|
# Note: Fixture currently initializes with specific values, overriding default.
|
||||||
|
# Re-initialize with defaults for this test.
|
||||||
|
default_agent = sac_agent.SACTradingAgent(state_dim=5, action_dim=1)
|
||||||
|
min_buffer = default_agent.min_buffer_size
|
||||||
|
print(f"\nAgent default min_buffer_size: {min_buffer}")
|
||||||
|
assert min_buffer >= 10000, f"Default min_buffer_size ({min_buffer}) is less than recommended 10000."
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sac_agent is None or tf is None, reason="SAC Agent module or TensorFlow not found")
|
||||||
|
def test_sac_action_variance(sac_agent_instance, sample_sac_inputs):
|
||||||
|
"""
|
||||||
|
Verify that the mean absolute action taken when the signal is non-zero
|
||||||
|
is >= 0.05.
|
||||||
|
"""
|
||||||
|
agent = sac_agent_instance
|
||||||
|
states, signals = sample_sac_inputs
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
for state in states:
|
||||||
|
# Use deterministic action for this sanity check
|
||||||
|
action = agent.get_action(state, deterministic=True)
|
||||||
|
actions.append(action[0]) # get_action returns list/array
|
||||||
|
actions = np.array(actions)
|
||||||
|
|
||||||
|
# Filter for non-zero signals based on the *simulated* edge
|
||||||
|
non_zero_signal_idx = signals != 0
|
||||||
|
if not np.any(non_zero_signal_idx):
|
||||||
|
pytest.fail("No non-zero signals generated in fixture for SAC variance test.")
|
||||||
|
|
||||||
|
actions_on_signal = actions[non_zero_signal_idx]
|
||||||
|
|
||||||
|
if len(actions_on_signal) == 0:
|
||||||
|
# This case should ideally not happen if the above check passed
|
||||||
|
pytest.fail("Filtered actions array is empty despite non-zero signals.")
|
||||||
|
|
||||||
|
mean_abs_action = np.mean(np.abs(actions_on_signal))
|
||||||
|
|
||||||
|
print(f"\nSAC Sanity Test: Mean Absolute Action (on signal != 0): {mean_abs_action:.4f}")
|
||||||
|
|
||||||
|
# Check if the agent is outputting actions with sufficient magnitude
|
||||||
|
assert mean_abs_action >= 0.05, \
|
||||||
|
f"Mean absolute action ({mean_abs_action:.4f}) is below threshold (0.05). Agent might be too timid or stuck near zero."
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires full backtest results which are not available in this unit test setup.")
|
||||||
|
def test_sac_reward_correlation():
|
||||||
|
"""
|
||||||
|
Optional: Check if actions taken correlate positively with subsequent rewards.
|
||||||
|
NOTE: This test requires results from a full backtest run (actions vs rewards)
|
||||||
|
and cannot be reliably simulated or executed in this unit test.
|
||||||
|
"""
|
||||||
|
pass # Cannot implement without actual backtest results
|
||||||
94
gru_sac_predictor/tests/test_time_encoding.py
Normal file
94
gru_sac_predictor/tests/test_time_encoding.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
Tests for time encoding, specifically DST transitions.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import pytz # For timezone handling
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def generate_dst_timeseries():
|
||||||
|
"""
|
||||||
|
Generate a minute-frequency timestamp series crossing DST transitions
|
||||||
|
for a specific timezone (e.g., US/Eastern).
|
||||||
|
"""
|
||||||
|
# Example: US/Eastern DST Start (e.g., March 10, 2024 2:00 AM -> 3:00 AM)
|
||||||
|
# Example: US/Eastern DST End (e.g., Nov 3, 2024 2:00 AM -> 1:00 AM)
|
||||||
|
tz = pytz.timezone('US/Eastern')
|
||||||
|
|
||||||
|
# Create timestamps around DST start
|
||||||
|
dst_start_range = pd.date_range(
|
||||||
|
start='2024-03-10 01:00:00', end='2024-03-10 04:00:00', freq='T', tz=tz
|
||||||
|
)
|
||||||
|
# Create timestamps around DST end
|
||||||
|
dst_end_range = pd.date_range(
|
||||||
|
start='2024-11-03 00:00:00', end='2024-11-03 03:00:00', freq='T', tz=tz
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine and ensure uniqueness/order (though disjoint here)
|
||||||
|
timestamps = dst_start_range.union(dst_end_range)
|
||||||
|
df = pd.DataFrame(index=timestamps)
|
||||||
|
df.index.name = 'timestamp'
|
||||||
|
return df
|
||||||
|
|
||||||
|
def calculate_cyclical_features(df):
|
||||||
|
"""Helper to calculate sin/cos features from a datetime index."""
|
||||||
|
if not isinstance(df.index, pd.DatetimeIndex):
|
||||||
|
raise TypeError("Input DataFrame must have a DatetimeIndex.")
|
||||||
|
|
||||||
|
# Ensure timezone is present (fixture provides it)
|
||||||
|
if df.index.tz is None:
|
||||||
|
print("Warning: Index timezone is None, assuming UTC for calculation.")
|
||||||
|
timestamp_source = df.index.tz_localize('utc')
|
||||||
|
else:
|
||||||
|
timestamp_source = df.index
|
||||||
|
|
||||||
|
# Use UTC hour for consistent calculation if timezone handling upstream is complex
|
||||||
|
# Or use localized hour if pipeline guarantees consistent local TZ
|
||||||
|
# Here, let's use the localized hour provided by the fixture
|
||||||
|
hour_of_day = timestamp_source.hour
|
||||||
|
# minute_of_day = timestamp_source.hour * 60 + timestamp_source.minute # Alternative
|
||||||
|
|
||||||
|
df['hour_sin'] = np.sin(2 * np.pi * hour_of_day / 24)
|
||||||
|
df['hour_cos'] = np.cos(2 * np.pi * hour_of_day / 24)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def test_cyclical_features_continuity(generate_dst_timeseries):
|
||||||
|
"""
|
||||||
|
Check if hour_sin and hour_cos features are continuous (no large jumps)
|
||||||
|
across DST transitions, assuming calculation uses localized time.
|
||||||
|
If using UTC hour, continuity is guaranteed, but might not capture
|
||||||
|
local market patterns intended.
|
||||||
|
"""
|
||||||
|
df = generate_dst_timeseries
|
||||||
|
df = calculate_cyclical_features(df)
|
||||||
|
|
||||||
|
# Check differences between consecutive values
|
||||||
|
sin_diff = df['hour_sin'].diff().abs()
|
||||||
|
cos_diff = df['hour_cos'].diff().abs()
|
||||||
|
|
||||||
|
# Define a reasonable threshold for a jump (e.g., difference > value for 15 mins)
|
||||||
|
# Max change in sin(2*pi*h/24) over 1 minute is small.
|
||||||
|
# A jump of 1 hour means h changes by 1, argument changes by pi/12.
|
||||||
|
# Max diff sin(x+pi/12) - sin(x) is approx pi/12 ~ 0.26
|
||||||
|
max_allowed_diff = 0.3 # Allow slightly more than 1 hour jump equivalent
|
||||||
|
|
||||||
|
print(f"\nMax Sin Diff: {sin_diff.max():.4f}")
|
||||||
|
print(f"Max Cos Diff: {cos_diff.max():.4f}")
|
||||||
|
|
||||||
|
assert sin_diff.max() < max_allowed_diff, \
|
||||||
|
f"Large jump detected in hour_sin ({sin_diff.max():.4f}) around DST. Check time source/calculation."
|
||||||
|
assert cos_diff.max() < max_allowed_diff, \
|
||||||
|
f"Large jump detected in hour_cos ({cos_diff.max():.4f}) around DST. Check time source/calculation."
|
||||||
|
|
||||||
|
# Optional: Plot to visually inspect
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# plt.figure()
|
||||||
|
# plt.plot(df.index, df['hour_sin'], '.-.', label='sin')
|
||||||
|
# plt.plot(df.index, df['hour_cos'], '.-.', label='cos')
|
||||||
|
# plt.title('Cyclical Features Across DST')
|
||||||
|
# plt.legend()
|
||||||
|
# plt.xticks(rotation=45)
|
||||||
|
# plt.tight_layout()
|
||||||
|
# plt.show()
|
||||||
Loading…
x
Reference in New Issue
Block a user