market_predictor/rag_engine.py
2025-02-02 18:22:12 -05:00

101 lines
3.3 KiB
Python

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
import asyncio
from typing import List, Dict
from config import OPENAI_API_KEY, MODEL_NAME, EMBEDDING_MODEL
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import pandas as pd
import json
class RAGEngine:
def __init__(self):
self.embeddings = OpenAIEmbeddings(
openai_api_key=OPENAI_API_KEY,
model=EMBEDDING_MODEL
)
self.llm = ChatOpenAI(
openai_api_key=OPENAI_API_KEY,
model=MODEL_NAME,
temperature=0.1
)
self.vectorstore = None
# Define system prompt
self.system_prompt = """You are a technical analysis expert predicting 5-minute VWAP movements.
Analysis Requirements:
1. Compare current VWAP to moving averages (5, 20 periods)
2. Analyze volume trends vs price movement
3. Identify short-term support/resistance levels
4. Check momentum indicators (RSI trends)
5. Evaluate recent price action patterns
Confidence Score Rules:
- 0.8-1.0: Strong signals with multiple confirmations
- 0.6-0.8: Good signals with some confirmation
- 0.4-0.6: Mixed or unclear signals
- Below 0.4: Weak or contradictory signals
Volume Analysis:
- Compare current volume to 5-period average
- Check volume trend direction
- Evaluate price/volume relationship
Pattern Recognition:
- Higher highs / lower lows
- Support/resistance tests
- Volume spikes or drops
- VWAP crossovers
- Momentum divergences
Output strict JSON format:
{
"vwap_prediction_next_5min": "up" or "down",
"confidence_score": 0.0 to 1.0,
"technical_signals": [
"volume_trend": "increasing/decreasing",
"price_momentum": "positive/negative",
"vwap_trend": "above_ma/below_ma",
"support_resistance": "near_support/near_resistance"
],
"key_levels": {
"support": "price_level",
"resistance": "price_level",
"vwap_ma_cross": "price_level"
},
"reasoning": "Brief technical analysis explanation"
}"""
def create_vectorstore(self, texts: List[str]):
self.vectorstore = FAISS.from_texts(
texts,
self.embeddings,
metadatas=[{"index": i} for i in range(len(texts))]
)
async def predict(self, query: str, timestamp: pd.Timestamp) -> Dict:
try:
similar_docs = self.vectorstore.similarity_search(query, k=5)
context = "\n".join([doc.page_content for doc in similar_docs])
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=f"Similar Patterns:\n{context}\n\nCurrent Data:\n{query}")
]
response = await self.llm.ainvoke(messages)
# Clean and parse JSON response
json_str = response.content.replace('```json\n', '').replace('\n```', '').strip()
prediction = json.loads(json_str)
# Add timestamp
prediction['timestamp_prediction'] = timestamp
return prediction
except Exception as e:
print(f"Prediction error: {e}")
return None