File size: 5,886 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from __future__ import annotations

import io
import json
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader


class AthenaLoader(BaseLoader):
    """Load documents from `AWS Athena`.

    Each document represents one row of the result.
    - By default, all columns are written into the `page_content` of the document
    and none into the `metadata` of the document.
    - If `metadata_columns` are provided then these columns are written
    into the `metadata` of the document while the rest of the columns
    are written into the `page_content` of the document.

    To authenticate, the AWS client uses this method to automatically load credentials:
    https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

    If a specific credential profile should be used, you must pass
    the name of the profile from the ~/.aws/credentials file that is to be used.

    Make sure the credentials / roles used have the required policies to
    access the Amazon Textract service.
    """

    def __init__(
        self,
        query: str,
        database: str,
        s3_output_uri: str,
        profile_name: str,
        metadata_columns: Optional[List[str]] = None,
    ):
        """Initialize Athena document loader.

        Args:
            query: The query to run in Athena.
            database: Athena database
            s3_output_uri: Athena output path
            metadata_columns: Optional. Columns written to Document `metadata`.
        """
        self.query = query
        self.database = database
        self.s3_output_uri = s3_output_uri
        self.metadata_columns = metadata_columns if metadata_columns is not None else []

        try:
            import boto3
        except ImportError:
            raise ImportError(
                "Could not import boto3 python package. "
                "Please install it with `pip install boto3`."
            )

        try:
            session = (
                boto3.Session(profile_name=profile_name)
                if profile_name is not None
                else boto3.Session()
            )
        except Exception as e:
            raise ValueError(
                "Could not load credentials to authenticate with AWS client. "
                "Please check that credentials in the specified "
                "profile name are valid."
            ) from e

        self.athena_client = session.client("athena")
        self.s3_client = session.client("s3")

    def _execute_query(self) -> List[Dict[str, Any]]:
        response = self.athena_client.start_query_execution(
            QueryString=self.query,
            QueryExecutionContext={"Database": self.database},
            ResultConfiguration={"OutputLocation": self.s3_output_uri},
        )
        query_execution_id = response["QueryExecutionId"]
        while True:
            response = self.athena_client.get_query_execution(
                QueryExecutionId=query_execution_id
            )
            state = response["QueryExecution"]["Status"]["State"]
            if state == "SUCCEEDED":
                break
            elif state == "FAILED":
                resp_status = response["QueryExecution"]["Status"]
                state_change_reason = resp_status["StateChangeReason"]
                err = f"Query Failed: {state_change_reason}"
                raise Exception(err)
            elif state == "CANCELLED":
                raise Exception("Query was cancelled by the user.")
            time.sleep(1)

        result_set = self._get_result_set(query_execution_id)
        return json.loads(result_set.to_json(orient="records"))

    def _remove_suffix(self, input_string: str, suffix: str) -> str:
        if suffix and input_string.endswith(suffix):
            return input_string[: -len(suffix)]
        return input_string

    def _remove_prefix(self, input_string: str, suffix: str) -> str:
        if suffix and input_string.startswith(suffix):
            return input_string[len(suffix) :]
        return input_string

    def _get_result_set(self, query_execution_id: str) -> Any:
        try:
            import pandas as pd
        except ImportError:
            raise ImportError(
                "Could not import pandas python package. "
                "Please install it with `pip install pandas`."
            )

        output_uri = self.s3_output_uri
        tokens = self._remove_prefix(
            self._remove_suffix(output_uri, "/"), "s3://"
        ).split("/")
        bucket = tokens[0]
        key = "/".join(tokens[1:] + [query_execution_id]) + ".csv"

        obj = self.s3_client.get_object(Bucket=bucket, Key=key)
        df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8")
        return df

    def _get_columns(
        self, query_result: List[Dict[str, Any]]
    ) -> Tuple[List[str], List[str]]:
        content_columns = []
        metadata_columns = []
        all_columns = list(query_result[0].keys())
        for key in all_columns:
            if key in self.metadata_columns:
                metadata_columns.append(key)
            else:
                content_columns.append(key)

        return content_columns, metadata_columns

    def lazy_load(self) -> Iterator[Document]:
        query_result = self._execute_query()
        content_columns, metadata_columns = self._get_columns(query_result)
        for row in query_result:
            page_content = "\n".join(
                f"{k}: {v}" for k, v in row.items() if k in content_columns
            )
            metadata = {
                k: v for k, v in row.items() if k in metadata_columns and v is not None
            }
            doc = Document(page_content=page_content, metadata=metadata)
            yield doc