Spaces:
Runtime error
Runtime error
import os | |
import json | |
import numpy as np | |
import faiss | |
import gradio as gr | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from groq import Groq | |
import nltk | |
import re | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize, sent_tokenize | |
from nltk.stem import WordNetLemmatizer | |
from multiprocessing import Pool, cpu_count | |
nltk.download("all") | |
# Load stopwords and lemmatizer | |
stop_words = set(stopwords.words("english")) | |
lemmatizer = WordNetLemmatizer() | |
# Load dataset | |
def load_and_preprocess_dataset(): | |
"""Load and preprocess the dataset.""" | |
dataset = load_dataset("MedRAG/textbooks") | |
print("Dataset loaded successfully.") | |
return dataset | |
# Preprocessing function | |
def preprocess_text(text): | |
"""Preprocess text by lowercasing, removing special characters, and lemmatizing.""" | |
text = text.lower() # Convert to lowercase | |
text = re.sub(r"[^\w\s]", "", text) # Remove special characters | |
words = word_tokenize(text) # Tokenization | |
words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal | |
return " ".join(words) | |
# Chunking function | |
def chunk_text(text, chunk_size=3): | |
"""Split text into chunks of sentences.""" | |
sentences = sent_tokenize(text) # Split text into sentences | |
return [" ".join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)] | |
# Generate embeddings in parallel | |
def generate_embeddings_parallel(chunks): | |
"""Generate embeddings for chunks in parallel.""" | |
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
with Pool(cpu_count()) as pool: | |
embeddings = pool.map(embed_model.encode, chunks) | |
return embeddings | |
# Generate embeddings for the dataset | |
def generate_embeddings(dataset): | |
"""Generate embeddings for the dataset.""" | |
print("Preprocessing dataset...") | |
dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])}) | |
dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])}) | |
print("Generating embeddings...") | |
all_chunks = [chunk for row in dataset["train"]["chunks"] for chunk in row] | |
embeddings = generate_embeddings_parallel(all_chunks) | |
# Add embeddings to the dataset | |
dataset = dataset.map(lambda row, idx: {"embedding": embeddings[idx]}, with_indices=True) | |
return dataset | |
# Create FAISS index | |
def create_faiss_index(dataset): | |
"""Create and save a FAISS index for the embeddings.""" | |
embeddings_np = np.array([np.array(row["embedding"]).flatten().tolist() for row in dataset["train"]], dtype=np.float32) | |
index = faiss.IndexFlatL2(embeddings_np.shape[1]) | |
index.add(embeddings_np) | |
faiss.write_index(index, "faiss_medical.index") | |
print("FAISS index created and saved.") | |
# Load FAISS index | |
def load_faiss_index(): | |
"""Load the FAISS index.""" | |
index = faiss.read_index("faiss_medical.index") | |
print("FAISS index loaded.") | |
return index | |
# Retrieve medical summary | |
def retrieve_medical_summary(query, index, id_to_text, k=3): | |
"""Retrieve the most relevant medical literature from FAISS.""" | |
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
query_embedding = embed_model.encode([query]) | |
D, I = index.search(np.array(query_embedding).astype("float32"), k) | |
retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]] | |
retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs] | |
return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found." | |
# Generate medical answer using Groq | |
def generate_medical_answer_groq(query, index, id_to_text): | |
"""Generate a medical response using Groq's API.""" | |
retrieved_summary = retrieve_medical_summary(query, index, id_to_text) | |
if not retrieved_summary or retrieved_summary == "No relevant data found.": | |
return "No relevant medical data found. Please consult a healthcare professional." | |
client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
try: | |
response = client.chat.completions.create( | |
model="llama-3.3-70b-versatile", | |
messages=[ | |
{"role": "system", "content": "You are an expert AI specializing in medical knowledge."}, | |
{"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"} | |
], | |
max_tokens=500, | |
temperature=0.3 | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
# Gradio interface | |
def ask_medical_question(question): | |
"""Gradio interface for asking medical questions.""" | |
return generate_medical_answer_groq(question, index, id_to_text) | |
# Main function | |
def main(): | |
"""Main function to set up the system.""" | |
global index, id_to_text | |
# Load and preprocess dataset | |
dataset = load_and_preprocess_dataset() | |
dataset = generate_embeddings(dataset) | |
# Create FAISS index | |
create_faiss_index(dataset) | |
# Load FAISS index | |
index = load_faiss_index() | |
# Create ID to text mapping | |
medical_texts = dataset["train"]["chunks"] | |
id_to_text = {idx: text for idx, text in enumerate(medical_texts)} | |
with open("id_to_text.json", "w") as f: | |
json.dump(id_to_text, f) | |
# Launch Gradio app | |
iface = gr.Interface( | |
fn=ask_medical_question, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."), | |
outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."), | |
title="Medical Question Answering System", | |
description="Ask any medical question, and the AI will provide an answer based on medical literature." | |
) | |
iface.launch() | |
# Run the main function | |
if __name__ == "__main__": | |
main() |