Ali2206 commited on
Commit
a4b1ab0
·
verified ·
1 Parent(s): 1a611b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -25
app.py CHANGED
@@ -14,14 +14,14 @@ import logging
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
- # PyTorch cleanup
18
  def cleanup():
19
  if dist.is_initialized():
20
  logger.info("Cleaning up PyTorch distributed process group")
21
  dist.destroy_process_group()
22
  atexit.register(cleanup)
23
 
24
- # Directories
25
  persistent_dir = "/data/hf_cache"
26
  os.makedirs(persistent_dir, exist_ok=True)
27
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
@@ -105,25 +105,6 @@ Please analyze these clinical notes and provide:
105
  Provide a structured response with clear medical reasoning.
106
  """
107
 
108
- def validate_tool_file(tool_name: str, tool_path: str) -> bool:
109
- try:
110
- if not os.path.exists(tool_path):
111
- logger.error(f"Missing tool file: {tool_path}")
112
- return False
113
- with open(tool_path, 'r') as f:
114
- tool_data = json.load(f)
115
- if isinstance(tool_data, list):
116
- return all(isinstance(item, dict) and 'name' in item for item in tool_data)
117
- elif isinstance(tool_data, dict):
118
- if 'tools' in tool_data:
119
- return all(isinstance(item, dict) and 'name' in item for item in tool_data['tools'])
120
- return 'name' in tool_data
121
- logger.error(f"Invalid format in tool: {tool_name}")
122
- return False
123
- except Exception as e:
124
- logger.error(f"Error in {tool_name}: {e}")
125
- return False
126
-
127
  def init_agent() -> TxAgent:
128
  new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
129
  if not os.path.exists(new_tool_path):
@@ -133,6 +114,7 @@ def init_agent() -> TxAgent:
133
  "description": "Default tool",
134
  "tools": [{"name": "dummy_tool", "description": "test", "version": "1.0"}]
135
  }, f)
 
136
  tool_files = {
137
  'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
138
  'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
@@ -140,13 +122,35 @@ def init_agent() -> TxAgent:
140
  'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
141
  'new_tool': new_tool_path
142
  }
143
- valid_tools = {k: v for k, v in tool_files.items() if validate_tool_file(k, v)}
144
- if not valid_tools:
145
- raise ValueError("No valid tool files")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  agent = TxAgent(
147
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
148
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
149
- tool_files_dict=valid_tools,
150
  force_finish=True,
151
  enable_checker=True,
152
  step_rag_num=4,
 
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Cleanup
18
  def cleanup():
19
  if dist.is_initialized():
20
  logger.info("Cleaning up PyTorch distributed process group")
21
  dist.destroy_process_group()
22
  atexit.register(cleanup)
23
 
24
+ # Cache dirs
25
  persistent_dir = "/data/hf_cache"
26
  os.makedirs(persistent_dir, exist_ok=True)
27
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
 
105
  Provide a structured response with clear medical reasoning.
106
  """
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def init_agent() -> TxAgent:
109
  new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
110
  if not os.path.exists(new_tool_path):
 
114
  "description": "Default tool",
115
  "tools": [{"name": "dummy_tool", "description": "test", "version": "1.0"}]
116
  }, f)
117
+
118
  tool_files = {
119
  'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
120
  'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
 
122
  'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
123
  'new_tool': new_tool_path
124
  }
125
+
126
+ validated = {}
127
+ for name, path in tool_files.items():
128
+ try:
129
+ with open(path, 'r') as f:
130
+ data = json.load(f)
131
+ if isinstance(data, dict) and 'tools' in data:
132
+ tools = data['tools']
133
+ elif isinstance(data, list):
134
+ tools = data
135
+ elif isinstance(data, dict) and 'name' in data:
136
+ tools = [data]
137
+ else:
138
+ logger.warning(f"Skipping {name}: bad structure")
139
+ continue
140
+ if all(isinstance(t, dict) and 'name' in t for t in tools):
141
+ validated[name] = path
142
+ else:
143
+ logger.warning(f"Skipping {name}: items malformed")
144
+ except Exception as e:
145
+ logger.error(f"Invalid tool {name}: {e}")
146
+
147
+ if not validated:
148
+ raise ValueError("No valid tools to load")
149
+
150
  agent = TxAgent(
151
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
152
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
153
+ tool_files_dict=validated,
154
  force_finish=True,
155
  enable_checker=True,
156
  step_rag_num=4,