ankanghosh commited on
Commit
db4ca04
Β·
verified Β·
1 Parent(s): aeac8df

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +82 -151
rag_engine.py CHANGED
@@ -22,46 +22,14 @@ local_faiss_index_file = "faiss_index.faiss"
22
  local_text_chunks_file = "text_chunks.txt"
23
  local_metadata_file = "metadata.jsonl"
24
 
25
- # Load GCP authentication from utility function
26
- def setup_gcp_client():
27
- try:
28
- credentials = setup_gcp_auth()
29
-
30
- # Get bucket name from secrets - required
31
- try:
32
- bucket_name_gcs = st.secrets["BUCKET_NAME_GCS"]
33
- except KeyError:
34
- print("❌ Error: GCS bucket name not found in secrets")
35
- return None
36
-
37
- storage_client = storage.Client(credentials=credentials)
38
- bucket = storage_client.bucket(bucket_name_gcs)
39
- print("βœ… GCP client initialized successfully")
40
- return bucket
41
- except Exception as e:
42
- print(f"❌ GCP client initialization error: {str(e)}")
43
- return None
44
 
45
- # Setup OpenAI authentication
46
- def setup_openai_client():
 
47
  try:
48
- setup_openai_auth()
49
- print("βœ… OpenAI client initialized successfully")
50
- return True
51
- except Exception as e:
52
- print(f"❌ OpenAI client initialization error: {str(e)}")
53
- return False
54
-
55
- def load_model():
56
- """Load the embedding model and store in session state"""
57
- try:
58
- # Check if model already loaded
59
- if 'model' in st.session_state and st.session_state.model is not None:
60
- print("Model already loaded in session state")
61
- return st.session_state.tokenizer, st.session_state.model
62
-
63
- print("Loading new model instance...")
64
-
65
  # Force model to CPU
66
  device = torch.device("cpu")
67
 
@@ -79,49 +47,24 @@ def load_model():
79
  torch_dtype=torch.float16
80
  )
81
 
82
- # Move to CPU and set to eval mode
83
  model = model.to(device)
84
  model.eval()
85
 
86
  # Disable gradient computation
87
  torch.set_grad_enabled(False)
88
 
89
- # Store in session state
90
- st.session_state.tokenizer = tokenizer
91
- st.session_state.model = model
92
-
93
- print("βœ… Model loaded successfully")
94
  return tokenizer, model
95
 
96
  except Exception as e:
97
  print(f"❌ Error loading model: {str(e)}")
98
- # Return None values - don't raise exception
99
  return None, None
100
 
101
- def download_file_from_gcs(bucket, gcs_path, local_path):
102
- """Download a file from GCS to local storage."""
103
- try:
104
- # Check if file already exists
105
- if os.path.exists(local_path):
106
- print(f"File already exists locally: {local_path}")
107
- return True
108
-
109
- blob = bucket.blob(gcs_path)
110
- blob.download_to_filename(local_path)
111
- print(f"βœ… Downloaded {gcs_path} β†’ {local_path}")
112
- return True
113
- except Exception as e:
114
- print(f"❌ Error downloading {gcs_path}: {str(e)}")
115
- return False
116
-
117
- def load_data_files():
118
- """Load FAISS index, text chunks, and metadata"""
119
- # Check if already loaded in session state
120
- if 'faiss_index' in st.session_state and st.session_state.faiss_index is not None:
121
- print("Using cached data files from session state")
122
- return st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict
123
-
124
- # Initialize clients
125
  bucket = setup_gcp_client()
126
  openai_initialized = setup_openai_client()
127
 
@@ -139,7 +82,7 @@ def load_data_files():
139
  print(f"❌ Error: Required GCS path not found in secrets: {e}")
140
  return None, None, None
141
 
142
- # Download necessary files
143
  success = True
144
  success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file)
145
  success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file)
@@ -158,7 +101,7 @@ def load_data_files():
158
 
159
  # Load text chunks
160
  try:
161
- text_chunks = {} # {ID -> (Title, Author, Text)}
162
  with open(local_text_chunks_file, "r", encoding="utf-8") as f:
163
  for line in f:
164
  parts = line.strip().split("\t")
@@ -179,45 +122,74 @@ def load_data_files():
179
  print(f"❌ Error loading metadata: {str(e)}")
180
  return None, None, None
181
 
182
- print(f"βœ… Data loaded successfully: {len(text_chunks)} passages available")
183
-
184
- # Store in session state
185
- st.session_state.faiss_index = faiss_index
186
- st.session_state.text_chunks = text_chunks
187
- st.session_state.metadata_dict = metadata_dict
188
-
189
  return faiss_index, text_chunks, metadata_dict
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def average_pool(last_hidden_states, attention_mask):
192
  """Average pooling for sentence embeddings."""
193
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
194
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
195
 
196
- # Cache for query embeddings
197
  query_embedding_cache = {}
198
 
199
  def get_embedding(text):
200
- """Generate embeddings for a text query"""
201
- # Check cache first
202
  if text in query_embedding_cache:
203
  return query_embedding_cache[text]
204
 
205
  try:
206
- # Get model
207
- if 'model' not in st.session_state or st.session_state.model is None:
208
- tokenizer, model = load_model()
209
- else:
210
- tokenizer, model = st.session_state.tokenizer, st.session_state.model
211
-
212
- # Handle model load failure
213
  if model is None:
214
  print("Model is None, returning zero embedding")
215
  return np.zeros((1, 384), dtype=np.float32)
216
 
217
- # Prepare text
218
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
219
-
220
- # Tokenize
221
  inputs = tokenizer(
222
  input_text,
223
  padding=True,
@@ -226,19 +198,13 @@ def get_embedding(text):
226
  max_length=512,
227
  return_attention_mask=True
228
  )
229
-
230
- # Generate embeddings
231
  with torch.no_grad():
232
  outputs = model(**inputs)
233
  embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
234
  embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
235
  embeddings = embeddings.detach().cpu().numpy()
236
-
237
- # Clean up
238
  del outputs, inputs
239
  gc.collect()
240
-
241
- # Cache and return
242
  query_embedding_cache[text] = embeddings
243
  return embeddings
244
  except Exception as e:
@@ -246,49 +212,31 @@ def get_embedding(text):
246
  return np.zeros((1, 384), dtype=np.float32)
247
 
248
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
249
- """Retrieve top-k most relevant passages using FAISS with metadata."""
250
  try:
251
  print(f"\nπŸ” Retrieving passages for query: {query}")
252
-
253
- # Get query embedding
254
  query_embedding = get_embedding(query)
255
-
256
- # Search in FAISS index
257
  distances, indices = faiss_index.search(query_embedding, top_k * 2)
258
-
259
  print(f"Found {len(distances[0])} potential matches")
260
  retrieved_passages = []
261
  retrieved_sources = []
262
  cited_titles = set()
263
-
264
- # Process results
265
  for dist, idx in zip(distances[0], indices[0]):
266
  print(f"Distance: {dist:.4f}, Index: {idx}")
267
  if idx in text_chunks and dist >= similarity_threshold:
268
  title_with_txt, author, text = text_chunks[idx]
269
-
270
- # Clean title
271
  clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
272
  clean_title = unicodedata.normalize("NFC", clean_title)
273
-
274
- # Skip duplicates
275
  if clean_title in cited_titles:
276
  continue
277
-
278
- # Get metadata
279
  metadata_entry = metadata_dict.get(clean_title, {})
280
  author = metadata_entry.get("Author", "Unknown")
281
  publisher = metadata_entry.get("Publisher", "Unknown")
282
-
283
- # Add to results
284
  cited_titles.add(clean_title)
285
  retrieved_passages.append(text)
286
- retrieved_sources.append((clean_title, author, publisher))
287
-
288
- # Stop if we have enough
289
  if len(retrieved_passages) == top_k:
290
  break
291
-
292
  print(f"Retrieved {len(retrieved_passages)} passages")
293
  return retrieved_passages, retrieved_sources
294
  except Exception as e:
@@ -296,39 +244,31 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
296
  return [], []
297
 
298
  def answer_with_llm(query, context=None, word_limit=100):
299
- """Generate an answer using OpenAI GPT model with formatted citations."""
300
  try:
301
- # Format context
302
  if context:
303
  formatted_contexts = []
304
  total_chars = 0
305
  max_context_chars = 4000
306
-
307
  for (title, author, publisher), text in context:
308
  remaining_space = max(0, max_context_chars - total_chars)
309
  excerpt_len = min(150, remaining_space)
310
-
311
  if excerpt_len > 50:
312
  excerpt = text[:excerpt_len].strip() + "..." if len(text) > excerpt_len else text
313
  formatted_context = f"[{title} by {author}, Published by {publisher}] {excerpt}"
314
  formatted_contexts.append(formatted_context)
315
  total_chars += len(formatted_context)
316
-
317
  if total_chars >= max_context_chars:
318
  break
319
-
320
  formatted_context = "\n".join(formatted_contexts)
321
  else:
322
  formatted_context = "No relevant information available."
323
 
324
- # System message
325
  system_message = (
326
  "You are an AI specialized in Indian spiritual texts. "
327
  "Answer based on context, summarizing ideas rather than quoting verbatim. "
328
  "Ensure proper citation and do not include direct excerpts."
329
  )
330
-
331
- # User message
332
  user_message = f"""
333
  Context:
334
  {formatted_context}
@@ -336,14 +276,12 @@ def answer_with_llm(query, context=None, word_limit=100):
336
  {query}
337
  """
338
 
339
- # Get LLM model from secrets
340
  try:
341
  llm_model = st.secrets["LLM_MODEL"]
342
  except KeyError:
343
  print("❌ Error: LLM model not found in secrets")
344
  return "I apologize, but I'm unable to answer at the moment."
345
 
346
- # Call OpenAI API
347
  response = openai.chat.completions.create(
348
  model=llm_model,
349
  messages=[
@@ -353,50 +291,42 @@ def answer_with_llm(query, context=None, word_limit=100):
353
  max_tokens=200,
354
  temperature=0.7
355
  )
356
-
357
  answer = response.choices[0].message.content.strip()
358
-
359
- # Enforce word limit
360
  words = answer.split()
361
  if len(words) > word_limit:
362
  answer = " ".join(words[:word_limit])
363
  if not answer.endswith((".", "!", "?")):
364
  answer += "."
365
-
366
  return answer
367
-
368
  except Exception as e:
369
  print(f"❌ LLM API error: {str(e)}")
370
  return "I apologize, but I'm unable to answer at the moment."
371
 
372
  def format_citations(sources):
373
- """Format citations to display each one on a new line with a full stop if needed."""
374
  formatted_citations = []
375
  for title, author, publisher in sources:
376
- # Check if the publisher already ends with a period, question mark, or exclamation mark
377
  if publisher.endswith(('.', '!', '?')):
378
  formatted_citations.append(f"πŸ“š {title} by {author}, Published by {publisher}")
379
  else:
380
  formatted_citations.append(f"πŸ“š {title} by {author}, Published by {publisher}.")
381
-
382
  return "\n".join(formatted_citations)
383
 
384
- def process_query(query, top_k=5, word_limit=100):
385
- """Process a query through the RAG pipeline with proper formatting."""
386
- print(f"\nπŸ” Processing query: {query}")
387
-
388
- # Load data files if not already loaded
389
- faiss_index, text_chunks, metadata_dict = load_data_files()
390
-
391
- # Check if data loaded successfully
 
392
  if faiss_index is None or text_chunks is None or metadata_dict is None:
393
  return {
394
  "query": query,
395
  "answer_with_rag": "⚠️ System error: Data files not loaded properly.",
396
  "citations": "No citations available."
397
  }
398
-
399
- # Get relevant passages
400
  retrieved_context, retrieved_sources = retrieve_passages(
401
  query,
402
  faiss_index,
@@ -404,15 +334,16 @@ def process_query(query, top_k=5, word_limit=100):
404
  metadata_dict,
405
  top_k=top_k
406
  )
407
-
408
- # Format citations
409
  sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available."
410
-
411
- # Generate answer
412
  if retrieved_context:
413
  context_with_sources = list(zip(retrieved_sources, retrieved_context))
414
  llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit)
415
  else:
416
  llm_answer_with_rag = "⚠️ No relevant context found."
 
417
 
418
- return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
 
 
 
 
 
22
  local_text_chunks_file = "text_chunks.txt"
23
  local_metadata_file = "metadata.jsonl"
24
 
25
+ # =============================================================================
26
+ # RESOURCE CACHING
27
+ # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ @st.cache_resource(show_spinner=False)
30
+ def cached_load_model():
31
+ """Cached version of load_model() for embedding model loading."""
32
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Force model to CPU
34
  device = torch.device("cpu")
35
 
 
47
  torch_dtype=torch.float16
48
  )
49
 
50
+ # Move model to CPU and set to eval mode
51
  model = model.to(device)
52
  model.eval()
53
 
54
  # Disable gradient computation
55
  torch.set_grad_enabled(False)
56
 
57
+ print("βœ… Model loaded successfully (cached)")
 
 
 
 
58
  return tokenizer, model
59
 
60
  except Exception as e:
61
  print(f"❌ Error loading model: {str(e)}")
 
62
  return None, None
63
 
64
+ @st.cache_resource(show_spinner=False)
65
+ def cached_load_data_files():
66
+ """Cached version of load_data_files() for FAISS index, text chunks, and metadata."""
67
+ # Initialize GCP and OpenAI clients
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  bucket = setup_gcp_client()
69
  openai_initialized = setup_openai_client()
70
 
 
82
  print(f"❌ Error: Required GCS path not found in secrets: {e}")
83
  return None, None, None
84
 
85
+ # Download necessary files if not already present locally
86
  success = True
87
  success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file)
88
  success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file)
 
101
 
102
  # Load text chunks
103
  try:
104
+ text_chunks = {} # Mapping: ID -> (Title, Author, Text)
105
  with open(local_text_chunks_file, "r", encoding="utf-8") as f:
106
  for line in f:
107
  parts = line.strip().split("\t")
 
122
  print(f"❌ Error loading metadata: {str(e)}")
123
  return None, None, None
124
 
125
+ print(f"βœ… Data loaded successfully (cached): {len(text_chunks)} passages available")
 
 
 
 
 
 
126
  return faiss_index, text_chunks, metadata_dict
127
 
128
+ # =============================================================================
129
+ # UTILITY FUNCTIONS
130
+ # =============================================================================
131
+
132
+ def setup_gcp_client():
133
+ try:
134
+ credentials = setup_gcp_auth()
135
+ try:
136
+ bucket_name_gcs = st.secrets["BUCKET_NAME_GCS"]
137
+ except KeyError:
138
+ print("❌ Error: GCS bucket name not found in secrets")
139
+ return None
140
+
141
+ storage_client = storage.Client(credentials=credentials)
142
+ bucket = storage_client.bucket(bucket_name_gcs)
143
+ print("βœ… GCP client initialized successfully")
144
+ return bucket
145
+ except Exception as e:
146
+ print(f"❌ GCP client initialization error: {str(e)}")
147
+ return None
148
+
149
+ def setup_openai_client():
150
+ try:
151
+ setup_openai_auth()
152
+ print("βœ… OpenAI client initialized successfully")
153
+ return True
154
+ except Exception as e:
155
+ print(f"❌ OpenAI client initialization error: {str(e)}")
156
+ return False
157
+
158
+ def download_file_from_gcs(bucket, gcs_path, local_path):
159
+ """Download a file from GCS to local storage if not already present."""
160
+ try:
161
+ if os.path.exists(local_path):
162
+ print(f"File already exists locally: {local_path}")
163
+ return True
164
+
165
+ blob = bucket.blob(gcs_path)
166
+ blob.download_to_filename(local_path)
167
+ print(f"βœ… Downloaded {gcs_path} β†’ {local_path}")
168
+ return True
169
+ except Exception as e:
170
+ print(f"❌ Error downloading {gcs_path}: {str(e)}")
171
+ return False
172
+
173
  def average_pool(last_hidden_states, attention_mask):
174
  """Average pooling for sentence embeddings."""
175
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
176
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
177
 
178
+ # In-memory cache for query embeddings
179
  query_embedding_cache = {}
180
 
181
  def get_embedding(text):
182
+ """Generate embeddings for a text query using the cached model."""
 
183
  if text in query_embedding_cache:
184
  return query_embedding_cache[text]
185
 
186
  try:
187
+ tokenizer, model = cached_load_model()
 
 
 
 
 
 
188
  if model is None:
189
  print("Model is None, returning zero embedding")
190
  return np.zeros((1, 384), dtype=np.float32)
191
 
 
192
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
 
 
193
  inputs = tokenizer(
194
  input_text,
195
  padding=True,
 
198
  max_length=512,
199
  return_attention_mask=True
200
  )
 
 
201
  with torch.no_grad():
202
  outputs = model(**inputs)
203
  embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
204
  embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
205
  embeddings = embeddings.detach().cpu().numpy()
 
 
206
  del outputs, inputs
207
  gc.collect()
 
 
208
  query_embedding_cache[text] = embeddings
209
  return embeddings
210
  except Exception as e:
 
212
  return np.zeros((1, 384), dtype=np.float32)
213
 
214
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
215
+ """Retrieve top-k most relevant passages using FAISS and accompanying metadata."""
216
  try:
217
  print(f"\nπŸ” Retrieving passages for query: {query}")
 
 
218
  query_embedding = get_embedding(query)
 
 
219
  distances, indices = faiss_index.search(query_embedding, top_k * 2)
 
220
  print(f"Found {len(distances[0])} potential matches")
221
  retrieved_passages = []
222
  retrieved_sources = []
223
  cited_titles = set()
 
 
224
  for dist, idx in zip(distances[0], indices[0]):
225
  print(f"Distance: {dist:.4f}, Index: {idx}")
226
  if idx in text_chunks and dist >= similarity_threshold:
227
  title_with_txt, author, text = text_chunks[idx]
 
 
228
  clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
229
  clean_title = unicodedata.normalize("NFC", clean_title)
 
 
230
  if clean_title in cited_titles:
231
  continue
 
 
232
  metadata_entry = metadata_dict.get(clean_title, {})
233
  author = metadata_entry.get("Author", "Unknown")
234
  publisher = metadata_entry.get("Publisher", "Unknown")
 
 
235
  cited_titles.add(clean_title)
236
  retrieved_passages.append(text)
237
+ retrieved_sources.append((clean_title, author, publisher))
 
 
238
  if len(retrieved_passages) == top_k:
239
  break
 
240
  print(f"Retrieved {len(retrieved_passages)} passages")
241
  return retrieved_passages, retrieved_sources
242
  except Exception as e:
 
244
  return [], []
245
 
246
  def answer_with_llm(query, context=None, word_limit=100):
247
+ """Generate an answer using the OpenAI GPT model with formatted citations."""
248
  try:
 
249
  if context:
250
  formatted_contexts = []
251
  total_chars = 0
252
  max_context_chars = 4000
 
253
  for (title, author, publisher), text in context:
254
  remaining_space = max(0, max_context_chars - total_chars)
255
  excerpt_len = min(150, remaining_space)
 
256
  if excerpt_len > 50:
257
  excerpt = text[:excerpt_len].strip() + "..." if len(text) > excerpt_len else text
258
  formatted_context = f"[{title} by {author}, Published by {publisher}] {excerpt}"
259
  formatted_contexts.append(formatted_context)
260
  total_chars += len(formatted_context)
 
261
  if total_chars >= max_context_chars:
262
  break
 
263
  formatted_context = "\n".join(formatted_contexts)
264
  else:
265
  formatted_context = "No relevant information available."
266
 
 
267
  system_message = (
268
  "You are an AI specialized in Indian spiritual texts. "
269
  "Answer based on context, summarizing ideas rather than quoting verbatim. "
270
  "Ensure proper citation and do not include direct excerpts."
271
  )
 
 
272
  user_message = f"""
273
  Context:
274
  {formatted_context}
 
276
  {query}
277
  """
278
 
 
279
  try:
280
  llm_model = st.secrets["LLM_MODEL"]
281
  except KeyError:
282
  print("❌ Error: LLM model not found in secrets")
283
  return "I apologize, but I'm unable to answer at the moment."
284
 
 
285
  response = openai.chat.completions.create(
286
  model=llm_model,
287
  messages=[
 
291
  max_tokens=200,
292
  temperature=0.7
293
  )
 
294
  answer = response.choices[0].message.content.strip()
 
 
295
  words = answer.split()
296
  if len(words) > word_limit:
297
  answer = " ".join(words[:word_limit])
298
  if not answer.endswith((".", "!", "?")):
299
  answer += "."
 
300
  return answer
 
301
  except Exception as e:
302
  print(f"❌ LLM API error: {str(e)}")
303
  return "I apologize, but I'm unable to answer at the moment."
304
 
305
  def format_citations(sources):
306
+ """Format citations so that each appears on a new line, ending with proper punctuation."""
307
  formatted_citations = []
308
  for title, author, publisher in sources:
 
309
  if publisher.endswith(('.', '!', '?')):
310
  formatted_citations.append(f"πŸ“š {title} by {author}, Published by {publisher}")
311
  else:
312
  formatted_citations.append(f"πŸ“š {title} by {author}, Published by {publisher}.")
 
313
  return "\n".join(formatted_citations)
314
 
315
+ # =============================================================================
316
+ # DATA CACHING FOR QUERY RESULTS
317
+ # =============================================================================
318
+
319
+ @st.cache_data(ttl=3600, show_spinner=False)
320
+ def cached_process_query(query, top_k=5, word_limit=100):
321
+ """Cached query processing to avoid redundant computation for repeated queries."""
322
+ print(f"\nπŸ” Processing query (cached): {query}")
323
+ faiss_index, text_chunks, metadata_dict = cached_load_data_files()
324
  if faiss_index is None or text_chunks is None or metadata_dict is None:
325
  return {
326
  "query": query,
327
  "answer_with_rag": "⚠️ System error: Data files not loaded properly.",
328
  "citations": "No citations available."
329
  }
 
 
330
  retrieved_context, retrieved_sources = retrieve_passages(
331
  query,
332
  faiss_index,
 
334
  metadata_dict,
335
  top_k=top_k
336
  )
 
 
337
  sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available."
 
 
338
  if retrieved_context:
339
  context_with_sources = list(zip(retrieved_sources, retrieved_context))
340
  llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit)
341
  else:
342
  llm_answer_with_rag = "⚠️ No relevant context found."
343
+ return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
344
 
345
+ def process_query(query, top_k=5, word_limit=100):
346
+ """Process a query through the RAG pipeline with proper formatting.
347
+ This function wraps the cached query processing.
348
+ """
349
+ return cached_process_query(query, top_k, word_limit)