Ali2206 commited on
Commit
e41225f
Β·
verified Β·
1 Parent(s): e594ff1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -196
app.py CHANGED
@@ -2,12 +2,11 @@ import sys
2
  import os
3
  import pandas as pd
4
  import gradio as gr
5
- from typing import List, Tuple, Dict, Any, Union
6
  import shutil
7
  import re
8
  from datetime import datetime
9
  import time
10
- from transformers import AutoTokenizer
11
  import asyncio
12
  import logging
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -33,22 +32,15 @@ sys.path.insert(0, src_path)
33
 
34
  from txagent.txagent import TxAgent
35
 
36
- # Constants
37
  MAX_MODEL_TOKENS = 131072 # TxAgent's max token limit
38
  MAX_CHUNK_TOKENS = 32768 # Larger chunks to reduce number of chunks
39
  MAX_NEW_TOKENS = 512 # Optimized for fast generation
40
  PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template
41
  MAX_CONCURRENT = 8 # High concurrency for A100 80GB
42
 
43
- # Initialize tokenizer for precise token counting
44
- try:
45
- tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
46
- except Exception as e:
47
- print(f"Warning: Could not load tokenizer, falling back to heuristic: {str(e)}")
48
- tokenizer = None
49
-
50
  # Setup logging
51
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
52
  logger = logging.getLogger(__name__)
53
 
54
  def clean_response(text: str) -> str:
@@ -62,13 +54,9 @@ def clean_response(text: str) -> str:
62
  return text.strip()
63
 
64
  def estimate_tokens(text: str) -> int:
65
- """Estimate tokens using tokenizer if available, else fall back to heuristic."""
66
- if tokenizer:
67
- return len(tokenizer.encode(text, add_special_tokens=False))
68
- return len(text) // 3.5 + 1
69
 
70
  def extract_text_from_excel(file_path: str) -> str:
71
- """Extract text from all sheets in an Excel file."""
72
  all_text = []
73
  try:
74
  xls = pd.ExcelFile(file_path)
@@ -79,15 +67,16 @@ def extract_text_from_excel(file_path: str) -> str:
79
  sheet_text = [f"[{sheet_name}] {line}" for line in rows]
80
  all_text.extend(sheet_text)
81
  except Exception as e:
82
- raise ValueError(f"Failed to extract text from Excel file: {str(e)}")
 
83
  return "\n".join(all_text)
84
 
85
- def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
86
- """Split text into chunks within token limits, accounting for prompt overhead."""
87
- effective_max_tokens = max_tokens - PROMPT_OVERHEAD
88
- if effective_max_tokens <= 0:
89
- raise ValueError(f"Effective max tokens ({effective_max_tokens}) must be positive.")
90
-
91
  lines = text.split("\n")
92
  chunks = []
93
  current_chunk = []
@@ -95,7 +84,7 @@ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> Lis
95
 
96
  for line in lines:
97
  line_tokens = estimate_tokens(line)
98
- if current_tokens + line_tokens > effective_max_tokens:
99
  if current_chunk:
100
  chunks.append("\n".join(current_chunk))
101
  current_chunk = [line]
@@ -106,11 +95,11 @@ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> Lis
106
 
107
  if current_chunk:
108
  chunks.append("\n".join(current_chunk))
109
-
 
110
  return chunks
111
 
112
  def build_prompt_from_text(chunk: str) -> str:
113
- """Build a prompt for analyzing a chunk of clinical data."""
114
  return f"""
115
  ### Unstructured Clinical Records
116
 
@@ -122,7 +111,7 @@ Here is the extracted content chunk:
122
 
123
  {chunk}
124
 
125
- Please analyze the above and provide:
126
  - Diagnostic Patterns
127
  - Medication Issues
128
  - Missed Opportunities
@@ -131,7 +120,6 @@ Please analyze the above and provide:
131
  """
132
 
133
  def init_agent():
134
- """Initialize the TxAgent with optimized vLLM settings for A100 80GB."""
135
  default_tool_path = os.path.abspath("data/new_tool.json")
136
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
137
 
@@ -146,24 +134,23 @@ def init_agent():
146
  enable_checker=True,
147
  step_rag_num=4,
148
  seed=100,
149
- additional_default_tools=[]
 
150
  )
151
  agent.init_model()
152
  return agent
153
 
154
- async def process_chunk(agent, chunk: str, chunk_index: int, total_chunks: int) -> Tuple[int, str, str]:
155
- """Process a single chunk and return index, response, and status message."""
156
- logger.info(f"Processing chunk {chunk_index+1}/{total_chunks}")
157
- prompt = build_prompt_from_text(chunk)
158
- prompt_tokens = estimate_tokens(prompt)
159
-
160
- if prompt_tokens > MAX_MODEL_TOKENS:
161
- error_msg = f"❌ Chunk {chunk_index+1} prompt too long ({prompt_tokens} tokens). Skipping..."
162
- logger.warning(error_msg)
163
- return chunk_index, "", error_msg
164
-
165
- response = ""
166
  try:
 
 
 
 
 
 
 
 
167
  for result in agent.run_gradio_chat(
168
  message=prompt,
169
  history=[],
@@ -181,205 +168,139 @@ async def process_chunk(agent, chunk: str, chunk_index: int, total_chunks: int)
181
  for r in result:
182
  if hasattr(r, "content"):
183
  response += r.content
184
- status = f"βœ… Chunk {chunk_index+1} analysis complete"
185
- logger.info(status)
186
- except Exception as e:
187
- status = f"❌ Error analyzing chunk {chunk_index+1}: {str(e)}"
188
- logger.error(status)
189
- response = ""
190
 
191
- return chunk_index, clean_response(response), status
 
 
192
 
193
- async def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
194
- """Process the Excel file and generate a final report."""
195
- messages = chatbot_state if chatbot_state else []
196
  report_path = None
197
-
198
- if file is None or not hasattr(file, "name"):
199
- messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
200
- return messages, report_path
201
-
202
  try:
203
- messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
204
- messages.append({"role": "assistant", "content": "⏳ Extracting and analyzing data..."})
205
-
206
- # Extract text and split into chunks
 
 
207
  start_time = time.time()
208
- extracted_text = extract_text_from_excel(file.name)
209
- chunks = split_text_into_chunks(extracted_text, max_tokens=MAX_CHUNK_TOKENS)
210
- logger.info(f"Extracted text and split into {len(chunks)} chunks in {time.time() - start_time:.2f} seconds")
211
-
 
 
212
  chunk_responses = [None] * len(chunks)
213
- batch_size = MAX_CONCURRENT
214
-
215
- # Process chunks in batches
216
- for batch_start in range(0, len(chunks), batch_size):
217
- batch_chunks = chunks[batch_start:batch_start + batch_size]
218
- batch_indices = list(range(batch_start, min(batch_start + batch_size, len(chunks))))
219
- logger.info(f"Processing batch {batch_start//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size}")
220
-
221
- with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
222
- futures = [
223
- executor.submit(lambda c, i: asyncio.run(process_chunk(agent, c, i, len(chunks))), chunk, i)
224
- for i, chunk in zip(batch_indices, batch_chunks)
225
- ]
226
- for future in as_completed(futures):
227
- chunk_index, response, status = future.result()
228
- chunk_responses[chunk_index] = response
229
- messages.append({"role": "assistant", "content": status})
230
-
231
- # Filter out empty responses
232
- chunk_responses = [r for r in chunk_responses if r]
233
- if not chunk_responses:
234
- messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
235
- return messages, report_path
236
-
237
- # Summarize chunk responses incrementally
238
- summary = ""
239
- current_summary_tokens = 0
240
- for i, response in enumerate(chunk_responses):
241
- response_tokens = estimate_tokens(response)
242
- if current_summary_tokens + response_tokens > MAX_MODEL_TOKENS - PROMPT_OVERHEAD - MAX_NEW_TOKENS:
243
- summary_prompt = f"Summarize the following analysis:\n\n{summary}\n\nProvide a concise summary."
244
- summary_response = ""
245
- try:
246
- for result in agent.run_gradio_chat(
247
- message=summary_prompt,
248
- history=[],
249
- temperature=0.2,
250
- max_new_tokens=MAX_NEW_TOKENS,
251
- max_token=MAX_MODEL_TOKENS,
252
- call_agent=False,
253
- conversation=[],
254
- ):
255
- if isinstance(result, str):
256
- summary_response += result
257
- elif hasattr(result, "content"):
258
- summary_response += result.content
259
- elif isinstance(result, list):
260
- for r in result:
261
- if hasattr(r, "content"):
262
- summary_response += r.content
263
- summary = clean_response(summary_response)
264
- current_summary_tokens = estimate_tokens(summary)
265
- except Exception as e:
266
- messages.append({"role": "assistant", "content": f"❌ Error summarizing intermediate results: {str(e)}"})
267
- return messages, report_path
268
-
269
- summary += f"\n\n### Chunk {i+1} Analysis\n{response}"
270
- current_summary_tokens += response_tokens
271
-
272
- # Final summarization
273
- final_prompt = f"Summarize the key findings from the following analyses:\n\n{summary}"
274
  messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
275
-
276
- final_report_text = ""
277
- try:
278
- for result in agent.run_gradio_chat(
279
- message=final_prompt,
280
- history=[],
281
- temperature=0.2,
282
- max_new_tokens=MAX_NEW_TOKENS,
283
- max_token=MAX_MODEL_TOKENS,
284
- call_agent=False,
285
- conversation=[],
286
- ):
287
- if isinstance(result, str):
288
- final_report_text += result
289
- elif hasattr(result, "content"):
290
- final_report_text += result.content
291
- elif isinstance(result, list):
292
- for r in result:
293
- if hasattr(r, "content"):
294
- final_report_text += r.content
295
- except Exception as e:
296
- messages.append({"role": "assistant", "content": f"❌ Error generating final report: {str(e)}"})
297
- return messages, report_path
298
-
299
- final_report = f"# \U0001f9e0 Final Patient Report\n\n{clean_response(final_report_text)}"
300
- messages[-1]["content"] = f"πŸ“Š Final Report:\n\n{clean_response(final_report_text)}"
301
-
302
- # Save the report
303
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
304
  report_path = os.path.join(report_dir, f"report_{timestamp}.md")
305
 
306
  with open(report_path, 'w') as f:
307
  f.write(final_report)
308
-
309
- messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: report_{timestamp}.md"})
310
- logger.info(f"Total processing time: {time.time() - start_time:.2f} seconds")
311
-
312
- return messages, report_path
313
-
314
  except Exception as e:
315
- messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
316
  logger.error(f"Processing failed: {str(e)}")
317
- return messages, report_path
 
318
 
319
- async def create_ui(agent):
320
- """Create the Gradio UI for the patient history analysis tool."""
321
- with gr.Blocks(title="Patient History Chat", css=".gradio-container {max-width: 900px !important}") as demo:
322
- gr.Markdown("## πŸ₯ Patient History Analysis Tool")
323
 
324
  with gr.Row():
325
  with gr.Column(scale=3):
326
  chatbot = gr.Chatbot(
327
- label="Clinical Assistant",
328
  show_copy_button=True,
329
  height=600,
330
- type="messages",
331
- avatar_images=(
332
- None,
333
- "https://i.imgur.com/6wX7Zb4.png"
334
- )
335
  )
336
  with gr.Column(scale=1):
337
- file_upload = gr.File(
338
  label="Upload Excel File",
339
  file_types=[".xlsx"],
340
  height=100
341
  )
342
  analyze_btn = gr.Button(
343
- "🧠 Analyze Patient History",
344
  variant="primary"
345
  )
346
  report_output = gr.File(
347
  label="Download Report",
348
- visible=False,
349
- interactive=False
350
  )
351
-
352
- # State to maintain chatbot messages
353
- chatbot_state = gr.State(value=[])
354
-
355
- async def update_ui(file, current_state):
356
- messages = current_state if current_state else []
357
- messages, report_path = await process_final_report(agent, file, messages)
358
- report_update = gr.update(visible=report_path is not None, value=report_path)
359
- return messages, report_update, messages
360
-
361
  analyze_btn.click(
362
- fn=update_ui,
363
- inputs=[file_upload, chatbot_state],
364
- outputs=[chatbot, report_output, chatbot_state],
365
- api_name="analyze"
366
  )
367
-
368
  return demo
369
 
370
  if __name__ == "__main__":
371
  try:
372
  agent = init_agent()
373
- demo = asyncio.run(create_ui(agent))
374
  demo.launch(
375
  server_name="0.0.0.0",
376
  server_port=7860,
377
  show_error=True,
378
- allowed_paths=["/data/hf_cache/reports"],
379
- share=False,
380
- inline=False,
381
- max_threads=40
382
  )
383
  except Exception as e:
384
- print(f"Error: {str(e)}")
385
  sys.exit(1)
 
2
  import os
3
  import pandas as pd
4
  import gradio as gr
5
+ from typing import List, Tuple, Dict, Any, Union, Generator
6
  import shutil
7
  import re
8
  from datetime import datetime
9
  import time
 
10
  import asyncio
11
  import logging
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
32
 
33
  from txagent.txagent import TxAgent
34
 
35
+ # Updated token limits as specified
36
  MAX_MODEL_TOKENS = 131072 # TxAgent's max token limit
37
  MAX_CHUNK_TOKENS = 32768 # Larger chunks to reduce number of chunks
38
  MAX_NEW_TOKENS = 512 # Optimized for fast generation
39
  PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template
40
  MAX_CONCURRENT = 8 # High concurrency for A100 80GB
41
 
 
 
 
 
 
 
 
42
  # Setup logging
43
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
44
  logger = logging.getLogger(__name__)
45
 
46
  def clean_response(text: str) -> str:
 
54
  return text.strip()
55
 
56
  def estimate_tokens(text: str) -> int:
57
+ return len(text) // 3.5 + 1 # More conservative estimate
 
 
 
58
 
59
  def extract_text_from_excel(file_path: str) -> str:
 
60
  all_text = []
61
  try:
62
  xls = pd.ExcelFile(file_path)
 
67
  sheet_text = [f"[{sheet_name}] {line}" for line in rows]
68
  all_text.extend(sheet_text)
69
  except Exception as e:
70
+ logger.error(f"Error extracting Excel: {str(e)}")
71
+ raise ValueError(f"Failed to process Excel file: {str(e)}")
72
  return "\n".join(all_text)
73
 
74
+ def split_text_into_chunks(text: str) -> List[str]:
75
+ """Split text into chunks respecting MAX_CHUNK_TOKENS and PROMPT_OVERHEAD"""
76
+ effective_max = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
77
+ if effective_max <= 0:
78
+ raise ValueError("Effective max tokens must be positive")
79
+
80
  lines = text.split("\n")
81
  chunks = []
82
  current_chunk = []
 
84
 
85
  for line in lines:
86
  line_tokens = estimate_tokens(line)
87
+ if current_tokens + line_tokens > effective_max:
88
  if current_chunk:
89
  chunks.append("\n".join(current_chunk))
90
  current_chunk = [line]
 
95
 
96
  if current_chunk:
97
  chunks.append("\n".join(current_chunk))
98
+
99
+ logger.info(f"Split text into {len(chunks)} chunks")
100
  return chunks
101
 
102
  def build_prompt_from_text(chunk: str) -> str:
 
103
  return f"""
104
  ### Unstructured Clinical Records
105
 
 
111
 
112
  {chunk}
113
 
114
+ Please analyze the above and provide concise responses (max {MAX_NEW_TOKENS} tokens):
115
  - Diagnostic Patterns
116
  - Medication Issues
117
  - Missed Opportunities
 
120
  """
121
 
122
  def init_agent():
 
123
  default_tool_path = os.path.abspath("data/new_tool.json")
124
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
125
 
 
134
  enable_checker=True,
135
  step_rag_num=4,
136
  seed=100,
137
+ additional_default_tools=[],
138
+ max_model_tokens=MAX_MODEL_TOKENS # Pass the updated token limit
139
  )
140
  agent.init_model()
141
  return agent
142
 
143
+ async def process_chunk(agent: TxAgent, chunk: str, chunk_idx: int) -> Tuple[int, str]:
144
+ """Process a single chunk with error handling"""
 
 
 
 
 
 
 
 
 
 
145
  try:
146
+ prompt = build_prompt_from_text(chunk)
147
+ prompt_tokens = estimate_tokens(prompt)
148
+
149
+ if prompt_tokens > MAX_MODEL_TOKENS:
150
+ logger.warning(f"Chunk {chunk_idx} prompt too long ({prompt_tokens} tokens)")
151
+ return chunk_idx, ""
152
+
153
+ response = ""
154
  for result in agent.run_gradio_chat(
155
  message=prompt,
156
  history=[],
 
168
  for r in result:
169
  if hasattr(r, "content"):
170
  response += r.content
171
+
172
+ return chunk_idx, clean_response(response)
 
 
 
 
173
 
174
+ except Exception as e:
175
+ logger.error(f"Error processing chunk {chunk_idx}: {str(e)}")
176
+ return chunk_idx, ""
177
 
178
+ async def process_file(agent: TxAgent, file_path: str) -> Generator[Tuple[List[Dict[str, str]], Union[str, None]], None, None]:
179
+ """Process the entire file and yield progress updates"""
180
+ messages = []
181
  report_path = None
182
+
 
 
 
 
183
  try:
184
+ # Initial messages
185
+ messages.append({"role": "user", "content": f"Processing file: {os.path.basename(file_path)}"})
186
+ messages.append({"role": "assistant", "content": "⏳ Extracting data from Excel..."})
187
+ yield messages, None
188
+
189
+ # Extract and chunk text
190
  start_time = time.time()
191
+ text = extract_text_from_excel(file_path)
192
+ chunks = split_text_into_chunks(text)
193
+ messages.append({"role": "assistant", "content": f"βœ… Extracted {len(chunks)} chunks in {time.time()-start_time:.1f}s"})
194
+ yield messages, None
195
+
196
+ # Process chunks in parallel
197
  chunk_responses = [None] * len(chunks)
198
+ with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
199
+ futures = []
200
+ for idx, chunk in enumerate(chunks):
201
+ futures.append(executor.submit(
202
+ lambda c, i: asyncio.run(process_chunk(agent, c, i)),
203
+ chunk, idx
204
+ )
205
+ messages.append({"role": "assistant", "content": f"πŸ” Processing chunk {idx+1}/{len(chunks)}..."})
206
+ yield messages, None
207
+
208
+ for future in as_completed(futures):
209
+ idx, response = future.result()
210
+ chunk_responses[idx] = response
211
+ messages.append({"role": "assistant", "content": f"βœ… Chunk {idx+1} processed"})
212
+ yield messages, None
213
+
214
+ # Combine and summarize
215
+ combined = "\n\n".join([r for r in chunk_responses if r])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
217
+ yield messages, None
218
+
219
+ final_response = ""
220
+ for result in agent.run_gradio_chat(
221
+ message=f"Summarize these clinical findings:\n\n{combined}",
222
+ history=[],
223
+ temperature=0.2,
224
+ max_new_tokens=MAX_NEW_TOKENS*2, # Allow more tokens for summary
225
+ max_token=MAX_MODEL_TOKENS,
226
+ call_agent=False,
227
+ conversation=[],
228
+ ):
229
+ if isinstance(result, str):
230
+ final_response += result
231
+ elif hasattr(result, "content"):
232
+ final_response += result.content
233
+ elif isinstance(result, list):
234
+ for r in result:
235
+ if hasattr(r, "content"):
236
+ final_response += r.content
237
+
238
+ messages[-1]["content"] = f"πŸ“Š Generating final report...\n\n{clean_response(final_response)}"
239
+ yield messages, None
240
+
241
+ # Save report
242
+ final_report = f"# Final Clinical Report\n\n{clean_response(final_response)}"
 
 
243
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
244
  report_path = os.path.join(report_dir, f"report_{timestamp}.md")
245
 
246
  with open(report_path, 'w') as f:
247
  f.write(final_report)
248
+
249
+ messages.append({"role": "assistant", "content": f"βœ… Report saved: report_{timestamp}.md"})
250
+ yield messages, report_path
251
+
 
 
252
  except Exception as e:
 
253
  logger.error(f"Processing failed: {str(e)}")
254
+ messages.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
255
+ yield messages, None
256
 
257
+ def create_ui(agent: TxAgent):
258
+ """Create the Gradio interface"""
259
+ with gr.Blocks(title="Clinical Analysis", css=".gradio-container {max-width: 900px}") as demo:
260
+ gr.Markdown("## πŸ₯ Clinical Data Analysis (TxAgent)")
261
 
262
  with gr.Row():
263
  with gr.Column(scale=3):
264
  chatbot = gr.Chatbot(
265
+ label="Analysis Progress",
266
  show_copy_button=True,
267
  height=600,
268
+ type="messages"
 
 
 
 
269
  )
270
  with gr.Column(scale=1):
271
+ file_input = gr.File(
272
  label="Upload Excel File",
273
  file_types=[".xlsx"],
274
  height=100
275
  )
276
  analyze_btn = gr.Button(
277
+ "🧠 Analyze Data",
278
  variant="primary"
279
  )
280
  report_output = gr.File(
281
  label="Download Report",
282
+ visible=False
 
283
  )
284
+
 
 
 
 
 
 
 
 
 
285
  analyze_btn.click(
286
+ fn=lambda file: process_file(agent, file.name) if file else ([{"role": "assistant", "content": "❌ Please upload a file"}], None),
287
+ inputs=[file_input],
288
+ outputs=[chatbot, report_output]
 
289
  )
290
+
291
  return demo
292
 
293
  if __name__ == "__main__":
294
  try:
295
  agent = init_agent()
296
+ demo = create_ui(agent)
297
  demo.launch(
298
  server_name="0.0.0.0",
299
  server_port=7860,
300
  show_error=True,
301
+ allowed_paths=[report_dir],
302
+ share=False
 
 
303
  )
304
  except Exception as e:
305
+ logger.error(f"Application failed: {str(e)}")
306
  sys.exit(1)