Spaces:
Running
Running
import streamlit as st | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
import torch | |
import numpy as np | |
MAPPING = { | |
'cs': 'Computer Science', 'stat': 'Statistics', 'math': 'Mathematics', 'q-bio': 'Quantitative Biology', | |
'physics': 'Physics', 'cmpl-lg': 'Computation and Language', 'eess': 'Electrical Engineering and Systems Science', | |
'quant-ph': 'Quantum Physics', 'cond-mat': 'Condensed Matter', 'astro-ph': 'Astrophysics', 'nlin': 'Nonlinear Sciences', | |
'q-fin': 'Quantitative Finance', ':)': 'Something else' | |
} | |
def load_model(): | |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') | |
model = DistilBertForSequenceClassification.from_pretrained('model/') | |
return tokenizer, model | |
tokenizer, model = load_model() | |
st.title('arXiv Article Classifier') | |
title = st.text_input('Title') | |
abstract = st.text_area('Abstract') | |
text = title + ' ' + abstract if abstract else title | |
if st.button('Predict'): | |
if not text.strip(): | |
st.error('Please enter at least a title.') | |
else: | |
inputs = tokenizer( | |
text, | |
truncation=True, | |
padding=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = torch.nn.functional.softmax(logits, dim=1).numpy()[0] | |
sorted_indices = np.argsort(-probs) | |
cumulative = 0 | |
result = [] | |
for idx in sorted_indices: | |
cumulative += probs[idx] | |
result.append((model.config.id2label[idx], probs[idx])) | |
if cumulative >= 0.95: | |
break | |
for tag, prob in result: | |
if tag in MAPPING: | |
st.write(f'{MAPPING[tag]}: {prob:.2%}') | |
else: | |
st.write(f'{tag}: {prob:.2%}') |