Update backend/agent_instance.py
Browse files
backend/agent_instance.py
CHANGED
@@ -7,27 +7,28 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..",
|
|
7 |
from txagent.txagent import TxAgent
|
8 |
|
9 |
def init_agent():
|
10 |
-
# ✅ Use Hugging Face persistent storage
|
11 |
base_dir = "/data"
|
12 |
model_cache_dir = os.path.join(base_dir, "hf_cache")
|
13 |
tool_cache_dir = os.path.join(base_dir, "tool_cache")
|
14 |
|
15 |
-
# ✅ Ensure the folders exist
|
16 |
os.makedirs(model_cache_dir, exist_ok=True)
|
17 |
os.makedirs(tool_cache_dir, exist_ok=True)
|
18 |
|
19 |
-
# ✅ Set environment variables so models stay cached after restart
|
20 |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
|
21 |
os.environ["HF_HOME"] = model_cache_dir
|
22 |
|
23 |
-
#
|
|
|
|
|
|
|
|
|
|
|
24 |
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
|
25 |
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
|
26 |
tool_files_dict = {
|
27 |
-
"new_tool":
|
28 |
}
|
29 |
|
30 |
-
# ✅ Init agent with config
|
31 |
agent = TxAgent(
|
32 |
model_name=model_name,
|
33 |
rag_model_name=rag_model_name,
|
|
|
7 |
from txagent.txagent import TxAgent
|
8 |
|
9 |
def init_agent():
|
|
|
10 |
base_dir = "/data"
|
11 |
model_cache_dir = os.path.join(base_dir, "hf_cache")
|
12 |
tool_cache_dir = os.path.join(base_dir, "tool_cache")
|
13 |
|
|
|
14 |
os.makedirs(model_cache_dir, exist_ok=True)
|
15 |
os.makedirs(tool_cache_dir, exist_ok=True)
|
16 |
|
|
|
17 |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
|
18 |
os.environ["HF_HOME"] = model_cache_dir
|
19 |
|
20 |
+
# 🧠 Copy default tool file into persistent storage if not already there
|
21 |
+
default_tool_path = os.path.abspath("data/new_tool.json")
|
22 |
+
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
|
23 |
+
if not os.path.exists(target_tool_path):
|
24 |
+
shutil.copy(default_tool_path, target_tool_path)
|
25 |
+
|
26 |
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
|
27 |
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
|
28 |
tool_files_dict = {
|
29 |
+
"new_tool": target_tool_path
|
30 |
}
|
31 |
|
|
|
32 |
agent = TxAgent(
|
33 |
model_name=model_name,
|
34 |
rag_model_name=rag_model_name,
|