Ali2206 commited on
Commit
d8282f1
·
verified ·
1 Parent(s): e44a01b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +327 -72
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import sys
2
  import os
3
- import pandas as pd
4
  import json
5
  import gradio as gr
6
  from typing import List, Tuple
@@ -8,7 +8,18 @@ import hashlib
8
  import shutil
9
  import re
10
  from datetime import datetime
 
 
 
 
 
 
11
 
 
 
 
 
 
12
  persistent_dir = "/data/hf_cache"
13
  os.makedirs(persistent_dir, exist_ok=True)
14
 
@@ -29,35 +40,101 @@ sys.path.insert(0, src_path)
29
 
30
  from txagent.txagent import TxAgent
31
 
 
 
 
32
  def file_hash(path: str) -> str:
 
33
  with open(path, "rb") as f:
34
  return hashlib.md5(f.read()).hexdigest()
35
 
36
  def clean_response(text: str) -> str:
 
37
  try:
38
  text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
39
  except UnicodeError:
40
  text = text.encode('utf-8', 'replace').decode('utf-8')
 
41
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
42
  text = re.sub(r"\n{3,}", "\n\n", text)
43
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
44
  return text.strip()
45
 
46
- def parse_excel_as_whole_prompt(file_path: str) -> str:
47
- xl = pd.ExcelFile(file_path)
48
- df = xl.parse(xl.sheet_names[0], header=0).fillna("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  records = []
50
- for _, row in df.iterrows():
51
- record = f"- {row['Form Name']}: {row['Form Item']} = {row['Item Response']} ({row['Interview Date']} by {row['Interviewer']})\n{row['Description']}"
52
- records.append(clean_response(record))
 
 
 
 
 
 
53
  record_text = "\n".join(records)
54
  prompt = f"""
55
- Patient Complete History:
56
 
57
  Instructions:
58
- Based on the complete patient record below, identify any potential missed diagnoses, medication conflicts, incomplete assessments, and urgent follow-up needs. Provide a clinical summary under the markdown headings.
59
 
60
- Patient History:
61
  {record_text}
62
 
63
  ### Missed Diagnoses
@@ -75,78 +152,256 @@ Patient History:
75
  return prompt
76
 
77
  def init_agent():
 
78
  default_tool_path = os.path.abspath("data/new_tool.json")
79
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
 
80
  if not os.path.exists(target_tool_path):
81
  shutil.copy(default_tool_path, target_tool_path)
82
- agent = TxAgent(
83
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
84
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
85
- tool_files_dict={"new_tool": target_tool_path},
86
- force_finish=True,
87
- enable_checker=True,
88
- step_rag_num=4,
89
- seed=100,
90
- additional_default_tools=[],
91
- )
92
- agent.init_model()
93
- return agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def create_ui(agent):
96
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
97
- gr.Markdown("<h1 style='text-align: center;'>🏥 Full Medical History Analyzer</h1>")
98
- chatbot = gr.Chatbot(label="Summary Output", height=600)
99
- file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"], file_count="single")
100
- msg_input = gr.Textbox(label="Optional Message", placeholder="Add context or instructions...", lines=2)
101
- send_btn = gr.Button("Analyze")
102
- download_output = gr.File(label="Download Report")
103
-
104
- def analyze(message: str, chat_history: List[Tuple[str, str]], file) -> Tuple[List[Tuple[str, str]], str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if not file:
106
- raise gr.Error("Please upload an Excel file.")
107
- new_history = chat_history + [(message, None)]
108
- new_history.append((None, "⏳ Analyzing full patient history..."))
109
- yield new_history, None
110
-
111
  try:
112
- prompt = parse_excel_as_whole_prompt(file.name)
113
- full_output = ""
114
- for result in agent.run_gradio_chat(
115
- message=prompt,
116
- history=[],
117
- temperature=0.2,
118
- max_new_tokens=2048,
119
- max_token=4096,
120
- call_agent=False,
121
- conversation=[],
122
- ):
123
- if isinstance(result, list):
124
- for r in result:
125
- if hasattr(r, 'content') and r.content:
126
- full_output += clean_response(r.content) + "\n"
127
- elif isinstance(result, str):
128
- full_output += clean_response(result) + "\n"
129
-
130
- new_history[-1] = (None, full_output.strip())
131
- report_path = os.path.join(report_dir, f"{file_hash(file.name)}_final_report.txt")
132
- with open(report_path, "w", encoding="utf-8") as f:
133
- f.write(full_output.strip())
134
- yield new_history, report_path
135
  except Exception as e:
136
- new_history.append((None, f" Error during analysis: {str(e)}"))
 
137
  yield new_history, None
138
-
139
- send_btn.click(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output])
140
- msg_input.submit(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return demo
142
 
143
  if __name__ == "__main__":
144
- agent = init_agent()
145
- demo = create_ui(agent)
146
- demo.queue(api_open=False).launch(
147
- server_name="0.0.0.0",
148
- server_port=7860,
149
- show_error=True,
150
- allowed_paths=[report_dir],
151
- share=False
152
- )
 
 
 
 
 
 
 
 
 
 
1
  import sys
2
  import os
3
+ import polars as pl
4
  import json
5
  import gradio as gr
6
  from typing import List, Tuple
 
8
  import shutil
9
  import re
10
  from datetime import datetime
11
+ import time
12
+ import asyncio
13
+ import aiofiles
14
+ import cachetools
15
+ import logging
16
+ import markdown
17
 
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Configuration and setup
23
  persistent_dir = "/data/hf_cache"
24
  os.makedirs(persistent_dir, exist_ok=True)
25
 
 
40
 
41
  from txagent.txagent import TxAgent
42
 
43
+ # Cache for processed data
44
+ cache = cachetools.LRUCache(maxsize=100)
45
+
46
  def file_hash(path: str) -> str:
47
+ """Generate MD5 hash of a file."""
48
  with open(path, "rb") as f:
49
  return hashlib.md5(f.read()).hexdigest()
50
 
51
  def clean_response(text: str) -> str:
52
+ """Clean text by removing unwanted characters and normalizing."""
53
  try:
54
  text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
55
  except UnicodeError:
56
  text = text.encode('utf-8', 'replace').decode('utf-8')
57
+
58
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
59
  text = re.sub(r"\n{3,}", "\n\n", text)
60
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
61
  return text.strip()
62
 
63
+ async def load_and_clean_data(file_path: str) -> pl.DataFrame:
64
+ """Load and clean Excel data using polars."""
65
+ try:
66
+ logger.info(f"Loading Excel file: {file_path}")
67
+ df = pl.read_excel(file_path).with_columns([
68
+ pl.col(col).str.strip_chars().fill_null("").alias(col) for col in [
69
+ "Booking Number", "Form Name", "Form Item", "Item Response",
70
+ "Interviewer", "Interview Date", "Description"
71
+ ]
72
+ ]).filter(pl.col("Booking Number").str.starts_with("BKG"))
73
+ logger.info(f"Loaded {len(df)} records")
74
+ return df
75
+ except Exception as e:
76
+ logger.error(f"Error loading data: {str(e)}")
77
+ raise
78
+
79
+ def generate_summary(df: pl.DataFrame) -> tuple[str, dict]:
80
+ """Generate summary statistics and interesting fact."""
81
+ symptom_counts = {}
82
+ for desc in df["Description"]:
83
+ desc = desc.lower()
84
+ if "chest discomfort" in desc:
85
+ symptom_counts["Chest Discomfort"] = symptom_counts.get("Chest Discomfort", 0) + 1
86
+ if "headaches" in desc:
87
+ symptom_counts["Headaches"] = symptom_counts.get("Headaches", 0) + 1
88
+ if "weight loss" in desc:
89
+ symptom_counts["Weight Loss"] = symptom_counts.get("Weight Loss", 0) + 1
90
+ if "back pain" in desc:
91
+ symptom_counts["Chronic Back Pain"] = symptom_counts.get("Chronic Back Pain", 0) + 1
92
+ if "cough" in desc:
93
+ symptom_counts["Persistent Cough"] = symptom_counts.get("Persistent Cough", 0) + 1
94
+
95
+ total_records = len(df)
96
+ unique_bookings = df["Booking Number"].n_unique()
97
+ interesting_fact = (
98
+ f"Chest discomfort was reported in {symptom_counts.get('Chest Discomfort', 0)} records, "
99
+ "frequently leading to ECG/lab referrals. Inconsistent follow-up documentation raises "
100
+ "concerns about potential missed cardiovascular diagnoses."
101
+ )
102
+
103
+ summary = (
104
+ f"## Summary\n\n"
105
+ f"Analyzed {total_records:,} patient records from {unique_bookings:,} unique bookings in 2023. "
106
+ f"Key findings include a high prevalence of chest discomfort ({symptom_counts.get('Chest Discomfort', 0)} instances), "
107
+ f"suggesting possible underdiagnosis of cardiovascular issues.\n\n"
108
+ f"### Interesting Fact\n{interesting_fact}\n"
109
+ )
110
+ return summary, symptom_counts
111
+
112
+ def prepare_aggregate_prompt(df: pl.DataFrame) -> str:
113
+ """Prepare a single prompt for all patient data."""
114
+ groups = df.group_by("Booking Number").agg([
115
+ pl.col("Form Name"), pl.col("Form Item"),
116
+ pl.col("Item Response"), pl.col("Interviewer"),
117
+ pl.col("Interview Date"), pl.col("Description")
118
+ ])
119
+
120
  records = []
121
+ for booking in groups.iter_rows(named=True):
122
+ booking_id = booking["Booking Number"]
123
+ for i in range(len(booking["Form Name"])):
124
+ record = (
125
+ f"- {booking['Form Name'][i]}: {booking['Form Item'][i]} = {booking['Item Response'][i]} "
126
+ f"({booking['Interview Date'][i]} by {booking['Interviewer'][i]})\n{booking['Description'][i]}"
127
+ )
128
+ records.append(clean_response(record))
129
+
130
  record_text = "\n".join(records)
131
  prompt = f"""
132
+ Patient Medical History Analysis
133
 
134
  Instructions:
135
+ Analyze the following aggregated patient data from all bookings to identify potential missed diagnoses, medication conflicts, incomplete assessments, and urgent follow-up needs across the entire dataset. Provide a comprehensive summary under the specified markdown headings. Focus on patterns and recurring issues across multiple patients.
136
 
137
+ Data:
138
  {record_text}
139
 
140
  ### Missed Diagnoses
 
152
  return prompt
153
 
154
  def init_agent():
155
+ """Initialize TxAgent with tool configuration."""
156
  default_tool_path = os.path.abspath("data/new_tool.json")
157
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
158
+
159
  if not os.path.exists(target_tool_path):
160
  shutil.copy(default_tool_path, target_tool_path)
161
+
162
+ try:
163
+ agent = TxAgent(
164
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
165
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
166
+ tool_files_dict={"new_tool": target_tool_path},
167
+ force_finish=True,
168
+ enable_checker=True,
169
+ step_rag_num=4,
170
+ seed=100,
171
+ additional_default_tools=[],
172
+ )
173
+ agent.init_model()
174
+ return agent
175
+ except Exception as e:
176
+ logger.error(f"Failed to initialize TxAgent: {str(e)}")
177
+ raise
178
+
179
+ async def generate_report(agent, df: pl.DataFrame, file_hash_value: str) -> tuple[str, str]:
180
+ """Generate a comprehensive markdown report."""
181
+ logger.info("Generating comprehensive report...")
182
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.md")
183
+
184
+ # Generate summary
185
+ summary, symptom_counts = generate_summary(df)
186
+
187
+ # Prepare and run aggregated analysis
188
+ prompt = prepare_aggregate_prompt(df)
189
+ full_output = ""
190
+
191
+ try:
192
+ chunk_output = ""
193
+ for result in agent.run_gradio_chat(
194
+ message=prompt,
195
+ history=[],
196
+ temperature=0.2,
197
+ max_new_tokens=2048,
198
+ max_token=8192,
199
+ call_agent=False,
200
+ conversation=[],
201
+ ):
202
+ if isinstance(result, list):
203
+ for r in result:
204
+ if hasattr(r, 'content') and r.content:
205
+ cleaned = clean_response(r.content)
206
+ chunk_output += cleaned + "\n"
207
+ elif isinstance(result, str):
208
+ cleaned = clean_response(result)
209
+ chunk_output += cleaned + "\n"
210
+ full_output = chunk_output.strip()
211
+ yield full_output, None # Stream partial results
212
+
213
+ # Filter out empty sections
214
+ sections = ["Missed Diagnoses", "Medication Conflicts", "Incomplete Assessments", "Urgent Follow-up"]
215
+ filtered_output = []
216
+ current_section = None
217
+ for line in full_output.split("\n"):
218
+ if any(line.startswith(f"### {section}") for section in sections):
219
+ current_section = line
220
+ filtered_output.append(line)
221
+ elif current_section and line.strip().startswith("-") and line.strip() != "- ...":
222
+ filtered_output.append(line)
223
+
224
+ # Compile final report
225
+ final_output = summary + "## Clinical Findings\n\n"
226
+ if filtered_output:
227
+ final_output += "\n".join(filtered_output) + "\n\n"
228
+ else:
229
+ final_output += "No significant clinical findings identified.\n\n"
230
+
231
+ final_output += (
232
+ "## Conclusion\n\n"
233
+ "The analysis reveals significant gaps in patient care, including potential missed cardiovascular diagnoses "
234
+ "due to inconsistent follow-up on chest discomfort and elevated vitals. Low medication adherence is a recurring "
235
+ "issue, likely impacting treatment efficacy. Incomplete assessments, particularly missing vital signs, hinder "
236
+ "comprehensive care. Urgent follow-up is recommended for patients with chest discomfort and elevated vitals to "
237
+ "prevent adverse outcomes."
238
+ )
239
+
240
+ # Save report
241
+ async with aiofiles.open(report_path, "w") as f:
242
+ await f.write(final_output)
243
+
244
+ logger.info(f"Report saved to {report_path}")
245
+ yield final_output, report_path
246
+
247
+ except Exception as e:
248
+ logger.error(f"Error generating report: {str(e)}")
249
+ yield f"Error: {str(e)}", None
250
 
251
  def create_ui(agent):
252
+ """Create Gradio interface for clinical oversight analysis."""
253
+ with gr.Blocks(
254
+ theme=gr.themes.Soft(),
255
+ title="Clinical Oversight Assistant",
256
+ css="""
257
+ .gradio-container {max-width: 1000px; margin: auto; font-family: Arial, sans-serif;}
258
+ #chatbot {border: 1px solid #e5e7eb; border-radius: 8px; padding: 10px; background: #f9fafb;}
259
+ .markdown {white-space: pre-wrap;}
260
+ """
261
+ ) as demo:
262
+ gr.Markdown("# 🏥 Clinical Oversight Assistant (Excel Optimized)")
263
+
264
+ with gr.Tabs():
265
+ with gr.TabItem("Analysis"):
266
+ with gr.Row():
267
+ # Left column - Inputs
268
+ with gr.Column(scale=1):
269
+ file_upload = gr.File(
270
+ label="Upload Excel File",
271
+ file_types=[".xlsx"],
272
+ file_count="single",
273
+ interactive=True
274
+ )
275
+ msg_input = gr.Textbox(
276
+ label="Additional Instructions",
277
+ placeholder="Add any specific analysis requests...",
278
+ lines=3
279
+ )
280
+ with gr.Row():
281
+ clear_btn = gr.Button("Clear", variant="secondary")
282
+ send_btn = gr.Button("Analyze", variant="primary")
283
+
284
+ # Right column - Outputs
285
+ with gr.Column(scale=2):
286
+ chatbot = gr.Chatbot(
287
+ label="Analysis Results",
288
+ height=600,
289
+ bubble_full_width=False,
290
+ show_copy_button=True,
291
+ elem_id="chatbot"
292
+ )
293
+ download_output = gr.File(
294
+ label="Download Full Report",
295
+ interactive=False
296
+ )
297
+
298
+ with gr.TabItem("Instructions"):
299
+ gr.Markdown("""
300
+ ## How to Use This Tool
301
+
302
+ 1. **Upload Excel File**: Select your patient records Excel file
303
+ 2. **Add Instructions** (Optional): Provide any specific analysis requests
304
+ 3. **Click Analyze**: The system will process all patient records and generate a comprehensive report
305
+ 4. **Review Results**: Analysis appears in the chat window
306
+ 5. **Download Report**: Get a full markdown report of all findings
307
+
308
+ ### Excel File Requirements
309
+ Your Excel file must contain these columns:
310
+ - Booking Number
311
+ - Form Name
312
+ - Form Item
313
+ - Item Response
314
+ - Interview Date
315
+ - Interviewer
316
+ - Description
317
+
318
+ ### Analysis Includes
319
+ - Missed diagnoses
320
+ - Medication conflicts
321
+ - Incomplete assessments
322
+ - Urgent follow-up needs
323
+ """)
324
+
325
+ def format_message(role: str, content: str) -> Tuple[str, str]:
326
+ """Format messages for the chatbot in (user, bot) format."""
327
+ if role == "user":
328
+ return (content, None)
329
+ else:
330
+ return (None, content)
331
+
332
+ async def analyze(message: str, chat_history: List[Tuple[str, str]], file) -> Tuple[List[Tuple[str, str]], str]:
333
+ """Analyze uploaded file and generate comprehensive report."""
334
  if not file:
335
+ raise gr.Error("Please upload an Excel file first")
336
+
 
 
 
337
  try:
338
+ # Initialize chat history
339
+ new_history = chat_history + [format_message("user", message)]
340
+ new_history.append(format_message("assistant", "⏳ Processing Excel data..."))
341
+ yield new_history, None
342
+
343
+ # Load and clean data
344
+ df = await load_and_clean_data(file.name)
345
+ file_hash_value = file_hash(file.name)
346
+
347
+ # Generate report
348
+ async for output, report_path in generate_report(agent, df, file_hash_value):
349
+ if output:
350
+ new_history[-1] = format_message("assistant", output)
351
+ yield new_history, report_path
352
+ else:
353
+ yield new_history, report_path
354
+
 
 
 
 
 
 
355
  except Exception as e:
356
+ logger.error(f"Analysis failed: {str(e)}")
357
+ new_history.append(format_message("assistant", f"❌ Error: {str(e)}"))
358
  yield new_history, None
359
+ raise gr.Error(f"Analysis failed: {str(e)}")
360
+
361
+ def clear_chat():
362
+ """Clear chat history and download output."""
363
+ return [], None
364
+
365
+ # Event handlers
366
+ send_btn.click(
367
+ analyze,
368
+ inputs=[msg_input, chatbot, file_upload],
369
+ outputs=[chatbot, download_output],
370
+ api_name="analyze",
371
+ queue=True
372
+ )
373
+
374
+ msg_input.submit(
375
+ analyze,
376
+ inputs=[msg_input, chatbot, file_upload],
377
+ outputs=[chatbot, download_output],
378
+ queue=True
379
+ )
380
+
381
+ clear_btn.click(
382
+ clear_chat,
383
+ inputs=[],
384
+ outputs=[chatbot, download_output]
385
+ )
386
+
387
  return demo
388
 
389
  if __name__ == "__main__":
390
+ try:
391
+ agent = init_agent()
392
+ demo = create_ui(agent)
393
+
394
+ demo.queue(
395
+ api_open=False,
396
+ max_size=20
397
+ ).launch(
398
+ server_name="0.0.0.0",
399
+ server_port=7860,
400
+ show_error=True,
401
+ allowed_paths=[report_dir],
402
+ share=False
403
+ )
404
+ except Exception as e:
405
+ logger.error(f"Failed to launch application: {str(e)}")
406
+ print(f"Failed to launch application: {str(e)}")
407
+ sys.exit(1)