import streamlit as st import os import torch from transformers import BertTokenizer from transformers import BertForSequenceClassification @st.cache_resource # кэширование def get_model(unzip_root: str='./'): """ unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .` """ checkpoint_path = os.path.join(unzip_root, "model.pth") checkpoint = torch.load(checkpoint_path, map_location="cpu") model_path = 'cointegrated/rubert-tiny' model = BertForSequenceClassification.from_pretrained(model_path) out_features = model.bert.encoder.layer[1].output.dense.out_features model.classifier = torch.nn.Linear(out_features, len(dict_tetm_to_int)) model.load_state_dict(checkpoint) return model @st.cache_resource # кэширование def get_tokenizer(): tokenizer_path = 'cointegrated/rubert-tiny' tokenizer = BertTokenizer.from_pretrained(tokenizer_path) return tokenizer @st.cache_resource # кэширование def get_vocab(unzip_root: str='./'): """ unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .` """ path = os.path.join(unzip_root, "vocab.tsv") with open(path, 'r') as f: size_dict = int(f.readline()) dict_tetm_to_int = dict() for _ in range(size_dict): key = f.readline()[:-1] dict_tetm_to_int[key] = int(f.readline()) size_dict = int(f.readline()) dict_int_to_term = dict() for _ in range(size_dict): key = int(f.readline()) dict_int_to_term[key] = f.readline()[:-1] return dict_tetm_to_int, dict_int_to_term softmax = torch.nn.Softmax(dim=1) dict_tetm_to_int, dict_int_to_term = get_vocab() model = get_model() tokenizer = get_tokenizer() def predict(text, device='cpu'): encoding = tokenizer.encode_plus( text, add_special_tokens=True, max_length=512, return_token_type_ids=False, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt', ) out = { 'text': text, 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten() } input_ids = out["input_ids"].to(device) attention_mask = out["attention_mask"].to(device) outputs = model( input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0) ) out = softmax(outputs.logits) prediction = torch.argsort(outputs.logits, dim=1, descending=True).cpu()[0] sum_answer = 0 answer = [] idx = 0 while sum_answer < 0.95: sum_answer += out[0][idx].item() answer.append(dict_int_to_term[prediction[idx].item()]) idx += 1 return answer st.title("We will help you determine what topic your article belongs to:)") st.header("Please enter a title and/or introduction") title = st.text_input(label="Title", value="") abstract = st.text_input(label="Abstract", value="") if(st.button('Show result')): predict = ' '.join(predict(title.title() + ' ' + abstract.title())) result = 'Suggested answer:\n' + predict st.success(result)