File size: 14,156 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
"""Wrapper around Prem's Chat API."""

from __future__ import annotations

import logging
import warnings
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

from langchain_core.callbacks import (
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    BaseMessageChunk,
    ChatMessage,
    ChatMessageChunk,
    HumanMessage,
    HumanMessageChunk,
    SystemMessage,
    SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
    BaseModel,
    Extra,
    Field,
    SecretStr,
    root_validator,
)
from langchain_core.utils import get_from_dict_or_env

if TYPE_CHECKING:
    from premai.api.chat_completions.v1_chat_completions_create import (
        ChatCompletionResponseStream,
    )
    from premai.models.chat_completion_response import ChatCompletionResponse

logger = logging.getLogger(__name__)


class ChatPremAPIError(Exception):
    """Error with the `PremAI` API."""


def _truncate_at_stop_tokens(
    text: str,
    stop: Optional[List[str]],
) -> str:
    """Truncates text at the earliest stop token found."""
    if stop is None:
        return text

    for stop_token in stop:
        stop_token_idx = text.find(stop_token)
        if stop_token_idx != -1:
            text = text[:stop_token_idx]
    return text


def _response_to_result(
    response: ChatCompletionResponse,
    stop: Optional[List[str]],
) -> ChatResult:
    """Converts a Prem API response into a LangChain result"""

    if not response.choices:
        raise ChatPremAPIError("ChatResponse must have at least one candidate")
    generations: List[ChatGeneration] = []
    for choice in response.choices:
        role = choice.message.role
        if role is None:
            raise ChatPremAPIError(f"ChatResponse {choice} must have a role.")

        # If content is None then it will be replaced by ""
        content = _truncate_at_stop_tokens(text=choice.message.content or "", stop=stop)
        if content is None:
            raise ChatPremAPIError(f"ChatResponse must have a content: {content}")

        if role == "assistant":
            generations.append(
                ChatGeneration(text=content, message=AIMessage(content=content))
            )
        elif role == "user":
            generations.append(
                ChatGeneration(text=content, message=HumanMessage(content=content))
            )
        else:
            generations.append(
                ChatGeneration(
                    text=content, message=ChatMessage(role=role, content=content)
                )
            )

    if response.document_chunks is not None:
        return ChatResult(
            generations=generations,
            llm_output={
                "document_chunks": [
                    chunk.to_dict() for chunk in response.document_chunks
                ]
            },
        )
    else:
        return ChatResult(generations=generations, llm_output={"document_chunks": None})


def _convert_delta_response_to_message_chunk(
    response: ChatCompletionResponseStream, default_class: Type[BaseMessageChunk]
) -> Tuple[
    Union[BaseMessageChunk, HumanMessageChunk, AIMessageChunk, SystemMessageChunk],
    Optional[str],
]:
    """Converts delta response to message chunk"""
    _delta = response.choices[0].delta  # type: ignore
    role = _delta.get("role", "")  # type: ignore
    content = _delta.get("content", "")  # type: ignore
    additional_kwargs: Dict = {}
    finish_reasons: Optional[str] = response.choices[0].finish_reason

    if role == "user" or default_class == HumanMessageChunk:
        return HumanMessageChunk(content=content), finish_reasons
    elif role == "assistant" or default_class == AIMessageChunk:
        return (
            AIMessageChunk(content=content, additional_kwargs=additional_kwargs),
            finish_reasons,
        )
    elif role == "system" or default_class == SystemMessageChunk:
        return SystemMessageChunk(content=content), finish_reasons
    elif role or default_class == ChatMessageChunk:
        return ChatMessageChunk(content=content, role=role), finish_reasons
    else:
        return default_class(content=content), finish_reasons  # type: ignore[call-arg]


def _messages_to_prompt_dict(
    input_messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict[str, str]]]:
    """Converts a list of LangChain Messages into a simple dict
    which is the message structure in Prem"""

    system_prompt: Optional[str] = None
    examples_and_messages: List[Dict[str, str]] = []

    for input_msg in input_messages:
        if isinstance(input_msg, SystemMessage):
            system_prompt = str(input_msg.content)
        elif isinstance(input_msg, HumanMessage):
            examples_and_messages.append(
                {"role": "user", "content": str(input_msg.content)}
            )
        elif isinstance(input_msg, AIMessage):
            examples_and_messages.append(
                {"role": "assistant", "content": str(input_msg.content)}
            )
        else:
            raise ChatPremAPIError("No such role explicitly exists")
    return system_prompt, examples_and_messages


class ChatPremAI(BaseChatModel, BaseModel):
    """PremAI Chat models.

    To use, you will need to have an API key. You can find your existing API Key
    or generate a new one here: https://app.premai.io/api_keys/
    """

    # TODO: Need to add the default parameters through prem-sdk here

    project_id: int
    """The project ID in which the experiments or deployments are carried out. 
    You can find all your projects here: https://app.premai.io/projects/"""
    premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
    """Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""

    model: Optional[str] = Field(default=None, alias="model_name")
    """Name of the model. This is an optional parameter. 
    The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
    If model name is other than default model then it will override the calls 
    from the model deployed from launchpad."""

    temperature: Optional[float] = None
    """Model temperature. Value should be >= 0 and <= 1.0"""

    max_tokens: Optional[int] = None
    """The maximum number of tokens to generate"""

    max_retries: int = 1
    """Max number of retries to call the API"""

    system_prompt: Optional[str] = ""
    """Acts like a default instruction that helps the LLM act or generate 
    in a specific way.This is an Optional Parameter. By default the 
    system prompt would be using Prem's Launchpad models system prompt. 
    Changing the system prompt would override the default system prompt.
    """

    repositories: Optional[dict] = None
    """Add valid repository ids. This will be overriding existing connected 
    repositories (if any) and will use RAG with the connected repos. 
    """

    streaming: Optional[bool] = False
    """Whether to stream the responses or not."""

    client: Any

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        allow_population_by_field_name = True
        arbitrary_types_allowed = True

    @root_validator()
    def validate_environments(cls, values: Dict) -> Dict:
        """Validate that the package is installed and that the API token is valid"""
        try:
            from premai import Prem
        except ImportError as error:
            raise ImportError(
                "Could not import Prem Python package."
                "Please install it with: `pip install premai`"
            ) from error

        try:
            premai_api_key = get_from_dict_or_env(
                values, "premai_api_key", "PREMAI_API_KEY"
            )
            values["client"] = Prem(api_key=premai_api_key)
        except Exception as error:
            raise ValueError("Your API Key is incorrect. Please try again.") from error
        return values

    @property
    def _llm_type(self) -> str:
        return "premai"

    @property
    def _default_params(self) -> Dict[str, Any]:
        return {
            "model": self.model,
            "system_prompt": self.system_prompt,
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
            "repositories": self.repositories,
        }

    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
        kwargs_to_ignore = [
            "top_p",
            "tools",
            "frequency_penalty",
            "presence_penalty",
            "logit_bias",
            "stop",
            "seed",
        ]
        keys_to_remove = []

        for key in kwargs:
            if key in kwargs_to_ignore:
                warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.")
                keys_to_remove.append(key)

        for key in keys_to_remove:
            kwargs.pop(key)

        all_kwargs = {**self._default_params, **kwargs}
        for key in list(self._default_params.keys()):
            if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
                all_kwargs.pop(key, None)
        return all_kwargs

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)  # type: ignore

        if system_prompt is not None and system_prompt != "":
            kwargs["system_prompt"] = system_prompt

        all_kwargs = self._get_all_kwargs(**kwargs)
        response = chat_with_retry(
            self,
            project_id=self.project_id,
            messages=messages_to_pass,
            stream=False,
            run_manager=run_manager,
            **all_kwargs,
        )

        return _response_to_result(response=response, stop=stop)

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)

        if stop is not None:
            logger.warning("stop is not supported in langchain streaming")

        if "system_prompt" not in kwargs:
            if system_prompt is not None and system_prompt != "":
                kwargs["system_prompt"] = system_prompt

        all_kwargs = self._get_all_kwargs(**kwargs)

        default_chunk_class = AIMessageChunk

        for streamed_response in chat_with_retry(
            self,
            project_id=self.project_id,
            messages=messages_to_pass,
            stream=True,
            run_manager=run_manager,
            **all_kwargs,
        ):
            try:
                chunk, finish_reason = _convert_delta_response_to_message_chunk(
                    response=streamed_response, default_class=default_chunk_class
                )
                generation_info = (
                    dict(finish_reason=finish_reason)
                    if finish_reason is not None
                    else None
                )
                cg_chunk = ChatGenerationChunk(
                    message=chunk, generation_info=generation_info
                )
                if run_manager:
                    run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
                yield cg_chunk
            except Exception as _:
                continue


def create_prem_retry_decorator(
    llm: ChatPremAI,
    *,
    max_retries: int = 1,
    run_manager: Optional[Union[CallbackManagerForLLMRun]] = None,
) -> Callable[[Any], Any]:
    """Create a retry decorator for PremAI API errors."""
    import premai.models

    errors = [
        premai.models.api_response_validation_error.APIResponseValidationError,
        premai.models.conflict_error.ConflictError,
        premai.models.model_not_found_error.ModelNotFoundError,
        premai.models.permission_denied_error.PermissionDeniedError,
        premai.models.provider_api_connection_error.ProviderAPIConnectionError,
        premai.models.provider_api_status_error.ProviderAPIStatusError,
        premai.models.provider_api_timeout_error.ProviderAPITimeoutError,
        premai.models.provider_internal_server_error.ProviderInternalServerError,
        premai.models.provider_not_found_error.ProviderNotFoundError,
        premai.models.rate_limit_error.RateLimitError,
        premai.models.unprocessable_entity_error.UnprocessableEntityError,
        premai.models.validation_error.ValidationError,
    ]

    decorator = create_base_retry_decorator(
        error_types=errors, max_retries=max_retries, run_manager=run_manager
    )
    return decorator


def chat_with_retry(
    llm: ChatPremAI,
    project_id: int,
    messages: List[dict],
    stream: bool = False,
    run_manager: Optional[CallbackManagerForLLMRun] = None,
    **kwargs: Any,
) -> Any:
    """Using tenacity for retry in completion call"""
    retry_decorator = create_prem_retry_decorator(
        llm, max_retries=llm.max_retries, run_manager=run_manager
    )

    @retry_decorator
    def _completion_with_retry(
        project_id: int,
        messages: List[dict],
        stream: Optional[bool] = False,
        **kwargs: Any,
    ) -> Any:
        response = llm.client.chat.completions.create(
            project_id=project_id,
            messages=messages,
            stream=stream,
            **kwargs,
        )
        return response

    return _completion_with_retry(
        project_id=project_id,
        messages=messages,
        stream=stream,
        **kwargs,
    )