Spaces:
Sleeping
Sleeping
Update rag_engine.py
Browse files- rag_engine.py +174 -19
rag_engine.py
CHANGED
@@ -13,10 +13,12 @@ import streamlit as st
|
|
13 |
from utils import setup_gcp_auth, setup_openai_auth
|
14 |
import gc
|
15 |
|
16 |
-
# Force model to CPU for stability
|
|
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
18 |
|
19 |
-
#
|
|
|
20 |
local_embeddings_file = "all_embeddings.npy"
|
21 |
local_faiss_index_file = "faiss_index.faiss"
|
22 |
local_text_chunks_file = "text_chunks.txt"
|
@@ -28,9 +30,18 @@ local_metadata_file = "metadata.jsonl"
|
|
28 |
|
29 |
@st.cache_resource(show_spinner=False)
|
30 |
def cached_load_model():
|
31 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
try:
|
33 |
-
# Force model to CPU
|
34 |
device = torch.device("cpu")
|
35 |
|
36 |
# Get embedding model path from secrets
|
@@ -47,11 +58,11 @@ def cached_load_model():
|
|
47 |
torch_dtype=torch.float16
|
48 |
)
|
49 |
|
50 |
-
# Move model to CPU and set to eval mode
|
51 |
model = model.to(device)
|
52 |
model.eval()
|
53 |
|
54 |
-
# Disable gradient computation
|
55 |
torch.set_grad_enabled(False)
|
56 |
|
57 |
print("✅ Model loaded successfully (cached)")
|
@@ -63,7 +74,19 @@ def cached_load_model():
|
|
63 |
|
64 |
@st.cache_resource(show_spinner=False)
|
65 |
def cached_load_data_files():
|
66 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
# Initialize GCP and OpenAI clients
|
68 |
bucket = setup_gcp_client()
|
69 |
openai_initialized = setup_openai_client()
|
@@ -130,6 +153,14 @@ def cached_load_data_files():
|
|
130 |
# =============================================================================
|
131 |
|
132 |
def setup_gcp_client():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
try:
|
134 |
credentials = setup_gcp_auth()
|
135 |
try:
|
@@ -147,6 +178,14 @@ def setup_gcp_client():
|
|
147 |
return None
|
148 |
|
149 |
def setup_openai_client():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
try:
|
151 |
setup_openai_auth()
|
152 |
print("✅ OpenAI client initialized successfully")
|
@@ -156,7 +195,19 @@ def setup_openai_client():
|
|
156 |
return False
|
157 |
|
158 |
def download_file_from_gcs(bucket, gcs_path, local_path):
|
159 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
try:
|
161 |
if os.path.exists(local_path):
|
162 |
print(f"File already exists locally: {local_path}")
|
@@ -171,15 +222,38 @@ def download_file_from_gcs(bucket, gcs_path, local_path):
|
|
171 |
return False
|
172 |
|
173 |
def average_pool(last_hidden_states, attention_mask):
|
174 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
176 |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
177 |
|
178 |
-
# In-memory cache for query embeddings
|
179 |
query_embedding_cache = {}
|
180 |
|
181 |
def get_embedding(text):
|
182 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
if text in query_embedding_cache:
|
184 |
return query_embedding_cache[text]
|
185 |
|
@@ -189,6 +263,8 @@ def get_embedding(text):
|
|
189 |
print("Model is None, returning zero embedding")
|
190 |
return np.zeros((1, 384), dtype=np.float32)
|
191 |
|
|
|
|
|
192 |
input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
|
193 |
inputs = tokenizer(
|
194 |
input_text,
|
@@ -212,7 +288,27 @@ def get_embedding(text):
|
|
212 |
return np.zeros((1, 384), dtype=np.float32)
|
213 |
|
214 |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
|
215 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
try:
|
217 |
print(f"\n🔍 Retrieving passages for query: {query}")
|
218 |
query_embedding = get_embedding(query)
|
@@ -244,12 +340,32 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
|
|
244 |
return [], []
|
245 |
|
246 |
def answer_with_llm(query, context=None, word_limit=100):
|
247 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
try:
|
249 |
if context:
|
250 |
formatted_contexts = []
|
251 |
total_chars = 0
|
252 |
-
max_context_chars = 4000
|
253 |
for (title, author, publisher), text in context:
|
254 |
remaining_space = max(0, max_context_chars - total_chars)
|
255 |
excerpt_len = min(150, remaining_space)
|
@@ -274,6 +390,7 @@ def answer_with_llm(query, context=None, word_limit=100):
|
|
274 |
"Maintain appropriate, respectful language at all times."
|
275 |
"Do not use profanity, expletives, obscenities, slurs, hate speech, sexually explicit content, or language promoting violence."
|
276 |
"As a spiritual guidance system, ensure all responses reflect dignity, peace, love, and compassion consistent with spiritual traditions."
|
|
|
277 |
)
|
278 |
|
279 |
user_message = f"""
|
@@ -298,6 +415,8 @@ def answer_with_llm(query, context=None, word_limit=100):
|
|
298 |
max_tokens=200,
|
299 |
temperature=0.7
|
300 |
)
|
|
|
|
|
301 |
answer = response.choices[0].message.content.strip()
|
302 |
words = answer.split()
|
303 |
if len(words) > word_limit:
|
@@ -310,7 +429,18 @@ def answer_with_llm(query, context=None, word_limit=100):
|
|
310 |
return "I apologize, but I'm unable to answer at the moment."
|
311 |
|
312 |
def format_citations(sources):
|
313 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
formatted_citations = []
|
315 |
for title, author, publisher in sources:
|
316 |
if publisher.endswith(('.', '!', '?')):
|
@@ -325,7 +455,21 @@ def format_citations(sources):
|
|
325 |
|
326 |
@st.cache_data(ttl=3600, show_spinner=False)
|
327 |
def cached_process_query(query, top_k=5, word_limit=100):
|
328 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
print(f"\n🔍 Processing query (cached): {query}")
|
330 |
faiss_index, text_chunks, metadata_dict = cached_load_data_files()
|
331 |
if faiss_index is None or text_chunks is None or metadata_dict is None:
|
@@ -350,10 +494,21 @@ def cached_process_query(query, top_k=5, word_limit=100):
|
|
350 |
return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
|
351 |
|
352 |
def process_query(query, top_k=5, word_limit=100):
|
353 |
-
"""
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
"""
|
356 |
return cached_process_query(query, top_k, word_limit)
|
357 |
|
358 |
-
# Alias for backward compatibility
|
359 |
load_model = cached_load_model
|
|
|
13 |
from utils import setup_gcp_auth, setup_openai_auth
|
14 |
import gc
|
15 |
|
16 |
+
# Force model to CPU for stability and to avoid GPU memory issues on resource-constrained environments
|
17 |
+
# This is especially important for deployment on platforms like Hugging Face Spaces
|
18 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
19 |
|
20 |
+
# Define local paths for files downloaded from Google Cloud Storage
|
21 |
+
# These files are cached locally to avoid repeated downloads and improve performance
|
22 |
local_embeddings_file = "all_embeddings.npy"
|
23 |
local_faiss_index_file = "faiss_index.faiss"
|
24 |
local_text_chunks_file = "text_chunks.txt"
|
|
|
30 |
|
31 |
@st.cache_resource(show_spinner=False)
|
32 |
def cached_load_model():
|
33 |
+
"""
|
34 |
+
Load and cache the E5-large-v2 embedding model and tokenizer.
|
35 |
+
|
36 |
+
Uses Streamlit's cache_resource decorator to ensure the model
|
37 |
+
is loaded only once during the application session, improving
|
38 |
+
performance and reducing memory usage.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
tuple: (tokenizer, model) pair or (None, None) if loading fails
|
42 |
+
"""
|
43 |
try:
|
44 |
+
# Force model to CPU for stability
|
45 |
device = torch.device("cpu")
|
46 |
|
47 |
# Get embedding model path from secrets
|
|
|
58 |
torch_dtype=torch.float16
|
59 |
)
|
60 |
|
61 |
+
# Move model to CPU and set to eval mode for inference
|
62 |
model = model.to(device)
|
63 |
model.eval()
|
64 |
|
65 |
+
# Disable gradient computation to save memory during inference
|
66 |
torch.set_grad_enabled(False)
|
67 |
|
68 |
print("✅ Model loaded successfully (cached)")
|
|
|
74 |
|
75 |
@st.cache_resource(show_spinner=False)
|
76 |
def cached_load_data_files():
|
77 |
+
"""
|
78 |
+
Load and cache data files needed for the RAG system.
|
79 |
+
|
80 |
+
This function loads:
|
81 |
+
- FAISS index for vector similarity search
|
82 |
+
- Text chunks containing the original spiritual text passages
|
83 |
+
- Metadata dictionary with publication and author information
|
84 |
+
|
85 |
+
All files are downloaded from Google Cloud Storage if not already present locally.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
tuple: (faiss_index, text_chunks, metadata_dict) or (None, None, None) if loading fails
|
89 |
+
"""
|
90 |
# Initialize GCP and OpenAI clients
|
91 |
bucket = setup_gcp_client()
|
92 |
openai_initialized = setup_openai_client()
|
|
|
153 |
# =============================================================================
|
154 |
|
155 |
def setup_gcp_client():
|
156 |
+
"""
|
157 |
+
Initialize and return the Google Cloud Storage client.
|
158 |
+
|
159 |
+
Sets up GCP authentication and creates a client for the configured bucket.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
google.cloud.storage.bucket.Bucket: The GCS bucket object or None if initialization fails
|
163 |
+
"""
|
164 |
try:
|
165 |
credentials = setup_gcp_auth()
|
166 |
try:
|
|
|
178 |
return None
|
179 |
|
180 |
def setup_openai_client():
|
181 |
+
"""
|
182 |
+
Initialize the OpenAI client.
|
183 |
+
|
184 |
+
Sets up OpenAI API authentication for generating answers using the LLM.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
bool: True if initialization was successful, False otherwise
|
188 |
+
"""
|
189 |
try:
|
190 |
setup_openai_auth()
|
191 |
print("✅ OpenAI client initialized successfully")
|
|
|
195 |
return False
|
196 |
|
197 |
def download_file_from_gcs(bucket, gcs_path, local_path):
|
198 |
+
"""
|
199 |
+
Download a file from Google Cloud Storage to local storage.
|
200 |
+
|
201 |
+
Only downloads if the file isn't already present locally, avoiding redundant downloads.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
bucket: GCS bucket object
|
205 |
+
gcs_path (str): Path to the file in GCS
|
206 |
+
local_path (str): Local path where the file should be saved
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
bool: True if download was successful or file already exists, False otherwise
|
210 |
+
"""
|
211 |
try:
|
212 |
if os.path.exists(local_path):
|
213 |
print(f"File already exists locally: {local_path}")
|
|
|
222 |
return False
|
223 |
|
224 |
def average_pool(last_hidden_states, attention_mask):
|
225 |
+
"""
|
226 |
+
Perform average pooling on model outputs for sentence embeddings.
|
227 |
+
|
228 |
+
This function creates a fixed-size vector representation of a text sequence by averaging
|
229 |
+
the token embeddings, accounting for padding tokens using the attention mask.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
last_hidden_states: Hidden states output from the model
|
233 |
+
attention_mask: Attention mask indicating which tokens to include
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
torch.Tensor: Pooled representation of the input sequence
|
237 |
+
"""
|
238 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
239 |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
240 |
|
241 |
+
# In-memory cache for query embeddings to avoid redundant computations
|
242 |
query_embedding_cache = {}
|
243 |
|
244 |
def get_embedding(text):
|
245 |
+
"""
|
246 |
+
Generate embeddings for a text query using the cached model.
|
247 |
+
|
248 |
+
Uses an in-memory cache to avoid redundant embedding generation for repeated queries.
|
249 |
+
Properly prefixes inputs with "query:" or "passage:" as required by the E5 model.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
text (str): The query text to embed
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
numpy.ndarray: The embedding vector or a zero vector if embedding fails
|
256 |
+
"""
|
257 |
if text in query_embedding_cache:
|
258 |
return query_embedding_cache[text]
|
259 |
|
|
|
263 |
print("Model is None, returning zero embedding")
|
264 |
return np.zeros((1, 384), dtype=np.float32)
|
265 |
|
266 |
+
# Format input based on text length
|
267 |
+
# For E5 models, "query:" prefix is for questions, "passage:" for documents
|
268 |
input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
|
269 |
inputs = tokenizer(
|
270 |
input_text,
|
|
|
288 |
return np.zeros((1, 384), dtype=np.float32)
|
289 |
|
290 |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
|
291 |
+
"""
|
292 |
+
Retrieve the most relevant passages for a given spiritual query.
|
293 |
+
|
294 |
+
This function:
|
295 |
+
1. Embeds the user query using the same model used for text chunks
|
296 |
+
2. Finds similar passages using the FAISS index with cosine similarity
|
297 |
+
3. Filters results based on similarity threshold to ensure relevance
|
298 |
+
4. Enriches results with metadata (title, author, publisher)
|
299 |
+
5. Ensures passage diversity by including only one passage per source title
|
300 |
+
|
301 |
+
Args:
|
302 |
+
query (str): The user's spiritual question
|
303 |
+
faiss_index: FAISS index containing passage embeddings
|
304 |
+
text_chunks (dict): Dictionary mapping IDs to text chunks and metadata
|
305 |
+
metadata_dict (dict): Dictionary containing publication information
|
306 |
+
top_k (int): Maximum number of passages to retrieve
|
307 |
+
similarity_threshold (float): Minimum similarity score (0.0-1.0) for retrieved passages
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
tuple: (retrieved_passages, retrieved_sources) containing the text and source information
|
311 |
+
"""
|
312 |
try:
|
313 |
print(f"\n🔍 Retrieving passages for query: {query}")
|
314 |
query_embedding = get_embedding(query)
|
|
|
340 |
return [], []
|
341 |
|
342 |
def answer_with_llm(query, context=None, word_limit=100):
|
343 |
+
"""
|
344 |
+
Generate an answer using the OpenAI GPT model with formatted citations.
|
345 |
+
|
346 |
+
This function:
|
347 |
+
1. Formats retrieved passages with source information
|
348 |
+
2. Creates a prompt with system and user messages
|
349 |
+
3. Calls the OpenAI API to generate an answer
|
350 |
+
4. Trims the response to the specified word limit
|
351 |
+
|
352 |
+
The system prompt ensures answers maintain appropriate respect for spiritual traditions,
|
353 |
+
synthesize rather than quote directly, and acknowledge gaps when relevant information
|
354 |
+
isn't available.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
query (str): The user's spiritual question
|
358 |
+
context (list, optional): List of (source_info, text) tuples for context
|
359 |
+
word_limit (int): Maximum word count for the generated answer
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
str: The generated answer or an error message
|
363 |
+
"""
|
364 |
try:
|
365 |
if context:
|
366 |
formatted_contexts = []
|
367 |
total_chars = 0
|
368 |
+
max_context_chars = 4000 # Limit context size to avoid exceeding token limits
|
369 |
for (title, author, publisher), text in context:
|
370 |
remaining_space = max(0, max_context_chars - total_chars)
|
371 |
excerpt_len = min(150, remaining_space)
|
|
|
390 |
"Maintain appropriate, respectful language at all times."
|
391 |
"Do not use profanity, expletives, obscenities, slurs, hate speech, sexually explicit content, or language promoting violence."
|
392 |
"As a spiritual guidance system, ensure all responses reflect dignity, peace, love, and compassion consistent with spiritual traditions."
|
393 |
+
"Provide concise, focused answers without lists or lengthy explanations."
|
394 |
)
|
395 |
|
396 |
user_message = f"""
|
|
|
415 |
max_tokens=200,
|
416 |
temperature=0.7
|
417 |
)
|
418 |
+
|
419 |
+
# Extract the answer and apply word limit
|
420 |
answer = response.choices[0].message.content.strip()
|
421 |
words = answer.split()
|
422 |
if len(words) > word_limit:
|
|
|
429 |
return "I apologize, but I'm unable to answer at the moment."
|
430 |
|
431 |
def format_citations(sources):
|
432 |
+
"""
|
433 |
+
Format citations for display to the user.
|
434 |
+
|
435 |
+
Creates properly formatted citations for each source used in generating the answer.
|
436 |
+
Each citation appears on a new line with consistent formatting.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
sources (list): List of (title, author, publisher) tuples
|
440 |
+
|
441 |
+
Returns:
|
442 |
+
str: Formatted citations as a string with each citation on a new line
|
443 |
+
"""
|
444 |
formatted_citations = []
|
445 |
for title, author, publisher in sources:
|
446 |
if publisher.endswith(('.', '!', '?')):
|
|
|
455 |
|
456 |
@st.cache_data(ttl=3600, show_spinner=False)
|
457 |
def cached_process_query(query, top_k=5, word_limit=100):
|
458 |
+
"""
|
459 |
+
Process a user query with caching to avoid redundant computation.
|
460 |
+
|
461 |
+
This function is cached with a time-to-live (TTL) of 1 hour, meaning identical
|
462 |
+
queries within this time period will return cached results rather than
|
463 |
+
reprocessing, improving responsiveness.
|
464 |
+
|
465 |
+
Args:
|
466 |
+
query (str): The user's spiritual question
|
467 |
+
top_k (int): Number of sources to retrieve and use for answer generation
|
468 |
+
word_limit (int): Maximum word count for the generated answer
|
469 |
+
|
470 |
+
Returns:
|
471 |
+
dict: Dictionary containing the query, answer, and citations
|
472 |
+
"""
|
473 |
print(f"\n🔍 Processing query (cached): {query}")
|
474 |
faiss_index, text_chunks, metadata_dict = cached_load_data_files()
|
475 |
if faiss_index is None or text_chunks is None or metadata_dict is None:
|
|
|
494 |
return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
|
495 |
|
496 |
def process_query(query, top_k=5, word_limit=100):
|
497 |
+
"""
|
498 |
+
Process a query through the RAG pipeline with proper formatting.
|
499 |
+
|
500 |
+
This is the main entry point for query processing, wrapping the cached
|
501 |
+
query processing function.
|
502 |
+
|
503 |
+
Args:
|
504 |
+
query (str): The user's spiritual question
|
505 |
+
top_k (int): Number of sources to retrieve and use for answer generation
|
506 |
+
word_limit (int): Maximum word count for the generated answer
|
507 |
+
|
508 |
+
Returns:
|
509 |
+
dict: Dictionary containing the query, answer, and citations
|
510 |
"""
|
511 |
return cached_process_query(query, top_k, word_limit)
|
512 |
|
513 |
+
# Alias for backward compatibility
|
514 |
load_model = cached_load_model
|