File size: 19,734 Bytes
b223991
d2cdbf2
b223991
 
 
 
 
 
 
 
 
 
 
 
26536f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b223991
 
 
 
 
 
 
5a752ac
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1517eaf
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf585c
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26536f5
 
 
 
 
 
 
 
 
 
ea2b41b
a14fb5f
ea2b41b
 
 
 
 
 
26536f5
 
 
 
 
a14fb5f
b223991
 
 
 
 
 
 
 
 
 
 
 
b849a22
 
 
c050c84
09c3b86
b849a22
9101813
 
 
4ddc454
9101813
4ddc454
 
 
 
 
 
 
3f98cc2
 
29933e1
 
 
 
3f98cc2
29933e1
 
 
 
 
 
 
 
 
 
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d110060
b507586
 
 
d110060
2d3291f
b223991
 
 
 
 
 
 
 
 
dbf585c
b223991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780faa4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
import os
import time
import gradio as gr
import pandas as pd

from classifier import classify
from statistics import mean
from qa_summary import generate_answer


HFTOKEN = os.environ["HF_TOKEN"]



# js = """
#     async () => {
#         // Load Twitter Widgets script
#         const script = document.createElement("script");
#         script.onload = () => console.log("Twitter Widgets.js loaded");
#         script.src = "https://platform.twitter.com/widgets.js";
#         document.head.appendChild(script);

#         // Define a global function to reload Twitter widgets
#         globalThis.reloadTwitterWidgets = () => {
#             if (window.twttr && twttr.widgets) {
#                 twttr.widgets.load();
#             }
#         };
#     }
# """

def T_on_select(evt: gr.SelectData):
    if evt.index[1] == 3:
        html = """<blockquote class="twitter-tweet" data-dnt="true" data-theme="dark">""" + \
               f"""\n<a href="https://twitter.com/anyuser/status/{evt.value}"></a></blockquote>"""
    else:
        html = f"""<h2>{evt.value}</h2>"""
    return html

def single_classification(text, event_model, threshold):
    res = classify(text, event_model, HFTOKEN, threshold)
    return res["event"], res["score"]
    
def load_and_classify_csv(file, text_field, event_model, threshold):
    filepath = file.name
    if ".csv" in filepath:
        df = pd.read_csv(filepath)
    else:
        df = pd.read_table(filepath)
    
    if text_field not in df.columns:
        raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")

    labels, scores = [], []
    for post in df[text_field].to_list():
        res = classify(post, event_model, HFTOKEN, threshold)
        labels.append(res["event"])
        scores.append(res["score"])

    df["model_label"] = labels
    df["model_score"] = scores
    
    # model_confidence = round(mean(scores), 5)   
    model_confidence = mean(scores)
    fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) 
    flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
    not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
    
    return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df, gr.update(interactive=True), gr.update(interactive=True)

def load_and_classify_csv_dataframe(file, text_field, event_model, threshold): 
    
    filepath = file.name
    if ".csv" in filepath:
        df = pd.read_csv(filepath)
    else:
        df = pd.read_table(filepath)
    
    if text_field not in df.columns:
        raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")

    labels, scores = [], []
    for post in df[text_field].to_list():
        res = classify(post, event_model, HFTOKEN, threshold)
        labels.append(res["event"])
        scores.append(round(res["score"], 5))

    df["event_label"] = labels
    df["model_score"] = scores
    
    result_df = df[[text_field, "event_label", "model_score", "tweet_id"]].copy()
    result_df["tweet_id"] = result_df["tweet_id"].astype(str)
    
    filters = list(result_df["event_label"].unique())
    extra_filters = ['Not-'+x for x in filters]+['All']
    
    return result_df, result_df, gr.update(choices=sorted(filters+extra_filters), 
                                                            value='All', 
                                                            label="Filter data by label", 
                                                            visible=True)


def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
    posts = data_df[text_field].to_list()
    selections = flood_selections + fire_selections + none_selections
    eval = []
    for post in posts:
        if post in selections:
            eval.append("incorrect")
        else:
            eval.append("correct")

    data_df["model_eval"] = eval
    incorrect = len(selections)
    correct = num_posts - incorrect
    accuracy = (correct/num_posts)*100

    data_df.to_csv("output.csv")
    return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True) 
    
def init_queries(history):
    history = history or []
    if not history:
        history = [
        "What areas are being evacuated?",
        "What areas are predicted to be impacted?",
        "What areas are without power?",
        "What barriers are hindering response efforts?",
        "What events have been canceled?",
        "What preparations are being made?",
        "What regions have announced a state of emergency?",
        "What roads are blocked / closed?",
        "What services have been closed?",
        "What warnings are currently in effect?",
        "Where are emergency services deployed?",
        "Where are emergency services needed?",
        "Where are evacuations needed?",
        "Where are people needing rescued?",
        "Where are recovery efforts taking place?",
        "Where has building or infrastructure damage occurred?",
        "Where has flooding occured?"
        "Where are volunteers being requested?",
        "Where has road damage occured?",
        "What area has the wildfire burned?",
        "Where have homes been damaged or destroyed?"]
    
    return gr.CheckboxGroup(choices=history), history

def add_query(to_add, history):
    if to_add not in history:
        history.append(to_add)
    return gr.CheckboxGroup(choices=history), history

def qa_summarise(selected_queries, qa_llm_model, text_field, data_df):
    
    qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
    texts = qa_input_df[text_field].to_list()
    
    summary = generate_answer(qa_llm_model, texts, selected_queries[0], selected_queries, mode="multi_summarize")

    doc_df = pd.DataFrame()
    doc_df["number"] = [i+1 for i in range(len(texts))]
    doc_df["text"] = texts
    
    return summary, doc_df

            
with gr.Blocks(fill_width=True) as demo:
    js = """
        async () => {
            // Load Twitter Widgets script
            const script = document.createElement("script");
            script.onload = () => console.log("Twitter Widgets.js loaded");
            script.src = "https://platform.twitter.com/widgets.js";
            document.head.appendChild(script);
    
            // Define a global function to reload Twitter widgets
            globalThis.reloadTwitterWidgets = () => {
                // Select the container where tweets are inserted
                const tweetContainer = document.getElementById("twitter-tweet");
    
                if (tweetContainer) {
                    tweetContainer.innerHTML = ""; // Clear previous tweets
                }
    
                // Reload Twitter widgets
                if (window.twttr && twttr.widgets) {
                    twttr.widgets.load();
                }
            };
        }
    """ #tweet-container 
    
    demo.load(None,None,None,js=js)
    
    event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",
                    "jayebaku/distilbert-base-multilingual-cased-weather-classifier-2",
                    "jayebaku/twitter-xlm-roberta-base-crexdata-relevance-classifier",
                    "jayebaku/twhin-bert-base-crexdata-relevance-classifier"]
    
    T_data_ss_state = gr.State(value=pd.DataFrame())
    
    
    with gr.Tab("Event Type Classification"):
        gr.Markdown(
        """
        # T4.5 Relevance Classifier Demo 
        This is a demo created to explore floods and wildfire classification in social media posts.\n
        Upload .tsv or .csv data file (must contain a text column with social media posts), next enter the name of the text column, choose classifier model, and click 'start prediction'.
        """)
        with gr.Group():
            with gr.Row(equal_height=True):
                with gr.Column():
                    T_file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv']) 
                with gr.Column():
                    T_text_field = gr.Textbox(label="Text field name", value="tweet_text")
                    T_event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
                    with gr.Accordion("Prediction threshold", open=False):        
                        T_threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, 
                                          info="This value sets a threshold by which texts classified flood or fire are accepted, \
                                              higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)                
                    T_predict_button = gr.Button("Start Prediction")
                    
        gr.Markdown("""Select an ID cell in dataframe to view Embedded tweet""")
        with gr.Group():
            with gr.Row():
                with gr.Column(scale=4):
                    T_data_filter = gr.Dropdown(visible=False)
                    T_tweet_embed = gr.HTML()
                    
                with gr.Column(scale=6):
                    T_data = gr.DataFrame(headers=["Texts", "event_label", "model_score", "IDs"], 
                                          wrap=True,
                                          show_fullscreen_button=True, 
                                          show_copy_button=True,
                                          show_row_numbers=True,
                                          show_search="filter",
                                          max_height=1000,
                                          column_widths=["49%","17%","17%","17%"])

    

    with gr.Tab("Event Type Classification Eval"):
        gr.Markdown(
        """
        # T4.5 Relevance Classifier Demo 
        This is a demo created to explore floods and wildfire classification in social media posts.\n
        Usage:\n
            - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
            - Next, type the name of the text column.\n
            - Then, choose a BERT classifier model from the drop down.\n
            - Finally, click the 'start prediction' buttton.\n
        Evaluation:\n
            - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
            - Then, click on the 'Calculate Accuracy' button.\n
            - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
        """)
        with gr.Row():
            with gr.Column(scale=4):
                file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
                
            with gr.Column(scale=6):
                text_field = gr.Textbox(label="Text field name", value="tweet_text")
                event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
                ETCE_predict_button = gr.Button("Start Prediction")
        with gr.Accordion("Prediction threshold", open=False):        
            threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False, 
                              info="This value sets a threshold by which texts classified flood or fire are accepted, \
                                  higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
        
        with gr.Row(): # XXX confirm this is not a problem later --equal_height=True
            with gr.Column():
                gr.Markdown("""### Flood-related""")
                flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
                
            with gr.Column():
                gr.Markdown("""### Fire-related""")
                fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
                
            with gr.Column():
                gr.Markdown("""### None""")
                none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) 

        with gr.Row():
            with gr.Column(scale=5):
                gr.Markdown(r"""
                Accuracy: is the model's ability to make correct predicitons.
                It is the fraction of correct prediction out of the total predictions.
                
                $$
                \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
                $$
                
                Model Confidence: is the mean probabilty of each case 
                belonging to their assigned classes. A value of 1 is best.
                """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
                gr.Markdown("\n\n\n")
                model_confidence = gr.Number(label="Model Confidence")
                
            with gr.Column(scale=5):
                correct = gr.Number(label="Number of correct classifications")
                incorrect = gr.Number(label="Number of incorrect classifications")
                accuracy = gr.Number(label="Model Accuracy (%)")

        ETCE_accuracy_button = gr.Button("Calculate Accuracy")
        download_csv = gr.DownloadButton(visible=False)
        num_posts = gr.Number(visible=False)
        data = gr.DataFrame(visible=False) 
        data_eval = gr.DataFrame(visible=False)
        

    qa_tab = gr.Tab("Question Answering")
    with qa_tab:
        gr.Markdown(
        """
        # Question Answering Demo
        This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n 
        Usage:\n
            - Select queries from predefined\n
            - Parameters for QA can be editted in sidebar\n
            
        Note: QA process is disabled untill after the relevance classification is done
        """)
        
        with gr.Accordion("Parameters", open=False):
            with gr.Row():
                with gr.Column():
                    qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
                    aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
                with gr.Column():
                    batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
                    topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)
                    
        selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
        queries_state = gr.State()
        qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])

        query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
        QA_addqry_button = gr.Button("Add to queries", interactive=False)
        QA_run_button = gr.Button("Start QA", interactive=False)
        hsummary = gr.Textbox(label="Summary")
        
        qa_df = gr.DataFrame()


    with gr.Tab("Single Text Classification"):
        gr.Markdown(
        """
        # Event Type Prediction Demo
        In this section you test the relevance classifier with written texts.\n 
        Usage:\n
            - Type a tweet-like text in the textbox.\n
            - Then press Enter.\n
        """)
        with gr.Row():
            with gr.Column(scale=3):
                model_sing_classify = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
            with gr.Column(scale=7):
                threshold_sing_classify = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", 
                              info="This value sets a threshold by which texts classified flood or fire are accepted, \
                                  higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
                
        text_to_classify = gr.Textbox(label="Text", info="Enter tweet-like text", submit_btn=True)
        text_to_classify_examples = gr.Examples([["The streets are flooded, I can't leave #BostonStorm"], 
                                                 ["Controlado el incendio de Rodezno que ha obligado a desalojar a varias bodegas de la zona."], 
                                                 ["Cambrils:estació Renfe inundada 19 persones dins d'un tren. FGC a Capellades, petit descarrilament 5 passatgers #Inuncat @emergenciescat"],
                                                 ["Anscheinend steht die komplette Neckarwiese unter Wasser! #Hochwasser"]], text_to_classify)
        
        with gr.Row():
            with gr.Column():
                classification = gr.Textbox(label="Classification")
            with gr.Column():
                classification_score = gr.Number(label="Classification Score")
 
        

                


         
        
    # Test event listeners
    T_predict_button.click(
        load_and_classify_csv_dataframe, 
        inputs=[T_file_input, T_text_field, T_event_model, T_threshold],  
        outputs=[T_data, T_data_ss_state, T_data_filter]
        )
    
    T_data.select(T_on_select, None, T_tweet_embed)#.then(fn=None, js="reloadTwitterWidgets()")
    # T_data.select(
    #     fn=lambda: gr.update(value=""),
    #     outputs=T_tweet_embed).then(T_on_select, None, T_tweet_embed).then(fn=None, js="reloadTwitterWidgets()")

    T_tweet_embed.change(fn=None, scroll_to_output=True, js="reloadTwitterWidgets()")
    
    @T_data_filter.input(inputs=[T_data_ss_state, T_data_filter], outputs=T_data)
    def filter_df(df, filter): 
        if filter == "All":
            result_df = df.copy()
        elif filter.startswith("Not"):
            result_df = df[df["event_label"]!=filter.split('-')[1]].copy()
        else: 
            result_df = df[df["event_label"]==filter].copy()     
        return result_df 


    # Button clicks ETC Eval
    ETCE_predict_button.click(
        load_and_classify_csv, 
        inputs=[file_input, text_field, event_model, threshold], 
        outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data, QA_addqry_button, QA_run_button])
    
    ETCE_accuracy_button.click(
        calculate_accuracy, 
        inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data], 
        outputs=[incorrect, correct, accuracy, data_eval, download_csv])
    
    
    # Button clicks QA
    QA_addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])

    QA_run_button.click(qa_summarise, 
                    inputs=[selected_queries, qa_llm_model, text_field, data], ## XXX fix text_field 
                    outputs=[hsummary, qa_df])      
    
    
    # Event listener for single text classification
    text_to_classify.submit(
        single_classification, 
        inputs=[text_to_classify, model_sing_classify, threshold_sing_classify], 
        outputs=[classification, classification_score])
        
demo.launch()