Spaces:
Paused
Paused
import os | |
import json | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from .utils import get_md5 | |
class ToolRAGModel: | |
def __init__(self, rag_model_name): | |
self.rag_model_name = rag_model_name | |
self.rag_model = None | |
self.tool_desc_embedding = None | |
self.tool_name = None | |
self.tool_embedding_path = None | |
self.load_rag_model() | |
def load_rag_model(self): | |
self.rag_model = SentenceTransformer(self.rag_model_name) | |
self.rag_model.max_seq_length = 4096 | |
self.rag_model.tokenizer.padding_side = "right" | |
def load_tool_desc_embedding(self, toolbox): | |
self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True) | |
all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)] | |
md5_value = get_md5(str(all_tools_str)) | |
print("Computed MD5 for tool embedding:", md5_value) | |
self.tool_embedding_path = os.path.join( | |
os.path.dirname(__file__), | |
self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt" | |
) | |
if os.path.exists(self.tool_embedding_path): | |
try: | |
self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu") | |
assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \ | |
"Tool count mismatch with loaded embeddings." | |
print("\033[92mLoaded cached tool_desc_embedding.\033[0m") | |
return | |
except Exception as e: | |
print(f"⚠️ Failed loading cached embeddings: {e}") | |
self.tool_desc_embedding = None | |
print("\033[93mGenerating new tool_desc_embedding...\033[0m") | |
self.tool_desc_embedding = self.rag_model.encode( | |
all_tools_str, prompt="", normalize_embeddings=True | |
) | |
torch.save(self.tool_desc_embedding, self.tool_embedding_path) | |
print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m") | |
def rag_infer(self, query, top_k=5): | |
torch.cuda.empty_cache() | |
queries = [query] | |
query_embeddings = self.rag_model.encode( | |
queries, prompt="", normalize_embeddings=True | |
) | |
if self.tool_desc_embedding is None: | |
raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?") | |
scores = self.rag_model.similarity( | |
query_embeddings, self.tool_desc_embedding | |
) | |
top_k = min(top_k, len(self.tool_name)) | |
top_k_indices = torch.topk(scores, top_k).indices.tolist()[0] | |
top_k_tool_names = [self.tool_name[i] for i in top_k_indices] | |
return top_k_tool_names | |