File size: 1,358 Bytes
fb8c74f
bc81715
1f64aee
fb8c74f
 
 
 
9b7fea0
bc81715
 
9b7fea0
 
 
 
2077bae
9b7fea0
 
bc81715
 
 
2c757d1
 
 
 
 
 
bc81715
 
 
2c757d1
bc81715
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import sys
import os
import shutil

# ✅ Add src to Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))

from txagent.txagent import TxAgent

def init_agent():
    base_dir = "/data"
    model_cache_dir = os.path.join(base_dir, "hf_cache")
    tool_cache_dir = os.path.join(base_dir, "tool_cache")

    os.makedirs(model_cache_dir, exist_ok=True)
    os.makedirs(tool_cache_dir, exist_ok=True)

    os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
    os.environ["HF_HOME"] = model_cache_dir

    # 🧠 Copy default tool file into persistent storage if not already there
    default_tool_path = os.path.abspath("data/new_tool.json")
    target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
    if not os.path.exists(target_tool_path):
        shutil.copy(default_tool_path, target_tool_path)

    model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
    rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
    tool_files_dict = {
        "new_tool": target_tool_path
    }

    agent = TxAgent(
        model_name=model_name,
        rag_model_name=rag_model_name,
        tool_files_dict=tool_files_dict,
        force_finish=True,
        enable_checker=True,
        step_rag_num=10,
        seed=100,
        additional_default_tools=[]
    )
    agent.init_model()
    return agent