Ali2206 commited on
Commit
de75e20
·
verified ·
1 Parent(s): 73810ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -16
app.py CHANGED
@@ -119,36 +119,53 @@ Please analyze these clinical notes and provide:
119
  Provide a structured response with clear medical reasoning.
120
  """
121
 
122
- def validate_tool_file(tool_name: str, tool_path: str) -> None:
123
- """Validate the structure of a tool JSON file."""
124
  try:
125
  if not os.path.exists(tool_path):
126
- raise FileNotFoundError(f"Tool file not found: {tool_path}")
 
127
 
128
  with open(tool_path, 'r') as f:
129
  tool_data = json.load(f)
130
 
131
  logger.info(f"Contents of {tool_name} ({tool_path}): {tool_data}")
132
 
133
- if isinstance(tool_data, list):
 
 
 
134
  for item in tool_data:
135
- if not isinstance(item, dict) or 'name' not in item:
136
- raise ValueError(f"Invalid tool format in {tool_name}: each item must be a dict with a 'name' key, got {item}")
 
 
 
 
137
  elif isinstance(tool_data, dict):
138
  if 'tools' in tool_data:
139
  if not isinstance(tool_data['tools'], list):
140
- raise ValueError(f"'tools' field in {tool_name} must be a list, got {type(tool_data['tools'])}")
 
141
  for item in tool_data['tools']:
142
- if not isinstance(item, dict) or 'name' not in item:
143
- raise ValueError(f"Invalid tool format in {tool_name}: each tool must be a dict with a 'name' key, got {item}")
 
 
 
 
144
  else:
145
  if 'name' not in tool_data:
146
- raise ValueError(f"Invalid tool format in {tool_name}: dict must have a 'name' key or 'tools' field, got {tool_data}")
 
147
  else:
148
- raise ValueError(f"Invalid tool file {tool_name}: must be a list or dict, got {type(tool_data)}")
 
 
 
149
  except Exception as e:
150
  logger.error(f"Error validating tool file {tool_name} ({tool_path}): {str(e)}")
151
- raise
152
 
153
  def init_agent() -> TxAgent:
154
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
@@ -177,17 +194,27 @@ def init_agent() -> TxAgent:
177
  'new_tool': tool_path
178
  }
179
 
180
- # Validate all tool files
 
181
  for tool_name, tool_path in tool_files_dict.items():
182
- validate_tool_file(tool_name, tool_path)
 
 
 
 
 
 
 
 
 
183
 
184
  # Initialize TxAgent
185
  try:
186
- logger.info(f"Initializing TxAgent with tool_files_dict: {tool_files_dict}")
187
  agent = TxAgent(
188
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
189
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
190
- tool_files_dict=tool_files_dict,
191
  force_finish=True,
192
  enable_checker=True,
193
  step_rag_num=4,
 
119
  Provide a structured response with clear medical reasoning.
120
  """
121
 
122
+ def validate_tool_file(tool_name: str, tool_path: str) -> bool:
123
+ """Validate the structure of a tool JSON file. Return True if valid, False if invalid."""
124
  try:
125
  if not os.path.exists(tool_path):
126
+ logger.error(f"Tool file not found: {tool_path}")
127
+ return False
128
 
129
  with open(tool_path, 'r') as f:
130
  tool_data = json.load(f)
131
 
132
  logger.info(f"Contents of {tool_name} ({tool_path}): {tool_data}")
133
 
134
+ if isinstance(tool_data, str):
135
+ logger.error(f"Invalid tool file {tool_name}: JSON root is a string, expected list or dict")
136
+ return False
137
+ elif isinstance(tool_data, list):
138
  for item in tool_data:
139
+ if not isinstance(item, dict):
140
+ logger.error(f"Invalid tool format in {tool_name}: each item must be a dict, got {type(item)}: {item}")
141
+ return False
142
+ if 'name' not in item:
143
+ logger.error(f"Invalid tool format in {tool_name}: each dict must have a 'name' key, got {item}")
144
+ return False
145
  elif isinstance(tool_data, dict):
146
  if 'tools' in tool_data:
147
  if not isinstance(tool_data['tools'], list):
148
+ logger.error(f"'tools' field in {tool_name} must be a list, got {type(tool_data['tools'])}")
149
+ return False
150
  for item in tool_data['tools']:
151
+ if not isinstance(item, dict):
152
+ logger.error(f"Invalid tool format in {tool_name}: each tool must be a dict, got {type(item)}: {item}")
153
+ return False
154
+ if 'name' not in item:
155
+ logger.error(f"Invalid tool format in {tool_name}: each tool dict must have a 'name' key, got {item}")
156
+ return False
157
  else:
158
  if 'name' not in tool_data:
159
+ logger.error(f"Invalid tool format in {tool_name}: dict must have a 'name' key or 'tools' field, got {tool_data}")
160
+ return False
161
  else:
162
+ logger.error(f"Invalid tool file {tool_name}: must be a list or dict, got {type(tool_data)}")
163
+ return False
164
+
165
+ return True
166
  except Exception as e:
167
  logger.error(f"Error validating tool file {tool_name} ({tool_path}): {str(e)}")
168
+ return False
169
 
170
  def init_agent() -> TxAgent:
171
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
 
194
  'new_tool': tool_path
195
  }
196
 
197
+ # Validate all tool files and filter invalid ones
198
+ valid_tool_files = {}
199
  for tool_name, tool_path in tool_files_dict.items():
200
+ if validate_tool_file(tool_name, tool_path):
201
+ valid_tool_files[tool_name] = tool_path
202
+ else:
203
+ logger.warning(f"Skipping invalid tool file: {tool_name} ({tool_path})")
204
+
205
+ if not valid_tool_files:
206
+ raise ValueError("No valid tool files found after validation")
207
+
208
+ # For testing, you can use only new_tool.json to isolate the issue
209
+ # valid_tool_files = {'new_tool': tool_path}
210
 
211
  # Initialize TxAgent
212
  try:
213
+ logger.info(f"Initializing TxAgent with tool_files_dict: {valid_tool_files}")
214
  agent = TxAgent(
215
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
216
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
217
+ tool_files_dict=valid_tool_files,
218
  force_finish=True,
219
  enable_checker=True,
220
  step_rag_num=4,