Nikita Pogadaev commited on
Commit
c2c8638
·
1 Parent(s): 6da6312

adding model runner, first commit

Browse files
Files changed (3) hide show
  1. app.py +174 -0
  2. model_info/label_to_theme.json +1 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+
3
+ import streamlit as st
4
+ import json
5
+ import numpy as np
6
+ import torch
7
+ from transformers import (
8
+ DebertaV2Config,
9
+ DebertaV2Model,
10
+ DebertaV2Tokenizer,
11
+ )
12
+
13
+ model_name = "microsoft/deberta-v3-base"
14
+ tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
15
+
16
+ def preprocess_text(text, tokenizer, max_length=512):
17
+ inputs = tokenizer(
18
+ text,
19
+ padding="max_length",
20
+ truncation=True,
21
+ max_length=max_length,
22
+ return_tensors="pt"
23
+ )
24
+ return inputs
25
+
26
+
27
+ def classify_text(text, model, tokenizer, device, threshold=0.5):
28
+ inputs = preprocess_text(text, tokenizer)
29
+ input_ids = inputs["input_ids"].to(device)
30
+ attention_mask = inputs["attention_mask"].to(device)
31
+ model.eval()
32
+ with torch.no_grad():
33
+ logits = model(input_ids, attention_mask)
34
+ probs = torch.sigmoid(logits)
35
+ predictions = (probs > threshold).int().numpy()
36
+
37
+ return probs.numpy(), predictions
38
+
39
+ def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
40
+ probabilities, _ = classify_text(text, model, tokenizer, device)
41
+ probabilities = probabilities / probabilities.sum()
42
+ themes = []
43
+ for label in probabilities[0].argsort()[-limit:]:
44
+ themes.append((label_to_theme[str(label)], probabilities[0][label]))
45
+ return themes
46
+
47
+ class DebertPaperClassifier(torch.nn.Module):
48
+ def __init__(self, num_labels, device, dropout_rate=0.1, class_weights=None):
49
+ super().__init__()
50
+ self.config = DebertaV2Config.from_pretrained(model_name)
51
+ self.deberta = DebertaV2Model.from_pretrained(model_name, config=self.config)
52
+
53
+ self.classifier = torch.nn.Sequential(
54
+ torch.nn.Dropout(dropout_rate),
55
+ torch.nn.Linear(self.config.hidden_size, 512),
56
+ torch.nn.LayerNorm(512),
57
+ torch.nn.GELU(),
58
+ torch.nn.Dropout(dropout_rate),
59
+ torch.nn.Linear(512, num_labels)
60
+ )
61
+
62
+ self._init_weights()
63
+ if class_weights is not None:
64
+ self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
65
+ else:
66
+ self.loss_fct = torch.nn.BCEWithLogitsLoss()
67
+
68
+ class DebertPaperClassifierV5(torch.nn.Module):
69
+ def __init__(self, device, num_labels=47, dropout_rate=0.1, class_weights=None):
70
+ super().__init__()
71
+ self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base")
72
+ self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config)
73
+
74
+ self.classifier = torch.nn.Sequential(
75
+ torch.nn.Dropout(dropout_rate),
76
+ torch.nn.Linear(self.config.hidden_size, 512),
77
+ torch.nn.LayerNorm(512),
78
+ torch.nn.GELU(),
79
+ torch.nn.Dropout(dropout_rate),
80
+ torch.nn.Linear(512, num_labels)
81
+ )
82
+
83
+ if class_weights is not None:
84
+ self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
85
+ else:
86
+ self.loss_fct = torch.nn.BCEWithLogitsLoss()
87
+
88
+ def forward(self, input_ids, attention_mask, labels=None):
89
+ outputs = self.deberta(
90
+ input_ids=input_ids,
91
+ attention_mask=attention_mask
92
+ )
93
+ logits = self.classifier(outputs.last_hidden_state[:, 0, :])
94
+ loss = None
95
+ if labels is not None:
96
+ loss = self.loss_fct(logits, labels)
97
+ return (loss, logits) if loss is not None else logits
98
+
99
+ def _init_weights(self):
100
+ for module in self.classifier.modules():
101
+ if isinstance(module, torch.nn.Linear):
102
+ module.weight.data.normal_(mean=0.0, std=0.02)
103
+ if module.bias is not None:
104
+ module.bias.data.zero_()
105
+
106
+ def forward(self,
107
+ input_ids,
108
+ attention_mask,
109
+ labels=None,
110
+ ):
111
+ outputs = self.deberta(
112
+ input_ids=input_ids,
113
+ attention_mask=attention_mask
114
+ )
115
+
116
+ cls_output = outputs.last_hidden_state[:, 0, :]
117
+ logits = self.classifier(cls_output)
118
+
119
+ loss = None
120
+ if labels is not None:
121
+ loss = self.loss_fct(logits, labels)
122
+
123
+ return (loss, logits) if loss is not None else logits
124
+
125
+ @st.cache_resource
126
+ def load_model():
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ with open('model_info/label_to_theme.json', 'r') as f:
129
+ label_to_theme = json.load(f)
130
+
131
+ model = DebertPaperClassifierV5(device=device, num_labels=len(label_to_theme)).to(device)
132
+ model.load_state_dict(torch.load("model_info/deberta_v3.pth", map_location=device))
133
+ return model, tokenizer, label_to_theme, device
134
+
135
+ def kek():
136
+ st.title("arXiv Paper Classifier")
137
+ st.markdown("""
138
+ <style>
139
+ .image-row {
140
+ display: flex;
141
+ flex-direction: row;
142
+ gap: 10px;
143
+ }
144
+ </style>
145
+
146
+ <div class="image-row">
147
+ <img width=100px src='https://storage.yandexcloud.net/lms-vault/media/cache/c9/a7/c9a754ba1b2bb5b34e1f178d4ec26f24.jpg'>
148
+ <img width=300px src='https://pic.rutubelist.ru/video/ba/b6/bab6ab515c15837e28eb6c99df192cae.jpg'>
149
+ </div>
150
+ """, unsafe_allow_html=True)
151
+ st.write("write the title or abstract to classify topic theme")
152
+
153
+ title = st.text_input("title")
154
+ abstract = st.text_area("abstract")
155
+ lim = int(st.number_input("top ? themes"))
156
+
157
+ if st.button("CLASSIFY"):
158
+ if not title and not abstract:
159
+ st.warning("empty abstract!!!")
160
+ return
161
+
162
+ text = f"{title}\n\n{abstract}" if title and abstract else title or abstract
163
+ model, tokenizer, label_to_theme, device = load_model()
164
+
165
+ with st.spinner("classifying..."):
166
+ themes = get_themes(text, model, tokenizer, label_to_theme, device, lim)
167
+ co = 0
168
+ st.success(f"top {int(lim)} results:")
169
+ for th, pr in themes:
170
+ st.write(f"{lim - co}. - {th}: {pr:.1%}")
171
+ co += 1
172
+
173
+ if __name__ == "__main__":
174
+ kek()
model_info/label_to_theme.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "cs.AI", "1": "physics.soc-ph", "2": "stat.ML", "3": "cs.CE", "4": "cs.DB", "5": "cs.CL", "6": "cs.NA", "7": "cs.CY", "8": "cs.GT", "9": "cs.SI", "10": "stat.AP", "11": "cs.DL", "12": "math.ST", "13": "nlin.AO", "14": "cs.LO", "15": "cs.MM", "16": "cond-mat.dis-nn", "17": "cs.DM", "18": "cs.CC", "19": "stat.CO", "20": "cs.DC", "21": "cs.IT", "22": "cs.DS", "23": "cs.SY", "24": "q-bio.QM", "25": "cs.PL", "26": "cs.RO", "27": "cs.NE", "28": "cs.CR", "29": "cs.MA", "30": "q-bio.NC", "31": "cs.LG", "32": "cs.GR", "33": "physics.data-an", "34": "quant-ph", "35": "cs.IR", "36": "math.NA", "37": "math.PR", "38": "stat.ME", "39": "cs.SE", "40": "math.OC", "41": "math.IT", "42": "cs.HC", "43": "stat.TH", "44": "cs.NI", "45": "cs.CV", "46": "cs.SD"}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ numpy
5
+ sentencepiece