File size: 2,846 Bytes
3e3b258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import json
import torch
from sentence_transformers import SentenceTransformer
from .utils import get_md5


class ToolRAGModel:
    def __init__(self, rag_model_name):
        self.rag_model_name = rag_model_name
        self.rag_model = None
        self.tool_desc_embedding = None
        self.tool_name = None
        self.tool_embedding_path = None
        self.load_rag_model()

    def load_rag_model(self):
        self.rag_model = SentenceTransformer(self.rag_model_name)
        self.rag_model.max_seq_length = 4096
        self.rag_model.tokenizer.padding_side = "right"

    def load_tool_desc_embedding(self, toolbox):
        self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True)
        all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
        md5_value = get_md5(str(all_tools_str))
        print("Computed MD5 for tool embedding:", md5_value)

        self.tool_embedding_path = os.path.join(
            os.path.dirname(__file__),
            self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt"
        )

        if os.path.exists(self.tool_embedding_path):
            try:
                self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu")
                assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \
                    "Tool count mismatch with loaded embeddings."
                print("\033[92mLoaded cached tool_desc_embedding.\033[0m")
                return
            except Exception as e:
                print(f"⚠️ Failed loading cached embeddings: {e}")
                self.tool_desc_embedding = None

        print("\033[93mGenerating new tool_desc_embedding...\033[0m")
        self.tool_desc_embedding = self.rag_model.encode(
            all_tools_str, prompt="", normalize_embeddings=True
        )

        torch.save(self.tool_desc_embedding, self.tool_embedding_path)
        print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m")

    def rag_infer(self, query, top_k=5):
        torch.cuda.empty_cache()
        queries = [query]
        query_embeddings = self.rag_model.encode(
            queries, prompt="", normalize_embeddings=True
        )
        if self.tool_desc_embedding is None:
            raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?")

        scores = self.rag_model.similarity(
            query_embeddings, self.tool_desc_embedding
        )
        top_k = min(top_k, len(self.tool_name))
        top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
        top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
        return top_k_tool_names