waleko commited on
Commit
ff3f523
·
1 Parent(s): fdeefe0

init commit

Browse files
Files changed (2) hide show
  1. app.py +107 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+
7
+ # Initialize model and tokenizer
8
+ MODEL_OPTIONS = {
9
+ "waleko/roberta-arxiv-tags": "RoBERTa Arxiv Tags"
10
+ }
11
+
12
+ def load_model(model_name):
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
15
+ return model, tokenizer
16
+
17
+ current_model = None
18
+ current_tokenizer = None
19
+
20
+ def get_model_and_tokenizer(model_name):
21
+ global current_model, current_tokenizer
22
+ if current_model is None or current_tokenizer is None:
23
+ current_model, current_tokenizer = load_model(model_name)
24
+ return current_model, current_tokenizer
25
+
26
+ def create_visualization(probs, labels):
27
+ return go.Figure(data=[go.Pie(
28
+ labels=labels + ['Others'] if sum(probs) < 1 else labels,
29
+ values=list(probs) + [1 - sum(probs)] if sum(probs) < 1 else list(probs),
30
+ textinfo='percent',
31
+ textposition='inside',
32
+ hole=.3,
33
+ showlegend=True
34
+ )])
35
+
36
+ def classify_text(title, abstract, model_name):
37
+ if not title and not abstract:
38
+ return "Error: At least one of title or abstract must be provided.", None
39
+
40
+ model, tokenizer = get_model_and_tokenizer(model_name)
41
+ text = 'Title: ' + (title or '') + '\n\nAbstract: ' + (abstract or '')
42
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+ logits = outputs.logits
46
+ probs = torch.nn.functional.softmax(logits[0], dim=0)
47
+ probs = probs.numpy()
48
+ sorted_idx = np.argsort(probs)[::-1]
49
+ sorted_probs = probs[sorted_idx]
50
+ cumsum = np.cumsum(sorted_probs)
51
+ k = 1
52
+ if sorted_probs[0] < 0.95:
53
+ k = np.argmax(cumsum >= 0.95) + 1
54
+ id2label = model.config.id2label
55
+ tags = [id2label[idx] for idx in sorted_idx[:k]]
56
+ compact_pred = f'<span style="font-weight: 800;">{tags[0]}</span>' + (f" {' '.join(tags[1:])}" if len(tags) > 1 else "")
57
+ viz_data = create_visualization(
58
+ sorted_probs[:k],
59
+ [id2label[idx] for idx in sorted_idx[:k]]
60
+ )
61
+ html_output = f"""
62
+ <div>
63
+ <h3>Predicted Tags</h3>
64
+ <p>{compact_pred}</p>
65
+ </div>
66
+ """
67
+ return html_output, viz_data
68
+
69
+ # Create Gradio interface
70
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
71
+ gr.Markdown("""
72
+ # Arxiv Tags Classification
73
+ Classify academic papers into arXiv categories using state-of-the-art language models.
74
+ """)
75
+
76
+ with gr.Row():
77
+ with gr.Column(scale=1):
78
+ model_dropdown = gr.Dropdown(
79
+ choices=list(MODEL_OPTIONS.keys()),
80
+ value=list(MODEL_OPTIONS.keys())[0],
81
+ label="Select Model",
82
+ info="Choose the model for classification"
83
+ )
84
+ title_input = gr.Textbox(
85
+ lines=1,
86
+ label="Title",
87
+ placeholder="Enter paper title (optional if abstract is provided)"
88
+ )
89
+ abstract_input = gr.Textbox(
90
+ lines=5,
91
+ label="Abstract",
92
+ placeholder="Enter paper abstract (optional if title is provided)"
93
+ )
94
+ with gr.Column(scale=1):
95
+ output_html = gr.HTML(
96
+ label="Predicted Tags"
97
+ )
98
+ output_plot = gr.Plot(
99
+ label="Probability Distribution",
100
+ show_label=True
101
+ )
102
+ inputs = [title_input, abstract_input, model_dropdown]
103
+ btn = gr.Button("Classify", variant="primary")
104
+ btn.click(fn=classify_text, inputs=inputs, outputs=[output_html, output_plot])
105
+
106
+ if __name__ == "__main__":
107
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ numpy>=1.24.0
5
+ plotly