Ali2206 commited on
Commit
a1a096d
Β·
verified Β·
1 Parent(s): c5da27e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -198
app.py CHANGED
@@ -1,27 +1,26 @@
1
- import sys
2
- import os
3
  import pandas as pd
4
- import json
5
- import gradio as gr
6
- from typing import List, Tuple, Dict, Any, Union
7
- import hashlib
8
- import shutil
9
- import re
10
  from datetime import datetime
11
- import time
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
13
 
14
- # Configuration and setup
15
- persistent_dir = "/data/hf_cache"
16
- os.makedirs(persistent_dir, exist_ok=True)
17
 
 
 
 
 
 
 
 
 
18
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
19
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
20
  file_cache_dir = os.path.join(persistent_dir, "cache")
21
  report_dir = os.path.join(persistent_dir, "reports")
22
 
23
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
- os.makedirs(directory, exist_ok=True)
25
 
26
  os.environ["HF_HOME"] = model_cache_dir
27
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
@@ -32,62 +31,47 @@ sys.path.insert(0, src_path)
32
 
33
  from txagent.txagent import TxAgent
34
 
35
- MAX_MODEL_TOKENS = 32768
36
- MAX_CHUNK_TOKENS = 8192
37
- MAX_NEW_TOKENS = 2048
38
- PROMPT_OVERHEAD = 500
39
 
40
  def clean_response(text: str) -> str:
41
- try:
42
- text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
43
- except UnicodeError:
44
- text = text.encode('utf-8', 'replace').decode('utf-8')
45
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
46
  text = re.sub(r"\n{3,}", "\n\n", text)
47
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
48
  return text.strip()
49
 
50
- def estimate_tokens(text: str) -> int:
51
- return len(text) // 3.5 + 1
52
-
53
- def extract_text_from_excel(file_path: str) -> str:
54
  all_text = []
55
  try:
56
- xls = pd.ExcelFile(file_path)
57
- for sheet_name in xls.sheet_names:
58
- df = xls.parse(sheet_name)
59
- df = df.astype(str).fillna("")
60
  rows = df.apply(lambda row: " | ".join(row), axis=1)
61
- sheet_text = [f"[{sheet_name}] {line}" for line in rows]
62
- all_text.extend(sheet_text)
63
  except Exception as e:
64
- raise ValueError(f"Failed to extract text from Excel file: {str(e)}")
65
  return "\n".join(all_text)
66
 
67
- def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
68
- effective_max_tokens = max_tokens - PROMPT_OVERHEAD
69
- if effective_max_tokens <= 0:
70
- raise ValueError("Effective max tokens must be positive.")
71
- lines = text.split("\n")
72
- chunks, current_chunk, current_tokens = [], [], 0
73
- for line in lines:
74
- line_tokens = estimate_tokens(line)
75
- if current_tokens + line_tokens > effective_max_tokens:
76
- if current_chunk:
77
- chunks.append("\n".join(current_chunk))
78
- current_chunk, current_tokens = [line], line_tokens
79
  else:
80
- current_chunk.append(line)
81
- current_tokens += line_tokens
82
- if current_chunk:
83
- chunks.append("\n".join(current_chunk))
84
  return chunks
85
 
86
- def build_prompt_from_text(chunk: str) -> str:
87
- return f"""
88
- ### Unstructured Clinical Records
89
 
90
- Analyze the following clinical notes and provide a detailed, concise summary focusing on:
91
  - Diagnostic Patterns
92
  - Medication Issues
93
  - Missed Opportunities
@@ -99,179 +83,147 @@ Analyze the following clinical notes and provide a detailed, concise summary foc
99
  {chunk}
100
 
101
  ---
102
- Respond in well-structured bullet points with medical reasoning.
103
- """
104
-
105
- def init_agent():
106
- default_tool_path = os.path.abspath("data/new_tool.json")
107
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
108
- if not os.path.exists(target_tool_path):
109
- shutil.copy(default_tool_path, target_tool_path)
110
  agent = TxAgent(
111
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
112
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
113
- tool_files_dict={"new_tool": target_tool_path},
114
  force_finish=True,
115
  enable_checker=True,
116
  step_rag_num=4,
117
- seed=100,
118
- additional_default_tools=[]
119
  )
120
  agent.init_model()
121
  return agent
122
 
123
- def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
124
- messages = chatbot_state if chatbot_state else []
125
- report_path = None
126
-
127
- if file is None or not hasattr(file, "name"):
128
- messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
129
- return messages, report_path
130
 
131
- try:
132
- messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
133
- extracted_text = extract_text_from_excel(file.name)
134
- chunks = split_text_into_chunks(extracted_text)
135
- chunk_responses = [None] * len(chunks)
136
-
137
- def analyze_chunk(index: int, chunk: str) -> Tuple[int, str]:
138
- prompt = build_prompt_from_text(chunk)
139
- prompt_tokens = estimate_tokens(prompt)
140
- if prompt_tokens > MAX_MODEL_TOKENS:
141
- return index, f"❌ Chunk {index+1} prompt too long. Skipping..."
142
  response = ""
143
- try:
144
- for result in agent.run_gradio_chat(
145
- message=prompt,
146
- history=[],
147
- temperature=0.2,
148
- max_new_tokens=MAX_NEW_TOKENS,
149
- max_token=MAX_MODEL_TOKENS,
150
- call_agent=False,
151
- conversation=[],
152
- ):
153
- if isinstance(result, str):
154
- response += result
155
- elif isinstance(result, list):
156
- for r in result:
157
- if hasattr(r, "content"):
158
- response += r.content
159
- elif hasattr(result, "content"):
160
- response += result.content
161
- except Exception as e:
162
- return index, f"❌ Error analyzing chunk {index+1}: {str(e)}"
163
- return index, clean_response(response)
164
-
165
- with ThreadPoolExecutor(max_workers=1) as executor:
166
- futures = [executor.submit(analyze_chunk, i, chunk) for i, chunk in enumerate(chunks)]
167
- for future in as_completed(futures):
168
- i, result = future.result()
169
- chunk_responses[i] = result
170
- if result.startswith("❌"):
171
- messages.append({"role": "assistant", "content": result})
172
-
173
- valid_responses = [res for res in chunk_responses if not res.startswith("❌")]
174
- if not valid_responses:
175
- messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
176
- return messages, report_path
177
-
178
- summary = "\n\n".join(valid_responses)
179
- final_prompt = f"Provide a structured, consolidated clinical analysis from these results:\n\n{summary}"
180
- messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
181
-
182
- final_report_text = ""
183
- for result in agent.run_gradio_chat(
184
- message=final_prompt,
185
- history=[],
186
- temperature=0.2,
187
- max_new_tokens=MAX_NEW_TOKENS,
188
- max_token=MAX_MODEL_TOKENS,
189
- call_agent=False,
190
- conversation=[],
191
- ):
192
- if isinstance(result, str):
193
- final_report_text += result
194
- elif isinstance(result, list):
195
- for r in result:
196
- if hasattr(r, "content"):
197
- final_report_text += r.content
198
- elif hasattr(result, "content"):
199
- final_report_text += result.content
200
-
201
- cleaned = clean_response(final_report_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
203
- with open(report_path, 'w') as f:
204
- f.write(f"# 🧠 Final Patient Report\n\n{cleaned}")
205
 
206
- messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{cleaned}"})
207
- messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: {os.path.basename(report_path)}"})
 
208
 
209
  except Exception as e:
210
- messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
211
-
212
- return messages, report_path
213
 
214
  def create_ui(agent):
215
- with gr.Blocks(css="""
216
- html, body, .gradio-container {
217
- height: 100vh;
218
- width: 100vw;
219
- padding: 0;
220
- margin: 0;
221
- font-family: 'Inter', sans-serif;
222
- background: #ffffff;
223
- }
224
- .gr-button.primary {
225
- background: #1e88e5;
226
- color: #fff;
227
- border: none;
228
- border-radius: 6px;
229
- font-weight: 600;
230
- }
231
- .gr-button.primary:hover {
232
- background: #1565c0;
233
- }
234
- .gr-chatbot {
235
- border: 1px solid #e0e0e0;
236
- background: #f9f9f9;
237
- border-radius: 10px;
238
- padding: 1rem;
239
- font-size: 15px;
240
- }
241
- .gr-markdown, .gr-file-upload {
242
- background: #ffffff;
243
- border-radius: 8px;
244
- box-shadow: 0 1px 3px rgba(0,0,0,0.08);
245
- }
246
- """) as demo:
247
- gr.Markdown("""
248
- <h2 style='color:#1e88e5'>🩺 Patient History AI Assistant</h2>
249
- <p>Upload a clinical Excel file and receive an advanced diagnostic summary.</p>
250
- """)
251
-
252
  with gr.Row():
253
  with gr.Column(scale=3):
254
- chatbot = gr.Chatbot(label="Clinical Assistant", height=700, type="messages")
255
  with gr.Column(scale=1):
256
- file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
257
- analyze_btn = gr.Button("🧠 Analyze", variant="primary")
258
- report_output = gr.File(label="Download Report", visible=False, interactive=False)
259
 
260
- chatbot_state = gr.State(value=[])
261
 
262
- def update_ui(file, current_state):
263
- messages, report_path = process_final_report(agent, file, current_state)
264
- return messages, gr.update(visible=report_path is not None, value=report_path), messages
265
 
266
- analyze_btn.click(fn=update_ui, inputs=[file_upload, chatbot_state], outputs=[chatbot, report_output, chatbot_state])
267
 
268
  return demo
269
 
270
  if __name__ == "__main__":
271
  try:
272
  agent = init_agent()
273
- demo = create_ui(agent)
274
- demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
275
- except Exception as e:
276
- print(f"Error: {str(e)}")
277
  sys.exit(1)
 
1
+ import sys, os, json, shutil, re, time, gc, hashlib
 
2
  import pandas as pd
 
 
 
 
 
 
3
  from datetime import datetime
 
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from typing import List, Tuple, Dict, Union
6
 
7
+ import gradio as gr
 
 
8
 
9
+ # Constants
10
+ MAX_MODEL_TOKENS = 131072
11
+ MAX_NEW_TOKENS = 4096
12
+ MAX_CHUNK_TOKENS = 8192
13
+ PROMPT_OVERHEAD = 300
14
+
15
+ # Paths
16
+ persistent_dir = "/data/hf_cache"
17
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
18
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
19
  file_cache_dir = os.path.join(persistent_dir, "cache")
20
  report_dir = os.path.join(persistent_dir, "reports")
21
 
22
+ for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
23
+ os.makedirs(d, exist_ok=True)
24
 
25
  os.environ["HF_HOME"] = model_cache_dir
26
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
 
31
 
32
  from txagent.txagent import TxAgent
33
 
34
+ def estimate_tokens(text: str) -> int:
35
+ return len(text) // 4 + 1
 
 
36
 
37
  def clean_response(text: str) -> str:
 
 
 
 
38
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
39
  text = re.sub(r"\n{3,}", "\n\n", text)
40
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
41
  return text.strip()
42
 
43
+ def extract_text_from_excel(path: str) -> str:
 
 
 
44
  all_text = []
45
  try:
46
+ xls = pd.ExcelFile(path)
47
+ for sheet in xls.sheet_names:
48
+ df = xls.parse(sheet).astype(str).fillna("")
 
49
  rows = df.apply(lambda row: " | ".join(row), axis=1)
50
+ all_text += [f"[{sheet}] {line}" for line in rows]
 
51
  except Exception as e:
52
+ raise ValueError(f"Error reading Excel file: {str(e)}")
53
  return "\n".join(all_text)
54
 
55
+ def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
56
+ effective_limit = max_tokens - PROMPT_OVERHEAD
57
+ chunks, current, current_tokens = [], [], 0
58
+ for line in text.split("\n"):
59
+ tokens = estimate_tokens(line)
60
+ if current_tokens + tokens > effective_limit:
61
+ if current:
62
+ chunks.append("\n".join(current))
63
+ current, current_tokens = [line], tokens
 
 
 
64
  else:
65
+ current.append(line)
66
+ current_tokens += tokens
67
+ if current:
68
+ chunks.append("\n".join(current))
69
  return chunks
70
 
71
+ def build_prompt(chunk: str) -> str:
72
+ return f"""### Unstructured Clinical Records
 
73
 
74
+ Analyze the clinical notes below and summarize with:
75
  - Diagnostic Patterns
76
  - Medication Issues
77
  - Missed Opportunities
 
83
  {chunk}
84
 
85
  ---
86
+ Respond concisely in bullet points with clinical reasoning."""
87
+
88
+ def init_agent() -> TxAgent:
89
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
90
+ if not os.path.exists(tool_path):
91
+ shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
92
+
 
93
  agent = TxAgent(
94
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
95
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
96
+ tool_files_dict={"new_tool": tool_path},
97
  force_finish=True,
98
  enable_checker=True,
99
  step_rag_num=4,
100
+ seed=100
 
101
  )
102
  agent.init_model()
103
  return agent
104
 
105
+ def analyze_chunks_parallel(agent, chunks: List[str]) -> List[str]:
106
+ results = [None] * len(chunks)
 
 
 
 
 
107
 
108
+ def analyze(i, chunk):
109
+ prompt = build_prompt(chunk)
110
+ try:
111
+ if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
112
+ return i, f"❌ Chunk {i+1} too long. Skipped."
 
 
 
 
 
 
113
  response = ""
114
+ for r in agent.run_gradio_chat(
115
+ message=prompt,
116
+ history=[],
117
+ temperature=0.2,
118
+ max_new_tokens=MAX_NEW_TOKENS,
119
+ max_token=MAX_MODEL_TOKENS,
120
+ call_agent=False,
121
+ conversation=[]
122
+ ):
123
+ if isinstance(r, str):
124
+ response += r
125
+ elif isinstance(r, list):
126
+ for m in r:
127
+ if hasattr(m, "content"):
128
+ response += m.content
129
+ elif hasattr(r, "content"):
130
+ response += r.content
131
+ gc.collect()
132
+ return i, clean_response(response)
133
+ except Exception as e:
134
+ return i, f"❌ Error in chunk {i+1}: {str(e)}"
135
+
136
+ with ThreadPoolExecutor(max_workers=4) as executor:
137
+ futures = [executor.submit(analyze, i, chunk) for i, chunk in enumerate(chunks)]
138
+ for future in as_completed(futures):
139
+ i, res = future.result()
140
+ results[i] = res
141
+
142
+ return results
143
+
144
+ def generate_final_summary(agent, combined: str) -> str:
145
+ final_prompt = f"""Provide a structured medical report based on the following summaries:
146
+
147
+ {combined}
148
+
149
+ Respond in detailed medical bullet points."""
150
+ full_report = ""
151
+ for r in agent.run_gradio_chat(
152
+ message=final_prompt,
153
+ history=[],
154
+ temperature=0.2,
155
+ max_new_tokens=MAX_NEW_TOKENS,
156
+ max_token=MAX_MODEL_TOKENS,
157
+ call_agent=False,
158
+ conversation=[]
159
+ ):
160
+ if isinstance(r, str):
161
+ full_report += r
162
+ elif isinstance(r, list):
163
+ for m in r:
164
+ if hasattr(m, "content"):
165
+ full_report += m.content
166
+ elif hasattr(r, "content"):
167
+ full_report += r.content
168
+ return clean_response(full_report)
169
+
170
+ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
171
+ if not file or not hasattr(file, "name"):
172
+ messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file."})
173
+ return messages, None
174
+
175
+ messages.append({"role": "user", "content": f"πŸ“‚ Processing file: {os.path.basename(file.name)}"})
176
+ try:
177
+ extracted = extract_text_from_excel(file.name)
178
+ chunks = split_text(extracted)
179
+ messages.append({"role": "assistant", "content": f"πŸ” Split into {len(chunks)} chunks. Analyzing..."})
180
+
181
+ chunk_results = analyze_chunks_parallel(agent, chunks)
182
+ valid = [res for res in chunk_results if not res.startswith("❌")]
183
+
184
+ if not valid:
185
+ messages.append({"role": "assistant", "content": "❌ No valid chunk outputs."})
186
+ return messages, None
187
+
188
+ summary = generate_final_summary(agent, "\n\n".join(valid))
189
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
190
+ with open(report_path, 'w', encoding='utf-8') as f:
191
+ f.write(f"# 🧠 Final Medical Report\n\n{summary}")
192
 
193
+ messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{summary}"})
194
+ messages.append({"role": "assistant", "content": f"βœ… Report saved: {os.path.basename(report_path)}"})
195
+ return messages, report_path
196
 
197
  except Exception as e:
198
+ messages.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
199
+ return messages, None
 
200
 
201
  def create_ui(agent):
202
+ with gr.Blocks() as demo:
203
+ gr.Markdown("<h2 style='color:#1e88e5'>🩺 Patient AI Assistant</h2><p>Upload a clinical Excel file and receive a diagnostic summary.</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  with gr.Row():
205
  with gr.Column(scale=3):
206
+ chatbot = gr.Chatbot(label="Assistant", height=700, type="messages")
207
  with gr.Column(scale=1):
208
+ upload = gr.File(label="Upload Excel", file_types=[".xlsx"])
209
+ analyze = gr.Button("🧠 Analyze", variant="primary")
210
+ download = gr.File(label="Download Report", visible=False, interactive=False)
211
 
212
+ state = gr.State(value=[])
213
 
214
+ def handle_analysis(file, chat):
215
+ messages, report_path = process_report(agent, file, chat)
216
+ return messages, gr.update(visible=bool(report_path), value=report_path), messages
217
 
218
+ analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
219
 
220
  return demo
221
 
222
  if __name__ == "__main__":
223
  try:
224
  agent = init_agent()
225
+ ui = create_ui(agent)
226
+ ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
227
+ except Exception as err:
228
+ print(f"Startup failed: {err}")
229
  sys.exit(1)