Spaces:
Running
Running
# rag_engine.py | |
import os | |
import json | |
import time | |
import faiss | |
import numpy as np | |
import requests | |
from dotenv import load_dotenv | |
from sentence_transformers import SentenceTransformer | |
import streamlit as st | |
# Load environment variables | |
#load_dotenv() | |
class RAGEngine: | |
def __init__(self): | |
# Load model for embedding | |
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
self.embedding_dim = 384 | |
self.index = faiss.IndexFlatL2(self.embedding_dim) | |
self.texts = [] | |
# Load documents | |
self.documents = self.load_documents() | |
self.create_vector_store() | |
# Hugging Face API details | |
self.api_token = os.getenv("HF_API_TOKEN") | |
self.model_url = "https://api-inference.huggingface.co/models/deepseek-ai/deepseek-llm-7b-instruct" | |
def load_documents(self): | |
docs = [] | |
data_folder = "data" | |
if not os.path.exists(data_folder): | |
raise FileNotFoundError(f"Data folder '{data_folder}' does not exist!") | |
for file_name in os.listdir(data_folder): | |
if file_name.endswith(".json"): | |
with open(os.path.join(data_folder, file_name), 'r', encoding='utf-8') as f: | |
try: | |
data = json.load(f) | |
docs.extend(self.flatten_data(data)) | |
except json.JSONDecodeError as e: | |
print(f"Error loading {file_name}: {e}") | |
return docs | |
def flatten_data(self, data): | |
flattened = [] | |
if isinstance(data, list): | |
for item in data: | |
flattened.extend(self.extract_fields(item)) | |
elif isinstance(data, dict): | |
flattened.extend(self.extract_fields(data)) | |
else: | |
print("Unexpected data format:", type(data)) | |
return flattened | |
def extract_fields(self, item): | |
"""Smart chunking: extract key fields instead of dumping full JSON.""" | |
chunks = [] | |
if isinstance(item, dict): | |
for key, value in item.items(): | |
if isinstance(value, (str, int, float)): | |
text = f"{key}: {value}" | |
chunks.append({"text": text}) | |
elif isinstance(value, dict): | |
for sub_key, sub_value in value.items(): | |
text = f"{key} -> {sub_key}: {sub_value}" | |
chunks.append({"text": text}) | |
elif isinstance(value, list): | |
for idx, sub_item in enumerate(value): | |
text = f"{key}[{idx}]: {sub_item}" | |
chunks.append({"text": text}) | |
return chunks | |
def create_vector_store(self): | |
embeddings = [] | |
for doc in self.documents: | |
emb = self.embedder.encode(doc['text']) | |
embeddings.append(emb) | |
self.texts.append(doc['text']) | |
embeddings = np.array(embeddings) | |
self.index.add(embeddings) | |
def search_documents(self, query, top_k=5): | |
query_emb = self.embedder.encode(query) | |
query_emb = np.expand_dims(query_emb, axis=0) | |
distances, indices = self.index.search(query_emb, top_k) | |
results = [self.texts[i] for i in indices[0] if i < len(self.texts)] | |
return results | |
def ask_deepseek(self, context, query, retries=3, wait_time=5): | |
prompt = ( | |
"You are an expert Honkai Star Rail Build Advisor.\n" | |
"You specialize in optimizing character performance based on Light Cones, Relics, Stats, Eidolons, and Team Synergies.\n" | |
"Provide detailed build advice for the given query using the provided context.\n" | |
"Always prioritize the most effective and meta-relevant recommendations.\n\n" | |
"Format your answer like this:\n" | |
"- Best Light Cones (Top 3)\n" | |
"- Recommended Relic Sets and Main Stats\n" | |
"- Important Substats to Prioritize\n" | |
"- Optimal Eidolon Level (if necessary)\n" | |
"- Best Team Compositions (Synergies and Playstyle)\n" | |
"- Any Special Notes\n\n" | |
f"Context:\n{context}\n\n" | |
f"Question:\n{query}\n" | |
"Answer:" | |
) | |
headers = { | |
"Authorization": f"Bearer {self.api_token}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"inputs": prompt, | |
"parameters": {"temperature": 0.7, "max_new_tokens": 800} | |
} | |
for attempt in range(retries): | |
response = requests.post(self.model_url, headers=headers, json=payload) | |
if response.status_code == 200: | |
generated_text = response.json()[0]["generated_text"] | |
return generated_text.split("Answer:")[-1].strip() | |
else: | |
print(f"Request failed (attempt {attempt+1}/{retries}): {response.status_code}") | |
if attempt < retries - 1: | |
time.sleep(wait_time) | |
return f"Error: Could not get a valid response after {retries} attempts." | |
def answer_query(self, query): | |
relevant_docs = self.search_documents(query) | |
context = "\n".join(relevant_docs) | |
answer = self.ask_deepseek(context, query) | |
return answer | |
def stream_answer(self, query): | |
"""Streamed generation for Streamlit.""" | |
answer = self.answer_query(query) | |
for word in answer.split(): | |
yield word + " " | |
time.sleep(0.02) # Feel free to tweak typing speed | |