Spaces:
Running
Running
Commit
·
106e870
1
Parent(s):
133e071
Добавил модель и интерфейс
Browse files- app.py +65 -0
- checkpoint-23985/model.safetensors +3 -0
- checkpoint-23985/optimizer.pt +3 -0
- checkpoint-23985/rng_state.pth +3 -0
- checkpoint-23985/scheduler.pt +3 -0
- checkpoint-23985/special_tokens_map.json +7 -0
- checkpoint-23985/tokenizer.json +0 -0
- checkpoint-23985/tokenizer_config.json +57 -0
- checkpoint-23985/trainer_state.json +180 -0
- checkpoint-23985/training_args.bin +3 -0
- checkpoint-23985/vocab.txt +0 -0
- label_mappings.json +136 -0
- model_SingleLabelClassifier.py +42 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from model_SingleLabelClassifier import SingleLabelClassifier
|
6 |
+
from safetensors.torch import load_file
|
7 |
+
|
8 |
+
# --- Настройки ---
|
9 |
+
MODEL_NAME = "allenai/scibert_scivocab_uncased"
|
10 |
+
CHECKPOINT_PATH = "checkpoint-28553"
|
11 |
+
NUM_CLASSES = 7
|
12 |
+
MAX_LEN = 320
|
13 |
+
|
14 |
+
# --- Загрузка меток ---
|
15 |
+
label2id = {'cs.CV': 0, 'cs.LG': 1, 'cs.AI': 2, 'cs.CL': 3, 'stat.ML': 4, 'cs.NE': 5, '<OTHER>': 6}
|
16 |
+
id2label = {v: k for k, v in label2id.items()}
|
17 |
+
|
18 |
+
# --- Загрузка модели и токенизатора ---
|
19 |
+
@st.cache_resource
|
20 |
+
def load_model_and_tokenizer():
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
|
22 |
+
model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES)
|
23 |
+
state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors")
|
24 |
+
model.load_state_dict(state_dict)
|
25 |
+
model.eval()
|
26 |
+
return model, tokenizer
|
27 |
+
|
28 |
+
model, tokenizer = load_model_and_tokenizer()
|
29 |
+
|
30 |
+
# --- Функция предсказания ---
|
31 |
+
def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3):
|
32 |
+
model.eval()
|
33 |
+
text = title + ". " + summary
|
34 |
+
|
35 |
+
inputs = tokenizer(
|
36 |
+
text,
|
37 |
+
return_tensors="pt",
|
38 |
+
truncation=True,
|
39 |
+
padding="max_length",
|
40 |
+
max_length=max_length
|
41 |
+
)
|
42 |
+
|
43 |
+
with torch.no_grad():
|
44 |
+
outputs = model(**inputs)
|
45 |
+
logits = outputs["logits"]
|
46 |
+
probs = F.softmax(logits, dim=1).squeeze().numpy()
|
47 |
+
|
48 |
+
top_indices = probs.argsort()[::-1][:top_k]
|
49 |
+
return [(id2label[i], round(probs[i], 3)) for i in top_indices]
|
50 |
+
|
51 |
+
# --- Интерфейс Streamlit ---
|
52 |
+
st.title("ArXiv Tag Predictor")
|
53 |
+
st.write("Вставьте заголовок и аннотацию статьи — получите предсказанный тег!")
|
54 |
+
|
55 |
+
title = st.text_input("**Title**")
|
56 |
+
summary = st.text_area("**Summary**", height=200)
|
57 |
+
|
58 |
+
if st.button("Предсказать тег"):
|
59 |
+
if not title or not summary:
|
60 |
+
st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
|
61 |
+
else:
|
62 |
+
preds = predict(title, summary, model, tokenizer, id2label)
|
63 |
+
st.subheader("Предсказанные теги:")
|
64 |
+
for tag, prob in preds:
|
65 |
+
st.write(f"**{tag}** — вероятность: {prob:.3f}")
|
checkpoint-23985/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51588e88f96095606894dcba3e2f3d15fc41d41eab37b68cc0b303453ac675ca
|
3 |
+
size 446466252
|
checkpoint-23985/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da681d00471785472c51b663fa5dfc09d86055e11281b2f9ca1d89f44cc450e2
|
3 |
+
size 207734077
|
checkpoint-23985/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44b8d226f925403ae3da29607377591d51678f6b810b38292a8988d53c35c49d
|
3 |
+
size 14244
|
checkpoint-23985/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ddb29c7c2b214d417422c344dc6f897586016be65d4e0e163c3818c7e263f168
|
3 |
+
size 1064
|
checkpoint-23985/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
checkpoint-23985/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
checkpoint-23985/tokenizer_config.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"101": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"102": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"103": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"104": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": true,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_basic_tokenize": true,
|
47 |
+
"do_lower_case": true,
|
48 |
+
"mask_token": "[MASK]",
|
49 |
+
"model_max_length": 1000000000000000019884624838656,
|
50 |
+
"never_split": null,
|
51 |
+
"pad_token": "[PAD]",
|
52 |
+
"sep_token": "[SEP]",
|
53 |
+
"strip_accents": null,
|
54 |
+
"tokenize_chinese_chars": true,
|
55 |
+
"tokenizer_class": "BertTokenizer",
|
56 |
+
"unk_token": "[UNK]"
|
57 |
+
}
|
checkpoint-23985/trainer_state.json
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 0.718079007713154,
|
3 |
+
"best_model_checkpoint": "./checkpoints/checkpoint-23985",
|
4 |
+
"epoch": 5.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 23985,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.26,
|
13 |
+
"learning_rate": 2.5e-05,
|
14 |
+
"loss": 2.119,
|
15 |
+
"step": 1250
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"epoch": 0.52,
|
19 |
+
"learning_rate": 1.75e-05,
|
20 |
+
"loss": 1.2117,
|
21 |
+
"step": 2500
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"epoch": 0.78,
|
25 |
+
"learning_rate": 1.75e-05,
|
26 |
+
"loss": 1.0565,
|
27 |
+
"step": 3750
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"epoch": 1.0,
|
31 |
+
"eval_accuracy": 0.6968938920158433,
|
32 |
+
"eval_loss": 0.930023193359375,
|
33 |
+
"eval_runtime": 327.1586,
|
34 |
+
"eval_samples_per_second": 117.301,
|
35 |
+
"eval_steps_per_second": 3.668,
|
36 |
+
"step": 4797
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"epoch": 1.04,
|
40 |
+
"learning_rate": 1.2249999999999998e-05,
|
41 |
+
"loss": 0.9975,
|
42 |
+
"step": 5000
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"epoch": 1.3,
|
46 |
+
"learning_rate": 1.2249999999999998e-05,
|
47 |
+
"loss": 0.9029,
|
48 |
+
"step": 6250
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"epoch": 1.56,
|
52 |
+
"learning_rate": 8.574999999999998e-06,
|
53 |
+
"loss": 0.8903,
|
54 |
+
"step": 7500
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"epoch": 1.82,
|
58 |
+
"learning_rate": 8.574999999999998e-06,
|
59 |
+
"loss": 0.874,
|
60 |
+
"step": 8750
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"epoch": 2.0,
|
64 |
+
"eval_accuracy": 0.7115905774442359,
|
65 |
+
"eval_loss": 0.8782150745391846,
|
66 |
+
"eval_runtime": 300.3321,
|
67 |
+
"eval_samples_per_second": 127.779,
|
68 |
+
"eval_steps_per_second": 3.996,
|
69 |
+
"step": 9594
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"epoch": 2.08,
|
73 |
+
"learning_rate": 6.002499999999999e-06,
|
74 |
+
"loss": 0.8538,
|
75 |
+
"step": 10000
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"epoch": 2.35,
|
79 |
+
"learning_rate": 6.002499999999999e-06,
|
80 |
+
"loss": 0.8203,
|
81 |
+
"step": 11250
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"epoch": 2.61,
|
85 |
+
"learning_rate": 4.201749999999999e-06,
|
86 |
+
"loss": 0.8195,
|
87 |
+
"step": 12500
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"epoch": 2.87,
|
91 |
+
"learning_rate": 4.201749999999999e-06,
|
92 |
+
"loss": 0.8116,
|
93 |
+
"step": 13750
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"epoch": 3.0,
|
97 |
+
"eval_accuracy": 0.7145611840733792,
|
98 |
+
"eval_loss": 0.8623952865600586,
|
99 |
+
"eval_runtime": 300.1975,
|
100 |
+
"eval_samples_per_second": 127.836,
|
101 |
+
"eval_steps_per_second": 3.997,
|
102 |
+
"step": 14391
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"epoch": 3.13,
|
106 |
+
"learning_rate": 2.941224999999999e-06,
|
107 |
+
"loss": 0.7909,
|
108 |
+
"step": 15000
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"epoch": 3.39,
|
112 |
+
"learning_rate": 2.941224999999999e-06,
|
113 |
+
"loss": 0.7914,
|
114 |
+
"step": 16250
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"epoch": 3.65,
|
118 |
+
"learning_rate": 2.058857499999999e-06,
|
119 |
+
"loss": 0.7926,
|
120 |
+
"step": 17500
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"epoch": 3.91,
|
124 |
+
"learning_rate": 2.058857499999999e-06,
|
125 |
+
"loss": 0.7839,
|
126 |
+
"step": 18750
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"epoch": 4.0,
|
130 |
+
"eval_accuracy": 0.7179226599958307,
|
131 |
+
"eval_loss": 0.8576174378395081,
|
132 |
+
"eval_runtime": 300.1549,
|
133 |
+
"eval_samples_per_second": 127.854,
|
134 |
+
"eval_steps_per_second": 3.998,
|
135 |
+
"step": 19188
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"epoch": 4.17,
|
139 |
+
"learning_rate": 1.4412002499999993e-06,
|
140 |
+
"loss": 0.7696,
|
141 |
+
"step": 20000
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"epoch": 4.43,
|
145 |
+
"learning_rate": 1.4412002499999993e-06,
|
146 |
+
"loss": 0.7667,
|
147 |
+
"step": 21250
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"epoch": 4.69,
|
151 |
+
"learning_rate": 1.0088401749999995e-06,
|
152 |
+
"loss": 0.7701,
|
153 |
+
"step": 22500
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"epoch": 4.95,
|
157 |
+
"learning_rate": 1.0088401749999995e-06,
|
158 |
+
"loss": 0.7726,
|
159 |
+
"step": 23750
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"epoch": 5.0,
|
163 |
+
"eval_accuracy": 0.718079007713154,
|
164 |
+
"eval_loss": 0.8557529449462891,
|
165 |
+
"eval_runtime": 300.2736,
|
166 |
+
"eval_samples_per_second": 127.803,
|
167 |
+
"eval_steps_per_second": 3.996,
|
168 |
+
"step": 23985
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"logging_steps": 1250,
|
172 |
+
"max_steps": 23985,
|
173 |
+
"num_input_tokens_seen": 0,
|
174 |
+
"num_train_epochs": 5,
|
175 |
+
"save_steps": 500,
|
176 |
+
"total_flos": 0.0,
|
177 |
+
"train_batch_size": 32,
|
178 |
+
"trial_name": null,
|
179 |
+
"trial_params": null
|
180 |
+
}
|
checkpoint-23985/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:351ee58f715ddbda098b9c7aa4f73852ef8355fd01d77a909cbdf33db04aedc4
|
3 |
+
size 4664
|
checkpoint-23985/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
label_mappings.json
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"label2id": {
|
3 |
+
"cs.CV": 0,
|
4 |
+
"cs.LG": 1,
|
5 |
+
"cs.CL": 2,
|
6 |
+
"cs.AI": 3,
|
7 |
+
"quant-ph": 4,
|
8 |
+
"math.CO": 5,
|
9 |
+
"stat.ML": 6,
|
10 |
+
"astro-ph.GA": 7,
|
11 |
+
"hep-ph": 8,
|
12 |
+
"hep-th": 9,
|
13 |
+
"astro-ph.HE": 10,
|
14 |
+
"cs.CR": 11,
|
15 |
+
"cond-mat.mtrl-sci": 12,
|
16 |
+
"cs.RO": 13,
|
17 |
+
"astro-ph.SR": 14,
|
18 |
+
"gr-qc": 15,
|
19 |
+
"math.NT": 16,
|
20 |
+
"math.OC": 17,
|
21 |
+
"cs.DS": 18,
|
22 |
+
"cs.NE": 19,
|
23 |
+
"cs.IT": 20,
|
24 |
+
"math.AP": 21,
|
25 |
+
"astro-ph.CO": 22,
|
26 |
+
"math.PR": 23,
|
27 |
+
"eess.IV": 24,
|
28 |
+
"physics.optics": 25,
|
29 |
+
"cond-mat.mes-hall": 26,
|
30 |
+
"stat.ME": 27,
|
31 |
+
"astro-ph.EP": 28,
|
32 |
+
"math.AG": 29,
|
33 |
+
"eess.SP": 30,
|
34 |
+
"eess.SY": 31,
|
35 |
+
"cs.IR": 32,
|
36 |
+
"math.NA": 33,
|
37 |
+
"cs.DC": 34,
|
38 |
+
"cs.SE": 35,
|
39 |
+
"astro-ph.IM": 36,
|
40 |
+
"cond-mat.str-el": 37,
|
41 |
+
"hep-ex": 38,
|
42 |
+
"math.DS": 39,
|
43 |
+
"math.DG": 40,
|
44 |
+
"cs.GT": 41,
|
45 |
+
"math.GR": 42,
|
46 |
+
"cond-mat.stat-mech": 43,
|
47 |
+
"physics.flu-dyn": 44,
|
48 |
+
"math.FA": 45,
|
49 |
+
"cs.CY": 46,
|
50 |
+
"cs.NI": 47,
|
51 |
+
"cond-mat.soft": 48,
|
52 |
+
"cs.SI": 49,
|
53 |
+
"cs.HC": 50,
|
54 |
+
"cs.LO": 51,
|
55 |
+
"math-ph": 52,
|
56 |
+
"physics.soc-ph": 53,
|
57 |
+
"math.RT": 54,
|
58 |
+
"physics.chem-ph": 55,
|
59 |
+
"math.GT": 56,
|
60 |
+
"math.ST": 57,
|
61 |
+
"cs.SD": 58,
|
62 |
+
"math.RA": 59,
|
63 |
+
"stat.AP": 60,
|
64 |
+
"eess.AS": 61,
|
65 |
+
"cs.DB": 62,
|
66 |
+
"math.LO": 63,
|
67 |
+
"<OTHER>": 64
|
68 |
+
},
|
69 |
+
"id2label": {
|
70 |
+
"0": "cs.CV",
|
71 |
+
"1": "cs.LG",
|
72 |
+
"2": "cs.CL",
|
73 |
+
"3": "cs.AI",
|
74 |
+
"4": "quant-ph",
|
75 |
+
"5": "math.CO",
|
76 |
+
"6": "stat.ML",
|
77 |
+
"7": "astro-ph.GA",
|
78 |
+
"8": "hep-ph",
|
79 |
+
"9": "hep-th",
|
80 |
+
"10": "astro-ph.HE",
|
81 |
+
"11": "cs.CR",
|
82 |
+
"12": "cond-mat.mtrl-sci",
|
83 |
+
"13": "cs.RO",
|
84 |
+
"14": "astro-ph.SR",
|
85 |
+
"15": "gr-qc",
|
86 |
+
"16": "math.NT",
|
87 |
+
"17": "math.OC",
|
88 |
+
"18": "cs.DS",
|
89 |
+
"19": "cs.NE",
|
90 |
+
"20": "cs.IT",
|
91 |
+
"21": "math.AP",
|
92 |
+
"22": "astro-ph.CO",
|
93 |
+
"23": "math.PR",
|
94 |
+
"24": "eess.IV",
|
95 |
+
"25": "physics.optics",
|
96 |
+
"26": "cond-mat.mes-hall",
|
97 |
+
"27": "stat.ME",
|
98 |
+
"28": "astro-ph.EP",
|
99 |
+
"29": "math.AG",
|
100 |
+
"30": "eess.SP",
|
101 |
+
"31": "eess.SY",
|
102 |
+
"32": "cs.IR",
|
103 |
+
"33": "math.NA",
|
104 |
+
"34": "cs.DC",
|
105 |
+
"35": "cs.SE",
|
106 |
+
"36": "astro-ph.IM",
|
107 |
+
"37": "cond-mat.str-el",
|
108 |
+
"38": "hep-ex",
|
109 |
+
"39": "math.DS",
|
110 |
+
"40": "math.DG",
|
111 |
+
"41": "cs.GT",
|
112 |
+
"42": "math.GR",
|
113 |
+
"43": "cond-mat.stat-mech",
|
114 |
+
"44": "physics.flu-dyn",
|
115 |
+
"45": "math.FA",
|
116 |
+
"46": "cs.CY",
|
117 |
+
"47": "cs.NI",
|
118 |
+
"48": "cond-mat.soft",
|
119 |
+
"49": "cs.SI",
|
120 |
+
"50": "cs.HC",
|
121 |
+
"51": "cs.LO",
|
122 |
+
"52": "math-ph",
|
123 |
+
"53": "physics.soc-ph",
|
124 |
+
"54": "math.RT",
|
125 |
+
"55": "physics.chem-ph",
|
126 |
+
"56": "math.GT",
|
127 |
+
"57": "math.ST",
|
128 |
+
"58": "cs.SD",
|
129 |
+
"59": "math.RA",
|
130 |
+
"60": "stat.AP",
|
131 |
+
"61": "eess.AS",
|
132 |
+
"62": "cs.DB",
|
133 |
+
"63": "math.LO",
|
134 |
+
"64": "<OTHER>"
|
135 |
+
}
|
136 |
+
}
|
model_SingleLabelClassifier.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class SingleLabelClassifier(nn.Module):
|
6 |
+
def __init__(self, base_model_name, num_labels, hidden_size=1024, freeze_bert=True):
|
7 |
+
super(SingleLabelClassifier, self).__init__()
|
8 |
+
self.base = AutoModel.from_pretrained(base_model_name)
|
9 |
+
|
10 |
+
if freeze_bert:
|
11 |
+
for name, param in self.base.named_parameters():
|
12 |
+
if not name.startswith("embeddings"):
|
13 |
+
param.requires_grad = False
|
14 |
+
|
15 |
+
self.intermediate = nn.Linear(self.base.config.hidden_size, hidden_size)
|
16 |
+
self.norm = nn.BatchNorm1d(hidden_size)
|
17 |
+
self.activation = nn.ReLU()
|
18 |
+
self.dropout = nn.Dropout(0.4)
|
19 |
+
self.classifier = nn.Linear(hidden_size, num_labels)
|
20 |
+
|
21 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,labels=None):
|
22 |
+
outputs = self.base(
|
23 |
+
input_ids=input_ids,
|
24 |
+
attention_mask=attention_mask,
|
25 |
+
token_type_ids=token_type_ids,
|
26 |
+
return_dict=True
|
27 |
+
)
|
28 |
+
pooled_output = outputs.pooler_output
|
29 |
+
x = self.intermediate(pooled_output)
|
30 |
+
x = self.norm(x)
|
31 |
+
x = self.activation(x)
|
32 |
+
x = self.dropout(x)
|
33 |
+
logits = self.classifier(x)
|
34 |
+
|
35 |
+
loss = None
|
36 |
+
if labels is not None:
|
37 |
+
labels = labels.long()
|
38 |
+
loss_fct = nn.CrossEntropyLoss()
|
39 |
+
loss = loss_fct(logits, labels)
|
40 |
+
|
41 |
+
|
42 |
+
return {"logits": logits, "loss": loss}
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
safetensors
|