ankanghosh commited on
Commit
ff4d9c5
·
verified ·
1 Parent(s): 3b2ec72

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +33 -3
rag_engine.py CHANGED
@@ -175,18 +175,48 @@ query_embedding_cache = {}
175
  def get_embedding(text):
176
  if text in query_embedding_cache:
177
  return query_embedding_cache[text]
 
178
  try:
179
  # Ensure model initialization
180
  if 'model' not in st.session_state or st.session_state.model is None:
181
  tokenizer, model = load_model()
182
  if model is None:
183
- return np.zeros((1, 384), dtype=np.float32) # Fallback
184
  else:
185
  tokenizer, model = st.session_state.tokenizer, st.session_state.model
186
-
187
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
188
 
189
- # Rest of your code...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
192
  """Retrieve top-k most relevant passages using FAISS with metadata."""
 
175
  def get_embedding(text):
176
  if text in query_embedding_cache:
177
  return query_embedding_cache[text]
178
+
179
  try:
180
  # Ensure model initialization
181
  if 'model' not in st.session_state or st.session_state.model is None:
182
  tokenizer, model = load_model()
183
  if model is None:
184
+ return np.zeros((1, 384), dtype=np.float32) # Fallback for e5-small-v2
185
  else:
186
  tokenizer, model = st.session_state.tokenizer, st.session_state.model
187
+
188
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
189
 
190
+ # Explicitly specify truncation parameters to avoid warnings
191
+ inputs = tokenizer(
192
+ input_text,
193
+ padding=True,
194
+ truncation=True,
195
+ return_tensors="pt",
196
+ max_length=512,
197
+ return_attention_mask=True
198
+ )
199
+
200
+ # Move to CPU explicitly before processing
201
+ inputs = {k: v.to('cpu') for k, v in inputs.items()}
202
+
203
+ with torch.no_grad():
204
+ outputs = model(**inputs)
205
+ embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
206
+ embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
207
+ # Ensure we detach and move to numpy on CPU
208
+ embeddings = embeddings.detach().cpu().numpy()
209
+
210
+ # Explicitly clean up
211
+ del outputs
212
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
213
+
214
+ query_embedding_cache[text] = embeddings
215
+ return embeddings
216
+ except Exception as e:
217
+ print(f"❌ Embedding error: {str(e)}")
218
+ st.error(f"Embedding error: {str(e)}")
219
+ return np.zeros((1, 384), dtype=np.float32) # Changed from 1024 to 384 for e5-small-v2
220
 
221
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
222
  """Retrieve top-k most relevant passages using FAISS with metadata."""