Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import numpy as np | |
import plotly.graph_objects as go | |
# Initialize model and tokenizer | |
MODEL_OPTIONS = { | |
"waleko/roberta-arxiv-tags": "RoBERTa Arxiv Tags" | |
} | |
def load_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
return model, tokenizer | |
current_model = None | |
current_tokenizer = None | |
def get_model_and_tokenizer(model_name): | |
global current_model, current_tokenizer | |
if current_model is None or current_tokenizer is None: | |
current_model, current_tokenizer = load_model(model_name) | |
return current_model, current_tokenizer | |
def create_visualization(probs, labels): | |
return go.Figure(data=[go.Pie( | |
labels=labels + ['Others'] if sum(probs) < 1 else labels, | |
values=list(probs) + [1 - sum(probs)] if sum(probs) < 1 else list(probs), | |
textinfo='percent', | |
textposition='inside', | |
hole=.3, | |
showlegend=True | |
)]) | |
def classify_text(title, abstract, model_name): | |
if not title and not abstract: | |
return "Error: At least one of title or abstract must be provided.", None | |
model, tokenizer = get_model_and_tokenizer(model_name) | |
text = 'Title: ' + (title or '') + '\n\nAbstract: ' + (abstract or '') | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = torch.nn.functional.softmax(logits[0], dim=0) | |
probs = probs.numpy() | |
sorted_idx = np.argsort(probs)[::-1] | |
sorted_probs = probs[sorted_idx] | |
cumsum = np.cumsum(sorted_probs) | |
k = 1 | |
if sorted_probs[0] < 0.95: | |
k = np.argmax(cumsum >= 0.95) + 1 | |
id2label = model.config.id2label | |
tags = [id2label[idx] for idx in sorted_idx[:k]] | |
compact_pred = f'<span style="font-weight: 800;">{tags[0]}</span>' + (f" {' '.join(tags[1:])}" if len(tags) > 1 else "") | |
viz_data = create_visualization( | |
sorted_probs[:k], | |
[id2label[idx] for idx in sorted_idx[:k]] | |
) | |
html_output = f""" | |
<div> | |
<h3>Predicted Tags</h3> | |
<p>{compact_pred}</p> | |
</div> | |
""" | |
return html_output, viz_data | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# Arxiv Tags Classification | |
Classify academic papers into arXiv categories using state-of-the-art language models. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_dropdown = gr.Dropdown( | |
choices=list(MODEL_OPTIONS.keys()), | |
value=list(MODEL_OPTIONS.keys())[0], | |
label="Select Model", | |
info="Choose the model for classification" | |
) | |
title_input = gr.Textbox( | |
lines=1, | |
label="Title", | |
placeholder="Enter paper title (optional if abstract is provided)" | |
) | |
abstract_input = gr.Textbox( | |
lines=5, | |
label="Abstract", | |
placeholder="Enter paper abstract (optional if title is provided)" | |
) | |
with gr.Column(scale=1): | |
output_html = gr.HTML( | |
label="Predicted Tags" | |
) | |
output_plot = gr.Plot( | |
label="Probability Distribution", | |
show_label=True | |
) | |
inputs = [title_input, abstract_input, model_dropdown] | |
btn = gr.Button("Classify", variant="primary") | |
btn.click(fn=classify_text, inputs=inputs, outputs=[output_html, output_plot]) | |
if __name__ == "__main__": | |
demo.launch() |