File size: 20,940 Bytes
b223991
d2cdbf2
b223991
 
 
 
 
 
 
 
 
 
 
 
f3ec7d3
5eb8a9b
 
 
 
 
 
 
 
 
 
 
 
 
a904d6c
5eb8a9b
 
 
b223991
5fdbd4f
593adc9
b223991
 
 
 
 
 
6b08849
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
031f180
b223991
 
 
 
 
 
 
 
 
 
1517eaf
6b08849
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf585c
b223991
 
65a7bbd
b223991
 
 
6b08849
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4110157
 
b223991
a265142
 
 
031f180
b223991
 
4110157
 
 
 
 
 
 
 
 
b223991
 
 
469da6e
b223991
 
 
 
 
f3ec7d3
b223991
e9e4a27
6539497
b223991
 
 
fc36c14
 
 
 
55d595f
fc36c14
 
 
 
 
 
 
 
 
 
 
454d59a
fc36c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b223991
b849a22
 
55d595f
c050c84
09c3b86
b849a22
9101813
 
 
4ddc454
9101813
4ddc454
 
 
 
 
 
 
a8d40a9
bee0f01
 
3f98cc2
 
c5a5c67
593adc9
 
 
29933e1
c5a5c67
593adc9
 
71750fd
593adc9
 
 
 
 
 
 
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a265142
 
 
 
 
 
 
 
 
4110157
a265142
 
 
 
 
 
469da6e
 
b223991
cd14702
bf92b6d
cd14702
 
 
6b08849
 
 
 
 
 
cd14702
 
b223991
 
65a7bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b223991
65a7bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b223991
 
cd14702
e778d38
cd14702
 
f3ec7d3
cd14702
 
 
 
 
65a7bbd
b223991
 
 
 
90c2fda
b223991
 
5eb8a9b
e778d38
b223991
 
 
 
 
 
 
 
 
dbf585c
b223991
 
 
65a7bbd
 
 
 
b223991
65a7bbd
 
 
 
b223991
 
 
 
 
 
4110157
469da6e
 
bf92b6d
e778d38
b223991
 
 
 
 
 
 
 
780faa4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import os
import time
import gradio as gr
import pandas as pd

from classifier import classify
from statistics import mean
from qa_summary import generate_answer


HFTOKEN = os.environ["HF_TOKEN"]



loadTwitterWidgets_js = """
    async () => {
        // Load Twitter Widgets script
        const script = document.createElement("script");
        script.onload = () => console.log("Twitter Widgets.js loaded");
        script.src = "https://platform.twitter.com/widgets.js";
        document.head.appendChild(script);

        // Define a global function to reload Twitter widgets
        globalThis.reloadTwitterWidgets = () => {
            // Reload Twitter widgets
            if (window.twttr && twttr.widgets) {
                twttr.widgets.load();
            }
                
        };
    }
"""

def T_on_select(evt: gr.SelectData):
    return evt.value 

def single_classification(text, event_model, threshold):
    res = classify(text, event_model, HFTOKEN, threshold)
    return res["event"], res["score"]
    
def load_and_classify_csv(file, text_field, event_model, threshold):
    text_field = text_field.strip()
    filepath = file.name
    if ".csv" in filepath:
        df = pd.read_csv(filepath)
    else:
        df = pd.read_table(filepath)
    
    if text_field not in df.columns:
        raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")

    labels, scores = [], []
    for post in df[text_field].to_list():
        res = classify(post, event_model, HFTOKEN, threshold)
        labels.append(res["event"])
        scores.append(res["score"])

    df["event_label"] = labels
    df["model_score"] = scores
    
    # model_confidence = round(mean(scores), 5)   
    model_confidence = mean(scores)
    fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) 
    flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
    not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
    
    return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df, gr.update(interactive=True), gr.update(interactive=True)

def load_and_classify_csv_dataframe(file, text_field, event_model, threshold): 
    text_field = text_field.strip()
    filepath = file.name
    if ".csv" in filepath:
        df = pd.read_csv(filepath)
    else:
        df = pd.read_table(filepath)
    
    if text_field not in df.columns:
        raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")

    labels, scores = [], []
    for post in df[text_field].to_list():
        res = classify(post, event_model, HFTOKEN, threshold)
        labels.append(res["event"])
        scores.append(round(res["score"], 5))

    df["event_label"] = labels
    df["model_score"] = scores
    
    result_df = df[[text_field, "event_label", "model_score", "tweet_id"]].copy()
    result_df["tweet_id"] = result_df["tweet_id"].astype(str)
    
    filters = list(result_df["event_label"].unique())
    extra_filters = ['Not-'+x for x in filters]+['All']
    
    return result_df, result_df, gr.update(choices=sorted(filters+extra_filters), 
                                                            value='All', 
                                                            label="Filter data by label", 
                                                            visible=True), gr.update(interactive=True), gr.update(interactive=True)


def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
    text_field = text_field.strip()
    posts = data_df[text_field].to_list()
    selections = flood_selections + fire_selections + none_selections
    eval = []
    for post in posts:
        if post in selections:
            eval.append("incorrect")
        else:
            eval.append("correct")

    data_df["model_eval"] = eval
    incorrect = len(selections)
    correct = num_posts - incorrect
    accuracy = (correct/num_posts)*100

    data_df.to_csv("output.csv")
    return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True) 
    
def init_queries(history):
    history = history or []
    if not history:
        history = [
        "What areas are being evacuated?",
        "What areas are predicted to be impacted?",
        "What areas are without power?",
        "What barriers are hindering response efforts?",
        "What events have been canceled?",
        "What preparations are being made?",
        "What regions have announced a state of emergency?",
        "What roads are blocked / closed?",
        "What services have been closed?",
        "What warnings are currently in effect?",
        "Where are emergency services deployed?",
        "Where are emergency services needed?",
        "Where are evacuations needed?",
        "Where are people needing rescued?",
        "Where are recovery efforts taking place?",
        "Where has building or infrastructure damage occurred?",
        "Where has flooding occured?"
        "Where are volunteers being requested?",
        "Where has road damage occured?",
        "What area has the wildfire burned?",
        "Where have homes been damaged or destroyed?"]
    
    return gr.CheckboxGroup(choices=history), history

def add_query(to_add, history):
    if to_add not in history:
        history.append(to_add)
    return gr.CheckboxGroup(choices=history), history

#def qa_summarise(selected_queries, qa_llm_model, text_field, data_df):
def qa_summarise(selected_queries, qa_llm_model, text_field, response_lang, data_df):
    
    if not selected_queries:
        raise gr.Error(f"Error: You have to select one or more queries to ask.")
    
    qa_input_df = data_df[data_df["event_label"] != "none"].reset_index()
    texts = qa_input_df[text_field].to_list()
    
    # summary = generate_answer(qa_llm_model, texts, selected_queries[0], selected_queries, mode="multi_summarize")

    summary = generate_answer(qa_llm_model, 
                              texts, 
                              selected_queries[0], 
                              selected_queries, 
                              response_lang, 
                              mode="multi_summarize")
    
    doc_df = pd.DataFrame()
    doc_df["number"] = [i+1 for i in range(len(texts))]
    doc_df["text"] = texts
    doc_df["IDs"] = qa_input_df["tweet_id"].to_list()
    
    return summary, doc_df

            
with gr.Blocks(fill_width=True) as demo:
    demo.load(None,None,None,js=loadTwitterWidgets_js)
    
    event_models = ["jayebaku/XLMRoberta-twitter-crexdata-flood-wildfire-detector",
                    "jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",]
    
    T_data_ss_state = gr.State(value=pd.DataFrame())
    

    with gr.Tab("Single Text Classification"):
        gr.Markdown(
        """
        # Single Text Classifier Demo
        In this section you test the relevance classifier with written texts.\n 
        Usage:\n
            - Type a tweet-like text in the textbox.\n
            - Then press Enter.\n
        """)
        with gr.Group():
            with gr.Row():
                with gr.Column(scale=3):
                    model_sing_classify = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
                with gr.Column(scale=7):
                    with gr.Accordion("Prediction threshold", open=False):        
                        threshold_sing_classify = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, 
                                          info="This value sets a threshold by which texts classified flood or fire are accepted, \
                                              higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True) 
                
        text_to_classify = gr.Textbox(label="Text", info="Enter tweet-like text", submit_btn=True)
        text_to_classify_examples = gr.Examples([["The streets are flooded, I can't leave #BostonStorm"], 
                                                 ["Controlado el incendio de Rodezno que ha obligado a desalojar a varias bodegas de la zona."], 
                                                 ["Cambrils:estació Renfe inundada 19 persones dins d'un tren. FGC a Capellades, petit descarrilament 5 passatgers #Inuncat @emergenciescat"],
                                                 ["Anscheinend steht die komplette Neckarwiese unter Wasser! #Hochwasser"]], text_to_classify)

        with gr.Group():
            with gr.Row():
                with gr.Column():
                    classification = gr.Textbox(label="Classification")
                with gr.Column():
                    classification_score = gr.Number(label="Classification Score")
                    
    
    with gr.Tab("Event Type Classification"):
        gr.Markdown(
        """
        # Relevance Classifier Demo 
        This is a demo created to explore floods and wildfire classification in social media posts.\n
        Upload .tsv or .csv data file (must contain a text column with social media posts), next enter the name of the text column, choose classifier model, and click 'start prediction'.
        """)
        with gr.Group():
            with gr.Row(equal_height=True):
                with gr.Column():
                    T_file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv']) 
                with gr.Column():
                    T_text_field = gr.Textbox(label="Text field name", value="tweet_text")
                    T_event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
                    with gr.Accordion("Prediction threshold", open=False):        
                        T_threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, 
                                          info="This value sets a threshold by which texts classified flood or fire are accepted, \
                                              higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)                
                    T_predict_button = gr.Button("Start Prediction")

        T_examples = gr.Examples([["./samples.tsv", "tweet_content", "jayebaku/XLMRoberta-twitter-crexdata-flood-wildfire-detector", 0.00]], 
                    inputs=[T_file_input, T_text_field, T_event_model, T_threshold])
                    
        gr.Markdown("""Select an ID cell in dataframe to view Embedded tweet""")
        T_tweetID = gr.Textbox(visible=False)
        with gr.Group():
            with gr.Row():
                with gr.Column(scale=3):
                    T_data_filter = gr.Dropdown(visible=False)
                    T_tweet_embed = gr.HTML("""<div id="tweet-container"></div>""")
                    
                with gr.Column(scale=7):
                    T_data = gr.DataFrame(#headers=["Texts", "event_label", "model_score", "IDs"], 
                                          wrap=True,
                                          show_fullscreen_button=True, 
                                          show_copy_button=True,
                                          show_row_numbers=True,
                                          show_search="filter",
                                          max_height=1000,
                                          column_widths=["49%","17%","17%","17%"])

    

    qa_tab = gr.Tab("Question Answering")
    with qa_tab:
        gr.Markdown(
        """
        # Question Answering Demo
        This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n 
        Usage:\n
            - Select queries from predefined\n
            - Parameters for QA can be editted in sidebar\n
            
        Note: QA process is disabled untill after the relevance classification is done
        """)
        with gr.Group():
            with gr.Accordion("Parameters", open=False):
                with gr.Row():
                    with gr.Column():
                        qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
                        aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
                    with gr.Column():
                        batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
                        topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)
            response_lang = gr.Dropdown(["english", "german", "catalan", "spanish"], label="Response language", value="english", interactive=True)
                        
            selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
            queries_state = gr.State()
            qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])
            query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
            QA_addqry_button = gr.Button("Add to queries", interactive=False)
            
        QA_run_button = gr.Button("Start QA", interactive=False)
        hsummary = gr.Textbox(label="Summary")

        qa_tweetID = gr.Textbox(visible=False)
        with gr.Group():
            with gr.Row():
                with gr.Column(scale=7):
                    qa_df = gr.DataFrame(wrap=True,
                                        show_fullscreen_button=True,
                                        show_copy_button=True,
                                        show_search="filter",
                                        max_height=1000,
                                        column_widths=["10%","70%","20%"])
                with gr.Column(scale=3):
                    qa_tweet_embed = gr.HTML("""<div id="tweet-container2"></div>""")
 
        
    # with gr.Tab("Event Type Classification Eval"):
    #     gr.Markdown(
    #     """
    #     # T4.5 Relevance Classifier Demo 
    #     This is a demo created to explore floods and wildfire classification in social media posts.\n
    #     Usage:\n
    #         - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
    #         - Next, type the name of the text column.\n
    #         - Then, choose a BERT classifier model from the drop down.\n
    #         - Finally, click the 'start prediction' buttton.\n
    #     Evaluation:\n
    #         - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
    #         - Then, click on the 'Calculate Accuracy' button.\n
    #         - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
    #     """)
    #     with gr.Row():
    #         with gr.Column(scale=4):
    #             file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
                
    #         with gr.Column(scale=6):
    #             text_field = gr.Textbox(label="Text field name", value="tweet_text")
    #             event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
    #             ETCE_predict_button = gr.Button("Start Prediction")
    #     with gr.Accordion("Prediction threshold", open=False):        
    #         threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, 
    #                           info="This value sets a threshold by which texts classified flood or fire are accepted, \
    #                               higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
        
    #     with gr.Row(): 
    #         with gr.Column():
    #             gr.Markdown("""### Flood-related""")
    #             flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
                
    #         with gr.Column():
    #             gr.Markdown("""### Fire-related""")
    #             fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
                
    #         with gr.Column():
    #             gr.Markdown("""### None""")
    #             none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) 

    #     with gr.Row():
    #         with gr.Column(scale=5):
    #             gr.Markdown(r"""
    #             Accuracy: is the model's ability to make correct predicitons.
    #             It is the fraction of correct prediction out of the total predictions.
                
    #             $$
    #             \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
    #             $$
                
    #             Model Confidence: is the mean probabilty of each case 
    #             belonging to their assigned classes. A value of 1 is best.
    #             """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
    #             gr.Markdown("\n\n\n")
    #             model_confidence = gr.Number(label="Model Confidence")
                
    #         with gr.Column(scale=5):
    #             correct = gr.Number(label="Number of correct classifications")
    #             incorrect = gr.Number(label="Number of incorrect classifications")
    #             accuracy = gr.Number(label="Model Accuracy (%)")

    #     ETCE_accuracy_button = gr.Button("Calculate Accuracy")
    #     download_csv = gr.DownloadButton(visible=False)
    #     num_posts = gr.Number(visible=False)
    #     data = gr.DataFrame(visible=False) 
    #     data_eval = gr.DataFrame(visible=False)

        



    createEmbedding_js = """ (x) =>
            {
                reloadTwitterWidgets();
                const tweetContainer = document.getElementById("<=CONTAINER-NAME=>");
                tweetContainer.innerHTML = "";
                twttr.widgets.createTweet(x,tweetContainer,{theme: 'dark', dnt: true, align: 'center'});
            }
        
        """
    
    # Test event listeners
    T_predict_button.click(
        load_and_classify_csv_dataframe, 
        inputs=[T_file_input, T_text_field, T_event_model, T_threshold],  
        outputs=[T_data, T_data_ss_state, T_data_filter, QA_addqry_button, QA_run_button]
        )
    
    T_data.select(T_on_select, None, T_tweetID)
    T_tweetID.change(fn=None, inputs=T_tweetID, outputs=None, js=createEmbedding_js.replace("<=CONTAINER-NAME=>", "tweet-container"))
    
    @T_data_filter.input(inputs=[T_data_ss_state, T_data_filter], outputs=T_data)
    def filter_df(df, filter): 
        if filter == "All":
            result_df = df.copy()
        elif filter.startswith("Not"):
            result_df = df[df["event_label"]!=filter.split('-')[1]].copy()
        else: 
            result_df = df[df["event_label"]==filter].copy()     
        return result_df 


    # Button clicks ETC Eval
    # ETCE_predict_button.click(
    #     load_and_classify_csv, 
    #     inputs=[file_input, text_field, event_model, threshold], 
    #     outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data, QA_addqry_button, QA_run_button])
    
    # ETCE_accuracy_button.click(
    #     calculate_accuracy, 
    #     inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data], 
    #     outputs=[incorrect, correct, accuracy, data_eval, download_csv])
    
    
    # Button clicks QA
    QA_addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])

    QA_run_button.click(qa_summarise, 
                    inputs=[selected_queries, qa_llm_model, T_text_field, response_lang, T_data_ss_state], 
                    outputs=[hsummary, qa_df])   

    qa_df.select(T_on_select, None, qa_tweetID)
    qa_tweetID.change(fn=None, inputs=qa_tweetID, outputs=None, js=createEmbedding_js.replace("<=CONTAINER-NAME=>", "tweet-container2"))
    
    
    # Event listener for single text classification
    text_to_classify.submit(
        single_classification, 
        inputs=[text_to_classify, model_sing_classify, threshold_sing_classify], 
        outputs=[classification, classification_score])
        
demo.launch()