101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
from langchain.prompts import ChatPromptTemplate
|
|
from langchain_community.vectorstores import FAISS
|
|
import asyncio
|
|
from typing import List, Dict
|
|
from config import OPENAI_API_KEY, MODEL_NAME, EMBEDDING_MODEL
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
import pandas as pd
|
|
import json
|
|
|
|
class RAGEngine:
|
|
def __init__(self):
|
|
self.embeddings = OpenAIEmbeddings(
|
|
openai_api_key=OPENAI_API_KEY,
|
|
model=EMBEDDING_MODEL
|
|
)
|
|
self.llm = ChatOpenAI(
|
|
openai_api_key=OPENAI_API_KEY,
|
|
model=MODEL_NAME,
|
|
temperature=0.1
|
|
)
|
|
self.vectorstore = None
|
|
|
|
# Define system prompt
|
|
self.system_prompt = """You are a technical analysis expert predicting 5-minute VWAP movements.
|
|
|
|
Analysis Requirements:
|
|
1. Compare current VWAP to moving averages (5, 20 periods)
|
|
2. Analyze volume trends vs price movement
|
|
3. Identify short-term support/resistance levels
|
|
4. Check momentum indicators (RSI trends)
|
|
5. Evaluate recent price action patterns
|
|
|
|
Confidence Score Rules:
|
|
- 0.8-1.0: Strong signals with multiple confirmations
|
|
- 0.6-0.8: Good signals with some confirmation
|
|
- 0.4-0.6: Mixed or unclear signals
|
|
- Below 0.4: Weak or contradictory signals
|
|
|
|
Volume Analysis:
|
|
- Compare current volume to 5-period average
|
|
- Check volume trend direction
|
|
- Evaluate price/volume relationship
|
|
|
|
Pattern Recognition:
|
|
- Higher highs / lower lows
|
|
- Support/resistance tests
|
|
- Volume spikes or drops
|
|
- VWAP crossovers
|
|
- Momentum divergences
|
|
|
|
Output strict JSON format:
|
|
{
|
|
"vwap_prediction_next_5min": "up" or "down",
|
|
"confidence_score": 0.0 to 1.0,
|
|
"technical_signals": [
|
|
"volume_trend": "increasing/decreasing",
|
|
"price_momentum": "positive/negative",
|
|
"vwap_trend": "above_ma/below_ma",
|
|
"support_resistance": "near_support/near_resistance"
|
|
],
|
|
"key_levels": {
|
|
"support": "price_level",
|
|
"resistance": "price_level",
|
|
"vwap_ma_cross": "price_level"
|
|
},
|
|
"reasoning": "Brief technical analysis explanation"
|
|
}"""
|
|
|
|
def create_vectorstore(self, texts: List[str]):
|
|
self.vectorstore = FAISS.from_texts(
|
|
texts,
|
|
self.embeddings,
|
|
metadatas=[{"index": i} for i in range(len(texts))]
|
|
)
|
|
|
|
async def predict(self, query: str, timestamp: pd.Timestamp) -> Dict:
|
|
try:
|
|
similar_docs = self.vectorstore.similarity_search(query, k=5)
|
|
context = "\n".join([doc.page_content for doc in similar_docs])
|
|
|
|
messages = [
|
|
SystemMessage(content=self.system_prompt),
|
|
HumanMessage(content=f"Similar Patterns:\n{context}\n\nCurrent Data:\n{query}")
|
|
]
|
|
|
|
response = await self.llm.ainvoke(messages)
|
|
|
|
# Clean and parse JSON response
|
|
json_str = response.content.replace('```json\n', '').replace('\n```', '').strip()
|
|
prediction = json.loads(json_str)
|
|
|
|
# Add timestamp
|
|
prediction['timestamp_prediction'] = timestamp
|
|
|
|
return prediction
|
|
|
|
except Exception as e:
|
|
print(f"Prediction error: {e}")
|
|
return None |