Ali2206 commited on
Commit
aa559b4
·
verified ·
1 Parent(s): a4b1ab0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -60
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import sys
2
  import os
3
  import pandas as pd
@@ -12,16 +14,18 @@ import logging
12
 
13
  # Logging
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
- logger = logging.getLogger(__name__)
16
 
17
  # Cleanup
 
18
  def cleanup():
19
  if dist.is_initialized():
20
  logger.info("Cleaning up PyTorch distributed process group")
21
  dist.destroy_process_group()
 
22
  atexit.register(cleanup)
23
 
24
- # Cache dirs
25
  persistent_dir = "/data/hf_cache"
26
  os.makedirs(persistent_dir, exist_ok=True)
27
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
@@ -33,6 +37,7 @@ for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
33
  os.environ["HF_HOME"] = model_cache_dir
34
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
35
 
 
36
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
37
  from txagent.txagent import TxAgent
38
 
@@ -41,15 +46,18 @@ MAX_CHUNK_TOKENS = 8192
41
  MAX_NEW_TOKENS = 2048
42
  PROMPT_OVERHEAD = 500
43
 
 
44
  def clean_response(text: str) -> str:
45
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
46
  text = re.sub(r"\n{3,}", "\n\n", text)
47
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
48
  return text.strip()
49
 
 
50
  def estimate_tokens(text: str) -> int:
51
  return len(text) // 3.5 + 1
52
 
 
53
  def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
54
  if isinstance(file_obj, dict) and 'name' in file_obj:
55
  file_path = file_obj['name']
@@ -71,6 +79,7 @@ def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
71
  logger.warning(f"Failed to parse {sheet}: {e}")
72
  return "\n".join(all_text)
73
 
 
74
  def split_text_into_chunks(text: str) -> List[str]:
75
  lines = text.split("\n")
76
  chunks, current, current_tokens = [], [], 0
@@ -87,6 +96,7 @@ def split_text_into_chunks(text: str) -> List[str]:
87
  chunks.append("\n".join(current))
88
  return chunks
89
 
 
90
  def build_prompt_from_text(chunk: str) -> str:
91
  return f"""
92
  ### Clinical Records Analysis
@@ -105,17 +115,36 @@ Please analyze these clinical notes and provide:
105
  Provide a structured response with clear medical reasoning.
106
  """
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def init_agent() -> TxAgent:
109
  new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
110
  if not os.path.exists(new_tool_path):
111
  with open(new_tool_path, 'w') as f:
112
- json.dump({
113
- "name": "new_tool",
114
- "description": "Default tool",
115
- "tools": [{"name": "dummy_tool", "description": "test", "version": "1.0"}]
116
- }, f)
117
 
118
- tool_files = {
119
  'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
120
  'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
121
  'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
@@ -123,34 +152,19 @@ def init_agent() -> TxAgent:
123
  'new_tool': new_tool_path
124
  }
125
 
126
- validated = {}
127
- for name, path in tool_files.items():
128
- try:
129
- with open(path, 'r') as f:
130
- data = json.load(f)
131
- if isinstance(data, dict) and 'tools' in data:
132
- tools = data['tools']
133
- elif isinstance(data, list):
134
- tools = data
135
- elif isinstance(data, dict) and 'name' in data:
136
- tools = [data]
137
- else:
138
- logger.warning(f"Skipping {name}: bad structure")
139
- continue
140
- if all(isinstance(t, dict) and 'name' in t for t in tools):
141
- validated[name] = path
142
- else:
143
- logger.warning(f"Skipping {name}: items malformed")
144
- except Exception as e:
145
- logger.error(f"Invalid tool {name}: {e}")
146
 
147
- if not validated:
148
- raise ValueError("No valid tools to load")
149
 
150
  agent = TxAgent(
151
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
152
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
153
- tool_files_dict=validated,
154
  force_finish=True,
155
  enable_checker=True,
156
  step_rag_num=4,
@@ -159,42 +173,42 @@ def init_agent() -> TxAgent:
159
  agent.init_model()
160
  return agent
161
 
 
162
  def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
163
  accumulated = ""
164
- if input_file is None:
165
- yield "❌ Upload an Excel file.", None, ""
166
- return
167
  try:
 
 
 
168
  text = extract_text_from_excel(input_file)
169
  chunks = split_text_into_chunks(text)
170
- except Exception as e:
171
- yield f"❌ Error: {str(e)}", None, ""
172
- return
173
- for i, chunk in enumerate(chunks):
174
- prompt = build_prompt_from_text(chunk)
175
- result = ""
 
 
 
 
 
 
 
176
  for out in agent.run_gradio_chat(
177
- message=prompt, history=[], temperature=0.2,
178
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
179
- call_agent=False, conversation=[]
180
- ):
181
- result += out if isinstance(out, str) else out.content
182
- cleaned = clean_response(result)
183
- accumulated += f"\n\n📄 Part {i+1}:\n{cleaned}"
184
- yield accumulated, None, ""
185
- summary_prompt = f"Summarize this analysis:\n\n{accumulated}"
186
- summary = ""
187
- for out in agent.run_gradio_chat(
188
- message=summary_prompt, history=[], temperature=0.2,
189
- max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
190
- call_agent=False, conversation=[]
191
- ):
192
- summary += out if isinstance(out, str) else out.content
193
- final = clean_response(summary)
194
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
195
- with open(report_path, 'w') as f:
196
- f.write(f"# Clinical Report\n\n{final}")
197
- yield f"{accumulated}\n\n📊 Final Summary:\n{final}", report_path, final
198
 
199
  def create_ui(agent: TxAgent) -> gr.Blocks:
200
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -211,6 +225,7 @@ def create_ui(agent: TxAgent) -> gr.Blocks:
211
  analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output])
212
  return demo
213
 
 
214
  if __name__ == "__main__":
215
  try:
216
  agent = init_agent()
 
1
+ # ✅ Fully updated app.py for TxAgent with strict tool validation to prevent runtime errors
2
+
3
  import sys
4
  import os
5
  import pandas as pd
 
14
 
15
  # Logging
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
+ logger = logging.getLogger("app")
18
 
19
  # Cleanup
20
+
21
  def cleanup():
22
  if dist.is_initialized():
23
  logger.info("Cleaning up PyTorch distributed process group")
24
  dist.destroy_process_group()
25
+
26
  atexit.register(cleanup)
27
 
28
+ # Directories
29
  persistent_dir = "/data/hf_cache"
30
  os.makedirs(persistent_dir, exist_ok=True)
31
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
 
37
  os.environ["HF_HOME"] = model_cache_dir
38
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
39
 
40
+ # Import TxAgent
41
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
42
  from txagent.txagent import TxAgent
43
 
 
46
  MAX_NEW_TOKENS = 2048
47
  PROMPT_OVERHEAD = 500
48
 
49
+
50
  def clean_response(text: str) -> str:
51
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
52
  text = re.sub(r"\n{3,}", "\n\n", text)
53
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
54
  return text.strip()
55
 
56
+
57
  def estimate_tokens(text: str) -> int:
58
  return len(text) // 3.5 + 1
59
 
60
+
61
  def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
62
  if isinstance(file_obj, dict) and 'name' in file_obj:
63
  file_path = file_obj['name']
 
79
  logger.warning(f"Failed to parse {sheet}: {e}")
80
  return "\n".join(all_text)
81
 
82
+
83
  def split_text_into_chunks(text: str) -> List[str]:
84
  lines = text.split("\n")
85
  chunks, current, current_tokens = [], [], 0
 
96
  chunks.append("\n".join(current))
97
  return chunks
98
 
99
+
100
  def build_prompt_from_text(chunk: str) -> str:
101
  return f"""
102
  ### Clinical Records Analysis
 
115
  Provide a structured response with clear medical reasoning.
116
  """
117
 
118
+
119
+ def clean_and_rewrite_tool_file(original_path: str, cleaned_path: str) -> bool:
120
+ try:
121
+ with open(original_path, "r") as f:
122
+ data = json.load(f)
123
+ if isinstance(data, dict) and "tools" in data:
124
+ tools = data["tools"]
125
+ elif isinstance(data, list):
126
+ tools = data
127
+ elif isinstance(data, dict) and "name" in data:
128
+ tools = [data]
129
+ else:
130
+ return False
131
+ if not all(isinstance(t, dict) and "name" in t for t in tools):
132
+ return False
133
+ with open(cleaned_path, "w") as out:
134
+ json.dump(tools, out)
135
+ return True
136
+ except Exception as e:
137
+ logger.error(f"Failed to clean tool {original_path}: {e}")
138
+ return False
139
+
140
+
141
  def init_agent() -> TxAgent:
142
  new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
143
  if not os.path.exists(new_tool_path):
144
  with open(new_tool_path, 'w') as f:
145
+ json.dump([{"name": "dummy_tool", "description": "test", "version": "1.0"}], f)
 
 
 
 
146
 
147
+ raw_tool_files = {
148
  'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
149
  'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
150
  'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
 
152
  'new_tool': new_tool_path
153
  }
154
 
155
+ validated_paths = {}
156
+ for name, original_path in raw_tool_files.items():
157
+ cleaned_path = os.path.join(tool_cache_dir, f"{name}_cleaned.json")
158
+ if clean_and_rewrite_tool_file(original_path, cleaned_path):
159
+ validated_paths[name] = cleaned_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ if not validated_paths:
162
+ raise ValueError("No valid tools found after sanitizing.")
163
 
164
  agent = TxAgent(
165
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
166
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
167
+ tool_files_dict=validated_paths,
168
  force_finish=True,
169
  enable_checker=True,
170
  step_rag_num=4,
 
173
  agent.init_model()
174
  return agent
175
 
176
+
177
  def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
178
  accumulated = ""
 
 
 
179
  try:
180
+ if input_file is None:
181
+ yield "❌ Upload a valid Excel file.", None, ""
182
+ return
183
  text = extract_text_from_excel(input_file)
184
  chunks = split_text_into_chunks(text)
185
+ for i, chunk in enumerate(chunks):
186
+ prompt = build_prompt_from_text(chunk)
187
+ result = ""
188
+ for out in agent.run_gradio_chat(
189
+ message=prompt, history=[], temperature=0.2,
190
+ max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
191
+ call_agent=False, conversation=[]):
192
+ result += out if isinstance(out, str) else out.content
193
+ cleaned = clean_response(result)
194
+ accumulated += f"\n\n📄 Part {i+1}:\n{cleaned}"
195
+ yield accumulated, None, ""
196
+ summary_prompt = f"Summarize this analysis:\n\n{accumulated}"
197
+ summary = ""
198
  for out in agent.run_gradio_chat(
199
+ message=summary_prompt, history=[], temperature=0.2,
200
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
201
+ call_agent=False, conversation=[]):
202
+ summary += out if isinstance(out, str) else out.content
203
+ final = clean_response(summary)
204
+ report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
205
+ with open(report_path, 'w') as f:
206
+ f.write(f"# Clinical Report\n\n{final}")
207
+ yield f"{accumulated}\n\n📊 Final Summary:\n{final}", report_path, final
208
+ except Exception as e:
209
+ logger.error(f"Stream error: {e}", exc_info=True)
210
+ yield f"❌ Error: {str(e)}", None, ""
211
+
 
 
 
 
 
 
 
 
212
 
213
  def create_ui(agent: TxAgent) -> gr.Blocks:
214
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
225
  analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output])
226
  return demo
227
 
228
+
229
  if __name__ == "__main__":
230
  try:
231
  agent = init_agent()