import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer from model_SingleLabelClassifier import SingleLabelClassifier from safetensors.torch import load_file import json MODEL_NAME = "allenai/scibert_scivocab_uncased" CHECKPOINT_PATH = "checkpoint-23985" NUM_CLASSES = 65 MAX_LEN = 250 # Загрузка меток with open("label_mappings.json", "r") as f: mappings = json.load(f) abel2id = mappings["label2id"] id2label = {int(k): v for k, v in mappings["id2label"].items()} # Загрузка модели и токенизатора @st.cache_resource def load_model_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH) model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES) state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors") model.load_state_dict(state_dict) model.eval() return model, tokenizer model, tokenizer = load_model_and_tokenizer() # Функция предсказания def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3): model.eval() text = title + ". " + summary inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length ) with torch.no_grad(): outputs = model(**inputs) logits = outputs["logits"] probs = F.softmax(logits, dim=1).squeeze().numpy() top_indices = probs.argsort()[::-1][:top_k] return [(id2label[i], round(probs[i], 3)) for i in top_indices] # Интерфейс Streamlit st.title("ArXiv Tag Predictor") st.write("Вставьте заголовок и аннотацию статьи!") title = st.text_input("**Title**") summary = st.text_area("**Summary**", height=200) if st.button("Предсказать тег"): if not title or not summary: st.warning("Пожалуйста, введите и заголовок, и аннотацию!") else: preds = predict(title, summary, model, tokenizer, id2label) st.subheader("Предсказанные теги:") for tag, prob in preds: st.write(f"**{tag}** — вероятность: {prob:.3f}")