Ali2206 commited on
Commit
1a611b9
·
verified ·
1 Parent(s): de75e20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -201
app.py CHANGED
@@ -3,37 +3,33 @@ import os
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
- from typing import List, Tuple, Union, Generator, BinaryIO, Dict, Any
7
  import re
8
  from datetime import datetime
9
  import atexit
10
  import torch.distributed as dist
11
  import logging
12
 
13
- # Setup logging
14
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
- # Cleanup for PyTorch distributed
18
  def cleanup():
19
  if dist.is_initialized():
20
  logger.info("Cleaning up PyTorch distributed process group")
21
  dist.destroy_process_group()
22
-
23
  atexit.register(cleanup)
24
 
25
- # Setup directories
26
  persistent_dir = "/data/hf_cache"
27
  os.makedirs(persistent_dir, exist_ok=True)
28
-
29
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
30
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
31
  file_cache_dir = os.path.join(persistent_dir, "cache")
32
  report_dir = os.path.join(persistent_dir, "reports")
33
-
34
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
35
  os.makedirs(d, exist_ok=True)
36
-
37
  os.environ["HF_HOME"] = model_cache_dir
38
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
39
 
@@ -55,50 +51,40 @@ def estimate_tokens(text: str) -> int:
55
  return len(text) // 3.5 + 1
56
 
57
  def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
58
- """Handle Gradio file upload object which is a dictionary with 'name' and other keys"""
 
 
 
 
 
 
 
 
59
  all_text = []
60
- try:
61
- if isinstance(file_obj, dict) and 'name' in file_obj:
62
- file_path = file_obj['name']
63
- elif isinstance(file_obj, str):
64
- file_path = file_obj
65
- else:
66
- raise ValueError("Unsupported file input type")
67
-
68
- if not os.path.exists(file_path):
69
- raise FileNotFoundError(f"Temporary upload file not found at: {file_path}")
70
-
71
- xls = pd.ExcelFile(file_path)
72
-
73
- for sheet_name in xls.sheet_names:
74
- try:
75
- df = xls.parse(sheet_name).astype(str).fillna("")
76
- rows = df.apply(lambda row: " | ".join([cell for cell in row if cell.strip()]), axis=1)
77
- sheet_text = [f"[{sheet_name}] {line}" for line in rows if line.strip()]
78
- all_text.extend(sheet_text)
79
- except Exception as e:
80
- logger.warning(f"Could not parse sheet {sheet_name}: {e}")
81
- continue
82
-
83
- return "\n".join(all_text)
84
-
85
- except Exception as e:
86
- raise ValueError(f"❌ Error processing Excel file: {str(e)}")
87
 
88
  def split_text_into_chunks(text: str) -> List[str]:
89
- effective_max = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
90
- lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
 
91
  for line in lines:
92
  t = estimate_tokens(line)
93
- if curr_tokens + t > effective_max:
94
- if curr_chunk:
95
- chunks.append("\n".join(curr_chunk))
96
- curr_chunk, curr_tokens = [line], t
97
  else:
98
- curr_chunk.append(line)
99
- curr_tokens += t
100
- if curr_chunk:
101
- chunks.append("\n".join(curr_chunk))
102
  return chunks
103
 
104
  def build_prompt_from_text(chunk: str) -> str:
@@ -120,196 +106,113 @@ Provide a structured response with clear medical reasoning.
120
  """
121
 
122
  def validate_tool_file(tool_name: str, tool_path: str) -> bool:
123
- """Validate the structure of a tool JSON file. Return True if valid, False if invalid."""
124
  try:
125
  if not os.path.exists(tool_path):
126
- logger.error(f"Tool file not found: {tool_path}")
127
  return False
128
-
129
  with open(tool_path, 'r') as f:
130
  tool_data = json.load(f)
131
-
132
- logger.info(f"Contents of {tool_name} ({tool_path}): {tool_data}")
133
-
134
- if isinstance(tool_data, str):
135
- logger.error(f"Invalid tool file {tool_name}: JSON root is a string, expected list or dict")
136
- return False
137
- elif isinstance(tool_data, list):
138
- for item in tool_data:
139
- if not isinstance(item, dict):
140
- logger.error(f"Invalid tool format in {tool_name}: each item must be a dict, got {type(item)}: {item}")
141
- return False
142
- if 'name' not in item:
143
- logger.error(f"Invalid tool format in {tool_name}: each dict must have a 'name' key, got {item}")
144
- return False
145
  elif isinstance(tool_data, dict):
146
  if 'tools' in tool_data:
147
- if not isinstance(tool_data['tools'], list):
148
- logger.error(f"'tools' field in {tool_name} must be a list, got {type(tool_data['tools'])}")
149
- return False
150
- for item in tool_data['tools']:
151
- if not isinstance(item, dict):
152
- logger.error(f"Invalid tool format in {tool_name}: each tool must be a dict, got {type(item)}: {item}")
153
- return False
154
- if 'name' not in item:
155
- logger.error(f"Invalid tool format in {tool_name}: each tool dict must have a 'name' key, got {item}")
156
- return False
157
- else:
158
- if 'name' not in tool_data:
159
- logger.error(f"Invalid tool format in {tool_name}: dict must have a 'name' key or 'tools' field, got {tool_data}")
160
- return False
161
- else:
162
- logger.error(f"Invalid tool file {tool_name}: must be a list or dict, got {type(tool_data)}")
163
- return False
164
-
165
- return True
166
  except Exception as e:
167
- logger.error(f"Error validating tool file {tool_name} ({tool_path}): {str(e)}")
168
  return False
169
 
170
  def init_agent() -> TxAgent:
171
- tool_path = os.path.join(tool_cache_dir, "new_tool.json")
172
- logger.info(f"Checking for tool file at: {tool_path}")
173
-
174
- # Create default tool file if it doesn't exist
175
- if not os.path.exists(tool_path):
176
- default_tool = {
177
- "name": "new_tool",
178
- "description": "Default tool configuration",
179
- "version": "1.0",
180
- "tools": [
181
- {"name": "dummy_tool", "description": "Dummy tool for testing", "version": "1.0"}
182
- ]
183
- }
184
- logger.info(f"Creating default tool file at: {tool_path}")
185
- with open(tool_path, 'w') as f:
186
- json.dump(default_tool, f)
187
-
188
- # Define tool files
189
- tool_files_dict = {
190
  'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
191
  'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
192
  'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
193
  'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
194
- 'new_tool': tool_path
195
  }
196
-
197
- # Validate all tool files and filter invalid ones
198
- valid_tool_files = {}
199
- for tool_name, tool_path in tool_files_dict.items():
200
- if validate_tool_file(tool_name, tool_path):
201
- valid_tool_files[tool_name] = tool_path
202
- else:
203
- logger.warning(f"Skipping invalid tool file: {tool_name} ({tool_path})")
204
-
205
- if not valid_tool_files:
206
- raise ValueError("No valid tool files found after validation")
207
-
208
- # For testing, you can use only new_tool.json to isolate the issue
209
- # valid_tool_files = {'new_tool': tool_path}
210
-
211
- # Initialize TxAgent
212
- try:
213
- logger.info(f"Initializing TxAgent with tool_files_dict: {valid_tool_files}")
214
- agent = TxAgent(
215
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
216
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
217
- tool_files_dict=valid_tool_files,
218
- force_finish=True,
219
- enable_checker=True,
220
- step_rag_num=4,
221
- seed=100
222
- )
223
- logger.info("TxAgent initialized, calling init_model")
224
- agent.init_model()
225
- logger.info("TxAgent model initialized successfully")
226
- return agent
227
- except Exception as e:
228
- logger.error(f"Error initializing TxAgent: {str(e)}", exc_info=True)
229
- raise
230
 
231
  def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
232
- accumulated_text = ""
 
 
 
233
  try:
234
- if input_file is None:
235
- yield "❌ Please upload a valid Excel file.", None, ""
236
- return
237
-
238
- try:
239
- text = extract_text_from_excel(input_file)
240
- chunks = split_text_into_chunks(text)
241
- except Exception as e:
242
- yield f"❌ {str(e)}", None, ""
243
- return
244
-
245
- for i, chunk in enumerate(chunks):
246
- prompt = build_prompt_from_text(chunk)
247
- partial = ""
248
- for res in agent.run_gradio_chat(
249
- message=prompt, history=[], temperature=0.2,
250
- max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
251
- call_agent=False, conversation=[]
252
- ):
253
- partial += res if isinstance(res, str) else res.content
254
-
255
- cleaned = clean_response(partial)
256
- accumulated_text += f"\n\n📄 Analysis Part {i+1}:\n{cleaned}"
257
- yield accumulated_text, None, ""
258
-
259
- summary_prompt = f"Please summarize this analysis:\n\n{accumulated_text}"
260
- final_report = ""
261
- for res in agent.run_gradio_chat(
262
- message=summary_prompt, history=[], temperature=0.2,
263
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
264
  call_agent=False, conversation=[]
265
  ):
266
- final_report += res if isinstance(res, str) else res.content
267
-
268
- cleaned = clean_response(final_report)
269
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
270
- with open(report_path, 'w') as f:
271
- f.write(f"# Clinical Analysis Report\n\n{cleaned}")
272
-
273
- yield f"{accumulated_text}\n\n📊 Final Summary:\n{cleaned}", report_path, cleaned
274
-
275
- except Exception as e:
276
- logger.error(f"Processing error in stream_report: {str(e)}", exc_info=True)
277
- yield f"❌ Processing error: {str(e)}", None, ""
 
 
 
 
 
278
 
279
  def create_ui(agent: TxAgent) -> gr.Blocks:
280
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px !important}") as demo:
281
- gr.Markdown("""# Clinical Records Analyzer""")
282
  with gr.Row():
283
  file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
284
  analyze_btn = gr.Button("Analyze", variant="primary")
285
-
286
  with gr.Row():
287
  with gr.Column(scale=2):
288
  report_output = gr.Markdown()
289
  with gr.Column(scale=1):
290
- report_file = gr.File(label="Download Report", visible=False)
291
-
292
  full_output = gr.State()
293
-
294
- analyze_btn.click(
295
- fn=stream_report,
296
- inputs=[file_upload, full_output],
297
- outputs=[report_output, report_file, full_output]
298
- )
299
-
300
  return demo
301
 
302
  if __name__ == "__main__":
303
  try:
304
  agent = init_agent()
305
  demo = create_ui(agent)
306
- logger.info("Launching Gradio UI")
307
- demo.launch(
308
- server_name="0.0.0.0",
309
- server_port=7860,
310
- share=False
311
- )
312
  except Exception as e:
313
- logger.error(f"Application error: {str(e)}", exc_info=True)
314
- print(f"Application error: {str(e)}", file=sys.stderr)
315
- sys.exit(1)
 
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
+ from typing import List, Tuple, Union, Generator, Dict, Any
7
  import re
8
  from datetime import datetime
9
  import atexit
10
  import torch.distributed as dist
11
  import logging
12
 
13
+ # Logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
+ # PyTorch cleanup
18
  def cleanup():
19
  if dist.is_initialized():
20
  logger.info("Cleaning up PyTorch distributed process group")
21
  dist.destroy_process_group()
 
22
  atexit.register(cleanup)
23
 
24
+ # Directories
25
  persistent_dir = "/data/hf_cache"
26
  os.makedirs(persistent_dir, exist_ok=True)
 
27
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
28
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
29
  file_cache_dir = os.path.join(persistent_dir, "cache")
30
  report_dir = os.path.join(persistent_dir, "reports")
 
31
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
32
  os.makedirs(d, exist_ok=True)
 
33
  os.environ["HF_HOME"] = model_cache_dir
34
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
35
 
 
51
  return len(text) // 3.5 + 1
52
 
53
  def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
54
+ if isinstance(file_obj, dict) and 'name' in file_obj:
55
+ file_path = file_obj['name']
56
+ elif isinstance(file_obj, str):
57
+ file_path = file_obj
58
+ else:
59
+ raise ValueError("Unsupported file input type")
60
+ if not os.path.exists(file_path):
61
+ raise FileNotFoundError(f"File not found: {file_path}")
62
+ xls = pd.ExcelFile(file_path)
63
  all_text = []
64
+ for sheet in xls.sheet_names:
65
+ try:
66
+ df = xls.parse(sheet).astype(str).fillna("")
67
+ rows = df.apply(lambda r: " | ".join([c for c in r if c.strip()]), axis=1)
68
+ sheet_text = [f"[{sheet}] {line}" for line in rows if line.strip()]
69
+ all_text.extend(sheet_text)
70
+ except Exception as e:
71
+ logger.warning(f"Failed to parse {sheet}: {e}")
72
+ return "\n".join(all_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def split_text_into_chunks(text: str) -> List[str]:
75
+ lines = text.split("\n")
76
+ chunks, current, current_tokens = [], [], 0
77
+ max_tokens = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
78
  for line in lines:
79
  t = estimate_tokens(line)
80
+ if current_tokens + t > max_tokens:
81
+ chunks.append("\n".join(current))
82
+ current, current_tokens = [line], t
 
83
  else:
84
+ current.append(line)
85
+ current_tokens += t
86
+ if current:
87
+ chunks.append("\n".join(current))
88
  return chunks
89
 
90
  def build_prompt_from_text(chunk: str) -> str:
 
106
  """
107
 
108
  def validate_tool_file(tool_name: str, tool_path: str) -> bool:
 
109
  try:
110
  if not os.path.exists(tool_path):
111
+ logger.error(f"Missing tool file: {tool_path}")
112
  return False
 
113
  with open(tool_path, 'r') as f:
114
  tool_data = json.load(f)
115
+ if isinstance(tool_data, list):
116
+ return all(isinstance(item, dict) and 'name' in item for item in tool_data)
 
 
 
 
 
 
 
 
 
 
 
 
117
  elif isinstance(tool_data, dict):
118
  if 'tools' in tool_data:
119
+ return all(isinstance(item, dict) and 'name' in item for item in tool_data['tools'])
120
+ return 'name' in tool_data
121
+ logger.error(f"Invalid format in tool: {tool_name}")
122
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
+ logger.error(f"Error in {tool_name}: {e}")
125
  return False
126
 
127
  def init_agent() -> TxAgent:
128
+ new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
129
+ if not os.path.exists(new_tool_path):
130
+ with open(new_tool_path, 'w') as f:
131
+ json.dump({
132
+ "name": "new_tool",
133
+ "description": "Default tool",
134
+ "tools": [{"name": "dummy_tool", "description": "test", "version": "1.0"}]
135
+ }, f)
136
+ tool_files = {
 
 
 
 
 
 
 
 
 
 
137
  'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
138
  'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
139
  'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
140
  'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
141
+ 'new_tool': new_tool_path
142
  }
143
+ valid_tools = {k: v for k, v in tool_files.items() if validate_tool_file(k, v)}
144
+ if not valid_tools:
145
+ raise ValueError("No valid tool files")
146
+ agent = TxAgent(
147
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
148
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
149
+ tool_files_dict=valid_tools,
150
+ force_finish=True,
151
+ enable_checker=True,
152
+ step_rag_num=4,
153
+ seed=100
154
+ )
155
+ agent.init_model()
156
+ return agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
159
+ accumulated = ""
160
+ if input_file is None:
161
+ yield "❌ Upload an Excel file.", None, ""
162
+ return
163
  try:
164
+ text = extract_text_from_excel(input_file)
165
+ chunks = split_text_into_chunks(text)
166
+ except Exception as e:
167
+ yield f"❌ Error: {str(e)}", None, ""
168
+ return
169
+ for i, chunk in enumerate(chunks):
170
+ prompt = build_prompt_from_text(chunk)
171
+ result = ""
172
+ for out in agent.run_gradio_chat(
173
+ message=prompt, history=[], temperature=0.2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
175
  call_agent=False, conversation=[]
176
  ):
177
+ result += out if isinstance(out, str) else out.content
178
+ cleaned = clean_response(result)
179
+ accumulated += f"\n\n📄 Part {i+1}:\n{cleaned}"
180
+ yield accumulated, None, ""
181
+ summary_prompt = f"Summarize this analysis:\n\n{accumulated}"
182
+ summary = ""
183
+ for out in agent.run_gradio_chat(
184
+ message=summary_prompt, history=[], temperature=0.2,
185
+ max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
186
+ call_agent=False, conversation=[]
187
+ ):
188
+ summary += out if isinstance(out, str) else out.content
189
+ final = clean_response(summary)
190
+ report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
191
+ with open(report_path, 'w') as f:
192
+ f.write(f"# Clinical Report\n\n{final}")
193
+ yield f"{accumulated}\n\n📊 Final Summary:\n{final}", report_path, final
194
 
195
  def create_ui(agent: TxAgent) -> gr.Blocks:
196
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
197
+ gr.Markdown("# 🏥 Clinical Records Analyzer")
198
  with gr.Row():
199
  file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
200
  analyze_btn = gr.Button("Analyze", variant="primary")
 
201
  with gr.Row():
202
  with gr.Column(scale=2):
203
  report_output = gr.Markdown()
204
  with gr.Column(scale=1):
205
+ report_file = gr.File(label="Download", visible=False)
 
206
  full_output = gr.State()
207
+ analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output])
 
 
 
 
 
 
208
  return demo
209
 
210
  if __name__ == "__main__":
211
  try:
212
  agent = init_agent()
213
  demo = create_ui(agent)
214
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
215
  except Exception as e:
216
+ logger.error(f"App error: {e}", exc_info=True)
217
+ print(f"Application error: {e}", file=sys.stderr)
218
+ sys.exit(1)