Spaces:
Sleeping
Sleeping
# app.py | |
import streamlit as st | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
import torch | |
import numpy as np | |
def load_model(): | |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') | |
model = DistilBertForSequenceClassification.from_pretrained('model/') | |
return tokenizer, model | |
tokenizer, model = load_model() | |
st.title('arXiv Article Classifier') | |
title = st.text_input('Title') | |
abstract = st.text_area('Abstract') | |
text = title + ' ' + abstract if abstract else title | |
if st.button('Predict'): | |
if not text.strip(): | |
st.error('Please enter at least a title.') | |
else: | |
inputs = tokenizer( | |
text, | |
truncation=True, | |
padding=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = torch.nn.functional.softmax(logits, dim=1).numpy()[0] | |
sorted_indices = np.argsort(-probs) | |
cumulative = 0 | |
result = [] | |
for idx in sorted_indices: | |
cumulative += probs[idx] | |
result.append((model.config.id2label[idx], probs[idx])) | |
if cumulative >= 0.95: | |
break | |
for tag, prob in result: | |
st.write(f'{tag}: {prob:.2%}') |