File size: 2,347 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Experimental implementation of RELLM wrapped LLM."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Optional, cast

from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_core.callbacks.manager import CallbackManagerForLLMRun

from langchain_experimental.pydantic_v1 import Field, root_validator

if TYPE_CHECKING:
    import rellm
    from regex import Pattern as RegexPattern
else:
    try:
        from regex import Pattern as RegexPattern
    except ImportError:
        pass


def import_rellm() -> rellm:
    """Lazily import of the rellm package."""
    try:
        import rellm
    except ImportError:
        raise ImportError(
            "Could not import rellm python package. "
            "Please install it with `pip install rellm`."
        )
    return rellm


class RELLM(HuggingFacePipeline):
    """RELLM wrapped LLM using HuggingFace Pipeline API."""

    regex: RegexPattern = Field(..., description="The structured format to complete.")
    max_new_tokens: int = Field(
        default=200, description="Maximum number of new tokens to generate."
    )

    # TODO: move away from `root_validator` since it is deprecated in pydantic v2
    #       and causes mypy type-checking failures (hence the `type: ignore`)
    @root_validator  # type: ignore[call-overload]
    def check_rellm_installation(cls, values: dict) -> dict:
        import_rellm()
        return values

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        rellm = import_rellm()
        from transformers import Text2TextGenerationPipeline

        pipeline = cast(Text2TextGenerationPipeline, self.pipeline)

        text = rellm.complete_re(
            prompt,
            self.regex,
            tokenizer=pipeline.tokenizer,
            model=pipeline.model,
            max_new_tokens=self.max_new_tokens,
        )
        if stop is not None:
            # This is a bit hacky, but I can't figure out a better way to enforce
            # stop tokens when making calls to huggingface_hub.
            text = enforce_stop_tokens(text, stop)
        return text