244 lines
9.3 KiB
Python
244 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Market Window Analysis Script
|
|
|
|
This script performs rolling window analysis on market data using the market_predictor package.
|
|
|
|
Example Usage:
|
|
python main.py \
|
|
--symbol BTC-USD \
|
|
--start-date 2024-01-01 \
|
|
--end-date 2024-01-31 \
|
|
--interval 5m \
|
|
--training-window 60 \
|
|
--inference-window 12 \
|
|
--inference-offset 0
|
|
|
|
Arguments:
|
|
--symbol: Trading pair symbol (e.g., BTC-USD)
|
|
--start-date: Analysis start date (YYYY-MM-DD)
|
|
--end-date: Analysis end date (YYYY-MM-DD)
|
|
--interval: Data interval (1m, 5m, 15m, 1h, etc.)
|
|
--training-window: Number of intervals in training window
|
|
--inference-window: Number of intervals in inference window
|
|
--inference-offset: Offset between training and inference windows
|
|
--output: Optional output file path for predictions CSV
|
|
"""
|
|
|
|
import asyncio
|
|
import argparse
|
|
from datetime import datetime
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
import nest_asyncio
|
|
import logging
|
|
from typing import Optional
|
|
from pathlib import Path
|
|
import matplotlib.pyplot as plt
|
|
|
|
from market_predictor.market_data_fetcher import MarketDataFetcher
|
|
from market_predictor.data_processor import MarketDataProcessor
|
|
from market_predictor.prediction_service import PredictionService
|
|
from market_predictor.performance_metrics import PerformanceMetrics
|
|
from market_predictor.analysis import PredictionAnalyzer
|
|
|
|
class MarketAnalyzer:
|
|
REQUIRED_COLUMNS = {'CLOSE', 'VOLUME', 'HIGH', 'LOW', 'OPEN'}
|
|
DEFAULT_OUTPUT_DIR = Path("output")
|
|
|
|
def __init__(self, symbol: str, start_date: str, end_date: str):
|
|
self.logger = logging.getLogger(__name__)
|
|
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# Create descriptive directory name
|
|
dir_name = f"{symbol}_{start_date}_{end_date}_{self.timestamp}"
|
|
self.output_dir = self.DEFAULT_OUTPUT_DIR / dir_name
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def get_output_path(self, filename: str) -> Path:
|
|
"""Get path for output file"""
|
|
return self.output_dir / filename
|
|
|
|
def validate_dataframe(self, df: pd.DataFrame) -> bool:
|
|
"""Validate DataFrame meets requirements"""
|
|
if df is None:
|
|
return False
|
|
if df.empty:
|
|
return False
|
|
return True
|
|
|
|
def prepare_market_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Standardize and validate market data"""
|
|
if not self.validate_dataframe(df):
|
|
raise ValueError("No market data provided")
|
|
|
|
# Standardize columns once
|
|
df = df.copy()
|
|
df.columns = [col.upper() for col in df.columns]
|
|
|
|
# Validate required columns
|
|
missing = self.REQUIRED_COLUMNS - set(df.columns)
|
|
if missing:
|
|
raise ValueError(f"Missing required columns: {missing}")
|
|
|
|
return df
|
|
|
|
async def process_chunk(self, chunk_data: pd.DataFrame, chunk_id: int,
|
|
training_window: int, inference_window: int,
|
|
inference_offset: int) -> Optional[pd.DataFrame]:
|
|
"""Process single data chunk"""
|
|
try:
|
|
if not self.validate_dataframe(chunk_data):
|
|
self.logger.error(f"Invalid chunk data for chunk {chunk_id}")
|
|
return None
|
|
|
|
processor = MarketDataProcessor(
|
|
df=chunk_data,
|
|
training_window_size=training_window,
|
|
inference_window_size=inference_window,
|
|
inference_offset=inference_offset
|
|
)
|
|
|
|
service = PredictionService(
|
|
market_data=processor.df,
|
|
training_window_size=training_window,
|
|
inference_window_size=inference_window,
|
|
inference_offset=inference_offset
|
|
)
|
|
|
|
predictions = await service.main()
|
|
return pd.DataFrame(predictions) if predictions else None
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Chunk {chunk_id} processing failed: {str(e)}")
|
|
return None
|
|
|
|
def is_valid_dataframe(self, df: Optional[pd.DataFrame]) -> bool:
|
|
"""Explicit DataFrame validation"""
|
|
return df is not None and not df.empty and isinstance(df, pd.DataFrame)
|
|
|
|
async def analyze(self, market_data: pd.DataFrame, **kwargs) -> Optional[pd.DataFrame]:
|
|
"""Main analysis pipeline"""
|
|
try:
|
|
if not self.is_valid_dataframe(market_data):
|
|
self.logger.error("Invalid market data provided")
|
|
return None
|
|
|
|
market_data = self.prepare_market_data(market_data)
|
|
self.logger.info(f"Processing {len(market_data)} rows")
|
|
|
|
chunk_size = kwargs.get('chunk_size', 100)
|
|
chunks = []
|
|
|
|
for i in range(0, len(market_data), chunk_size):
|
|
end_idx = min(i + chunk_size, len(market_data))
|
|
chunk_data = market_data.iloc[i:end_idx]
|
|
|
|
if self.is_valid_dataframe(chunk_data):
|
|
chunks.append((chunk_data, i//chunk_size))
|
|
|
|
tasks = [
|
|
self.process_chunk(chunk_data, chunk_id,
|
|
kwargs['training_window'],
|
|
kwargs['inference_window'],
|
|
kwargs['inference_offset'])
|
|
for chunk_data, chunk_id in chunks
|
|
]
|
|
|
|
results = []
|
|
with tqdm(total=len(tasks), desc="Processing") as pbar:
|
|
for task in asyncio.as_completed(tasks):
|
|
result = await task
|
|
if self.is_valid_dataframe(result):
|
|
results.append(result)
|
|
pbar.update(1)
|
|
|
|
if not results:
|
|
return None
|
|
|
|
predictions = pd.concat(results, ignore_index=True)
|
|
metrics = PerformanceMetrics(predictions, market_data)
|
|
predictions = metrics.get_predictions_df()
|
|
|
|
if predictions is not None:
|
|
# Save predictions
|
|
predictions_file = self.get_output_path("predictions.csv")
|
|
predictions.to_csv(predictions_file, index=False)
|
|
|
|
# Generate analysis
|
|
analyzer = PredictionAnalyzer(predictions_file)
|
|
|
|
# Save plots with correct function references
|
|
plots = [
|
|
("accuracy_over_time.png", analyzer.plot_accuracy_over_time),
|
|
("confusion_matrix.png", analyzer.plot_confusion_matrix),
|
|
("returns_distribution.png", analyzer.plot_returns_distribution),
|
|
("hourly_performance.png", analyzer.plot_hourly_performance),
|
|
("vwap_changes_comparison.png", lambda: analyzer.plot_vwap_changes_comparison(window_size=20))
|
|
|
|
]
|
|
|
|
for plot_name, plot_func in plots:
|
|
plt.figure()
|
|
plot_func()
|
|
plt.savefig(self.get_output_path(plot_name))
|
|
plt.close()
|
|
|
|
self.logger.info(f"Analysis results saved to: {self.output_dir}")
|
|
return predictions
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Analysis failed: {str(e)}")
|
|
return None
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(description="Market Window Analysis")
|
|
|
|
parser.add_argument("--symbol", required=True, help="Trading pair symbol")
|
|
parser.add_argument("--start-date", required=True, help="Start date (YYYY-MM-DD)")
|
|
parser.add_argument("--end-date", required=True, help="End date (YYYY-MM-DD)")
|
|
parser.add_argument("--interval", default="5m", help="Data interval")
|
|
parser.add_argument("--training-window", type=int, default=60, help="Training window size")
|
|
parser.add_argument("--inference-window", type=int, default=12, help="Inference window size")
|
|
parser.add_argument("--inference-offset", type=int, default = 0, help="Inference offset")
|
|
parser.add_argument("--output", help="Output file path for predictions CSV")
|
|
|
|
return parser.parse_args()
|
|
|
|
async def main():
|
|
try:
|
|
args = parse_args()
|
|
analyzer = MarketAnalyzer(
|
|
symbol=args.symbol,
|
|
start_date=args.start_date,
|
|
end_date=args.end_date
|
|
)
|
|
|
|
# Fetch data
|
|
fetcher = MarketDataFetcher(args.symbol)
|
|
market_data = fetcher.fetch_data(
|
|
start_date=args.start_date,
|
|
end_date=args.end_date,
|
|
interval=args.interval
|
|
)
|
|
|
|
# Run analysis
|
|
predictions = await analyzer.analyze(
|
|
market_data,
|
|
training_window=args.training_window,
|
|
inference_window=args.inference_window,
|
|
inference_offset=args.inference_offset
|
|
)
|
|
|
|
if predictions is not None:
|
|
print(f"Saved predictions and analysis results to: {analyzer.output_dir}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Analysis failed: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
nest_asyncio.apply()
|
|
asyncio.run(main()) |