import os import json import torch from sentence_transformers import SentenceTransformer from .utils import get_md5 class ToolRAGModel: def __init__(self, rag_model_name): self.rag_model_name = rag_model_name self.rag_model = None self.tool_desc_embedding = None self.tool_name = None self.tool_embedding_path = None self.load_rag_model() def load_rag_model(self): self.rag_model = SentenceTransformer(self.rag_model_name) self.rag_model.max_seq_length = 4096 self.rag_model.tokenizer.padding_side = "right" def load_tool_desc_embedding(self, toolbox): self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True) all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)] md5_value = get_md5(str(all_tools_str)) print("Computed MD5 for tool embedding:", md5_value) self.tool_embedding_path = os.path.join( os.path.dirname(__file__), self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt" ) if os.path.exists(self.tool_embedding_path): try: self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu") assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \ "Tool count mismatch with loaded embeddings." print("\033[92mLoaded cached tool_desc_embedding.\033[0m") return except Exception as e: print(f"⚠️ Failed loading cached embeddings: {e}") self.tool_desc_embedding = None print("\033[93mGenerating new tool_desc_embedding...\033[0m") self.tool_desc_embedding = self.rag_model.encode( all_tools_str, prompt="", normalize_embeddings=True ) torch.save(self.tool_desc_embedding, self.tool_embedding_path) print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m") def rag_infer(self, query, top_k=5): torch.cuda.empty_cache() queries = [query] query_embeddings = self.rag_model.encode( queries, prompt="", normalize_embeddings=True ) if self.tool_desc_embedding is None: raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?") scores = self.rag_model.similarity( query_embeddings, self.tool_desc_embedding ) top_k = min(top_k, len(self.tool_name)) top_k_indices = torch.topk(scores, top_k).indices.tolist()[0] top_k_tool_names = [self.tool_name[i] for i in top_k_indices] return top_k_tool_names