2025-02-04 15:01:15 -05:00

244 lines
9.3 KiB
Python

#!/usr/bin/env python3
"""
Market Window Analysis Script
This script performs rolling window analysis on market data using the market_predictor package.
Example Usage:
python main.py \
--symbol BTC-USD \
--start-date 2024-01-01 \
--end-date 2024-01-31 \
--interval 5m \
--training-window 60 \
--inference-window 12 \
--inference-offset 0
Arguments:
--symbol: Trading pair symbol (e.g., BTC-USD)
--start-date: Analysis start date (YYYY-MM-DD)
--end-date: Analysis end date (YYYY-MM-DD)
--interval: Data interval (1m, 5m, 15m, 1h, etc.)
--training-window: Number of intervals in training window
--inference-window: Number of intervals in inference window
--inference-offset: Offset between training and inference windows
--output: Optional output file path for predictions CSV
"""
import asyncio
import argparse
from datetime import datetime
import pandas as pd
from tqdm import tqdm
import nest_asyncio
import logging
from typing import Optional
from pathlib import Path
import matplotlib.pyplot as plt
from market_predictor.market_data_fetcher import MarketDataFetcher
from market_predictor.data_processor import MarketDataProcessor
from market_predictor.prediction_service import PredictionService
from market_predictor.performance_metrics import PerformanceMetrics
from market_predictor.analysis import PredictionAnalyzer
class MarketAnalyzer:
REQUIRED_COLUMNS = {'CLOSE', 'VOLUME', 'HIGH', 'LOW', 'OPEN'}
DEFAULT_OUTPUT_DIR = Path("output")
def __init__(self, symbol: str, start_date: str, end_date: str):
self.logger = logging.getLogger(__name__)
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create descriptive directory name
dir_name = f"{symbol}_{start_date}_{end_date}_{self.timestamp}"
self.output_dir = self.DEFAULT_OUTPUT_DIR / dir_name
self.output_dir.mkdir(parents=True, exist_ok=True)
def get_output_path(self, filename: str) -> Path:
"""Get path for output file"""
return self.output_dir / filename
def validate_dataframe(self, df: pd.DataFrame) -> bool:
"""Validate DataFrame meets requirements"""
if df is None:
return False
if df.empty:
return False
return True
def prepare_market_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Standardize and validate market data"""
if not self.validate_dataframe(df):
raise ValueError("No market data provided")
# Standardize columns once
df = df.copy()
df.columns = [col.upper() for col in df.columns]
# Validate required columns
missing = self.REQUIRED_COLUMNS - set(df.columns)
if missing:
raise ValueError(f"Missing required columns: {missing}")
return df
async def process_chunk(self, chunk_data: pd.DataFrame, chunk_id: int,
training_window: int, inference_window: int,
inference_offset: int) -> Optional[pd.DataFrame]:
"""Process single data chunk"""
try:
if not self.validate_dataframe(chunk_data):
self.logger.error(f"Invalid chunk data for chunk {chunk_id}")
return None
processor = MarketDataProcessor(
df=chunk_data,
training_window_size=training_window,
inference_window_size=inference_window,
inference_offset=inference_offset
)
service = PredictionService(
market_data=processor.df,
training_window_size=training_window,
inference_window_size=inference_window,
inference_offset=inference_offset
)
predictions = await service.main()
return pd.DataFrame(predictions) if predictions else None
except Exception as e:
self.logger.error(f"Chunk {chunk_id} processing failed: {str(e)}")
return None
def is_valid_dataframe(self, df: Optional[pd.DataFrame]) -> bool:
"""Explicit DataFrame validation"""
return df is not None and not df.empty and isinstance(df, pd.DataFrame)
async def analyze(self, market_data: pd.DataFrame, **kwargs) -> Optional[pd.DataFrame]:
"""Main analysis pipeline"""
try:
if not self.is_valid_dataframe(market_data):
self.logger.error("Invalid market data provided")
return None
market_data = self.prepare_market_data(market_data)
self.logger.info(f"Processing {len(market_data)} rows")
chunk_size = kwargs.get('chunk_size', 100)
chunks = []
for i in range(0, len(market_data), chunk_size):
end_idx = min(i + chunk_size, len(market_data))
chunk_data = market_data.iloc[i:end_idx]
if self.is_valid_dataframe(chunk_data):
chunks.append((chunk_data, i//chunk_size))
tasks = [
self.process_chunk(chunk_data, chunk_id,
kwargs['training_window'],
kwargs['inference_window'],
kwargs['inference_offset'])
for chunk_data, chunk_id in chunks
]
results = []
with tqdm(total=len(tasks), desc="Processing") as pbar:
for task in asyncio.as_completed(tasks):
result = await task
if self.is_valid_dataframe(result):
results.append(result)
pbar.update(1)
if not results:
return None
predictions = pd.concat(results, ignore_index=True)
metrics = PerformanceMetrics(predictions, market_data)
predictions = metrics.get_predictions_df()
if predictions is not None:
# Save predictions
predictions_file = self.get_output_path("predictions.csv")
predictions.to_csv(predictions_file, index=False)
# Generate analysis
analyzer = PredictionAnalyzer(predictions_file)
# Save plots with correct function references
plots = [
("accuracy_over_time.png", analyzer.plot_accuracy_over_time),
("confusion_matrix.png", analyzer.plot_confusion_matrix),
("returns_distribution.png", analyzer.plot_returns_distribution),
("hourly_performance.png", analyzer.plot_hourly_performance),
("vwap_changes_comparison.png", lambda: analyzer.plot_vwap_changes_comparison(window_size=20))
]
for plot_name, plot_func in plots:
plt.figure()
plot_func()
plt.savefig(self.get_output_path(plot_name))
plt.close()
self.logger.info(f"Analysis results saved to: {self.output_dir}")
return predictions
except Exception as e:
self.logger.error(f"Analysis failed: {str(e)}")
return None
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Market Window Analysis")
parser.add_argument("--symbol", required=True, help="Trading pair symbol")
parser.add_argument("--start-date", required=True, help="Start date (YYYY-MM-DD)")
parser.add_argument("--end-date", required=True, help="End date (YYYY-MM-DD)")
parser.add_argument("--interval", default="5m", help="Data interval")
parser.add_argument("--training-window", type=int, default=60, help="Training window size")
parser.add_argument("--inference-window", type=int, default=12, help="Inference window size")
parser.add_argument("--inference-offset", type=int, default = 0, help="Inference offset")
parser.add_argument("--output", help="Output file path for predictions CSV")
return parser.parse_args()
async def main():
try:
args = parse_args()
analyzer = MarketAnalyzer(
symbol=args.symbol,
start_date=args.start_date,
end_date=args.end_date
)
# Fetch data
fetcher = MarketDataFetcher(args.symbol)
market_data = fetcher.fetch_data(
start_date=args.start_date,
end_date=args.end_date,
interval=args.interval
)
# Run analysis
predictions = await analyzer.analyze(
market_data,
training_window=args.training_window,
inference_window=args.inference_window,
inference_offset=args.inference_offset
)
if predictions is not None:
print(f"Saved predictions and analysis results to: {analyzer.output_dir}")
except Exception as e:
logging.error(f"Analysis failed: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
nest_asyncio.apply()
asyncio.run(main())