File size: 8,786 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations

import json
from typing import Any, Dict, List, NamedTuple, Optional, cast

from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from requests import Response

from langchain_community.tools.openapi.utils.api_models import APIOperation
from langchain_community.utilities.requests import Requests


class _ParamMapping(NamedTuple):
    """Mapping from parameter name to parameter value."""

    query_params: List[str]
    body_params: List[str]
    path_params: List[str]


class OpenAPIEndpointChain(Chain, BaseModel):
    """Chain interacts with an OpenAPI endpoint using natural language."""

    api_request_chain: LLMChain
    api_response_chain: Optional[LLMChain]
    api_operation: APIOperation
    requests: Requests = Field(exclude=True, default_factory=Requests)
    param_mapping: _ParamMapping = Field(alias="param_mapping")
    return_intermediate_steps: bool = False
    instructions_key: str = "instructions"  #: :meta private:
    output_key: str = "output"  #: :meta private:
    max_text_length: Optional[int] = Field(ge=0)  #: :meta private:

    @property
    def input_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        return [self.instructions_key]

    @property
    def output_keys(self) -> List[str]:
        """Expect output key.

        :meta private:
        """
        if not self.return_intermediate_steps:
            return [self.output_key]
        else:
            return [self.output_key, "intermediate_steps"]

    def _construct_path(self, args: Dict[str, str]) -> str:
        """Construct the path from the deserialized input."""
        path = self.api_operation.base_url + self.api_operation.path
        for param in self.param_mapping.path_params:
            path = path.replace(f"{{{param}}}", str(args.pop(param, "")))
        return path

    def _extract_query_params(self, args: Dict[str, str]) -> Dict[str, str]:
        """Extract the query params from the deserialized input."""
        query_params = {}
        for param in self.param_mapping.query_params:
            if param in args:
                query_params[param] = args.pop(param)
        return query_params

    def _extract_body_params(self, args: Dict[str, str]) -> Optional[Dict[str, str]]:
        """Extract the request body params from the deserialized input."""
        body_params = None
        if self.param_mapping.body_params:
            body_params = {}
            for param in self.param_mapping.body_params:
                if param in args:
                    body_params[param] = args.pop(param)
        return body_params

    def deserialize_json_input(self, serialized_args: str) -> dict:
        """Use the serialized typescript dictionary.

        Resolve the path, query params dict, and optional requestBody dict.
        """
        args: dict = json.loads(serialized_args)
        path = self._construct_path(args)
        body_params = self._extract_body_params(args)
        query_params = self._extract_query_params(args)
        return {
            "url": path,
            "data": body_params,
            "params": query_params,
        }

    def _get_output(self, output: str, intermediate_steps: dict) -> dict:
        """Return the output from the API call."""
        if self.return_intermediate_steps:
            return {
                self.output_key: output,
                "intermediate_steps": intermediate_steps,
            }
        else:
            return {self.output_key: output}

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        intermediate_steps = {}
        instructions = inputs[self.instructions_key]
        instructions = instructions[: self.max_text_length]
        _api_arguments = self.api_request_chain.predict_and_parse(
            instructions=instructions, callbacks=_run_manager.get_child()
        )
        api_arguments = cast(str, _api_arguments)
        intermediate_steps["request_args"] = api_arguments
        _run_manager.on_text(
            api_arguments, color="green", end="\n", verbose=self.verbose
        )
        if api_arguments.startswith("ERROR"):
            return self._get_output(api_arguments, intermediate_steps)
        elif api_arguments.startswith("MESSAGE:"):
            return self._get_output(
                api_arguments[len("MESSAGE:") :], intermediate_steps
            )
        try:
            request_args = self.deserialize_json_input(api_arguments)
            method = getattr(self.requests, self.api_operation.method.value)
            api_response: Response = method(**request_args)
            if api_response.status_code != 200:
                method_str = str(self.api_operation.method.value)
                response_text = (
                    f"{api_response.status_code}: {api_response.reason}"
                    + f"\nFor {method_str.upper()}  {request_args['url']}\n"
                    + f"Called with args: {request_args['params']}"
                )
            else:
                response_text = api_response.text
        except Exception as e:
            response_text = f"Error with message {str(e)}"
        response_text = response_text[: self.max_text_length]
        intermediate_steps["response_text"] = response_text
        _run_manager.on_text(
            response_text, color="blue", end="\n", verbose=self.verbose
        )
        if self.api_response_chain is not None:
            _answer = self.api_response_chain.predict_and_parse(
                response=response_text,
                instructions=instructions,
                callbacks=_run_manager.get_child(),
            )
            answer = cast(str, _answer)
            _run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose)
            return self._get_output(answer, intermediate_steps)
        else:
            return self._get_output(response_text, intermediate_steps)

    @classmethod
    def from_url_and_method(
        cls,
        spec_url: str,
        path: str,
        method: str,
        llm: BaseLanguageModel,
        requests: Optional[Requests] = None,
        return_intermediate_steps: bool = False,
        **kwargs: Any,
        # TODO: Handle async
    ) -> "OpenAPIEndpointChain":
        """Create an OpenAPIEndpoint from a spec at the specified url."""
        operation = APIOperation.from_openapi_url(spec_url, path, method)
        return cls.from_api_operation(
            operation,
            requests=requests,
            llm=llm,
            return_intermediate_steps=return_intermediate_steps,
            **kwargs,
        )

    @classmethod
    def from_api_operation(
        cls,
        operation: APIOperation,
        llm: BaseLanguageModel,
        requests: Optional[Requests] = None,
        verbose: bool = False,
        return_intermediate_steps: bool = False,
        raw_response: bool = False,
        callbacks: Callbacks = None,
        **kwargs: Any,
        # TODO: Handle async
    ) -> "OpenAPIEndpointChain":
        """Create an OpenAPIEndpointChain from an operation and a spec."""
        param_mapping = _ParamMapping(
            query_params=operation.query_params,
            body_params=operation.body_params,
            path_params=operation.path_params,
        )
        requests_chain = APIRequesterChain.from_llm_and_typescript(
            llm,
            typescript_definition=operation.to_typescript(),
            verbose=verbose,
            callbacks=callbacks,
        )
        if raw_response:
            response_chain = None
        else:
            response_chain = APIResponderChain.from_llm(
                llm, verbose=verbose, callbacks=callbacks
            )
        _requests = requests or Requests()
        return cls(
            api_request_chain=requests_chain,
            api_response_chain=response_chain,
            api_operation=operation,
            requests=_requests,
            param_mapping=param_mapping,
            verbose=verbose,
            return_intermediate_steps=return_intermediate_steps,
            callbacks=callbacks,
            **kwargs,
        )