Spaces:
Running
Running
app fixes_v6
Browse files
app.py
CHANGED
@@ -7,7 +7,13 @@ from transformers import (
|
|
7 |
DebertaV2Model,
|
8 |
DebertaV2Tokenizer,
|
9 |
)
|
10 |
-
import sentencepiece
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
model_name = "microsoft/deberta-v3-base"
|
13 |
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
|
@@ -37,6 +43,7 @@ def classify_text(text, model, tokenizer, device, threshold=0.5):
|
|
37 |
|
38 |
def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
|
39 |
probabilities, _ = classify_text(text, model, tokenizer, device)
|
|
|
40 |
themes = []
|
41 |
for label in probabilities[0].argsort()[-limit:]:
|
42 |
themes.append((label_to_theme[str(label)], probabilities[0][label]))
|
@@ -57,7 +64,7 @@ class DebertPaperClassifier(torch.nn.Module):
|
|
57 |
torch.nn.Linear(512, num_labels)
|
58 |
)
|
59 |
|
60 |
-
|
61 |
if class_weights is not None:
|
62 |
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
|
63 |
else:
|
@@ -112,11 +119,11 @@ def load_model(test=False):
|
|
112 |
path = '/home/user/app'
|
113 |
with open(f'{path}/label_to_theme.json', 'r') as f:
|
114 |
label_to_theme = json.load(f)
|
115 |
-
|
116 |
|
117 |
-
class_weights = torch.load(
|
118 |
-
|
119 |
-
model
|
|
|
120 |
model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
|
121 |
if test:
|
122 |
print(device)
|
@@ -126,6 +133,7 @@ def load_model(test=False):
|
|
126 |
print(get_themes(text, model, tokenizer, label_to_theme, device))
|
127 |
return model, tokenizer, label_to_theme, device
|
128 |
|
|
|
129 |
def kek():
|
130 |
|
131 |
title = st.text_input("Title")
|
|
|
7 |
DebertaV2Model,
|
8 |
DebertaV2Tokenizer,
|
9 |
)
|
10 |
+
import sentencepiece
|
11 |
+
|
12 |
+
try:
|
13 |
+
import sentencepiece
|
14 |
+
except ImportError:
|
15 |
+
st.error("Требуется установить SentencePiece: pip install sentencepiece")
|
16 |
+
st.stop()
|
17 |
|
18 |
model_name = "microsoft/deberta-v3-base"
|
19 |
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
|
|
|
43 |
|
44 |
def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
|
45 |
probabilities, _ = classify_text(text, model, tokenizer, device)
|
46 |
+
print(probabilities)
|
47 |
themes = []
|
48 |
for label in probabilities[0].argsort()[-limit:]:
|
49 |
themes.append((label_to_theme[str(label)], probabilities[0][label]))
|
|
|
64 |
torch.nn.Linear(512, num_labels)
|
65 |
)
|
66 |
|
67 |
+
self._init_weights()
|
68 |
if class_weights is not None:
|
69 |
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
|
70 |
else:
|
|
|
119 |
path = '/home/user/app'
|
120 |
with open(f'{path}/label_to_theme.json', 'r') as f:
|
121 |
label_to_theme = json.load(f)
|
|
|
122 |
|
123 |
+
#class_weights = torch.load('model_info/class_weights.pth').to(device)
|
124 |
+
# model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
|
125 |
+
# model.load_state_dict(torch.load("model_info/full_model_v4.pth", map_location=device))
|
126 |
+
model = DebertPaperClassifierV5(device=device, num_labels=47).to(device)
|
127 |
model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
|
128 |
if test:
|
129 |
print(device)
|
|
|
133 |
print(get_themes(text, model, tokenizer, label_to_theme, device))
|
134 |
return model, tokenizer, label_to_theme, device
|
135 |
|
136 |
+
|
137 |
def kek():
|
138 |
|
139 |
title = st.text_input("Title")
|