File size: 3,067 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import (
    Any,
    List,
    Optional,
    Tuple,
    Type,
    TypedDict,
    Union,
)

from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config

from langchain_ai21.ai21_base import AI21Base

ANSWER_NOT_IN_CONTEXT_RESPONSE = "Answer not in context"

ContextType = Union[str, List[Union[Document, str]]]


class ContextualAnswerInput(TypedDict):
    context: ContextType
    question: str


class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base):
    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    @property
    def InputType(self) -> Type[ContextualAnswerInput]:
        """Get the input type for this runnable."""
        return ContextualAnswerInput

    @property
    def OutputType(self) -> Type[str]:
        """Get the input type for this runnable."""
        return str

    def invoke(
        self,
        input: ContextualAnswerInput,
        config: Optional[RunnableConfig] = None,
        response_if_no_answer_found: str = ANSWER_NOT_IN_CONTEXT_RESPONSE,
        **kwargs: Any,
    ) -> str:
        config = ensure_config(config)
        return self._call_with_config(
            func=lambda inner_input: self._call_contextual_answers(
                inner_input, response_if_no_answer_found
            ),
            input=input,
            config=config,
            run_type="llm",
        )

    def _call_contextual_answers(
        self,
        input: ContextualAnswerInput,
        response_if_no_answer_found: str,
    ) -> str:
        context, question = self._convert_input(input)
        response = self.client.answer.create(context=context, question=question)

        if response.answer is None:
            return response_if_no_answer_found

        return response.answer

    def _convert_input(self, input: ContextualAnswerInput) -> Tuple[str, str]:
        context, question = self._extract_context_and_question(input)

        context = self._parse_context(context)

        return context, question

    def _extract_context_and_question(
        self,
        input: ContextualAnswerInput,
    ) -> Tuple[ContextType, str]:
        context = input.get("context")
        question = input.get("question")

        if not context or not question:
            raise ValueError(
                f"Input must contain a 'context' and 'question' fields. Got {input}"
            )

        if not isinstance(context, list) and not isinstance(context, str):
            raise ValueError(
                f"Expected input to be a list of strings or Documents."
                f" Received {type(input)}"
            )

        return context, question

    def _parse_context(self, context: ContextType) -> str:
        if isinstance(context, str):
            return context

        docs = [
            item.page_content if isinstance(item, Document) else item
            for item in context
        ]

        return "\n".join(docs)