import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import numpy as np import plotly.graph_objects as go # Initialize model and tokenizer MODEL_OPTIONS = { "waleko/roberta-arxiv-tags": "RoBERTa Arxiv Tags" } def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) return model, tokenizer current_model = None current_tokenizer = None def get_model_and_tokenizer(model_name): global current_model, current_tokenizer if current_model is None or current_tokenizer is None: current_model, current_tokenizer = load_model(model_name) return current_model, current_tokenizer def create_visualization(probs, labels): return go.Figure(data=[go.Pie( labels=labels + ['Others'] if sum(probs) < 1 else labels, values=list(probs) + [1 - sum(probs)] if sum(probs) < 1 else list(probs), textinfo='percent', textposition='inside', hole=.3, showlegend=True )]) def classify_text(title, abstract, model_name): if not title and not abstract: return "Error: At least one of title or abstract must be provided.", None model, tokenizer = get_model_and_tokenizer(model_name) text = 'Title: ' + (title or '') + '\n\nAbstract: ' + (abstract or '') inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits[0], dim=0) probs = probs.numpy() sorted_idx = np.argsort(probs)[::-1] sorted_probs = probs[sorted_idx] cumsum = np.cumsum(sorted_probs) k = 1 if sorted_probs[0] < 0.95: k = np.argmax(cumsum >= 0.95) + 1 id2label = model.config.id2label tags = [id2label[idx] for idx in sorted_idx[:k]] compact_pred = f'{tags[0]}' + (f" {' '.join(tags[1:])}" if len(tags) > 1 else "") viz_data = create_visualization( sorted_probs[:k], [id2label[idx] for idx in sorted_idx[:k]] ) html_output = f"""
{compact_pred}