khalil2233 commited on
Commit
e23d57a
Β·
verified Β·
1 Parent(s): a6ae1a3
Files changed (1) hide show
  1. app.py +248 -15
app.py CHANGED
@@ -1,28 +1,239 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import faiss
5
- import gradio as gr
6
  from sentence_transformers import SentenceTransformer
7
- from groq import Groq
8
 
9
- # Load FAISS index
10
- FAISS_INDEX_PATH = "faiss_medical.index"
11
  index = faiss.read_index(FAISS_INDEX_PATH)
12
 
13
- # Load embedding model
14
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
15
 
16
- # Load FAISS ID β†’ Text Mapping
17
  with open("id_to_text.json", "r") as f:
18
  id_to_text = json.load(f)
19
 
20
- # Convert JSON keys to integers (FAISS returns int IDs)
21
  id_to_text = {int(k): v for k, v in id_to_text.items()}
22
 
23
- # Initialize Groq client
24
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
25
-
26
  def retrieve_medical_summary(query, k=3):
27
  """
28
  Retrieve the most relevant medical literature from FAISS.
@@ -43,12 +254,30 @@ def retrieve_medical_summary(query, k=3):
43
  # Retrieve the closest matching text using FAISS index IDs
44
  retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
45
 
46
- # Ensure all retrieved texts are strings (Flatten lists if needed)
47
  retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
48
 
49
- # Join multiple retrieved documents into one response
50
  return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tokens=500, temperature=0.3):
53
  """
54
  Generates a medical response using Groq's API with LLaMA 3.3-70B, after retrieving relevant literature from FAISS.
@@ -63,7 +292,7 @@ def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tok
63
  str: The AI-generated medical advice.
64
  """
65
 
66
- # Retrieve relevant medical literature from FAISS
67
  retrieved_summary = retrieve_medical_summary(query)
68
  print("\nπŸ” Retrieved Medical Text for Query:", query)
69
  print(retrieved_summary, "\n")
@@ -72,7 +301,7 @@ def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tok
72
  return "No relevant medical data found. Please consult a healthcare professional."
73
 
74
  try:
75
- # Send request to Groq API
76
  response = client.chat.completions.create(
77
  model=model,
78
  messages=[
@@ -88,6 +317,10 @@ def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tok
88
  except Exception as e:
89
  return f"Error generating response: {str(e)}"
90
 
 
 
 
 
91
  # Gradio Interface
92
  def ask_medical_question(question):
93
  return generate_medical_answer_groq(question)
 
1
+ from datasets import load_dataset
2
+
3
+ # Load dataset from Hugging Face
4
+ dataset = load_dataset("MedRAG/textbooks")
5
+
6
+ # Preview dataset
7
+ print(dataset)
8
+
9
+ import pandas as pd
10
+
11
+ # Convert to Pandas DataFrame
12
+ df = pd.DataFrame(dataset["train"])
13
+
14
+ # Display first rows
15
+ print(df.head())
16
+
17
+ # Check file format
18
+ print(df.dtypes)
19
+
20
+ import nltk
21
+ import shutil
22
+
23
+ # Supprimer les ressources existantes
24
+ nltk.data.path.append('/root/nltk_data') # Ajouter le chemin de nltk_data
25
+ nltk.data.clear_cache() # Effacer le cache des donnΓ©es
26
+
27
+
28
+ # RΓ©installer le package 'punkt'
29
+ nltk.download('all')
30
+
31
+
32
+ import re
33
+ import nltk
34
+ from nltk.corpus import stopwords
35
+ from nltk.tokenize import word_tokenize, sent_tokenize
36
+ from nltk.stem import WordNetLemmatizer
37
+
38
+ # Download necessary NLTK components
39
+ nltk.download("stopwords")
40
+ nltk.download("punkt")
41
+ nltk.download("wordnet")
42
+ nltk.download("omw-1.4")
43
+
44
+ # Load stopwords and lemmatizer
45
+ stop_words = set(stopwords.words("english"))
46
+ lemmatizer = WordNetLemmatizer()
47
+
48
+ # Step 1: Preprocessing Function
49
+ def preprocess_text(text):
50
+ text = text.lower() # Convert to lowercase
51
+ text = re.sub(r"[^\w\s]", "", text) # Remove special characters
52
+ words = word_tokenize(text) # Tokenization
53
+ words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal
54
+ return " ".join(words)
55
+
56
+ # Apply preprocessing before chunking
57
+ dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])})
58
+
59
+ # Step 2: Chunking Function
60
+ def chunk_text(text, chunk_size=3):
61
+ sentences = sent_tokenize(text) # Split text into sentences
62
+ return [" ".join(sentences[i:i+chunk_size]) for i in range(0, len(sentences), chunk_size)]
63
+
64
+ # Apply chunking on the cleaned text
65
+ dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])})
66
+
67
+ from sentence_transformers import SentenceTransformer
68
+
69
+ # Load BioBERT or MiniLM for fast embedding
70
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
71
+
72
+ def generate_embedding(row):
73
+ embedding = embed_model.encode(row["chunks"], convert_to_tensor=False).tolist()
74
+
75
+ # Fix: Ensure embedding is a flat list, not nested
76
+ row["embedding"] = embedding[0] if isinstance(embedding, list) and len(embedding) == 1 else embedding
77
+ return row
78
+
79
+ dataset = dataset.map(generate_embedding)
80
+
81
+ import numpy as np
82
+
83
+ # Flatten embeddings (convert [[...]] β†’ [...])
84
+ valid_embeddings = [
85
+ np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D
86
+ for row in dataset["train"]
87
+ if isinstance(row["embedding"], list) and len(row["embedding"]) == 384
88
+ ]
89
+
90
+ # Convert to NumPy array
91
+ embeddings_np = np.array(valid_embeddings, dtype=np.float32)
92
+
93
+ # Check shape
94
+ print("βœ… Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384)
95
+
96
+ import numpy as np
97
+
98
+ # Flatten embeddings (convert [[...]] β†’ [...])
99
+ valid_embeddings = [
100
+ np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D
101
+ for row in dataset["train"]
102
+ if isinstance(row["embedding"], list) and len(row["embedding"]) == 384
103
+ ]
104
+
105
+ # Convert to NumPy array
106
+ embeddings_np = np.array(valid_embeddings, dtype=np.float32)
107
+
108
+ # Check shape
109
+ print("βœ… Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384)
110
+
111
+ import faiss
112
+
113
+ # Check if embeddings are 2D
114
+ if len(embeddings_np.shape) == 1:
115
+ embeddings_np = embeddings_np.reshape(1, -1) # Ensure it's (num_samples, embedding_dim)
116
+
117
+ # Check final shape
118
+ print("Fixed Embeddings Shape:", embeddings_np.shape)
119
+
120
+ # Create FAISS index
121
+ index = faiss.IndexFlatL2(embeddings_np.shape[1])
122
+ index.add(embeddings_np) # Add all embeddings
123
+
124
+ print("βœ… Embeddings successfully stored in FAISS!")
125
+ print("Total embeddings in FAISS:", index.ntotal)
126
+
127
+ FAISS_INDEX_PATH = "/content/faiss_medical.index" # Save in Colab's file system
128
+
129
+ # Save the FAISS index
130
+ faiss.write_index(index, FAISS_INDEX_PATH)
131
+
132
+ print(f"βœ… FAISS index successfully saved at: {FAISS_INDEX_PATH}")
133
+
134
+ # Load FAISS index from file
135
+ index = faiss.read_index(FAISS_INDEX_PATH)
136
+
137
+ print(f"βœ… FAISS index loaded from: {FAISS_INDEX_PATH}")
138
+ print(f"Total embeddings stored: {index.ntotal}")
139
+
140
+ print("πŸ” Available columns:", dataset.column_names) # Should include "chunks"
141
+
142
+ medical_texts = dataset["train"]["chunks"] # βœ… Correct way to access chunks
143
+ # Use the same text that will be encoded
144
+
145
+ print("πŸ” Dataset structure:", dataset)
146
+ print("πŸ” Available columns in train:", dataset["train"].column_names)
147
+ print("βœ… First 3 chunked texts:", dataset["train"]["chunks"][:3])
148
+
149
  import json
150
+ id_to_text = {idx: text for idx, text in enumerate(medical_texts)}
151
+
152
+ with open("id_to_text.json", "w") as f:
153
+ json.dump(id_to_text, f)
154
+
155
+ import os
156
+
157
+ # βœ… Check if file exists
158
+ if os.path.exists("id_to_text.json"):
159
+ print("βœ… `id_to_text.json` exists!")
160
+
161
+ # βœ… Load the JSON file
162
+ with open("id_to_text.json", "r") as f:
163
+ id_to_text = json.load(f)
164
+
165
+ # βœ… Compare number of records
166
+ print(f"πŸ“Š Records in `id_to_text.json`: {len(id_to_text)}")
167
+ print(f"πŸ“Š Records in `medical_texts`: {len(medical_texts)}")
168
+
169
+ if len(id_to_text) == len(medical_texts):
170
+ print("βœ… JSON file contains the correct number of records!")
171
+ else:
172
+ print("❌ Mismatch! FAISS ID mapping and dataset size are different.")
173
+
174
+ else:
175
+ print("❌ `id_to_text.json` was not found! Make sure it was saved correctly.")
176
+
177
+ import random
178
+
179
+ # βœ… Pick 3 random FAISS IDs
180
+ sample_ids = random.sample(list(id_to_text.keys()), 3)
181
+
182
+ # βœ… Print their corresponding texts
183
+ for faiss_id in sample_ids:
184
+ print(f"FAISS ID {faiss_id} β†’ Text: {id_to_text[faiss_id][:100]}...") # Show only first 100 chars
185
+
186
+ import faiss
187
  import numpy as np
188
+ from sentence_transformers import SentenceTransformer
189
+
190
+ # βœ… Load FAISS
191
+ FAISS_INDEX_PATH = "/content/faiss_medical.index"
192
+ index = faiss.read_index(FAISS_INDEX_PATH)
193
+
194
+ # βœ… Load Sentence Transformer model
195
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
196
+
197
+ # βœ… Test a retrieval query
198
+ query = "What are the symptoms of pneumonia?"
199
+ query_embedding = embed_model.encode([query])
200
+
201
+ # βœ… Perform FAISS search
202
+ D, I = index.search(np.array(query_embedding).astype("float32"), 3) # Retrieve top 3 matches
203
+
204
+ # βœ… Print the FAISS results & compare with JSON mapping
205
+ print("πŸ” FAISS Search Results:", I[0])
206
+ print("πŸ“ FAISS Distances:", D[0])
207
+
208
+ # βœ… Load `id_to_text.json`
209
+ with open("id_to_text.json", "r") as f:
210
+ id_to_text = json.load(f)
211
+
212
+ id_to_text = {int(k): v for k, v in id_to_text.items()} # Ensure keys are integers
213
+
214
+ # βœ… Print the matching texts
215
+ for faiss_id in I[0]:
216
+ print(f"FAISS ID {faiss_id} β†’ Text: {id_to_text[faiss_id][:100]}...") # Show first 100 characters
217
+
218
  import faiss
219
+ import numpy as np
220
  from sentence_transformers import SentenceTransformer
221
+ import json
222
 
223
+ # βœ… Load FAISS index
224
+ FAISS_INDEX_PATH = "/content/faiss_medical.index"
225
  index = faiss.read_index(FAISS_INDEX_PATH)
226
 
227
+ # βœ… Load embedding model
228
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
229
 
230
+ # βœ… Load FAISS ID β†’ Text Mapping
231
  with open("id_to_text.json", "r") as f:
232
  id_to_text = json.load(f)
233
 
234
+ # βœ… Convert JSON keys to integers (FAISS returns int IDs)
235
  id_to_text = {int(k): v for k, v in id_to_text.items()}
236
 
 
 
 
237
  def retrieve_medical_summary(query, k=3):
238
  """
239
  Retrieve the most relevant medical literature from FAISS.
 
254
  # Retrieve the closest matching text using FAISS index IDs
255
  retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
256
 
257
+ # βœ… Ensure all retrieved texts are strings (Flatten lists if needed)
258
  retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
259
 
260
+ # βœ… Join multiple retrieved documents into one response
261
  return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."
262
 
263
+
264
+ # βœ… Example Test
265
+ query = "What are the symptoms of pneumonia?"
266
+ retrieved_summary = retrieve_medical_summary(query, k=3)
267
+
268
+ print("πŸ“– Retrieved Medical Summary:\n", retrieved_summary)
269
+
270
+
271
+
272
+ import os
273
+ from groq import Groq
274
+
275
+ # βœ… Store API Key in Environment Variable
276
+ os.environ["GROQ_API_KEY"] = "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" # Replace with your actual key
277
+
278
+ # βœ… Initialize Groq client correctly (Retrieve API key properly)
279
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
280
+
281
  def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tokens=500, temperature=0.3):
282
  """
283
  Generates a medical response using Groq's API with LLaMA 3.3-70B, after retrieving relevant literature from FAISS.
 
292
  str: The AI-generated medical advice.
293
  """
294
 
295
+ # βœ… Retrieve relevant medical literature from FAISS
296
  retrieved_summary = retrieve_medical_summary(query)
297
  print("\nπŸ” Retrieved Medical Text for Query:", query)
298
  print(retrieved_summary, "\n")
 
301
  return "No relevant medical data found. Please consult a healthcare professional."
302
 
303
  try:
304
+ # βœ… Send request to Groq API
305
  response = client.chat.completions.create(
306
  model=model,
307
  messages=[
 
317
  except Exception as e:
318
  return f"Error generating response: {str(e)}"
319
 
320
+ # βœ… Example Usage
321
+ query = "What are the symptoms of pneumonia?"
322
+ print("🩺 AI-Generated Response:", generate_medical_answer_groq(query))
323
+
324
  # Gradio Interface
325
  def ask_medical_question(question):
326
  return generate_medical_answer_groq(question)