318 lines
11 KiB
Python
318 lines
11 KiB
Python
"""
|
|
Market Prediction Analysis Module
|
|
===============================
|
|
|
|
This module provides tools for analyzing market prediction performance metrics and visualizations.
|
|
|
|
Features
|
|
--------
|
|
- Prediction accuracy analysis
|
|
- VWAP and price comparison
|
|
- Performance metrics calculation
|
|
- Time-series visualizations
|
|
- Error analysis
|
|
|
|
Example Usage
|
|
------------
|
|
```python
|
|
analyzer = PredictionAnalyzer("predictions.csv")
|
|
analyzer.plot_accuracy_over_time()
|
|
analyzer.plot_hourly_performance()
|
|
analyzer.plot_returns_distribution()
|
|
```
|
|
"""
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.dates as mdates
|
|
import seaborn as sns
|
|
from sklearn.metrics import confusion_matrix, classification_report
|
|
from pathlib import Path
|
|
import argparse
|
|
from datetime import datetime
|
|
|
|
class PredictionAnalyzer:
|
|
"""
|
|
This class provides methods for analyzing prediction accuracy,
|
|
comparing predicted vs actual values, and generating performance
|
|
visualizations for both VWAP and price predictions.
|
|
|
|
Attributes
|
|
----------
|
|
df : pd.DataFrame
|
|
Predictions data with timestamp index
|
|
metrics : Dict[str, Union[float, Dict]]
|
|
Calculated performance metrics
|
|
|
|
Methods
|
|
-------
|
|
plot_accuracy_over_time()
|
|
Plots prediction accuracy trends
|
|
plot_hourly_performance()
|
|
Plots hourly performance metrics
|
|
plot_returns_distribution()
|
|
Plots return distribution analysis
|
|
plot_confusion_matrix()
|
|
Plots prediction confusion matrix
|
|
"""
|
|
|
|
def __init__(self, predictions_file: Union[str, Path]) -> None:
|
|
"""
|
|
Initialize analyzer with predictions data.
|
|
|
|
Parameters
|
|
----------
|
|
predictions_file : str or Path
|
|
Path to CSV file containing predictions data
|
|
"""
|
|
self.df = pd.read_csv(predictions_file)
|
|
self.df['timestamp_prediction'] = pd.to_datetime(self.df['timestamp_prediction'])
|
|
self._calculate_metrics()
|
|
|
|
def _calculate_metrics(self) -> None:
|
|
"""
|
|
Calculate performance metrics from predictions data.
|
|
|
|
Computes:
|
|
- Cumulative returns
|
|
- Rolling accuracy
|
|
- RMSE for VWAP and price predictions
|
|
"""
|
|
self.df['cumulative_return'] = self.df['actual_return'].cumsum()
|
|
self.df['rolling_accuracy'] = (
|
|
self.df['prediction_correct'].rolling(20, min_periods=1).mean()
|
|
)
|
|
|
|
def plot_accuracy_over_time(self) -> None:
|
|
"""
|
|
Plot prediction accuracy trends over time.
|
|
|
|
Generates:
|
|
- Direction accuracy plot
|
|
- VWAP change comparison plot
|
|
"""
|
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
|
|
|
|
# Direction accuracy
|
|
ax1.plot(self.df['timestamp_prediction'],
|
|
self.df['prediction_correct'].rolling(20).mean(),
|
|
label='Direction Accuracy', color='blue')
|
|
ax1.set_title('Prediction Accuracy Over Time')
|
|
ax1.set_ylabel('Direction Accuracy')
|
|
ax1.grid(True, alpha=0.3)
|
|
ax1.legend()
|
|
|
|
# Magnitude accuracy
|
|
ax2.plot(self.df['timestamp_prediction'],
|
|
self.df['actual_vwap_change'],
|
|
label='Actual VWAP Change', alpha=0.6)
|
|
ax2.plot(self.df['timestamp_prediction'],
|
|
self.df['expected_vwap_change'],
|
|
label='Predicted VWAP Change', alpha=0.6)
|
|
ax2.set_title('VWAP Change Prediction vs Actual')
|
|
ax2.set_ylabel('VWAP Change %')
|
|
ax2.grid(True, alpha=0.3)
|
|
ax2.legend()
|
|
|
|
plt.tight_layout()
|
|
|
|
def plot_confusion_matrix(self):
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
|
|
|
# Direction confusion matrix
|
|
cm_direction = confusion_matrix(
|
|
self.df['actual_movement'],
|
|
self.df['vwap_direction_next_5min']
|
|
)
|
|
sns.heatmap(cm_direction, annot=True, fmt='d', cmap='Blues',
|
|
xticklabels=['down', 'up'],
|
|
yticklabels=['down', 'up'],
|
|
ax=ax1)
|
|
ax1.set_title('Direction Prediction Matrix')
|
|
|
|
# Magnitude error distribution
|
|
magnitude_error = self.df['actual_vwap_change'] - self.df['expected_vwap_change']
|
|
sns.histplot(magnitude_error, bins=50, ax=ax2)
|
|
ax2.set_title('VWAP Change Prediction Error')
|
|
ax2.set_xlabel('Error (Actual - Predicted)')
|
|
|
|
plt.tight_layout()
|
|
|
|
def plot_returns_distribution(self):
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
|
|
|
# Returns by prediction accuracy
|
|
sns.histplot(data=self.df, x='actual_return',
|
|
hue='prediction_correct', bins=50, ax=ax1)
|
|
ax1.set_title('Return Distribution by Direction Accuracy')
|
|
|
|
# VWAP vs Price changes
|
|
sns.scatterplot(data=self.df,
|
|
x='actual_vwap_change',
|
|
y='actual_price_change',
|
|
hue='prediction_correct',
|
|
alpha=0.6,
|
|
ax=ax2)
|
|
ax2.set_title('VWAP vs Price Changes')
|
|
ax2.set_xlabel('VWAP Change')
|
|
ax2.set_ylabel('Price Change')
|
|
|
|
plt.tight_layout()
|
|
|
|
def plot_hourly_performance(self):
|
|
"""Plot performance metrics, VWAP and price prediction accuracy"""
|
|
# Calculate RMSE for both VWAP and price
|
|
self.df['vwap_rmse'] = np.sqrt(
|
|
(self.df['actual_next_vwap'] - self.df['predicted_vwap'])**2
|
|
).rolling(window=20).mean()
|
|
|
|
self.df['price_rmse'] = np.sqrt(
|
|
(self.df['actual_next_price'] - self.df['predicted_price'])**2
|
|
).rolling(window=20).mean()
|
|
|
|
# Group metrics by hour
|
|
hourly_metrics = self.df.groupby(
|
|
pd.Grouper(key='timestamp_prediction', freq='H')
|
|
).agg({
|
|
'prediction_correct': 'mean',
|
|
'actual_return': ['mean', 'sum'],
|
|
'actual_vwap_change': 'mean',
|
|
'predicted_vwap': 'mean',
|
|
'actual_next_vwap': 'mean',
|
|
'predicted_price': 'mean',
|
|
'actual_next_price': 'mean',
|
|
'vwap_rmse': 'mean',
|
|
'price_rmse': 'mean'
|
|
}).reset_index()
|
|
|
|
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(15, 20))
|
|
|
|
# Plot 1: Direction Accuracy
|
|
ax1.plot(hourly_metrics['timestamp_prediction'],
|
|
hourly_metrics['prediction_correct'],
|
|
marker='o', linestyle='-')
|
|
ax1.set_title('Direction Prediction Accuracy')
|
|
ax1.set_ylabel('Accuracy')
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
# Plot 2: VWAP Comparison
|
|
ax2.plot(hourly_metrics['timestamp_prediction'],
|
|
hourly_metrics['predicted_vwap'],
|
|
label='Predicted VWAP', color='blue')
|
|
ax2.plot(hourly_metrics['timestamp_prediction'],
|
|
hourly_metrics['actual_next_vwap'],
|
|
label='Actual VWAP', color='red')
|
|
ax2.set_title('Predicted vs Actual VWAP')
|
|
ax2.set_ylabel('VWAP Value')
|
|
ax2.grid(True, alpha=0.3)
|
|
ax2.legend()
|
|
|
|
# Plot 3: Price Comparison
|
|
ax3.plot(hourly_metrics['timestamp_prediction'],
|
|
hourly_metrics['predicted_price'],
|
|
label='Predicted Price', color='green')
|
|
ax3.plot(hourly_metrics['timestamp_prediction'],
|
|
hourly_metrics['actual_next_price'],
|
|
label='Actual Price', color='orange')
|
|
ax3.set_title('Predicted vs Actual Price')
|
|
ax3.set_ylabel('Price Value')
|
|
ax3.grid(True, alpha=0.3)
|
|
ax3.legend()
|
|
|
|
# Plot 4: Rolling RMSE Comparison
|
|
ax4.plot(hourly_metrics['timestamp_prediction'],
|
|
hourly_metrics['vwap_rmse'],
|
|
label='VWAP RMSE', color='purple')
|
|
ax4.plot(hourly_metrics['timestamp_prediction'],
|
|
hourly_metrics['price_rmse'],
|
|
label='Price RMSE', color='brown')
|
|
ax4.set_title('Prediction RMSE (20-period rolling)')
|
|
ax4.set_ylabel('RMSE')
|
|
ax4.grid(True, alpha=0.3)
|
|
ax4.legend()
|
|
|
|
# Format datetime x-axis
|
|
for ax in [ax1, ax2, ax3, ax4]:
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
|
ax.xaxis.set_major_locator(mdates.HourLocator(interval=4))
|
|
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
|
|
|
|
plt.tight_layout()
|
|
|
|
def plot_vwap_changes_comparison(self, window_size: int = 20):
|
|
"""Plot rolling average of predicted vs actual VWAP changes"""
|
|
plt.figure(figsize=(15, 6))
|
|
|
|
# Calculate rolling averages
|
|
actual_rolling = self.df['actual_vwap_change'].rolling(window=window_size).mean()
|
|
predicted_rolling = self.df['expected_vwap_change'].rolling(window=window_size).mean()
|
|
|
|
# Plot both lines
|
|
plt.plot(self.df['timestamp_prediction'], actual_rolling,
|
|
label='Actual VWAP Change', color='blue', alpha=0.7)
|
|
plt.plot(self.df['timestamp_prediction'], predicted_rolling,
|
|
label='Predicted VWAP Change', color='red', alpha=0.7)
|
|
|
|
plt.title(f'Predicted vs Actual VWAP Changes ({window_size}-period rolling average)')
|
|
plt.xlabel('Time')
|
|
plt.ylabel('VWAP Change %')
|
|
plt.grid(True, alpha=0.3)
|
|
plt.legend()
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
def generate_report(self) -> str:
|
|
report = (
|
|
f"\nPerformance Metrics:\n"
|
|
f"Total Predictions: {len(self.df)}\n"
|
|
f"Overall Accuracy: {self.df['prediction_correct'].mean():.2%}\n"
|
|
f"Mean Return: {self.df['actual_return'].mean():.4f}\n"
|
|
f"Cumulative Return: {self.df['actual_return'].sum():.4f}\n\n"
|
|
f"Classification Report:\n"
|
|
f"{classification_report(self.df['actual_movement'], self.df['vwap_direction_next_5min'])}"
|
|
)
|
|
print(report)
|
|
return report
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('predictions_file', help='Path to predictions CSV')
|
|
parser.add_argument('--output-dir', default='analysis_output',
|
|
help='Directory for output plots')
|
|
args = parser.parse_args()
|
|
|
|
# Create output directory
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
# Analyze predictions
|
|
analyzer = PredictionAnalyzer(args.predictions_file)
|
|
analyzer.generate_report()
|
|
|
|
# Generate and save plots
|
|
analyzer.plot_accuracy_over_time()
|
|
plt.savefig(output_dir / 'accuracy_over_time.png')
|
|
plt.close()
|
|
|
|
analyzer.plot_confusion_matrix()
|
|
plt.savefig(output_dir / 'confusion_matrix.png')
|
|
plt.close()
|
|
|
|
analyzer.plot_returns_distribution()
|
|
plt.savefig(output_dir / 'returns_distribution.png')
|
|
plt.close()
|
|
|
|
analyzer.plot_hourly_performance()
|
|
plt.savefig(output_dir / 'hourly_performance.png')
|
|
plt.close()
|
|
|
|
# Add VWAP changes comparison plot
|
|
analyzer.plot_vwap_changes_comparison(window_size=20)
|
|
plt.savefig(output_dir / 'vwap_changes_comparison.png')
|
|
plt.close()
|
|
|
|
plt.show()
|
|
|
|
if __name__ == "__main__":
|
|
main() |