Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer | |
from model_SingleLabelClassifier import SingleLabelClassifier | |
from safetensors.torch import load_file | |
import json | |
MODEL_NAME = "allenai/scibert_scivocab_uncased" | |
CHECKPOINT_PATH = "checkpoint-23985" | |
NUM_CLASSES = 65 | |
MAX_LEN = 250 | |
# Загрузка меток | |
with open("label_mappings.json", "r") as f: | |
mappings = json.load(f) | |
abel2id = mappings["label2id"] | |
id2label = {int(k): v for k, v in mappings["id2label"].items()} | |
# Загрузка модели и токенизатора | |
def load_model_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH) | |
model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES) | |
state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors") | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model, tokenizer | |
model, tokenizer = load_model_and_tokenizer() | |
# Функция предсказания | |
def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3): | |
model.eval() | |
text = title + ". " + summary | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding="max_length", | |
max_length=max_length | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs["logits"] | |
probs = F.softmax(logits, dim=1).squeeze().numpy() | |
top_indices = probs.argsort()[::-1][:top_k] | |
return [(id2label[i], round(probs[i], 3)) for i in top_indices] | |
# Интерфейс Streamlit | |
st.title("ArXiv Tag Predictor") | |
st.write("Вставьте заголовок и аннотацию статьи!") | |
title = st.text_input("**Title**") | |
summary = st.text_area("**Summary**", height=200) | |
if st.button("Предсказать тег"): | |
if not title or not summary: | |
st.warning("Пожалуйста, введите и заголовок, и аннотацию!") | |
else: | |
preds = predict(title, summary, model, tokenizer, id2label) | |
st.subheader("Предсказанные теги:") | |
for tag, prob in preds: | |
st.write(f"**{tag}** — вероятность: {prob:.3f}") | |