133 lines
4.8 KiB
Python
133 lines
4.8 KiB
Python
import asyncio
|
|
import pandas as pd
|
|
from datetime import datetime, date
|
|
from typing import List, Dict, Optional, Callable, Tuple
|
|
from .data_processor import MarketDataProcessor
|
|
from .rag_engine import RAGEngine
|
|
import json
|
|
from tqdm import tqdm
|
|
import asyncio
|
|
from typing import List, Dict
|
|
import logging
|
|
# Optionally configure logging
|
|
# logging.basicConfig(level=logging.INFO)
|
|
|
|
class PredictionService:
|
|
def __init__(
|
|
self,
|
|
market_data: pd.DataFrame,
|
|
training_window_size: int = 78,
|
|
inference_window_size: int = 12,
|
|
inference_offset: int = 0,
|
|
max_concurrent: int = 3,
|
|
):
|
|
self.market_data = market_data.copy()
|
|
self.market_data.columns = [col.upper() for col in self.market_data.columns]
|
|
self.processor = MarketDataProcessor(
|
|
df=self.market_data,
|
|
training_window_size=training_window_size,
|
|
inference_window_size=inference_window_size,
|
|
inference_offset=inference_offset,
|
|
)
|
|
self.engine = RAGEngine()
|
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
|
self.training_window_size = training_window_size
|
|
self.inference_window_size = inference_window_size
|
|
self.inference_offset = inference_offset
|
|
|
|
async def main(self) -> List[Dict]:
|
|
"""Coordinate prediction process using sliding windows"""
|
|
|
|
# Get prediction windows from data processor
|
|
windows = self.processor.create_prediction_windows(
|
|
)
|
|
|
|
predictions = []
|
|
# Add progress bar
|
|
with tqdm(total=len(windows),
|
|
desc=f"Processing predictions",
|
|
leave=False) as pbar:
|
|
for window in windows:
|
|
# Initialize RAG with training window
|
|
self.engine.create_vectorstore(
|
|
self.processor.combine_intervals(window["training"], is_training=True)
|
|
)
|
|
|
|
# Create current state description
|
|
current_state = self.processor.combine_intervals(window["current"])
|
|
|
|
# Predict next interval
|
|
prediction = await self._predict_next_interval(
|
|
current_state, window["target"].name
|
|
)
|
|
predictions.append(prediction)
|
|
pbar.update(1)
|
|
|
|
return predictions
|
|
|
|
async def _predict_next_interval(
|
|
self, current_state: str, target_timestamp: pd.Timestamp
|
|
) -> Dict:
|
|
"""
|
|
Predict next interval using LLM with RAG context
|
|
|
|
Args:
|
|
current_state: Combined description of current market intervals
|
|
target_timestamp: Timestamp of interval to predict
|
|
|
|
Returns:
|
|
Dict containing prediction and metadata
|
|
"""
|
|
try:
|
|
async with self.semaphore: # Added semaphore for concurrent control
|
|
# Get historical context
|
|
historical_context = self.engine.get_relevant_context(current_state)
|
|
# print(historical_context)
|
|
# Format messages for LLM
|
|
messages = [
|
|
{"role": "system", "content": self.engine.system_prompt},
|
|
{
|
|
"role": "user",
|
|
"content": f"Historical context:\n{historical_context}\n\nCurrent market state:\n{current_state}",
|
|
},
|
|
]
|
|
|
|
# Get LLM prediction
|
|
response = await self.engine.llm.ainvoke(messages)
|
|
raw_content = response.content.strip()
|
|
|
|
# Parse and validate response
|
|
if not raw_content:
|
|
raise ValueError("Empty response from LLM")
|
|
|
|
# Clean JSON format
|
|
if not raw_content.startswith("{"):
|
|
raw_content = "{" + raw_content.split("{", 1)[1]
|
|
if not raw_content.endswith("}"):
|
|
raw_content = raw_content.rsplit("}", 1)[0] + "}"
|
|
|
|
# Parse prediction
|
|
prediction = json.loads(raw_content)
|
|
prediction["timestamp_prediction"] = target_timestamp
|
|
|
|
return prediction
|
|
|
|
except Exception as e:
|
|
logging.error(f"Prediction error for {target_timestamp}: {str(e)}")
|
|
return None
|
|
|
|
async def batch_predictions(
|
|
market_data: pd.DataFrame,
|
|
training_window_size: int = 78, # 1 trading day
|
|
inference_window_size: int = 12, # 1 hour
|
|
inference_offset: int = 0,
|
|
) -> pd.DataFrame:
|
|
service = PredictionService(
|
|
market_data,
|
|
training_window_size=training_window_size,
|
|
inference_window_size=inference_window_size,
|
|
inference_offset=inference_offset,
|
|
)
|
|
predictions = await service.main(market_data, training_window_size, inference_window_size)
|
|
return pd.DataFrame(predictions)
|