Ali2206 commited on
Commit
1c5bd8e
·
verified ·
1 Parent(s): f75a23b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -273
app.py CHANGED
@@ -1,21 +1,15 @@
1
  import sys
2
  import os
3
  import pandas as pd
4
- import pdfplumber
5
  import json
6
  import gradio as gr
7
  from typing import List
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import re
12
- import psutil
13
- import subprocess
14
- import multiprocessing
15
- from functools import partial
16
  import time
17
 
18
- # Persistent directory
19
  persistent_dir = "/data/hf_cache"
20
  os.makedirs(persistent_dir, exist_ok=True)
21
 
@@ -23,16 +17,12 @@ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
23
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
24
  file_cache_dir = os.path.join(persistent_dir, "cache")
25
  report_dir = os.path.join(persistent_dir, "reports")
26
- vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
27
 
28
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
29
  os.makedirs(directory, exist_ok=True)
30
 
31
  os.environ["HF_HOME"] = model_cache_dir
32
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
33
- os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
34
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
 
37
  current_dir = os.path.dirname(os.path.abspath(__file__))
38
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
@@ -40,146 +30,56 @@ sys.path.insert(0, src_path)
40
 
41
  from txagent.txagent import TxAgent
42
 
43
- def sanitize_utf8(text: str) -> str:
44
- return text.encode("utf-8", "ignore").decode("utf-8")
45
-
46
  def file_hash(path: str) -> str:
47
  with open(path, "rb") as f:
48
  return hashlib.md5(f.read()).hexdigest()
49
 
50
- def extract_page_range(file_path: str, start_page: int, end_page: int) -> str:
51
- """Extract text from a range of PDF pages."""
52
- try:
53
- text_chunks = []
54
- with pdfplumber.open(file_path) as pdf:
55
- for page in pdf.pages[start_page:end_page]:
56
- page_text = page.extract_text() or ""
57
- text_chunks.append(f"=== Page {start_page + pdf.pages.index(page) + 1} ===\n{page_text.strip()}")
58
- return "\n\n".join(text_chunks)
59
- except Exception:
60
- return ""
61
-
62
- def extract_all_pages(file_path: str, progress_callback=None) -> str:
63
- """Extract text from all pages of a PDF using parallel processing."""
64
- try:
65
- with pdfplumber.open(file_path) as pdf:
66
- total_pages = len(pdf.pages)
67
-
68
- if total_pages == 0:
69
- return ""
70
-
71
- # Use 6 processes (adjust based on CPU cores)
72
- num_processes = min(6, multiprocessing.cpu_count())
73
- pages_per_process = max(1, total_pages // num_processes)
74
-
75
- # Create page ranges for parallel processing
76
- ranges = [(i * pages_per_process, min((i + 1) * pages_per_process, total_pages))
77
- for i in range(num_processes)]
78
- if ranges[-1][1] != total_pages:
79
- ranges[-1] = (ranges[-1][0], total_pages)
80
-
81
- # Process page ranges in parallel
82
- with multiprocessing.Pool(processes=num_processes) as pool:
83
- extract_func = partial(extract_page_range, file_path)
84
- results = []
85
- for idx, result in enumerate(pool.starmap(extract_func, ranges)):
86
- results.append(result)
87
- if progress_callback:
88
- processed_pages = min((idx + 1) * pages_per_process, total_pages)
89
- progress_callback(processed_pages, total_pages)
90
-
91
- return "\n\n".join(filter(None, results))
92
- except Exception as e:
93
- return f"PDF processing error: {str(e)}"
94
 
95
- def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
96
- try:
97
- h = file_hash(file_path)
98
- cache_path = os.path.join(file_cache_dir, f"{h}.json")
99
- if os.path.exists(cache_path):
100
- with open(cache_path, "r", encoding="utf-8") as f:
101
- return f.read()
102
 
103
- if file_type == "pdf":
104
- text = extract_all_pages(file_path, progress_callback)
105
- result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
106
- elif file_type == "csv":
107
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
108
- skip_blank_lines=False, on_bad_lines="skip")
109
- content = df.fillna("").astype(str).values.tolist()
110
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
111
- elif file_type in ["xls", "xlsx"]:
112
- try:
113
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
114
- except Exception:
115
- df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
116
- content = df.fillna("").astype(str).values.tolist()
117
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
118
- else:
119
- result = json.dumps({"error": f"Unsupported file type: {file_type}"})
120
- with open(cache_path, "w", encoding="utf-8") as f:
121
- f.write(result)
122
- return result
123
- except Exception as e:
124
- return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
125
 
126
- def log_system_usage(tag=""):
127
- try:
128
- cpu = psutil.cpu_percent(interval=1)
129
- mem = psutil.virtual_memory()
130
- print(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
131
- result = subprocess.run(
132
- ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
133
- capture_output=True, text=True
134
- )
135
- if result.returncode == 0:
136
- used, total, util = result.stdout.strip().split(", ")
137
- print(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
138
- except Exception as e:
139
- print(f"[{tag}] GPU/CPU monitor failed: {e}")
140
 
141
- def clean_response(text: str) -> str:
142
- """Clean TxAgent response to keep only markdown sections with valid findings."""
143
- text = sanitize_utf8(text)
144
- # Remove tool call artifacts, None, and reasoning
145
- text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
146
- # Remove extra whitespace and non-markdown content
147
- text = re.sub(r"\n{3,}", "\n\n", text)
148
- text = re.sub(r"[^\n#\-\*\w\s\.\,\:\(\)]+", "", text) # Keep markdown-relevant characters
149
-
150
- # Extract markdown sections with valid findings
151
- sections = []
152
- current_section = None
153
- lines = text.splitlines()
154
- for line in lines:
155
- line = line.strip()
156
- if not line:
157
- continue
158
- if re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
159
- current_section = line
160
- sections.append([current_section])
161
- elif current_section and re.match(r"-\s*.+", line) and not re.match(r"-\s*No issues identified", line):
162
- sections[-1].append(line)
163
-
164
- # Combine only non-empty sections
165
- cleaned = []
166
- for section in sections:
167
- if len(section) > 1: # Section has at least one finding
168
- cleaned.append("\n".join(section))
169
-
170
- text = "\n\n".join(cleaned).strip()
171
- if not text:
172
- text = "" # Return empty string if no valid findings
173
- return text
174
 
175
  def init_agent():
176
- print("🔁 Initializing model...")
177
- log_system_usage("Before Load")
178
  default_tool_path = os.path.abspath("data/new_tool.json")
179
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
180
  if not os.path.exists(target_tool_path):
181
  shutil.copy(default_tool_path, target_tool_path)
182
-
183
  agent = TxAgent(
184
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
185
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -191,160 +91,59 @@ def init_agent():
191
  additional_default_tools=[],
192
  )
193
  agent.init_model()
194
- log_system_usage("After Load")
195
- print("✅ Agent Ready")
196
  return agent
197
 
198
  def create_ui(agent):
199
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
200
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
201
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
202
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
203
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
204
  send_btn = gr.Button("Analyze", variant="primary")
205
  download_output = gr.File(label="Download Full Report")
206
 
207
- def analyze(message: str, history: List[dict], files: List):
208
  history.append({"role": "user", "content": message})
209
- history.append({"role": "assistant", "content": "⏳ Extracting text from files..."})
210
  yield history, None
211
 
212
- extracted = ""
213
- file_hash_value = ""
214
- if files:
215
- # Progress callback for extraction
216
- total_pages = 0
217
- processed_pages = 0
218
- def update_extraction_progress(current, total):
219
- nonlocal processed_pages, total_pages
220
- processed_pages = current
221
- total_pages = total
222
- animation = ["🌀", "🔄", "⚙️", "🔃"][(int(time.time() * 2) % 4)]
223
- history[-1] = {"role": "assistant", "content": f"Extracting text... {animation} Page {processed_pages}/{total_pages}"}
224
- return history, None
225
-
226
- with ThreadPoolExecutor(max_workers=6) as executor:
227
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
228
- results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
229
- extracted = "\n".join(results)
230
- file_hash_value = file_hash(files[0].name) if files else ""
231
-
232
- history.pop() # Remove extraction message
233
- history.append({"role": "assistant", "content": "✅ Text extraction complete."})
234
- yield history, None
235
-
236
- # Split extracted text into chunks of ~6,000 characters
237
- chunk_size = 6000
238
- chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
239
- combined_response = ""
240
-
241
- prompt_template = """
242
- You are a medical analysis assistant. Analyze the following patient record excerpt for clinical oversights and provide a concise, evidence-based summary in markdown format under these headings: Missed Diagnoses, Medication Conflicts, Incomplete Assessments, and Urgent Follow-up. For each finding, include:
243
- - Clinical context (why the issue was missed or relevant details from the record).
244
- - Potential risks if unaddressed (e.g., disease progression, adverse events).
245
- - Actionable recommendations (e.g., tests, referrals, medication adjustments).
246
- Output ONLY the markdown-formatted findings, with bullet points under each heading. Do NOT include reasoning, tool calls, or intermediate steps. If no issues are found in a section, state "No issues identified." Ensure the output is specific to the provided text and avoids generic responses.
247
-
248
- Example Output:
249
- ### Missed Diagnoses
250
- - Elevated BP noted without diagnosis. Missed due to inconsistent visits. Risks: stroke. Recommend: BP monitoring, antihypertensives.
251
- ### Medication Conflicts
252
- - No issues identified.
253
- ### Incomplete Assessments
254
- - Chest pain not evaluated. Time constraints likely cause. Risks: cardiac issues. Recommend: ECG, stress test.
255
- ### Urgent Follow-up
256
- - Abnormal creatinine not addressed. Delayed lab review. Risks: renal failure. Recommend: nephrology referral.
257
-
258
- Patient Record Excerpt (Chunk {0} of {1}):
259
- {chunk}
260
-
261
- ### Missed Diagnoses
262
- - ...
263
-
264
- ### Medication Conflicts
265
- - ...
266
-
267
- ### Incomplete Assessments
268
- - ...
269
-
270
- ### Urgent Follow-up
271
- - ...
272
- """
273
-
274
- try:
275
- # Process each chunk and stream results in real-time
276
- for chunk_idx, chunk in enumerate(chunks, 1):
277
- # Update UI with chunk progress
278
- animation = ["🔍", "📊", "🧠", "🔎"][(int(time.time() * 2) % 4)]
279
- history.append({"role": "assistant", "content": f"Analyzing records... {animation} Chunk {chunk_idx}/{len(chunks)}"})
280
  yield history, None
281
 
282
- prompt = prompt_template.format(chunk_idx, len(chunks), chunk=chunk[:4000]) # Truncate to avoid token limits
283
- chunk_response = ""
284
- for chunk_output in agent.run_gradio_chat(
285
- message=prompt,
286
- history=[],
287
- temperature=0.2,
288
- max_new_tokens=1024,
289
- max_token=4096,
290
- call_agent=False,
291
- conversation=[],
292
- ):
293
- if chunk_output is None:
294
- continue
295
- if isinstance(chunk_output, list):
296
- for m in chunk_output:
297
- if hasattr(m, 'content') and m.content:
298
- cleaned = clean_response(m.content)
299
- if cleaned and re.search(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", cleaned):
300
- chunk_response += cleaned + "\n\n"
301
- # Update UI with partial response
302
- if history[-1]["content"].startswith("Analyzing"):
303
- history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
304
- else:
305
- history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
306
- yield history, None
307
- elif isinstance(chunk_output, str) and chunk_output.strip():
308
- cleaned = clean_response(chunk_output)
309
- if cleaned and re.search(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", cleaned):
310
- chunk_response += cleaned + "\n\n"
311
- # Update UI with partial response
312
- if history[-1]["content"].startswith("Analyzing"):
313
- history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
314
- else:
315
- history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
316
- yield history, None
317
-
318
- # Append completed chunk response to combined response
319
- if chunk_response:
320
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
321
- else:
322
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo oversights identified for this chunk.\n\n"
323
-
324
- # Finalize UI with complete response
325
- if combined_response.strip() and not all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
326
- history[-1]["content"] = combined_response.strip()
327
- else:
328
- history.append({"role": "assistant", "content": "No oversights identified in the provided records."})
329
-
330
- # Generate report file
331
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
332
- if report_path:
333
- with open(report_path, "w", encoding="utf-8") as f:
334
- f.write(combined_response)
335
- yield history, report_path if report_path and os.path.exists(report_path) else None
336
-
337
- except Exception as e:
338
- print("🚨 ERROR:", e)
339
- history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
340
- yield history, None
341
 
342
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
343
  msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
344
  return demo
345
 
346
  if __name__ == "__main__":
347
- print("🚀 Launching app...")
348
  agent = init_agent()
349
  demo = create_ui(agent)
350
  demo.queue(api_open=False).launch(
 
1
  import sys
2
  import os
3
  import pandas as pd
 
4
  import json
5
  import gradio as gr
6
  from typing import List
 
7
  import hashlib
8
  import shutil
9
  import re
10
+ from datetime import datetime
 
 
 
11
  import time
12
 
 
13
  persistent_dir = "/data/hf_cache"
14
  os.makedirs(persistent_dir, exist_ok=True)
15
 
 
17
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
18
  file_cache_dir = os.path.join(persistent_dir, "cache")
19
  report_dir = os.path.join(persistent_dir, "reports")
 
20
 
21
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
22
  os.makedirs(directory, exist_ok=True)
23
 
24
  os.environ["HF_HOME"] = model_cache_dir
25
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
 
 
 
26
 
27
  current_dir = os.path.dirname(os.path.abspath(__file__))
28
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
30
 
31
  from txagent.txagent import TxAgent
32
 
 
 
 
33
  def file_hash(path: str) -> str:
34
  with open(path, "rb") as f:
35
  return hashlib.md5(f.read()).hexdigest()
36
 
37
+ def clean_response(text: str) -> str:
38
+ text = text.encode("utf-8", "ignore").decode("utf-8")
39
+ text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
40
+ text = re.sub(r"\n{3,}", "\n\n", text)
41
+ text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
42
+ return text.strip()
43
+
44
+ def parse_excel_to_prompts(file_path: str) -> List[str]:
45
+ xl = pd.ExcelFile(file_path)
46
+ df = xl.parse(xl.sheet_names[0], header=0).fillna("")
47
+ groups = df.groupby("Booking Number")
48
+ prompts = []
49
+ for booking, group in groups:
50
+ records = []
51
+ for _, row in group.iterrows():
52
+ records.append(f"- {row['Form Name']}: {row['Form Item']} = {row['Item Response']} ({row['Interview Date']} by {row['Interviewer']})\n{row['Description']}")
53
+ record_text = "\n".join(records)
54
+ prompt = f"""
55
+ Patient Booking Number: {booking}
56
+
57
+ Instructions:
58
+ Analyze the following patient case for missed diagnoses, medication conflicts, incomplete assessments, and any urgent follow-up needed. Summarize under the markdown headings.
59
+
60
+ Data:
61
+ {record_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ ### Missed Diagnoses
64
+ - ...
 
 
 
 
 
65
 
66
+ ### Medication Conflicts
67
+ - ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ ### Incomplete Assessments
70
+ - ...
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ ### Urgent Follow-up
73
+ - ...
74
+ """
75
+ prompts.append(prompt)
76
+ return prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def init_agent():
 
 
79
  default_tool_path = os.path.abspath("data/new_tool.json")
80
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
81
  if not os.path.exists(target_tool_path):
82
  shutil.copy(default_tool_path, target_tool_path)
 
83
  agent = TxAgent(
84
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
85
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
91
  additional_default_tools=[],
92
  )
93
  agent.init_model()
 
 
94
  return agent
95
 
96
  def create_ui(agent):
97
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
98
+ gr.Markdown("<h1 style='text-align: center;'>\ud83e\uddfa Clinical Oversight Assistant (Excel Optimized)</h1>")
99
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
100
+ file_upload = gr.File(file_types=[".xlsx"], file_count="single")
101
+ msg_input = gr.Textbox(placeholder="Ask about patient history...", show_label=False)
102
  send_btn = gr.Button("Analyze", variant="primary")
103
  download_output = gr.File(label="Download Full Report")
104
 
105
+ def analyze(message: str, history: List[dict], file) -> tuple:
106
  history.append({"role": "user", "content": message})
107
+ history.append({"role": "assistant", "content": "⏳ Processing Excel data..."})
108
  yield history, None
109
 
110
+ prompts = parse_excel_to_prompts(file.name)
111
+ full_output = ""
112
+
113
+ for idx, prompt in enumerate(prompts, 1):
114
+ chunk_output = ""
115
+ for result in agent.run_gradio_chat(
116
+ message=prompt,
117
+ history=[],
118
+ temperature=0.2,
119
+ max_new_tokens=1024,
120
+ max_token=4096,
121
+ call_agent=False,
122
+ conversation=[],
123
+ ):
124
+ if isinstance(result, list):
125
+ for r in result:
126
+ if hasattr(r, 'content') and r.content:
127
+ chunk_output += clean_response(r.content) + "\n"
128
+ elif isinstance(result, str):
129
+ chunk_output += clean_response(result) + "\n"
130
+ if chunk_output:
131
+ output = f"--- Booking {idx} ---\n{chunk_output.strip()}\n"
132
+ history.append({"role": "assistant", "content": output})
133
+ full_output += output + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  yield history, None
135
 
136
+ file_hash_value = file_hash(file.name)
137
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
138
+ with open(report_path, "w", encoding="utf-8") as f:
139
+ f.write(full_output)
140
+ yield history, report_path if os.path.exists(report_path) else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
143
  msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
144
  return demo
145
 
146
  if __name__ == "__main__":
 
147
  agent = init_agent()
148
  demo = create_ui(agent)
149
  demo.queue(api_open=False).launch(