Spaces:
Sleeping
Sleeping
File size: 12,729 Bytes
fef8b06 5575829 fef8b06 8fca602 736168d 9317009 5575829 fef8b06 aeada86 f988ad7 6890e02 f988ad7 5575829 fef8b06 0c3391b fef8b06 0c3391b 8fca602 0c3391b 8fca602 0c3391b 8fca602 26e5625 5575829 cf0c8b1 5662680 475b809 ac89270 9865cf1 10698b8 984f521 d6dd9fc 2097977 984f521 2097977 47df43c 736168d 47df43c 736168d 47df43c 736168d 5c1cf41 736168d 47df43c 736168d 47df43c 736168d fddb849 736168d 47df43c 736168d 47df43c 736168d 47df43c 4da2d54 9442705 4da2d54 93413ed 9442705 93413ed 9442705 9865cf1 5575829 93413ed f8d4a0e 5575829 8ce9236 de89642 8ce9236 de89642 8ce9236 5575829 aeada86 5575829 fef8b06 58cbd2d 5575829 069449a 5575829 069449a 5575829 069449a 8fca602 e5996e7 e3e91ee e5996e7 234945b 05d925b e5996e7 05d925b 234945b e5996e7 05d925b 4067906 6f5a3f7 e3e91ee 0c3391b f95341b 4067906 1085ea0 21731c5 ab2e768 cf0c8b1 ac89270 0c3391b cf0c8b1 aeada86 cf0c8b1 3c8dac4 21731c5 1d42bf8 4da2d54 410e6af 1d42bf8 5575829 77e3da1 de89642 77e3da1 4da2d54 77e3da1 946dac8 2bb9658 8a64462 5c41875 2097977 5c41875 cc1117d ca8a1b9 d6dd9fc 93413ed 9442705 d6dd9fc 2097977 9442705 01d8202 4da2d54 93413ed 47df43c 5575829 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
import os
import time
import gradio as gr
import pandas as pd
from classifier import classify
from statistics import mean
# from genra_incremental import GenraPipeline
from qa_summary import generate_answer
HFTOKEN = os.environ["HF_TOKEN"]
def load_and_classify_csv(file, text_field, event_model):
filepath = file.name
if ".csv" in filepath:
df = pd.read_csv(filepath)
else:
df = pd.read_table(filepath)
if text_field not in df.columns:
raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
labels, scores = [], []
for post in df[text_field].to_list():
res = classify(post, event_model, HFTOKEN)
labels.append(res["event"])
scores.append(res["score"])
df["model_label"] = labels
df["model_score"] = scores
model_confidence = round(mean(scores), 5)
fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) #fires
flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df
def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
posts = data_df[text_field].to_list()
selections = flood_selections + fire_selections + none_selections
eval = []
for post in posts:
if post in selections:
eval.append("incorrect")
else:
eval.append("correct")
data_df["model_eval"] = eval
incorrect = len(selections)
correct = num_posts - incorrect
accuracy = (correct/num_posts)*100
data_df.to_csv("output.csv")
return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True)
def init_queries(history):
history = history or []
if not history:
history = [
"What areas are being evacuated?",
"What areas are predicted to be impacted?",
"What areas are without power?",
"What barriers are hindering response efforts?",
"What events have been canceled?",
"What preparations are being made?",
"What regions have announced a state of emergency?",
"What roads are blocked / closed?",
"What services have been closed?",
"What warnings are currently in effect?",
"Where are emergency services deployed?",
"Where are emergency services needed?",
"Where are evacuations needed?",
"Where are people needing rescued?",
"Where are recovery efforts taking place?",
"Where has building or infrastructure damage occurred?",
"Where has flooding occured?"
"Where are volunteers being requested?",
"Where has road damage occured?",
"What area has the wildfire burned?",
"Where have homes been damaged or destroyed?"]
return gr.CheckboxGroup(choices=history), history
def add_query(to_add, history):
if to_add not in history:
history.append(to_add)
return gr.CheckboxGroup(choices=history), history
# def qa_process(selected_queries, qa_llm_model, aggregator,
# batch_size, topk, text_field, data_df):
# emb_model = 'multi-qa-mpnet-base-dot-v1'
# contexts = []
# queries_df = pd.DataFrame({'id':[j for j in range(len(selected_queries))],'query': selected_queries})
# qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
# tweets_df = qa_input_df[[text_field]]
# tweets_df.reset_index(inplace=True)
# tweets_df.rename(columns={"index": "order", text_field: "text"},inplace=True)
# gr.Info("Loading GENRA pipeline....")
# genra = GenraPipeline(qa_llm_model, emb_model, aggregator, contexts)
# gr.Info("Waiting for data...")
# batches = [tweets_df[i:i+batch_size] for i in range(0,len(tweets_df),batch_size)]
# genra_answers = []
# summarize_batch = True
# for batch_number, tweets in enumerate(batches):
# gr.Info(f"Populating index for batch {batch_number}")
# genra.qa_indexer.index_dataframe(tweets)
# gr.Info(f"Performing retrieval for batch {batch_number}")
# genra.retrieval(batch_number, queries_df, topk, summarize_batch)
# gr.Info("Processed all batches!")
# gr.Info("Getting summary...")
# summary = genra.summarize_history(queries_df)
# gr.Info("Preparing results...")
# results = genra.answers_store
# final_answers, q_a = [], []
# for q, g_answers in results.items():
# for answer in g_answers:
# final_answers.append({'question':q, "tweets":answer['tweets'], "batch":answer['batch_number'], "summary":answer['summary'] })
# for t in answer['tweets']:
# q_a.append((q,t))
# answers_df = pd.DataFrame.from_dict(final_answers)
# q_a = list(set(q_a))
# q_a_df = pd.DataFrame(q_a, columns =['question', 'tweet'])
# q_a_df = q_a_df.sort_values(by=["question"], ascending=False)
# return q_a_df, answers_df, summary
def qa_summarise(selected_queries, qa_llm_model, text_field, response_lang, data_df):
qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
texts = qa_input_df[text_field].to_list()
summary = generate_answer(qa_llm_model,
texts,
selected_queries[0],
selected_queries,
response_lang,
mode="multi_summarize")
doc_df = pd.DataFrame()
doc_df["number"] = [i+1 for i in range(len(texts))]
doc_df["text"] = texts
return summary, doc_df
with gr.Blocks() as demo:
event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",
"jayebaku/distilbert-base-multilingual-cased-weather-classifier-2",
"jayebaku/twitter-xlm-roberta-base-crexdata-relevance-classifier",
"jayebaku/twhin-bert-base-crexdata-relevance-classifier"]
with gr.Tab("Event Type Classification"):
gr.Markdown(
"""
# T4.5 Relevance Classifier Demo
This is a demo created to explore floods and wildfire classification in social media posts.\n
Usage:\n
- Upload .tsv data file (must contain a text column with social media posts).\n
- Next, type the name of the text column.\n
- Then, choose a BERT classifier model from the drop down.\n
- Finally, click the 'start prediction' buttton.\n
Evaluation:\n
- To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
- Then, click on the 'Calculate Accuracy' button.\n
- Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
""")
with gr.Row(equal_height=True):
with gr.Column(scale=4):
file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
with gr.Column(scale=6):
text_field = gr.Textbox(label="Text field name", value="tweet_text")
event_model = gr.Dropdown(event_models, label="Select classification model")
predict_button = gr.Button("Start Prediction")
with gr.Row(): # XXX confirm this is not a problem later --equal_height=True
with gr.Column():
gr.Markdown("""### Flood-related""")
flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
with gr.Column():
gr.Markdown("""### Fire-related""")
fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
with gr.Column():
gr.Markdown("""### None""")
none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=5):
gr.Markdown(r"""
Accuracy: is the model's ability to make correct predicitons.
It is the fraction of correct prediction out of the total predictions.
$$
\text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
$$
Model Confidence: is the mean probabilty of each case
belonging to their assigned classes. A value of 1 is best.
""", latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
gr.Markdown("\n\n\n")
model_confidence = gr.Number(label="Model Confidence")
with gr.Column(scale=5):
correct = gr.Number(label="Number of correct classifications")
incorrect = gr.Number(label="Number of incorrect classifications")
accuracy = gr.Number(label="Model Accuracy (%)")
accuracy_button = gr.Button("Calculate Accuracy")
download_csv = gr.DownloadButton(visible=False)
num_posts = gr.Number(visible=False)
data = gr.DataFrame(visible=False)
data_eval = gr.DataFrame(visible=False)
predict_button.click(
load_and_classify_csv,
inputs=[file_input, text_field, event_model],
outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data])
accuracy_button.click(
calculate_accuracy,
inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data],
outputs=[incorrect, correct, accuracy, data_eval, download_csv])
qa_tab = gr.Tab("Question Answering")
with qa_tab:
# XXX Add some button disabling here, if the classification process is not completed first XXX
gr.Markdown(
"""
# Question Answering Demo
This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n
Usage:\n
- Select queries from predefined\n
- Parameters for QA can be editted in sidebar\n
""")
with gr.Accordion("Parameters", open=False):
with gr.Row(equal_height=True):
with gr.Column():
qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
with gr.Column():
batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)
response_lang = gr.Dropdown(["english", "german", "catalan", "spanish"], label="Response language", value="english", interactive=True)
selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
queries_state = gr.State()
qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])
query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
addqry_button = gr.Button("Add to queries")
qa_button = gr.Button("Start QA")
hsummary = gr.Textbox(label="Summary")
qa_df = gr.DataFrame()
# answers_df = gr.DataFrame()
addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
# qa_button.click(qa_process,
# inputs=[selected_queries, qa_llm_model, aggregator, batch_size, topk, text_field, data],
# outputs=[qa_df, answers_df, hsummary])
qa_button.click(qa_summarise,
inputs=[selected_queries, qa_llm_model, text_field, response_lang, data],
outputs=[hsummary, qa_df])
demo.launch() |