hw4 / app.py
MikhailPugachev's picture
Исправлен путь к модели
f2b6c86
raw
history blame
2.25 kB
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model_SingleLabelClassifier import SingleLabelClassifier
from safetensors.torch import load_file
import json
MODEL_NAME = "allenai/scibert_scivocab_uncased"
CHECKPOINT_PATH = "checkpoint-23985"
NUM_CLASSES = 65
MAX_LEN = 250
# Загрузка меток
with open("label_mappings.json", "r") as f:
mappings = json.load(f)
abel2id = mappings["label2id"]
id2label = {int(k): v for k, v in mappings["id2label"].items()}
# Загрузка модели и токенизатора
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES)
state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors")
model.load_state_dict(state_dict)
model.eval()
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
# Функция предсказания
def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3):
model.eval()
text = title + ". " + summary
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length
)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
probs = F.softmax(logits, dim=1).squeeze().numpy()
top_indices = probs.argsort()[::-1][:top_k]
return [(id2label[i], round(probs[i], 3)) for i in top_indices]
# Интерфейс Streamlit
st.title("ArXiv Tag Predictor")
st.write("Вставьте заголовок и аннотацию статьи!")
title = st.text_input("**Title**")
summary = st.text_area("**Summary**", height=200)
if st.button("Предсказать тег"):
if not title or not summary:
st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
else:
preds = predict(title, summary, model, tokenizer, id2label)
st.subheader("Предсказанные теги:")
for tag, prob in preds:
st.write(f"**{tag}** — вероятность: {prob:.3f}")