Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import torch | |
from transformers import BertTokenizer | |
from transformers import BertForSequenceClassification | |
# кэширование | |
def get_model(unzip_root: str='./'): | |
""" | |
unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .` | |
""" | |
checkpoint_path = os.path.join(unzip_root, "model.pth") | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
model_path = 'cointegrated/rubert-tiny' | |
model = BertForSequenceClassification.from_pretrained(model_path) | |
out_features = model.bert.encoder.layer[1].output.dense.out_features | |
model.classifier = torch.nn.Linear(out_features, len(dict_tetm_to_int)) | |
model.load_state_dict(checkpoint) | |
return model | |
# кэширование | |
def get_tokenizer(): | |
tokenizer_path = 'cointegrated/rubert-tiny' | |
tokenizer = BertTokenizer.from_pretrained(tokenizer_path) | |
return tokenizer | |
# кэширование | |
def get_vocab(unzip_root: str='./'): | |
""" | |
unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .` | |
""" | |
path = os.path.join(unzip_root, "vocab.tsv") | |
with open(path, 'r') as f: | |
size_dict = int(f.readline()) | |
dict_tetm_to_int = dict() | |
for _ in range(size_dict): | |
key = f.readline()[:-1] | |
dict_tetm_to_int[key] = int(f.readline()) | |
size_dict = int(f.readline()) | |
dict_int_to_term = dict() | |
for _ in range(size_dict): | |
key = int(f.readline()) | |
dict_int_to_term[key] = f.readline()[:-1] | |
return dict_tetm_to_int, dict_int_to_term | |
softmax = torch.nn.Softmax(dim=1) | |
dict_tetm_to_int, dict_int_to_term = get_vocab() | |
model = get_model() | |
tokenizer = get_tokenizer() | |
def predict(text, device='cpu'): | |
encoding = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=512, | |
return_token_type_ids=False, | |
truncation=True, | |
padding='max_length', | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
out = { | |
'text': text, | |
'input_ids': encoding['input_ids'].flatten(), | |
'attention_mask': encoding['attention_mask'].flatten() | |
} | |
input_ids = out["input_ids"].to(device) | |
attention_mask = out["attention_mask"].to(device) | |
outputs = model( | |
input_ids=input_ids.unsqueeze(0), | |
attention_mask=attention_mask.unsqueeze(0) | |
) | |
out = softmax(outputs.logits) | |
prediction = torch.argsort(outputs.logits, dim=1, descending=True).cpu()[0] | |
sum_answer = 0 | |
answer = [] | |
idx = 0 | |
while sum_answer < 0.95: | |
sum_answer += out[0][idx].item() | |
answer.append(dict_int_to_term[prediction[idx].item()]) | |
idx += 1 | |
return answer | |
st.title("We will help you determine what topic your article belongs to:)") | |
st.header("Please enter a title and/or introduction") | |
title = st.text_input(label="Title", value="") | |
abstract = st.text_input(label="Abstract", value="") | |
if(st.button('Show result')): | |
predict = ' '.join(predict(title.title() + ' ' + abstract.title())) | |
result = 'Suggested answer:\n' + predict | |
st.success(result) | |