from typing import List from langchain_huggingface import HuggingFaceEmbeddings from sandbox.light_rag.utils import get_device class HFEmbedding: def __init__( self, model_id: str, ): device = get_device() # TODO: hack for zeroGPU device = "cpu" print(f"Using device: {device}") if device == "cpu": print("Using CPU might be too slow") self.model_name = model_id print(f"Loading embeddings model from: {self.model_name}") self.embeddings_service = HuggingFaceEmbeddings( model_name=self.model_name, model_kwargs={"device": device}, ) def embed_batch(self, batch: list[str]): return self.embeddings_service.embed_documents(batch) def embed_documents(self, texts: list[str]) -> list[list[float]]: embeddings = self.embeddings_service.embed_documents(texts) return embeddings def embed_query(self, text: str) -> list[float]: return self.embed_documents([text])[0]