Spaces:
Running
Running
import streamlit as st | |
import torch | |
import os | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification | |
st.markdown('## Классификатор статей') | |
st.write('Данный сервис предназначен для выбора темы статьи [по таксономии arxiv.org](https://arxiv.org/category_taxonomy), \n' \ | |
'основываясь на ее названии и краткой выжимки текста статьи. \n' \ | |
'Сервис работает благодаря fine-tune версии модели [distil bert](https://huggingface.co/distilbert/distilbert-base-cased) [1]. \n' \ | |
'Данные для обучения были взяты [отсюда](https://www.kaggle.com/datasets/neelshah18/arxivdataset). \n' \ | |
'Поддерживается ввод только английского языка. \n') | |
st.markdown('#### Введите название статьи и ее краткое содержание:') | |
device = torch.device('cpu') | |
def create_model_and_optimizer(model, lr=1e-4, beta1=0.9, beta2=0.999, device=device): | |
model = model.to(device) | |
params = [] | |
for param in model.parameters(): | |
if param.requires_grad: | |
params.append(param) | |
optimizer = torch.optim.Adam(params, lr, [beta1, beta2]) | |
return model, optimizer | |
def title_summary_transform(items, tokenizer): | |
return tokenizer( | |
items['title'] + '[SEP]' + items['summary'], | |
padding="max_length", | |
truncation=True | |
) | |
def predict_category(case, model, tokenizer): | |
input_ = {key: torch.tensor(val).unsqueeze(0).to(device) for key, val in title_summary_transform(case, tokenizer).items()} | |
pred = [] | |
pred_prob = [] | |
with torch.no_grad(): | |
logits = model(**input_).logits[0] | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
probs, indices = torch.sort(probs, descending=True) | |
sum_prob = 0 | |
for i, prob_ in enumerate(probs): | |
pred.append(indices[i].item()) | |
pred_prob.append(prob_.item()) | |
sum_prob += prob_ | |
if sum_prob > 0.95: | |
break | |
return pred, pred_prob | |
# кэширование | |
def load_model(): | |
chkp_folder = '.' | |
model_name = 'model' | |
cat_count = 358 | |
checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False, map_location=torch.device('cpu')) | |
# Создаём те же классы, что и внутри чекпоинта | |
model_ = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-cased', num_labels=cat_count).to(device) | |
for param in model_.distilbert.parameters(): | |
param.requires_grad = False | |
for i in range(4, 6): | |
for param in model_.distilbert.transformer.layer[i].parameters(): | |
param.requires_grad = True | |
model, optimizer = create_model_and_optimizer(model_) | |
# Загружаем состояния из чекпоинта | |
model.load_state_dict(checkpoint['model_state_dict']) | |
ind_to_cat = checkpoint['ind_to_cat'] | |
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-cased') | |
return model, tokenizer, ind_to_cat | |
model, tokenizer, ind_to_cat = load_model() | |
case_ = {} | |
case_['title'] = st.text_area("Название статьи:", value="") | |
case_['summary'] = st.text_area("Краткое содержание:", value="") | |
if case_['title'] or case_['summary']: | |
categories, probabilities = predict_category(case_, model, tokenizer) | |
st.markdown('#### Возможные категории:') | |
for i, cat in enumerate(categories): | |
st.markdown("- " + f'{ind_to_cat[cat]}') | |
st.write( | |
'''[1] DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter, | |
Victor Sanh and Lysandre Debut and Julien Chaumond and Thomas Wolf, | |
ArXiv, 2019, abs/1910.01108''' | |
) |