MDSS / app.py
khalil2233's picture
ulpad
f168ba9 verified
import os
import json
import numpy as np
import faiss
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
import nltk
import re
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.stem import WordNetLemmatizer
from multiprocessing import Pool, cpu_count
nltk.download("all")
# Load stopwords and lemmatizer
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
# Load dataset
def load_and_preprocess_dataset():
"""Load and preprocess the dataset."""
dataset = load_dataset("MedRAG/textbooks")
print("Dataset loaded successfully.")
return dataset
# Preprocessing function
def preprocess_text(text):
"""Preprocess text by lowercasing, removing special characters, and lemmatizing."""
text = text.lower() # Convert to lowercase
text = re.sub(r"[^\w\s]", "", text) # Remove special characters
words = word_tokenize(text) # Tokenization
words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal
return " ".join(words)
# Chunking function
def chunk_text(text, chunk_size=3):
"""Split text into chunks of sentences."""
sentences = sent_tokenize(text) # Split text into sentences
return [" ".join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)]
# Generate embeddings in parallel
def generate_embeddings_parallel(chunks):
"""Generate embeddings for chunks in parallel."""
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
with Pool(cpu_count()) as pool:
embeddings = pool.map(embed_model.encode, chunks)
return embeddings
# Generate embeddings for the dataset
def generate_embeddings(dataset):
"""Generate embeddings for the dataset."""
print("Preprocessing dataset...")
dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])})
dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])})
print("Generating embeddings...")
all_chunks = [chunk for row in dataset["train"]["chunks"] for chunk in row]
embeddings = generate_embeddings_parallel(all_chunks)
# Add embeddings to the dataset
dataset = dataset.map(lambda row, idx: {"embedding": embeddings[idx]}, with_indices=True)
return dataset
# Create FAISS index
def create_faiss_index(dataset):
"""Create and save a FAISS index for the embeddings."""
embeddings_np = np.array([np.array(row["embedding"]).flatten().tolist() for row in dataset["train"]], dtype=np.float32)
index = faiss.IndexFlatL2(embeddings_np.shape[1])
index.add(embeddings_np)
faiss.write_index(index, "faiss_medical.index")
print("FAISS index created and saved.")
# Load FAISS index
def load_faiss_index():
"""Load the FAISS index."""
index = faiss.read_index("faiss_medical.index")
print("FAISS index loaded.")
return index
# Retrieve medical summary
def retrieve_medical_summary(query, index, id_to_text, k=3):
"""Retrieve the most relevant medical literature from FAISS."""
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
query_embedding = embed_model.encode([query])
D, I = index.search(np.array(query_embedding).astype("float32"), k)
retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."
# Generate medical answer using Groq
def generate_medical_answer_groq(query, index, id_to_text):
"""Generate a medical response using Groq's API."""
retrieved_summary = retrieve_medical_summary(query, index, id_to_text)
if not retrieved_summary or retrieved_summary == "No relevant data found.":
return "No relevant medical data found. Please consult a healthcare professional."
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
try:
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{"role": "system", "content": "You are an expert AI specializing in medical knowledge."},
{"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"}
],
max_tokens=500,
temperature=0.3
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
# Gradio interface
def ask_medical_question(question):
"""Gradio interface for asking medical questions."""
return generate_medical_answer_groq(question, index, id_to_text)
# Main function
def main():
"""Main function to set up the system."""
global index, id_to_text
# Load and preprocess dataset
dataset = load_and_preprocess_dataset()
dataset = generate_embeddings(dataset)
# Create FAISS index
create_faiss_index(dataset)
# Load FAISS index
index = load_faiss_index()
# Create ID to text mapping
medical_texts = dataset["train"]["chunks"]
id_to_text = {idx: text for idx, text in enumerate(medical_texts)}
with open("id_to_text.json", "w") as f:
json.dump(id_to_text, f)
# Launch Gradio app
iface = gr.Interface(
fn=ask_medical_question,
inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."),
outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."),
title="Medical Question Answering System",
description="Ask any medical question, and the AI will provide an answer based on medical literature."
)
iface.launch()
# Run the main function
if __name__ == "__main__":
main()