ysda_hf / app.py
ppapenj's picture
Update app.py
e526340 verified
raw
history blame contribute delete
1.61 kB
import streamlit as st
from transformers import pipeline
import json
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
@st.cache_resource
def load_dicts():
with open("label2ind.json", "r") as file:
label2ind = json.load(file)
with open("ind2label.json", "r") as file:
ind2label = json.load(file)
return label2ind, ind2label
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = AutoModelForSequenceClassification.from_pretrained(
"checkpoint-23000",
num_labels=len(label2ind),
problem_type="single_label_classification",
)
return tokenizer, model
label2ind, ind2label = load_dicts()
tokenizer, model = load_model()
title = st.text_input("Title", value="Math")
abstract = st.text_input("Abstract", value="Random variable")
def get_logits(title, abstract):
text = title + "###" + abstract
inputs = tokenizer(text, return_tensors="pt")
logits = model(**inputs)['logits']
return logits
def get_ans(logits):
logits = F.softmax(logits, dim=1)
ind = torch.argsort(logits, dim=1, descending=True).flatten()
cum_sum = 0
i = 0
while cum_sum < 0.95 and i < len(ind):
idx = ind[i].item()
cum_sum += logits[0][idx].item()
st.write(f"label: {ind2label.get(str(idx))} with probability: {logits[0][idx].item() * 100:.2f}%")
i += 1
if title or abstract:
logits = get_logits(title, abstract)
get_ans(logits)