Spaces:
Sleeping
Sleeping
import os | |
import time | |
import gradio as gr | |
import pandas as pd | |
from classifier import classify | |
from statistics import mean | |
# from genra_incremental import GenraPipeline | |
from qa_summary import generate_answer | |
HFTOKEN = os.environ["HF_TOKEN"] | |
def load_and_classify_csv(file, text_field, event_model): | |
filepath = file.name | |
if ".csv" in filepath: | |
df = pd.read_csv(filepath) | |
else: | |
df = pd.read_table(filepath) | |
if text_field not in df.columns: | |
raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") | |
labels, scores = [], [] | |
for post in df[text_field].to_list(): | |
res = classify(post, event_model, HFTOKEN) | |
labels.append(res["event"]) | |
scores.append(res["score"]) | |
df["model_label"] = labels | |
df["model_score"] = scores | |
model_confidence = round(mean(scores), 5) | |
fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) #fires | |
flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list()) | |
not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list()) | |
return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df | |
def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df): | |
posts = data_df[text_field].to_list() | |
selections = flood_selections + fire_selections + none_selections | |
eval = [] | |
for post in posts: | |
if post in selections: | |
eval.append("incorrect") | |
else: | |
eval.append("correct") | |
data_df["model_eval"] = eval | |
incorrect = len(selections) | |
correct = num_posts - incorrect | |
accuracy = (correct/num_posts)*100 | |
data_df.to_csv("output.csv") | |
return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True) | |
def init_queries(history): | |
history = history or [] | |
if not history: | |
history = [ | |
"What areas are being evacuated?", | |
"What areas are predicted to be impacted?", | |
"What areas are without power?", | |
"What barriers are hindering response efforts?", | |
"What events have been canceled?", | |
"What preparations are being made?", | |
"What regions have announced a state of emergency?", | |
"What roads are blocked / closed?", | |
"What services have been closed?", | |
"What warnings are currently in effect?", | |
"Where are emergency services deployed?", | |
"Where are emergency services needed?", | |
"Where are evacuations needed?", | |
"Where are people needing rescued?", | |
"Where are recovery efforts taking place?", | |
"Where has building or infrastructure damage occurred?", | |
"Where has flooding occured?" | |
"Where are volunteers being requested?", | |
"Where has road damage occured?", | |
"What area has the wildfire burned?", | |
"Where have homes been damaged or destroyed?"] | |
return gr.CheckboxGroup(choices=history), history | |
def add_query(to_add, history): | |
if to_add not in history: | |
history.append(to_add) | |
return gr.CheckboxGroup(choices=history), history | |
# def qa_process(selected_queries, qa_llm_model, aggregator, | |
# batch_size, topk, text_field, data_df): | |
# emb_model = 'multi-qa-mpnet-base-dot-v1' | |
# contexts = [] | |
# queries_df = pd.DataFrame({'id':[j for j in range(len(selected_queries))],'query': selected_queries}) | |
# qa_input_df = data_df[data_df["model_label"] != "none"].reset_index() | |
# tweets_df = qa_input_df[[text_field]] | |
# tweets_df.reset_index(inplace=True) | |
# tweets_df.rename(columns={"index": "order", text_field: "text"},inplace=True) | |
# gr.Info("Loading GENRA pipeline....") | |
# genra = GenraPipeline(qa_llm_model, emb_model, aggregator, contexts) | |
# gr.Info("Waiting for data...") | |
# batches = [tweets_df[i:i+batch_size] for i in range(0,len(tweets_df),batch_size)] | |
# genra_answers = [] | |
# summarize_batch = True | |
# for batch_number, tweets in enumerate(batches): | |
# gr.Info(f"Populating index for batch {batch_number}") | |
# genra.qa_indexer.index_dataframe(tweets) | |
# gr.Info(f"Performing retrieval for batch {batch_number}") | |
# genra.retrieval(batch_number, queries_df, topk, summarize_batch) | |
# gr.Info("Processed all batches!") | |
# gr.Info("Getting summary...") | |
# summary = genra.summarize_history(queries_df) | |
# gr.Info("Preparing results...") | |
# results = genra.answers_store | |
# final_answers, q_a = [], [] | |
# for q, g_answers in results.items(): | |
# for answer in g_answers: | |
# final_answers.append({'question':q, "tweets":answer['tweets'], "batch":answer['batch_number'], "summary":answer['summary'] }) | |
# for t in answer['tweets']: | |
# q_a.append((q,t)) | |
# answers_df = pd.DataFrame.from_dict(final_answers) | |
# q_a = list(set(q_a)) | |
# q_a_df = pd.DataFrame(q_a, columns =['question', 'tweet']) | |
# q_a_df = q_a_df.sort_values(by=["question"], ascending=False) | |
# return q_a_df, answers_df, summary | |
def qa_summarise(selected_queries, qa_llm_model, text_field, response_lang, data_df): | |
qa_input_df = data_df[data_df["model_label"] != "none"].reset_index() | |
texts = qa_input_df[text_field].to_list() | |
summary = generate_answer(qa_llm_model, | |
texts, | |
selected_queries[0], | |
selected_queries, | |
response_lang, | |
mode="multi_summarize") | |
doc_df = pd.DataFrame() | |
doc_df["number"] = [i+1 for i in range(len(texts))] | |
doc_df["text"] = texts | |
return summary, doc_df | |
with gr.Blocks() as demo: | |
event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier", | |
"jayebaku/distilbert-base-multilingual-cased-weather-classifier-2", | |
"jayebaku/twitter-xlm-roberta-base-crexdata-relevance-classifier", | |
"jayebaku/twhin-bert-base-crexdata-relevance-classifier"] | |
with gr.Tab("Event Type Classification"): | |
gr.Markdown( | |
""" | |
# T4.5 Relevance Classifier Demo | |
This is a demo created to explore floods and wildfire classification in social media posts.\n | |
Usage:\n | |
- Upload .tsv data file (must contain a text column with social media posts).\n | |
- Next, type the name of the text column.\n | |
- Then, choose a BERT classifier model from the drop down.\n | |
- Finally, click the 'start prediction' buttton.\n | |
Evaluation:\n | |
- To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n | |
- Then, click on the 'Calculate Accuracy' button.\n | |
- Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file. | |
""") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv']) | |
with gr.Column(scale=6): | |
text_field = gr.Textbox(label="Text field name", value="tweet_text") | |
event_model = gr.Dropdown(event_models, label="Select classification model") | |
predict_button = gr.Button("Start Prediction") | |
with gr.Row(): # XXX confirm this is not a problem later --equal_height=True | |
with gr.Column(): | |
gr.Markdown("""### Flood-related""") | |
flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
with gr.Column(): | |
gr.Markdown("""### Fire-related""") | |
fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
with gr.Column(): | |
gr.Markdown("""### None""") | |
none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=5): | |
gr.Markdown(r""" | |
Accuracy: is the model's ability to make correct predicitons. | |
It is the fraction of correct prediction out of the total predictions. | |
$$ | |
\text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100 | |
$$ | |
Model Confidence: is the mean probabilty of each case | |
belonging to their assigned classes. A value of 1 is best. | |
""", latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]) | |
gr.Markdown("\n\n\n") | |
model_confidence = gr.Number(label="Model Confidence") | |
with gr.Column(scale=5): | |
correct = gr.Number(label="Number of correct classifications") | |
incorrect = gr.Number(label="Number of incorrect classifications") | |
accuracy = gr.Number(label="Model Accuracy (%)") | |
accuracy_button = gr.Button("Calculate Accuracy") | |
download_csv = gr.DownloadButton(visible=False) | |
num_posts = gr.Number(visible=False) | |
data = gr.DataFrame(visible=False) | |
data_eval = gr.DataFrame(visible=False) | |
predict_button.click( | |
load_and_classify_csv, | |
inputs=[file_input, text_field, event_model], | |
outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data]) | |
accuracy_button.click( | |
calculate_accuracy, | |
inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data], | |
outputs=[incorrect, correct, accuracy, data_eval, download_csv]) | |
qa_tab = gr.Tab("Question Answering") | |
with qa_tab: | |
# XXX Add some button disabling here, if the classification process is not completed first XXX | |
gr.Markdown( | |
""" | |
# Question Answering Demo | |
This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n | |
Usage:\n | |
- Select queries from predefined\n | |
- Parameters for QA can be editted in sidebar\n | |
""") | |
with gr.Accordion("Parameters", open=False): | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True) | |
aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True) | |
with gr.Column(): | |
batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True) | |
topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True) | |
response_lang = gr.Dropdown(["english", "german", "catalan", "spanish"], label="Response language", value="english", interactive=True) | |
selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True) | |
queries_state = gr.State() | |
qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state]) | |
query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time") | |
addqry_button = gr.Button("Add to queries") | |
qa_button = gr.Button("Start QA") | |
hsummary = gr.Textbox(label="Summary") | |
qa_df = gr.DataFrame() | |
# answers_df = gr.DataFrame() | |
addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state]) | |
# qa_button.click(qa_process, | |
# inputs=[selected_queries, qa_llm_model, aggregator, batch_size, topk, text_field, data], | |
# outputs=[qa_df, answers_df, hsummary]) | |
qa_button.click(qa_summarise, | |
inputs=[selected_queries, qa_llm_model, text_field, response_lang, data], | |
outputs=[hsummary, qa_df]) | |
demo.launch() |