File size: 4,459 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Chain that interprets a prompt and executes bash operations."""
from __future__ import annotations

import logging
import warnings
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel

from langchain_experimental.llm_bash.bash import BashProcess
from langchain_experimental.llm_bash.prompt import PROMPT
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator

logger = logging.getLogger(__name__)


class LLMBashChain(Chain):
    """Chain that interprets a prompt and executes bash operations.

    Example:
        .. code-block:: python

            from langchain.chains import LLMBashChain
            from langchain_community.llms import OpenAI
            llm_bash = LLMBashChain.from_llm(OpenAI())
    """

    llm_chain: LLMChain
    llm: Optional[BaseLanguageModel] = None
    """[Deprecated] LLM wrapper to use."""
    input_key: str = "question"  #: :meta private:
    output_key: str = "answer"  #: :meta private:
    prompt: BasePromptTemplate = PROMPT
    """[Deprecated]"""
    bash_process: BashProcess = Field(default_factory=BashProcess)  #: :meta private:

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def raise_deprecation(cls, values: Dict) -> Dict:
        if "llm" in values:
            warnings.warn(
                "Directly instantiating an LLMBashChain with an llm is deprecated. "
                "Please instantiate with llm_chain or using the from_llm class method."
            )
            if "llm_chain" not in values and values["llm"] is not None:
                prompt = values.get("prompt", PROMPT)
                values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
        return values

    # TODO: move away from `root_validator` since it is deprecated in pydantic v2
    #       and causes mypy type-checking failures (hence the `type: ignore`)
    @root_validator  # type: ignore[call-overload]
    def validate_prompt(cls, values: Dict) -> Dict:
        if values["llm_chain"].prompt.output_parser is None:
            raise ValueError(
                "The prompt used by llm_chain is expected to have an output_parser."
            )
        return values

    @property
    def input_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """Expect output key.

        :meta private:
        """
        return [self.output_key]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        _run_manager.on_text(inputs[self.input_key], verbose=self.verbose)

        t = self.llm_chain.predict(
            question=inputs[self.input_key], callbacks=_run_manager.get_child()
        )
        _run_manager.on_text(t, color="green", verbose=self.verbose)
        t = t.strip()
        try:
            parser = self.llm_chain.prompt.output_parser
            command_list = parser.parse(t)  # type: ignore[union-attr]
        except OutputParserException as e:
            _run_manager.on_chain_error(e, verbose=self.verbose)
            raise e

        if self.verbose:
            _run_manager.on_text("\nCode: ", verbose=self.verbose)
            _run_manager.on_text(
                str(command_list), color="yellow", verbose=self.verbose
            )
        output = self.bash_process.run(command_list)
        _run_manager.on_text("\nAnswer: ", verbose=self.verbose)
        _run_manager.on_text(output, color="yellow", verbose=self.verbose)
        return {self.output_key: output}

    @property
    def _chain_type(self) -> str:
        return "llm_bash_chain"

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        prompt: BasePromptTemplate = PROMPT,
        **kwargs: Any,
    ) -> LLMBashChain:
        llm_chain = LLMChain(llm=llm, prompt=prompt)
        return cls(llm_chain=llm_chain, **kwargs)