File size: 6,229 Bytes
8d72b01
 
 
 
 
e23d57a
8d72b01
 
e23d57a
 
 
 
 
d5f642c
e23d57a
f168ba9
e23d57a
 
 
 
 
8d72b01
 
 
 
 
 
 
 
e23d57a
8d72b01
e23d57a
 
 
 
 
 
8d72b01
e23d57a
8d72b01
e23d57a
8d72b01
e23d57a
d5f642c
 
 
 
 
 
 
 
 
8d72b01
 
d5f642c
8d72b01
 
d5f642c
 
 
 
 
 
 
8d72b01
e23d57a
 
8d72b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bfc0c2
 
 
 
 
 
8d72b01
 
 
 
2bfc0c2
 
 
8d72b01
2bfc0c2
 
8d72b01
2bfc0c2
 
 
 
8d72b01
 
2bfc0c2
8d72b01
2bfc0c2
 
 
8d72b01
2bfc0c2
8d72b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import json
import numpy as np
import faiss
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
import nltk
import re
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.stem import WordNetLemmatizer
from multiprocessing import Pool, cpu_count

nltk.download("all")

# Load stopwords and lemmatizer
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()

# Load dataset
def load_and_preprocess_dataset():
    """Load and preprocess the dataset."""
    dataset = load_dataset("MedRAG/textbooks")
    print("Dataset loaded successfully.")
    return dataset

# Preprocessing function
def preprocess_text(text):
    """Preprocess text by lowercasing, removing special characters, and lemmatizing."""
    text = text.lower()  # Convert to lowercase
    text = re.sub(r"[^\w\s]", "", text)  # Remove special characters
    words = word_tokenize(text)  # Tokenization
    words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words]  # Lemmatization & stopword removal
    return " ".join(words)

# Chunking function
def chunk_text(text, chunk_size=3):
    """Split text into chunks of sentences."""
    sentences = sent_tokenize(text)  # Split text into sentences
    return [" ".join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)]

# Generate embeddings in parallel
def generate_embeddings_parallel(chunks):
    """Generate embeddings for chunks in parallel."""
    embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    with Pool(cpu_count()) as pool:
        embeddings = pool.map(embed_model.encode, chunks)
    return embeddings

# Generate embeddings for the dataset
def generate_embeddings(dataset):
    """Generate embeddings for the dataset."""
    print("Preprocessing dataset...")
    dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])})
    dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])})

    print("Generating embeddings...")
    all_chunks = [chunk for row in dataset["train"]["chunks"] for chunk in row]
    embeddings = generate_embeddings_parallel(all_chunks)

    # Add embeddings to the dataset
    dataset = dataset.map(lambda row, idx: {"embedding": embeddings[idx]}, with_indices=True)
    return dataset

# Create FAISS index
def create_faiss_index(dataset):
    """Create and save a FAISS index for the embeddings."""
    embeddings_np = np.array([np.array(row["embedding"]).flatten().tolist() for row in dataset["train"]], dtype=np.float32)
    index = faiss.IndexFlatL2(embeddings_np.shape[1])
    index.add(embeddings_np)
    faiss.write_index(index, "faiss_medical.index")
    print("FAISS index created and saved.")

# Load FAISS index
def load_faiss_index():
    """Load the FAISS index."""
    index = faiss.read_index("faiss_medical.index")
    print("FAISS index loaded.")
    return index

# Retrieve medical summary
def retrieve_medical_summary(query, index, id_to_text, k=3):
    """Retrieve the most relevant medical literature from FAISS."""
    embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    query_embedding = embed_model.encode([query])
    D, I = index.search(np.array(query_embedding).astype("float32"), k)
    retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
    retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
    return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."

# Generate medical answer using Groq
def generate_medical_answer_groq(query, index, id_to_text):
    """Generate a medical response using Groq's API."""
    retrieved_summary = retrieve_medical_summary(query, index, id_to_text)
    if not retrieved_summary or retrieved_summary == "No relevant data found.":
        return "No relevant medical data found. Please consult a healthcare professional."

    client = Groq(api_key=os.getenv("GROQ_API_KEY"))
    try:
        response = client.chat.completions.create(
            model="llama-3.3-70b-versatile",
            messages=[
                {"role": "system", "content": "You are an expert AI specializing in medical knowledge."},
                {"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"}
            ],
            max_tokens=500,
            temperature=0.3
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error generating response: {str(e)}"

# Gradio interface
def ask_medical_question(question):
    """Gradio interface for asking medical questions."""
    return generate_medical_answer_groq(question, index, id_to_text)

# Main function
def main():
    """Main function to set up the system."""
    global index, id_to_text

    # Load and preprocess dataset
    dataset = load_and_preprocess_dataset()
    dataset = generate_embeddings(dataset)

    # Create FAISS index
    create_faiss_index(dataset)

    # Load FAISS index
    index = load_faiss_index()

    # Create ID to text mapping
    medical_texts = dataset["train"]["chunks"]
    id_to_text = {idx: text for idx, text in enumerate(medical_texts)}
    with open("id_to_text.json", "w") as f:
        json.dump(id_to_text, f)

    # Launch Gradio app
    iface = gr.Interface(
        fn=ask_medical_question,
        inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."),
        outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."),
        title="Medical Question Answering System",
        description="Ask any medical question, and the AI will provide an answer based on medical literature."
    )
    iface.launch()

# Run the main function
if __name__ == "__main__":
    main()