Karpernik's picture
Update app.py
ee2a472 verified
import streamlit as st
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
st.markdown('## Классификатор статей')
st.write('Данный сервис предназначен для выбора темы статьи [по таксономии arxiv.org](https://arxiv.org/category_taxonomy), \n' \
'основываясь на ее названии и краткой выжимки текста статьи. \n' \
'Сервис работает благодаря fine-tune версии модели [distil bert](https://huggingface.co/distilbert/distilbert-base-cased) [1]. \n' \
'Данные для обучения были взяты [отсюда](https://www.kaggle.com/datasets/neelshah18/arxivdataset). \n' \
'Поддерживается ввод только английского языка. \n')
st.markdown('#### Введите название статьи и ее краткое содержание:')
device = torch.device('cpu')
def create_model_and_optimizer(model, lr=1e-4, beta1=0.9, beta2=0.999, device=device):
model = model.to(device)
params = []
for param in model.parameters():
if param.requires_grad:
params.append(param)
optimizer = torch.optim.Adam(params, lr, [beta1, beta2])
return model, optimizer
def title_summary_transform(items, tokenizer):
return tokenizer(
items['title'] + '[SEP]' + items['summary'],
padding="max_length",
truncation=True
)
def predict_category(case, model, tokenizer):
input_ = {key: torch.tensor(val).unsqueeze(0).to(device) for key, val in title_summary_transform(case, tokenizer).items()}
pred = []
pred_prob = []
with torch.no_grad():
logits = model(**input_).logits[0]
probs = torch.nn.functional.softmax(logits, dim=-1)
probs, indices = torch.sort(probs, descending=True)
sum_prob = 0
for i, prob_ in enumerate(probs):
pred.append(indices[i].item())
pred_prob.append(prob_.item())
sum_prob += prob_
if sum_prob > 0.95:
break
return pred, pred_prob
@st.cache_resource # кэширование
def load_model():
chkp_folder = '.'
model_name = 'model'
cat_count = 358
checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False, map_location=torch.device('cpu'))
# Создаём те же классы, что и внутри чекпоинта
model_ = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-cased', num_labels=cat_count).to(device)
for param in model_.distilbert.parameters():
param.requires_grad = False
for i in range(4, 6):
for param in model_.distilbert.transformer.layer[i].parameters():
param.requires_grad = True
model, optimizer = create_model_and_optimizer(model_)
# Загружаем состояния из чекпоинта
model.load_state_dict(checkpoint['model_state_dict'])
ind_to_cat = checkpoint['ind_to_cat']
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-cased')
return model, tokenizer, ind_to_cat
model, tokenizer, ind_to_cat = load_model()
case_ = {}
case_['title'] = st.text_area("Название статьи:", value="")
case_['summary'] = st.text_area("Краткое содержание:", value="")
if case_['title'] or case_['summary']:
categories, probabilities = predict_category(case_, model, tokenizer)
st.markdown('#### Возможные категории:')
for i, cat in enumerate(categories):
st.markdown("- " + f'{ind_to_cat[cat]}')
st.write(
'''[1] DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter,
Victor Sanh and Lysandre Debut and Julien Chaumond and Thomas Wolf,
ArXiv, 2019, abs/1910.01108'''
)