Spaces:
Running
Running
import re | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import gc | |
import os | |
import time | |
from sentence_transformers import SentenceTransformer | |
class DocumentRetriever: | |
def __init__(self, csv_path="data/mtsamples_surgery.csv", top_k=3, similarity_threshold=0.2, batch_size=8): | |
self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
self.dimension = self.model.get_sentence_embedding_dimension() | |
self.index = faiss.IndexFlatIP(self.dimension) | |
self.texts = [] | |
self.metadata = [] | |
self.top_k = top_k | |
self.similarity_threshold = similarity_threshold | |
self.batch_size = batch_size | |
self._build_index(csv_path) | |
def _preprocess_text(self, text): | |
if not isinstance(text, str): | |
return "" | |
text = re.sub(r'\s+', ' ', text).strip() | |
text = re.sub(r'[^\w\s.,?!:;()\[\]{}\-\'"]+', ' ', text) | |
return text | |
def _build_index(self, path): | |
gc.collect() | |
print(f"Loading CSV from {path}...") | |
df = pd.read_csv(path) | |
print(f"Loaded {len(df)} rows") | |
print("Filtering and preprocessing texts...") | |
df = df.dropna(subset=['transcription']) | |
self.metadata = df[['medical_specialty', 'sample_name']].to_dict('records') | |
self.texts = [] | |
for i in range(0, len(df), self.batch_size): | |
batch = df['transcription'].iloc[i:i+self.batch_size].tolist() | |
self.texts.extend([self._preprocess_text(text) for text in batch]) | |
gc.collect() | |
print(f"Preprocessing complete. Starting encoding {len(self.texts)} documents...") | |
for i in range(0, len(self.texts), self.batch_size): | |
end_idx = min(i + self.batch_size, len(self.texts)) | |
batch = self.texts[i:end_idx] | |
print(f"Encoding batch {i//self.batch_size + 1}/{(len(self.texts) + self.batch_size - 1)//self.batch_size}...") | |
batch_embeddings = self.model.encode(batch, show_progress_bar=False) | |
faiss.normalize_L2(batch_embeddings) | |
self.index.add(np.array(batch_embeddings)) | |
del batch_embeddings | |
gc.collect() | |
time.sleep(0.1) | |
print(f"Index built with {len(self.texts)} documents") | |
def add_documents(self, new_texts, new_metadata=None): | |
if not new_texts: | |
return | |
processed_texts = [self._preprocess_text(text) for text in new_texts] | |
# Add to existing texts and metadata | |
self.texts.extend(processed_texts) | |
if new_metadata: | |
self.metadata.extend(new_metadata) | |
for i in range(0, len(processed_texts), self.batch_size): | |
batch = processed_texts[i:i+min(self.batch_size, len(processed_texts)-i)] | |
batch_embeddings = self.model.encode(batch, show_progress_bar=False) | |
faiss.normalize_L2(batch_embeddings) | |
self.index.add(np.array(batch_embeddings)) | |
# def query(self, question, include_metadata=True): | |
# try: | |
# q_embedding = self.model.encode([question]) | |
# faiss.normalize_L2(q_embedding) | |
# k = min(self.top_k * 2, len(self.texts)) | |
# scores, indices = self.index.search(np.array(q_embedding), k) | |
# results = [] | |
# for i, (score, idx) in enumerate(zip(scores[0], indices[0])): | |
# if idx != -1 and score >= self.similarity_threshold and i < self.top_k: | |
# doc_text = self.texts[idx] | |
# if include_metadata and idx < len(self.metadata): | |
# meta = self.metadata[idx] | |
# doc_info = f"[Document {i+1}] (Score: {score:.2f}, Specialty: {meta.get('medical_specialty', 'Unknown')}, Sample: {meta.get('sample_name', 'Unknown')})\n\n{doc_text}" | |
# else: | |
# doc_info = f"[Document {i+1}] (Score: {score:.2f})\n\n{doc_text}" | |
# results.append(doc_info) | |
# gc.collect() | |
# if not results: | |
# return "No relevant documents found for this query." | |
# return "\n\n" + "-"*80 + "\n\n".join(results) | |
# except Exception as e: | |
# return f"Error during retrieval: {str(e)}" | |
def query(self, question, include_metadata=True): | |
try: | |
q_embedding = self.model.encode([question]) | |
faiss.normalize_L2(q_embedding) | |
k = min(self.top_k * 2, len(self.texts)) | |
scores, indices = self.index.search(np.array(q_embedding), k) | |
results = [] | |
for i, (score, idx) in enumerate(zip(scores[0], indices[0])): | |
if idx != -1 and score >= self.similarity_threshold and i < self.top_k: | |
doc_text = self.texts[idx] | |
if include_metadata and idx < len(self.metadata): | |
meta = self.metadata[idx] | |
# Add description to the output | |
description = meta.get('description', 'No description available') | |
doc_info = f"[Document {i+1}] (Score: {score:.2f})\nSpecialty: {meta.get('medical_specialty', 'Unknown')}\nSample: {meta.get('sample_name', 'Unknown')}\nDescription: {description}\n\n{doc_text}" | |
else: | |
doc_info = f"[Document {i+1}] (Score: {score:.2f})\n\n{doc_text}" | |
results.append(doc_info) | |
gc.collect() | |
if not results: | |
return "No relevant documents found for this query." | |
return "\n\n" + "-"*80 + "\n\n".join(results) | |
except Exception as e: | |
return f"Error during retrieval: {str(e)}" |