Ali2206 commited on
Commit
3ed8d49
·
verified ·
1 Parent(s): 6287195

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -256
app.py CHANGED
@@ -3,15 +3,14 @@ import os
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
- from typing import List, Tuple, Dict, Any, Union
7
  import hashlib
8
  import shutil
9
  import re
10
  from datetime import datetime
11
- import time
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
 
14
- # Configuration and setup
15
  persistent_dir = "/data/hf_cache"
16
  os.makedirs(persistent_dir, exist_ok=True)
17
 
@@ -20,29 +19,21 @@ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
20
  file_cache_dir = os.path.join(persistent_dir, "cache")
21
  report_dir = os.path.join(persistent_dir, "reports")
22
 
23
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
- os.makedirs(directory, exist_ok=True)
25
 
26
  os.environ["HF_HOME"] = model_cache_dir
27
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
28
 
29
- current_dir = os.path.dirname(os.path.abspath(__file__))
30
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
31
- sys.path.insert(0, src_path)
32
-
33
  from txagent.txagent import TxAgent
34
 
35
- # Constants
36
  MAX_MODEL_TOKENS = 32768
37
  MAX_CHUNK_TOKENS = 8192
38
  MAX_NEW_TOKENS = 2048
39
  PROMPT_OVERHEAD = 500
40
 
41
  def clean_response(text: str) -> str:
42
- try:
43
- text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
44
- except UnicodeError:
45
- text = text.encode('utf-8', 'replace').decode('utf-8')
46
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
47
  text = re.sub(r"\n{3,}", "\n\n", text)
48
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
@@ -51,286 +42,126 @@ def clean_response(text: str) -> str:
51
  def estimate_tokens(text: str) -> int:
52
  return len(text) // 3.5 + 1
53
 
54
- def extract_text_from_excel(file_path: str) -> str:
55
  all_text = []
56
  try:
57
- xls = pd.ExcelFile(file_path)
58
- for sheet_name in xls.sheet_names:
59
- df = xls.parse(sheet_name)
60
- df = df.astype(str).fillna("")
61
- rows = df.apply(lambda row: " | ".join(row), axis=1)
62
- sheet_text = [f"[{sheet_name}] {line}" for line in rows]
63
- all_text.extend(sheet_text)
64
  except Exception as e:
65
- raise ValueError(f"Failed to extract text from Excel file: {str(e)}")
 
 
 
 
 
66
  return "\n".join(all_text)
67
 
68
- def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
69
- effective_max_tokens = max_tokens - PROMPT_OVERHEAD
70
- if effective_max_tokens <= 0:
71
- raise ValueError(f"Effective max tokens ({effective_max_tokens}) must be positive.")
72
- lines = text.split("\n")
73
- chunks, current_chunk, current_tokens = [], [], 0
74
  for line in lines:
75
- line_tokens = estimate_tokens(line)
76
- if current_tokens + line_tokens > effective_max_tokens:
77
- if current_chunk:
78
- chunks.append("\n".join(current_chunk))
79
- current_chunk, current_tokens = [line], line_tokens
 
 
80
  else:
81
- current_chunk.append(line)
82
- current_tokens += line_tokens
83
- if current_chunk:
84
- chunks.append("\n".join(current_chunk))
85
  return chunks
86
 
87
  def build_prompt_from_text(chunk: str) -> str:
88
  return f"""
89
  ### Unstructured Clinical Records
90
 
91
- You are reviewing unstructured, mixed-format clinical documentation from various forms, tables, and sheets.
92
-
93
- **Objective:** Identify patterns, missed diagnoses, inconsistencies, and follow-up gaps.
94
-
95
- Here is the extracted content chunk:
96
-
97
- {chunk}
98
-
99
- Please analyze the above and provide:
100
  - Diagnostic Patterns
101
  - Medication Issues
102
  - Missed Opportunities
103
  - Inconsistencies
104
  - Follow-up Recommendations
 
 
 
 
 
 
 
105
  """
106
 
107
  def init_agent():
108
- default_tool_path = os.path.abspath("data/new_tool.json")
109
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
110
- if not os.path.exists(target_tool_path):
111
- shutil.copy(default_tool_path, target_tool_path)
112
  agent = TxAgent(
113
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
114
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
115
- tool_files_dict={"new_tool": target_tool_path},
116
  force_finish=True,
117
  enable_checker=True,
118
  step_rag_num=4,
119
- seed=100,
120
- additional_default_tools=[]
121
  )
122
  agent.init_model()
123
  return agent
124
 
125
- def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
126
- messages = chatbot_state if chatbot_state else []
127
- report_path = None
128
-
129
- if file is None or not hasattr(file, "name"):
130
- messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
131
- return messages, report_path
132
-
133
- try:
134
- messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
135
- messages.append({"role": "assistant", "content": "⏳ Extracting and analyzing data..."})
136
- extracted_text = extract_text_from_excel(file.name)
137
- chunks = split_text_into_chunks(extracted_text)
138
- chunk_responses = [None] * len(chunks)
139
-
140
- def analyze_chunk(index: int, chunk: str) -> Tuple[int, str]:
141
- prompt = build_prompt_from_text(chunk)
142
- prompt_tokens = estimate_tokens(prompt)
143
- if prompt_tokens > MAX_MODEL_TOKENS:
144
- return index, f"❌ Chunk {index+1} prompt too long ({prompt_tokens} tokens). Skipping..."
145
- response = ""
146
- try:
147
- for result in agent.run_gradio_chat(
148
- message=prompt,
149
- history=[],
150
- temperature=0.2,
151
- max_new_tokens=MAX_NEW_TOKENS,
152
- max_token=MAX_MODEL_TOKENS,
153
- call_agent=False,
154
- conversation=[],
155
- ):
156
- if isinstance(result, str):
157
- response += result
158
- elif hasattr(result, "content"):
159
- response += result.content
160
- elif isinstance(result, list):
161
- for r in result:
162
- if hasattr(r, "content"):
163
- response += r.content
164
- except Exception as e:
165
- return index, f"❌ Error analyzing chunk {index+1}: {str(e)}"
166
- return index, clean_response(response)
167
-
168
- with ThreadPoolExecutor(max_workers=1) as executor:
169
- futures = [executor.submit(analyze_chunk, i, chunk) for i, chunk in enumerate(chunks)]
170
- for future in as_completed(futures):
171
- i, result = future.result()
172
- chunk_responses[i] = result
173
- if not result.startswith("❌"):
174
- messages.append({"role": "assistant", "content": f"✅ Chunk {i+1} analysis complete"})
175
- else:
176
- messages.append({"role": "assistant", "content": result})
177
-
178
- valid_responses = [res for res in chunk_responses if not res.startswith("❌")]
179
- if not valid_responses:
180
- messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
181
- return messages, report_path
182
-
183
- summary = ""
184
- current_summary_tokens = 0
185
- for i, response in enumerate(valid_responses):
186
- response_tokens = estimate_tokens(response)
187
- if current_summary_tokens + response_tokens > MAX_MODEL_TOKENS - PROMPT_OVERHEAD - MAX_NEW_TOKENS:
188
- summary_prompt = f"Summarize the following analysis:\n\n{summary}\n\nProvide a concise summary."
189
- summary_response = ""
190
- try:
191
- for result in agent.run_gradio_chat(
192
- message=summary_prompt,
193
- history=[],
194
- temperature=0.2,
195
- max_new_tokens=MAX_NEW_TOKENS,
196
- max_token=MAX_MODEL_TOKENS,
197
- call_agent=False,
198
- conversation=[],
199
- ):
200
- if isinstance(result, str):
201
- summary_response += result
202
- elif hasattr(result, "content"):
203
- summary_response += result.content
204
- elif isinstance(result, list):
205
- for r in result:
206
- if hasattr(r, "content"):
207
- summary_response += r.content
208
- summary = clean_response(summary_response)
209
- current_summary_tokens = estimate_tokens(summary)
210
- except Exception as e:
211
- messages.append({"role": "assistant", "content": f"❌ Error summarizing intermediate results: {str(e)}"})
212
- return messages, report_path
213
- summary += f"\n\n### Chunk {i+1} Analysis\n{response}"
214
- current_summary_tokens += response_tokens
215
-
216
- final_prompt = f"Summarize the key findings from the following analyses:\n\n{summary}"
217
- messages.append({"role": "assistant", "content": "📊 Generating final report..."})
218
-
219
- final_report_text = ""
220
- try:
221
- for result in agent.run_gradio_chat(
222
- message=final_prompt,
223
- history=[],
224
- temperature=0.2,
225
- max_new_tokens=MAX_NEW_TOKENS,
226
- max_token=MAX_MODEL_TOKENS,
227
- call_agent=False,
228
- conversation=[],
229
- ):
230
- if isinstance(result, str):
231
- final_report_text += result
232
- elif hasattr(result, "content"):
233
- final_report_text += result.content
234
- elif isinstance(result, list):
235
- for r in result:
236
- if hasattr(r, "content"):
237
- final_report_text += r.content
238
- except Exception as e:
239
- messages.append({"role": "assistant", "content": f"❌ Error generating final report: {str(e)}"})
240
- return messages, report_path
241
-
242
- final_report = f"# 🧠 Final Patient Report\n\n{clean_response(final_report_text)}"
243
- messages[-1]["content"] = f"📊 Final Report:\n\n{clean_response(final_report_text)}"
244
-
245
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
246
- report_path = os.path.join(report_dir, f"report_{timestamp}.md")
247
-
248
- with open(report_path, 'w') as f:
249
- f.write(final_report)
250
-
251
- messages.append({"role": "assistant", "content": f"✅ Report generated and saved: report_{timestamp}.md"})
252
-
253
- except Exception as e:
254
- messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
255
-
256
- return messages, report_path
257
 
258
  def create_ui(agent):
259
- with gr.Blocks(
260
- title="Patient History Chat",
261
- css="""
 
 
 
 
 
262
  .gradio-container {
263
- max-width: 900px !important;
264
- margin: auto;
265
- font-family: 'Segoe UI', sans-serif;
266
- background-color: #f8f9fa;
 
 
 
 
 
 
 
 
 
267
  }
268
- .gr-button.primary {
269
- background: linear-gradient(to right, #4b6cb7, #182848);
270
  color: white;
 
271
  border: none;
 
272
  border-radius: 8px;
 
273
  }
274
- .gr-button.primary:hover {
275
- background: linear-gradient(to right, #3552a3, #101a3e);
276
  }
277
- .gr-file-upload, .gr-chatbot, .gr-markdown {
278
- background-color: white;
279
- border-radius: 10px;
280
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
281
- padding: 1rem;
282
- }
283
- .gr-chatbot {
284
- border-left: 4px solid #4b6cb7;
285
- }
286
- .gr-file-upload input {
287
- font-size: 0.95rem;
288
- }
289
- .chat-message-content p {
290
- margin: 0.3em 0;
291
- }
292
- .chat-message-content ul {
293
- padding-left: 1.2em;
294
- margin: 0.4em 0;
295
- }
296
- """
297
- ) as demo:
298
- gr.Markdown("""
299
- <h2 style='color:#182848'>🏥 Patient History Analysis Tool</h2>
300
- <p style='color:#444;'>Upload an Excel file containing clinical data. The assistant will analyze it for patterns, inconsistencies, and recommendations.</p>
301
- """)
302
-
303
- with gr.Row():
304
- with gr.Column(scale=3):
305
- chatbot = gr.Chatbot(
306
- label="Clinical Assistant",
307
- show_copy_button=True,
308
- height=600,
309
- type="messages",
310
- avatar_images=(None, "https://i.imgur.com/6wX7Zb4.png"),
311
- render_markdown=True
312
- )
313
- with gr.Column(scale=1):
314
- file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"], height=100)
315
- analyze_btn = gr.Button("🧠 Analyze Patient History", variant="primary", elem_classes="primary")
316
- report_output = gr.File(label="Download Report", visible=False, interactive=False)
317
-
318
- chatbot_state = gr.State(value=[])
319
-
320
- def update_ui(file, current_state):
321
- messages, report_path = process_final_report(agent, file, current_state)
322
- formatted_messages = []
323
- for msg in messages:
324
- role = msg.get("role")
325
- content = msg.get("content", "")
326
- if role == "assistant":
327
- content = content.replace("- ", "\n- ")
328
- content = f"<div class='chat-message-content'>{content}</div>"
329
- formatted_messages.append({"role": role, "content": content})
330
- report_update = gr.update(visible=report_path is not None, value=report_path)
331
- return formatted_messages, report_update, formatted_messages
332
-
333
- analyze_btn.click(fn=update_ui, inputs=[file_upload, chatbot_state], outputs=[chatbot, report_output, chatbot_state], api_name="analyze")
334
 
335
  return demo
336
 
@@ -338,7 +169,7 @@ if __name__ == "__main__":
338
  try:
339
  agent = init_agent()
340
  demo = create_ui(agent)
341
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, allowed_paths=["/data/hf_cache/reports"], share=False)
342
  except Exception as e:
343
  print(f"Error: {str(e)}")
344
- sys.exit(1)
 
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
+ from typing import List, Tuple, Union, Generator
7
  import hashlib
8
  import shutil
9
  import re
10
  from datetime import datetime
 
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
 
13
+ # Setup directories
14
  persistent_dir = "/data/hf_cache"
15
  os.makedirs(persistent_dir, exist_ok=True)
16
 
 
19
  file_cache_dir = os.path.join(persistent_dir, "cache")
20
  report_dir = os.path.join(persistent_dir, "reports")
21
 
22
+ for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
23
+ os.makedirs(d, exist_ok=True)
24
 
25
  os.environ["HF_HOME"] = model_cache_dir
26
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
27
 
28
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
 
 
 
29
  from txagent.txagent import TxAgent
30
 
 
31
  MAX_MODEL_TOKENS = 32768
32
  MAX_CHUNK_TOKENS = 8192
33
  MAX_NEW_TOKENS = 2048
34
  PROMPT_OVERHEAD = 500
35
 
36
  def clean_response(text: str) -> str:
 
 
 
 
37
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
38
  text = re.sub(r"\n{3,}", "\n\n", text)
39
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
 
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 3.5 + 1
44
 
45
+ def extract_text_from_excel(file_obj: Union[str, os.PathLike, 'file']) -> str:
46
  all_text = []
47
  try:
48
+ xls = pd.ExcelFile(file_obj)
 
 
 
 
 
 
49
  except Exception as e:
50
+ raise ValueError(f" Error reading Excel file: {e}")
51
+ for sheet_name in xls.sheet_names:
52
+ df = xls.parse(sheet_name).astype(str).fillna("")
53
+ rows = df.apply(lambda row: " | ".join([cell for cell in row if cell.strip()]), axis=1)
54
+ sheet_text = [f"[{sheet_name}] {line}" for line in rows if line.strip()]
55
+ all_text.extend(sheet_text)
56
  return "\n".join(all_text)
57
 
58
+ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS, max_chunks: int = 30) -> List[str]:
59
+ effective_max = max_tokens - PROMPT_OVERHEAD
60
+ lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
 
 
 
61
  for line in lines:
62
+ t = estimate_tokens(line)
63
+ if curr_tokens + t > effective_max:
64
+ if curr_chunk:
65
+ chunks.append("\n".join(curr_chunk))
66
+ if len(chunks) >= max_chunks:
67
+ break
68
+ curr_chunk, curr_tokens = [line], t
69
  else:
70
+ curr_chunk.append(line)
71
+ curr_tokens += t
72
+ if curr_chunk and len(chunks) < max_chunks:
73
+ chunks.append("\n".join(curr_chunk))
74
  return chunks
75
 
76
  def build_prompt_from_text(chunk: str) -> str:
77
  return f"""
78
  ### Unstructured Clinical Records
79
 
80
+ Analyze the following clinical notes and provide a detailed, concise summary focusing on:
 
 
 
 
 
 
 
 
81
  - Diagnostic Patterns
82
  - Medication Issues
83
  - Missed Opportunities
84
  - Inconsistencies
85
  - Follow-up Recommendations
86
+
87
+ ---
88
+
89
+ {chunk}
90
+
91
+ ---
92
+ Respond in well-structured bullet points with medical reasoning.
93
  """
94
 
95
  def init_agent():
96
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
97
+ if not os.path.exists(tool_path):
98
+ shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
 
99
  agent = TxAgent(
100
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
101
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
102
+ tool_files_dict={"new_tool": tool_path},
103
  force_finish=True,
104
  enable_checker=True,
105
  step_rag_num=4,
106
+ seed=100
 
107
  )
108
  agent.init_model()
109
  return agent
110
 
111
+ def stream_report(agent, file: Union[str, 'file'], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
112
+ yield from stream_report_wrapper(agent)(file, full_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def create_ui(agent):
115
+ with gr.Blocks(css="""
116
+ body {
117
+ background: #10141f;
118
+ color: #ffffff;
119
+ font-family: 'Inter', sans-serif;
120
+ margin: 0;
121
+ padding: 0;
122
+ }
123
  .gradio-container {
124
+ padding: 30px;
125
+ width: 100vw;
126
+ max-width: 100%;
127
+ border-radius: 0;
128
+ background-color: #1a1f2e;
129
+ }
130
+ .output-markdown {
131
+ background-color: #131720;
132
+ border-radius: 12px;
133
+ padding: 20px;
134
+ min-height: 600px;
135
+ overflow-y: auto;
136
+ border: 1px solid #2c3344;
137
  }
138
+ .gr-button {
139
+ background: linear-gradient(135deg, #4b4ced, #37b6e9);
140
  color: white;
141
+ font-weight: 500;
142
  border: none;
143
+ padding: 10px 20px;
144
  border-radius: 8px;
145
+ transition: background 0.3s ease;
146
  }
147
+ .gr-button:hover {
148
+ background: linear-gradient(135deg, #37b6e9, #4b4ced);
149
  }
150
+ """) as demo:
151
+ gr.Markdown("""# 🧠 Clinical Reasoning Assistant
152
+ Upload clinical Excel records below and click **Analyze** to generate a medical summary.
153
+ """)
154
+ file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
155
+ analyze_btn = gr.Button("Analyze")
156
+ report_output_markdown = gr.Markdown(elem_classes="output-markdown")
157
+ report_file = gr.File(label="Download Report", visible=False)
158
+ full_output = gr.State(value="")
159
+
160
+ analyze_btn.click(
161
+ fn=stream_report,
162
+ inputs=[file_upload, full_output],
163
+ outputs=[report_output_markdown, report_file, full_output]
164
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  return demo
167
 
 
169
  try:
170
  agent = init_agent()
171
  demo = create_ui(agent)
172
+ demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=True)
173
  except Exception as e:
174
  print(f"Error: {str(e)}")
175
+ sys.exit(1)