Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/embeddings
/infinity_local.py
"""written under MIT Licence, Michael Feil 2023.""" | |
import asyncio | |
from logging import getLogger | |
from typing import Any, Dict, List, Optional | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator | |
__all__ = ["InfinityEmbeddingsLocal"] | |
logger = getLogger(__name__) | |
class InfinityEmbeddingsLocal(BaseModel, Embeddings): | |
"""Optimized Infinity embedding models. | |
https://github.com/michaelfeil/infinity | |
This class deploys a local Infinity instance to embed text. | |
The class requires async usage. | |
Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity | |
Example: | |
.. code-block:: python | |
from langchain_community.embeddings import InfinityEmbeddingsLocal | |
async with InfinityEmbeddingsLocal( | |
model="BAAI/bge-small-en-v1.5", | |
revision=None, | |
device="cpu", | |
) as embedder: | |
embeddings = await engine.aembed_documents(["text1", "text2"]) | |
""" | |
model: str | |
"Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5" | |
revision: Optional[str] = None | |
"Model version, the commit hash from huggingface" | |
batch_size: int = 32 | |
"Internal batch size for inference, e.g. 32" | |
device: str = "auto" | |
"Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'" | |
backend: str = "torch" | |
"Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)" | |
" or 'optimum' for onnx/tensorrt" | |
model_warmup: bool = True | |
"Warmup the model with the max batch size." | |
engine: Any = None #: :meta private: | |
"""Infinity's AsyncEmbeddingEngine.""" | |
# LLM call kwargs | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
try: | |
from infinity_emb import AsyncEmbeddingEngine # type: ignore | |
except ImportError: | |
raise ImportError( | |
"Please install the " | |
"`pip install 'infinity_emb[optimum,torch]>=0.0.24'` " | |
"package to use the InfinityEmbeddingsLocal." | |
) | |
logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}") | |
values["engine"] = AsyncEmbeddingEngine( | |
model_name_or_path=values["model"], | |
device=values["device"], | |
revision=values["revision"], | |
model_warmup=values["model_warmup"], | |
batch_size=values["batch_size"], | |
engine=values["backend"], | |
) | |
return values | |
async def __aenter__(self) -> None: | |
"""start the background worker. | |
recommended usage is with the async with statement. | |
async with InfinityEmbeddingsLocal( | |
model="BAAI/bge-small-en-v1.5", | |
revision=None, | |
device="cpu", | |
) as embedder: | |
embeddings = await engine.aembed_documents(["text1", "text2"]) | |
""" | |
await self.engine.__aenter__() | |
async def __aexit__(self, *args: Any) -> None: | |
"""stop the background worker, | |
required to free references to the pytorch model.""" | |
await self.engine.__aexit__(*args) | |
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Async call out to Infinity's embedding endpoint. | |
Args: | |
texts: The list of texts to embed. | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
if not self.engine.running: | |
logger.warning( | |
"Starting Infinity engine on the fly. This is not recommended." | |
"Please start the engine before using it." | |
) | |
async with self: | |
# spawning threadpool for multithreaded encode, tokenization | |
embeddings, _ = await self.engine.embed(texts) | |
# stopping threadpool on exit | |
logger.warning("Stopped infinity engine after usage.") | |
else: | |
embeddings, _ = await self.engine.embed(texts) | |
return embeddings | |
async def aembed_query(self, text: str) -> List[float]: | |
"""Async call out to Infinity's embedding endpoint. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
embeddings = await self.aembed_documents([text]) | |
return embeddings[0] | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
""" | |
This method is async only. | |
""" | |
logger.warning( | |
"This method is async only. " | |
"Please use the async version `await aembed_documents`." | |
) | |
return asyncio.run(self.aembed_documents(texts)) | |
def embed_query(self, text: str) -> List[float]: | |
""" """ | |
logger.warning( | |
"This method is async only." | |
" Please use the async version `await aembed_query`." | |
) | |
return asyncio.run(self.aembed_query(text)) | |