fastx commited on
Commit
dd78d61
Β·
verified Β·
1 Parent(s): 36417d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +521 -520
app.py CHANGED
@@ -1,520 +1,521 @@
1
- import os
2
- import json
3
- import glob
4
- from pathlib import Path
5
- import torch
6
- import streamlit as st
7
- from dotenv import load_dotenv
8
- from langchain_groq import ChatGroq
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain_text_splitters import RecursiveCharacterTextSplitter
11
- from langchain_community.vectorstores import FAISS
12
- from langchain_core.documents import Document
13
- from langchain_core.prompts import ChatPromptTemplate
14
- from langchain.chains import create_retrieval_chain
15
- from langchain.chains.combine_documents import create_stuff_documents_chain
16
- import numpy as np
17
- from sentence_transformers import util
18
- import time
19
-
20
- # Set device for model (CUDA if available)
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
-
23
- # Load environment variables - works for both local and Hugging Face Spaces
24
- load_dotenv()
25
-
26
- # Set up the clinical assistant LLM
27
- # Try to get API key from Hugging Face Spaces secrets first, then fall back to .env file
28
- try:
29
- # For Hugging Face Spaces
30
- from huggingface_hub.inference_api import InferenceApi
31
- import os
32
- groq_api_key = os.environ.get('GROQ_API_KEY')
33
-
34
- # If not found in environment, try to get from st.secrets (Streamlit Cloud/Spaces)
35
- if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
36
- groq_api_key = st.secrets['GROQ_API_KEY']
37
-
38
- if not groq_api_key:
39
- st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.")
40
- # For UI demonstration without API key
41
- class MockLLM:
42
- def invoke(self, prompt):
43
- return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."}
44
- llm = MockLLM()
45
- else:
46
- llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
47
-
48
- except Exception as e:
49
- st.error(f"Error setting up LLM: {str(e)}")
50
- class MockLLM:
51
- def invoke(self, prompt):
52
- return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."}
53
- llm = MockLLM()
54
-
55
- # Set up embeddings for clinical context (Bio_ClinicalBERT)
56
- embeddings = HuggingFaceEmbeddings(
57
- model_name="emilyalsentzer/Bio_ClinicalBERT",
58
- model_kwargs={"device": device}
59
- )
60
-
61
- def load_clinical_data():
62
- """Load both flowcharts and patient cases"""
63
- docs = []
64
-
65
- # Get the absolute path to the current script
66
- current_dir = os.path.dirname(os.path.abspath(__file__))
67
-
68
- # Try to handle potential errors with file loading
69
- try:
70
- # Load diagnosis flowcharts
71
- flowchart_dir = os.path.join(current_dir, "Diagnosis_flowchart")
72
- if os.path.exists(flowchart_dir):
73
- for fpath in glob.glob(os.path.join(flowchart_dir, "*.json")):
74
- try:
75
- with open(fpath, 'r', encoding='utf-8') as f:
76
- data = json.load(f)
77
- content = f"""
78
- DIAGNOSTIC FLOWCHART: {Path(fpath).stem}
79
- Diagnostic Path: {data.get('diagnostic', 'N/A')}
80
- Key Criteria: {data.get('knowledge', 'N/A')}
81
- """
82
- docs.append(Document(
83
- page_content=content,
84
- metadata={"source": fpath, "type": "flowchart"}
85
- ))
86
- except Exception as e:
87
- st.warning(f"Error loading flowchart file {fpath}: {str(e)}")
88
- else:
89
- st.warning(f"Flowchart directory not found at {flowchart_dir}")
90
-
91
- # Load patient cases
92
- finished_dir = os.path.join(current_dir, "Finished")
93
- if os.path.exists(finished_dir):
94
- for category_dir in glob.glob(os.path.join(finished_dir, "*")):
95
- if os.path.isdir(category_dir):
96
- for case_file in glob.glob(os.path.join(category_dir, "*.json")):
97
- try:
98
- with open(case_file, 'r', encoding='utf-8') as f:
99
- case_data = json.load(f)
100
- notes = "\n".join(
101
- f"{k}: {v}" for k, v in case_data.items() if k.startswith("input")
102
- )
103
- docs.append(Document(
104
- page_content=f"""
105
- PATIENT CASE: {Path(case_file).stem}
106
- Category: {Path(category_dir).name}
107
- Notes: {notes}
108
- """,
109
- metadata={"source": case_file, "type": "patient_case"}
110
- ))
111
- except Exception as e:
112
- st.warning(f"Error loading case file {case_file}: {str(e)}")
113
- else:
114
- st.warning(f"Finished directory not found at {finished_dir}")
115
-
116
- # If no documents were loaded, add a sample document for testing
117
- if not docs:
118
- st.warning("No clinical data files found. Using sample data for demonstration.")
119
- docs.append(Document(
120
- page_content="""SAMPLE CLINICAL DATA: This is sample data for demonstration purposes.
121
- This application requires clinical data files to be present in the correct directories.
122
- Please ensure the Diagnosis_flowchart and Finished directories exist with proper JSON files.""",
123
- metadata={"source": "sample", "type": "sample"}
124
- ))
125
- except Exception as e:
126
- st.error(f"Error loading clinical data: {str(e)}")
127
- # Add a fallback document
128
- docs.append(Document(
129
- page_content="Error loading clinical data. This is a fallback document for demonstration purposes.",
130
- metadata={"source": "error", "type": "error"}
131
- ))
132
- return docs
133
-
134
- def build_vectorstore():
135
- """Build and return the vectorstore using FAISS"""
136
- documents = load_clinical_data()
137
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
138
- splits = splitter.split_documents(documents)
139
- vectorstore = FAISS.from_documents(splits, embeddings)
140
- return vectorstore
141
-
142
- # Path for saving/loading the vectorstore
143
- def get_vectorstore_path():
144
- """Get the path for saving/loading the vectorstore"""
145
- current_dir = os.path.dirname(os.path.abspath(__file__))
146
- return os.path.join(current_dir, "vectorstore")
147
-
148
- # Initialize vectorstore with disk persistence
149
- @st.cache_resource(show_spinner="Loading clinical knowledge base...")
150
- def get_vectorstore():
151
- """Get or create the vectorstore with disk persistence"""
152
- vectorstore_path = get_vectorstore_path()
153
-
154
- # Try to load from disk first
155
- try:
156
- if os.path.exists(vectorstore_path):
157
- st.info("Loading vectorstore from disk...")
158
- # Set allow_dangerous_deserialization to True since we trust our own vectorstore files
159
- return FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True)
160
- except Exception as e:
161
- st.warning(f"Could not load vectorstore from disk: {str(e)}. Building new vectorstore.")
162
-
163
- # If loading fails or doesn't exist, build a new one
164
- st.info("Building new vectorstore...")
165
- vectorstore = build_vectorstore()
166
-
167
- # Save to disk for future use
168
- try:
169
- os.makedirs(vectorstore_path, exist_ok=True)
170
- vectorstore.save_local(vectorstore_path)
171
- st.success("Vectorstore saved to disk for future use")
172
- except Exception as e:
173
- st.warning(f"Could not save vectorstore to disk: {str(e)}")
174
-
175
- return vectorstore
176
-
177
- def run_rag_chat(query, vectorstore):
178
- """Run the Retrieval-Augmented Generation (RAG) for clinical questions"""
179
- try:
180
- retriever = vectorstore.as_retriever()
181
-
182
- prompt_template = ChatPromptTemplate.from_template("""
183
- You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question.
184
-
185
- <context>
186
- {context}
187
- </context>
188
-
189
- Question: {input}
190
-
191
- Answer:
192
- """)
193
-
194
- retrieved_docs = retriever.invoke(query, k=3)
195
- retrieved_context = "\n".join([doc.page_content for doc in retrieved_docs])
196
-
197
- # Create document chain first
198
- document_chain = create_stuff_documents_chain(llm, prompt_template)
199
-
200
- # Then create retrieval chain
201
- chain = create_retrieval_chain(retriever, document_chain)
202
-
203
- # Invoke the chain
204
- response = chain.invoke({"input": query})
205
-
206
- # Add retrieved documents to response for transparency
207
- response["context"] = retrieved_docs
208
-
209
- return response
210
- except Exception as e:
211
- st.error(f"Error in RAG processing: {str(e)}")
212
- # Return a fallback response
213
- return {
214
- "answer": f"I encountered an error processing your query: {str(e)}",
215
- "context": [],
216
- "input": query
217
- }
218
-
219
- def calculate_hit_rate(retriever, query, expected_docs, k=3):
220
- """Calculate the hit rate for top-k retrieved documents"""
221
- retrieved_docs = retriever.get_relevant_documents(query, k=k)
222
- retrieved_contents = [doc.page_content for doc in retrieved_docs]
223
-
224
- hits = 0
225
- for expected in expected_docs:
226
- if any(expected in retrieved for retrieved in retrieved_contents):
227
- hits += 1
228
-
229
- return hits / len(expected_docs) if expected_docs else 0.0
230
-
231
- def evaluate_rag_response(response, embeddings):
232
- """Evaluate the RAG response for faithfulness and hit rate"""
233
- scores = {}
234
-
235
- # Faithfulness: Answer-Context Similarity
236
- answer_embed = embeddings.embed_query(response["answer"])
237
- context_embeds = [embeddings.embed_query(doc.page_content) for doc in response["context"]]
238
- similarities = [util.cos_sim(answer_embed, ctx_embed).item() for ctx_embed in context_embeds]
239
- scores["faithfulness"] = float(np.mean(similarities)) if similarities else 0.0
240
-
241
- # Custom Hit Rate Calculation
242
- retriever = response["retriever"]
243
- scores["hit_rate"] = calculate_hit_rate(
244
- retriever,
245
- query=response["input"],
246
- expected_docs=[doc.page_content for doc in response["context"]],
247
- k=3
248
- )
249
-
250
- return scores
251
-
252
- def main():
253
- """Main function to run the Streamlit app"""
254
- # Set page configuration
255
- st.set_page_config(
256
- page_title="DiReCT - Clinical Diagnostic Assistant",
257
- page_icon="🩺",
258
- layout="wide",
259
- initial_sidebar_state="expanded"
260
- )
261
-
262
- # Load vectorstore only once using session state
263
- if 'vectorstore' not in st.session_state:
264
- with st.spinner("Loading clinical knowledge base... This may take a minute."):
265
- try:
266
- st.session_state.vectorstore = get_vectorstore()
267
- # Use custom styled message without the success icon
268
- st.markdown("<div style='padding:10px 15px;background-color:rgba(40,167,69,0.2);border-radius:5px;border-left:5px solid rgba(40,167,69,0.8);'>Clinical knowledge base loaded successfully!</div>", unsafe_allow_html=True)
269
- except Exception as e:
270
- st.error(f"Error loading knowledge base: {str(e)}")
271
- st.session_state.vectorstore = None
272
-
273
- # Custom CSS for modern look with dark theme compatibility
274
- st.markdown("""
275
- <style>
276
- .stApp {max-width: 1200px; margin: 0 auto;}
277
- .css-18e3th9 {padding-top: 2rem;}
278
- .stButton>button {background-color: #3498db; color: white;}
279
- .stButton>button:hover {background-color: #2980b9;}
280
- .chat-message {border-radius: 10px; padding: 10px; margin-bottom: 10px;}
281
- .chat-message-user {background-color: rgba(52, 152, 219, 0.2); color: inherit;}
282
- .chat-message-assistant {background-color: rgba(240, 240, 240, 0.2); color: inherit;}
283
- .source-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-bottom: 10px; border-left: 5px solid #3498db;}
284
- .metrics-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-top: 20px;}
285
- .features-container {display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; margin-top: 30px;}
286
- .feature-item {flex: 1 1 calc(50% - 20px); min-width: 300px; display: flex; align-items: center; padding: 20px; border-radius: 10px; background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); transition: transform 0.3s, box-shadow 0.3s; border: 1px solid rgba(255, 255, 255, 0.1);}
287
- .feature-item:hover {transform: translateY(-5px); box-shadow: 0 10px 20px rgba(0, 0, 0, 0.1);}
288
- .feature-icon {width: 60px; height: 60px; border-radius: 50%; background: linear-gradient(135deg, #3498db, #2980b9); display: flex; align-items: center; justify-content: center; margin-right: 20px; box-shadow: 0 5px 15px rgba(52, 152, 219, 0.3);}
289
- .feature-icon i {font-size: 24px; color: white;}
290
- .feature-content {flex: 1;}
291
- .feature-content h3 {margin-top: 0; margin-bottom: 10px; color: inherit;}
292
- .feature-content p {margin: 0; font-size: 0.9em; color: inherit; opacity: 0.8;}
293
- .input-container {margin-bottom: 20px; padding: 15px; border-radius: 10px; background-color: rgba(255, 255, 255, 0.05); border: 1px solid rgba(255, 255, 255, 0.1);}
294
- </style>
295
- """, unsafe_allow_html=True)
296
-
297
- # App states
298
- if 'chat_history' not in st.session_state:
299
- st.session_state.chat_history = []
300
- if 'page' not in st.session_state:
301
- st.session_state.page = 'cover'
302
-
303
- # Sidebar
304
- with st.sidebar:
305
- st.image("https://img.icons8.com/color/96/000000/caduceus.png", width=80)
306
- st.title("DiReCT")
307
- st.markdown("### Diagnostic Reasoning for Clinical Text")
308
- st.markdown("---")
309
-
310
- if st.button("Home", key="home_btn"):
311
- st.session_state.page = 'cover'
312
- if st.button("Diagnostic Assistant", key="assistant_btn"):
313
- st.session_state.page = 'chat'
314
- if st.button("About", key="about_btn"):
315
- st.session_state.page = 'about'
316
-
317
- st.markdown("---")
318
- st.markdown("### Model Information")
319
- st.markdown("**Embedding Model:** Bio_ClinicalBERT")
320
- st.markdown("**LLM:** Llama-3.3-70B")
321
- st.markdown("**Vector Store:** FAISS")
322
-
323
- # Cover page
324
- if st.session_state.page == 'cover':
325
- # Hero section with animation
326
- col1, col2 = st.columns([2, 1])
327
- with col1:
328
- st.markdown("<h1 style='font-size:3.5em;'>DiReCT</h1>", unsafe_allow_html=True)
329
- st.markdown("<h2 style='font-size:1.8em;color:#3498db;'>Diagnostic Reasoning for Clinical Text</h2>", unsafe_allow_html=True)
330
- st.markdown("""<p style='font-size:1.2em;'>A powerful RAG-based clinical diagnostic assistant that leverages the MIMIC-IV-Ext dataset to provide accurate medical insights and diagnostic reasoning.</p>""", unsafe_allow_html=True)
331
-
332
- st.markdown("""<br>""", unsafe_allow_html=True)
333
- if st.button("Get Started", key="get_started"):
334
- st.session_state.page = 'chat'
335
- st.rerun()
336
-
337
- with col2:
338
- # Animated medical icon
339
- st.markdown("""
340
- <div style='display:flex;justify-content:center;align-items:center;height:100%;'>
341
- <img src="https://img.icons8.com/color/240/000000/healthcare-and-medical.png" style='max-width:90%;'>
342
- </div>
343
- """, unsafe_allow_html=True)
344
-
345
- # Modern Features section with Streamlit native components
346
- st.markdown("<br><br>", unsafe_allow_html=True)
347
- st.markdown("<h2 style='text-align:center;'>Key Features</h2>", unsafe_allow_html=True)
348
-
349
- # Create a 2x2 grid for features using Streamlit columns
350
- col1, col2 = st.columns(2)
351
-
352
- # Feature 1
353
- with col1:
354
- st.markdown("""
355
- <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
356
- padding: 20px; border-radius: 10px; height: 100%;
357
- border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
358
- <h3>πŸ” Intelligent Retrieval</h3>
359
- <p>Finds the most relevant clinical information from the MIMIC-IV-Ext dataset</p>
360
- </div>
361
- """, unsafe_allow_html=True)
362
-
363
- # Feature 2
364
- with col2:
365
- st.markdown("""
366
- <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
367
- padding: 20px; border-radius: 10px; height: 100%;
368
- border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
369
- <h3>🧠 Advanced Reasoning</h3>
370
- <p>Applies clinical knowledge to generate accurate diagnostic insights</p>
371
- </div>
372
- """, unsafe_allow_html=True)
373
-
374
- # Feature 3
375
- with col1:
376
- st.markdown("""
377
- <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
378
- padding: 20px; border-radius: 10px; height: 100%;
379
- border: 1px solid rgba(255, 255, 255, 0.1);">
380
- <h3>πŸ“„ Source Transparency</h3>
381
- <p>Provides references to all clinical sources used in generating responses</p>
382
- </div>
383
- """, unsafe_allow_html=True)
384
-
385
- # Feature 4
386
- with col2:
387
- st.markdown("""
388
- <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
389
- padding: 20px; border-radius: 10px; height: 100%;
390
- border: 1px solid rgba(255, 255, 255, 0.1);">
391
- <h3>πŸŒ“ Dark/Light Theme Compatible</h3>
392
- <p>Optimized interface that works seamlessly in both dark and light themes</p>
393
- </div>
394
- """, unsafe_allow_html=True)
395
-
396
- # Chat interface
397
- elif st.session_state.page == 'chat':
398
- # Initialize session state for input if not exists
399
- if 'user_input' not in st.session_state:
400
- st.session_state.user_input = ""
401
-
402
- # Header with clear button
403
- col1, col2 = st.columns([3, 1])
404
- with col1:
405
- st.markdown("<h1>Clinical Diagnostic Assistant</h1>", unsafe_allow_html=True)
406
- with col2:
407
- # Add a clear button in the header
408
- if st.button("πŸ—‘οΈ Clear Chat"):
409
- st.session_state.chat_history = []
410
- st.session_state.user_input = ""
411
- st.rerun()
412
-
413
- st.markdown("Ask any clinical diagnostic question and get insights based on medical knowledge and patient cases.")
414
-
415
- # Fixed input area at the top
416
- with st.container():
417
- st.markdown("<div class='input-container'>", unsafe_allow_html=True)
418
- user_input = st.text_area("Ask a clinical question:", st.session_state.user_input, height=100, key="question_input")
419
- col1, col2 = st.columns([1, 5])
420
- with col1:
421
- submit_button = st.button("Submit")
422
- st.markdown("</div>", unsafe_allow_html=True)
423
-
424
- # Create a container for chat history
425
- chat_container = st.container()
426
-
427
- # Process query
428
- if submit_button and user_input:
429
- if st.session_state.vectorstore is None:
430
- st.error("Knowledge base not loaded. Please refresh the page and try again.")
431
- else:
432
- with st.spinner("Analyzing clinical data..."):
433
- try:
434
- # Add a small delay for UX
435
- time.sleep(0.5)
436
-
437
- # Run RAG
438
- response = run_rag_chat(user_input, st.session_state.vectorstore)
439
- response["retriever"] = st.session_state.vectorstore.as_retriever()
440
-
441
- # Clear previous chat history and only keep the current response
442
- st.session_state.chat_history = [(user_input, response)]
443
-
444
- # Clear the input field
445
- st.session_state.user_input = ""
446
-
447
- # Rerun to update UI
448
- st.rerun()
449
- except Exception as e:
450
- st.error(f"Error processing query: {str(e)}")
451
-
452
- # Display chat history in the container
453
- with chat_container:
454
- for i, (query, response) in enumerate(st.session_state.chat_history):
455
- st.markdown(f"<div class='chat-message chat-message-user'><b>πŸ§‘β€βš•οΈ You:</b> {query}</div>", unsafe_allow_html=True)
456
-
457
- st.markdown(f"<div class='chat-message chat-message-assistant'><b>🩺 DiReCT:</b> {response['answer']}</div>", unsafe_allow_html=True)
458
-
459
- with st.expander("View Sources"):
460
- for doc in response["context"]:
461
- st.markdown(f"<div class='source-box'>"
462
- f"<b>Source:</b> {Path(doc.metadata['source']).stem}<br>"
463
- f"<b>Type:</b> {doc.metadata['type']}<br>"
464
- f"<b>Content:</b> {doc.page_content[:300]}...</div>",
465
- unsafe_allow_html=True)
466
-
467
- # Show evaluation metrics if available
468
- try:
469
- eval_scores = evaluate_rag_response(response, embeddings)
470
- with st.expander("View Evaluation Metrics"):
471
- col1, col2 = st.columns(2)
472
- with col1:
473
- st.metric("Hit Rate (Top-3)", f"{eval_scores['hit_rate']:.2f}")
474
- with col2:
475
- st.metric("Faithfulness", f"{eval_scores['faithfulness']:.2f}")
476
- except Exception as e:
477
- st.warning(f"Evaluation metrics unavailable: {str(e)}")
478
-
479
- # About page
480
- elif st.session_state.page == 'about':
481
- st.markdown("<h1>About DiReCT</h1>", unsafe_allow_html=True)
482
-
483
- st.markdown("""
484
- ### Project Overview
485
-
486
- DiReCT (Diagnostic Reasoning for Clinical Text) is a Retrieval-Augmented Generation (RAG) system designed to assist medical professionals with diagnostic reasoning based on clinical notes and medical knowledge.
487
-
488
- ### Data Sources
489
-
490
- This application uses the MIMIC-IV-Ext dataset, which contains de-identified clinical notes and medical records. The system processes:
491
-
492
- - Diagnostic flowcharts
493
- - Patient cases
494
- - Clinical guidelines
495
-
496
- ### Technical Implementation
497
-
498
- - **Embedding Model**: Bio_ClinicalBERT for domain-specific text understanding
499
- - **Vector Database**: FAISS for efficient similarity search
500
- - **LLM**: Llama-3.3-70B for generating medically accurate responses
501
- - **Framework**: Built with LangChain and Streamlit
502
-
503
- ### Evaluation Metrics
504
-
505
- The system evaluates responses using:
506
-
507
- - **Hit Rate**: Measures how many relevant documents were retrieved
508
- - **Faithfulness**: Measures how well the response aligns with the retrieved context
509
-
510
- ### Ethical Considerations
511
-
512
- This system is designed as a clinical decision support tool and not as a replacement for professional medical judgment. All patient data used has been properly de-identified in compliance with healthcare privacy regulations.
513
- """)
514
-
515
- st.markdown("<br>", unsafe_allow_html=True)
516
- st.markdown("### Developers")
517
- st.markdown("This project was developed as part of an academic assignment on RAG systems for clinical applications.")
518
-
519
- if __name__ == "__main__":
520
- main()
 
 
1
+ import os
2
+ import json
3
+ import glob
4
+ from pathlib import Path
5
+ import torch
6
+ import streamlit as st
7
+ from dotenv import load_dotenv
8
+ from langchain_groq import ChatGroq
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_core.documents import Document
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain.chains import create_retrieval_chain
15
+ from langchain.chains.combine_documents import create_stuff_documents_chain
16
+ import numpy as np
17
+ from sentence_transformers import util
18
+ import time
19
+
20
+ # Set device for model (CUDA if available)
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # Load environment variables - works for both local and Hugging Face Spaces
24
+ load_dotenv()
25
+
26
+ # Set up the clinical assistant LLM
27
+ # Try to get API key from Hugging Face Spaces secrets first, then fall back to .env file
28
+ try:
29
+ # For Hugging Face Spaces
30
+ from huggingface_hub.inference_api import InferenceApi
31
+ import os
32
+ groq_api_key = os.environ.get('GROQ_API_KEY')
33
+
34
+ # If not found in environment, try to get from st.secrets (Streamlit Cloud/Spaces)
35
+ if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
36
+ groq_api_key = st.secrets['GROQ_API_KEY']
37
+
38
+ if not groq_api_key:
39
+ st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.")
40
+ # For UI demonstration without API key
41
+ class MockLLM:
42
+ def invoke(self, prompt):
43
+ return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."}
44
+ llm = MockLLM()
45
+ else:
46
+ llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
47
+
48
+ except Exception as e:
49
+ st.error(f"Error setting up LLM: {str(e)}")
50
+ class MockLLM:
51
+ def invoke(self, prompt):
52
+ return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."}
53
+ llm = MockLLM()
54
+
55
+ # Set up embeddings for clinical context (Bio_ClinicalBERT)
56
+ embeddings = HuggingFaceEmbeddings(
57
+ model_name="emilyalsentzer/Bio_ClinicalBERT",
58
+ model_kwargs={"device": device}
59
+ )
60
+
61
+
62
+ def load_clinical_data():
63
+ """Load both flowcharts and patient cases"""
64
+ docs = []
65
+
66
+ # Get the absolute path to the current script
67
+ current_dir = os.path.dirname(os.path.abspath(__file__))
68
+
69
+ # Try to handle potential errors with file loading
70
+ try:
71
+ # Load diagnosis flowcharts
72
+ flowchart_dir = os.path.join(current_dir, "Diagnosis_flowchart")
73
+ if os.path.exists(flowchart_dir):
74
+ for fpath in glob.glob(os.path.join(flowchart_dir, "*.json")):
75
+ try:
76
+ with open(fpath, 'r', encoding='utf-8') as f:
77
+ data = json.load(f)
78
+ content = f"""
79
+ DIAGNOSTIC FLOWCHART: {Path(fpath).stem}
80
+ Diagnostic Path: {data.get('diagnostic', 'N/A')}
81
+ Key Criteria: {data.get('knowledge', 'N/A')}
82
+ """
83
+ docs.append(Document(
84
+ page_content=content,
85
+ metadata={"source": fpath, "type": "flowchart"}
86
+ ))
87
+ except Exception as e:
88
+ st.warning(f"Error loading flowchart file {fpath}: {str(e)}")
89
+ else:
90
+ st.warning(f"Flowchart directory not found at {flowchart_dir}")
91
+
92
+ # Load patient cases
93
+ finished_dir = os.path.join(current_dir, "Finished")
94
+ if os.path.exists(finished_dir):
95
+ for category_dir in glob.glob(os.path.join(finished_dir, "*")):
96
+ if os.path.isdir(category_dir):
97
+ for case_file in glob.glob(os.path.join(category_dir, "*.json")):
98
+ try:
99
+ with open(case_file, 'r', encoding='utf-8') as f:
100
+ case_data = json.load(f)
101
+ notes = "\n".join(
102
+ f"{k}: {v}" for k, v in case_data.items() if k.startswith("input")
103
+ )
104
+ docs.append(Document(
105
+ page_content=f"""
106
+ PATIENT CASE: {Path(case_file).stem}
107
+ Category: {Path(category_dir).name}
108
+ Notes: {notes}
109
+ """,
110
+ metadata={"source": case_file, "type": "patient_case"}
111
+ ))
112
+ except Exception as e:
113
+ st.warning(f"Error loading case file {case_file}: {str(e)}")
114
+ else:
115
+ st.warning(f"Finished directory not found at {finished_dir}")
116
+
117
+ # If no documents were loaded, add a sample document for testing
118
+ if not docs:
119
+ st.warning("No clinical data files found. Using sample data for demonstration.")
120
+ docs.append(Document(
121
+ page_content="""SAMPLE CLINICAL DATA: This is sample data for demonstration purposes.
122
+ This application requires clinical data files to be present in the correct directories.
123
+ Please ensure the Diagnosis_flowchart and Finished directories exist with proper JSON files.""",
124
+ metadata={"source": "sample", "type": "sample"}
125
+ ))
126
+ except Exception as e:
127
+ st.error(f"Error loading clinical data: {str(e)}")
128
+ # Add a fallback document
129
+ docs.append(Document(
130
+ page_content="Error loading clinical data. This is a fallback document for demonstration purposes.",
131
+ metadata={"source": "error", "type": "error"}
132
+ ))
133
+ return docs
134
+
135
+ def build_vectorstore():
136
+ """Build and return the vectorstore using FAISS"""
137
+ documents = load_clinical_data()
138
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
139
+ splits = splitter.split_documents(documents)
140
+ vectorstore = FAISS.from_documents(splits, embeddings)
141
+ return vectorstore
142
+
143
+ # Path for saving/loading the vectorstore
144
+ def get_vectorstore_path():
145
+ """Get the path for saving/loading the vectorstore"""
146
+ current_dir = os.path.dirname(os.path.abspath(__file__))
147
+ return os.path.join(current_dir, "vectorstore")
148
+
149
+ # Initialize vectorstore with disk persistence
150
+ @st.cache_resource(show_spinner="Loading clinical knowledge base...")
151
+ def get_vectorstore():
152
+ """Get or create the vectorstore with disk persistence"""
153
+ vectorstore_path = get_vectorstore_path()
154
+
155
+ # Try to load from disk first
156
+ try:
157
+ if os.path.exists(vectorstore_path):
158
+ st.info("Loading vectorstore from disk...")
159
+ # Set allow_dangerous_deserialization to True since we trust our own vectorstore files
160
+ return FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True)
161
+ except Exception as e:
162
+ st.warning(f"Could not load vectorstore from disk: {str(e)}. Building new vectorstore.")
163
+
164
+ # If loading fails or doesn't exist, build a new one
165
+ st.info("Building new vectorstore...")
166
+ vectorstore = build_vectorstore()
167
+
168
+ # Save to disk for future use
169
+ try:
170
+ os.makedirs(vectorstore_path, exist_ok=True)
171
+ vectorstore.save_local(vectorstore_path)
172
+ st.success("Vectorstore saved to disk for future use")
173
+ except Exception as e:
174
+ st.warning(f"Could not save vectorstore to disk: {str(e)}")
175
+
176
+ return vectorstore
177
+
178
+ def run_rag_chat(query, vectorstore):
179
+ """Run the Retrieval-Augmented Generation (RAG) for clinical questions"""
180
+ try:
181
+ retriever = vectorstore.as_retriever()
182
+
183
+ prompt_template = ChatPromptTemplate.from_template("""
184
+ You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question.
185
+
186
+ <context>
187
+ {context}
188
+ </context>
189
+
190
+ Question: {input}
191
+
192
+ Answer:
193
+ """)
194
+
195
+ retrieved_docs = retriever.invoke(query, k=3)
196
+ retrieved_context = "\n".join([doc.page_content for doc in retrieved_docs])
197
+
198
+ # Create document chain first
199
+ document_chain = create_stuff_documents_chain(llm, prompt_template)
200
+
201
+ # Then create retrieval chain
202
+ chain = create_retrieval_chain(retriever, document_chain)
203
+
204
+ # Invoke the chain
205
+ response = chain.invoke({"input": query})
206
+
207
+ # Add retrieved documents to response for transparency
208
+ response["context"] = retrieved_docs
209
+
210
+ return response
211
+ except Exception as e:
212
+ st.error(f"Error in RAG processing: {str(e)}")
213
+ # Return a fallback response
214
+ return {
215
+ "answer": f"I encountered an error processing your query: {str(e)}",
216
+ "context": [],
217
+ "input": query
218
+ }
219
+
220
+ def calculate_hit_rate(retriever, query, expected_docs, k=3):
221
+ """Calculate the hit rate for top-k retrieved documents"""
222
+ retrieved_docs = retriever.get_relevant_documents(query, k=k)
223
+ retrieved_contents = [doc.page_content for doc in retrieved_docs]
224
+
225
+ hits = 0
226
+ for expected in expected_docs:
227
+ if any(expected in retrieved for retrieved in retrieved_contents):
228
+ hits += 1
229
+
230
+ return hits / len(expected_docs) if expected_docs else 0.0
231
+
232
+ def evaluate_rag_response(response, embeddings):
233
+ """Evaluate the RAG response for faithfulness and hit rate"""
234
+ scores = {}
235
+
236
+ # Faithfulness: Answer-Context Similarity
237
+ answer_embed = embeddings.embed_query(response["answer"])
238
+ context_embeds = [embeddings.embed_query(doc.page_content) for doc in response["context"]]
239
+ similarities = [util.cos_sim(answer_embed, ctx_embed).item() for ctx_embed in context_embeds]
240
+ scores["faithfulness"] = float(np.mean(similarities)) if similarities else 0.0
241
+
242
+ # Custom Hit Rate Calculation
243
+ retriever = response["retriever"]
244
+ scores["hit_rate"] = calculate_hit_rate(
245
+ retriever,
246
+ query=response["input"],
247
+ expected_docs=[doc.page_content for doc in response["context"]],
248
+ k=3
249
+ )
250
+
251
+ return scores
252
+
253
+ def main():
254
+ """Main function to run the Streamlit app"""
255
+ # Set page configuration
256
+ st.set_page_config(
257
+ page_title="DiReCT - Clinical Diagnostic Assistant",
258
+ page_icon="🩺",
259
+ layout="wide",
260
+ initial_sidebar_state="expanded"
261
+ )
262
+
263
+ # Load vectorstore only once using session state
264
+ if 'vectorstore' not in st.session_state:
265
+ with st.spinner("Loading clinical knowledge base... This may take a minute."):
266
+ try:
267
+ st.session_state.vectorstore = get_vectorstore()
268
+ # Use custom styled message without the success icon
269
+ st.markdown("<div style='padding:10px 15px;background-color:rgba(40,167,69,0.2);border-radius:5px;border-left:5px solid rgba(40,167,69,0.8);'>Clinical knowledge base loaded successfully!</div>", unsafe_allow_html=True)
270
+ except Exception as e:
271
+ st.error(f"Error loading knowledge base: {str(e)}")
272
+ st.session_state.vectorstore = None
273
+
274
+ # Custom CSS for modern look with dark theme compatibility
275
+ st.markdown("""
276
+ <style>
277
+ .stApp {max-width: 1200px; margin: 0 auto;}
278
+ .css-18e3th9 {padding-top: 2rem;}
279
+ .stButton>button {background-color: #3498db; color: white;}
280
+ .stButton>button:hover {background-color: #2980b9;}
281
+ .chat-message {border-radius: 10px; padding: 10px; margin-bottom: 10px;}
282
+ .chat-message-user {background-color: rgba(52, 152, 219, 0.2); color: inherit;}
283
+ .chat-message-assistant {background-color: rgba(240, 240, 240, 0.2); color: inherit;}
284
+ .source-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-bottom: 10px; border-left: 5px solid #3498db;}
285
+ .metrics-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-top: 20px;}
286
+ .features-container {display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; margin-top: 30px;}
287
+ .feature-item {flex: 1 1 calc(50% - 20px); min-width: 300px; display: flex; align-items: center; padding: 20px; border-radius: 10px; background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); transition: transform 0.3s, box-shadow 0.3s; border: 1px solid rgba(255, 255, 255, 0.1);}
288
+ .feature-item:hover {transform: translateY(-5px); box-shadow: 0 10px 20px rgba(0, 0, 0, 0.1);}
289
+ .feature-icon {width: 60px; height: 60px; border-radius: 50%; background: linear-gradient(135deg, #3498db, #2980b9); display: flex; align-items: center; justify-content: center; margin-right: 20px; box-shadow: 0 5px 15px rgba(52, 152, 219, 0.3);}
290
+ .feature-icon i {font-size: 24px; color: white;}
291
+ .feature-content {flex: 1;}
292
+ .feature-content h3 {margin-top: 0; margin-bottom: 10px; color: inherit;}
293
+ .feature-content p {margin: 0; font-size: 0.9em; color: inherit; opacity: 0.8;}
294
+ .input-container {margin-bottom: 20px; padding: 15px; border-radius: 10px; background-color: rgba(255, 255, 255, 0.05); border: 1px solid rgba(255, 255, 255, 0.1);}
295
+ </style>
296
+ """, unsafe_allow_html=True)
297
+
298
+ # App states
299
+ if 'chat_history' not in st.session_state:
300
+ st.session_state.chat_history = []
301
+ if 'page' not in st.session_state:
302
+ st.session_state.page = 'cover'
303
+
304
+ # Sidebar
305
+ with st.sidebar:
306
+ st.image("https://img.icons8.com/color/96/000000/caduceus.png", width=80)
307
+ st.title("DiReCT")
308
+ st.markdown("### Diagnostic Reasoning for Clinical Text")
309
+ st.markdown("---")
310
+
311
+ if st.button("Home", key="home_btn"):
312
+ st.session_state.page = 'cover'
313
+ if st.button("Diagnostic Assistant", key="assistant_btn"):
314
+ st.session_state.page = 'chat'
315
+ if st.button("About", key="about_btn"):
316
+ st.session_state.page = 'about'
317
+
318
+ st.markdown("---")
319
+ st.markdown("### Model Information")
320
+ st.markdown("**Embedding Model:** Bio_ClinicalBERT")
321
+ st.markdown("**LLM:** Llama-3.3-70B")
322
+ st.markdown("**Vector Store:** FAISS")
323
+
324
+ # Cover page
325
+ if st.session_state.page == 'cover':
326
+ # Hero section with animation
327
+ col1, col2 = st.columns([2, 1])
328
+ with col1:
329
+ st.markdown("<h1 style='font-size:3.5em;'>DiReCT</h1>", unsafe_allow_html=True)
330
+ st.markdown("<h2 style='font-size:1.8em;color:#3498db;'>Diagnostic Reasoning for Clinical Text</h2>", unsafe_allow_html=True)
331
+ st.markdown("""<p style='font-size:1.2em;'>A powerful RAG-based clinical diagnostic assistant that leverages the MIMIC-IV-Ext dataset to provide accurate medical insights and diagnostic reasoning.</p>""", unsafe_allow_html=True)
332
+
333
+ st.markdown("""<br>""", unsafe_allow_html=True)
334
+ if st.button("Get Started", key="get_started"):
335
+ st.session_state.page = 'chat'
336
+ st.rerun()
337
+
338
+ with col2:
339
+ # Animated medical icon
340
+ st.markdown("""
341
+ <div style='display:flex;justify-content:center;align-items:center;height:100%;'>
342
+ <img src="https://img.icons8.com/color/240/000000/healthcare-and-medical.png" style='max-width:90%;'>
343
+ </div>
344
+ """, unsafe_allow_html=True)
345
+
346
+ # Modern Features section with Streamlit native components
347
+ st.markdown("<br><br>", unsafe_allow_html=True)
348
+ st.markdown("<h2 style='text-align:center;'>Key Features</h2>", unsafe_allow_html=True)
349
+
350
+ # Create a 2x2 grid for features using Streamlit columns
351
+ col1, col2 = st.columns(2)
352
+
353
+ # Feature 1
354
+ with col1:
355
+ st.markdown("""
356
+ <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
357
+ padding: 20px; border-radius: 10px; height: 100%;
358
+ border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
359
+ <h3>πŸ” Intelligent Retrieval</h3>
360
+ <p>Finds the most relevant clinical information from the MIMIC-IV-Ext dataset</p>
361
+ </div>
362
+ """, unsafe_allow_html=True)
363
+
364
+ # Feature 2
365
+ with col2:
366
+ st.markdown("""
367
+ <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
368
+ padding: 20px; border-radius: 10px; height: 100%;
369
+ border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
370
+ <h3>🧠 Advanced Reasoning</h3>
371
+ <p>Applies clinical knowledge to generate accurate diagnostic insights</p>
372
+ </div>
373
+ """, unsafe_allow_html=True)
374
+
375
+ # Feature 3
376
+ with col1:
377
+ st.markdown("""
378
+ <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
379
+ padding: 20px; border-radius: 10px; height: 100%;
380
+ border: 1px solid rgba(255, 255, 255, 0.1);">
381
+ <h3>πŸ“„ Source Transparency</h3>
382
+ <p>Provides references to all clinical sources used in generating responses</p>
383
+ </div>
384
+ """, unsafe_allow_html=True)
385
+
386
+ # Feature 4
387
+ with col2:
388
+ st.markdown("""
389
+ <div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
390
+ padding: 20px; border-radius: 10px; height: 100%;
391
+ border: 1px solid rgba(255, 255, 255, 0.1);">
392
+ <h3>πŸŒ“ Dark/Light Theme Compatible</h3>
393
+ <p>Optimized interface that works seamlessly in both dark and light themes</p>
394
+ </div>
395
+ """, unsafe_allow_html=True)
396
+
397
+ # Chat interface
398
+ elif st.session_state.page == 'chat':
399
+ # Initialize session state for input if not exists
400
+ if 'user_input' not in st.session_state:
401
+ st.session_state.user_input = ""
402
+
403
+ # Header with clear button
404
+ col1, col2 = st.columns([3, 1])
405
+ with col1:
406
+ st.markdown("<h1>Clinical Diagnostic Assistant</h1>", unsafe_allow_html=True)
407
+ with col2:
408
+ # Add a clear button in the header
409
+ if st.button("πŸ—‘οΈ Clear Chat"):
410
+ st.session_state.chat_history = []
411
+ st.session_state.user_input = ""
412
+ st.rerun()
413
+
414
+ st.markdown("Ask any clinical diagnostic question and get insights based on medical knowledge and patient cases.")
415
+
416
+ # Fixed input area at the top
417
+ with st.container():
418
+ st.markdown("<div class='input-container'>", unsafe_allow_html=True)
419
+ user_input = st.text_area("Ask a clinical question:", st.session_state.user_input, height=100, key="question_input")
420
+ col1, col2 = st.columns([1, 5])
421
+ with col1:
422
+ submit_button = st.button("Submit")
423
+ st.markdown("</div>", unsafe_allow_html=True)
424
+
425
+ # Create a container for chat history
426
+ chat_container = st.container()
427
+
428
+ # Process query
429
+ if submit_button and user_input:
430
+ if st.session_state.vectorstore is None:
431
+ st.error("Knowledge base not loaded. Please refresh the page and try again.")
432
+ else:
433
+ with st.spinner("Analyzing clinical data..."):
434
+ try:
435
+ # Add a small delay for UX
436
+ time.sleep(0.5)
437
+
438
+ # Run RAG
439
+ response = run_rag_chat(user_input, st.session_state.vectorstore)
440
+ response["retriever"] = st.session_state.vectorstore.as_retriever()
441
+
442
+ # Clear previous chat history and only keep the current response
443
+ st.session_state.chat_history = [(user_input, response)]
444
+
445
+ # Clear the input field
446
+ st.session_state.user_input = ""
447
+
448
+ # Rerun to update UI
449
+ st.rerun()
450
+ except Exception as e:
451
+ st.error(f"Error processing query: {str(e)}")
452
+
453
+ # Display chat history in the container
454
+ with chat_container:
455
+ for i, (query, response) in enumerate(st.session_state.chat_history):
456
+ st.markdown(f"<div class='chat-message chat-message-user'><b>πŸ§‘β€βš•οΈ You:</b> {query}</div>", unsafe_allow_html=True)
457
+
458
+ st.markdown(f"<div class='chat-message chat-message-assistant'><b>🩺 DiReCT:</b> {response['answer']}</div>", unsafe_allow_html=True)
459
+
460
+ with st.expander("View Sources"):
461
+ for doc in response["context"]:
462
+ st.markdown(f"<div class='source-box'>"
463
+ f"<b>Source:</b> {Path(doc.metadata['source']).stem}<br>"
464
+ f"<b>Type:</b> {doc.metadata['type']}<br>"
465
+ f"<b>Content:</b> {doc.page_content[:300]}...</div>",
466
+ unsafe_allow_html=True)
467
+
468
+ # Show evaluation metrics if available
469
+ try:
470
+ eval_scores = evaluate_rag_response(response, embeddings)
471
+ with st.expander("View Evaluation Metrics"):
472
+ col1, col2 = st.columns(2)
473
+ with col1:
474
+ st.metric("Hit Rate (Top-3)", f"{eval_scores['hit_rate']:.2f}")
475
+ with col2:
476
+ st.metric("Faithfulness", f"{eval_scores['faithfulness']:.2f}")
477
+ except Exception as e:
478
+ st.warning(f"Evaluation metrics unavailable: {str(e)}")
479
+
480
+ # About page
481
+ elif st.session_state.page == 'about':
482
+ st.markdown("<h1>About DiReCT</h1>", unsafe_allow_html=True)
483
+
484
+ st.markdown("""
485
+ ### Project Overview
486
+
487
+ DiReCT (Diagnostic Reasoning for Clinical Text) is a Retrieval-Augmented Generation (RAG) system designed to assist medical professionals with diagnostic reasoning based on clinical notes and medical knowledge.
488
+
489
+ ### Data Sources
490
+
491
+ This application uses the MIMIC-IV-Ext dataset, which contains de-identified clinical notes and medical records. The system processes:
492
+
493
+ - Diagnostic flowcharts
494
+ - Patient cases
495
+ - Clinical guidelines
496
+
497
+ ### Technical Implementation
498
+
499
+ - **Embedding Model**: Bio_ClinicalBERT for domain-specific text understanding
500
+ - **Vector Database**: FAISS for efficient similarity search
501
+ - **LLM**: Llama-3.3-70B for generating medically accurate responses
502
+ - **Framework**: Built with LangChain and Streamlit
503
+
504
+ ### Evaluation Metrics
505
+
506
+ The system evaluates responses using:
507
+
508
+ - **Hit Rate**: Measures how many relevant documents were retrieved
509
+ - **Faithfulness**: Measures how well the response aligns with the retrieved context
510
+
511
+ ### Ethical Considerations
512
+
513
+ This system is designed as a clinical decision support tool and not as a replacement for professional medical judgment. All patient data used has been properly de-identified in compliance with healthcare privacy regulations.
514
+ """)
515
+
516
+ st.markdown("<br>", unsafe_allow_html=True)
517
+ st.markdown("### Developers")
518
+ st.markdown("This project was developed as part of an academic assignment on RAG systems for clinical applications.")
519
+
520
+ if __name__ == "__main__":
521
+ main()