File size: 4,677 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import logging
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator

logger = logging.getLogger(__name__)


class ClarifaiEmbeddings(BaseModel, Embeddings):
    """Clarifai embedding models.

    To use, you should have the ``clarifai`` python package installed, and the
    environment variable ``CLARIFAI_PAT`` set with your personal access token or pass it
    as a named parameter to the constructor.

    Example:
        .. code-block:: python

            from langchain_community.embeddings import ClarifaiEmbeddings
            clarifai = ClarifaiEmbeddings(user_id=USER_ID,
                                          app_id=APP_ID,
                                          model_id=MODEL_ID)
                             (or)
            Example_URL = "https://clarifai.com/clarifai/main/models/BAAI-bge-base-en-v15"
            clarifai = ClarifaiEmbeddings(model_url=EXAMPLE_URL)
    """

    model_url: Optional[str] = None
    """Model url to use."""
    model_id: Optional[str] = None
    """Model id to use."""
    model_version_id: Optional[str] = None
    """Model version id to use."""
    app_id: Optional[str] = None
    """Clarifai application id to use."""
    user_id: Optional[str] = None
    """Clarifai user id to use."""
    pat: Optional[str] = Field(default=None, exclude=True)
    """Clarifai personal access token to use."""
    token: Optional[str] = Field(default=None, exclude=True)
    """Clarifai session token to use."""
    model: Any = Field(default=None, exclude=True)  #: :meta private:
    api_base: str = "https://api.clarifai.com"

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that we have all required info to access Clarifai
        platform and python package exists in environment."""

        try:
            from clarifai.client.model import Model
        except ImportError:
            raise ImportError(
                "Could not import clarifai python package. "
                "Please install it with `pip install clarifai`."
            )
        user_id = values.get("user_id")
        app_id = values.get("app_id")
        model_id = values.get("model_id")
        model_version_id = values.get("model_version_id")
        model_url = values.get("model_url")
        api_base = values.get("api_base")
        pat = values.get("pat")
        token = values.get("token")

        values["model"] = Model(
            url=model_url,
            app_id=app_id,
            user_id=user_id,
            model_version=dict(id=model_version_id),
            pat=pat,
            token=token,
            model_id=model_id,
            base_url=api_base,
        )

        return values

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Call out to Clarifai's embedding models.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        from clarifai.client.input import Inputs

        input_obj = Inputs.from_auth_helper(self.model.auth_helper)
        batch_size = 32
        embeddings = []

        try:
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                input_batch = [
                    input_obj.get_text_input(input_id=str(id), raw_text=inp)
                    for id, inp in enumerate(batch)
                ]
                predict_response = self.model.predict(input_batch)
                embeddings.extend(
                    [
                        list(output.data.embeddings[0].vector)
                        for output in predict_response.outputs
                    ]
                )

        except Exception as e:
            logger.error(f"Predict failed, exception: {e}")

        return embeddings

    def embed_query(self, text: str) -> List[float]:
        """Call out to Clarifai's embedding models.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """

        try:
            predict_response = self.model.predict_by_bytes(
                bytes(text, "utf-8"), input_type="text"
            )
            embeddings = [
                list(op.data.embeddings[0].vector) for op in predict_response.outputs
            ]

        except Exception as e:
            logger.error(f"Predict failed, exception: {e}")

        return embeddings[0]