File size: 3,537 Bytes
076c251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import os
import torch
from transformers import BertTokenizer
from transformers import BertForSequenceClassification


@st.cache_resource  # кэширование
def get_model(unzip_root: str='./'):
    """
        unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .`
    """
    checkpoint_path = os.path.join(unzip_root, "model.pth")
    checkpoint = torch.load(checkpoint_path, map_location="cpu")

    model_path = 'cointegrated/rubert-tiny'
    model = BertForSequenceClassification.from_pretrained(model_path)
    out_features = model.bert.encoder.layer[1].output.dense.out_features
    model.classifier = torch.nn.Linear(out_features, len(dict_tetm_to_int))

    model.load_state_dict(checkpoint)
    return model


@st.cache_resource  # кэширование
def get_tokenizer():
    tokenizer_path = 'cointegrated/rubert-tiny'
    tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
    return tokenizer


@st.cache_resource  # кэширование
def get_vocab(unzip_root: str='./'):
    """
        unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .`
    """
    path = os.path.join(unzip_root, "vocab.tsv")

    with open(path, 'r') as f:
      size_dict = int(f.readline())
      dict_tetm_to_int = dict()
      for _ in range(size_dict):
        key = f.readline()[:-1]
        dict_tetm_to_int[key] = int(f.readline())

      size_dict = int(f.readline())
      dict_int_to_term = dict()
      for _ in range(size_dict):
        key = int(f.readline())
        dict_int_to_term[key] = f.readline()[:-1]

    return dict_tetm_to_int, dict_int_to_term


softmax = torch.nn.Softmax(dim=1)
dict_tetm_to_int, dict_int_to_term = get_vocab()
model = get_model()
tokenizer = get_tokenizer()


def predict(text, device='cpu'):
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=512,
        return_token_type_ids=False,
        truncation=True,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )
    
    out = {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten()
      }
    
    input_ids = out["input_ids"].to(device)
    attention_mask = out["attention_mask"].to(device)
    
    outputs = model(
        input_ids=input_ids.unsqueeze(0),
        attention_mask=attention_mask.unsqueeze(0)
    )
    
    out = softmax(outputs.logits)
    prediction = torch.argsort(outputs.logits, dim=1, descending=True).cpu()[0]
    sum_answer = 0
    answer = []
    idx = 0
    while sum_answer < 0.95:
      sum_answer += out[0][idx].item()
      answer.append(dict_int_to_term[prediction[idx].item()])
      idx += 1

    return answer


st.title("We will help you determine what topic your article belongs to:)")
st.header("Please enter a title and/or introduction")


title = st.text_input(label="Title", value="")
abstract = st.text_input(label="Abstract", value="")
if(st.button('Show result')):
    predict = ' '.join(predict(title.title() + ' ' + abstract.title()))
    result = 'Suggested answer:\n' + predict
    st.success(result)