File size: 1,358 Bytes
fb8c74f bc81715 1f64aee fb8c74f 9b7fea0 bc81715 9b7fea0 2077bae 9b7fea0 bc81715 2c757d1 bc81715 2c757d1 bc81715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import sys
import os
import shutil
# ✅ Add src to Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from txagent.txagent import TxAgent
def init_agent():
base_dir = "/data"
model_cache_dir = os.path.join(base_dir, "hf_cache")
tool_cache_dir = os.path.join(base_dir, "tool_cache")
os.makedirs(model_cache_dir, exist_ok=True)
os.makedirs(tool_cache_dir, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
os.environ["HF_HOME"] = model_cache_dir
# 🧠 Copy default tool file into persistent storage if not already there
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
tool_files_dict = {
"new_tool": target_tool_path
}
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=tool_files_dict,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=[]
)
agent.init_model()
return agent
|