Spaces:
Running
Running
Commit
·
d5fbae4
1
Parent(s):
c38b845
Исправлен путь к модели
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ from transformers import AutoTokenizer
|
|
5 |
from model_SingleLabelClassifier import SingleLabelClassifier
|
6 |
from safetensors.torch import load_file
|
7 |
import json
|
|
|
8 |
|
9 |
|
10 |
MODEL_NAME = "allenai/scibert_scivocab_uncased"
|
@@ -15,7 +16,7 @@ MAX_LEN = 250
|
|
15 |
# Загрузка меток
|
16 |
with open("label_mappings.json", "r") as f:
|
17 |
mappings = json.load(f)
|
18 |
-
|
19 |
id2label = {int(k): v for k, v in mappings["id2label"].items()}
|
20 |
|
21 |
# Загрузка модели и токенизатора
|
@@ -30,9 +31,13 @@ def load_model_and_tokenizer():
|
|
30 |
|
31 |
model, tokenizer = load_model_and_tokenizer()
|
32 |
|
33 |
-
#
|
34 |
-
def predict(title, summary, model, tokenizer, id2label, max_length=
|
35 |
model.eval()
|
|
|
|
|
|
|
|
|
36 |
text = title + ". " + summary
|
37 |
|
38 |
inputs = tokenizer(
|
@@ -44,7 +49,11 @@ def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3)
|
|
44 |
)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
-
outputs = model(
|
|
|
|
|
|
|
|
|
48 |
logits = outputs["logits"]
|
49 |
probs = F.softmax(logits, dim=1).squeeze().numpy()
|
50 |
|
@@ -52,17 +61,34 @@ def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3)
|
|
52 |
return [(id2label[i], round(probs[i], 3)) for i in top_indices]
|
53 |
|
54 |
# Интерфейс Streamlit
|
55 |
-
st.title("ArXiv Tag Predictor")
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
title = st.text_input("**Title**")
|
59 |
summary = st.text_area("**Summary**", height=200)
|
60 |
|
61 |
-
if st.button("Предсказать
|
62 |
if not title or not summary:
|
63 |
st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
|
64 |
else:
|
65 |
preds = predict(title, summary, model, tokenizer, id2label)
|
66 |
-
st.subheader("Предсказанные теги:")
|
67 |
for tag, prob in preds:
|
68 |
-
st.write(f"**{tag}** — вероятность: {prob:.3f}")
|
|
|
5 |
from model_SingleLabelClassifier import SingleLabelClassifier
|
6 |
from safetensors.torch import load_file
|
7 |
import json
|
8 |
+
import re
|
9 |
|
10 |
|
11 |
MODEL_NAME = "allenai/scibert_scivocab_uncased"
|
|
|
16 |
# Загрузка меток
|
17 |
with open("label_mappings.json", "r") as f:
|
18 |
mappings = json.load(f)
|
19 |
+
label2id = mappings["label2id"]
|
20 |
id2label = {int(k): v for k, v in mappings["id2label"].items()}
|
21 |
|
22 |
# Загрузка модели и токенизатора
|
|
|
31 |
|
32 |
model, tokenizer = load_model_and_tokenizer()
|
33 |
|
34 |
+
# Обновлённая функция предсказания
|
35 |
+
def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3):
|
36 |
model.eval()
|
37 |
+
|
38 |
+
# Удаляем лишние точки, пробелы и объединяем текст
|
39 |
+
title = re.sub(r"\.+$", "", title.strip())
|
40 |
+
summary = re.sub(r"\.+$", "", summary.strip())
|
41 |
text = title + ". " + summary
|
42 |
|
43 |
inputs = tokenizer(
|
|
|
49 |
)
|
50 |
|
51 |
with torch.no_grad():
|
52 |
+
outputs = model(
|
53 |
+
input_ids=inputs["input_ids"],
|
54 |
+
attention_mask=inputs["attention_mask"],
|
55 |
+
token_type_ids=inputs.get("token_type_ids")
|
56 |
+
)
|
57 |
logits = outputs["logits"]
|
58 |
probs = F.softmax(logits, dim=1).squeeze().numpy()
|
59 |
|
|
|
61 |
return [(id2label[i], round(probs[i], 3)) for i in top_indices]
|
62 |
|
63 |
# Интерфейс Streamlit
|
64 |
+
st.title("🧠 ArXiv Tag Predictor")
|
65 |
+
|
66 |
+
with st.expander("ℹ️ Описание модели"):
|
67 |
+
st.markdown("""
|
68 |
+
Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для классификации научных статей с сайта [arXiv.org](https://arxiv.org).
|
69 |
+
|
70 |
+
- Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.)
|
71 |
+
- Модель обучена на **заголовках и аннотациях** научных публикаций
|
72 |
+
- На вход принимает **англоязычный текст**
|
73 |
+
- Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи
|
74 |
+
|
75 |
+
Ниже вы можете посмотреть полный список возможных тегов 👇
|
76 |
+
""")
|
77 |
+
|
78 |
+
with st.expander("📄 Список всех тегов"):
|
79 |
+
tag_list = sorted(label2id.keys())
|
80 |
+
st.markdown("\n".join([f"- `{tag}`" for tag in tag_list]))
|
81 |
+
|
82 |
+
st.write("Введите заголовок и аннотацию научной статьи (на английском):")
|
83 |
|
84 |
title = st.text_input("**Title**")
|
85 |
summary = st.text_area("**Summary**", height=200)
|
86 |
|
87 |
+
if st.button("📌 Предсказать теги"):
|
88 |
if not title or not summary:
|
89 |
st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
|
90 |
else:
|
91 |
preds = predict(title, summary, model, tokenizer, id2label)
|
92 |
+
st.subheader("📚 Предсказанные теги:")
|
93 |
for tag, prob in preds:
|
94 |
+
st.write(f"**{tag}** — вероятность: `{prob:.3f}`")
|