market_predictor/market_predictor/prediction_service.py
2025-02-04 15:01:15 -05:00

133 lines
4.8 KiB
Python

import asyncio
import pandas as pd
from datetime import datetime, date
from typing import List, Dict, Optional, Callable, Tuple
from .data_processor import MarketDataProcessor
from .rag_engine import RAGEngine
import json
from tqdm import tqdm
import asyncio
from typing import List, Dict
import logging
# Optionally configure logging
# logging.basicConfig(level=logging.INFO)
class PredictionService:
def __init__(
self,
market_data: pd.DataFrame,
training_window_size: int = 78,
inference_window_size: int = 12,
inference_offset: int = 0,
max_concurrent: int = 3,
):
self.market_data = market_data.copy()
self.market_data.columns = [col.upper() for col in self.market_data.columns]
self.processor = MarketDataProcessor(
df=self.market_data,
training_window_size=training_window_size,
inference_window_size=inference_window_size,
inference_offset=inference_offset,
)
self.engine = RAGEngine()
self.semaphore = asyncio.Semaphore(max_concurrent)
self.training_window_size = training_window_size
self.inference_window_size = inference_window_size
self.inference_offset = inference_offset
async def main(self) -> List[Dict]:
"""Coordinate prediction process using sliding windows"""
# Get prediction windows from data processor
windows = self.processor.create_prediction_windows(
)
predictions = []
# Add progress bar
with tqdm(total=len(windows),
desc=f"Processing predictions",
leave=False) as pbar:
for window in windows:
# Initialize RAG with training window
self.engine.create_vectorstore(
self.processor.combine_intervals(window["training"], is_training=True)
)
# Create current state description
current_state = self.processor.combine_intervals(window["current"])
# Predict next interval
prediction = await self._predict_next_interval(
current_state, window["target"].name
)
predictions.append(prediction)
pbar.update(1)
return predictions
async def _predict_next_interval(
self, current_state: str, target_timestamp: pd.Timestamp
) -> Dict:
"""
Predict next interval using LLM with RAG context
Args:
current_state: Combined description of current market intervals
target_timestamp: Timestamp of interval to predict
Returns:
Dict containing prediction and metadata
"""
try:
async with self.semaphore: # Added semaphore for concurrent control
# Get historical context
historical_context = self.engine.get_relevant_context(current_state)
# print(historical_context)
# Format messages for LLM
messages = [
{"role": "system", "content": self.engine.system_prompt},
{
"role": "user",
"content": f"Historical context:\n{historical_context}\n\nCurrent market state:\n{current_state}",
},
]
# Get LLM prediction
response = await self.engine.llm.ainvoke(messages)
raw_content = response.content.strip()
# Parse and validate response
if not raw_content:
raise ValueError("Empty response from LLM")
# Clean JSON format
if not raw_content.startswith("{"):
raw_content = "{" + raw_content.split("{", 1)[1]
if not raw_content.endswith("}"):
raw_content = raw_content.rsplit("}", 1)[0] + "}"
# Parse prediction
prediction = json.loads(raw_content)
prediction["timestamp_prediction"] = target_timestamp
return prediction
except Exception as e:
logging.error(f"Prediction error for {target_timestamp}: {str(e)}")
return None
async def batch_predictions(
market_data: pd.DataFrame,
training_window_size: int = 78, # 1 trading day
inference_window_size: int = 12, # 1 hour
inference_offset: int = 0,
) -> pd.DataFrame:
service = PredictionService(
market_data,
training_window_size=training_window_size,
inference_window_size=inference_window_size,
inference_offset=inference_offset,
)
predictions = await service.main(market_data, training_window_size, inference_window_size)
return pd.DataFrame(predictions)