File size: 3,759 Bytes
ff3f523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import plotly.graph_objects as go

# Initialize model and tokenizer
MODEL_OPTIONS = {
    "waleko/roberta-arxiv-tags": "RoBERTa Arxiv Tags"
}

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    return model, tokenizer

current_model = None
current_tokenizer = None

def get_model_and_tokenizer(model_name):
    global current_model, current_tokenizer
    if current_model is None or current_tokenizer is None:
        current_model, current_tokenizer = load_model(model_name)
    return current_model, current_tokenizer

def create_visualization(probs, labels):
    return go.Figure(data=[go.Pie(
        labels=labels + ['Others'] if sum(probs) < 1 else labels,
        values=list(probs) + [1 - sum(probs)] if sum(probs) < 1 else list(probs),
        textinfo='percent',
        textposition='inside',
        hole=.3,
        showlegend=True
    )])

def classify_text(title, abstract, model_name):
    if not title and not abstract:
        return "Error: At least one of title or abstract must be provided.", None
    
    model, tokenizer = get_model_and_tokenizer(model_name)
    text = 'Title: ' + (title or '') + '\n\nAbstract: ' + (abstract or '')
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    probs = torch.nn.functional.softmax(logits[0], dim=0)
    probs = probs.numpy()
    sorted_idx = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_idx]
    cumsum = np.cumsum(sorted_probs)
    k = 1
    if sorted_probs[0] < 0.95:
        k = np.argmax(cumsum >= 0.95) + 1
    id2label = model.config.id2label
    tags = [id2label[idx] for idx in sorted_idx[:k]]
    compact_pred = f'<span style="font-weight: 800;">{tags[0]}</span>' + (f" {' '.join(tags[1:])}" if len(tags) > 1 else "")
    viz_data = create_visualization(
        sorted_probs[:k],
        [id2label[idx] for idx in sorted_idx[:k]]
    )
    html_output = f"""
    <div>
        <h3>Predicted Tags</h3>
        <p>{compact_pred}</p>
    </div>
    """
    return html_output, viz_data

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Arxiv Tags Classification
    Classify academic papers into arXiv categories using state-of-the-art language models.
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            model_dropdown = gr.Dropdown(
                choices=list(MODEL_OPTIONS.keys()),
                value=list(MODEL_OPTIONS.keys())[0],
                label="Select Model",
                info="Choose the model for classification"
            )
            title_input = gr.Textbox(
                lines=1,
                label="Title",
                placeholder="Enter paper title (optional if abstract is provided)"
            )
            abstract_input = gr.Textbox(
                lines=5,
                label="Abstract",
                placeholder="Enter paper abstract (optional if title is provided)"
            )
        with gr.Column(scale=1):
            output_html = gr.HTML(
                label="Predicted Tags"
            )
            output_plot = gr.Plot(
                label="Probability Distribution",
                show_label=True
            )
    inputs = [title_input, abstract_input, model_dropdown]
    btn = gr.Button("Classify", variant="primary")
    btn.click(fn=classify_text, inputs=inputs, outputs=[output_html, output_plot])

if __name__ == "__main__":
    demo.launch()