Ali2206 commited on
Commit
a53de3c
Β·
verified Β·
1 Parent(s): 589b0c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -1,14 +1,16 @@
1
- import sys, os, json, shutil, re, time, gc, hashlib
2
  import pandas as pd
3
  from datetime import datetime
4
  from typing import List, Tuple, Dict, Union
5
  import gradio as gr
 
6
 
7
  # Constants
8
  MAX_MODEL_TOKENS = 131072
9
  MAX_NEW_TOKENS = 4096
10
- MAX_CHUNK_TOKENS = 8192 # IMPORTANT: Split input into 8k tokens chunks
11
  PROMPT_OVERHEAD = 300
 
12
 
13
  # Paths
14
  persistent_dir = "/data/hf_cache"
@@ -41,20 +43,17 @@ def clean_response(text: str) -> str:
41
  def extract_text_from_excel(path: str) -> str:
42
  all_text = []
43
  xls = pd.ExcelFile(path)
44
-
45
  for sheet_name in xls.sheet_names:
46
  try:
47
  df = xls.parse(sheet_name).astype(str).fillna("")
48
  except Exception:
49
  continue
50
-
51
  for idx, row in df.iterrows():
52
- non_empty = [cell.strip() for cell in row if cell.strip() != ""]
53
  if len(non_empty) >= 2:
54
  text_line = " | ".join(non_empty)
55
  if len(text_line) > 15:
56
  all_text.append(f"[{sheet_name}] {text_line}")
57
-
58
  return "\n".join(all_text)
59
 
60
  def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
@@ -73,6 +72,9 @@ def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
73
  chunks.append("\n".join(current))
74
  return chunks
75
 
 
 
 
76
  def build_prompt(chunk: str) -> str:
77
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
78
 
@@ -92,19 +94,16 @@ def init_agent() -> TxAgent:
92
  agent.init_model()
93
  return agent
94
 
95
- def analyze_serial(agent, chunks: List[str]) -> List[str]:
96
  results = []
97
- for idx, chunk in enumerate(chunks):
98
- prompt = build_prompt(chunk)
99
- if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
100
- results.append(f"❌ Chunk {idx+1} too long. Skipped.")
101
- continue
102
  response = ""
103
  try:
104
  for r in agent.run_gradio_chat(
105
  message=prompt,
106
  history=[],
107
- temperature=0.2,
108
  max_new_tokens=MAX_NEW_TOKENS,
109
  max_token=MAX_MODEL_TOKENS,
110
  call_agent=False,
@@ -120,7 +119,8 @@ def analyze_serial(agent, chunks: List[str]) -> List[str]:
120
  response += r.content
121
  results.append(clean_response(response))
122
  except Exception as e:
123
- results.append(f"❌ Error in chunk {idx+1}: {str(e)}")
 
124
  gc.collect()
125
  return results
126
 
@@ -130,7 +130,7 @@ def generate_final_summary(agent, combined: str) -> str:
130
  for r in agent.run_gradio_chat(
131
  message=final_prompt,
132
  history=[],
133
- temperature=0.2,
134
  max_new_tokens=MAX_NEW_TOKENS,
135
  max_token=MAX_MODEL_TOKENS,
136
  call_agent=False,
@@ -155,13 +155,14 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
155
  try:
156
  extracted = extract_text_from_excel(file.name)
157
  chunks = split_text(extracted)
158
- messages.append({"role": "assistant", "content": f"πŸ” Split into {len(chunks)} chunks. Analyzing..."})
 
159
 
160
- chunk_results = analyze_serial(agent, chunks)
161
- valid = [res for res in chunk_results if not res.startswith("❌")]
162
 
163
  if not valid:
164
- messages.append({"role": "assistant", "content": "❌ No valid chunk outputs."})
165
  return messages, None
166
 
167
  summary = generate_final_summary(agent, "\n\n".join(valid))
 
1
+ import sys, os, json, shutil, re, time, gc
2
  import pandas as pd
3
  from datetime import datetime
4
  from typing import List, Tuple, Dict, Union
5
  import gradio as gr
6
+ from concurrent.futures import ThreadPoolExecutor
7
 
8
  # Constants
9
  MAX_MODEL_TOKENS = 131072
10
  MAX_NEW_TOKENS = 4096
11
+ MAX_CHUNK_TOKENS = 8192
12
  PROMPT_OVERHEAD = 300
13
+ BATCH_SIZE = 2 # NEW: batch 2 prompts together for faster processing
14
 
15
  # Paths
16
  persistent_dir = "/data/hf_cache"
 
43
  def extract_text_from_excel(path: str) -> str:
44
  all_text = []
45
  xls = pd.ExcelFile(path)
 
46
  for sheet_name in xls.sheet_names:
47
  try:
48
  df = xls.parse(sheet_name).astype(str).fillna("")
49
  except Exception:
50
  continue
 
51
  for idx, row in df.iterrows():
52
+ non_empty = [cell.strip() for cell in row if cell.strip()]
53
  if len(non_empty) >= 2:
54
  text_line = " | ".join(non_empty)
55
  if len(text_line) > 15:
56
  all_text.append(f"[{sheet_name}] {text_line}")
 
57
  return "\n".join(all_text)
58
 
59
  def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
 
72
  chunks.append("\n".join(current))
73
  return chunks
74
 
75
+ def batch_chunks(chunks: List[str], batch_size: int = 2) -> List[List[str]]:
76
+ return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
77
+
78
  def build_prompt(chunk: str) -> str:
79
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
80
 
 
94
  agent.init_model()
95
  return agent
96
 
97
+ def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
98
  results = []
99
+ for batch in batches:
100
+ prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
 
 
 
101
  response = ""
102
  try:
103
  for r in agent.run_gradio_chat(
104
  message=prompt,
105
  history=[],
106
+ temperature=0.0,
107
  max_new_tokens=MAX_NEW_TOKENS,
108
  max_token=MAX_MODEL_TOKENS,
109
  call_agent=False,
 
119
  response += r.content
120
  results.append(clean_response(response))
121
  except Exception as e:
122
+ results.append(f"❌ Error in batch: {str(e)}")
123
+ torch.cuda.empty_cache()
124
  gc.collect()
125
  return results
126
 
 
130
  for r in agent.run_gradio_chat(
131
  message=final_prompt,
132
  history=[],
133
+ temperature=0.0,
134
  max_new_tokens=MAX_NEW_TOKENS,
135
  max_token=MAX_MODEL_TOKENS,
136
  call_agent=False,
 
155
  try:
156
  extracted = extract_text_from_excel(file.name)
157
  chunks = split_text(extracted)
158
+ batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
159
+ messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Analyzing..."})
160
 
161
+ batch_results = analyze_batches(agent, batches)
162
+ valid = [res for res in batch_results if not res.startswith("❌")]
163
 
164
  if not valid:
165
+ messages.append({"role": "assistant", "content": "❌ No valid batch outputs."})
166
  return messages, None
167
 
168
  summary = generate_final_summary(agent, "\n\n".join(valid))