# app.py import streamlit as st from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch import numpy as np @st.cache(allow_output_mutation=True) def load_model(): tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') model = DistilBertForSequenceClassification.from_pretrained('model/') return tokenizer, model tokenizer, model = load_model() st.title('arXiv Article Classifier') title = st.text_input('Title') abstract = st.text_area('Abstract') text = title + ' ' + abstract if abstract else title if st.button('Predict'): if not text.strip(): st.error('Please enter at least a title.') else: inputs = tokenizer( text, truncation=True, padding=True, max_length=512, return_tensors='pt' ) with torch.no_grad(): logits = model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=1).numpy()[0] sorted_indices = np.argsort(-probs) cumulative = 0 result = [] for idx in sorted_indices: cumulative += probs[idx] result.append((model.config.id2label[idx], probs[idx])) if cumulative >= 0.95: break for tag, prob in result: st.write(f'{tag}: {prob:.2%}')