Update app.py
Browse files
app.py
CHANGED
@@ -1,520 +1,521 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import glob
|
4 |
-
from pathlib import Path
|
5 |
-
import torch
|
6 |
-
import streamlit as st
|
7 |
-
from dotenv import load_dotenv
|
8 |
-
from langchain_groq import ChatGroq
|
9 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
10 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
11 |
-
from langchain_community.vectorstores import FAISS
|
12 |
-
from langchain_core.documents import Document
|
13 |
-
from langchain_core.prompts import ChatPromptTemplate
|
14 |
-
from langchain.chains import create_retrieval_chain
|
15 |
-
from langchain.chains.combine_documents import create_stuff_documents_chain
|
16 |
-
import numpy as np
|
17 |
-
from sentence_transformers import util
|
18 |
-
import time
|
19 |
-
|
20 |
-
# Set device for model (CUDA if available)
|
21 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
-
|
23 |
-
# Load environment variables - works for both local and Hugging Face Spaces
|
24 |
-
load_dotenv()
|
25 |
-
|
26 |
-
# Set up the clinical assistant LLM
|
27 |
-
# Try to get API key from Hugging Face Spaces secrets first, then fall back to .env file
|
28 |
-
try:
|
29 |
-
# For Hugging Face Spaces
|
30 |
-
from huggingface_hub.inference_api import InferenceApi
|
31 |
-
import os
|
32 |
-
groq_api_key = os.environ.get('GROQ_API_KEY')
|
33 |
-
|
34 |
-
# If not found in environment, try to get from st.secrets (Streamlit Cloud/Spaces)
|
35 |
-
if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
|
36 |
-
groq_api_key = st.secrets['GROQ_API_KEY']
|
37 |
-
|
38 |
-
if not groq_api_key:
|
39 |
-
st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.")
|
40 |
-
# For UI demonstration without API key
|
41 |
-
class MockLLM:
|
42 |
-
def invoke(self, prompt):
|
43 |
-
return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."}
|
44 |
-
llm = MockLLM()
|
45 |
-
else:
|
46 |
-
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
|
47 |
-
|
48 |
-
except Exception as e:
|
49 |
-
st.error(f"Error setting up LLM: {str(e)}")
|
50 |
-
class MockLLM:
|
51 |
-
def invoke(self, prompt):
|
52 |
-
return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."}
|
53 |
-
llm = MockLLM()
|
54 |
-
|
55 |
-
# Set up embeddings for clinical context (Bio_ClinicalBERT)
|
56 |
-
embeddings = HuggingFaceEmbeddings(
|
57 |
-
model_name="emilyalsentzer/Bio_ClinicalBERT",
|
58 |
-
model_kwargs={"device": device}
|
59 |
-
)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
response
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
"
|
216 |
-
"
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
st.
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
.
|
278 |
-
.
|
279 |
-
.stButton>button
|
280 |
-
.
|
281 |
-
.chat-message
|
282 |
-
.chat-message-
|
283 |
-
.
|
284 |
-
.
|
285 |
-
.
|
286 |
-
.
|
287 |
-
.feature-item
|
288 |
-
.feature-
|
289 |
-
.feature-icon
|
290 |
-
.feature-
|
291 |
-
.feature-content
|
292 |
-
.feature-content
|
293 |
-
.
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
st.
|
307 |
-
st.
|
308 |
-
st.markdown("
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
st.markdown("
|
319 |
-
st.markdown("
|
320 |
-
st.markdown("**
|
321 |
-
st.markdown("**
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
st.markdown("<
|
330 |
-
st.markdown("
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
st.
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
st.markdown("<
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
<
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
<
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
<
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
<
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
if
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
st.session_state.
|
411 |
-
st.
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
response
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
f"<b>
|
464 |
-
f"<b>
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
-
|
494 |
-
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
- **
|
500 |
-
- **
|
501 |
-
- **
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
- **
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
st.markdown("
|
517 |
-
st.markdown("
|
518 |
-
|
519 |
-
|
520 |
-
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import glob
|
4 |
+
from pathlib import Path
|
5 |
+
import torch
|
6 |
+
import streamlit as st
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from langchain_groq import ChatGroq
|
9 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
10 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
11 |
+
from langchain_community.vectorstores import FAISS
|
12 |
+
from langchain_core.documents import Document
|
13 |
+
from langchain_core.prompts import ChatPromptTemplate
|
14 |
+
from langchain.chains import create_retrieval_chain
|
15 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
16 |
+
import numpy as np
|
17 |
+
from sentence_transformers import util
|
18 |
+
import time
|
19 |
+
|
20 |
+
# Set device for model (CUDA if available)
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
|
23 |
+
# Load environment variables - works for both local and Hugging Face Spaces
|
24 |
+
load_dotenv()
|
25 |
+
|
26 |
+
# Set up the clinical assistant LLM
|
27 |
+
# Try to get API key from Hugging Face Spaces secrets first, then fall back to .env file
|
28 |
+
try:
|
29 |
+
# For Hugging Face Spaces
|
30 |
+
from huggingface_hub.inference_api import InferenceApi
|
31 |
+
import os
|
32 |
+
groq_api_key = os.environ.get('GROQ_API_KEY')
|
33 |
+
|
34 |
+
# If not found in environment, try to get from st.secrets (Streamlit Cloud/Spaces)
|
35 |
+
if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
|
36 |
+
groq_api_key = st.secrets['GROQ_API_KEY']
|
37 |
+
|
38 |
+
if not groq_api_key:
|
39 |
+
st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.")
|
40 |
+
# For UI demonstration without API key
|
41 |
+
class MockLLM:
|
42 |
+
def invoke(self, prompt):
|
43 |
+
return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."}
|
44 |
+
llm = MockLLM()
|
45 |
+
else:
|
46 |
+
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
|
47 |
+
|
48 |
+
except Exception as e:
|
49 |
+
st.error(f"Error setting up LLM: {str(e)}")
|
50 |
+
class MockLLM:
|
51 |
+
def invoke(self, prompt):
|
52 |
+
return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."}
|
53 |
+
llm = MockLLM()
|
54 |
+
|
55 |
+
# Set up embeddings for clinical context (Bio_ClinicalBERT)
|
56 |
+
embeddings = HuggingFaceEmbeddings(
|
57 |
+
model_name="emilyalsentzer/Bio_ClinicalBERT",
|
58 |
+
model_kwargs={"device": device}
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def load_clinical_data():
|
63 |
+
"""Load both flowcharts and patient cases"""
|
64 |
+
docs = []
|
65 |
+
|
66 |
+
# Get the absolute path to the current script
|
67 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
68 |
+
|
69 |
+
# Try to handle potential errors with file loading
|
70 |
+
try:
|
71 |
+
# Load diagnosis flowcharts
|
72 |
+
flowchart_dir = os.path.join(current_dir, "Diagnosis_flowchart")
|
73 |
+
if os.path.exists(flowchart_dir):
|
74 |
+
for fpath in glob.glob(os.path.join(flowchart_dir, "*.json")):
|
75 |
+
try:
|
76 |
+
with open(fpath, 'r', encoding='utf-8') as f:
|
77 |
+
data = json.load(f)
|
78 |
+
content = f"""
|
79 |
+
DIAGNOSTIC FLOWCHART: {Path(fpath).stem}
|
80 |
+
Diagnostic Path: {data.get('diagnostic', 'N/A')}
|
81 |
+
Key Criteria: {data.get('knowledge', 'N/A')}
|
82 |
+
"""
|
83 |
+
docs.append(Document(
|
84 |
+
page_content=content,
|
85 |
+
metadata={"source": fpath, "type": "flowchart"}
|
86 |
+
))
|
87 |
+
except Exception as e:
|
88 |
+
st.warning(f"Error loading flowchart file {fpath}: {str(e)}")
|
89 |
+
else:
|
90 |
+
st.warning(f"Flowchart directory not found at {flowchart_dir}")
|
91 |
+
|
92 |
+
# Load patient cases
|
93 |
+
finished_dir = os.path.join(current_dir, "Finished")
|
94 |
+
if os.path.exists(finished_dir):
|
95 |
+
for category_dir in glob.glob(os.path.join(finished_dir, "*")):
|
96 |
+
if os.path.isdir(category_dir):
|
97 |
+
for case_file in glob.glob(os.path.join(category_dir, "*.json")):
|
98 |
+
try:
|
99 |
+
with open(case_file, 'r', encoding='utf-8') as f:
|
100 |
+
case_data = json.load(f)
|
101 |
+
notes = "\n".join(
|
102 |
+
f"{k}: {v}" for k, v in case_data.items() if k.startswith("input")
|
103 |
+
)
|
104 |
+
docs.append(Document(
|
105 |
+
page_content=f"""
|
106 |
+
PATIENT CASE: {Path(case_file).stem}
|
107 |
+
Category: {Path(category_dir).name}
|
108 |
+
Notes: {notes}
|
109 |
+
""",
|
110 |
+
metadata={"source": case_file, "type": "patient_case"}
|
111 |
+
))
|
112 |
+
except Exception as e:
|
113 |
+
st.warning(f"Error loading case file {case_file}: {str(e)}")
|
114 |
+
else:
|
115 |
+
st.warning(f"Finished directory not found at {finished_dir}")
|
116 |
+
|
117 |
+
# If no documents were loaded, add a sample document for testing
|
118 |
+
if not docs:
|
119 |
+
st.warning("No clinical data files found. Using sample data for demonstration.")
|
120 |
+
docs.append(Document(
|
121 |
+
page_content="""SAMPLE CLINICAL DATA: This is sample data for demonstration purposes.
|
122 |
+
This application requires clinical data files to be present in the correct directories.
|
123 |
+
Please ensure the Diagnosis_flowchart and Finished directories exist with proper JSON files.""",
|
124 |
+
metadata={"source": "sample", "type": "sample"}
|
125 |
+
))
|
126 |
+
except Exception as e:
|
127 |
+
st.error(f"Error loading clinical data: {str(e)}")
|
128 |
+
# Add a fallback document
|
129 |
+
docs.append(Document(
|
130 |
+
page_content="Error loading clinical data. This is a fallback document for demonstration purposes.",
|
131 |
+
metadata={"source": "error", "type": "error"}
|
132 |
+
))
|
133 |
+
return docs
|
134 |
+
|
135 |
+
def build_vectorstore():
|
136 |
+
"""Build and return the vectorstore using FAISS"""
|
137 |
+
documents = load_clinical_data()
|
138 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
139 |
+
splits = splitter.split_documents(documents)
|
140 |
+
vectorstore = FAISS.from_documents(splits, embeddings)
|
141 |
+
return vectorstore
|
142 |
+
|
143 |
+
# Path for saving/loading the vectorstore
|
144 |
+
def get_vectorstore_path():
|
145 |
+
"""Get the path for saving/loading the vectorstore"""
|
146 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
147 |
+
return os.path.join(current_dir, "vectorstore")
|
148 |
+
|
149 |
+
# Initialize vectorstore with disk persistence
|
150 |
+
@st.cache_resource(show_spinner="Loading clinical knowledge base...")
|
151 |
+
def get_vectorstore():
|
152 |
+
"""Get or create the vectorstore with disk persistence"""
|
153 |
+
vectorstore_path = get_vectorstore_path()
|
154 |
+
|
155 |
+
# Try to load from disk first
|
156 |
+
try:
|
157 |
+
if os.path.exists(vectorstore_path):
|
158 |
+
st.info("Loading vectorstore from disk...")
|
159 |
+
# Set allow_dangerous_deserialization to True since we trust our own vectorstore files
|
160 |
+
return FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True)
|
161 |
+
except Exception as e:
|
162 |
+
st.warning(f"Could not load vectorstore from disk: {str(e)}. Building new vectorstore.")
|
163 |
+
|
164 |
+
# If loading fails or doesn't exist, build a new one
|
165 |
+
st.info("Building new vectorstore...")
|
166 |
+
vectorstore = build_vectorstore()
|
167 |
+
|
168 |
+
# Save to disk for future use
|
169 |
+
try:
|
170 |
+
os.makedirs(vectorstore_path, exist_ok=True)
|
171 |
+
vectorstore.save_local(vectorstore_path)
|
172 |
+
st.success("Vectorstore saved to disk for future use")
|
173 |
+
except Exception as e:
|
174 |
+
st.warning(f"Could not save vectorstore to disk: {str(e)}")
|
175 |
+
|
176 |
+
return vectorstore
|
177 |
+
|
178 |
+
def run_rag_chat(query, vectorstore):
|
179 |
+
"""Run the Retrieval-Augmented Generation (RAG) for clinical questions"""
|
180 |
+
try:
|
181 |
+
retriever = vectorstore.as_retriever()
|
182 |
+
|
183 |
+
prompt_template = ChatPromptTemplate.from_template("""
|
184 |
+
You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question.
|
185 |
+
|
186 |
+
<context>
|
187 |
+
{context}
|
188 |
+
</context>
|
189 |
+
|
190 |
+
Question: {input}
|
191 |
+
|
192 |
+
Answer:
|
193 |
+
""")
|
194 |
+
|
195 |
+
retrieved_docs = retriever.invoke(query, k=3)
|
196 |
+
retrieved_context = "\n".join([doc.page_content for doc in retrieved_docs])
|
197 |
+
|
198 |
+
# Create document chain first
|
199 |
+
document_chain = create_stuff_documents_chain(llm, prompt_template)
|
200 |
+
|
201 |
+
# Then create retrieval chain
|
202 |
+
chain = create_retrieval_chain(retriever, document_chain)
|
203 |
+
|
204 |
+
# Invoke the chain
|
205 |
+
response = chain.invoke({"input": query})
|
206 |
+
|
207 |
+
# Add retrieved documents to response for transparency
|
208 |
+
response["context"] = retrieved_docs
|
209 |
+
|
210 |
+
return response
|
211 |
+
except Exception as e:
|
212 |
+
st.error(f"Error in RAG processing: {str(e)}")
|
213 |
+
# Return a fallback response
|
214 |
+
return {
|
215 |
+
"answer": f"I encountered an error processing your query: {str(e)}",
|
216 |
+
"context": [],
|
217 |
+
"input": query
|
218 |
+
}
|
219 |
+
|
220 |
+
def calculate_hit_rate(retriever, query, expected_docs, k=3):
|
221 |
+
"""Calculate the hit rate for top-k retrieved documents"""
|
222 |
+
retrieved_docs = retriever.get_relevant_documents(query, k=k)
|
223 |
+
retrieved_contents = [doc.page_content for doc in retrieved_docs]
|
224 |
+
|
225 |
+
hits = 0
|
226 |
+
for expected in expected_docs:
|
227 |
+
if any(expected in retrieved for retrieved in retrieved_contents):
|
228 |
+
hits += 1
|
229 |
+
|
230 |
+
return hits / len(expected_docs) if expected_docs else 0.0
|
231 |
+
|
232 |
+
def evaluate_rag_response(response, embeddings):
|
233 |
+
"""Evaluate the RAG response for faithfulness and hit rate"""
|
234 |
+
scores = {}
|
235 |
+
|
236 |
+
# Faithfulness: Answer-Context Similarity
|
237 |
+
answer_embed = embeddings.embed_query(response["answer"])
|
238 |
+
context_embeds = [embeddings.embed_query(doc.page_content) for doc in response["context"]]
|
239 |
+
similarities = [util.cos_sim(answer_embed, ctx_embed).item() for ctx_embed in context_embeds]
|
240 |
+
scores["faithfulness"] = float(np.mean(similarities)) if similarities else 0.0
|
241 |
+
|
242 |
+
# Custom Hit Rate Calculation
|
243 |
+
retriever = response["retriever"]
|
244 |
+
scores["hit_rate"] = calculate_hit_rate(
|
245 |
+
retriever,
|
246 |
+
query=response["input"],
|
247 |
+
expected_docs=[doc.page_content for doc in response["context"]],
|
248 |
+
k=3
|
249 |
+
)
|
250 |
+
|
251 |
+
return scores
|
252 |
+
|
253 |
+
def main():
|
254 |
+
"""Main function to run the Streamlit app"""
|
255 |
+
# Set page configuration
|
256 |
+
st.set_page_config(
|
257 |
+
page_title="DiReCT - Clinical Diagnostic Assistant",
|
258 |
+
page_icon="π©Ί",
|
259 |
+
layout="wide",
|
260 |
+
initial_sidebar_state="expanded"
|
261 |
+
)
|
262 |
+
|
263 |
+
# Load vectorstore only once using session state
|
264 |
+
if 'vectorstore' not in st.session_state:
|
265 |
+
with st.spinner("Loading clinical knowledge base... This may take a minute."):
|
266 |
+
try:
|
267 |
+
st.session_state.vectorstore = get_vectorstore()
|
268 |
+
# Use custom styled message without the success icon
|
269 |
+
st.markdown("<div style='padding:10px 15px;background-color:rgba(40,167,69,0.2);border-radius:5px;border-left:5px solid rgba(40,167,69,0.8);'>Clinical knowledge base loaded successfully!</div>", unsafe_allow_html=True)
|
270 |
+
except Exception as e:
|
271 |
+
st.error(f"Error loading knowledge base: {str(e)}")
|
272 |
+
st.session_state.vectorstore = None
|
273 |
+
|
274 |
+
# Custom CSS for modern look with dark theme compatibility
|
275 |
+
st.markdown("""
|
276 |
+
<style>
|
277 |
+
.stApp {max-width: 1200px; margin: 0 auto;}
|
278 |
+
.css-18e3th9 {padding-top: 2rem;}
|
279 |
+
.stButton>button {background-color: #3498db; color: white;}
|
280 |
+
.stButton>button:hover {background-color: #2980b9;}
|
281 |
+
.chat-message {border-radius: 10px; padding: 10px; margin-bottom: 10px;}
|
282 |
+
.chat-message-user {background-color: rgba(52, 152, 219, 0.2); color: inherit;}
|
283 |
+
.chat-message-assistant {background-color: rgba(240, 240, 240, 0.2); color: inherit;}
|
284 |
+
.source-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-bottom: 10px; border-left: 5px solid #3498db;}
|
285 |
+
.metrics-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-top: 20px;}
|
286 |
+
.features-container {display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; margin-top: 30px;}
|
287 |
+
.feature-item {flex: 1 1 calc(50% - 20px); min-width: 300px; display: flex; align-items: center; padding: 20px; border-radius: 10px; background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); transition: transform 0.3s, box-shadow 0.3s; border: 1px solid rgba(255, 255, 255, 0.1);}
|
288 |
+
.feature-item:hover {transform: translateY(-5px); box-shadow: 0 10px 20px rgba(0, 0, 0, 0.1);}
|
289 |
+
.feature-icon {width: 60px; height: 60px; border-radius: 50%; background: linear-gradient(135deg, #3498db, #2980b9); display: flex; align-items: center; justify-content: center; margin-right: 20px; box-shadow: 0 5px 15px rgba(52, 152, 219, 0.3);}
|
290 |
+
.feature-icon i {font-size: 24px; color: white;}
|
291 |
+
.feature-content {flex: 1;}
|
292 |
+
.feature-content h3 {margin-top: 0; margin-bottom: 10px; color: inherit;}
|
293 |
+
.feature-content p {margin: 0; font-size: 0.9em; color: inherit; opacity: 0.8;}
|
294 |
+
.input-container {margin-bottom: 20px; padding: 15px; border-radius: 10px; background-color: rgba(255, 255, 255, 0.05); border: 1px solid rgba(255, 255, 255, 0.1);}
|
295 |
+
</style>
|
296 |
+
""", unsafe_allow_html=True)
|
297 |
+
|
298 |
+
# App states
|
299 |
+
if 'chat_history' not in st.session_state:
|
300 |
+
st.session_state.chat_history = []
|
301 |
+
if 'page' not in st.session_state:
|
302 |
+
st.session_state.page = 'cover'
|
303 |
+
|
304 |
+
# Sidebar
|
305 |
+
with st.sidebar:
|
306 |
+
st.image("https://img.icons8.com/color/96/000000/caduceus.png", width=80)
|
307 |
+
st.title("DiReCT")
|
308 |
+
st.markdown("### Diagnostic Reasoning for Clinical Text")
|
309 |
+
st.markdown("---")
|
310 |
+
|
311 |
+
if st.button("Home", key="home_btn"):
|
312 |
+
st.session_state.page = 'cover'
|
313 |
+
if st.button("Diagnostic Assistant", key="assistant_btn"):
|
314 |
+
st.session_state.page = 'chat'
|
315 |
+
if st.button("About", key="about_btn"):
|
316 |
+
st.session_state.page = 'about'
|
317 |
+
|
318 |
+
st.markdown("---")
|
319 |
+
st.markdown("### Model Information")
|
320 |
+
st.markdown("**Embedding Model:** Bio_ClinicalBERT")
|
321 |
+
st.markdown("**LLM:** Llama-3.3-70B")
|
322 |
+
st.markdown("**Vector Store:** FAISS")
|
323 |
+
|
324 |
+
# Cover page
|
325 |
+
if st.session_state.page == 'cover':
|
326 |
+
# Hero section with animation
|
327 |
+
col1, col2 = st.columns([2, 1])
|
328 |
+
with col1:
|
329 |
+
st.markdown("<h1 style='font-size:3.5em;'>DiReCT</h1>", unsafe_allow_html=True)
|
330 |
+
st.markdown("<h2 style='font-size:1.8em;color:#3498db;'>Diagnostic Reasoning for Clinical Text</h2>", unsafe_allow_html=True)
|
331 |
+
st.markdown("""<p style='font-size:1.2em;'>A powerful RAG-based clinical diagnostic assistant that leverages the MIMIC-IV-Ext dataset to provide accurate medical insights and diagnostic reasoning.</p>""", unsafe_allow_html=True)
|
332 |
+
|
333 |
+
st.markdown("""<br>""", unsafe_allow_html=True)
|
334 |
+
if st.button("Get Started", key="get_started"):
|
335 |
+
st.session_state.page = 'chat'
|
336 |
+
st.rerun()
|
337 |
+
|
338 |
+
with col2:
|
339 |
+
# Animated medical icon
|
340 |
+
st.markdown("""
|
341 |
+
<div style='display:flex;justify-content:center;align-items:center;height:100%;'>
|
342 |
+
<img src="https://img.icons8.com/color/240/000000/healthcare-and-medical.png" style='max-width:90%;'>
|
343 |
+
</div>
|
344 |
+
""", unsafe_allow_html=True)
|
345 |
+
|
346 |
+
# Modern Features section with Streamlit native components
|
347 |
+
st.markdown("<br><br>", unsafe_allow_html=True)
|
348 |
+
st.markdown("<h2 style='text-align:center;'>Key Features</h2>", unsafe_allow_html=True)
|
349 |
+
|
350 |
+
# Create a 2x2 grid for features using Streamlit columns
|
351 |
+
col1, col2 = st.columns(2)
|
352 |
+
|
353 |
+
# Feature 1
|
354 |
+
with col1:
|
355 |
+
st.markdown("""
|
356 |
+
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
|
357 |
+
padding: 20px; border-radius: 10px; height: 100%;
|
358 |
+
border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
|
359 |
+
<h3>π Intelligent Retrieval</h3>
|
360 |
+
<p>Finds the most relevant clinical information from the MIMIC-IV-Ext dataset</p>
|
361 |
+
</div>
|
362 |
+
""", unsafe_allow_html=True)
|
363 |
+
|
364 |
+
# Feature 2
|
365 |
+
with col2:
|
366 |
+
st.markdown("""
|
367 |
+
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
|
368 |
+
padding: 20px; border-radius: 10px; height: 100%;
|
369 |
+
border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
|
370 |
+
<h3>π§ Advanced Reasoning</h3>
|
371 |
+
<p>Applies clinical knowledge to generate accurate diagnostic insights</p>
|
372 |
+
</div>
|
373 |
+
""", unsafe_allow_html=True)
|
374 |
+
|
375 |
+
# Feature 3
|
376 |
+
with col1:
|
377 |
+
st.markdown("""
|
378 |
+
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
|
379 |
+
padding: 20px; border-radius: 10px; height: 100%;
|
380 |
+
border: 1px solid rgba(255, 255, 255, 0.1);">
|
381 |
+
<h3>π Source Transparency</h3>
|
382 |
+
<p>Provides references to all clinical sources used in generating responses</p>
|
383 |
+
</div>
|
384 |
+
""", unsafe_allow_html=True)
|
385 |
+
|
386 |
+
# Feature 4
|
387 |
+
with col2:
|
388 |
+
st.markdown("""
|
389 |
+
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
|
390 |
+
padding: 20px; border-radius: 10px; height: 100%;
|
391 |
+
border: 1px solid rgba(255, 255, 255, 0.1);">
|
392 |
+
<h3>π Dark/Light Theme Compatible</h3>
|
393 |
+
<p>Optimized interface that works seamlessly in both dark and light themes</p>
|
394 |
+
</div>
|
395 |
+
""", unsafe_allow_html=True)
|
396 |
+
|
397 |
+
# Chat interface
|
398 |
+
elif st.session_state.page == 'chat':
|
399 |
+
# Initialize session state for input if not exists
|
400 |
+
if 'user_input' not in st.session_state:
|
401 |
+
st.session_state.user_input = ""
|
402 |
+
|
403 |
+
# Header with clear button
|
404 |
+
col1, col2 = st.columns([3, 1])
|
405 |
+
with col1:
|
406 |
+
st.markdown("<h1>Clinical Diagnostic Assistant</h1>", unsafe_allow_html=True)
|
407 |
+
with col2:
|
408 |
+
# Add a clear button in the header
|
409 |
+
if st.button("ποΈ Clear Chat"):
|
410 |
+
st.session_state.chat_history = []
|
411 |
+
st.session_state.user_input = ""
|
412 |
+
st.rerun()
|
413 |
+
|
414 |
+
st.markdown("Ask any clinical diagnostic question and get insights based on medical knowledge and patient cases.")
|
415 |
+
|
416 |
+
# Fixed input area at the top
|
417 |
+
with st.container():
|
418 |
+
st.markdown("<div class='input-container'>", unsafe_allow_html=True)
|
419 |
+
user_input = st.text_area("Ask a clinical question:", st.session_state.user_input, height=100, key="question_input")
|
420 |
+
col1, col2 = st.columns([1, 5])
|
421 |
+
with col1:
|
422 |
+
submit_button = st.button("Submit")
|
423 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
424 |
+
|
425 |
+
# Create a container for chat history
|
426 |
+
chat_container = st.container()
|
427 |
+
|
428 |
+
# Process query
|
429 |
+
if submit_button and user_input:
|
430 |
+
if st.session_state.vectorstore is None:
|
431 |
+
st.error("Knowledge base not loaded. Please refresh the page and try again.")
|
432 |
+
else:
|
433 |
+
with st.spinner("Analyzing clinical data..."):
|
434 |
+
try:
|
435 |
+
# Add a small delay for UX
|
436 |
+
time.sleep(0.5)
|
437 |
+
|
438 |
+
# Run RAG
|
439 |
+
response = run_rag_chat(user_input, st.session_state.vectorstore)
|
440 |
+
response["retriever"] = st.session_state.vectorstore.as_retriever()
|
441 |
+
|
442 |
+
# Clear previous chat history and only keep the current response
|
443 |
+
st.session_state.chat_history = [(user_input, response)]
|
444 |
+
|
445 |
+
# Clear the input field
|
446 |
+
st.session_state.user_input = ""
|
447 |
+
|
448 |
+
# Rerun to update UI
|
449 |
+
st.rerun()
|
450 |
+
except Exception as e:
|
451 |
+
st.error(f"Error processing query: {str(e)}")
|
452 |
+
|
453 |
+
# Display chat history in the container
|
454 |
+
with chat_container:
|
455 |
+
for i, (query, response) in enumerate(st.session_state.chat_history):
|
456 |
+
st.markdown(f"<div class='chat-message chat-message-user'><b>π§ββοΈ You:</b> {query}</div>", unsafe_allow_html=True)
|
457 |
+
|
458 |
+
st.markdown(f"<div class='chat-message chat-message-assistant'><b>π©Ί DiReCT:</b> {response['answer']}</div>", unsafe_allow_html=True)
|
459 |
+
|
460 |
+
with st.expander("View Sources"):
|
461 |
+
for doc in response["context"]:
|
462 |
+
st.markdown(f"<div class='source-box'>"
|
463 |
+
f"<b>Source:</b> {Path(doc.metadata['source']).stem}<br>"
|
464 |
+
f"<b>Type:</b> {doc.metadata['type']}<br>"
|
465 |
+
f"<b>Content:</b> {doc.page_content[:300]}...</div>",
|
466 |
+
unsafe_allow_html=True)
|
467 |
+
|
468 |
+
# Show evaluation metrics if available
|
469 |
+
try:
|
470 |
+
eval_scores = evaluate_rag_response(response, embeddings)
|
471 |
+
with st.expander("View Evaluation Metrics"):
|
472 |
+
col1, col2 = st.columns(2)
|
473 |
+
with col1:
|
474 |
+
st.metric("Hit Rate (Top-3)", f"{eval_scores['hit_rate']:.2f}")
|
475 |
+
with col2:
|
476 |
+
st.metric("Faithfulness", f"{eval_scores['faithfulness']:.2f}")
|
477 |
+
except Exception as e:
|
478 |
+
st.warning(f"Evaluation metrics unavailable: {str(e)}")
|
479 |
+
|
480 |
+
# About page
|
481 |
+
elif st.session_state.page == 'about':
|
482 |
+
st.markdown("<h1>About DiReCT</h1>", unsafe_allow_html=True)
|
483 |
+
|
484 |
+
st.markdown("""
|
485 |
+
### Project Overview
|
486 |
+
|
487 |
+
DiReCT (Diagnostic Reasoning for Clinical Text) is a Retrieval-Augmented Generation (RAG) system designed to assist medical professionals with diagnostic reasoning based on clinical notes and medical knowledge.
|
488 |
+
|
489 |
+
### Data Sources
|
490 |
+
|
491 |
+
This application uses the MIMIC-IV-Ext dataset, which contains de-identified clinical notes and medical records. The system processes:
|
492 |
+
|
493 |
+
- Diagnostic flowcharts
|
494 |
+
- Patient cases
|
495 |
+
- Clinical guidelines
|
496 |
+
|
497 |
+
### Technical Implementation
|
498 |
+
|
499 |
+
- **Embedding Model**: Bio_ClinicalBERT for domain-specific text understanding
|
500 |
+
- **Vector Database**: FAISS for efficient similarity search
|
501 |
+
- **LLM**: Llama-3.3-70B for generating medically accurate responses
|
502 |
+
- **Framework**: Built with LangChain and Streamlit
|
503 |
+
|
504 |
+
### Evaluation Metrics
|
505 |
+
|
506 |
+
The system evaluates responses using:
|
507 |
+
|
508 |
+
- **Hit Rate**: Measures how many relevant documents were retrieved
|
509 |
+
- **Faithfulness**: Measures how well the response aligns with the retrieved context
|
510 |
+
|
511 |
+
### Ethical Considerations
|
512 |
+
|
513 |
+
This system is designed as a clinical decision support tool and not as a replacement for professional medical judgment. All patient data used has been properly de-identified in compliance with healthcare privacy regulations.
|
514 |
+
""")
|
515 |
+
|
516 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
517 |
+
st.markdown("### Developers")
|
518 |
+
st.markdown("This project was developed as part of an academic assignment on RAG systems for clinical applications.")
|
519 |
+
|
520 |
+
if __name__ == "__main__":
|
521 |
+
main()
|