Update app.py
Browse files
app.py
CHANGED
@@ -119,36 +119,53 @@ Please analyze these clinical notes and provide:
|
|
119 |
Provide a structured response with clear medical reasoning.
|
120 |
"""
|
121 |
|
122 |
-
def validate_tool_file(tool_name: str, tool_path: str) ->
|
123 |
-
"""Validate the structure of a tool JSON file."""
|
124 |
try:
|
125 |
if not os.path.exists(tool_path):
|
126 |
-
|
|
|
127 |
|
128 |
with open(tool_path, 'r') as f:
|
129 |
tool_data = json.load(f)
|
130 |
|
131 |
logger.info(f"Contents of {tool_name} ({tool_path}): {tool_data}")
|
132 |
|
133 |
-
if isinstance(tool_data,
|
|
|
|
|
|
|
134 |
for item in tool_data:
|
135 |
-
if not isinstance(item, dict)
|
136 |
-
|
|
|
|
|
|
|
|
|
137 |
elif isinstance(tool_data, dict):
|
138 |
if 'tools' in tool_data:
|
139 |
if not isinstance(tool_data['tools'], list):
|
140 |
-
|
|
|
141 |
for item in tool_data['tools']:
|
142 |
-
if not isinstance(item, dict)
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
else:
|
145 |
if 'name' not in tool_data:
|
146 |
-
|
|
|
147 |
else:
|
148 |
-
|
|
|
|
|
|
|
149 |
except Exception as e:
|
150 |
logger.error(f"Error validating tool file {tool_name} ({tool_path}): {str(e)}")
|
151 |
-
|
152 |
|
153 |
def init_agent() -> TxAgent:
|
154 |
tool_path = os.path.join(tool_cache_dir, "new_tool.json")
|
@@ -177,17 +194,27 @@ def init_agent() -> TxAgent:
|
|
177 |
'new_tool': tool_path
|
178 |
}
|
179 |
|
180 |
-
# Validate all tool files
|
|
|
181 |
for tool_name, tool_path in tool_files_dict.items():
|
182 |
-
validate_tool_file(tool_name, tool_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
# Initialize TxAgent
|
185 |
try:
|
186 |
-
logger.info(f"Initializing TxAgent with tool_files_dict: {
|
187 |
agent = TxAgent(
|
188 |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
189 |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
190 |
-
tool_files_dict=
|
191 |
force_finish=True,
|
192 |
enable_checker=True,
|
193 |
step_rag_num=4,
|
|
|
119 |
Provide a structured response with clear medical reasoning.
|
120 |
"""
|
121 |
|
122 |
+
def validate_tool_file(tool_name: str, tool_path: str) -> bool:
|
123 |
+
"""Validate the structure of a tool JSON file. Return True if valid, False if invalid."""
|
124 |
try:
|
125 |
if not os.path.exists(tool_path):
|
126 |
+
logger.error(f"Tool file not found: {tool_path}")
|
127 |
+
return False
|
128 |
|
129 |
with open(tool_path, 'r') as f:
|
130 |
tool_data = json.load(f)
|
131 |
|
132 |
logger.info(f"Contents of {tool_name} ({tool_path}): {tool_data}")
|
133 |
|
134 |
+
if isinstance(tool_data, str):
|
135 |
+
logger.error(f"Invalid tool file {tool_name}: JSON root is a string, expected list or dict")
|
136 |
+
return False
|
137 |
+
elif isinstance(tool_data, list):
|
138 |
for item in tool_data:
|
139 |
+
if not isinstance(item, dict):
|
140 |
+
logger.error(f"Invalid tool format in {tool_name}: each item must be a dict, got {type(item)}: {item}")
|
141 |
+
return False
|
142 |
+
if 'name' not in item:
|
143 |
+
logger.error(f"Invalid tool format in {tool_name}: each dict must have a 'name' key, got {item}")
|
144 |
+
return False
|
145 |
elif isinstance(tool_data, dict):
|
146 |
if 'tools' in tool_data:
|
147 |
if not isinstance(tool_data['tools'], list):
|
148 |
+
logger.error(f"'tools' field in {tool_name} must be a list, got {type(tool_data['tools'])}")
|
149 |
+
return False
|
150 |
for item in tool_data['tools']:
|
151 |
+
if not isinstance(item, dict):
|
152 |
+
logger.error(f"Invalid tool format in {tool_name}: each tool must be a dict, got {type(item)}: {item}")
|
153 |
+
return False
|
154 |
+
if 'name' not in item:
|
155 |
+
logger.error(f"Invalid tool format in {tool_name}: each tool dict must have a 'name' key, got {item}")
|
156 |
+
return False
|
157 |
else:
|
158 |
if 'name' not in tool_data:
|
159 |
+
logger.error(f"Invalid tool format in {tool_name}: dict must have a 'name' key or 'tools' field, got {tool_data}")
|
160 |
+
return False
|
161 |
else:
|
162 |
+
logger.error(f"Invalid tool file {tool_name}: must be a list or dict, got {type(tool_data)}")
|
163 |
+
return False
|
164 |
+
|
165 |
+
return True
|
166 |
except Exception as e:
|
167 |
logger.error(f"Error validating tool file {tool_name} ({tool_path}): {str(e)}")
|
168 |
+
return False
|
169 |
|
170 |
def init_agent() -> TxAgent:
|
171 |
tool_path = os.path.join(tool_cache_dir, "new_tool.json")
|
|
|
194 |
'new_tool': tool_path
|
195 |
}
|
196 |
|
197 |
+
# Validate all tool files and filter invalid ones
|
198 |
+
valid_tool_files = {}
|
199 |
for tool_name, tool_path in tool_files_dict.items():
|
200 |
+
if validate_tool_file(tool_name, tool_path):
|
201 |
+
valid_tool_files[tool_name] = tool_path
|
202 |
+
else:
|
203 |
+
logger.warning(f"Skipping invalid tool file: {tool_name} ({tool_path})")
|
204 |
+
|
205 |
+
if not valid_tool_files:
|
206 |
+
raise ValueError("No valid tool files found after validation")
|
207 |
+
|
208 |
+
# For testing, you can use only new_tool.json to isolate the issue
|
209 |
+
# valid_tool_files = {'new_tool': tool_path}
|
210 |
|
211 |
# Initialize TxAgent
|
212 |
try:
|
213 |
+
logger.info(f"Initializing TxAgent with tool_files_dict: {valid_tool_files}")
|
214 |
agent = TxAgent(
|
215 |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
216 |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
217 |
+
tool_files_dict=valid_tool_files,
|
218 |
force_finish=True,
|
219 |
enable_checker=True,
|
220 |
step_rag_num=4,
|