File size: 11,517 Bytes
f688422
d21cd8b
41f3d00
 
f688422
84bdc0f
 
f688422
d21cd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc61879
 
41f3d00
f688422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33d270b
 
 
 
f688422
 
33d270b
 
f688422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d21cd8b
 
e441653
33d270b
84bdc0f
105935e
d21cd8b
105935e
 
 
 
 
 
 
 
 
 
 
 
 
f688422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d21cd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105935e
 
bc61879
41f3d00
 
d21cd8b
 
 
 
 
 
 
bc61879
f688422
d21cd8b
 
 
 
 
 
 
bc61879
d21cd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f688422
d21cd8b
 
 
 
 
 
 
bc61879
d21cd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f688422
d21cd8b
 
 
 
 
 
f688422
d21cd8b
 
 
 
 
 
 
 
 
 
f688422
d21cd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41f3d00
 
d21cd8b
41f3d00
d21cd8b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import gradio as gr
from typing import TypedDict, List, Optional
import os
import pandas as pd

from climateqa.engine.talk_to_data.main import ask_drias
from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT

class DriasUIElements(TypedDict):
    tab: gr.Tab
    details_accordion: gr.Accordion
    examples_hidden: gr.Textbox
    examples: gr.Examples
    drias_direct_question: gr.Textbox
    result_text: gr.Textbox
    table_names_display: gr.DataFrame
    query_accordion: gr.Accordion
    drias_sql_query: gr.Textbox
    chart_accordion: gr.Accordion
    model_selection: gr.Dropdown
    drias_display: gr.Plot
    table_accordion: gr.Accordion
    drias_table: gr.DataFrame
    pagination_display: gr.Markdown
    prev_button: gr.Button
    next_button: gr.Button


async def ask_drias_query(query: str, index_state: int, user_id: str):
    result = await ask_drias(query, index_state, user_id)
    return result


def show_results(sql_queries_state, dataframes_state, plots_state):
    if not sql_queries_state or not dataframes_state or not plots_state:
        # If all results are empty, show "No result"
        return (
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    else:
        # Show the appropriate components with their data
        return (
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
        )


def filter_by_model(dataframes, figures, index_state, model_selection):
    df = dataframes[index_state]
    if df.empty:
        return df, None
    if "model" not in df.columns:
        return df, figures[index_state](df)
    if model_selection != "ALL":
        df = df[df["model"] == model_selection]
        if df.empty:
            return df, None
    figure = figures[index_state](df)
    return df, figure


def update_pagination(index, sql_queries):
    pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
    return pagination


def show_previous(index, sql_queries, dataframes, plots):
    if index > 0:
        index -= 1
    return (
        sql_queries[index],
        dataframes[index],
        plots[index](dataframes[index]),
        index,
    )


def show_next(index, sql_queries, dataframes, plots):
    if index < len(sql_queries) - 1:
        index += 1
    return (
        sql_queries[index],
        dataframes[index],
        plots[index](dataframes[index]),
        index,
    )


def display_table_names(table_names):
    return [table_names]


def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
    index = evt.index[1]
    figure = plots[index](dataframes[index])
    return (
        sql_queries[index],
        dataframes[index],
        figure,
        index,
    )


def create_drias_ui() -> DriasUIElements:
    """Create and return all UI elements for the DRIAS tab."""
    with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
        with gr.Accordion(label="Details") as details_accordion:
            gr.Markdown(DRIAS_UI_TEXT)
            
        # Add examples for common questions
        examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden")
        examples = gr.Examples(
            examples=[
                ["What will the temperature be like in Paris?"],
                ["What will be the total rainfall in France in 2030?"],
                ["How frequent will extreme events be in Lyon?"],
                ["Comment va évoluer la température en France entre 2030 et 2050 ?"]
            ],
            label="Example Questions",
            inputs=[examples_hidden],
            outputs=[examples_hidden],
        )
        
        with gr.Row():
            drias_direct_question = gr.Textbox(
                label="Direct Question",
                placeholder="You can write direct question here",
                elem_id="direct-question",
                interactive=True,
            )

        result_text = gr.Textbox(
            label="", elem_id="no-result-label", interactive=False, visible=True
        )

        table_names_display = gr.DataFrame(
            [], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False
        )

        with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
            drias_sql_query = gr.Textbox(
                label="", elem_id="sql-query", interactive=False
            )

        with gr.Accordion(label="Chart", visible=False) as chart_accordion:
            model_selection = gr.Dropdown(
                label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
            )
            drias_display = gr.Plot(elem_id="vanna-plot")

        with gr.Accordion(
            label="Data used", open=False, visible=False
        ) as table_accordion:
            drias_table = gr.DataFrame([], elem_id="vanna-table")

        pagination_display = gr.Markdown(
            value="", visible=False, elem_id="pagination-display"
        )

        with gr.Row():
            prev_button = gr.Button("Previous", visible=False)
            next_button = gr.Button("Next", visible=False)

        return DriasUIElements(
            tab=tab,
            details_accordion=details_accordion,
            examples_hidden=examples_hidden,
            examples=examples,
            drias_direct_question=drias_direct_question,
            result_text=result_text,
            table_names_display=table_names_display,
            query_accordion=query_accordion,
            drias_sql_query=drias_sql_query,
            chart_accordion=chart_accordion,
            model_selection=model_selection,
            drias_display=drias_display,
            table_accordion=table_accordion,
            drias_table=drias_table,
            pagination_display=pagination_display,
            prev_button=prev_button,
            next_button=next_button
        )



def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
    """Set up all event handlers for the DRIAS tab."""
    # Create state variables
    sql_queries_state = gr.State([])
    dataframes_state = gr.State([])
    plots_state = gr.State([])
    index_state = gr.State(0)
    table_names_list = gr.State([])
    user_id = gr.State(user_id)

    # Handle example selection
    ui_elements["examples_hidden"].change(
        lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
        inputs=[ui_elements["examples_hidden"]],
        outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
    ).then(
        ask_drias_query,
        inputs=[ui_elements["examples_hidden"], index_state, user_id],
        outputs=[
            ui_elements["drias_sql_query"],
            ui_elements["drias_table"],
            ui_elements["drias_display"],
            sql_queries_state,
            dataframes_state,
            plots_state,
            index_state,
            table_names_list,
            ui_elements["result_text"],
        ],
    ).then(
        show_results,
        inputs=[sql_queries_state, dataframes_state, plots_state],
        outputs=[
            ui_elements["result_text"],
            ui_elements["query_accordion"],
            ui_elements["table_accordion"],
            ui_elements["chart_accordion"],
            ui_elements["prev_button"],
            ui_elements["next_button"],
            ui_elements["pagination_display"],
            ui_elements["table_names_display"],
        ],
    ).then(
        update_pagination,
        inputs=[index_state, sql_queries_state],
        outputs=[ui_elements["pagination_display"]],
    ).then(
        display_table_names,
        inputs=[table_names_list],
        outputs=[ui_elements["table_names_display"]],
    )

    # Handle direct question submission
    ui_elements["drias_direct_question"].submit(
        lambda: gr.Accordion(open=False),
        inputs=None,
        outputs=[ui_elements["details_accordion"]]
    ).then(
        ask_drias_query,
        inputs=[ui_elements["drias_direct_question"], index_state, user_id],
        outputs=[
            ui_elements["drias_sql_query"],
            ui_elements["drias_table"],
            ui_elements["drias_display"],
            sql_queries_state,
            dataframes_state,
            plots_state,
            index_state,
            table_names_list,
            ui_elements["result_text"],
        ],
    ).then(
        show_results,
        inputs=[sql_queries_state, dataframes_state, plots_state],
        outputs=[
            ui_elements["result_text"],
            ui_elements["query_accordion"],
            ui_elements["table_accordion"],
            ui_elements["chart_accordion"],
            ui_elements["prev_button"],
            ui_elements["next_button"],
            ui_elements["pagination_display"],
            ui_elements["table_names_display"],
        ],
    ).then(
        update_pagination,
        inputs=[index_state, sql_queries_state],
        outputs=[ui_elements["pagination_display"]],
    ).then(
        display_table_names,
        inputs=[table_names_list],
        outputs=[ui_elements["table_names_display"]],
    )

    # Handle model selection change
    ui_elements["model_selection"].change(
        filter_by_model,
        inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]],
        outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
    )

    # Handle pagination buttons
    ui_elements["prev_button"].click(
        show_previous,
        inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
        outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
    ).then(
        update_pagination,
        inputs=[index_state, sql_queries_state],
        outputs=[ui_elements["pagination_display"]],
    )

    ui_elements["next_button"].click(
        show_next,
        inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
        outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
    ).then(
        update_pagination,
        inputs=[index_state, sql_queries_state],
        outputs=[ui_elements["pagination_display"]],
    )

    # Handle table selection
    ui_elements["table_names_display"].select(
        fn=on_table_click,
        inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
        outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
    ).then(
        update_pagination,
        inputs=[index_state, sql_queries_state],
        outputs=[ui_elements["pagination_display"]],
    )

def create_drias_tab(share_client=None, user_id=None):
    """Create the DRIAS tab with all its components and event handlers."""
    ui_elements = create_drias_ui()
    setup_drias_events(ui_elements, share_client=share_client, user_id=user_id)