Ali2206 commited on
Commit
2c757d1
·
verified ·
1 Parent(s): 69136c9

Update backend/agent_instance.py

Browse files
Files changed (1) hide show
  1. backend/agent_instance.py +7 -6
backend/agent_instance.py CHANGED
@@ -7,27 +7,28 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..",
7
  from txagent.txagent import TxAgent
8
 
9
  def init_agent():
10
- # ✅ Use Hugging Face persistent storage
11
  base_dir = "/data"
12
  model_cache_dir = os.path.join(base_dir, "hf_cache")
13
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
14
 
15
- # ✅ Ensure the folders exist
16
  os.makedirs(model_cache_dir, exist_ok=True)
17
  os.makedirs(tool_cache_dir, exist_ok=True)
18
 
19
- # ✅ Set environment variables so models stay cached after restart
20
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
21
  os.environ["HF_HOME"] = model_cache_dir
22
 
23
- # Paths to model + tool definitions
 
 
 
 
 
24
  model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
25
  rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
26
  tool_files_dict = {
27
- "new_tool": os.path.join(tool_cache_dir, "new_tool.json")
28
  }
29
 
30
- # ✅ Init agent with config
31
  agent = TxAgent(
32
  model_name=model_name,
33
  rag_model_name=rag_model_name,
 
7
  from txagent.txagent import TxAgent
8
 
9
  def init_agent():
 
10
  base_dir = "/data"
11
  model_cache_dir = os.path.join(base_dir, "hf_cache")
12
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
13
 
 
14
  os.makedirs(model_cache_dir, exist_ok=True)
15
  os.makedirs(tool_cache_dir, exist_ok=True)
16
 
 
17
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
18
  os.environ["HF_HOME"] = model_cache_dir
19
 
20
+ # 🧠 Copy default tool file into persistent storage if not already there
21
+ default_tool_path = os.path.abspath("data/new_tool.json")
22
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
23
+ if not os.path.exists(target_tool_path):
24
+ shutil.copy(default_tool_path, target_tool_path)
25
+
26
  model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
27
  rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
28
  tool_files_dict = {
29
+ "new_tool": target_tool_path
30
  }
31
 
 
32
  agent = TxAgent(
33
  model_name=model_name,
34
  rag_model_name=rag_model_name,