File size: 2,067 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Util that calls Dataherald."""
from typing import Any, Dict, Optional

from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env


class DataheraldAPIWrapper(BaseModel):
    """Wrapper for Dataherald.

    Docs for using:

    1. Go to dataherald and sign up
    2. Create an API key
    3. Save your API key into DATAHERALD_API_KEY env variable
    4. pip install dataherald

    """

    dataherald_client: Any  #: :meta private:
    db_connection_id: str
    dataherald_api_key: Optional[str] = None

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        dataherald_api_key = get_from_dict_or_env(
            values, "dataherald_api_key", "DATAHERALD_API_KEY"
        )
        values["dataherald_api_key"] = dataherald_api_key

        try:
            import dataherald

        except ImportError:
            raise ImportError(
                "dataherald is not installed. "
                "Please install it with `pip install dataherald`"
            )

        client = dataherald.Dataherald(api_key=dataherald_api_key)
        values["dataherald_client"] = client

        return values

    def run(self, prompt: str) -> str:
        """Generate a sql query through Dataherald and parse result."""
        from dataherald.types.sql_generation_create_params import Prompt

        prompt_obj = Prompt(text=prompt, db_connection_id=self.db_connection_id)
        res = self.dataherald_client.sql_generations.create(prompt=prompt_obj)

        try:
            answer = res.sql
            if not answer:
                # We don't want to return the assumption alone if answer is empty
                return "No answer"
            else:
                return f"Answer: {answer}"

        except StopIteration:
            return "Dataherald wasn't able to answer it"