File size: 9,183 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
from __future__ import annotations  # allows pydantic model to reference itself

import re
from typing import Any, List, Optional, Union

from langchain_community.graphs.networkx_graph import NetworkxEntityGraph

from langchain_experimental.cpal.constants import Constant
from langchain_experimental.pydantic_v1 import (
    BaseModel,
    Field,
    PrivateAttr,
    root_validator,
    validator,
)


class NarrativeModel(BaseModel):
    """
    Narrative input as three story elements.
    """

    story_outcome_question: str
    story_hypothetical: str
    story_plot: str  # causal stack of operations

    @validator("*", pre=True)
    def empty_str_to_none(cls, v: str) -> Union[str, None]:
        """Empty strings are not allowed"""
        if v == "":
            return None
        return v


class EntityModel(BaseModel):
    """Entity in the story."""

    name: str = Field(description="entity name")
    code: str = Field(description="entity actions")
    value: float = Field(description="entity initial value")
    depends_on: List[str] = Field(default=[], description="ancestor entities")

    # TODO: generalize to multivariate math
    # TODO: acyclic graph

    class Config:
        validate_assignment = True

    @validator("name")
    def lower_case_name(cls, v: str) -> str:
        v = v.lower()
        return v


class CausalModel(BaseModel):
    """Casual data."""

    attribute: str = Field(description="name of the attribute to be calculated")
    entities: List[EntityModel] = Field(description="entities in the story")

    # TODO: root validate each `entity.depends_on` using system's entity names


class EntitySettingModel(BaseModel):
    """Entity initial conditions.

    Initial conditions for an entity

    {"name": "bud", "attribute": "pet_count", "value": 12}
    """

    name: str = Field(description="name of the entity")
    attribute: str = Field(description="name of the attribute to be calculated")
    value: float = Field(description="entity's attribute value (calculated)")

    @validator("name")
    def lower_case_transform(cls, v: str) -> str:
        v = v.lower()
        return v


class SystemSettingModel(BaseModel):
    """System initial conditions.

    Initial global conditions for the system.

    {"parameter": "interest_rate", "value": .05}
    """

    parameter: str
    value: float


class InterventionModel(BaseModel):
    """Intervention data of the story aka initial conditions.

    >>> intervention.dict()
    {
        entity_settings: [
            {"name": "bud", "attribute": "pet_count", "value": 12},
            {"name": "pat", "attribute": "pet_count", "value": 0},
        ],
        system_settings: None,
    }
    """

    entity_settings: List[EntitySettingModel]
    system_settings: Optional[List[SystemSettingModel]] = None

    @validator("system_settings")
    def lower_case_name(cls, v: str) -> Union[str, None]:
        if v is not None:
            raise NotImplementedError("system_setting is not implemented yet")
        return v


class QueryModel(BaseModel):
    """Query data of the story.

    translate a question about the story outcome into a programmatic expression"""

    question: str = Field(alias=Constant.narrative_input.value)  # input
    expression: str  # output, part of llm completion
    llm_error_msg: str  # output, part of llm completion
    _result_table: str = PrivateAttr()  # result of the executed query


class ResultModel(BaseModel):
    """Result of the story query."""

    question: str = Field(alias=Constant.narrative_input.value)  # input
    _result_table: str = PrivateAttr()  # result of the executed query


class StoryModel(BaseModel):
    """Story data."""

    causal_operations: Any = Field(required=True)
    intervention: Any = Field(required=True)
    query: Any = Field(required=True)
    _outcome_table: Any = PrivateAttr(default=None)
    _networkx_wrapper: Any = PrivateAttr(default=None)

    def __init__(self, **kwargs: Any):
        super().__init__(**kwargs)
        self._compute()

        # TODO: when langchain adopts pydantic.v2 replace w/ `__post_init__`
        # misses hints github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214

    # TODO: move away from `root_validator` since it is deprecated in pydantic v2
    #       and causes mypy type-checking failures (hence the `type: ignore`)
    @root_validator  # type: ignore[call-overload]
    def check_intervention_is_valid(cls, values: dict) -> dict:
        valid_names = [e.name for e in values["causal_operations"].entities]
        for setting in values["intervention"].entity_settings:
            if setting.name not in valid_names:
                error_msg = f"""
                    Hypothetical question has an invalid entity name.
                    `{setting.name}` not in `{valid_names}`
                """
                raise ValueError(error_msg)
        return values

    def _block_back_door_paths(self) -> None:
        # stop intervention entities from depending on others
        intervention_entities = [
            entity_setting.name for entity_setting in self.intervention.entity_settings
        ]
        for entity in self.causal_operations.entities:
            if entity.name in intervention_entities:
                entity.depends_on = []
                entity.code = "pass"

    def _set_initial_conditions(self) -> None:
        for entity_setting in self.intervention.entity_settings:
            for entity in self.causal_operations.entities:
                if entity.name == entity_setting.name:
                    entity.value = entity_setting.value

    def _make_graph(self) -> None:
        self._networkx_wrapper = NetworkxEntityGraph()
        for entity in self.causal_operations.entities:
            for parent_name in entity.depends_on:
                self._networkx_wrapper._graph.add_edge(
                    parent_name, entity.name, relation=entity.code
                )

        # TODO: is it correct to drop entities with no impact on the outcome (?)
        self.causal_operations.entities = [
            entity
            for entity in self.causal_operations.entities
            if entity.name in self._networkx_wrapper.get_topological_sort()
        ]

    def _sort_entities(self) -> None:
        # order the sequence of causal actions
        sorted_nodes = self._networkx_wrapper.get_topological_sort()
        self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))

    def _forward_propagate(self) -> None:
        try:
            import pandas as pd
        except ImportError as e:
            raise ImportError(
                "Unable to import pandas, please install with `pip install pandas`."
            ) from e
        entity_scope = {
            entity.name: entity for entity in self.causal_operations.entities
        }
        for entity in self.causal_operations.entities:
            if entity.code == "pass":
                continue
            else:
                # gist.github.com/dean0x7d/df5ce97e4a1a05be4d56d1378726ff92
                exec(entity.code, globals(), entity_scope)
        row_values = [entity.dict() for entity in entity_scope.values()]
        self._outcome_table = pd.DataFrame(row_values)

    def _run_query(self) -> None:
        def humanize_sql_error_msg(error: str) -> str:
            pattern = r"column\s+(.*?)\s+not found"
            col_match = re.search(pattern, error)
            if col_match:
                return (
                    "SQL error: "
                    + col_match.group(1)
                    + " is not an attribute in your story!"
                )
            else:
                return str(error)

        if self.query.llm_error_msg == "":
            try:
                import duckdb

                df = self._outcome_table  # noqa
                query_result = duckdb.sql(self.query.expression).df()
                self.query._result_table = query_result
            except duckdb.BinderException as e:
                self.query._result_table = humanize_sql_error_msg(str(e))
            except ImportError as e:
                raise ImportError(
                    "Unable to import duckdb, please install with `pip install duckdb`."
                ) from e
            except Exception as e:
                self.query._result_table = str(e)
        else:
            msg = "LLM maybe failed to translate question to SQL query."
            raise ValueError(
                {
                    "question": self.query.question,
                    "llm_error_msg": self.query.llm_error_msg,
                    "msg": msg,
                }
            )

    def _compute(self) -> Any:
        self._block_back_door_paths()
        self._set_initial_conditions()
        self._make_graph()
        self._sort_entities()
        self._forward_propagate()
        self._run_query()

    def print_debug_report(self) -> None:
        report = {
            "outcome": self._outcome_table,
            "query": self.query.dict(),
            "result": self.query._result_table,
        }
        from pprint import pprint

        pprint(report)