Ali2206 commited on
Commit
a046927
Β·
verified Β·
1 Parent(s): 1dd5b3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -32
app.py CHANGED
@@ -1,24 +1,27 @@
 
 
1
  import sys
2
  import os
3
- import gc
4
  import json
5
  import shutil
6
  import re
7
  import time
8
- import pandas as pd
9
- import gradio as gr
10
- import torch
11
- from typing import List, Tuple, Dict, Union
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
13
  from datetime import datetime
 
 
14
 
15
  # Constants
16
  MAX_MODEL_TOKENS = 131072
17
  MAX_NEW_TOKENS = 4096
18
  MAX_CHUNK_TOKENS = 8192
19
  PROMPT_OVERHEAD = 300
20
- BATCH_SIZE = 4 # 4 chunks per batch
21
- MAX_WORKERS = 6 # 6 parallel batches
 
22
 
23
  # Paths
24
  persistent_dir = "/data/hf_cache"
@@ -39,6 +42,7 @@ sys.path.insert(0, src_path)
39
 
40
  from txagent.txagent import TxAgent
41
 
 
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 4 + 1
44
 
@@ -56,12 +60,12 @@ def extract_text_from_excel(path: str) -> str:
56
  df = xls.parse(sheet_name).astype(str).fillna("")
57
  except Exception:
58
  continue
59
- for _, row in df.iterrows():
60
  non_empty = [cell.strip() for cell in row if cell.strip()]
61
  if len(non_empty) >= 2:
62
- line = " | ".join(non_empty)
63
- if len(line) > 15:
64
- all_text.append(f"[{sheet_name}] {line}")
65
  return "\n".join(all_text)
66
 
67
  def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
@@ -80,7 +84,7 @@ def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
80
  chunks.append("\n".join(current))
81
  return chunks
82
 
83
- def batch_chunks(chunks: List[str], batch_size: int = 4) -> List[List[str]]:
84
  return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
85
 
86
  def build_prompt(chunk: str) -> str:
@@ -102,12 +106,13 @@ def init_agent() -> TxAgent:
102
  agent.init_model()
103
  return agent
104
 
105
- def analyze_batch(agent, batch: List[str]) -> str:
106
- prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
 
107
  response = ""
108
  try:
109
  for r in agent.run_gradio_chat(
110
- message=prompt,
111
  history=[],
112
  temperature=0.0,
113
  max_new_tokens=MAX_NEW_TOKENS,
@@ -123,19 +128,21 @@ def analyze_batch(agent, batch: List[str]) -> str:
123
  response += m.content
124
  elif hasattr(r, "content"):
125
  response += r.content
 
126
  except Exception as e:
127
- return f"❌ Error in batch: {str(e)}"
128
- finally:
129
- torch.cuda.empty_cache()
130
- gc.collect()
131
- return clean_response(response)
132
 
133
  def analyze_batches_parallel(agent, batches: List[List[str]]) -> List[str]:
134
  results = []
135
- with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
136
- futures = [executor.submit(analyze_batch, agent, batch) for batch in batches]
 
 
 
137
  for future in as_completed(futures):
138
  results.append(future.result())
 
 
139
  return results
140
 
141
  def generate_final_summary(agent, combined: str) -> str:
@@ -170,7 +177,7 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
170
  extracted = extract_text_from_excel(file.name)
171
  chunks = split_text(extracted)
172
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
173
- messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Analyzing in parallel..."})
174
 
175
  batch_results = analyze_batches_parallel(agent, batches)
176
  valid = [res for res in batch_results if not res.startswith("❌")]
@@ -194,20 +201,25 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
194
 
195
  def create_ui(agent):
196
  with gr.Blocks(css="""
197
- html, body, .gradio-container {background-color: #0e1621; color: #e0e0e0; font-family: 'Inter', sans-serif;}
198
- h2, h3, h4 {color: #89b4fa; font-weight: 600;}
199
- button.gr-button-primary {background-color: #007bff !important; color: white !important;}
200
- .gr-chatbot, .gr-markdown, .gr-file-upload {border-radius: 16px; background-color: #1b2533;}
201
- .gr-chatbot .message {font-size: 16px; padding: 12px; border-radius: 18px;}
202
- .gr-chatbot .message.user {background-color: #334155;}
203
- .gr-chatbot .message.assistant {background-color: #1e293b;}
 
 
 
 
204
  """) as demo:
205
- gr.Markdown("""<h2>πŸ“„ CPS: Clinical Patient Support System</h2><p>Upload a file and analyze medical notes.</p>""")
206
  with gr.Column():
207
  chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
208
  upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
209
  analyze = gr.Button("🧠 Analyze", variant="primary")
210
  download = gr.File(label="Download Report", visible=False, interactive=False)
 
211
  state = gr.State(value=[])
212
 
213
  def handle_analysis(file, chat):
@@ -225,4 +237,4 @@ if __name__ == "__main__":
225
  ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
226
  except Exception as err:
227
  print(f"Startup failed: {err}")
228
- sys.exit(1)
 
1
+ # Optimized app.py for A100 GPU (safe parallel batching + no stuck + max performance)
2
+
3
  import sys
4
  import os
 
5
  import json
6
  import shutil
7
  import re
8
  import time
9
+ import gc
10
+ import threading
 
 
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
+ from typing import List, Tuple, Dict, Union
13
  from datetime import datetime
14
+ import pandas as pd
15
+ import gradio as gr
16
 
17
  # Constants
18
  MAX_MODEL_TOKENS = 131072
19
  MAX_NEW_TOKENS = 4096
20
  MAX_CHUNK_TOKENS = 8192
21
  PROMPT_OVERHEAD = 300
22
+ BATCH_SIZE = 2 # Safer for vLLM
23
+ MAX_PARALLEL_JOBS = 2 # Max threads launched in parallel
24
+ SLEEP_BETWEEN_JOBS = 0.5 # Seconds
25
 
26
  # Paths
27
  persistent_dir = "/data/hf_cache"
 
42
 
43
  from txagent.txagent import TxAgent
44
 
45
+ # Utility functions
46
  def estimate_tokens(text: str) -> int:
47
  return len(text) // 4 + 1
48
 
 
60
  df = xls.parse(sheet_name).astype(str).fillna("")
61
  except Exception:
62
  continue
63
+ for idx, row in df.iterrows():
64
  non_empty = [cell.strip() for cell in row if cell.strip()]
65
  if len(non_empty) >= 2:
66
+ text_line = " | ".join(non_empty)
67
+ if len(text_line) > 15:
68
+ all_text.append(f"[{sheet_name}] {text_line}")
69
  return "\n".join(all_text)
70
 
71
  def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
 
84
  chunks.append("\n".join(current))
85
  return chunks
86
 
87
+ def batch_chunks(chunks: List[str], batch_size: int = 2) -> List[List[str]]:
88
  return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
89
 
90
  def build_prompt(chunk: str) -> str:
 
106
  agent.init_model()
107
  return agent
108
 
109
+ def process_single_batch(agent, batch: List[str]) -> str:
110
+ prompts = [build_prompt(chunk) for chunk in batch]
111
+ joined_prompt = "\n\n".join(prompts)
112
  response = ""
113
  try:
114
  for r in agent.run_gradio_chat(
115
+ message=joined_prompt,
116
  history=[],
117
  temperature=0.0,
118
  max_new_tokens=MAX_NEW_TOKENS,
 
128
  response += m.content
129
  elif hasattr(r, "content"):
130
  response += r.content
131
+ return clean_response(response)
132
  except Exception as e:
133
+ return f"❌ Error: {str(e)}"
 
 
 
 
134
 
135
  def analyze_batches_parallel(agent, batches: List[List[str]]) -> List[str]:
136
  results = []
137
+ with ThreadPoolExecutor(max_workers=MAX_PARALLEL_JOBS) as executor:
138
+ futures = []
139
+ for batch in batches:
140
+ futures.append(executor.submit(process_single_batch, agent, batch))
141
+ time.sleep(SLEEP_BETWEEN_JOBS)
142
  for future in as_completed(futures):
143
  results.append(future.result())
144
+ torch.cuda.empty_cache()
145
+ gc.collect()
146
  return results
147
 
148
  def generate_final_summary(agent, combined: str) -> str:
 
177
  extracted = extract_text_from_excel(file.name)
178
  chunks = split_text(extracted)
179
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
180
+ messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Parallel analyzing..."})
181
 
182
  batch_results = analyze_batches_parallel(agent, batches)
183
  valid = [res for res in batch_results if not res.startswith("❌")]
 
201
 
202
  def create_ui(agent):
203
  with gr.Blocks(css="""
204
+ html, body, .gradio-container {
205
+ background-color: #0e1621;
206
+ color: #e0e0e0;
207
+ font-family: 'Inter', sans-serif;
208
+ }
209
+ h2, h3, h4 { color: #89b4fa; font-weight: 600; }
210
+ button.gr-button-primary {
211
+ background-color: #007bff !important;
212
+ color: white !important;
213
+ font-weight: bold;
214
+ }
215
  """) as demo:
216
+ gr.Markdown("""<h2>πŸ“„ CPS: Clinical Patient Support System</h2>""")
217
  with gr.Column():
218
  chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
219
  upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
220
  analyze = gr.Button("🧠 Analyze", variant="primary")
221
  download = gr.File(label="Download Report", visible=False, interactive=False)
222
+
223
  state = gr.State(value=[])
224
 
225
  def handle_analysis(file, chat):
 
237
  ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
238
  except Exception as err:
239
  print(f"Startup failed: {err}")
240
+ sys.exit(1)