ankanghosh commited on
Commit
c440b0b
Β·
verified Β·
1 Parent(s): 1a4c2ba

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +70 -52
rag_engine.py CHANGED
@@ -11,6 +11,7 @@ import textwrap
11
  import unicodedata
12
  import streamlit as st
13
  from utils import setup_gcp_auth, setup_openai_auth
 
14
 
15
  # Force model to CPU for stability
16
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -22,6 +23,7 @@ def initialize_session_state():
22
  st.session_state.model = None
23
  st.session_state.tokenizer = None
24
  st.session_state.device = torch.device("cpu")
 
25
  print("Initialized session state variables")
26
 
27
  # Call the initialization function right away
@@ -38,7 +40,6 @@ def setup_gcp_client():
38
  return bucket
39
  except Exception as e:
40
  print(f"❌ GCP client initialization error: {str(e)}")
41
- st.error(f"GCP client initialization error: {str(e)}")
42
  return None
43
 
44
  # Setup OpenAI authentication
@@ -49,7 +50,6 @@ def setup_openai_client():
49
  return True
50
  except Exception as e:
51
  print(f"❌ OpenAI client initialization error: {str(e)}")
52
- st.error(f"OpenAI client initialization error: {str(e)}")
53
  return False
54
 
55
  # GCS Paths
@@ -66,32 +66,36 @@ local_metadata_file = "metadata.jsonl"
66
 
67
  def load_model():
68
  try:
69
- # Initialize model if it doesn't exist
70
- if 'model' not in st.session_state or st.session_state.model is None:
71
- # Force model to CPU - more stable than GPU for this use case
72
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
73
-
74
- print("Loading tokenizer...")
75
- tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
76
-
77
- print("Loading model...")
78
- model = AutoModel.from_pretrained(
79
- "intfloat/e5-small-v2",
80
- torch_dtype=torch.float16 # Use half precision
81
- )
82
-
83
- # Move model to CPU explicitly
84
- model = model.to('cpu')
85
- model.eval()
86
- torch.set_grad_enabled(False)
87
-
88
- # Store in session state
89
- st.session_state.tokenizer = tokenizer
90
- st.session_state.model = model
91
-
92
- print("βœ… Model loaded successfully")
 
 
 
 
93
 
94
- return st.session_state.tokenizer, st.session_state.model
95
  except Exception as e:
96
  print(f"❌ Error loading model: {str(e)}")
97
  # Return None values instead of raising to avoid crashing
@@ -100,32 +104,41 @@ def load_model():
100
  def download_file_from_gcs(bucket, gcs_path, local_path):
101
  """Download a file from GCS to local storage."""
102
  try:
 
 
 
 
 
103
  blob = bucket.blob(gcs_path)
104
  blob.download_to_filename(local_path)
105
  print(f"βœ… Downloaded {gcs_path} β†’ {local_path}")
106
  return True
107
  except Exception as e:
108
  print(f"❌ Error downloading {gcs_path}: {str(e)}")
109
- st.error(f"Error downloading {gcs_path}: {str(e)}")
110
  return False
111
 
112
  def load_data_files():
 
 
 
 
 
113
  # Initialize GCP and OpenAI clients
114
  bucket = setup_gcp_client()
115
  openai_initialized = setup_openai_client()
116
 
117
  if not bucket or not openai_initialized:
118
- st.error("Failed to initialize required services")
119
  return None, None, None
120
 
121
- # Download necessary files - remove the spinner from here
122
  success = True
123
  success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file)
124
  success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file)
125
  success &= download_file_from_gcs(bucket, metadata_file_gcs, local_metadata_file)
126
 
127
  if not success:
128
- st.error("Failed to download required files")
129
  return None, None, None
130
 
131
  # Load FAISS index
@@ -133,7 +146,6 @@ def load_data_files():
133
  faiss_index = faiss.read_index(local_faiss_index_file)
134
  except Exception as e:
135
  print(f"❌ Error loading FAISS index: {str(e)}")
136
- st.error(f"Error loading FAISS index: {str(e)}")
137
  return None, None, None
138
 
139
  # Load text chunks
@@ -146,7 +158,6 @@ def load_data_files():
146
  text_chunks[int(parts[0])] = (parts[1], parts[2], parts[3])
147
  except Exception as e:
148
  print(f"❌ Error loading text chunks: {str(e)}")
149
- st.error(f"Error loading text chunks: {str(e)}")
150
  return None, None, None
151
 
152
  # Load metadata.jsonl for publisher information
@@ -158,10 +169,16 @@ def load_data_files():
158
  metadata_dict[item["Title"]] = item # Store for easy lookup
159
  except Exception as e:
160
  print(f"❌ Error loading metadata: {str(e)}")
161
- st.error(f"Error loading metadata: {str(e)}")
162
  return None, None, None
163
 
164
  print(f"βœ… FAISS index and text chunks loaded. {len(text_chunks)} passages available.")
 
 
 
 
 
 
 
165
  return faiss_index, text_chunks, metadata_dict
166
 
167
  def average_pool(last_hidden_states, attention_mask):
@@ -177,16 +194,17 @@ def get_embedding(text):
177
 
178
  try:
179
  # Ensure model initialization
180
- if 'model' not in st.session_state or st.session_state.model is None:
181
  tokenizer, model = load_model()
182
  if model is None:
183
- return np.zeros((1, 384), dtype=np.float32) # Fallback for e5-small-v2
184
  else:
185
  tokenizer, model = st.session_state.tokenizer, st.session_state.model
186
 
 
187
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
188
 
189
- # Explicitly specify truncation parameters to avoid warnings
190
  inputs = tokenizer(
191
  input_text,
192
  padding=True,
@@ -196,26 +214,25 @@ def get_embedding(text):
196
  return_attention_mask=True
197
  )
198
 
199
- # Move to CPU explicitly before processing
200
  inputs = {k: v.to('cpu') for k, v in inputs.items()}
201
 
202
  with torch.no_grad():
203
  outputs = model(**inputs)
204
  embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
205
  embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
206
- # Ensure we detach and move to numpy on CPU
207
  embeddings = embeddings.detach().cpu().numpy()
208
 
209
  # Explicitly clean up
210
- del outputs
 
211
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
212
 
213
  query_embedding_cache[text] = embeddings
214
  return embeddings
215
  except Exception as e:
216
  print(f"❌ Embedding error: {str(e)}")
217
- st.error(f"Embedding error: {str(e)}")
218
- return np.zeros((1, 384), dtype=np.float32) # Changed from 1024 to 384 for e5-small-v2
219
 
220
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
221
  """Retrieve top-k most relevant passages using FAISS with metadata."""
@@ -242,6 +259,7 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
242
  if clean_title in cited_titles:
243
  continue
244
 
 
245
  metadata_entry = metadata_dict.get(clean_title, {})
246
  author = metadata_entry.get("Author", "Unknown")
247
  publisher = metadata_entry.get("Publisher", "Unknown")
@@ -258,7 +276,6 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
258
  return retrieved_passages, retrieved_sources
259
  except Exception as e:
260
  print(f"❌ Error in retrieve_passages: {str(e)}")
261
- st.error(f"Error in retrieve_passages: {str(e)}")
262
  return [], []
263
 
264
  def answer_with_llm(query, context=None, word_limit=100):
@@ -325,7 +342,6 @@ def answer_with_llm(query, context=None, word_limit=100):
325
 
326
  except Exception as e:
327
  print(f"❌ LLM API error: {str(e)}")
328
- st.error(f"LLM API error: {str(e)}")
329
  return "I apologize, but I'm unable to answer at the moment."
330
 
331
  def format_citations(sources):
@@ -345,19 +361,21 @@ def process_query(query, top_k=5, word_limit=100):
345
  print(f"\nπŸ” Processing query: {query}")
346
 
347
  # Load data files if not already loaded
348
- if not hasattr(st.session_state, 'data_loaded') or not st.session_state.data_loaded:
349
- st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict = load_data_files()
350
- st.session_state.data_loaded = True
351
 
352
  # Check if data loaded successfully
353
- if not st.session_state.faiss_index or not st.session_state.text_chunks or not st.session_state.metadata_dict:
354
- return {"query": query, "answer_with_rag": "⚠️ System error: Data files not loaded properly.", "citations": "No citations available."}
 
 
 
 
355
 
356
  retrieved_context, retrieved_sources = retrieve_passages(
357
  query,
358
- st.session_state.faiss_index,
359
- st.session_state.text_chunks,
360
- st.session_state.metadata_dict,
361
  top_k=top_k
362
  )
363
 
 
11
  import unicodedata
12
  import streamlit as st
13
  from utils import setup_gcp_auth, setup_openai_auth
14
+ import gc # Added for explicit garbage collection
15
 
16
  # Force model to CPU for stability
17
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
23
  st.session_state.model = None
24
  st.session_state.tokenizer = None
25
  st.session_state.device = torch.device("cpu")
26
+ st.session_state.data_loaded = False
27
  print("Initialized session state variables")
28
 
29
  # Call the initialization function right away
 
40
  return bucket
41
  except Exception as e:
42
  print(f"❌ GCP client initialization error: {str(e)}")
 
43
  return None
44
 
45
  # Setup OpenAI authentication
 
50
  return True
51
  except Exception as e:
52
  print(f"❌ OpenAI client initialization error: {str(e)}")
 
53
  return False
54
 
55
  # GCS Paths
 
66
 
67
  def load_model():
68
  try:
69
+ # Check if model is already loaded
70
+ if st.session_state.model is not None and st.session_state.tokenizer is not None:
71
+ print("Model already loaded, reusing existing instance")
72
+ return st.session_state.tokenizer, st.session_state.model
73
+
74
+ # Force model to CPU - more stable than GPU for this use case
75
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
76
+
77
+ print("Loading tokenizer...")
78
+ tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
79
+
80
+ print("Loading model...")
81
+ model = AutoModel.from_pretrained(
82
+ "intfloat/e5-small-v2",
83
+ torch_dtype=torch.float16 # Use half precision
84
+ )
85
+
86
+ # Move model to CPU explicitly
87
+ model = model.to('cpu')
88
+ model.eval()
89
+ torch.set_grad_enabled(False)
90
+
91
+ # Store in session state
92
+ st.session_state.tokenizer = tokenizer
93
+ st.session_state.model = model
94
+ st.session_state.model_initialized = True
95
+
96
+ print("βœ… Model loaded successfully")
97
 
98
+ return tokenizer, model
99
  except Exception as e:
100
  print(f"❌ Error loading model: {str(e)}")
101
  # Return None values instead of raising to avoid crashing
 
104
  def download_file_from_gcs(bucket, gcs_path, local_path):
105
  """Download a file from GCS to local storage."""
106
  try:
107
+ # Check if file already exists locally
108
+ if os.path.exists(local_path):
109
+ print(f"File already exists locally: {local_path}")
110
+ return True
111
+
112
  blob = bucket.blob(gcs_path)
113
  blob.download_to_filename(local_path)
114
  print(f"βœ… Downloaded {gcs_path} β†’ {local_path}")
115
  return True
116
  except Exception as e:
117
  print(f"❌ Error downloading {gcs_path}: {str(e)}")
 
118
  return False
119
 
120
  def load_data_files():
121
+ # Check if already loaded in session state
122
+ if hasattr(st.session_state, 'faiss_index') and st.session_state.faiss_index is not None:
123
+ print("Using cached data files from session state")
124
+ return st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict
125
+
126
  # Initialize GCP and OpenAI clients
127
  bucket = setup_gcp_client()
128
  openai_initialized = setup_openai_client()
129
 
130
  if not bucket or not openai_initialized:
131
+ print("Failed to initialize required services")
132
  return None, None, None
133
 
134
+ # Download necessary files
135
  success = True
136
  success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file)
137
  success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file)
138
  success &= download_file_from_gcs(bucket, metadata_file_gcs, local_metadata_file)
139
 
140
  if not success:
141
+ print("Failed to download required files")
142
  return None, None, None
143
 
144
  # Load FAISS index
 
146
  faiss_index = faiss.read_index(local_faiss_index_file)
147
  except Exception as e:
148
  print(f"❌ Error loading FAISS index: {str(e)}")
 
149
  return None, None, None
150
 
151
  # Load text chunks
 
158
  text_chunks[int(parts[0])] = (parts[1], parts[2], parts[3])
159
  except Exception as e:
160
  print(f"❌ Error loading text chunks: {str(e)}")
 
161
  return None, None, None
162
 
163
  # Load metadata.jsonl for publisher information
 
169
  metadata_dict[item["Title"]] = item # Store for easy lookup
170
  except Exception as e:
171
  print(f"❌ Error loading metadata: {str(e)}")
 
172
  return None, None, None
173
 
174
  print(f"βœ… FAISS index and text chunks loaded. {len(text_chunks)} passages available.")
175
+
176
+ # Store in session state
177
+ st.session_state.faiss_index = faiss_index
178
+ st.session_state.text_chunks = text_chunks
179
+ st.session_state.metadata_dict = metadata_dict
180
+ st.session_state.data_loaded = True
181
+
182
  return faiss_index, text_chunks, metadata_dict
183
 
184
  def average_pool(last_hidden_states, attention_mask):
 
194
 
195
  try:
196
  # Ensure model initialization
197
+ if not hasattr(st.session_state, 'model') or st.session_state.model is None:
198
  tokenizer, model = load_model()
199
  if model is None:
200
+ return np.zeros((1, 384), dtype=np.float32)
201
  else:
202
  tokenizer, model = st.session_state.tokenizer, st.session_state.model
203
 
204
+ # Prepare text
205
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
206
 
207
+ # Explicitly specify truncation parameters
208
  inputs = tokenizer(
209
  input_text,
210
  padding=True,
 
214
  return_attention_mask=True
215
  )
216
 
217
+ # Move to CPU explicitly
218
  inputs = {k: v.to('cpu') for k, v in inputs.items()}
219
 
220
  with torch.no_grad():
221
  outputs = model(**inputs)
222
  embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
223
  embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
 
224
  embeddings = embeddings.detach().cpu().numpy()
225
 
226
  # Explicitly clean up
227
+ del outputs, inputs
228
+ gc.collect()
229
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
230
 
231
  query_embedding_cache[text] = embeddings
232
  return embeddings
233
  except Exception as e:
234
  print(f"❌ Embedding error: {str(e)}")
235
+ return np.zeros((1, 384), dtype=np.float32)
 
236
 
237
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
238
  """Retrieve top-k most relevant passages using FAISS with metadata."""
 
259
  if clean_title in cited_titles:
260
  continue
261
 
262
+ # Get metadata safely
263
  metadata_entry = metadata_dict.get(clean_title, {})
264
  author = metadata_entry.get("Author", "Unknown")
265
  publisher = metadata_entry.get("Publisher", "Unknown")
 
276
  return retrieved_passages, retrieved_sources
277
  except Exception as e:
278
  print(f"❌ Error in retrieve_passages: {str(e)}")
 
279
  return [], []
280
 
281
  def answer_with_llm(query, context=None, word_limit=100):
 
342
 
343
  except Exception as e:
344
  print(f"❌ LLM API error: {str(e)}")
 
345
  return "I apologize, but I'm unable to answer at the moment."
346
 
347
  def format_citations(sources):
 
361
  print(f"\nπŸ” Processing query: {query}")
362
 
363
  # Load data files if not already loaded
364
+ faiss_index, text_chunks, metadata_dict = load_data_files()
 
 
365
 
366
  # Check if data loaded successfully
367
+ if faiss_index is None or text_chunks is None or metadata_dict is None:
368
+ return {
369
+ "query": query,
370
+ "answer_with_rag": "⚠️ System error: Data files not loaded properly.",
371
+ "citations": "No citations available."
372
+ }
373
 
374
  retrieved_context, retrieved_sources = retrieve_passages(
375
  query,
376
+ faiss_index,
377
+ text_chunks,
378
+ metadata_dict,
379
  top_k=top_k
380
  )
381