|
import streamlit as st |
|
import torch |
|
import torch.nn.functional as F |
|
import pandas as pd |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from datasets import load_dataset |
|
|
|
|
|
device = 'cpu' |
|
|
|
@st.cache_resource |
|
def get_model_and_tokenizer(): |
|
model_name = "FacebookAI/roberta-base" |
|
num_labels = 157 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) |
|
|
|
chkp = torch.load("arxiv_roberta_final.pt", map_location=device) |
|
model.load_state_dict(chkp['model']) |
|
|
|
return model, tokenizer |
|
|
|
@st.cache_data |
|
def get_categories(): |
|
categories = load_dataset("TimSchopf/arxiv_categories", "arxiv_category_descriptions") |
|
|
|
cat2id = dict((cat, id) for id, cat in enumerate(categories['arxiv_category_descriptions']['tag'])) |
|
id2cat = categories['arxiv_category_descriptions']['tag'] |
|
names = categories['arxiv_category_descriptions']['name'] |
|
|
|
return cat2id, id2cat, names |
|
|
|
model, tokenizer = get_model_and_tokenizer() |
|
cat2id, id2cat, cat_names = get_categories() |
|
|
|
@torch.no_grad |
|
def predict_and_decode(model, title='', abstract=''): |
|
model.eval() |
|
|
|
inputs = tokenizer(title, abstract, return_tensors='pt', truncation=True, max_length=512).to(device) |
|
logits = model(**inputs)['logits'][0].cpu() |
|
|
|
df = pd.DataFrame([ |
|
(id2cat[cat_id], cat_names[cat_id], prob.item()) |
|
for cat_id, prob in enumerate(F.sigmoid(logits)) |
|
], columns=("tag", "name", "probability")) |
|
df.sort_values(by="probability", ascending=False, inplace=True) |
|
|
|
return df.reset_index(drop=True) |
|
|
|
st.header("Paper Category Classifier") |
|
st.text("Input a title and/or an abstract of a scientific paper, and get classification according to arxiv.org categories") |
|
|
|
title_default = "Attention Is All You Need" |
|
abstract_default = ( |
|
"The dominant sequence transduction models are based on complex recurrent or convolutional neural networks " |
|
"in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through " |
|
"an attention mechanism. We propose a new simple network architecture, the Transformer..." |
|
) |
|
|
|
line_height = 34 |
|
n_lines = 10 |
|
title = st.text_input("Paper title", value=title_default, help="Type in paper's title") |
|
abstract = st.text_area("Paper abstract", value=abstract_default, height=line_height*n_lines, help="Type in paper's abstract") |
|
|
|
if title or abstract: |
|
result = predict_and_decode(model, title=title, abstract=abstract) |
|
|
|
cnt = st.container(border=True) |
|
with cnt: |
|
st.markdown("#### Top category") |
|
st.markdown(f"**{result.tag[0]}** -- {result.name[0]}") |
|
st.markdown(f"Probability: {result.probability[0]*100:.2f}%") |
|
|
|
threshold = 0.55 |
|
st.text("Other top categories:") |
|
max_len = min(max(1, sum(result.iloc[1:].probability > threshold)), 5) |
|
|
|
def format_p(example): |
|
example.probability = f"{example.probability * 100 :.2f}%" |
|
return example |
|
st.table(result.iloc[1:1 + max_len].apply(format_p, axis=1)) |
|
else: |
|
st.warning("Type a title and/or an abstract to get started!") |