Spaces:
Sleeping
Sleeping
File size: 20,940 Bytes
b223991 d2cdbf2 b223991 f3ec7d3 5eb8a9b a904d6c 5eb8a9b b223991 5fdbd4f 593adc9 b223991 6b08849 b223991 031f180 b223991 1517eaf 6b08849 b223991 dbf585c b223991 65a7bbd b223991 6b08849 b223991 4110157 b223991 a265142 031f180 b223991 4110157 b223991 469da6e b223991 f3ec7d3 b223991 e9e4a27 6539497 b223991 fc36c14 55d595f fc36c14 454d59a fc36c14 b223991 b849a22 55d595f c050c84 09c3b86 b849a22 9101813 4ddc454 9101813 4ddc454 a8d40a9 bee0f01 3f98cc2 c5a5c67 593adc9 29933e1 c5a5c67 593adc9 71750fd 593adc9 b223991 a265142 4110157 a265142 469da6e b223991 cd14702 bf92b6d cd14702 6b08849 cd14702 b223991 65a7bbd b223991 65a7bbd b223991 cd14702 e778d38 cd14702 f3ec7d3 cd14702 65a7bbd b223991 90c2fda b223991 5eb8a9b e778d38 b223991 dbf585c b223991 65a7bbd b223991 65a7bbd b223991 4110157 469da6e bf92b6d e778d38 b223991 780faa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 |
import os
import time
import gradio as gr
import pandas as pd
from classifier import classify
from statistics import mean
from qa_summary import generate_answer
HFTOKEN = os.environ["HF_TOKEN"]
loadTwitterWidgets_js = """
async () => {
// Load Twitter Widgets script
const script = document.createElement("script");
script.onload = () => console.log("Twitter Widgets.js loaded");
script.src = "https://platform.twitter.com/widgets.js";
document.head.appendChild(script);
// Define a global function to reload Twitter widgets
globalThis.reloadTwitterWidgets = () => {
// Reload Twitter widgets
if (window.twttr && twttr.widgets) {
twttr.widgets.load();
}
};
}
"""
def T_on_select(evt: gr.SelectData):
return evt.value
def single_classification(text, event_model, threshold):
res = classify(text, event_model, HFTOKEN, threshold)
return res["event"], res["score"]
def load_and_classify_csv(file, text_field, event_model, threshold):
text_field = text_field.strip()
filepath = file.name
if ".csv" in filepath:
df = pd.read_csv(filepath)
else:
df = pd.read_table(filepath)
if text_field not in df.columns:
raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
labels, scores = [], []
for post in df[text_field].to_list():
res = classify(post, event_model, HFTOKEN, threshold)
labels.append(res["event"])
scores.append(res["score"])
df["event_label"] = labels
df["model_score"] = scores
# model_confidence = round(mean(scores), 5)
model_confidence = mean(scores)
fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list())
flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df, gr.update(interactive=True), gr.update(interactive=True)
def load_and_classify_csv_dataframe(file, text_field, event_model, threshold):
text_field = text_field.strip()
filepath = file.name
if ".csv" in filepath:
df = pd.read_csv(filepath)
else:
df = pd.read_table(filepath)
if text_field not in df.columns:
raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
labels, scores = [], []
for post in df[text_field].to_list():
res = classify(post, event_model, HFTOKEN, threshold)
labels.append(res["event"])
scores.append(round(res["score"], 5))
df["event_label"] = labels
df["model_score"] = scores
result_df = df[[text_field, "event_label", "model_score", "tweet_id"]].copy()
result_df["tweet_id"] = result_df["tweet_id"].astype(str)
filters = list(result_df["event_label"].unique())
extra_filters = ['Not-'+x for x in filters]+['All']
return result_df, result_df, gr.update(choices=sorted(filters+extra_filters),
value='All',
label="Filter data by label",
visible=True), gr.update(interactive=True), gr.update(interactive=True)
def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
text_field = text_field.strip()
posts = data_df[text_field].to_list()
selections = flood_selections + fire_selections + none_selections
eval = []
for post in posts:
if post in selections:
eval.append("incorrect")
else:
eval.append("correct")
data_df["model_eval"] = eval
incorrect = len(selections)
correct = num_posts - incorrect
accuracy = (correct/num_posts)*100
data_df.to_csv("output.csv")
return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True)
def init_queries(history):
history = history or []
if not history:
history = [
"What areas are being evacuated?",
"What areas are predicted to be impacted?",
"What areas are without power?",
"What barriers are hindering response efforts?",
"What events have been canceled?",
"What preparations are being made?",
"What regions have announced a state of emergency?",
"What roads are blocked / closed?",
"What services have been closed?",
"What warnings are currently in effect?",
"Where are emergency services deployed?",
"Where are emergency services needed?",
"Where are evacuations needed?",
"Where are people needing rescued?",
"Where are recovery efforts taking place?",
"Where has building or infrastructure damage occurred?",
"Where has flooding occured?"
"Where are volunteers being requested?",
"Where has road damage occured?",
"What area has the wildfire burned?",
"Where have homes been damaged or destroyed?"]
return gr.CheckboxGroup(choices=history), history
def add_query(to_add, history):
if to_add not in history:
history.append(to_add)
return gr.CheckboxGroup(choices=history), history
#def qa_summarise(selected_queries, qa_llm_model, text_field, data_df):
def qa_summarise(selected_queries, qa_llm_model, text_field, response_lang, data_df):
if not selected_queries:
raise gr.Error(f"Error: You have to select one or more queries to ask.")
qa_input_df = data_df[data_df["event_label"] != "none"].reset_index()
texts = qa_input_df[text_field].to_list()
# summary = generate_answer(qa_llm_model, texts, selected_queries[0], selected_queries, mode="multi_summarize")
summary = generate_answer(qa_llm_model,
texts,
selected_queries[0],
selected_queries,
response_lang,
mode="multi_summarize")
doc_df = pd.DataFrame()
doc_df["number"] = [i+1 for i in range(len(texts))]
doc_df["text"] = texts
doc_df["IDs"] = qa_input_df["tweet_id"].to_list()
return summary, doc_df
with gr.Blocks(fill_width=True) as demo:
demo.load(None,None,None,js=loadTwitterWidgets_js)
event_models = ["jayebaku/XLMRoberta-twitter-crexdata-flood-wildfire-detector",
"jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",]
T_data_ss_state = gr.State(value=pd.DataFrame())
with gr.Tab("Single Text Classification"):
gr.Markdown(
"""
# Single Text Classifier Demo
In this section you test the relevance classifier with written texts.\n
Usage:\n
- Type a tweet-like text in the textbox.\n
- Then press Enter.\n
""")
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
model_sing_classify = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
with gr.Column(scale=7):
with gr.Accordion("Prediction threshold", open=False):
threshold_sing_classify = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
info="This value sets a threshold by which texts classified flood or fire are accepted, \
higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
text_to_classify = gr.Textbox(label="Text", info="Enter tweet-like text", submit_btn=True)
text_to_classify_examples = gr.Examples([["The streets are flooded, I can't leave #BostonStorm"],
["Controlado el incendio de Rodezno que ha obligado a desalojar a varias bodegas de la zona."],
["Cambrils:estació Renfe inundada 19 persones dins d'un tren. FGC a Capellades, petit descarrilament 5 passatgers #Inuncat @emergenciescat"],
["Anscheinend steht die komplette Neckarwiese unter Wasser! #Hochwasser"]], text_to_classify)
with gr.Group():
with gr.Row():
with gr.Column():
classification = gr.Textbox(label="Classification")
with gr.Column():
classification_score = gr.Number(label="Classification Score")
with gr.Tab("Event Type Classification"):
gr.Markdown(
"""
# Relevance Classifier Demo
This is a demo created to explore floods and wildfire classification in social media posts.\n
Upload .tsv or .csv data file (must contain a text column with social media posts), next enter the name of the text column, choose classifier model, and click 'start prediction'.
""")
with gr.Group():
with gr.Row(equal_height=True):
with gr.Column():
T_file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
with gr.Column():
T_text_field = gr.Textbox(label="Text field name", value="tweet_text")
T_event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
with gr.Accordion("Prediction threshold", open=False):
T_threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
info="This value sets a threshold by which texts classified flood or fire are accepted, \
higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
T_predict_button = gr.Button("Start Prediction")
T_examples = gr.Examples([["./samples.tsv", "tweet_content", "jayebaku/XLMRoberta-twitter-crexdata-flood-wildfire-detector", 0.00]],
inputs=[T_file_input, T_text_field, T_event_model, T_threshold])
gr.Markdown("""Select an ID cell in dataframe to view Embedded tweet""")
T_tweetID = gr.Textbox(visible=False)
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
T_data_filter = gr.Dropdown(visible=False)
T_tweet_embed = gr.HTML("""<div id="tweet-container"></div>""")
with gr.Column(scale=7):
T_data = gr.DataFrame(#headers=["Texts", "event_label", "model_score", "IDs"],
wrap=True,
show_fullscreen_button=True,
show_copy_button=True,
show_row_numbers=True,
show_search="filter",
max_height=1000,
column_widths=["49%","17%","17%","17%"])
qa_tab = gr.Tab("Question Answering")
with qa_tab:
gr.Markdown(
"""
# Question Answering Demo
This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n
Usage:\n
- Select queries from predefined\n
- Parameters for QA can be editted in sidebar\n
Note: QA process is disabled untill after the relevance classification is done
""")
with gr.Group():
with gr.Accordion("Parameters", open=False):
with gr.Row():
with gr.Column():
qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
with gr.Column():
batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)
response_lang = gr.Dropdown(["english", "german", "catalan", "spanish"], label="Response language", value="english", interactive=True)
selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
queries_state = gr.State()
qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])
query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
QA_addqry_button = gr.Button("Add to queries", interactive=False)
QA_run_button = gr.Button("Start QA", interactive=False)
hsummary = gr.Textbox(label="Summary")
qa_tweetID = gr.Textbox(visible=False)
with gr.Group():
with gr.Row():
with gr.Column(scale=7):
qa_df = gr.DataFrame(wrap=True,
show_fullscreen_button=True,
show_copy_button=True,
show_search="filter",
max_height=1000,
column_widths=["10%","70%","20%"])
with gr.Column(scale=3):
qa_tweet_embed = gr.HTML("""<div id="tweet-container2"></div>""")
# with gr.Tab("Event Type Classification Eval"):
# gr.Markdown(
# """
# # T4.5 Relevance Classifier Demo
# This is a demo created to explore floods and wildfire classification in social media posts.\n
# Usage:\n
# - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
# - Next, type the name of the text column.\n
# - Then, choose a BERT classifier model from the drop down.\n
# - Finally, click the 'start prediction' buttton.\n
# Evaluation:\n
# - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
# - Then, click on the 'Calculate Accuracy' button.\n
# - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
# """)
# with gr.Row():
# with gr.Column(scale=4):
# file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
# with gr.Column(scale=6):
# text_field = gr.Textbox(label="Text field name", value="tweet_text")
# event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
# ETCE_predict_button = gr.Button("Start Prediction")
# with gr.Accordion("Prediction threshold", open=False):
# threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
# info="This value sets a threshold by which texts classified flood or fire are accepted, \
# higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
# with gr.Row():
# with gr.Column():
# gr.Markdown("""### Flood-related""")
# flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
# with gr.Column():
# gr.Markdown("""### Fire-related""")
# fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
# with gr.Column():
# gr.Markdown("""### None""")
# none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
# with gr.Row():
# with gr.Column(scale=5):
# gr.Markdown(r"""
# Accuracy: is the model's ability to make correct predicitons.
# It is the fraction of correct prediction out of the total predictions.
# $$
# \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
# $$
# Model Confidence: is the mean probabilty of each case
# belonging to their assigned classes. A value of 1 is best.
# """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
# gr.Markdown("\n\n\n")
# model_confidence = gr.Number(label="Model Confidence")
# with gr.Column(scale=5):
# correct = gr.Number(label="Number of correct classifications")
# incorrect = gr.Number(label="Number of incorrect classifications")
# accuracy = gr.Number(label="Model Accuracy (%)")
# ETCE_accuracy_button = gr.Button("Calculate Accuracy")
# download_csv = gr.DownloadButton(visible=False)
# num_posts = gr.Number(visible=False)
# data = gr.DataFrame(visible=False)
# data_eval = gr.DataFrame(visible=False)
createEmbedding_js = """ (x) =>
{
reloadTwitterWidgets();
const tweetContainer = document.getElementById("<=CONTAINER-NAME=>");
tweetContainer.innerHTML = "";
twttr.widgets.createTweet(x,tweetContainer,{theme: 'dark', dnt: true, align: 'center'});
}
"""
# Test event listeners
T_predict_button.click(
load_and_classify_csv_dataframe,
inputs=[T_file_input, T_text_field, T_event_model, T_threshold],
outputs=[T_data, T_data_ss_state, T_data_filter, QA_addqry_button, QA_run_button]
)
T_data.select(T_on_select, None, T_tweetID)
T_tweetID.change(fn=None, inputs=T_tweetID, outputs=None, js=createEmbedding_js.replace("<=CONTAINER-NAME=>", "tweet-container"))
@T_data_filter.input(inputs=[T_data_ss_state, T_data_filter], outputs=T_data)
def filter_df(df, filter):
if filter == "All":
result_df = df.copy()
elif filter.startswith("Not"):
result_df = df[df["event_label"]!=filter.split('-')[1]].copy()
else:
result_df = df[df["event_label"]==filter].copy()
return result_df
# Button clicks ETC Eval
# ETCE_predict_button.click(
# load_and_classify_csv,
# inputs=[file_input, text_field, event_model, threshold],
# outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data, QA_addqry_button, QA_run_button])
# ETCE_accuracy_button.click(
# calculate_accuracy,
# inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data],
# outputs=[incorrect, correct, accuracy, data_eval, download_csv])
# Button clicks QA
QA_addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
QA_run_button.click(qa_summarise,
inputs=[selected_queries, qa_llm_model, T_text_field, response_lang, T_data_ss_state],
outputs=[hsummary, qa_df])
qa_df.select(T_on_select, None, qa_tweetID)
qa_tweetID.change(fn=None, inputs=qa_tweetID, outputs=None, js=createEmbedding_js.replace("<=CONTAINER-NAME=>", "tweet-container2"))
# Event listener for single text classification
text_to_classify.submit(
single_classification,
inputs=[text_to_classify, model_sing_classify, threshold_sing_classify],
outputs=[classification, classification_score])
demo.launch() |