|
import sys |
|
import os |
|
import shutil |
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
def init_agent(): |
|
base_dir = "/data" |
|
model_cache_dir = os.path.join(base_dir, "hf_cache") |
|
tool_cache_dir = os.path.join(base_dir, "tool_cache") |
|
|
|
os.makedirs(model_cache_dir, exist_ok=True) |
|
os.makedirs(tool_cache_dir, exist_ok=True) |
|
|
|
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir |
|
os.environ["HF_HOME"] = model_cache_dir |
|
|
|
|
|
default_tool_path = os.path.abspath("data/new_tool.json") |
|
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") |
|
if not os.path.exists(target_tool_path): |
|
shutil.copy(default_tool_path, target_tool_path) |
|
|
|
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B" |
|
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B" |
|
tool_files_dict = { |
|
"new_tool": target_tool_path |
|
} |
|
|
|
agent = TxAgent( |
|
model_name=model_name, |
|
rag_model_name=rag_model_name, |
|
tool_files_dict=tool_files_dict, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=100, |
|
additional_default_tools=[] |
|
) |
|
agent.init_model() |
|
return agent |
|
|