Spaces:
Running
Running
File size: 5,552 Bytes
9ad9fa8 1e147e6 9ad9fa8 45ba04e 9ad9fa8 7afbb68 5e9fd09 9ad9fa8 1e147e6 5e9fd09 9ad9fa8 5e9fd09 9ad9fa8 1e147e6 9ad9fa8 1e147e6 5e9fd09 9ad9fa8 5e9fd09 1e147e6 9ad9fa8 1e147e6 9ad9fa8 1e147e6 9ad9fa8 aefe3f2 1e147e6 df23d02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# rag_engine.py
import os
import json
import time
import faiss
import numpy as np
import requests
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
import streamlit as st
# Load environment variables
#load_dotenv()
class RAGEngine:
def __init__(self):
# Load model for embedding
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
self.embedding_dim = 384
self.index = faiss.IndexFlatL2(self.embedding_dim)
self.texts = []
# Load documents
self.documents = self.load_documents()
self.create_vector_store()
# Hugging Face API details
self.api_token = os.getenv("HF_API_TOKEN")
self.model_url = "https://api-inference.huggingface.co/models/deepseek-ai/deepseek-llm-7b-instruct"
def load_documents(self):
docs = []
data_folder = "data"
if not os.path.exists(data_folder):
raise FileNotFoundError(f"Data folder '{data_folder}' does not exist!")
for file_name in os.listdir(data_folder):
if file_name.endswith(".json"):
with open(os.path.join(data_folder, file_name), 'r', encoding='utf-8') as f:
try:
data = json.load(f)
docs.extend(self.flatten_data(data))
except json.JSONDecodeError as e:
print(f"Error loading {file_name}: {e}")
return docs
def flatten_data(self, data):
flattened = []
if isinstance(data, list):
for item in data:
flattened.extend(self.extract_fields(item))
elif isinstance(data, dict):
flattened.extend(self.extract_fields(data))
else:
print("Unexpected data format:", type(data))
return flattened
def extract_fields(self, item):
"""Smart chunking: extract key fields instead of dumping full JSON."""
chunks = []
if isinstance(item, dict):
for key, value in item.items():
if isinstance(value, (str, int, float)):
text = f"{key}: {value}"
chunks.append({"text": text})
elif isinstance(value, dict):
for sub_key, sub_value in value.items():
text = f"{key} -> {sub_key}: {sub_value}"
chunks.append({"text": text})
elif isinstance(value, list):
for idx, sub_item in enumerate(value):
text = f"{key}[{idx}]: {sub_item}"
chunks.append({"text": text})
return chunks
def create_vector_store(self):
embeddings = []
for doc in self.documents:
emb = self.embedder.encode(doc['text'])
embeddings.append(emb)
self.texts.append(doc['text'])
embeddings = np.array(embeddings)
self.index.add(embeddings)
def search_documents(self, query, top_k=5):
query_emb = self.embedder.encode(query)
query_emb = np.expand_dims(query_emb, axis=0)
distances, indices = self.index.search(query_emb, top_k)
results = [self.texts[i] for i in indices[0] if i < len(self.texts)]
return results
def ask_deepseek(self, context, query, retries=3, wait_time=5):
prompt = (
"You are an expert Honkai Star Rail Build Advisor.\n"
"You specialize in optimizing character performance based on Light Cones, Relics, Stats, Eidolons, and Team Synergies.\n"
"Provide detailed build advice for the given query using the provided context.\n"
"Always prioritize the most effective and meta-relevant recommendations.\n\n"
"Format your answer like this:\n"
"- Best Light Cones (Top 3)\n"
"- Recommended Relic Sets and Main Stats\n"
"- Important Substats to Prioritize\n"
"- Optimal Eidolon Level (if necessary)\n"
"- Best Team Compositions (Synergies and Playstyle)\n"
"- Any Special Notes\n\n"
f"Context:\n{context}\n\n"
f"Question:\n{query}\n"
"Answer:"
)
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"parameters": {"temperature": 0.7, "max_new_tokens": 800}
}
for attempt in range(retries):
response = requests.post(self.model_url, headers=headers, json=payload)
if response.status_code == 200:
generated_text = response.json()[0]["generated_text"]
return generated_text.split("Answer:")[-1].strip()
else:
print(f"Request failed (attempt {attempt+1}/{retries}): {response.status_code}")
if attempt < retries - 1:
time.sleep(wait_time)
return f"Error: Could not get a valid response after {retries} attempts."
def answer_query(self, query):
relevant_docs = self.search_documents(query)
context = "\n".join(relevant_docs)
answer = self.ask_deepseek(context, query)
return answer
def stream_answer(self, query):
"""Streamed generation for Streamlit."""
answer = self.answer_query(query)
for word in answer.split():
yield word + " "
time.sleep(0.02) # Feel free to tweak typing speed
|