|
import os |
|
import json |
|
import torch |
|
from sentence_transformers import SentenceTransformer |
|
from .utils import get_md5 |
|
|
|
|
|
class ToolRAGModel: |
|
def __init__(self, rag_model_name): |
|
self.rag_model_name = rag_model_name |
|
self.rag_model = None |
|
self.tool_desc_embedding = None |
|
self.tool_name = None |
|
self.tool_embedding_path = None |
|
self.load_rag_model() |
|
|
|
def load_rag_model(self): |
|
self.rag_model = SentenceTransformer(self.rag_model_name) |
|
self.rag_model.max_seq_length = 4096 |
|
self.rag_model.tokenizer.padding_side = "right" |
|
|
|
def load_tool_desc_embedding(self, toolbox): |
|
self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True) |
|
all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)] |
|
md5_value = get_md5(str(all_tools_str)) |
|
print("Computed MD5 for tool embedding:", md5_value) |
|
|
|
self.tool_embedding_path = os.path.join( |
|
os.path.dirname(__file__), |
|
self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt" |
|
) |
|
|
|
if os.path.exists(self.tool_embedding_path): |
|
try: |
|
self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu") |
|
assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \ |
|
"Tool count mismatch with loaded embeddings." |
|
print("\033[92mLoaded cached tool_desc_embedding.\033[0m") |
|
return |
|
except Exception as e: |
|
print(f"⚠️ Failed loading cached embeddings: {e}") |
|
self.tool_desc_embedding = None |
|
|
|
print("\033[93mGenerating new tool_desc_embedding...\033[0m") |
|
self.tool_desc_embedding = self.rag_model.encode( |
|
all_tools_str, prompt="", normalize_embeddings=True |
|
) |
|
|
|
torch.save(self.tool_desc_embedding, self.tool_embedding_path) |
|
print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m") |
|
|
|
def rag_infer(self, query, top_k=5): |
|
torch.cuda.empty_cache() |
|
queries = [query] |
|
query_embeddings = self.rag_model.encode( |
|
queries, prompt="", normalize_embeddings=True |
|
) |
|
if self.tool_desc_embedding is None: |
|
raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?") |
|
|
|
scores = self.rag_model.similarity( |
|
query_embeddings, self.tool_desc_embedding |
|
) |
|
top_k = min(top_k, len(self.tool_name)) |
|
top_k_indices = torch.topk(scores, top_k).indices.tolist()[0] |
|
top_k_tool_names = [self.tool_name[i] for i in top_k_indices] |
|
return top_k_tool_names |
|
|