Spaces:
Runtime error
Runtime error
File size: 6,229 Bytes
8d72b01 e23d57a 8d72b01 e23d57a d5f642c e23d57a f168ba9 e23d57a 8d72b01 e23d57a 8d72b01 e23d57a 8d72b01 e23d57a 8d72b01 e23d57a 8d72b01 e23d57a d5f642c 8d72b01 d5f642c 8d72b01 d5f642c 8d72b01 e23d57a 8d72b01 2bfc0c2 8d72b01 2bfc0c2 8d72b01 2bfc0c2 8d72b01 2bfc0c2 8d72b01 2bfc0c2 8d72b01 2bfc0c2 8d72b01 2bfc0c2 8d72b01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import json
import numpy as np
import faiss
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
import nltk
import re
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.stem import WordNetLemmatizer
from multiprocessing import Pool, cpu_count
nltk.download("all")
# Load stopwords and lemmatizer
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
# Load dataset
def load_and_preprocess_dataset():
"""Load and preprocess the dataset."""
dataset = load_dataset("MedRAG/textbooks")
print("Dataset loaded successfully.")
return dataset
# Preprocessing function
def preprocess_text(text):
"""Preprocess text by lowercasing, removing special characters, and lemmatizing."""
text = text.lower() # Convert to lowercase
text = re.sub(r"[^\w\s]", "", text) # Remove special characters
words = word_tokenize(text) # Tokenization
words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal
return " ".join(words)
# Chunking function
def chunk_text(text, chunk_size=3):
"""Split text into chunks of sentences."""
sentences = sent_tokenize(text) # Split text into sentences
return [" ".join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)]
# Generate embeddings in parallel
def generate_embeddings_parallel(chunks):
"""Generate embeddings for chunks in parallel."""
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
with Pool(cpu_count()) as pool:
embeddings = pool.map(embed_model.encode, chunks)
return embeddings
# Generate embeddings for the dataset
def generate_embeddings(dataset):
"""Generate embeddings for the dataset."""
print("Preprocessing dataset...")
dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])})
dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])})
print("Generating embeddings...")
all_chunks = [chunk for row in dataset["train"]["chunks"] for chunk in row]
embeddings = generate_embeddings_parallel(all_chunks)
# Add embeddings to the dataset
dataset = dataset.map(lambda row, idx: {"embedding": embeddings[idx]}, with_indices=True)
return dataset
# Create FAISS index
def create_faiss_index(dataset):
"""Create and save a FAISS index for the embeddings."""
embeddings_np = np.array([np.array(row["embedding"]).flatten().tolist() for row in dataset["train"]], dtype=np.float32)
index = faiss.IndexFlatL2(embeddings_np.shape[1])
index.add(embeddings_np)
faiss.write_index(index, "faiss_medical.index")
print("FAISS index created and saved.")
# Load FAISS index
def load_faiss_index():
"""Load the FAISS index."""
index = faiss.read_index("faiss_medical.index")
print("FAISS index loaded.")
return index
# Retrieve medical summary
def retrieve_medical_summary(query, index, id_to_text, k=3):
"""Retrieve the most relevant medical literature from FAISS."""
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
query_embedding = embed_model.encode([query])
D, I = index.search(np.array(query_embedding).astype("float32"), k)
retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."
# Generate medical answer using Groq
def generate_medical_answer_groq(query, index, id_to_text):
"""Generate a medical response using Groq's API."""
retrieved_summary = retrieve_medical_summary(query, index, id_to_text)
if not retrieved_summary or retrieved_summary == "No relevant data found.":
return "No relevant medical data found. Please consult a healthcare professional."
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
try:
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{"role": "system", "content": "You are an expert AI specializing in medical knowledge."},
{"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"}
],
max_tokens=500,
temperature=0.3
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
# Gradio interface
def ask_medical_question(question):
"""Gradio interface for asking medical questions."""
return generate_medical_answer_groq(question, index, id_to_text)
# Main function
def main():
"""Main function to set up the system."""
global index, id_to_text
# Load and preprocess dataset
dataset = load_and_preprocess_dataset()
dataset = generate_embeddings(dataset)
# Create FAISS index
create_faiss_index(dataset)
# Load FAISS index
index = load_faiss_index()
# Create ID to text mapping
medical_texts = dataset["train"]["chunks"]
id_to_text = {idx: text for idx, text in enumerate(medical_texts)}
with open("id_to_text.json", "w") as f:
json.dump(id_to_text, f)
# Launch Gradio app
iface = gr.Interface(
fn=ask_medical_question,
inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."),
outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."),
title="Medical Question Answering System",
description="Ask any medical question, and the AI will provide an answer based on medical literature."
)
iface.launch()
# Run the main function
if __name__ == "__main__":
main() |