Ali2206 commited on
Commit
6d47188
·
verified ·
1 Parent(s): 8c814cd

Update src/txagent/toolrag.py

Browse files
Files changed (1) hide show
  1. src/txagent/toolrag.py +67 -60
src/txagent/toolrag.py CHANGED
@@ -1,60 +1,67 @@
1
- from sentence_transformers import SentenceTransformer
2
- import torch
3
- import json
4
- from .utils import get_md5
5
-
6
-
7
- class ToolRAGModel:
8
- def __init__(self, rag_model_name):
9
- self.rag_model_name = rag_model_name
10
- self.rag_model = None
11
- self.tool_desc_embedding = None
12
- self.tool_name = None
13
- self.tool_embedding_path = None
14
- self.load_rag_model()
15
-
16
- def load_rag_model(self):
17
- self.rag_model = SentenceTransformer(self.rag_model_name)
18
- self.rag_model.max_seq_length = 4096
19
- self.rag_model.tokenizer.padding_side = "right"
20
-
21
- def load_tool_desc_embedding(self, toolbox):
22
- self.tool_name, _ = toolbox.refresh_tool_name_desc(
23
- enable_full_desc=True)
24
- all_tools_str = [json.dumps(
25
- each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
26
- md5_value = get_md5(str(all_tools_str))
27
- print("get the md value of tools:", md5_value)
28
- self.tool_embedding_path = self.rag_model_name.split(
29
- '/')[-1] + "tool_embedding_" + md5_value + ".pt"
30
- try:
31
- self.tool_desc_embedding = torch.load(
32
- self.tool_embedding_path, weights_only=False)
33
- assert len(self.tool_desc_embedding) == len(
34
- toolbox.all_tools), "The number of tools in the toolbox is not equal to the number of tool_desc_embedding."
35
- except:
36
- self.tool_desc_embedding = None
37
- print("\033[92mInferring the tool_desc_embedding.\033[0m")
38
- self.tool_desc_embedding = self.rag_model.encode(
39
- all_tools_str, prompt="", normalize_embeddings=True
40
- )
41
- torch.save(self.tool_desc_embedding, self.tool_embedding_path)
42
- print("\033[92mFinished inferring the tool_desc_embedding.\033[0m")
43
- print("\033[91mExiting. Please rerun the code to avoid the OOM issue.\033[0m")
44
- exit()
45
-
46
- def rag_infer(self, query, top_k=5):
47
- torch.cuda.empty_cache()
48
- queries = [query]
49
- query_embeddings = self.rag_model.encode(
50
- queries, prompt="", normalize_embeddings=True
51
- )
52
- if self.tool_desc_embedding is None:
53
- print("No tool_desc_embedding")
54
- exit()
55
- scores = self.rag_model.similarity(
56
- query_embeddings, self.tool_desc_embedding)
57
- top_k = min(top_k, len(self.tool_name))
58
- top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
59
- top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
60
- return top_k_tool_names
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from .utils import get_md5
6
+
7
+
8
+ class ToolRAGModel:
9
+ def __init__(self, rag_model_name):
10
+ self.rag_model_name = rag_model_name
11
+ self.rag_model = None
12
+ self.tool_desc_embedding = None
13
+ self.tool_name = None
14
+ self.tool_embedding_path = None
15
+ self.load_rag_model()
16
+
17
+ def load_rag_model(self):
18
+ self.rag_model = SentenceTransformer(self.rag_model_name)
19
+ self.rag_model.max_seq_length = 4096
20
+ self.rag_model.tokenizer.padding_side = "right"
21
+
22
+ def load_tool_desc_embedding(self, toolbox):
23
+ self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True)
24
+ all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
25
+ md5_value = get_md5(str(all_tools_str))
26
+ print("Computed MD5 for tool embedding:", md5_value)
27
+
28
+ self.tool_embedding_path = os.path.join(
29
+ os.path.dirname(__file__),
30
+ self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt"
31
+ )
32
+
33
+ if os.path.exists(self.tool_embedding_path):
34
+ try:
35
+ self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu")
36
+ assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \
37
+ "Tool count mismatch with loaded embeddings."
38
+ print("\033[92mLoaded cached tool_desc_embedding.\033[0m")
39
+ return
40
+ except Exception as e:
41
+ print(f"⚠️ Failed loading cached embeddings: {e}")
42
+ self.tool_desc_embedding = None
43
+
44
+ print("\033[93mGenerating new tool_desc_embedding...\033[0m")
45
+ self.tool_desc_embedding = self.rag_model.encode(
46
+ all_tools_str, prompt="", normalize_embeddings=True
47
+ )
48
+
49
+ torch.save(self.tool_desc_embedding, self.tool_embedding_path)
50
+ print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m")
51
+
52
+ def rag_infer(self, query, top_k=5):
53
+ torch.cuda.empty_cache()
54
+ queries = [query]
55
+ query_embeddings = self.rag_model.encode(
56
+ queries, prompt="", normalize_embeddings=True
57
+ )
58
+ if self.tool_desc_embedding is None:
59
+ raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?")
60
+
61
+ scores = self.rag_model.similarity(
62
+ query_embeddings, self.tool_desc_embedding
63
+ )
64
+ top_k = min(top_k, len(self.tool_name))
65
+ top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
66
+ top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
67
+ return top_k_tool_names