import sys | |
import os | |
# ✅ Add src to Python path | |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) | |
from txagent.txagent import TxAgent # ✅ Now this will work | |
def init_agent(): | |
model_cache_dir = os.path.expanduser("~/.cache/txagent_models") | |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir | |
os.environ["HF_HOME"] = model_cache_dir | |
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B" | |
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B" | |
tool_files_dict = { | |
"new_tool": os.path.abspath("data/new_tool.json") | |
} | |
agent = TxAgent( | |
model_name=model_name, | |
rag_model_name=rag_model_name, | |
tool_files_dict=tool_files_dict, | |
force_finish=True, | |
enable_checker=True, | |
step_rag_num=10, | |
seed=100, | |
additional_default_tools=[] | |
) | |
agent.init_model() | |
return agent | |