|
import streamlit as st |
|
from transformers import pipeline |
|
import json |
|
import torch |
|
from torch.nn import functional as F |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
@st.cache_resource |
|
def load_dicts(): |
|
with open("label2ind.json", "r") as file: |
|
label2ind = json.load(file) |
|
with open("ind2label.json", "r") as file: |
|
ind2label = json.load(file) |
|
return label2ind, ind2label |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
"checkpoint-23000", |
|
num_labels=len(label2ind), |
|
problem_type="single_label_classification", |
|
) |
|
return tokenizer, model |
|
|
|
label2ind, ind2label = load_dicts() |
|
tokenizer, model = load_model() |
|
|
|
title = st.text_input("Title", value="Math") |
|
abstract = st.text_input("Abstract", value="Random variable") |
|
|
|
def get_logits(title, abstract): |
|
text = title + "###" + abstract |
|
inputs = tokenizer(text, return_tensors="pt") |
|
logits = model(**inputs)['logits'] |
|
return logits |
|
|
|
def get_ans(logits): |
|
logits = F.softmax(logits, dim=1) |
|
ind = torch.argsort(logits, dim=1, descending=True).flatten() |
|
cum_sum = 0 |
|
i = 0 |
|
while cum_sum < 0.95 and i < len(ind): |
|
idx = ind[i].item() |
|
cum_sum += logits[0][idx].item() |
|
st.write(f"label: {ind2label.get(str(idx))} with probability: {logits[0][idx].item() * 100:.2f}%") |
|
i += 1 |
|
|
|
if title or abstract: |
|
logits = get_logits(title, abstract) |
|
get_ans(logits) |
|
|