ankanghosh commited on
Commit
c72a64f
·
verified ·
1 Parent(s): 77c7ac2

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +174 -19
rag_engine.py CHANGED
@@ -13,10 +13,12 @@ import streamlit as st
13
  from utils import setup_gcp_auth, setup_openai_auth
14
  import gc
15
 
16
- # Force model to CPU for stability
 
17
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
18
 
19
- # Local Paths
 
20
  local_embeddings_file = "all_embeddings.npy"
21
  local_faiss_index_file = "faiss_index.faiss"
22
  local_text_chunks_file = "text_chunks.txt"
@@ -28,9 +30,18 @@ local_metadata_file = "metadata.jsonl"
28
 
29
  @st.cache_resource(show_spinner=False)
30
  def cached_load_model():
31
- """Cached version of load_model() for embedding model loading."""
 
 
 
 
 
 
 
 
 
32
  try:
33
- # Force model to CPU
34
  device = torch.device("cpu")
35
 
36
  # Get embedding model path from secrets
@@ -47,11 +58,11 @@ def cached_load_model():
47
  torch_dtype=torch.float16
48
  )
49
 
50
- # Move model to CPU and set to eval mode
51
  model = model.to(device)
52
  model.eval()
53
 
54
- # Disable gradient computation
55
  torch.set_grad_enabled(False)
56
 
57
  print("✅ Model loaded successfully (cached)")
@@ -63,7 +74,19 @@ def cached_load_model():
63
 
64
  @st.cache_resource(show_spinner=False)
65
  def cached_load_data_files():
66
- """Cached version of load_data_files() for FAISS index, text chunks, and metadata."""
 
 
 
 
 
 
 
 
 
 
 
 
67
  # Initialize GCP and OpenAI clients
68
  bucket = setup_gcp_client()
69
  openai_initialized = setup_openai_client()
@@ -130,6 +153,14 @@ def cached_load_data_files():
130
  # =============================================================================
131
 
132
  def setup_gcp_client():
 
 
 
 
 
 
 
 
133
  try:
134
  credentials = setup_gcp_auth()
135
  try:
@@ -147,6 +178,14 @@ def setup_gcp_client():
147
  return None
148
 
149
  def setup_openai_client():
 
 
 
 
 
 
 
 
150
  try:
151
  setup_openai_auth()
152
  print("✅ OpenAI client initialized successfully")
@@ -156,7 +195,19 @@ def setup_openai_client():
156
  return False
157
 
158
  def download_file_from_gcs(bucket, gcs_path, local_path):
159
- """Download a file from GCS to local storage if not already present."""
 
 
 
 
 
 
 
 
 
 
 
 
160
  try:
161
  if os.path.exists(local_path):
162
  print(f"File already exists locally: {local_path}")
@@ -171,15 +222,38 @@ def download_file_from_gcs(bucket, gcs_path, local_path):
171
  return False
172
 
173
  def average_pool(last_hidden_states, attention_mask):
174
- """Average pooling for sentence embeddings."""
 
 
 
 
 
 
 
 
 
 
 
 
175
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
176
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
177
 
178
- # In-memory cache for query embeddings
179
  query_embedding_cache = {}
180
 
181
  def get_embedding(text):
182
- """Generate embeddings for a text query using the cached model."""
 
 
 
 
 
 
 
 
 
 
 
183
  if text in query_embedding_cache:
184
  return query_embedding_cache[text]
185
 
@@ -189,6 +263,8 @@ def get_embedding(text):
189
  print("Model is None, returning zero embedding")
190
  return np.zeros((1, 384), dtype=np.float32)
191
 
 
 
192
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
193
  inputs = tokenizer(
194
  input_text,
@@ -212,7 +288,27 @@ def get_embedding(text):
212
  return np.zeros((1, 384), dtype=np.float32)
213
 
214
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
215
- """Retrieve top-k most relevant passages using FAISS and accompanying metadata."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  try:
217
  print(f"\n🔍 Retrieving passages for query: {query}")
218
  query_embedding = get_embedding(query)
@@ -244,12 +340,32 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
244
  return [], []
245
 
246
  def answer_with_llm(query, context=None, word_limit=100):
247
- """Generate an answer using the OpenAI GPT model with formatted citations."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  try:
249
  if context:
250
  formatted_contexts = []
251
  total_chars = 0
252
- max_context_chars = 4000
253
  for (title, author, publisher), text in context:
254
  remaining_space = max(0, max_context_chars - total_chars)
255
  excerpt_len = min(150, remaining_space)
@@ -274,6 +390,7 @@ def answer_with_llm(query, context=None, word_limit=100):
274
  "Maintain appropriate, respectful language at all times."
275
  "Do not use profanity, expletives, obscenities, slurs, hate speech, sexually explicit content, or language promoting violence."
276
  "As a spiritual guidance system, ensure all responses reflect dignity, peace, love, and compassion consistent with spiritual traditions."
 
277
  )
278
 
279
  user_message = f"""
@@ -298,6 +415,8 @@ def answer_with_llm(query, context=None, word_limit=100):
298
  max_tokens=200,
299
  temperature=0.7
300
  )
 
 
301
  answer = response.choices[0].message.content.strip()
302
  words = answer.split()
303
  if len(words) > word_limit:
@@ -310,7 +429,18 @@ def answer_with_llm(query, context=None, word_limit=100):
310
  return "I apologize, but I'm unable to answer at the moment."
311
 
312
  def format_citations(sources):
313
- """Format citations so that each appears on a new line, ending with proper punctuation."""
 
 
 
 
 
 
 
 
 
 
 
314
  formatted_citations = []
315
  for title, author, publisher in sources:
316
  if publisher.endswith(('.', '!', '?')):
@@ -325,7 +455,21 @@ def format_citations(sources):
325
 
326
  @st.cache_data(ttl=3600, show_spinner=False)
327
  def cached_process_query(query, top_k=5, word_limit=100):
328
- """Cached query processing to avoid redundant computation for repeated queries."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  print(f"\n🔍 Processing query (cached): {query}")
330
  faiss_index, text_chunks, metadata_dict = cached_load_data_files()
331
  if faiss_index is None or text_chunks is None or metadata_dict is None:
@@ -350,10 +494,21 @@ def cached_process_query(query, top_k=5, word_limit=100):
350
  return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
351
 
352
  def process_query(query, top_k=5, word_limit=100):
353
- """Process a query through the RAG pipeline with proper formatting.
354
- This function wraps the cached query processing.
 
 
 
 
 
 
 
 
 
 
 
355
  """
356
  return cached_process_query(query, top_k, word_limit)
357
 
358
- # Alias for backward compatibility.
359
  load_model = cached_load_model
 
13
  from utils import setup_gcp_auth, setup_openai_auth
14
  import gc
15
 
16
+ # Force model to CPU for stability and to avoid GPU memory issues on resource-constrained environments
17
+ # This is especially important for deployment on platforms like Hugging Face Spaces
18
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
19
 
20
+ # Define local paths for files downloaded from Google Cloud Storage
21
+ # These files are cached locally to avoid repeated downloads and improve performance
22
  local_embeddings_file = "all_embeddings.npy"
23
  local_faiss_index_file = "faiss_index.faiss"
24
  local_text_chunks_file = "text_chunks.txt"
 
30
 
31
  @st.cache_resource(show_spinner=False)
32
  def cached_load_model():
33
+ """
34
+ Load and cache the E5-large-v2 embedding model and tokenizer.
35
+
36
+ Uses Streamlit's cache_resource decorator to ensure the model
37
+ is loaded only once during the application session, improving
38
+ performance and reducing memory usage.
39
+
40
+ Returns:
41
+ tuple: (tokenizer, model) pair or (None, None) if loading fails
42
+ """
43
  try:
44
+ # Force model to CPU for stability
45
  device = torch.device("cpu")
46
 
47
  # Get embedding model path from secrets
 
58
  torch_dtype=torch.float16
59
  )
60
 
61
+ # Move model to CPU and set to eval mode for inference
62
  model = model.to(device)
63
  model.eval()
64
 
65
+ # Disable gradient computation to save memory during inference
66
  torch.set_grad_enabled(False)
67
 
68
  print("✅ Model loaded successfully (cached)")
 
74
 
75
  @st.cache_resource(show_spinner=False)
76
  def cached_load_data_files():
77
+ """
78
+ Load and cache data files needed for the RAG system.
79
+
80
+ This function loads:
81
+ - FAISS index for vector similarity search
82
+ - Text chunks containing the original spiritual text passages
83
+ - Metadata dictionary with publication and author information
84
+
85
+ All files are downloaded from Google Cloud Storage if not already present locally.
86
+
87
+ Returns:
88
+ tuple: (faiss_index, text_chunks, metadata_dict) or (None, None, None) if loading fails
89
+ """
90
  # Initialize GCP and OpenAI clients
91
  bucket = setup_gcp_client()
92
  openai_initialized = setup_openai_client()
 
153
  # =============================================================================
154
 
155
  def setup_gcp_client():
156
+ """
157
+ Initialize and return the Google Cloud Storage client.
158
+
159
+ Sets up GCP authentication and creates a client for the configured bucket.
160
+
161
+ Returns:
162
+ google.cloud.storage.bucket.Bucket: The GCS bucket object or None if initialization fails
163
+ """
164
  try:
165
  credentials = setup_gcp_auth()
166
  try:
 
178
  return None
179
 
180
  def setup_openai_client():
181
+ """
182
+ Initialize the OpenAI client.
183
+
184
+ Sets up OpenAI API authentication for generating answers using the LLM.
185
+
186
+ Returns:
187
+ bool: True if initialization was successful, False otherwise
188
+ """
189
  try:
190
  setup_openai_auth()
191
  print("✅ OpenAI client initialized successfully")
 
195
  return False
196
 
197
  def download_file_from_gcs(bucket, gcs_path, local_path):
198
+ """
199
+ Download a file from Google Cloud Storage to local storage.
200
+
201
+ Only downloads if the file isn't already present locally, avoiding redundant downloads.
202
+
203
+ Args:
204
+ bucket: GCS bucket object
205
+ gcs_path (str): Path to the file in GCS
206
+ local_path (str): Local path where the file should be saved
207
+
208
+ Returns:
209
+ bool: True if download was successful or file already exists, False otherwise
210
+ """
211
  try:
212
  if os.path.exists(local_path):
213
  print(f"File already exists locally: {local_path}")
 
222
  return False
223
 
224
  def average_pool(last_hidden_states, attention_mask):
225
+ """
226
+ Perform average pooling on model outputs for sentence embeddings.
227
+
228
+ This function creates a fixed-size vector representation of a text sequence by averaging
229
+ the token embeddings, accounting for padding tokens using the attention mask.
230
+
231
+ Args:
232
+ last_hidden_states: Hidden states output from the model
233
+ attention_mask: Attention mask indicating which tokens to include
234
+
235
+ Returns:
236
+ torch.Tensor: Pooled representation of the input sequence
237
+ """
238
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
239
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
240
 
241
+ # In-memory cache for query embeddings to avoid redundant computations
242
  query_embedding_cache = {}
243
 
244
  def get_embedding(text):
245
+ """
246
+ Generate embeddings for a text query using the cached model.
247
+
248
+ Uses an in-memory cache to avoid redundant embedding generation for repeated queries.
249
+ Properly prefixes inputs with "query:" or "passage:" as required by the E5 model.
250
+
251
+ Args:
252
+ text (str): The query text to embed
253
+
254
+ Returns:
255
+ numpy.ndarray: The embedding vector or a zero vector if embedding fails
256
+ """
257
  if text in query_embedding_cache:
258
  return query_embedding_cache[text]
259
 
 
263
  print("Model is None, returning zero embedding")
264
  return np.zeros((1, 384), dtype=np.float32)
265
 
266
+ # Format input based on text length
267
+ # For E5 models, "query:" prefix is for questions, "passage:" for documents
268
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
269
  inputs = tokenizer(
270
  input_text,
 
288
  return np.zeros((1, 384), dtype=np.float32)
289
 
290
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
291
+ """
292
+ Retrieve the most relevant passages for a given spiritual query.
293
+
294
+ This function:
295
+ 1. Embeds the user query using the same model used for text chunks
296
+ 2. Finds similar passages using the FAISS index with cosine similarity
297
+ 3. Filters results based on similarity threshold to ensure relevance
298
+ 4. Enriches results with metadata (title, author, publisher)
299
+ 5. Ensures passage diversity by including only one passage per source title
300
+
301
+ Args:
302
+ query (str): The user's spiritual question
303
+ faiss_index: FAISS index containing passage embeddings
304
+ text_chunks (dict): Dictionary mapping IDs to text chunks and metadata
305
+ metadata_dict (dict): Dictionary containing publication information
306
+ top_k (int): Maximum number of passages to retrieve
307
+ similarity_threshold (float): Minimum similarity score (0.0-1.0) for retrieved passages
308
+
309
+ Returns:
310
+ tuple: (retrieved_passages, retrieved_sources) containing the text and source information
311
+ """
312
  try:
313
  print(f"\n🔍 Retrieving passages for query: {query}")
314
  query_embedding = get_embedding(query)
 
340
  return [], []
341
 
342
  def answer_with_llm(query, context=None, word_limit=100):
343
+ """
344
+ Generate an answer using the OpenAI GPT model with formatted citations.
345
+
346
+ This function:
347
+ 1. Formats retrieved passages with source information
348
+ 2. Creates a prompt with system and user messages
349
+ 3. Calls the OpenAI API to generate an answer
350
+ 4. Trims the response to the specified word limit
351
+
352
+ The system prompt ensures answers maintain appropriate respect for spiritual traditions,
353
+ synthesize rather than quote directly, and acknowledge gaps when relevant information
354
+ isn't available.
355
+
356
+ Args:
357
+ query (str): The user's spiritual question
358
+ context (list, optional): List of (source_info, text) tuples for context
359
+ word_limit (int): Maximum word count for the generated answer
360
+
361
+ Returns:
362
+ str: The generated answer or an error message
363
+ """
364
  try:
365
  if context:
366
  formatted_contexts = []
367
  total_chars = 0
368
+ max_context_chars = 4000 # Limit context size to avoid exceeding token limits
369
  for (title, author, publisher), text in context:
370
  remaining_space = max(0, max_context_chars - total_chars)
371
  excerpt_len = min(150, remaining_space)
 
390
  "Maintain appropriate, respectful language at all times."
391
  "Do not use profanity, expletives, obscenities, slurs, hate speech, sexually explicit content, or language promoting violence."
392
  "As a spiritual guidance system, ensure all responses reflect dignity, peace, love, and compassion consistent with spiritual traditions."
393
+ "Provide concise, focused answers without lists or lengthy explanations."
394
  )
395
 
396
  user_message = f"""
 
415
  max_tokens=200,
416
  temperature=0.7
417
  )
418
+
419
+ # Extract the answer and apply word limit
420
  answer = response.choices[0].message.content.strip()
421
  words = answer.split()
422
  if len(words) > word_limit:
 
429
  return "I apologize, but I'm unable to answer at the moment."
430
 
431
  def format_citations(sources):
432
+ """
433
+ Format citations for display to the user.
434
+
435
+ Creates properly formatted citations for each source used in generating the answer.
436
+ Each citation appears on a new line with consistent formatting.
437
+
438
+ Args:
439
+ sources (list): List of (title, author, publisher) tuples
440
+
441
+ Returns:
442
+ str: Formatted citations as a string with each citation on a new line
443
+ """
444
  formatted_citations = []
445
  for title, author, publisher in sources:
446
  if publisher.endswith(('.', '!', '?')):
 
455
 
456
  @st.cache_data(ttl=3600, show_spinner=False)
457
  def cached_process_query(query, top_k=5, word_limit=100):
458
+ """
459
+ Process a user query with caching to avoid redundant computation.
460
+
461
+ This function is cached with a time-to-live (TTL) of 1 hour, meaning identical
462
+ queries within this time period will return cached results rather than
463
+ reprocessing, improving responsiveness.
464
+
465
+ Args:
466
+ query (str): The user's spiritual question
467
+ top_k (int): Number of sources to retrieve and use for answer generation
468
+ word_limit (int): Maximum word count for the generated answer
469
+
470
+ Returns:
471
+ dict: Dictionary containing the query, answer, and citations
472
+ """
473
  print(f"\n🔍 Processing query (cached): {query}")
474
  faiss_index, text_chunks, metadata_dict = cached_load_data_files()
475
  if faiss_index is None or text_chunks is None or metadata_dict is None:
 
494
  return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
495
 
496
  def process_query(query, top_k=5, word_limit=100):
497
+ """
498
+ Process a query through the RAG pipeline with proper formatting.
499
+
500
+ This is the main entry point for query processing, wrapping the cached
501
+ query processing function.
502
+
503
+ Args:
504
+ query (str): The user's spiritual question
505
+ top_k (int): Number of sources to retrieve and use for answer generation
506
+ word_limit (int): Maximum word count for the generated answer
507
+
508
+ Returns:
509
+ dict: Dictionary containing the query, answer, and citations
510
  """
511
  return cached_process_query(query, top_k, word_limit)
512
 
513
+ # Alias for backward compatibility
514
  load_model = cached_load_model