nishantgaurav23 commited on
Commit
c8cc55e
·
verified ·
1 Parent(s): 060ddae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -6
app.py CHANGED
@@ -177,10 +177,35 @@ class SentenceTransformerRetriever:
177
  return None
178
 
179
  @log_function
180
- def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
181
  try:
182
- embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  return F.normalize(embeddings, p=2, dim=1)
 
184
  except Exception as e:
185
  logging.error(f"Error encoding texts: {str(e)}")
186
  raise
@@ -274,38 +299,59 @@ class RAGPipeline:
274
  @st.cache_data
275
  def load_and_process_csvs(_self):
276
  try:
 
277
  cache_data = _self.retriever.load_cache(_self.data_folder)
278
  if cache_data is not None:
279
  _self.documents = cache_data['documents']
280
  _self.retriever.store_embeddings(cache_data['embeddings'])
 
281
  return
282
 
 
283
  csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv"))
284
  if not csv_files:
285
  raise FileNotFoundError(f"No CSV files found in {_self.data_folder}")
286
 
287
  all_documents = []
288
- for csv_file in tqdm(csv_files, desc="Reading CSV files"):
 
 
 
 
 
289
  try:
290
- df = pd.read_csv(csv_file)
291
  texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
292
  all_documents.extend(texts)
 
 
 
 
 
293
  except Exception as e:
294
  logging.error(f"Error processing file {csv_file}: {e}")
295
  continue
296
 
 
 
 
297
  if not all_documents:
298
  raise ValueError("No documents were successfully loaded")
299
 
 
300
  _self.documents = all_documents
301
  embeddings = _self.retriever.encode(all_documents)
302
  _self.retriever.store_embeddings(embeddings)
303
 
 
304
  cache_data = {
305
  'embeddings': embeddings,
306
  'documents': _self.documents
307
  }
308
  _self.retriever.save_cache(_self.data_folder, cache_data)
 
 
 
309
  except Exception as e:
310
  logging.error(f"Error in load_and_process_csvs: {str(e)}")
311
  raise
@@ -403,13 +449,20 @@ def initialize_rag_pipeline():
403
  data_folder = "ESPN_data"
404
  if not os.path.exists(data_folder):
405
  os.makedirs(data_folder, exist_ok=True)
 
 
 
 
 
 
 
406
 
407
  rag = RAGPipeline(data_folder)
408
- rag.load_and_process_csvs()
409
  return rag
 
410
  except Exception as e:
411
  logging.error(f"Pipeline initialization error: {str(e)}")
412
- st.error("Failed to initialize the system. Please check your data folder and try again.")
413
  raise
414
 
415
  def main():
 
177
  return None
178
 
179
  @log_function
180
+ def encode(self, texts: List[str], batch_size: int = 64) -> torch.Tensor: # Increased batch size
181
  try:
182
+ # Show a Streamlit progress bar
183
+ progress_text = "Processing documents..."
184
+ progress_bar = st.progress(0)
185
+
186
+ total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0)
187
+ all_embeddings = []
188
+
189
+ for i in range(0, len(texts), batch_size):
190
+ batch = texts[i:i + batch_size]
191
+ batch_embeddings = self.model.encode(
192
+ batch,
193
+ convert_to_tensor=True,
194
+ show_progress_bar=False # Disable tqdm progress bar
195
+ )
196
+ all_embeddings.append(batch_embeddings)
197
+
198
+ # Update progress
199
+ progress = min((i + batch_size) / len(texts), 1.0)
200
+ progress_bar.progress(progress)
201
+
202
+ # Clear progress bar
203
+ progress_bar.empty()
204
+
205
+ # Concatenate all embeddings
206
+ embeddings = torch.cat(all_embeddings, dim=0)
207
  return F.normalize(embeddings, p=2, dim=1)
208
+
209
  except Exception as e:
210
  logging.error(f"Error encoding texts: {str(e)}")
211
  raise
 
299
  @st.cache_data
300
  def load_and_process_csvs(_self):
301
  try:
302
+ # Try loading from cache first
303
  cache_data = _self.retriever.load_cache(_self.data_folder)
304
  if cache_data is not None:
305
  _self.documents = cache_data['documents']
306
  _self.retriever.store_embeddings(cache_data['embeddings'])
307
+ st.success("Loaded documents from cache")
308
  return
309
 
310
+ st.info("Processing documents... This may take a while.")
311
  csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv"))
312
  if not csv_files:
313
  raise FileNotFoundError(f"No CSV files found in {_self.data_folder}")
314
 
315
  all_documents = []
316
+ total_files = len(csv_files)
317
+
318
+ # Create a progress bar
319
+ progress_bar = st.progress(0)
320
+
321
+ for idx, csv_file in enumerate(csv_files):
322
  try:
323
+ df = pd.read_csv(csv_file, low_memory=False) # Added low_memory=False
324
  texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
325
  all_documents.extend(texts)
326
+
327
+ # Update progress
328
+ progress = (idx + 1) / total_files
329
+ progress_bar.progress(progress)
330
+
331
  except Exception as e:
332
  logging.error(f"Error processing file {csv_file}: {e}")
333
  continue
334
 
335
+ # Clear progress bar
336
+ progress_bar.empty()
337
+
338
  if not all_documents:
339
  raise ValueError("No documents were successfully loaded")
340
 
341
+ st.info(f"Processing {len(all_documents)} documents...")
342
  _self.documents = all_documents
343
  embeddings = _self.retriever.encode(all_documents)
344
  _self.retriever.store_embeddings(embeddings)
345
 
346
+ # Save to cache
347
  cache_data = {
348
  'embeddings': embeddings,
349
  'documents': _self.documents
350
  }
351
  _self.retriever.save_cache(_self.data_folder, cache_data)
352
+
353
+ st.success("Document processing complete!")
354
+
355
  except Exception as e:
356
  logging.error(f"Error in load_and_process_csvs: {str(e)}")
357
  raise
 
449
  data_folder = "ESPN_data"
450
  if not os.path.exists(data_folder):
451
  os.makedirs(data_folder, exist_ok=True)
452
+
453
+ # Check for cache
454
+ cache_path = os.path.join("embeddings_cache", "embeddings.pkl")
455
+ if os.path.exists(cache_path):
456
+ st.info("Found cached data. Loading...")
457
+ else:
458
+ st.warning("Initial setup may take several minutes...")
459
 
460
  rag = RAGPipeline(data_folder)
 
461
  return rag
462
+
463
  except Exception as e:
464
  logging.error(f"Pipeline initialization error: {str(e)}")
465
+ st.error("Failed to initialize the system. Please check if all required files are present.")
466
  raise
467
 
468
  def main():