File size: 2,575 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Any, Callable, List, Mapping, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field

from langchain_community.llms.utils import enforce_stop_tokens


def _display_prompt(prompt: str) -> None:
    """Displays the given prompt to the user."""
    print(f"\n{prompt}")  # noqa: T201


def _collect_user_input(
    separator: Optional[str] = None, stop: Optional[List[str]] = None
) -> str:
    """Collects and returns user input as a single string."""
    separator = separator or "\n"
    lines = []

    while True:
        line = input()
        if not line:
            break
        lines.append(line)

        if stop and any(seq in line for seq in stop):
            break
    # Combine all lines into a single string
    multi_line_input = separator.join(lines)
    return multi_line_input


class HumanInputLLM(LLM):
    """User input as the response."""

    input_func: Callable = Field(default_factory=lambda: _collect_user_input)
    prompt_func: Callable[[str], None] = Field(default_factory=lambda: _display_prompt)
    separator: str = "\n"
    input_kwargs: Mapping[str, Any] = {}
    prompt_kwargs: Mapping[str, Any] = {}

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """
        Returns an empty dictionary as there are no identifying parameters.
        """
        return {}

    @property
    def _llm_type(self) -> str:
        """Returns the type of LLM."""
        return "human-input"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """
        Displays the prompt to the user and returns their input as a response.

        Args:
            prompt (str): The prompt to be displayed to the user.
            stop (Optional[List[str]]): A list of stop strings.
            run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.

        Returns:
            str: The user's input as a response.
        """
        self.prompt_func(prompt, **self.prompt_kwargs)
        user_input = self.input_func(
            separator=self.separator, stop=stop, **self.input_kwargs
        )

        if stop is not None:
            # I believe this is required since the stop tokens
            # are not enforced by the human themselves
            user_input = enforce_stop_tokens(user_input, stop)
        return user_input