File size: 3,143 Bytes
c23c7fd
 
4b87ee5
c0a3229
c23c7fd
c3b53de
c2d61c7
 
c23c7fd
 
71a36a2
aa45a03
312d734
97377d8
312d734
97377d8
 
 
 
b4d98ff
aa45a03
c23c7fd
 
 
 
ccf784a
 
 
 
c23c7fd
 
 
 
c7d91ee
 
 
c23c7fd
ac94852
c23c7fd
ac94852
 
 
 
 
66ca485
ac94852
c23c7fd
66ca485
ac94852
c23c7fd
f1fb568
c23c7fd
 
 
 
 
c2d61c7
 
 
 
434ea14
c23c7fd
 
 
 
c2d61c7
 
b6b5a2d
c2d61c7
 
 
 
 
 
 
97377d8
4af7d1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from safetensors.torch import load_file as safe_load

target_to_ind = {'cs': 0, 'econ': 1, 'eess': 2, 'math': 3, 'phys': 4, 'q-bio': 5, 'q-fin': 6, 'stat': 7}
target_to_label = {'cs': 'Computer Science', 'econ': 'Economics', 'eess': 'Electrical Engineering and Systems Science', 'math': 'Mathematics', 'phys': 'Physics', 
                  'q-bio': 'Quantitative Biology', 'q-fin': 'Quantitative Finance', 'stat': 'Statistics'}
ind_to_target = {ind: target for target, ind in target_to_ind.items()}



st.title('papers_classifier 🤓')
st.text("Hey! I'm papers_classifier and I'm here to help you with answering the question 'WTF is this paper about?\n'
According to arXiv there are 8 different fields of study - Computer Science, Economics, Electrical Engineering and Systems Science, Mathematics, Physics, Quantitative Biology, \
Quantitative Finance and Statistics. So, everything I'll tell you will be about these eight gentlemen.\n
You need to give me paper's title and (if you have one) it's abstract. Also you need to choose classification mode - there are 2 of them:
best prediction and top 95% which means that you'll see as many classes as I need to show you to be confident with probability at least 0.95 that the correct one is among them.\n
After that you need to press the Get prediction button and I'll tell you to which fields of study this paper is related. \n
")

@st.cache_resource
def load_model_and_tokenizer():
    model_name = 'distilbert/distilbert-base-cased'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(target_to_ind))
    
    state_dict = safe_load("model.safetensors")
    model.load_state_dict(state_dict)
    
    return model, tokenizer


model, tokenizer = load_model_and_tokenizer()


def get_predict(title: str, abstract: str) -> (str, float, dict):
    text = [title + tokenizer.sep_token + abstract[:128]]

    tokens_info = tokenizer(
        text,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    
    with torch.no_grad():
        out = model(**tokens_info)
        probs = torch.nn.functional.softmax(out.logits, dim=-1).tolist()[0]

        return list(sorted([(p, ind_to_target[i]) for i, p in enumerate(probs)]))[::-1]


title = st.text_area("Title ", "", height=100)
abstract = st.text_area("Abstract ", "", height=150)


mode = st.radio("Mode: ", ("Best prediction", "Top 95%"))

if st.button("Get prediction", key="manual"):
    if len(title) == 0:
        st.error("Please, provide paper's title")
    else:
        with st.spinner("Be patient, I'm doing my best"):
            predict = get_predict(title, abstract)

        tags = []
        threshold = 0 if mode == "Best prediction" else 0.95
        sum_p = 0
        for p, tag in predict:
            sum_p += p
            tags.append(target_to_label[tag])

            if sum_p >= threshold:
                break
        tags = '\n\n'.join(tags)
        st.success(tags)