Spaces:
Runtime error
Runtime error
File size: 7,310 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import asyncio
import logging
import warnings
from typing import Dict, Iterable, List, Optional
import httpx
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tokenizers import Tokenizer # type: ignore
logger = logging.getLogger(__name__)
MAX_TOKENS = 16_000
"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens
accepted by the embedding model for each document/chunk, but rather the maximum number
of tokens that can be sent in a single request to the Mistral API (across multiple
documents/chunks)"""
class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
def encode_batch(self, texts: List[str]) -> List[List[str]]:
return [list(text) for text in texts]
class MistralAIEmbeddings(BaseModel, Embeddings):
"""MistralAI embedding models.
To use, set the environment variable `MISTRAL_API_KEY` is set with your API key or
pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_mistralai import MistralAIEmbeddings
mistral = MistralAIEmbeddings(
model="mistral-embed",
api_key="my-api-key"
)
"""
client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)
model: str = "mistral-embed"
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate configuration."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
if not values.get("client"):
values["client"] = httpx.Client(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
# todo: handle retries and max_concurrency
if not values.get("async_client"):
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
if values["tokenizer"] is None:
try:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
warnings.warn(
"Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the "
"HF_TOKEN environment variable to download the real tokenizer. "
"Falling back to a dummy tokenizer that uses `len()`."
)
values["tokenizer"] = DummyTokenizer()
return values
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
"""Split a list of texts into batches of less than 16k tokens
for Mistral API."""
batch: List[str] = []
batch_tokens = 0
text_token_lengths = [
len(encoded) for encoded in self.tokenizer.encode_batch(texts)
]
for text, text_tokens in zip(texts, text_token_lengths):
if batch_tokens + text_tokens > MAX_TOKENS:
if len(batch) > 0:
# edge case where first batch exceeds max tokens
# should not yield an empty batch.
yield batch
batch = [text]
batch_tokens = text_tokens
else:
batch.append(text)
batch_tokens += text_tokens
if batch:
yield batch
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
try:
batch_responses = (
self.client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
)
for batch in self._get_batches(texts)
)
return [
list(map(float, embedding_obj["embedding"]))
for response in batch_responses
for embedding_obj in response.json()["data"]
]
except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}")
raise
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
try:
batch_responses = await asyncio.gather(
*[
self.async_client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
)
for batch in self._get_batches(texts)
]
)
return [
list(map(float, embedding_obj["embedding"]))
for response in batch_responses
for embedding_obj in response.json()["data"]
]
except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}")
raise
def embed_query(self, text: str) -> List[float]:
"""Embed a single query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
"""Embed a single query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
return (await self.aembed_documents([text]))[0]
|