File size: 3,787 Bytes
106e870
 
 
 
 
 
0c47f30
d5fbae4
0c47f30
106e870
 
0c47f30
 
f2b6c86
106e870
0c47f30
 
 
d5fbae4
0c47f30
106e870
0c47f30
106e870
 
 
 
 
 
 
 
 
 
 
d5fbae4
 
106e870
d5fbae4
 
 
 
106e870
 
 
 
 
 
 
 
 
 
 
d5fbae4
 
 
 
 
106e870
 
 
 
 
 
0c47f30
d5fbae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106e870
 
 
 
d5fbae4
106e870
 
 
 
d5fbae4
106e870
d5fbae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model_SingleLabelClassifier import SingleLabelClassifier
from safetensors.torch import load_file
import json
import re


MODEL_NAME = "allenai/scibert_scivocab_uncased"
CHECKPOINT_PATH = "checkpoint-23985"
NUM_CLASSES = 65
MAX_LEN = 250

# Загрузка меток
with open("label_mappings.json", "r") as f:
    mappings = json.load(f)
label2id = mappings["label2id"]
id2label = {int(k): v for k, v in mappings["id2label"].items()}

# Загрузка модели и токенизатора
@st.cache_resource
def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
    model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES)
    state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors")
    model.load_state_dict(state_dict)
    model.eval()
    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()

# Обновлённая функция предсказания
def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3):
    model.eval()

    # Удаляем лишние точки, пробелы и объединяем текст
    title = re.sub(r"\.+$", "", title.strip())
    summary = re.sub(r"\.+$", "", summary.strip())
    text = title + ". " + summary

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length
    )

    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            token_type_ids=inputs.get("token_type_ids")
        )
        logits = outputs["logits"]
        probs = F.softmax(logits, dim=1).squeeze().numpy()

    top_indices = probs.argsort()[::-1][:top_k]
    return [(id2label[i], round(probs[i], 3)) for i in top_indices]

# Интерфейс Streamlit
st.title("🧠 ArXiv Tag Predictor")

with st.expander("ℹ️ Описание модели"):
    st.markdown("""
    Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для классификации научных статей с сайта [arXiv.org](https://arxiv.org).
    
    - Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.)
    - Модель обучена на **заголовках и аннотациях** научных публикаций
    - На вход принимает **англоязычный текст**
    - Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи
    
    Ниже вы можете посмотреть полный список возможных тегов 👇
    """)

with st.expander("📄 Список всех тегов"):
    tag_list = sorted(label2id.keys())
    st.markdown("\n".join([f"- `{tag}`" for tag in tag_list]))

st.write("Введите заголовок и аннотацию научной статьи (на английском):")

title = st.text_input("**Title**")
summary = st.text_area("**Summary**", height=200)

if st.button("📌 Предсказать теги"):
    if not title or not summary:
        st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
    else:
        preds = predict(title, summary, model, tokenizer, id2label)
        st.subheader("📚 Предсказанные теги:")
        for tag, prob in preds:
            st.write(f"**{tag}** — вероятность: `{prob:.3f}`")