Spaces:
Sleeping
Sleeping
Update rag_engine.py
Browse files- rag_engine.py +82 -151
rag_engine.py
CHANGED
@@ -22,46 +22,14 @@ local_faiss_index_file = "faiss_index.faiss"
|
|
22 |
local_text_chunks_file = "text_chunks.txt"
|
23 |
local_metadata_file = "metadata.jsonl"
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
credentials = setup_gcp_auth()
|
29 |
-
|
30 |
-
# Get bucket name from secrets - required
|
31 |
-
try:
|
32 |
-
bucket_name_gcs = st.secrets["BUCKET_NAME_GCS"]
|
33 |
-
except KeyError:
|
34 |
-
print("β Error: GCS bucket name not found in secrets")
|
35 |
-
return None
|
36 |
-
|
37 |
-
storage_client = storage.Client(credentials=credentials)
|
38 |
-
bucket = storage_client.bucket(bucket_name_gcs)
|
39 |
-
print("β
GCP client initialized successfully")
|
40 |
-
return bucket
|
41 |
-
except Exception as e:
|
42 |
-
print(f"β GCP client initialization error: {str(e)}")
|
43 |
-
return None
|
44 |
|
45 |
-
|
46 |
-
def
|
|
|
47 |
try:
|
48 |
-
setup_openai_auth()
|
49 |
-
print("β
OpenAI client initialized successfully")
|
50 |
-
return True
|
51 |
-
except Exception as e:
|
52 |
-
print(f"β OpenAI client initialization error: {str(e)}")
|
53 |
-
return False
|
54 |
-
|
55 |
-
def load_model():
|
56 |
-
"""Load the embedding model and store in session state"""
|
57 |
-
try:
|
58 |
-
# Check if model already loaded
|
59 |
-
if 'model' in st.session_state and st.session_state.model is not None:
|
60 |
-
print("Model already loaded in session state")
|
61 |
-
return st.session_state.tokenizer, st.session_state.model
|
62 |
-
|
63 |
-
print("Loading new model instance...")
|
64 |
-
|
65 |
# Force model to CPU
|
66 |
device = torch.device("cpu")
|
67 |
|
@@ -79,49 +47,24 @@ def load_model():
|
|
79 |
torch_dtype=torch.float16
|
80 |
)
|
81 |
|
82 |
-
# Move to CPU and set to eval mode
|
83 |
model = model.to(device)
|
84 |
model.eval()
|
85 |
|
86 |
# Disable gradient computation
|
87 |
torch.set_grad_enabled(False)
|
88 |
|
89 |
-
|
90 |
-
st.session_state.tokenizer = tokenizer
|
91 |
-
st.session_state.model = model
|
92 |
-
|
93 |
-
print("β
Model loaded successfully")
|
94 |
return tokenizer, model
|
95 |
|
96 |
except Exception as e:
|
97 |
print(f"β Error loading model: {str(e)}")
|
98 |
-
# Return None values - don't raise exception
|
99 |
return None, None
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
if os.path.exists(local_path):
|
106 |
-
print(f"File already exists locally: {local_path}")
|
107 |
-
return True
|
108 |
-
|
109 |
-
blob = bucket.blob(gcs_path)
|
110 |
-
blob.download_to_filename(local_path)
|
111 |
-
print(f"β
Downloaded {gcs_path} β {local_path}")
|
112 |
-
return True
|
113 |
-
except Exception as e:
|
114 |
-
print(f"β Error downloading {gcs_path}: {str(e)}")
|
115 |
-
return False
|
116 |
-
|
117 |
-
def load_data_files():
|
118 |
-
"""Load FAISS index, text chunks, and metadata"""
|
119 |
-
# Check if already loaded in session state
|
120 |
-
if 'faiss_index' in st.session_state and st.session_state.faiss_index is not None:
|
121 |
-
print("Using cached data files from session state")
|
122 |
-
return st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict
|
123 |
-
|
124 |
-
# Initialize clients
|
125 |
bucket = setup_gcp_client()
|
126 |
openai_initialized = setup_openai_client()
|
127 |
|
@@ -139,7 +82,7 @@ def load_data_files():
|
|
139 |
print(f"β Error: Required GCS path not found in secrets: {e}")
|
140 |
return None, None, None
|
141 |
|
142 |
-
# Download necessary files
|
143 |
success = True
|
144 |
success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file)
|
145 |
success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file)
|
@@ -158,7 +101,7 @@ def load_data_files():
|
|
158 |
|
159 |
# Load text chunks
|
160 |
try:
|
161 |
-
text_chunks = {} #
|
162 |
with open(local_text_chunks_file, "r", encoding="utf-8") as f:
|
163 |
for line in f:
|
164 |
parts = line.strip().split("\t")
|
@@ -179,45 +122,74 @@ def load_data_files():
|
|
179 |
print(f"β Error loading metadata: {str(e)}")
|
180 |
return None, None, None
|
181 |
|
182 |
-
print(f"β
Data loaded successfully: {len(text_chunks)} passages available")
|
183 |
-
|
184 |
-
# Store in session state
|
185 |
-
st.session_state.faiss_index = faiss_index
|
186 |
-
st.session_state.text_chunks = text_chunks
|
187 |
-
st.session_state.metadata_dict = metadata_dict
|
188 |
-
|
189 |
return faiss_index, text_chunks, metadata_dict
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
def average_pool(last_hidden_states, attention_mask):
|
192 |
"""Average pooling for sentence embeddings."""
|
193 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
194 |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
195 |
|
196 |
-
#
|
197 |
query_embedding_cache = {}
|
198 |
|
199 |
def get_embedding(text):
|
200 |
-
"""Generate embeddings for a text query"""
|
201 |
-
# Check cache first
|
202 |
if text in query_embedding_cache:
|
203 |
return query_embedding_cache[text]
|
204 |
|
205 |
try:
|
206 |
-
|
207 |
-
if 'model' not in st.session_state or st.session_state.model is None:
|
208 |
-
tokenizer, model = load_model()
|
209 |
-
else:
|
210 |
-
tokenizer, model = st.session_state.tokenizer, st.session_state.model
|
211 |
-
|
212 |
-
# Handle model load failure
|
213 |
if model is None:
|
214 |
print("Model is None, returning zero embedding")
|
215 |
return np.zeros((1, 384), dtype=np.float32)
|
216 |
|
217 |
-
# Prepare text
|
218 |
input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
|
219 |
-
|
220 |
-
# Tokenize
|
221 |
inputs = tokenizer(
|
222 |
input_text,
|
223 |
padding=True,
|
@@ -226,19 +198,13 @@ def get_embedding(text):
|
|
226 |
max_length=512,
|
227 |
return_attention_mask=True
|
228 |
)
|
229 |
-
|
230 |
-
# Generate embeddings
|
231 |
with torch.no_grad():
|
232 |
outputs = model(**inputs)
|
233 |
embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
234 |
embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
|
235 |
embeddings = embeddings.detach().cpu().numpy()
|
236 |
-
|
237 |
-
# Clean up
|
238 |
del outputs, inputs
|
239 |
gc.collect()
|
240 |
-
|
241 |
-
# Cache and return
|
242 |
query_embedding_cache[text] = embeddings
|
243 |
return embeddings
|
244 |
except Exception as e:
|
@@ -246,49 +212,31 @@ def get_embedding(text):
|
|
246 |
return np.zeros((1, 384), dtype=np.float32)
|
247 |
|
248 |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
|
249 |
-
"""Retrieve top-k most relevant passages using FAISS
|
250 |
try:
|
251 |
print(f"\nπ Retrieving passages for query: {query}")
|
252 |
-
|
253 |
-
# Get query embedding
|
254 |
query_embedding = get_embedding(query)
|
255 |
-
|
256 |
-
# Search in FAISS index
|
257 |
distances, indices = faiss_index.search(query_embedding, top_k * 2)
|
258 |
-
|
259 |
print(f"Found {len(distances[0])} potential matches")
|
260 |
retrieved_passages = []
|
261 |
retrieved_sources = []
|
262 |
cited_titles = set()
|
263 |
-
|
264 |
-
# Process results
|
265 |
for dist, idx in zip(distances[0], indices[0]):
|
266 |
print(f"Distance: {dist:.4f}, Index: {idx}")
|
267 |
if idx in text_chunks and dist >= similarity_threshold:
|
268 |
title_with_txt, author, text = text_chunks[idx]
|
269 |
-
|
270 |
-
# Clean title
|
271 |
clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
|
272 |
clean_title = unicodedata.normalize("NFC", clean_title)
|
273 |
-
|
274 |
-
# Skip duplicates
|
275 |
if clean_title in cited_titles:
|
276 |
continue
|
277 |
-
|
278 |
-
# Get metadata
|
279 |
metadata_entry = metadata_dict.get(clean_title, {})
|
280 |
author = metadata_entry.get("Author", "Unknown")
|
281 |
publisher = metadata_entry.get("Publisher", "Unknown")
|
282 |
-
|
283 |
-
# Add to results
|
284 |
cited_titles.add(clean_title)
|
285 |
retrieved_passages.append(text)
|
286 |
-
retrieved_sources.append((clean_title, author, publisher))
|
287 |
-
|
288 |
-
# Stop if we have enough
|
289 |
if len(retrieved_passages) == top_k:
|
290 |
break
|
291 |
-
|
292 |
print(f"Retrieved {len(retrieved_passages)} passages")
|
293 |
return retrieved_passages, retrieved_sources
|
294 |
except Exception as e:
|
@@ -296,39 +244,31 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
|
|
296 |
return [], []
|
297 |
|
298 |
def answer_with_llm(query, context=None, word_limit=100):
|
299 |
-
"""Generate an answer using OpenAI GPT model with formatted citations."""
|
300 |
try:
|
301 |
-
# Format context
|
302 |
if context:
|
303 |
formatted_contexts = []
|
304 |
total_chars = 0
|
305 |
max_context_chars = 4000
|
306 |
-
|
307 |
for (title, author, publisher), text in context:
|
308 |
remaining_space = max(0, max_context_chars - total_chars)
|
309 |
excerpt_len = min(150, remaining_space)
|
310 |
-
|
311 |
if excerpt_len > 50:
|
312 |
excerpt = text[:excerpt_len].strip() + "..." if len(text) > excerpt_len else text
|
313 |
formatted_context = f"[{title} by {author}, Published by {publisher}] {excerpt}"
|
314 |
formatted_contexts.append(formatted_context)
|
315 |
total_chars += len(formatted_context)
|
316 |
-
|
317 |
if total_chars >= max_context_chars:
|
318 |
break
|
319 |
-
|
320 |
formatted_context = "\n".join(formatted_contexts)
|
321 |
else:
|
322 |
formatted_context = "No relevant information available."
|
323 |
|
324 |
-
# System message
|
325 |
system_message = (
|
326 |
"You are an AI specialized in Indian spiritual texts. "
|
327 |
"Answer based on context, summarizing ideas rather than quoting verbatim. "
|
328 |
"Ensure proper citation and do not include direct excerpts."
|
329 |
)
|
330 |
-
|
331 |
-
# User message
|
332 |
user_message = f"""
|
333 |
Context:
|
334 |
{formatted_context}
|
@@ -336,14 +276,12 @@ def answer_with_llm(query, context=None, word_limit=100):
|
|
336 |
{query}
|
337 |
"""
|
338 |
|
339 |
-
# Get LLM model from secrets
|
340 |
try:
|
341 |
llm_model = st.secrets["LLM_MODEL"]
|
342 |
except KeyError:
|
343 |
print("β Error: LLM model not found in secrets")
|
344 |
return "I apologize, but I'm unable to answer at the moment."
|
345 |
|
346 |
-
# Call OpenAI API
|
347 |
response = openai.chat.completions.create(
|
348 |
model=llm_model,
|
349 |
messages=[
|
@@ -353,50 +291,42 @@ def answer_with_llm(query, context=None, word_limit=100):
|
|
353 |
max_tokens=200,
|
354 |
temperature=0.7
|
355 |
)
|
356 |
-
|
357 |
answer = response.choices[0].message.content.strip()
|
358 |
-
|
359 |
-
# Enforce word limit
|
360 |
words = answer.split()
|
361 |
if len(words) > word_limit:
|
362 |
answer = " ".join(words[:word_limit])
|
363 |
if not answer.endswith((".", "!", "?")):
|
364 |
answer += "."
|
365 |
-
|
366 |
return answer
|
367 |
-
|
368 |
except Exception as e:
|
369 |
print(f"β LLM API error: {str(e)}")
|
370 |
return "I apologize, but I'm unable to answer at the moment."
|
371 |
|
372 |
def format_citations(sources):
|
373 |
-
"""Format citations
|
374 |
formatted_citations = []
|
375 |
for title, author, publisher in sources:
|
376 |
-
# Check if the publisher already ends with a period, question mark, or exclamation mark
|
377 |
if publisher.endswith(('.', '!', '?')):
|
378 |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}")
|
379 |
else:
|
380 |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}.")
|
381 |
-
|
382 |
return "\n".join(formatted_citations)
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
392 |
if faiss_index is None or text_chunks is None or metadata_dict is None:
|
393 |
return {
|
394 |
"query": query,
|
395 |
"answer_with_rag": "β οΈ System error: Data files not loaded properly.",
|
396 |
"citations": "No citations available."
|
397 |
}
|
398 |
-
|
399 |
-
# Get relevant passages
|
400 |
retrieved_context, retrieved_sources = retrieve_passages(
|
401 |
query,
|
402 |
faiss_index,
|
@@ -404,15 +334,16 @@ def process_query(query, top_k=5, word_limit=100):
|
|
404 |
metadata_dict,
|
405 |
top_k=top_k
|
406 |
)
|
407 |
-
|
408 |
-
# Format citations
|
409 |
sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available."
|
410 |
-
|
411 |
-
# Generate answer
|
412 |
if retrieved_context:
|
413 |
context_with_sources = list(zip(retrieved_sources, retrieved_context))
|
414 |
llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit)
|
415 |
else:
|
416 |
llm_answer_with_rag = "β οΈ No relevant context found."
|
|
|
417 |
|
418 |
-
|
|
|
|
|
|
|
|
|
|
22 |
local_text_chunks_file = "text_chunks.txt"
|
23 |
local_metadata_file = "metadata.jsonl"
|
24 |
|
25 |
+
# =============================================================================
|
26 |
+
# RESOURCE CACHING
|
27 |
+
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
@st.cache_resource(show_spinner=False)
|
30 |
+
def cached_load_model():
|
31 |
+
"""Cached version of load_model() for embedding model loading."""
|
32 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
# Force model to CPU
|
34 |
device = torch.device("cpu")
|
35 |
|
|
|
47 |
torch_dtype=torch.float16
|
48 |
)
|
49 |
|
50 |
+
# Move model to CPU and set to eval mode
|
51 |
model = model.to(device)
|
52 |
model.eval()
|
53 |
|
54 |
# Disable gradient computation
|
55 |
torch.set_grad_enabled(False)
|
56 |
|
57 |
+
print("β
Model loaded successfully (cached)")
|
|
|
|
|
|
|
|
|
58 |
return tokenizer, model
|
59 |
|
60 |
except Exception as e:
|
61 |
print(f"β Error loading model: {str(e)}")
|
|
|
62 |
return None, None
|
63 |
|
64 |
+
@st.cache_resource(show_spinner=False)
|
65 |
+
def cached_load_data_files():
|
66 |
+
"""Cached version of load_data_files() for FAISS index, text chunks, and metadata."""
|
67 |
+
# Initialize GCP and OpenAI clients
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
bucket = setup_gcp_client()
|
69 |
openai_initialized = setup_openai_client()
|
70 |
|
|
|
82 |
print(f"β Error: Required GCS path not found in secrets: {e}")
|
83 |
return None, None, None
|
84 |
|
85 |
+
# Download necessary files if not already present locally
|
86 |
success = True
|
87 |
success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file)
|
88 |
success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file)
|
|
|
101 |
|
102 |
# Load text chunks
|
103 |
try:
|
104 |
+
text_chunks = {} # Mapping: ID -> (Title, Author, Text)
|
105 |
with open(local_text_chunks_file, "r", encoding="utf-8") as f:
|
106 |
for line in f:
|
107 |
parts = line.strip().split("\t")
|
|
|
122 |
print(f"β Error loading metadata: {str(e)}")
|
123 |
return None, None, None
|
124 |
|
125 |
+
print(f"β
Data loaded successfully (cached): {len(text_chunks)} passages available")
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
return faiss_index, text_chunks, metadata_dict
|
127 |
|
128 |
+
# =============================================================================
|
129 |
+
# UTILITY FUNCTIONS
|
130 |
+
# =============================================================================
|
131 |
+
|
132 |
+
def setup_gcp_client():
|
133 |
+
try:
|
134 |
+
credentials = setup_gcp_auth()
|
135 |
+
try:
|
136 |
+
bucket_name_gcs = st.secrets["BUCKET_NAME_GCS"]
|
137 |
+
except KeyError:
|
138 |
+
print("β Error: GCS bucket name not found in secrets")
|
139 |
+
return None
|
140 |
+
|
141 |
+
storage_client = storage.Client(credentials=credentials)
|
142 |
+
bucket = storage_client.bucket(bucket_name_gcs)
|
143 |
+
print("β
GCP client initialized successfully")
|
144 |
+
return bucket
|
145 |
+
except Exception as e:
|
146 |
+
print(f"β GCP client initialization error: {str(e)}")
|
147 |
+
return None
|
148 |
+
|
149 |
+
def setup_openai_client():
|
150 |
+
try:
|
151 |
+
setup_openai_auth()
|
152 |
+
print("β
OpenAI client initialized successfully")
|
153 |
+
return True
|
154 |
+
except Exception as e:
|
155 |
+
print(f"β OpenAI client initialization error: {str(e)}")
|
156 |
+
return False
|
157 |
+
|
158 |
+
def download_file_from_gcs(bucket, gcs_path, local_path):
|
159 |
+
"""Download a file from GCS to local storage if not already present."""
|
160 |
+
try:
|
161 |
+
if os.path.exists(local_path):
|
162 |
+
print(f"File already exists locally: {local_path}")
|
163 |
+
return True
|
164 |
+
|
165 |
+
blob = bucket.blob(gcs_path)
|
166 |
+
blob.download_to_filename(local_path)
|
167 |
+
print(f"β
Downloaded {gcs_path} β {local_path}")
|
168 |
+
return True
|
169 |
+
except Exception as e:
|
170 |
+
print(f"β Error downloading {gcs_path}: {str(e)}")
|
171 |
+
return False
|
172 |
+
|
173 |
def average_pool(last_hidden_states, attention_mask):
|
174 |
"""Average pooling for sentence embeddings."""
|
175 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
176 |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
177 |
|
178 |
+
# In-memory cache for query embeddings
|
179 |
query_embedding_cache = {}
|
180 |
|
181 |
def get_embedding(text):
|
182 |
+
"""Generate embeddings for a text query using the cached model."""
|
|
|
183 |
if text in query_embedding_cache:
|
184 |
return query_embedding_cache[text]
|
185 |
|
186 |
try:
|
187 |
+
tokenizer, model = cached_load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
if model is None:
|
189 |
print("Model is None, returning zero embedding")
|
190 |
return np.zeros((1, 384), dtype=np.float32)
|
191 |
|
|
|
192 |
input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
|
|
|
|
|
193 |
inputs = tokenizer(
|
194 |
input_text,
|
195 |
padding=True,
|
|
|
198 |
max_length=512,
|
199 |
return_attention_mask=True
|
200 |
)
|
|
|
|
|
201 |
with torch.no_grad():
|
202 |
outputs = model(**inputs)
|
203 |
embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
204 |
embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
|
205 |
embeddings = embeddings.detach().cpu().numpy()
|
|
|
|
|
206 |
del outputs, inputs
|
207 |
gc.collect()
|
|
|
|
|
208 |
query_embedding_cache[text] = embeddings
|
209 |
return embeddings
|
210 |
except Exception as e:
|
|
|
212 |
return np.zeros((1, 384), dtype=np.float32)
|
213 |
|
214 |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
|
215 |
+
"""Retrieve top-k most relevant passages using FAISS and accompanying metadata."""
|
216 |
try:
|
217 |
print(f"\nπ Retrieving passages for query: {query}")
|
|
|
|
|
218 |
query_embedding = get_embedding(query)
|
|
|
|
|
219 |
distances, indices = faiss_index.search(query_embedding, top_k * 2)
|
|
|
220 |
print(f"Found {len(distances[0])} potential matches")
|
221 |
retrieved_passages = []
|
222 |
retrieved_sources = []
|
223 |
cited_titles = set()
|
|
|
|
|
224 |
for dist, idx in zip(distances[0], indices[0]):
|
225 |
print(f"Distance: {dist:.4f}, Index: {idx}")
|
226 |
if idx in text_chunks and dist >= similarity_threshold:
|
227 |
title_with_txt, author, text = text_chunks[idx]
|
|
|
|
|
228 |
clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
|
229 |
clean_title = unicodedata.normalize("NFC", clean_title)
|
|
|
|
|
230 |
if clean_title in cited_titles:
|
231 |
continue
|
|
|
|
|
232 |
metadata_entry = metadata_dict.get(clean_title, {})
|
233 |
author = metadata_entry.get("Author", "Unknown")
|
234 |
publisher = metadata_entry.get("Publisher", "Unknown")
|
|
|
|
|
235 |
cited_titles.add(clean_title)
|
236 |
retrieved_passages.append(text)
|
237 |
+
retrieved_sources.append((clean_title, author, publisher))
|
|
|
|
|
238 |
if len(retrieved_passages) == top_k:
|
239 |
break
|
|
|
240 |
print(f"Retrieved {len(retrieved_passages)} passages")
|
241 |
return retrieved_passages, retrieved_sources
|
242 |
except Exception as e:
|
|
|
244 |
return [], []
|
245 |
|
246 |
def answer_with_llm(query, context=None, word_limit=100):
|
247 |
+
"""Generate an answer using the OpenAI GPT model with formatted citations."""
|
248 |
try:
|
|
|
249 |
if context:
|
250 |
formatted_contexts = []
|
251 |
total_chars = 0
|
252 |
max_context_chars = 4000
|
|
|
253 |
for (title, author, publisher), text in context:
|
254 |
remaining_space = max(0, max_context_chars - total_chars)
|
255 |
excerpt_len = min(150, remaining_space)
|
|
|
256 |
if excerpt_len > 50:
|
257 |
excerpt = text[:excerpt_len].strip() + "..." if len(text) > excerpt_len else text
|
258 |
formatted_context = f"[{title} by {author}, Published by {publisher}] {excerpt}"
|
259 |
formatted_contexts.append(formatted_context)
|
260 |
total_chars += len(formatted_context)
|
|
|
261 |
if total_chars >= max_context_chars:
|
262 |
break
|
|
|
263 |
formatted_context = "\n".join(formatted_contexts)
|
264 |
else:
|
265 |
formatted_context = "No relevant information available."
|
266 |
|
|
|
267 |
system_message = (
|
268 |
"You are an AI specialized in Indian spiritual texts. "
|
269 |
"Answer based on context, summarizing ideas rather than quoting verbatim. "
|
270 |
"Ensure proper citation and do not include direct excerpts."
|
271 |
)
|
|
|
|
|
272 |
user_message = f"""
|
273 |
Context:
|
274 |
{formatted_context}
|
|
|
276 |
{query}
|
277 |
"""
|
278 |
|
|
|
279 |
try:
|
280 |
llm_model = st.secrets["LLM_MODEL"]
|
281 |
except KeyError:
|
282 |
print("β Error: LLM model not found in secrets")
|
283 |
return "I apologize, but I'm unable to answer at the moment."
|
284 |
|
|
|
285 |
response = openai.chat.completions.create(
|
286 |
model=llm_model,
|
287 |
messages=[
|
|
|
291 |
max_tokens=200,
|
292 |
temperature=0.7
|
293 |
)
|
|
|
294 |
answer = response.choices[0].message.content.strip()
|
|
|
|
|
295 |
words = answer.split()
|
296 |
if len(words) > word_limit:
|
297 |
answer = " ".join(words[:word_limit])
|
298 |
if not answer.endswith((".", "!", "?")):
|
299 |
answer += "."
|
|
|
300 |
return answer
|
|
|
301 |
except Exception as e:
|
302 |
print(f"β LLM API error: {str(e)}")
|
303 |
return "I apologize, but I'm unable to answer at the moment."
|
304 |
|
305 |
def format_citations(sources):
|
306 |
+
"""Format citations so that each appears on a new line, ending with proper punctuation."""
|
307 |
formatted_citations = []
|
308 |
for title, author, publisher in sources:
|
|
|
309 |
if publisher.endswith(('.', '!', '?')):
|
310 |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}")
|
311 |
else:
|
312 |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}.")
|
|
|
313 |
return "\n".join(formatted_citations)
|
314 |
|
315 |
+
# =============================================================================
|
316 |
+
# DATA CACHING FOR QUERY RESULTS
|
317 |
+
# =============================================================================
|
318 |
+
|
319 |
+
@st.cache_data(ttl=3600, show_spinner=False)
|
320 |
+
def cached_process_query(query, top_k=5, word_limit=100):
|
321 |
+
"""Cached query processing to avoid redundant computation for repeated queries."""
|
322 |
+
print(f"\nπ Processing query (cached): {query}")
|
323 |
+
faiss_index, text_chunks, metadata_dict = cached_load_data_files()
|
324 |
if faiss_index is None or text_chunks is None or metadata_dict is None:
|
325 |
return {
|
326 |
"query": query,
|
327 |
"answer_with_rag": "β οΈ System error: Data files not loaded properly.",
|
328 |
"citations": "No citations available."
|
329 |
}
|
|
|
|
|
330 |
retrieved_context, retrieved_sources = retrieve_passages(
|
331 |
query,
|
332 |
faiss_index,
|
|
|
334 |
metadata_dict,
|
335 |
top_k=top_k
|
336 |
)
|
|
|
|
|
337 |
sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available."
|
|
|
|
|
338 |
if retrieved_context:
|
339 |
context_with_sources = list(zip(retrieved_sources, retrieved_context))
|
340 |
llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit)
|
341 |
else:
|
342 |
llm_answer_with_rag = "β οΈ No relevant context found."
|
343 |
+
return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
|
344 |
|
345 |
+
def process_query(query, top_k=5, word_limit=100):
|
346 |
+
"""Process a query through the RAG pipeline with proper formatting.
|
347 |
+
This function wraps the cached query processing.
|
348 |
+
"""
|
349 |
+
return cached_process_query(query, top_k, word_limit)
|