medTranscript_QA_agent / tools /retriever_tool.py
atrmkj's picture
saving changes for summary gen
c27cd83
import re
import faiss
import numpy as np
import pandas as pd
import gc
import os
import time
from sentence_transformers import SentenceTransformer
class DocumentRetriever:
def __init__(self, csv_path="data/mtsamples_surgery.csv", top_k=3, similarity_threshold=0.2, batch_size=8):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.dimension = self.model.get_sentence_embedding_dimension()
self.index = faiss.IndexFlatIP(self.dimension)
self.texts = []
self.metadata = []
self.top_k = top_k
self.similarity_threshold = similarity_threshold
self.batch_size = batch_size
self._build_index(csv_path)
def _preprocess_text(self, text):
if not isinstance(text, str):
return ""
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r'[^\w\s.,?!:;()\[\]{}\-\'"]+', ' ', text)
return text
def _build_index(self, path):
gc.collect()
print(f"Loading CSV from {path}...")
df = pd.read_csv(path)
print(f"Loaded {len(df)} rows")
print("Filtering and preprocessing texts...")
df = df.dropna(subset=['transcription'])
self.metadata = df[['medical_specialty', 'sample_name']].to_dict('records')
self.texts = []
for i in range(0, len(df), self.batch_size):
batch = df['transcription'].iloc[i:i+self.batch_size].tolist()
self.texts.extend([self._preprocess_text(text) for text in batch])
gc.collect()
print(f"Preprocessing complete. Starting encoding {len(self.texts)} documents...")
for i in range(0, len(self.texts), self.batch_size):
end_idx = min(i + self.batch_size, len(self.texts))
batch = self.texts[i:end_idx]
print(f"Encoding batch {i//self.batch_size + 1}/{(len(self.texts) + self.batch_size - 1)//self.batch_size}...")
batch_embeddings = self.model.encode(batch, show_progress_bar=False)
faiss.normalize_L2(batch_embeddings)
self.index.add(np.array(batch_embeddings))
del batch_embeddings
gc.collect()
time.sleep(0.1)
print(f"Index built with {len(self.texts)} documents")
def add_documents(self, new_texts, new_metadata=None):
if not new_texts:
return
processed_texts = [self._preprocess_text(text) for text in new_texts]
# Add to existing texts and metadata
self.texts.extend(processed_texts)
if new_metadata:
self.metadata.extend(new_metadata)
for i in range(0, len(processed_texts), self.batch_size):
batch = processed_texts[i:i+min(self.batch_size, len(processed_texts)-i)]
batch_embeddings = self.model.encode(batch, show_progress_bar=False)
faiss.normalize_L2(batch_embeddings)
self.index.add(np.array(batch_embeddings))
# def query(self, question, include_metadata=True):
# try:
# q_embedding = self.model.encode([question])
# faiss.normalize_L2(q_embedding)
# k = min(self.top_k * 2, len(self.texts))
# scores, indices = self.index.search(np.array(q_embedding), k)
# results = []
# for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
# if idx != -1 and score >= self.similarity_threshold and i < self.top_k:
# doc_text = self.texts[idx]
# if include_metadata and idx < len(self.metadata):
# meta = self.metadata[idx]
# doc_info = f"[Document {i+1}] (Score: {score:.2f}, Specialty: {meta.get('medical_specialty', 'Unknown')}, Sample: {meta.get('sample_name', 'Unknown')})\n\n{doc_text}"
# else:
# doc_info = f"[Document {i+1}] (Score: {score:.2f})\n\n{doc_text}"
# results.append(doc_info)
# gc.collect()
# if not results:
# return "No relevant documents found for this query."
# return "\n\n" + "-"*80 + "\n\n".join(results)
# except Exception as e:
# return f"Error during retrieval: {str(e)}"
def query(self, question, include_metadata=True):
try:
q_embedding = self.model.encode([question])
faiss.normalize_L2(q_embedding)
k = min(self.top_k * 2, len(self.texts))
scores, indices = self.index.search(np.array(q_embedding), k)
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx != -1 and score >= self.similarity_threshold and i < self.top_k:
doc_text = self.texts[idx]
if include_metadata and idx < len(self.metadata):
meta = self.metadata[idx]
# Add description to the output
description = meta.get('description', 'No description available')
doc_info = f"[Document {i+1}] (Score: {score:.2f})\nSpecialty: {meta.get('medical_specialty', 'Unknown')}\nSample: {meta.get('sample_name', 'Unknown')}\nDescription: {description}\n\n{doc_text}"
else:
doc_info = f"[Document {i+1}] (Score: {score:.2f})\n\n{doc_text}"
results.append(doc_info)
gc.collect()
if not results:
return "No relevant documents found for this query."
return "\n\n" + "-"*80 + "\n\n".join(results)
except Exception as e:
return f"Error during retrieval: {str(e)}"