Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,11 +6,12 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
6 |
|
7 |
|
8 |
st.markdown('## Классификатор статей')
|
9 |
-
st.write('Данный сервис предназначен для выбора темы
|
10 |
-
'основываясь на ее названии и краткой
|
11 |
-
'Сервис работает благодаря fine-tune версии модели distil bert. \n' \
|
12 |
'Данные для обучения были взяты [отсюда](https://www.kaggle.com/datasets/neelshah18/arxivdataset). \n' \
|
13 |
-
'Поддерживается ввод только английского языка.')
|
|
|
14 |
st.markdown('#### Введите название статьи и ее краткое содержание:')
|
15 |
|
16 |
device = torch.device('cpu')
|
@@ -60,7 +61,7 @@ def load_model():
|
|
60 |
model_name = 'model'
|
61 |
cat_count = 358
|
62 |
|
63 |
-
checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False, map_location=
|
64 |
|
65 |
# Создаём те же классы, что и внутри чекпоинта
|
66 |
|
@@ -92,6 +93,12 @@ case_['summary'] = st.text_area("Краткое содержание:", value=""
|
|
92 |
|
93 |
if case_['title'] or case_['summary']:
|
94 |
categories, probabilities = predict_category(case_, model, tokenizer)
|
95 |
-
st.
|
96 |
for i, cat in enumerate(categories):
|
97 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
st.markdown('## Классификатор статей')
|
9 |
+
st.write('Данный сервис предназначен для выбора темы статьи [по таксономии arxiv.org](https://arxiv.org/category_taxonomy), \n' \
|
10 |
+
'основываясь на ее названии и краткой выжимки текста статьи. \n' \
|
11 |
+
'Сервис работает благодаря fine-tune версии модели [distil bert](https://huggingface.co/distilbert/distilbert-base-cased) [1]. \n' \
|
12 |
'Данные для обучения были взяты [отсюда](https://www.kaggle.com/datasets/neelshah18/arxivdataset). \n' \
|
13 |
+
'Поддерживается ввод только английского языка. \n')
|
14 |
+
|
15 |
st.markdown('#### Введите название статьи и ее краткое содержание:')
|
16 |
|
17 |
device = torch.device('cpu')
|
|
|
61 |
model_name = 'model'
|
62 |
cat_count = 358
|
63 |
|
64 |
+
checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False, map_location=device)
|
65 |
|
66 |
# Создаём те же классы, что и внутри чекпоинта
|
67 |
|
|
|
93 |
|
94 |
if case_['title'] or case_['summary']:
|
95 |
categories, probabilities = predict_category(case_, model, tokenizer)
|
96 |
+
st.markdown('#### Возможные категории:')
|
97 |
for i, cat in enumerate(categories):
|
98 |
+
st.markdown("- " + f'{ind_to_cat[cat]}')
|
99 |
+
|
100 |
+
st.write(
|
101 |
+
'''[1] DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter,
|
102 |
+
Victor Sanh and Lysandre Debut and Julien Chaumond and Thomas Wolf,
|
103 |
+
ArXiv, 2019, abs/1910.01108'''
|
104 |
+
)
|