Spaces:
Running
on
Zero
Running
on
Zero
from typing import List | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from sandbox.light_rag.utils import get_device | |
class HFEmbedding: | |
def __init__( | |
self, | |
model_id: str, | |
): | |
device = get_device() | |
# TODO: hack for zeroGPU | |
device = "cpu" | |
print(f"Using device: {device}") | |
if device == "cpu": | |
print("Using CPU might be too slow") | |
self.model_name = model_id | |
print(f"Loading embeddings model from: {self.model_name}") | |
self.embeddings_service = HuggingFaceEmbeddings( | |
model_name=self.model_name, | |
model_kwargs={"device": device}, | |
) | |
def embed_batch(self, batch: list[str]): | |
return self.embeddings_service.embed_documents(batch) | |
def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
embeddings = self.embeddings_service.embed_documents(texts) | |
return embeddings | |
def embed_query(self, text: str) -> list[float]: | |
return self.embed_documents([text])[0] | |