4_homework / app.py
Kotann's picture
Upload 3 files
076c251 verified
import streamlit as st
import os
import torch
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
@st.cache_resource # кэширование
def get_model(unzip_root: str='./'):
"""
unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .`
"""
checkpoint_path = os.path.join(unzip_root, "model.pth")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model_path = 'cointegrated/rubert-tiny'
model = BertForSequenceClassification.from_pretrained(model_path)
out_features = model.bert.encoder.layer[1].output.dense.out_features
model.classifier = torch.nn.Linear(out_features, len(dict_tetm_to_int))
model.load_state_dict(checkpoint)
return model
@st.cache_resource # кэширование
def get_tokenizer():
tokenizer_path = 'cointegrated/rubert-tiny'
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
return tokenizer
@st.cache_resource # кэширование
def get_vocab(unzip_root: str='./'):
"""
unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .`
"""
path = os.path.join(unzip_root, "vocab.tsv")
with open(path, 'r') as f:
size_dict = int(f.readline())
dict_tetm_to_int = dict()
for _ in range(size_dict):
key = f.readline()[:-1]
dict_tetm_to_int[key] = int(f.readline())
size_dict = int(f.readline())
dict_int_to_term = dict()
for _ in range(size_dict):
key = int(f.readline())
dict_int_to_term[key] = f.readline()[:-1]
return dict_tetm_to_int, dict_int_to_term
softmax = torch.nn.Softmax(dim=1)
dict_tetm_to_int, dict_int_to_term = get_vocab()
model = get_model()
tokenizer = get_tokenizer()
def predict(text, device='cpu'):
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=512,
return_token_type_ids=False,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt',
)
out = {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten()
}
input_ids = out["input_ids"].to(device)
attention_mask = out["attention_mask"].to(device)
outputs = model(
input_ids=input_ids.unsqueeze(0),
attention_mask=attention_mask.unsqueeze(0)
)
out = softmax(outputs.logits)
prediction = torch.argsort(outputs.logits, dim=1, descending=True).cpu()[0]
sum_answer = 0
answer = []
idx = 0
while sum_answer < 0.95:
sum_answer += out[0][idx].item()
answer.append(dict_int_to_term[prediction[idx].item()])
idx += 1
return answer
st.title("We will help you determine what topic your article belongs to:)")
st.header("Please enter a title and/or introduction")
title = st.text_input(label="Title", value="")
abstract = st.text_input(label="Abstract", value="")
if(st.button('Show result')):
predict = ' '.join(predict(title.title() + ' ' + abstract.title()))
result = 'Suggested answer:\n' + predict
st.success(result)