Ali2206 commited on
Commit
6287195
Β·
verified Β·
1 Parent(s): d201c84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -87
app.py CHANGED
@@ -3,14 +3,15 @@ import os
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
- from typing import List, Tuple, Union, Generator
7
  import hashlib
8
  import shutil
9
  import re
10
  from datetime import datetime
 
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
 
13
- # Setup directories
14
  persistent_dir = "/data/hf_cache"
15
  os.makedirs(persistent_dir, exist_ok=True)
16
 
@@ -19,21 +20,29 @@ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
19
  file_cache_dir = os.path.join(persistent_dir, "cache")
20
  report_dir = os.path.join(persistent_dir, "reports")
21
 
22
- for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
23
- os.makedirs(d, exist_ok=True)
24
 
25
  os.environ["HF_HOME"] = model_cache_dir
26
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
27
 
28
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
 
 
 
29
  from txagent.txagent import TxAgent
30
 
 
31
  MAX_MODEL_TOKENS = 32768
32
  MAX_CHUNK_TOKENS = 8192
33
  MAX_NEW_TOKENS = 2048
34
  PROMPT_OVERHEAD = 500
35
 
36
  def clean_response(text: str) -> str:
 
 
 
 
37
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
38
  text = re.sub(r"\n{3,}", "\n\n", text)
39
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
@@ -42,126 +51,286 @@ def clean_response(text: str) -> str:
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 3.5 + 1
44
 
45
- def extract_text_from_excel(file_obj: Union[str, os.PathLike, 'file']) -> str:
46
  all_text = []
47
  try:
48
- xls = pd.ExcelFile(file_obj)
 
 
 
 
 
 
49
  except Exception as e:
50
- raise ValueError(f"❌ Error reading Excel file: {e}")
51
- for sheet_name in xls.sheet_names:
52
- df = xls.parse(sheet_name).astype(str).fillna("")
53
- rows = df.apply(lambda row: " | ".join([cell for cell in row if cell.strip()]), axis=1)
54
- sheet_text = [f"[{sheet_name}] {line}" for line in rows if line.strip()]
55
- all_text.extend(sheet_text)
56
  return "\n".join(all_text)
57
 
58
- def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS, max_chunks: int = 30) -> List[str]:
59
- effective_max = max_tokens - PROMPT_OVERHEAD
60
- lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
 
 
 
61
  for line in lines:
62
- t = estimate_tokens(line)
63
- if curr_tokens + t > effective_max:
64
- if curr_chunk:
65
- chunks.append("\n".join(curr_chunk))
66
- if len(chunks) >= max_chunks:
67
- break
68
- curr_chunk, curr_tokens = [line], t
69
  else:
70
- curr_chunk.append(line)
71
- curr_tokens += t
72
- if curr_chunk and len(chunks) < max_chunks:
73
- chunks.append("\n".join(curr_chunk))
74
  return chunks
75
 
76
  def build_prompt_from_text(chunk: str) -> str:
77
  return f"""
78
  ### Unstructured Clinical Records
79
 
80
- Analyze the following clinical notes and provide a detailed, concise summary focusing on:
 
 
 
 
 
 
 
 
81
  - Diagnostic Patterns
82
  - Medication Issues
83
  - Missed Opportunities
84
  - Inconsistencies
85
  - Follow-up Recommendations
86
-
87
- ---
88
-
89
- {chunk}
90
-
91
- ---
92
- Respond in well-structured bullet points with medical reasoning.
93
  """
94
 
95
  def init_agent():
96
- tool_path = os.path.join(tool_cache_dir, "new_tool.json")
97
- if not os.path.exists(tool_path):
98
- shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
 
99
  agent = TxAgent(
100
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
101
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
102
- tool_files_dict={"new_tool": tool_path},
103
  force_finish=True,
104
  enable_checker=True,
105
  step_rag_num=4,
106
- seed=100
 
107
  )
108
  agent.init_model()
109
  return agent
110
 
111
- def stream_report(agent, file: Union[str, 'file'], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
112
- yield from stream_report_wrapper(agent)(file, full_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def create_ui(agent):
115
- with gr.Blocks(css="""
116
- body {
117
- background: #10141f;
118
- color: #ffffff;
119
- font-family: 'Inter', sans-serif;
120
- margin: 0;
121
- padding: 0;
122
- }
123
  .gradio-container {
124
- padding: 30px;
125
- width: 100vw;
126
- max-width: 100%;
127
- border-radius: 0;
128
- background-color: #1a1f2e;
129
- }
130
- .output-markdown {
131
- background-color: #131720;
132
- border-radius: 12px;
133
- padding: 20px;
134
- min-height: 600px;
135
- overflow-y: auto;
136
- border: 1px solid #2c3344;
137
  }
138
- .gr-button {
139
- background: linear-gradient(135deg, #4b4ced, #37b6e9);
140
  color: white;
141
- font-weight: 500;
142
  border: none;
143
- padding: 10px 20px;
144
  border-radius: 8px;
145
- transition: background 0.3s ease;
146
  }
147
- .gr-button:hover {
148
- background: linear-gradient(135deg, #37b6e9, #4b4ced);
149
  }
150
- """) as demo:
151
- gr.Markdown("""# 🧠 Clinical Reasoning Assistant
152
- Upload clinical Excel records below and click **Analyze** to generate a medical summary.
153
- """)
154
- file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
155
- analyze_btn = gr.Button("Analyze")
156
- report_output_markdown = gr.Markdown(elem_classes="output-markdown")
157
- report_file = gr.File(label="Download Report", visible=False)
158
- full_output = gr.State(value="")
159
-
160
- analyze_btn.click(
161
- fn=stream_report,
162
- inputs=[file_upload, full_output],
163
- outputs=[report_output_markdown, report_file, full_output]
164
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  return demo
167
 
@@ -169,7 +338,7 @@ if __name__ == "__main__":
169
  try:
170
  agent = init_agent()
171
  demo = create_ui(agent)
172
- demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=True)
173
  except Exception as e:
174
  print(f"Error: {str(e)}")
175
- sys.exit(1)
 
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
+ from typing import List, Tuple, Dict, Any, Union
7
  import hashlib
8
  import shutil
9
  import re
10
  from datetime import datetime
11
+ import time
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
 
14
+ # Configuration and setup
15
  persistent_dir = "/data/hf_cache"
16
  os.makedirs(persistent_dir, exist_ok=True)
17
 
 
20
  file_cache_dir = os.path.join(persistent_dir, "cache")
21
  report_dir = os.path.join(persistent_dir, "reports")
22
 
23
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
+ os.makedirs(directory, exist_ok=True)
25
 
26
  os.environ["HF_HOME"] = model_cache_dir
27
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
28
 
29
+ current_dir = os.path.dirname(os.path.abspath(__file__))
30
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
31
+ sys.path.insert(0, src_path)
32
+
33
  from txagent.txagent import TxAgent
34
 
35
+ # Constants
36
  MAX_MODEL_TOKENS = 32768
37
  MAX_CHUNK_TOKENS = 8192
38
  MAX_NEW_TOKENS = 2048
39
  PROMPT_OVERHEAD = 500
40
 
41
  def clean_response(text: str) -> str:
42
+ try:
43
+ text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
44
+ except UnicodeError:
45
+ text = text.encode('utf-8', 'replace').decode('utf-8')
46
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
47
  text = re.sub(r"\n{3,}", "\n\n", text)
48
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
 
51
  def estimate_tokens(text: str) -> int:
52
  return len(text) // 3.5 + 1
53
 
54
+ def extract_text_from_excel(file_path: str) -> str:
55
  all_text = []
56
  try:
57
+ xls = pd.ExcelFile(file_path)
58
+ for sheet_name in xls.sheet_names:
59
+ df = xls.parse(sheet_name)
60
+ df = df.astype(str).fillna("")
61
+ rows = df.apply(lambda row: " | ".join(row), axis=1)
62
+ sheet_text = [f"[{sheet_name}] {line}" for line in rows]
63
+ all_text.extend(sheet_text)
64
  except Exception as e:
65
+ raise ValueError(f"Failed to extract text from Excel file: {str(e)}")
 
 
 
 
 
66
  return "\n".join(all_text)
67
 
68
+ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
69
+ effective_max_tokens = max_tokens - PROMPT_OVERHEAD
70
+ if effective_max_tokens <= 0:
71
+ raise ValueError(f"Effective max tokens ({effective_max_tokens}) must be positive.")
72
+ lines = text.split("\n")
73
+ chunks, current_chunk, current_tokens = [], [], 0
74
  for line in lines:
75
+ line_tokens = estimate_tokens(line)
76
+ if current_tokens + line_tokens > effective_max_tokens:
77
+ if current_chunk:
78
+ chunks.append("\n".join(current_chunk))
79
+ current_chunk, current_tokens = [line], line_tokens
 
 
80
  else:
81
+ current_chunk.append(line)
82
+ current_tokens += line_tokens
83
+ if current_chunk:
84
+ chunks.append("\n".join(current_chunk))
85
  return chunks
86
 
87
  def build_prompt_from_text(chunk: str) -> str:
88
  return f"""
89
  ### Unstructured Clinical Records
90
 
91
+ You are reviewing unstructured, mixed-format clinical documentation from various forms, tables, and sheets.
92
+
93
+ **Objective:** Identify patterns, missed diagnoses, inconsistencies, and follow-up gaps.
94
+
95
+ Here is the extracted content chunk:
96
+
97
+ {chunk}
98
+
99
+ Please analyze the above and provide:
100
  - Diagnostic Patterns
101
  - Medication Issues
102
  - Missed Opportunities
103
  - Inconsistencies
104
  - Follow-up Recommendations
 
 
 
 
 
 
 
105
  """
106
 
107
  def init_agent():
108
+ default_tool_path = os.path.abspath("data/new_tool.json")
109
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
110
+ if not os.path.exists(target_tool_path):
111
+ shutil.copy(default_tool_path, target_tool_path)
112
  agent = TxAgent(
113
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
114
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
115
+ tool_files_dict={"new_tool": target_tool_path},
116
  force_finish=True,
117
  enable_checker=True,
118
  step_rag_num=4,
119
+ seed=100,
120
+ additional_default_tools=[]
121
  )
122
  agent.init_model()
123
  return agent
124
 
125
+ def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
126
+ messages = chatbot_state if chatbot_state else []
127
+ report_path = None
128
+
129
+ if file is None or not hasattr(file, "name"):
130
+ messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
131
+ return messages, report_path
132
+
133
+ try:
134
+ messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
135
+ messages.append({"role": "assistant", "content": "⏳ Extracting and analyzing data..."})
136
+ extracted_text = extract_text_from_excel(file.name)
137
+ chunks = split_text_into_chunks(extracted_text)
138
+ chunk_responses = [None] * len(chunks)
139
+
140
+ def analyze_chunk(index: int, chunk: str) -> Tuple[int, str]:
141
+ prompt = build_prompt_from_text(chunk)
142
+ prompt_tokens = estimate_tokens(prompt)
143
+ if prompt_tokens > MAX_MODEL_TOKENS:
144
+ return index, f"❌ Chunk {index+1} prompt too long ({prompt_tokens} tokens). Skipping..."
145
+ response = ""
146
+ try:
147
+ for result in agent.run_gradio_chat(
148
+ message=prompt,
149
+ history=[],
150
+ temperature=0.2,
151
+ max_new_tokens=MAX_NEW_TOKENS,
152
+ max_token=MAX_MODEL_TOKENS,
153
+ call_agent=False,
154
+ conversation=[],
155
+ ):
156
+ if isinstance(result, str):
157
+ response += result
158
+ elif hasattr(result, "content"):
159
+ response += result.content
160
+ elif isinstance(result, list):
161
+ for r in result:
162
+ if hasattr(r, "content"):
163
+ response += r.content
164
+ except Exception as e:
165
+ return index, f"❌ Error analyzing chunk {index+1}: {str(e)}"
166
+ return index, clean_response(response)
167
+
168
+ with ThreadPoolExecutor(max_workers=1) as executor:
169
+ futures = [executor.submit(analyze_chunk, i, chunk) for i, chunk in enumerate(chunks)]
170
+ for future in as_completed(futures):
171
+ i, result = future.result()
172
+ chunk_responses[i] = result
173
+ if not result.startswith("❌"):
174
+ messages.append({"role": "assistant", "content": f"βœ… Chunk {i+1} analysis complete"})
175
+ else:
176
+ messages.append({"role": "assistant", "content": result})
177
+
178
+ valid_responses = [res for res in chunk_responses if not res.startswith("❌")]
179
+ if not valid_responses:
180
+ messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
181
+ return messages, report_path
182
+
183
+ summary = ""
184
+ current_summary_tokens = 0
185
+ for i, response in enumerate(valid_responses):
186
+ response_tokens = estimate_tokens(response)
187
+ if current_summary_tokens + response_tokens > MAX_MODEL_TOKENS - PROMPT_OVERHEAD - MAX_NEW_TOKENS:
188
+ summary_prompt = f"Summarize the following analysis:\n\n{summary}\n\nProvide a concise summary."
189
+ summary_response = ""
190
+ try:
191
+ for result in agent.run_gradio_chat(
192
+ message=summary_prompt,
193
+ history=[],
194
+ temperature=0.2,
195
+ max_new_tokens=MAX_NEW_TOKENS,
196
+ max_token=MAX_MODEL_TOKENS,
197
+ call_agent=False,
198
+ conversation=[],
199
+ ):
200
+ if isinstance(result, str):
201
+ summary_response += result
202
+ elif hasattr(result, "content"):
203
+ summary_response += result.content
204
+ elif isinstance(result, list):
205
+ for r in result:
206
+ if hasattr(r, "content"):
207
+ summary_response += r.content
208
+ summary = clean_response(summary_response)
209
+ current_summary_tokens = estimate_tokens(summary)
210
+ except Exception as e:
211
+ messages.append({"role": "assistant", "content": f"❌ Error summarizing intermediate results: {str(e)}"})
212
+ return messages, report_path
213
+ summary += f"\n\n### Chunk {i+1} Analysis\n{response}"
214
+ current_summary_tokens += response_tokens
215
+
216
+ final_prompt = f"Summarize the key findings from the following analyses:\n\n{summary}"
217
+ messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
218
+
219
+ final_report_text = ""
220
+ try:
221
+ for result in agent.run_gradio_chat(
222
+ message=final_prompt,
223
+ history=[],
224
+ temperature=0.2,
225
+ max_new_tokens=MAX_NEW_TOKENS,
226
+ max_token=MAX_MODEL_TOKENS,
227
+ call_agent=False,
228
+ conversation=[],
229
+ ):
230
+ if isinstance(result, str):
231
+ final_report_text += result
232
+ elif hasattr(result, "content"):
233
+ final_report_text += result.content
234
+ elif isinstance(result, list):
235
+ for r in result:
236
+ if hasattr(r, "content"):
237
+ final_report_text += r.content
238
+ except Exception as e:
239
+ messages.append({"role": "assistant", "content": f"❌ Error generating final report: {str(e)}"})
240
+ return messages, report_path
241
+
242
+ final_report = f"# 🧠 Final Patient Report\n\n{clean_response(final_report_text)}"
243
+ messages[-1]["content"] = f"πŸ“Š Final Report:\n\n{clean_response(final_report_text)}"
244
+
245
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
246
+ report_path = os.path.join(report_dir, f"report_{timestamp}.md")
247
+
248
+ with open(report_path, 'w') as f:
249
+ f.write(final_report)
250
+
251
+ messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: report_{timestamp}.md"})
252
+
253
+ except Exception as e:
254
+ messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
255
+
256
+ return messages, report_path
257
 
258
  def create_ui(agent):
259
+ with gr.Blocks(
260
+ title="Patient History Chat",
261
+ css="""
 
 
 
 
 
262
  .gradio-container {
263
+ max-width: 900px !important;
264
+ margin: auto;
265
+ font-family: 'Segoe UI', sans-serif;
266
+ background-color: #f8f9fa;
 
 
 
 
 
 
 
 
 
267
  }
268
+ .gr-button.primary {
269
+ background: linear-gradient(to right, #4b6cb7, #182848);
270
  color: white;
 
271
  border: none;
 
272
  border-radius: 8px;
 
273
  }
274
+ .gr-button.primary:hover {
275
+ background: linear-gradient(to right, #3552a3, #101a3e);
276
  }
277
+ .gr-file-upload, .gr-chatbot, .gr-markdown {
278
+ background-color: white;
279
+ border-radius: 10px;
280
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
281
+ padding: 1rem;
282
+ }
283
+ .gr-chatbot {
284
+ border-left: 4px solid #4b6cb7;
285
+ }
286
+ .gr-file-upload input {
287
+ font-size: 0.95rem;
288
+ }
289
+ .chat-message-content p {
290
+ margin: 0.3em 0;
291
+ }
292
+ .chat-message-content ul {
293
+ padding-left: 1.2em;
294
+ margin: 0.4em 0;
295
+ }
296
+ """
297
+ ) as demo:
298
+ gr.Markdown("""
299
+ <h2 style='color:#182848'>πŸ₯ Patient History Analysis Tool</h2>
300
+ <p style='color:#444;'>Upload an Excel file containing clinical data. The assistant will analyze it for patterns, inconsistencies, and recommendations.</p>
301
+ """)
302
+
303
+ with gr.Row():
304
+ with gr.Column(scale=3):
305
+ chatbot = gr.Chatbot(
306
+ label="Clinical Assistant",
307
+ show_copy_button=True,
308
+ height=600,
309
+ type="messages",
310
+ avatar_images=(None, "https://i.imgur.com/6wX7Zb4.png"),
311
+ render_markdown=True
312
+ )
313
+ with gr.Column(scale=1):
314
+ file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"], height=100)
315
+ analyze_btn = gr.Button("🧠 Analyze Patient History", variant="primary", elem_classes="primary")
316
+ report_output = gr.File(label="Download Report", visible=False, interactive=False)
317
+
318
+ chatbot_state = gr.State(value=[])
319
+
320
+ def update_ui(file, current_state):
321
+ messages, report_path = process_final_report(agent, file, current_state)
322
+ formatted_messages = []
323
+ for msg in messages:
324
+ role = msg.get("role")
325
+ content = msg.get("content", "")
326
+ if role == "assistant":
327
+ content = content.replace("- ", "\n- ")
328
+ content = f"<div class='chat-message-content'>{content}</div>"
329
+ formatted_messages.append({"role": role, "content": content})
330
+ report_update = gr.update(visible=report_path is not None, value=report_path)
331
+ return formatted_messages, report_update, formatted_messages
332
+
333
+ analyze_btn.click(fn=update_ui, inputs=[file_upload, chatbot_state], outputs=[chatbot, report_output, chatbot_state], api_name="analyze")
334
 
335
  return demo
336
 
 
338
  try:
339
  agent = init_agent()
340
  demo = create_ui(agent)
341
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, allowed_paths=["/data/hf_cache/reports"], share=False)
342
  except Exception as e:
343
  print(f"Error: {str(e)}")
344
+ sys.exit(1)