File size: 12,729 Bytes
fef8b06
5575829
 
 
 
fef8b06
8fca602
736168d
9317009
5575829
fef8b06
 
 
aeada86
f988ad7
 
 
6890e02
f988ad7
5575829
 
 
fef8b06
0c3391b
fef8b06
 
0c3391b
8fca602
 
0c3391b
 
 
8fca602
0c3391b
 
 
8fca602
26e5625
5575829
cf0c8b1
 
 
 
 
 
 
 
 
 
 
 
5662680
 
475b809
 
ac89270
9865cf1
10698b8
984f521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6dd9fc
 
2097977
 
984f521
 
2097977
47df43c
736168d
 
47df43c
736168d
 
47df43c
736168d
 
5c1cf41
736168d
 
 
 
 
 
 
 
47df43c
736168d
 
 
 
 
 
 
47df43c
736168d
fddb849
736168d
 
47df43c
736168d
 
 
 
 
 
 
 
 
 
 
 
47df43c
736168d
47df43c
4da2d54
9442705
 
 
 
4da2d54
 
 
 
 
 
93413ed
 
 
 
9442705
93413ed
9442705
9865cf1
5575829
93413ed
 
 
 
f8d4a0e
5575829
8ce9236
 
 
 
 
de89642
 
 
 
8ce9236
de89642
 
 
8ce9236
5575829
 
aeada86
5575829
 
fef8b06
58cbd2d
5575829
 
 
 
 
069449a
5575829
 
 
069449a
5575829
 
 
069449a
8fca602
e5996e7
e3e91ee
e5996e7
 
 
234945b
05d925b
e5996e7
05d925b
234945b
e5996e7
 
05d925b
 
4067906
6f5a3f7
e3e91ee
0c3391b
f95341b
 
4067906
1085ea0
21731c5
ab2e768
cf0c8b1
ac89270
0c3391b
cf0c8b1
aeada86
cf0c8b1
 
 
 
3c8dac4
21731c5
1d42bf8
 
4da2d54
410e6af
1d42bf8
5575829
77e3da1
de89642
 
 
 
 
 
 
 
77e3da1
 
 
 
 
 
 
 
 
 
4da2d54
 
77e3da1
946dac8
2bb9658
8a64462
5c41875
2097977
5c41875
cc1117d
ca8a1b9
d6dd9fc
93413ed
9442705
d6dd9fc
2097977
9442705
 
 
01d8202
4da2d54
93413ed
47df43c
5575829
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import os
import time 
import gradio as gr
import pandas as pd

from classifier import classify
from statistics import mean
# from genra_incremental import GenraPipeline
from qa_summary import generate_answer


HFTOKEN = os.environ["HF_TOKEN"]

def load_and_classify_csv(file, text_field, event_model):
    filepath = file.name
    if ".csv" in filepath:
        df = pd.read_csv(filepath)
    else:
        df = pd.read_table(filepath)
    
    if text_field not in df.columns:
        raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")

    labels, scores = [], []
    for post in df[text_field].to_list():
        res = classify(post, event_model, HFTOKEN)
        labels.append(res["event"])
        scores.append(res["score"])

    df["model_label"] = labels
    df["model_score"] = scores
    
    model_confidence = round(mean(scores), 5)   
    fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) #fires
    flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
    not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
    
    return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df

def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
    posts = data_df[text_field].to_list()
    selections = flood_selections + fire_selections + none_selections
    eval = []
    for post in posts:
        if post in selections:
            eval.append("incorrect")
        else:
            eval.append("correct")

    data_df["model_eval"] = eval
    incorrect = len(selections)
    correct = num_posts - incorrect
    accuracy = (correct/num_posts)*100

    data_df.to_csv("output.csv")
    return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True) 
    
def init_queries(history):
    history = history or []
    if not history:
        history = [
        "What areas are being evacuated?",
        "What areas are predicted to be impacted?",
        "What areas are without power?",
        "What barriers are hindering response efforts?",
        "What events have been canceled?",
        "What preparations are being made?",
        "What regions have announced a state of emergency?",
        "What roads are blocked / closed?",
        "What services have been closed?",
        "What warnings are currently in effect?",
        "Where are emergency services deployed?",
        "Where are emergency services needed?",
        "Where are evacuations needed?",
        "Where are people needing rescued?",
        "Where are recovery efforts taking place?",
        "Where has building or infrastructure damage occurred?",
        "Where has flooding occured?"
        "Where are volunteers being requested?",
        "Where has road damage occured?",
        "What area has the wildfire burned?",
        "Where have homes been damaged or destroyed?"]
    
    return gr.CheckboxGroup(choices=history), history

def add_query(to_add, history):
    if to_add not in history:
        history.append(to_add)
    return gr.CheckboxGroup(choices=history), history

# def qa_process(selected_queries, qa_llm_model, aggregator,
#                batch_size, topk, text_field, data_df):
    
#     emb_model = 'multi-qa-mpnet-base-dot-v1'
#     contexts = []

#     queries_df = pd.DataFrame({'id':[j for j in range(len(selected_queries))],'query': selected_queries})
#     qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
    
#     tweets_df = qa_input_df[[text_field]]
#     tweets_df.reset_index(inplace=True)
#     tweets_df.rename(columns={"index": "order", text_field: "text"},inplace=True)

#     gr.Info("Loading GENRA pipeline....")
#     genra = GenraPipeline(qa_llm_model, emb_model, aggregator, contexts)
#     gr.Info("Waiting for data...")
#     batches = [tweets_df[i:i+batch_size] for i in range(0,len(tweets_df),batch_size)]
    
#     genra_answers = []
#     summarize_batch = True
#     for batch_number, tweets in enumerate(batches):
#         gr.Info(f"Populating index for batch {batch_number}")
#         genra.qa_indexer.index_dataframe(tweets)
#         gr.Info(f"Performing retrieval for batch {batch_number}")
#         genra.retrieval(batch_number, queries_df, topk, summarize_batch) 
    
#     gr.Info("Processed all batches!")
    
#     gr.Info("Getting summary...")
#     summary = genra.summarize_history(queries_df)
    
#     gr.Info("Preparing results...")
#     results = genra.answers_store
#     final_answers, q_a = [], []
#     for q, g_answers in results.items():
#         for answer in g_answers:
#             final_answers.append({'question':q, "tweets":answer['tweets'], "batch":answer['batch_number'], "summary":answer['summary'] })
#             for t in answer['tweets']:
#                 q_a.append((q,t))
#     answers_df = pd.DataFrame.from_dict(final_answers)
#     q_a = list(set(q_a))
#     q_a_df = pd.DataFrame(q_a, columns =['question', 'tweet'])
#     q_a_df = q_a_df.sort_values(by=["question"], ascending=False)
    
#     return q_a_df, answers_df, summary

def qa_summarise(selected_queries, qa_llm_model, text_field, response_lang, data_df):
    
    qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
    texts = qa_input_df[text_field].to_list()
    
    summary = generate_answer(qa_llm_model, 
                              texts, 
                              selected_queries[0], 
                              selected_queries, 
                              response_lang, 
                              mode="multi_summarize")

    doc_df = pd.DataFrame()
    doc_df["number"] = [i+1 for i in range(len(texts))]
    doc_df["text"] = texts
    
    return summary, doc_df

            
with gr.Blocks() as demo:
    event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",
                    "jayebaku/distilbert-base-multilingual-cased-weather-classifier-2",
                    "jayebaku/twitter-xlm-roberta-base-crexdata-relevance-classifier",
                    "jayebaku/twhin-bert-base-crexdata-relevance-classifier"]

    with gr.Tab("Event Type Classification"):
        gr.Markdown(
        """
        # T4.5 Relevance Classifier Demo 
        This is a demo created to explore floods and wildfire classification in social media posts.\n
        Usage:\n
            - Upload .tsv data file (must contain a text column with social media posts).\n
            - Next, type the name of the text column.\n
            - Then, choose a BERT classifier model from the drop down.\n
            - Finally, click the 'start prediction' buttton.\n
        Evaluation:\n
            - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
            - Then, click on the 'Calculate Accuracy' button.\n
            - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
        """)
        with gr.Row(equal_height=True):
            with gr.Column(scale=4):
                file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
                
            with gr.Column(scale=6):
                text_field = gr.Textbox(label="Text field name", value="tweet_text")
                event_model = gr.Dropdown(event_models, label="Select classification model")
                predict_button = gr.Button("Start Prediction")
        
        with gr.Row(): # XXX confirm this is not a problem later --equal_height=True
            with gr.Column():
                gr.Markdown("""### Flood-related""")
                flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
                
            with gr.Column():
                gr.Markdown("""### Fire-related""")
                fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
                
            with gr.Column():
                gr.Markdown("""### None""")
                none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) 

        with gr.Row(equal_height=True):
            with gr.Column(scale=5):
                gr.Markdown(r"""
                Accuracy: is the model's ability to make correct predicitons.
                It is the fraction of correct prediction out of the total predictions.
                
                $$
                \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
                $$
                
                Model Confidence: is the mean probabilty of each case 
                belonging to their assigned classes. A value of 1 is best.
                """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
                gr.Markdown("\n\n\n")
                model_confidence = gr.Number(label="Model Confidence")
                
            with gr.Column(scale=5):
                correct = gr.Number(label="Number of correct classifications")
                incorrect = gr.Number(label="Number of incorrect classifications")
                accuracy = gr.Number(label="Model Accuracy (%)")

        accuracy_button = gr.Button("Calculate Accuracy")
        download_csv = gr.DownloadButton(visible=False)
        num_posts = gr.Number(visible=False)
        data = gr.DataFrame(visible=False) 
        data_eval = gr.DataFrame(visible=False)
        
        predict_button.click(
            load_and_classify_csv, 
            inputs=[file_input, text_field, event_model], 
            outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data])
        accuracy_button.click(
            calculate_accuracy, 
            inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data], 
            outputs=[incorrect, correct, accuracy, data_eval, download_csv])

    qa_tab = gr.Tab("Question Answering")

    
    with qa_tab:
        # XXX Add some button disabling here, if the classification process is not completed first XXX
       
        gr.Markdown(
        """
        # Question Answering Demo
        This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n 
        Usage:\n
            - Select queries from predefined\n
            - Parameters for QA can be editted in sidebar\n
        """)
        
        with gr.Accordion("Parameters", open=False):
            with gr.Row(equal_height=True):
                with gr.Column():
                    qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
                    aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
    
                with gr.Column():
                    batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
                    topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)

            response_lang = gr.Dropdown(["english", "german", "catalan", "spanish"], label="Response language", value="english", interactive=True)
                    
        selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
        queries_state = gr.State()
        qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])

        query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
        addqry_button = gr.Button("Add to queries")
        qa_button = gr.Button("Start QA")
        hsummary = gr.Textbox(label="Summary")
        
        qa_df = gr.DataFrame()
        # answers_df = gr.DataFrame()

        addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
        # qa_button.click(qa_process, 
        #                 inputs=[selected_queries, qa_llm_model, aggregator, batch_size, topk, text_field, data], 
        #                 outputs=[qa_df, answers_df, hsummary])
        qa_button.click(qa_summarise, 
                        inputs=[selected_queries, qa_llm_model, text_field, response_lang, data], 
                        outputs=[hsummary, qa_df])


demo.launch()