File size: 1,605 Bytes
3d4f52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94e1e0f
3d4f52d
 
 
 
 
 
 
 
 
 
 
 
 
e526340
 
3d4f52d
 
 
e526340
 
3d4f52d
 
e526340
 
 
 
 
3d4f52d
 
489c19d
442e6f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import streamlit as st
from transformers import pipeline
import json
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

@st.cache_resource
def load_dicts():
    with open("label2ind.json", "r") as file:
        label2ind = json.load(file)
    with open("ind2label.json", "r") as file:
        ind2label = json.load(file)
    return label2ind, ind2label

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
    model = AutoModelForSequenceClassification.from_pretrained(
        "checkpoint-23000",
        num_labels=len(label2ind),
        problem_type="single_label_classification",
    )
    return tokenizer, model

label2ind, ind2label = load_dicts()
tokenizer, model = load_model()

title = st.text_input("Title", value="Math")
abstract = st.text_input("Abstract", value="Random variable")

def get_logits(title, abstract):
    text = title + "###" + abstract
    inputs = tokenizer(text, return_tensors="pt")
    logits = model(**inputs)['logits']
    return logits

def get_ans(logits):
    logits = F.softmax(logits, dim=1)
    ind = torch.argsort(logits, dim=1, descending=True).flatten()
    cum_sum = 0
    i = 0
    while cum_sum < 0.95 and i < len(ind):
        idx = ind[i].item()
        cum_sum += logits[0][idx].item()
        st.write(f"label: {ind2label.get(str(idx))} with probability: {logits[0][idx].item() * 100:.2f}%")
        i += 1

if title or abstract:
    logits = get_logits(title, abstract)
    get_ans(logits)