Ali2206 commited on
Commit
04c881d
Β·
verified Β·
1 Parent(s): 4bfbcac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -180
app.py CHANGED
@@ -7,7 +7,6 @@ import shutil
7
  import re
8
  from datetime import datetime
9
  import time
10
- from transformers import AutoTokenizer
11
  import asyncio
12
  import logging
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -37,14 +36,7 @@ MAX_MODEL_TOKENS = 131072 # TxAgent's max token limit
37
  MAX_CHUNK_TOKENS = 32768 # Larger chunks to reduce number of chunks
38
  MAX_NEW_TOKENS = 512 # Optimized for fast generation
39
  PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template
40
- MAX_CONCURRENT = 4 # Reduced concurrency to avoid vLLM socket issues
41
-
42
- # Initialize tokenizer for precise token counting
43
- try:
44
- tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
45
- except Exception as e:
46
- print(f"Warning: Could not load tokenizer, falling back to heuristic: {str(e)}")
47
- tokenizer = None
48
 
49
  # Setup logging
50
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -61,13 +53,9 @@ def clean_response(text: str) -> str:
61
  return text.strip()
62
 
63
  def estimate_tokens(text: str) -> int:
64
- """Estimate tokens using tokenizer if available, else fall back to heuristic."""
65
- if tokenizer:
66
- return len(tokenizer.encode(text, add_special_tokens=False))
67
- return len(text) // 3.5 + 1
68
 
69
  def extract_text_from_excel(file_path: str) -> str:
70
- """Extract text from all sheets in an Excel file."""
71
  all_text = []
72
  try:
73
  xls = pd.ExcelFile(file_path)
@@ -82,12 +70,12 @@ def extract_text_from_excel(file_path: str) -> str:
82
  raise ValueError(f"Failed to process Excel file: {str(e)}")
83
  return "\n".join(all_text)
84
 
85
- def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
86
- """Split text into chunks respecting MAX_CHUNK_TOKENS and PROMPT_OVERHEAD."""
87
- effective_max_tokens = max_tokens - PROMPT_OVERHEAD
88
- if effective_max_tokens <= 0:
89
- raise ValueError(f"Effective max tokens ({effective_max_tokens}) must be positive.")
90
-
91
  lines = text.split("\n")
92
  chunks = []
93
  current_chunk = []
@@ -95,7 +83,7 @@ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> Lis
95
 
96
  for line in lines:
97
  line_tokens = estimate_tokens(line)
98
- if current_tokens + line_tokens > effective_max_tokens:
99
  if current_chunk:
100
  chunks.append("\n".join(current_chunk))
101
  current_chunk = [line]
@@ -106,12 +94,11 @@ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> Lis
106
 
107
  if current_chunk:
108
  chunks.append("\n".join(current_chunk))
109
-
110
  logger.info(f"Split text into {len(chunks)} chunks")
111
  return chunks
112
 
113
  def build_prompt_from_text(chunk: str) -> str:
114
- """Build a prompt for analyzing a chunk of clinical data."""
115
  return f"""
116
  ### Unstructured Clinical Records
117
 
@@ -132,7 +119,7 @@ Please analyze the above and provide concise responses (max {MAX_NEW_TOKENS} tok
132
  """
133
 
134
  def init_agent():
135
- """Initialize the TxAgent with optimized vLLM settings for A100 80GB."""
136
  default_tool_path = os.path.abspath("data/new_tool.json")
137
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
138
 
@@ -152,19 +139,17 @@ def init_agent():
152
  agent.init_model()
153
  return agent
154
 
155
- async def process_chunk(agent, chunk: str, chunk_index: int, total_chunks: int) -> Tuple[int, str, str]:
156
- """Process a single chunk with enhanced error handling."""
157
- logger.info(f"Processing chunk {chunk_index+1}/{total_chunks}")
158
- prompt = build_prompt_from_text(chunk)
159
- prompt_tokens = estimate_tokens(prompt)
160
-
161
- if prompt_tokens > MAX_MODEL_TOKENS:
162
- error_msg = f"❌ Chunk {chunk_index+1} prompt too long ({prompt_tokens} tokens). Skipping..."
163
- logger.warning(error_msg)
164
- return chunk_index, "", error_msg
165
-
166
- response = ""
167
  try:
 
 
 
 
 
 
 
 
168
  for result in agent.run_gradio_chat(
169
  message=prompt,
170
  history=[],
@@ -182,143 +167,87 @@ async def process_chunk(agent, chunk: str, chunk_index: int, total_chunks: int)
182
  for r in result:
183
  if hasattr(r, "content"):
184
  response += r.content
185
- status = f"βœ… Chunk {chunk_index+1} analysis complete"
186
- logger.info(status)
187
  except Exception as e:
188
- status = f"❌ Error analyzing chunk {chunk_index+1}: {str(e)}"
189
- logger.error(status)
190
- response = ""
191
-
192
- return chunk_index, clean_response(response), status
193
 
194
- async def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
195
- """Process the Excel file and generate a final report."""
196
- messages = chatbot_state if chatbot_state else []
197
  report_path = None
198
-
199
- if file is None or not hasattr(file, "name"):
200
- messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
201
- return messages, report_path
202
-
203
  try:
204
- messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
205
- messages.append({"role": "assistant", "content": "⏳ Extracting and analyzing data..."})
206
-
207
- # Extract text and split into chunks
 
 
208
  start_time = time.time()
209
- extracted_text = extract_text_from_excel(file.name)
210
- chunks = split_text_into_chunks(extracted_text, max_tokens=MAX_CHUNK_TOKENS)
211
- logger.info(f"Extracted text and split into {len(chunks)} chunks in {time.time() - start_time:.2f} seconds")
212
-
213
- chunk_responses = [None] * len(chunks)
214
- batch_size = MAX_CONCURRENT
215
-
216
- # Process chunks in batches
217
- for batch_start in range(0, len(chunks), batch_size):
218
- batch_chunks = chunks[batch_start:batch_start + batch_size]
219
- batch_indices = list(range(batch_start, min(batch_start + batch_size, len(chunks))))
220
- logger.info(f"Processing batch {batch_start//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size}")
221
-
222
- with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
223
- futures = [
224
- executor.submit(lambda c, i: asyncio.run(process_chunk(agent, c, i, len(chunks))), chunk, i)
225
- for i, chunk in zip(batch_indices, batch_chunks)
226
- ]
227
- for future in as_completed(futures):
228
- chunk_index, response, status = future.result()
229
- chunk_responses[chunk_index] = response
230
- messages.append({"role": "assistant", "content": status})
231
-
232
- # Filter out empty responses
233
- chunk_responses = [r for r in chunk_responses if r]
234
- if not chunk_responses:
235
- messages.append({"role": "assistant", "content": "❌ No valid chunk responses to summarize."})
236
- return messages, report_path
237
-
238
- # Summarize chunk responses incrementally
239
- summary = ""
240
- current_summary_tokens = 0
241
- for i, response in enumerate(chunk_responses):
242
- response_tokens = estimate_tokens(response)
243
- if current_summary_tokens + response_tokens > MAX_MODEL_TOKENS - PROMPT_OVERHEAD - MAX_NEW_TOKENS:
244
- summary_prompt = f"Summarize the following analysis:\n\n{summary}\n\nProvide a concise summary."
245
- summary_response = ""
246
- try:
247
- for result in agent.run_gradio_chat(
248
- message=summary_prompt,
249
- history=[],
250
- temperature=0.2,
251
- max_new_tokens=MAX_NEW_TOKENS,
252
- max_token=MAX_MODEL_TOKENS,
253
- call_agent=False,
254
- conversation=[],
255
- ):
256
- if isinstance(result, str):
257
- summary_response += result
258
- elif hasattr(result, "content"):
259
- summary_response += result.content
260
- elif isinstance(result, list):
261
- for r in result:
262
- if hasattr(r, "content"):
263
- summary_response += r.content
264
- summary = clean_response(summary_response)
265
- current_summary_tokens = estimate_tokens(summary)
266
- except Exception as e:
267
- messages.append({"role": "assistant", "content": f"❌ Error summarizing intermediate results: {str(e)}"})
268
- return messages, report_path
269
-
270
- summary += f"\n\n### Chunk {i+1} Analysis\n{response}"
271
- current_summary_tokens += response_tokens
272
-
273
- # Final summarization
274
- final_prompt = f"Summarize the key findings from the following analyses:\n\n{summary}"
275
  messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
276
-
277
- final_report_text = ""
278
- try:
279
- for result in agent.run_gradio_chat(
280
- message=final_prompt,
281
- history=[],
282
- temperature=0.2,
283
- max_new_tokens=MAX_NEW_TOKENS * 2,
284
- max_token=MAX_MODEL_TOKENS,
285
- call_agent=False,
286
- conversation=[],
287
- ):
288
- if isinstance(result, str):
289
- final_report_text += result
290
- elif hasattr(result, "content"):
291
- final_report_text += result.content
292
- elif isinstance(result, list):
293
- for r in result:
294
- if hasattr(r, "content"):
295
- final_report_text += r.content
296
- except Exception as e:
297
- messages.append({"role": "assistant", "content": f"❌ Error generating final report: {str(e)}"})
298
- return messages, report_path
299
-
300
- final_report = f"# Final Clinical Report\n\n{clean_response(final_report_text)}"
301
- messages[-1]["content"] = f"πŸ“Š Final Report:\n\n{clean_response(final_report_text)}"
302
-
303
- # Save the report
304
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
305
  report_path = os.path.join(report_dir, f"report_{timestamp}.md")
306
 
307
  with open(report_path, 'w') as f:
308
  f.write(final_report)
309
-
310
- messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: report_{timestamp}.md"})
311
- logger.info(f"Total processing time: {time.time() - start_time:.2f} seconds")
312
-
313
- return messages, report_path
314
-
315
  except Exception as e:
316
- messages.append({"role": "assistant", "content": f"❌ Error processing file: {str(e)}"})
317
  logger.error(f"Processing failed: {str(e)}")
318
- return messages, report_path
 
319
 
320
- def create_ui(agent):
321
- """Create the Gradio interface."""
322
  with gr.Blocks(title="Clinical Analysis", css=".gradio-container {max-width: 900px}") as demo:
323
  gr.Markdown("## πŸ₯ Clinical Data Analysis (TxAgent)")
324
 
@@ -342,43 +271,32 @@ def create_ui(agent):
342
  )
343
  report_output = gr.File(
344
  label="Download Report",
345
- visible=False,
346
- interactive=False
347
  )
348
-
349
- # State to maintain chatbot messages
350
- chatbot_state = gr.State(value=[])
351
-
352
- async def update_ui(file, current_state):
353
- if file is None or not hasattr(file, "name"):
354
- messages = current_state if current_state else []
355
- messages.append({"role": "assistant", "content": "❌ Please upload a valid Excel file before analyzing."})
356
- return messages, None
357
- messages, report_path = await process_final_report(agent, file, current_state)
358
- report_update = gr.update(visible=report_path is not None, value=report_path)
359
- return messages, report_update
360
-
361
  analyze_btn.click(
362
- fn=update_ui,
363
- inputs=[file_input, chatbot_state],
364
  outputs=[chatbot, report_output],
365
- api_name="analyze"
366
  )
367
-
368
  return demo
369
 
370
  if __name__ == "__main__":
371
  try:
 
372
  agent = init_agent()
373
  demo = create_ui(agent)
 
 
374
  demo.launch(
375
  server_name="0.0.0.0",
376
  server_port=7860,
377
  show_error=True,
378
  allowed_paths=[report_dir],
379
  share=False,
380
- inline=False,
381
- max_threads=40
382
  )
383
  except Exception as e:
384
  logger.error(f"Application failed: {str(e)}")
 
7
  import re
8
  from datetime import datetime
9
  import time
 
10
  import asyncio
11
  import logging
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
36
  MAX_CHUNK_TOKENS = 32768 # Larger chunks to reduce number of chunks
37
  MAX_NEW_TOKENS = 512 # Optimized for fast generation
38
  PROMPT_OVERHEAD = 500 # Estimated tokens for prompt template
39
+ MAX_CONCURRENT = 4 # Reduced concurrency to avoid vLLM issues
 
 
 
 
 
 
 
40
 
41
  # Setup logging
42
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
53
  return text.strip()
54
 
55
  def estimate_tokens(text: str) -> int:
56
+ return len(text) // 3.5 + 1 # Conservative estimate
 
 
 
57
 
58
  def extract_text_from_excel(file_path: str) -> str:
 
59
  all_text = []
60
  try:
61
  xls = pd.ExcelFile(file_path)
 
70
  raise ValueError(f"Failed to process Excel file: {str(e)}")
71
  return "\n".join(all_text)
72
 
73
+ def split_text_into_chunks(text: str) -> List[str]:
74
+ """Split text into chunks respecting MAX_CHUNK_TOKENS and PROMPT_OVERHEAD"""
75
+ effective_max = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
76
+ if effective_max <= 0:
77
+ raise ValueError("Effective max tokens must be positive")
78
+
79
  lines = text.split("\n")
80
  chunks = []
81
  current_chunk = []
 
83
 
84
  for line in lines:
85
  line_tokens = estimate_tokens(line)
86
+ if current_tokens + line_tokens > effective_max:
87
  if current_chunk:
88
  chunks.append("\n".join(current_chunk))
89
  current_chunk = [line]
 
94
 
95
  if current_chunk:
96
  chunks.append("\n".join(current_chunk))
97
+
98
  logger.info(f"Split text into {len(chunks)} chunks")
99
  return chunks
100
 
101
  def build_prompt_from_text(chunk: str) -> str:
 
102
  return f"""
103
  ### Unstructured Clinical Records
104
 
 
119
  """
120
 
121
  def init_agent():
122
+ """Initialize TxAgent with conservative settings to avoid vLLM issues"""
123
  default_tool_path = os.path.abspath("data/new_tool.json")
124
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
125
 
 
139
  agent.init_model()
140
  return agent
141
 
142
+ def process_chunk_sync(agent, chunk: str, chunk_idx: int) -> Tuple[int, str]:
143
+ """Synchronous wrapper for chunk processing"""
 
 
 
 
 
 
 
 
 
 
144
  try:
145
+ prompt = build_prompt_from_text(chunk)
146
+ prompt_tokens = estimate_tokens(prompt)
147
+
148
+ if prompt_tokens > MAX_MODEL_TOKENS:
149
+ logger.warning(f"Chunk {chunk_idx} prompt too long ({prompt_tokens} tokens)")
150
+ return chunk_idx, ""
151
+
152
+ response = ""
153
  for result in agent.run_gradio_chat(
154
  message=prompt,
155
  history=[],
 
167
  for r in result:
168
  if hasattr(r, "content"):
169
  response += r.content
170
+
171
+ return chunk_idx, clean_response(response)
172
  except Exception as e:
173
+ logger.error(f"Error processing chunk {chunk_idx}: {str(e)}")
174
+ return chunk_idx, ""
 
 
 
175
 
176
+ async def process_file(agent: TxAgent, file_path: str) -> Generator[Tuple[List[Dict[str, str]], Union[str, None]], None, None]:
177
+ """Process the file with improved error handling and vLLM stability"""
178
+ messages = []
179
  report_path = None
180
+
 
 
 
 
181
  try:
182
+ # Initial messages
183
+ messages.append({"role": "user", "content": f"Processing file: {os.path.basename(file_path)}"})
184
+ messages.append({"role": "assistant", "content": "⏳ Extracting data from Excel..."})
185
+ yield messages, None
186
+
187
+ # Extract and chunk text
188
  start_time = time.time()
189
+ text = extract_text_from_excel(file_path)
190
+ chunks = split_text_into_chunks(text)
191
+ messages.append({"role": "assistant", "content": f"βœ… Extracted {len(chunks)} chunks in {time.time()-start_time:.1f}s"})
192
+ yield messages, None
193
+
194
+ # Process chunks sequentially to avoid vLLM socket issues
195
+ chunk_responses = []
196
+ for idx, chunk in enumerate(chunks):
197
+ messages.append({"role": "assistant", "content": f"πŸ” Processing chunk {idx+1}/{len(chunks)}..."})
198
+ yield messages, None
199
+
200
+ _, response = process_chunk_sync(agent, chunk, idx)
201
+ chunk_responses.append(response)
202
+
203
+ messages.append({"role": "assistant", "content": f"βœ… Chunk {idx+1} processed"})
204
+ yield messages, None
205
+
206
+ # Combine and summarize
207
+ combined = "\n\n".join([r for r in chunk_responses if r])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
209
+ yield messages, None
210
+
211
+ final_response = ""
212
+ for result in agent.run_gradio_chat(
213
+ message=f"Summarize these clinical findings:\n\n{combined}",
214
+ history=[],
215
+ temperature=0.2,
216
+ max_new_tokens=MAX_NEW_TOKENS*2,
217
+ max_token=MAX_MODEL_TOKENS,
218
+ call_agent=False,
219
+ conversation=[],
220
+ ):
221
+ if isinstance(result, str):
222
+ final_response += result
223
+ elif hasattr(result, "content"):
224
+ final_response += result.content
225
+ elif isinstance(result, list):
226
+ for r in result:
227
+ if hasattr(r, "content"):
228
+ final_response += r.content
229
+
230
+ messages[-1]["content"] = f"πŸ“Š Generating final report...\n\n{clean_response(final_response)}"
231
+ yield messages, None
232
+
233
+ # Save report
234
+ final_report = f"# Final Clinical Report\n\n{clean_response(final_response)}"
 
 
235
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
236
  report_path = os.path.join(report_dir, f"report_{timestamp}.md")
237
 
238
  with open(report_path, 'w') as f:
239
  f.write(final_report)
240
+
241
+ messages.append({"role": "assistant", "content": f"βœ… Report saved: report_{timestamp}.md"})
242
+ yield messages, report_path
243
+
 
 
244
  except Exception as e:
 
245
  logger.error(f"Processing failed: {str(e)}")
246
+ messages.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
247
+ yield messages, None
248
 
249
+ def create_ui(agent: TxAgent):
250
+ """Create the Gradio interface with simplified interaction"""
251
  with gr.Blocks(title="Clinical Analysis", css=".gradio-container {max-width: 900px}") as demo:
252
  gr.Markdown("## πŸ₯ Clinical Data Analysis (TxAgent)")
253
 
 
271
  )
272
  report_output = gr.File(
273
  label="Download Report",
274
+ visible=False
 
275
  )
276
+
 
 
 
 
 
 
 
 
 
 
 
 
277
  analyze_btn.click(
278
+ fn=lambda file: process_file(agent, file.name) if file else ([{"role": "assistant", "content": "❌ Please upload a file"}], None),
279
+ inputs=[file_input],
280
  outputs=[chatbot, report_output],
281
+ concurrency_limit=1 # Ensure sequential processing
282
  )
283
+
284
  return demo
285
 
286
  if __name__ == "__main__":
287
  try:
288
+ # Initialize with conservative settings
289
  agent = init_agent()
290
  demo = create_ui(agent)
291
+
292
+ # Launch with stability optimizations
293
  demo.launch(
294
  server_name="0.0.0.0",
295
  server_port=7860,
296
  show_error=True,
297
  allowed_paths=[report_dir],
298
  share=False,
299
+ max_threads=4 # Reduced thread count for stability
 
300
  )
301
  except Exception as e:
302
  logger.error(f"Application failed: {str(e)}")