import re import faiss import numpy as np import pandas as pd import gc import os import time from sentence_transformers import SentenceTransformer class DocumentRetriever: def __init__(self, csv_path="data/mtsamples_surgery.csv", top_k=3, similarity_threshold=0.2, batch_size=8): self.model = SentenceTransformer('all-MiniLM-L6-v2') self.dimension = self.model.get_sentence_embedding_dimension() self.index = faiss.IndexFlatIP(self.dimension) self.texts = [] self.metadata = [] self.top_k = top_k self.similarity_threshold = similarity_threshold self.batch_size = batch_size self._build_index(csv_path) def _preprocess_text(self, text): if not isinstance(text, str): return "" text = re.sub(r'\s+', ' ', text).strip() text = re.sub(r'[^\w\s.,?!:;()\[\]{}\-\'"]+', ' ', text) return text def _build_index(self, path): gc.collect() print(f"Loading CSV from {path}...") df = pd.read_csv(path) print(f"Loaded {len(df)} rows") print("Filtering and preprocessing texts...") df = df.dropna(subset=['transcription']) self.metadata = df[['medical_specialty', 'sample_name']].to_dict('records') self.texts = [] for i in range(0, len(df), self.batch_size): batch = df['transcription'].iloc[i:i+self.batch_size].tolist() self.texts.extend([self._preprocess_text(text) for text in batch]) gc.collect() print(f"Preprocessing complete. Starting encoding {len(self.texts)} documents...") for i in range(0, len(self.texts), self.batch_size): end_idx = min(i + self.batch_size, len(self.texts)) batch = self.texts[i:end_idx] print(f"Encoding batch {i//self.batch_size + 1}/{(len(self.texts) + self.batch_size - 1)//self.batch_size}...") batch_embeddings = self.model.encode(batch, show_progress_bar=False) faiss.normalize_L2(batch_embeddings) self.index.add(np.array(batch_embeddings)) del batch_embeddings gc.collect() time.sleep(0.1) print(f"Index built with {len(self.texts)} documents") def add_documents(self, new_texts, new_metadata=None): if not new_texts: return processed_texts = [self._preprocess_text(text) for text in new_texts] # Add to existing texts and metadata self.texts.extend(processed_texts) if new_metadata: self.metadata.extend(new_metadata) for i in range(0, len(processed_texts), self.batch_size): batch = processed_texts[i:i+min(self.batch_size, len(processed_texts)-i)] batch_embeddings = self.model.encode(batch, show_progress_bar=False) faiss.normalize_L2(batch_embeddings) self.index.add(np.array(batch_embeddings)) # def query(self, question, include_metadata=True): # try: # q_embedding = self.model.encode([question]) # faiss.normalize_L2(q_embedding) # k = min(self.top_k * 2, len(self.texts)) # scores, indices = self.index.search(np.array(q_embedding), k) # results = [] # for i, (score, idx) in enumerate(zip(scores[0], indices[0])): # if idx != -1 and score >= self.similarity_threshold and i < self.top_k: # doc_text = self.texts[idx] # if include_metadata and idx < len(self.metadata): # meta = self.metadata[idx] # doc_info = f"[Document {i+1}] (Score: {score:.2f}, Specialty: {meta.get('medical_specialty', 'Unknown')}, Sample: {meta.get('sample_name', 'Unknown')})\n\n{doc_text}" # else: # doc_info = f"[Document {i+1}] (Score: {score:.2f})\n\n{doc_text}" # results.append(doc_info) # gc.collect() # if not results: # return "No relevant documents found for this query." # return "\n\n" + "-"*80 + "\n\n".join(results) # except Exception as e: # return f"Error during retrieval: {str(e)}" def query(self, question, include_metadata=True): try: q_embedding = self.model.encode([question]) faiss.normalize_L2(q_embedding) k = min(self.top_k * 2, len(self.texts)) scores, indices = self.index.search(np.array(q_embedding), k) results = [] for i, (score, idx) in enumerate(zip(scores[0], indices[0])): if idx != -1 and score >= self.similarity_threshold and i < self.top_k: doc_text = self.texts[idx] if include_metadata and idx < len(self.metadata): meta = self.metadata[idx] # Add description to the output description = meta.get('description', 'No description available') doc_info = f"[Document {i+1}] (Score: {score:.2f})\nSpecialty: {meta.get('medical_specialty', 'Unknown')}\nSample: {meta.get('sample_name', 'Unknown')}\nDescription: {description}\n\n{doc_text}" else: doc_info = f"[Document {i+1}] (Score: {score:.2f})\n\n{doc_text}" results.append(doc_info) gc.collect() if not results: return "No relevant documents found for this query." return "\n\n" + "-"*80 + "\n\n".join(results) except Exception as e: return f"Error during retrieval: {str(e)}"