Ali2206 commited on
Commit
7e55ae2
Β·
verified Β·
1 Parent(s): 0152260

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -125
app.py CHANGED
@@ -1,18 +1,17 @@
1
- import sys, os, json, shutil, re, time, gc
 
2
  import pandas as pd
3
- from datetime import datetime
4
- from typing import List, Tuple, Dict, Union
5
  import gradio as gr
 
 
 
 
 
 
6
  from concurrent.futures import ThreadPoolExecutor
7
 
8
- # Constants
9
- MAX_MODEL_TOKENS = 131072
10
- MAX_NEW_TOKENS = 4096
11
- MAX_CHUNK_TOKENS = 8192
12
- PROMPT_OVERHEAD = 300
13
- BATCH_SIZE = 2 # NEW: batch 2 prompts together for faster processing
14
-
15
- # Paths
16
  persistent_dir = "/data/hf_cache"
17
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
18
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
@@ -25,59 +24,66 @@ for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
25
  os.environ["HF_HOME"] = model_cache_dir
26
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
27
 
 
28
  current_dir = os.path.dirname(os.path.abspath(__file__))
29
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
30
  sys.path.insert(0, src_path)
31
 
32
  from txagent.txagent import TxAgent
33
 
 
 
 
 
 
 
 
34
  def estimate_tokens(text: str) -> int:
35
  return len(text) // 4 + 1
36
 
37
- def clean_response(text: str) -> str:
38
- text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
39
- text = re.sub(r"\n{3,}", "\n\n", text)
40
- text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
41
- return text.strip()
42
-
43
  def extract_text_from_excel(path: str) -> str:
44
  all_text = []
45
  xls = pd.ExcelFile(path)
46
- for sheet_name in xls.sheet_names:
47
  try:
48
- df = xls.parse(sheet_name).astype(str).fillna("")
49
  except Exception:
50
  continue
51
- for idx, row in df.iterrows():
52
  non_empty = [cell.strip() for cell in row if cell.strip()]
53
  if len(non_empty) >= 2:
54
- text_line = " | ".join(non_empty)
55
- if len(text_line) > 15:
56
- all_text.append(f"[{sheet_name}] {text_line}")
57
  return "\n".join(all_text)
58
 
59
- def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
60
  effective_limit = max_tokens - PROMPT_OVERHEAD
61
- chunks, current, current_tokens = [], [], 0
62
  for line in text.split("\n"):
63
- tokens = estimate_tokens(line)
64
- if current_tokens + tokens > effective_limit:
65
  if current:
66
  chunks.append("\n".join(current))
67
- current, current_tokens = [line], tokens
68
  else:
69
  current.append(line)
70
- current_tokens += tokens
71
  if current:
72
  chunks.append("\n".join(current))
73
  return chunks
74
 
75
- def batch_chunks(chunks: List[str], batch_size: int = 2) -> List[List[str]]:
76
- return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
77
 
78
  def build_prompt(chunk: str) -> str:
79
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
80
 
 
 
 
 
 
81
  def init_agent() -> TxAgent:
82
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
83
  if not os.path.exists(tool_path):
@@ -97,9 +103,9 @@ def init_agent() -> TxAgent:
97
  def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
98
  results = []
99
  for batch in batches:
100
- prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
101
- response = ""
102
  try:
 
103
  for r in agent.run_gradio_chat(
104
  message=prompt,
105
  history=[],
@@ -119,14 +125,14 @@ def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
119
  response += r.content
120
  results.append(clean_response(response))
121
  except Exception as e:
122
- results.append(f"❌ Error in batch: {str(e)}")
123
- torch.cuda.empty_cache()
124
- gc.collect()
125
  return results
126
 
127
  def generate_final_summary(agent, combined: str) -> str:
128
- final_prompt = f"""Provide a structured medical report based on the following summaries:\n\n{combined}\n\nRespond in detailed medical bullet points."""
129
- full_report = ""
130
  for r in agent.run_gradio_chat(
131
  message=final_prompt,
132
  history=[],
@@ -137,41 +143,43 @@ def generate_final_summary(agent, combined: str) -> str:
137
  conversation=[]
138
  ):
139
  if isinstance(r, str):
140
- full_report += r
141
  elif isinstance(r, list):
142
  for m in r:
143
  if hasattr(m, "content"):
144
- full_report += m.content
145
  elif hasattr(r, "content"):
146
- full_report += r.content
147
- return clean_response(full_report)
148
 
149
- def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
150
  if not file or not hasattr(file, "name"):
151
  messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file."})
152
  return messages, None
153
 
154
- messages.append({"role": "user", "content": f"πŸ“‚ Processing file: {os.path.basename(file.name)}"})
155
  try:
156
- extracted = extract_text_from_excel(file.name)
157
- chunks = split_text(extracted)
158
- batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
159
  messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Analyzing..."})
160
 
161
- batch_results = analyze_batches(agent, batches)
162
- valid = [res for res in batch_results if not res.startswith("❌")]
163
 
164
- if not valid:
165
  messages.append({"role": "assistant", "content": "❌ No valid batch outputs."})
166
  return messages, None
167
 
168
- summary = generate_final_summary(agent, "\n\n".join(valid))
 
169
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
170
- with open(report_path, 'w', encoding='utf-8') as f:
171
  f.write(f"# 🧠 Final Medical Report\n\n{summary}")
172
 
173
  messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{summary}"})
174
- messages.append({"role": "assistant", "content": f"βœ… Report saved: {os.path.basename(report_path)}"})
 
175
  return messages, report_path
176
 
177
  except Exception as e:
@@ -180,84 +188,27 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
180
 
181
  def create_ui(agent):
182
  with gr.Blocks(css="""
183
- html, body, .gradio-container {
184
- background-color: #0e1621;
185
- color: #e0e0e0;
186
- font-family: 'Inter', sans-serif;
187
- padding: 0;
188
- margin: 0;
189
- }
190
- h2, h3, h4 {
191
- color: #89b4fa;
192
- font-weight: 600;
193
- }
194
- button.gr-button-primary {
195
- background-color: #007bff !important;
196
- color: white !important;
197
- font-weight: bold;
198
- border-radius: 8px !important;
199
- padding: 0.65em 1.2em !important;
200
- font-size: 16px !important;
201
- border: none;
202
- }
203
- button.gr-button-primary:hover {
204
- background-color: #0056b3 !important;
205
- }
206
- .gr-chatbot, .gr-markdown, .gr-file-upload {
207
- border-radius: 16px;
208
- background-color: #1b2533;
209
- border: 1px solid #2a2f45;
210
- padding: 10px;
211
- }
212
- .gr-chatbot .message {
213
- font-size: 16px;
214
- padding: 12px 16px;
215
- border-radius: 18px;
216
- margin: 8px 0;
217
- max-width: 80%;
218
- word-break: break-word;
219
- white-space: pre-wrap;
220
- }
221
- .gr-chatbot .message.user {
222
- background-color: #334155;
223
- align-self: flex-end;
224
- margin-left: auto;
225
- }
226
- .gr-chatbot .message.assistant {
227
- background-color: #1e293b;
228
- align-self: flex-start;
229
- margin-right: auto;
230
- }
231
- .gr-file-upload .file-name {
232
- font-size: 14px;
233
- color: #89b4fa;
234
- }
235
  """) as demo:
236
- gr.Markdown("""
237
- <h2>πŸ“„ CPS: Clinical Patient Support System</h2>
238
- <p>CPS Assistant helps you analyze and summarize unstructured medical files using AI.</p>
239
- """)
240
- with gr.Column():
241
- chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
242
- upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
243
- analyze = gr.Button("🧠 Analyze", variant="primary")
244
- download = gr.File(label="Download Report", visible=False, interactive=False)
245
 
246
- state = gr.State(value=[])
247
 
248
- def handle_analysis(file, chat):
249
- messages, report_path = process_report(agent, file, chat)
250
  return messages, gr.update(visible=bool(report_path), value=report_path), messages
251
 
252
- analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
253
 
254
  return demo
255
 
256
  if __name__ == "__main__":
257
- try:
258
- agent = init_agent()
259
- ui = create_ui(agent)
260
- ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
261
- except Exception as err:
262
- print(f"Startup failed: {err}")
263
- sys.exit(1)
 
1
+ import sys
2
+ import os
3
  import pandas as pd
4
+ import json
 
5
  import gradio as gr
6
+ from datetime import datetime
7
+ import shutil
8
+ import gc
9
+ import re
10
+ import torch
11
+ from typing import List, Tuple, Dict
12
  from concurrent.futures import ThreadPoolExecutor
13
 
14
+ # Directories
 
 
 
 
 
 
 
15
  persistent_dir = "/data/hf_cache"
16
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
17
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
 
24
  os.environ["HF_HOME"] = model_cache_dir
25
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
26
 
27
+ # Paths
28
  current_dir = os.path.dirname(os.path.abspath(__file__))
29
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
30
  sys.path.insert(0, src_path)
31
 
32
  from txagent.txagent import TxAgent
33
 
34
+ # Constants
35
+ MAX_MODEL_TOKENS = 131072
36
+ MAX_NEW_TOKENS = 4096
37
+ MAX_CHUNK_TOKENS = 8192
38
+ PROMPT_OVERHEAD = 300
39
+ BATCH_SIZE = 2
40
+
41
  def estimate_tokens(text: str) -> int:
42
  return len(text) // 4 + 1
43
 
 
 
 
 
 
 
44
  def extract_text_from_excel(path: str) -> str:
45
  all_text = []
46
  xls = pd.ExcelFile(path)
47
+ for sheet in xls.sheet_names:
48
  try:
49
+ df = xls.parse(sheet).astype(str).fillna("")
50
  except Exception:
51
  continue
52
+ for _, row in df.iterrows():
53
  non_empty = [cell.strip() for cell in row if cell.strip()]
54
  if len(non_empty) >= 2:
55
+ line = " | ".join(non_empty)
56
+ if len(line) > 15:
57
+ all_text.append(f"[{sheet}] {line}")
58
  return "\n".join(all_text)
59
 
60
+ def split_text(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
61
  effective_limit = max_tokens - PROMPT_OVERHEAD
62
+ chunks, current, tokens = [], [], 0
63
  for line in text.split("\n"):
64
+ tks = estimate_tokens(line)
65
+ if tokens + tks > effective_limit:
66
  if current:
67
  chunks.append("\n".join(current))
68
+ current, tokens = [line], tks
69
  else:
70
  current.append(line)
71
+ tokens += tks
72
  if current:
73
  chunks.append("\n".join(current))
74
  return chunks
75
 
76
+ def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]:
77
+ return [chunks[i:i + batch_size] for i in range(0, len(chunks), batch_size)]
78
 
79
  def build_prompt(chunk: str) -> str:
80
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
81
 
82
+ def clean_response(text: str) -> str:
83
+ text = re.sub(r"\[.*?\]", "", text, flags=re.DOTALL)
84
+ text = re.sub(r"\n{3,}", "\n\n", text)
85
+ return text.strip()
86
+
87
  def init_agent() -> TxAgent:
88
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
89
  if not os.path.exists(tool_path):
 
103
  def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
104
  results = []
105
  for batch in batches:
106
+ prompt = "\n\n".join(build_prompt(c) for c in batch)
 
107
  try:
108
+ response = ""
109
  for r in agent.run_gradio_chat(
110
  message=prompt,
111
  history=[],
 
125
  response += r.content
126
  results.append(clean_response(response))
127
  except Exception as e:
128
+ results.append(f"❌ Error: {str(e)}")
129
+ torch.cuda.empty_cache()
130
+ gc.collect()
131
  return results
132
 
133
  def generate_final_summary(agent, combined: str) -> str:
134
+ final_prompt = f"""Summarize the following clinical summaries into a final medical report:\n\n{combined}"""
135
+ response = ""
136
  for r in agent.run_gradio_chat(
137
  message=final_prompt,
138
  history=[],
 
143
  conversation=[]
144
  ):
145
  if isinstance(r, str):
146
+ response += r
147
  elif isinstance(r, list):
148
  for m in r:
149
  if hasattr(m, "content"):
150
+ response += m.content
151
  elif hasattr(r, "content"):
152
+ response += r.content
153
+ return clean_response(response)
154
 
155
+ def process_file(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], str]:
156
  if not file or not hasattr(file, "name"):
157
  messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file."})
158
  return messages, None
159
 
160
+ messages.append({"role": "user", "content": f"πŸ“‚ Processing file: {file.name}"})
161
  try:
162
+ extracted_text = extract_text_from_excel(file.name)
163
+ chunks = split_text(extracted_text)
164
+ batches = batch_chunks(chunks)
165
  messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Analyzing..."})
166
 
167
+ batch_outputs = analyze_batches(agent, batches)
168
+ valid_outputs = [res for res in batch_outputs if not res.startswith("❌")]
169
 
170
+ if not valid_outputs:
171
  messages.append({"role": "assistant", "content": "❌ No valid batch outputs."})
172
  return messages, None
173
 
174
+ summary = generate_final_summary(agent, "\n\n".join(valid_outputs))
175
+
176
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
177
+ with open(report_path, "w", encoding="utf-8") as f:
178
  f.write(f"# 🧠 Final Medical Report\n\n{summary}")
179
 
180
  messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{summary}"})
181
+ messages.append({"role": "assistant", "content": f"βœ… Saved report: {os.path.basename(report_path)}"})
182
+
183
  return messages, report_path
184
 
185
  except Exception as e:
 
188
 
189
  def create_ui(agent):
190
  with gr.Blocks(css="""
191
+ html, body { background-color: #0e1621; color: #e0e0e0; }
192
+ button { background: #007bff; color: white; border-radius: 8px; padding: 8px 16px; }
193
+ .gr-chatbot { background: #1b2533; border: 1px solid #2a2f45; border-radius: 16px; padding: 10px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  """) as demo:
195
+ gr.Markdown("""## 🧠 CPS: Clinical Patient Support Assistant""")
196
+ chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
197
+ upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
198
+ analyze_btn = gr.Button("🧠 Analyze File")
199
+ download = gr.File(label="Download Report", visible=False)
 
 
 
 
200
 
201
+ state = gr.State([])
202
 
203
+ def handle_analyze(file, chat_state):
204
+ messages, report_path = process_file(agent, file, chat_state)
205
  return messages, gr.update(visible=bool(report_path), value=report_path), messages
206
 
207
+ analyze_btn.click(fn=handle_analyze, inputs=[upload, state], outputs=[chatbot, download, state])
208
 
209
  return demo
210
 
211
  if __name__ == "__main__":
212
+ agent = init_agent()
213
+ ui = create_ui(agent)
214
+ ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)