Ali2206 commited on
Commit
1dd5b3f
Β·
verified Β·
1 Parent(s): 095998d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -98
app.py CHANGED
@@ -1,17 +1,26 @@
1
  import sys
2
  import os
3
- import pandas as pd
4
  import json
5
- import gradio as gr
6
- from datetime import datetime
7
  import shutil
8
- import gc
9
  import re
 
 
 
10
  import torch
11
- from typing import List, Tuple, Dict
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
13
 
14
- # Directories
 
 
 
 
 
 
 
 
15
  persistent_dir = "/data/hf_cache"
16
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
17
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
@@ -24,29 +33,27 @@ for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
  os.environ["HF_HOME"] = model_cache_dir
25
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
26
 
27
- # Paths
28
  current_dir = os.path.dirname(os.path.abspath(__file__))
29
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
30
  sys.path.insert(0, src_path)
31
 
32
  from txagent.txagent import TxAgent
33
 
34
- # Constants
35
- MAX_MODEL_TOKENS = 131072
36
- MAX_NEW_TOKENS = 4096
37
- MAX_CHUNK_TOKENS = 8192
38
- PROMPT_OVERHEAD = 300
39
- BATCH_SIZE = 2
40
-
41
  def estimate_tokens(text: str) -> int:
42
  return len(text) // 4 + 1
43
 
 
 
 
 
 
 
44
  def extract_text_from_excel(path: str) -> str:
45
  all_text = []
46
  xls = pd.ExcelFile(path)
47
- for sheet in xls.sheet_names:
48
  try:
49
- df = xls.parse(sheet).astype(str).fillna("")
50
  except Exception:
51
  continue
52
  for _, row in df.iterrows():
@@ -54,36 +61,31 @@ def extract_text_from_excel(path: str) -> str:
54
  if len(non_empty) >= 2:
55
  line = " | ".join(non_empty)
56
  if len(line) > 15:
57
- all_text.append(f"[{sheet}] {line}")
58
  return "\n".join(all_text)
59
 
60
- def split_text(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
61
  effective_limit = max_tokens - PROMPT_OVERHEAD
62
- chunks, current, tokens = [], [], 0
63
  for line in text.split("\n"):
64
- tks = estimate_tokens(line)
65
- if tokens + tks > effective_limit:
66
  if current:
67
  chunks.append("\n".join(current))
68
- current, tokens = [line], tks
69
  else:
70
  current.append(line)
71
- tokens += tks
72
  if current:
73
  chunks.append("\n".join(current))
74
  return chunks
75
 
76
- def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]:
77
- return [chunks[i:i + batch_size] for i in range(0, len(chunks), batch_size)]
78
 
79
  def build_prompt(chunk: str) -> str:
80
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
81
 
82
- def clean_response(text: str) -> str:
83
- text = re.sub(r"\[.*?\]", "", text, flags=re.DOTALL)
84
- text = re.sub(r"\n{3,}", "\n\n", text)
85
- return text.strip()
86
-
87
  def init_agent() -> TxAgent:
88
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
89
  if not os.path.exists(tool_path):
@@ -100,46 +102,45 @@ def init_agent() -> TxAgent:
100
  agent.init_model()
101
  return agent
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
-
105
- def analyze_batches(agent, batches: List[List[str]], max_workers: int = 3) -> List[str]:
106
  results = []
107
-
108
- def process_single_batch(batch):
109
- prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
110
- response = ""
111
- try:
112
- for r in agent.run_gradio_chat(
113
- message=prompt,
114
- history=[],
115
- temperature=0.0,
116
- max_new_tokens=4096,
117
- max_token=131072,
118
- call_agent=False,
119
- conversation=[]
120
- ):
121
- if isinstance(r, str):
122
- response += r
123
- elif isinstance(r, list):
124
- for m in r:
125
- if hasattr(m, "content"):
126
- response += m.content
127
- elif hasattr(r, "content"):
128
- response += r.content
129
- except Exception as e:
130
- response = f"❌ Error: {str(e)}"
131
- return clean_response(response)
132
-
133
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
134
- futures = [executor.submit(process_single_batch, batch) for batch in batches]
135
  for future in as_completed(futures):
136
  results.append(future.result())
137
-
138
  return results
139
 
140
  def generate_final_summary(agent, combined: str) -> str:
141
- final_prompt = f"""Summarize the following clinical summaries into a final medical report:\n\n{combined}"""
142
- response = ""
143
  for r in agent.run_gradio_chat(
144
  message=final_prompt,
145
  history=[],
@@ -150,43 +151,41 @@ def generate_final_summary(agent, combined: str) -> str:
150
  conversation=[]
151
  ):
152
  if isinstance(r, str):
153
- response += r
154
  elif isinstance(r, list):
155
  for m in r:
156
  if hasattr(m, "content"):
157
- response += m.content
158
  elif hasattr(r, "content"):
159
- response += r.content
160
- return clean_response(response)
161
 
162
- def process_file(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], str]:
163
  if not file or not hasattr(file, "name"):
164
  messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file."})
165
  return messages, None
166
 
167
- messages.append({"role": "user", "content": f"πŸ“‚ Processing file: {file.name}"})
168
  try:
169
- extracted_text = extract_text_from_excel(file.name)
170
- chunks = split_text(extracted_text)
171
- batches = batch_chunks(chunks)
172
- messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Analyzing..."})
173
 
174
- batch_outputs = analyze_batches(agent, batches)
175
- valid_outputs = [res for res in batch_outputs if not res.startswith("❌")]
176
 
177
- if not valid_outputs:
178
  messages.append({"role": "assistant", "content": "❌ No valid batch outputs."})
179
  return messages, None
180
 
181
- summary = generate_final_summary(agent, "\n\n".join(valid_outputs))
182
-
183
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
184
- with open(report_path, "w", encoding="utf-8") as f:
185
  f.write(f"# 🧠 Final Medical Report\n\n{summary}")
186
 
187
  messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{summary}"})
188
- messages.append({"role": "assistant", "content": f"βœ… Saved report: {os.path.basename(report_path)}"})
189
-
190
  return messages, report_path
191
 
192
  except Exception as e:
@@ -195,27 +194,35 @@ def process_file(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict
195
 
196
  def create_ui(agent):
197
  with gr.Blocks(css="""
198
- html, body { background-color: #0e1621; color: #e0e0e0; }
199
- button { background: #007bff; color: white; border-radius: 8px; padding: 8px 16px; }
200
- .gr-chatbot { background: #1b2533; border: 1px solid #2a2f45; border-radius: 16px; padding: 10px; }
 
 
 
 
201
  """) as demo:
202
- gr.Markdown("""## 🧠 CPS: Clinical Patient Support Assistant""")
203
- chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
204
- upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
205
- analyze_btn = gr.Button("🧠 Analyze File")
206
- download = gr.File(label="Download Report", visible=False)
207
-
208
- state = gr.State([])
209
-
210
- def handle_analyze(file, chat_state):
211
- messages, report_path = process_file(agent, file, chat_state)
212
  return messages, gr.update(visible=bool(report_path), value=report_path), messages
213
 
214
- analyze_btn.click(fn=handle_analyze, inputs=[upload, state], outputs=[chatbot, download, state])
215
 
216
  return demo
217
 
218
  if __name__ == "__main__":
219
- agent = init_agent()
220
- ui = create_ui(agent)
221
- ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
 
 
 
 
 
1
  import sys
2
  import os
3
+ import gc
4
  import json
 
 
5
  import shutil
 
6
  import re
7
+ import time
8
+ import pandas as pd
9
+ import gradio as gr
10
  import torch
11
+ from typing import List, Tuple, Dict, Union
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ from datetime import datetime
14
 
15
+ # Constants
16
+ MAX_MODEL_TOKENS = 131072
17
+ MAX_NEW_TOKENS = 4096
18
+ MAX_CHUNK_TOKENS = 8192
19
+ PROMPT_OVERHEAD = 300
20
+ BATCH_SIZE = 4 # 4 chunks per batch
21
+ MAX_WORKERS = 6 # 6 parallel batches
22
+
23
+ # Paths
24
  persistent_dir = "/data/hf_cache"
25
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
26
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
 
33
  os.environ["HF_HOME"] = model_cache_dir
34
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
35
 
 
36
  current_dir = os.path.dirname(os.path.abspath(__file__))
37
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
38
  sys.path.insert(0, src_path)
39
 
40
  from txagent.txagent import TxAgent
41
 
 
 
 
 
 
 
 
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 4 + 1
44
 
45
+ def clean_response(text: str) -> str:
46
+ text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
47
+ text = re.sub(r"\n{3,}", "\n\n", text)
48
+ text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
49
+ return text.strip()
50
+
51
  def extract_text_from_excel(path: str) -> str:
52
  all_text = []
53
  xls = pd.ExcelFile(path)
54
+ for sheet_name in xls.sheet_names:
55
  try:
56
+ df = xls.parse(sheet_name).astype(str).fillna("")
57
  except Exception:
58
  continue
59
  for _, row in df.iterrows():
 
61
  if len(non_empty) >= 2:
62
  line = " | ".join(non_empty)
63
  if len(line) > 15:
64
+ all_text.append(f"[{sheet_name}] {line}")
65
  return "\n".join(all_text)
66
 
67
+ def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
68
  effective_limit = max_tokens - PROMPT_OVERHEAD
69
+ chunks, current, current_tokens = [], [], 0
70
  for line in text.split("\n"):
71
+ tokens = estimate_tokens(line)
72
+ if current_tokens + tokens > effective_limit:
73
  if current:
74
  chunks.append("\n".join(current))
75
+ current, current_tokens = [line], tokens
76
  else:
77
  current.append(line)
78
+ current_tokens += tokens
79
  if current:
80
  chunks.append("\n".join(current))
81
  return chunks
82
 
83
+ def batch_chunks(chunks: List[str], batch_size: int = 4) -> List[List[str]]:
84
+ return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
85
 
86
  def build_prompt(chunk: str) -> str:
87
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
88
 
 
 
 
 
 
89
  def init_agent() -> TxAgent:
90
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
91
  if not os.path.exists(tool_path):
 
102
  agent.init_model()
103
  return agent
104
 
105
+ def analyze_batch(agent, batch: List[str]) -> str:
106
+ prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
107
+ response = ""
108
+ try:
109
+ for r in agent.run_gradio_chat(
110
+ message=prompt,
111
+ history=[],
112
+ temperature=0.0,
113
+ max_new_tokens=MAX_NEW_TOKENS,
114
+ max_token=MAX_MODEL_TOKENS,
115
+ call_agent=False,
116
+ conversation=[]
117
+ ):
118
+ if isinstance(r, str):
119
+ response += r
120
+ elif isinstance(r, list):
121
+ for m in r:
122
+ if hasattr(m, "content"):
123
+ response += m.content
124
+ elif hasattr(r, "content"):
125
+ response += r.content
126
+ except Exception as e:
127
+ return f"❌ Error in batch: {str(e)}"
128
+ finally:
129
+ torch.cuda.empty_cache()
130
+ gc.collect()
131
+ return clean_response(response)
132
 
133
+ def analyze_batches_parallel(agent, batches: List[List[str]]) -> List[str]:
 
134
  results = []
135
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
136
+ futures = [executor.submit(analyze_batch, agent, batch) for batch in batches]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  for future in as_completed(futures):
138
  results.append(future.result())
 
139
  return results
140
 
141
  def generate_final_summary(agent, combined: str) -> str:
142
+ final_prompt = f"""Provide a structured medical report based on the following summaries:\n\n{combined}\n\nRespond in detailed medical bullet points."""
143
+ full_report = ""
144
  for r in agent.run_gradio_chat(
145
  message=final_prompt,
146
  history=[],
 
151
  conversation=[]
152
  ):
153
  if isinstance(r, str):
154
+ full_report += r
155
  elif isinstance(r, list):
156
  for m in r:
157
  if hasattr(m, "content"):
158
+ full_report += m.content
159
  elif hasattr(r, "content"):
160
+ full_report += r.content
161
+ return clean_response(full_report)
162
 
163
+ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
164
  if not file or not hasattr(file, "name"):
165
  messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file."})
166
  return messages, None
167
 
168
+ messages.append({"role": "user", "content": f"πŸ“‚ Processing file: {os.path.basename(file.name)}"})
169
  try:
170
+ extracted = extract_text_from_excel(file.name)
171
+ chunks = split_text(extracted)
172
+ batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
173
+ messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Analyzing in parallel..."})
174
 
175
+ batch_results = analyze_batches_parallel(agent, batches)
176
+ valid = [res for res in batch_results if not res.startswith("❌")]
177
 
178
+ if not valid:
179
  messages.append({"role": "assistant", "content": "❌ No valid batch outputs."})
180
  return messages, None
181
 
182
+ summary = generate_final_summary(agent, "\n\n".join(valid))
 
183
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
184
+ with open(report_path, 'w', encoding='utf-8') as f:
185
  f.write(f"# 🧠 Final Medical Report\n\n{summary}")
186
 
187
  messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{summary}"})
188
+ messages.append({"role": "assistant", "content": f"βœ… Report saved: {os.path.basename(report_path)}"})
 
189
  return messages, report_path
190
 
191
  except Exception as e:
 
194
 
195
  def create_ui(agent):
196
  with gr.Blocks(css="""
197
+ html, body, .gradio-container {background-color: #0e1621; color: #e0e0e0; font-family: 'Inter', sans-serif;}
198
+ h2, h3, h4 {color: #89b4fa; font-weight: 600;}
199
+ button.gr-button-primary {background-color: #007bff !important; color: white !important;}
200
+ .gr-chatbot, .gr-markdown, .gr-file-upload {border-radius: 16px; background-color: #1b2533;}
201
+ .gr-chatbot .message {font-size: 16px; padding: 12px; border-radius: 18px;}
202
+ .gr-chatbot .message.user {background-color: #334155;}
203
+ .gr-chatbot .message.assistant {background-color: #1e293b;}
204
  """) as demo:
205
+ gr.Markdown("""<h2>πŸ“„ CPS: Clinical Patient Support System</h2><p>Upload a file and analyze medical notes.</p>""")
206
+ with gr.Column():
207
+ chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
208
+ upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
209
+ analyze = gr.Button("🧠 Analyze", variant="primary")
210
+ download = gr.File(label="Download Report", visible=False, interactive=False)
211
+ state = gr.State(value=[])
212
+
213
+ def handle_analysis(file, chat):
214
+ messages, report_path = process_report(agent, file, chat)
215
  return messages, gr.update(visible=bool(report_path), value=report_path), messages
216
 
217
+ analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
218
 
219
  return demo
220
 
221
  if __name__ == "__main__":
222
+ try:
223
+ agent = init_agent()
224
+ ui = create_ui(agent)
225
+ ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
226
+ except Exception as err:
227
+ print(f"Startup failed: {err}")
228
+ sys.exit(1)