bvd757 commited on
Commit
b8c8071
·
1 Parent(s): 8292fd2

app fixes_v6

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -7,7 +7,13 @@ from transformers import (
7
  DebertaV2Model,
8
  DebertaV2Tokenizer,
9
  )
10
- import sentencepiece
 
 
 
 
 
 
11
 
12
  model_name = "microsoft/deberta-v3-base"
13
  tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
@@ -37,6 +43,7 @@ def classify_text(text, model, tokenizer, device, threshold=0.5):
37
 
38
  def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
39
  probabilities, _ = classify_text(text, model, tokenizer, device)
 
40
  themes = []
41
  for label in probabilities[0].argsort()[-limit:]:
42
  themes.append((label_to_theme[str(label)], probabilities[0][label]))
@@ -57,7 +64,7 @@ class DebertPaperClassifier(torch.nn.Module):
57
  torch.nn.Linear(512, num_labels)
58
  )
59
 
60
- #self._init_weights()
61
  if class_weights is not None:
62
  self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
63
  else:
@@ -112,11 +119,11 @@ def load_model(test=False):
112
  path = '/home/user/app'
113
  with open(f'{path}/label_to_theme.json', 'r') as f:
114
  label_to_theme = json.load(f)
115
-
116
 
117
- class_weights = torch.load(f'{path}/class_weights.pth').to(device)
118
-
119
- model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
 
120
  model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
121
  if test:
122
  print(device)
@@ -126,6 +133,7 @@ def load_model(test=False):
126
  print(get_themes(text, model, tokenizer, label_to_theme, device))
127
  return model, tokenizer, label_to_theme, device
128
 
 
129
  def kek():
130
 
131
  title = st.text_input("Title")
 
7
  DebertaV2Model,
8
  DebertaV2Tokenizer,
9
  )
10
+ import sentencepiece
11
+
12
+ try:
13
+ import sentencepiece
14
+ except ImportError:
15
+ st.error("Требуется установить SentencePiece: pip install sentencepiece")
16
+ st.stop()
17
 
18
  model_name = "microsoft/deberta-v3-base"
19
  tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
 
43
 
44
  def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
45
  probabilities, _ = classify_text(text, model, tokenizer, device)
46
+ print(probabilities)
47
  themes = []
48
  for label in probabilities[0].argsort()[-limit:]:
49
  themes.append((label_to_theme[str(label)], probabilities[0][label]))
 
64
  torch.nn.Linear(512, num_labels)
65
  )
66
 
67
+ self._init_weights()
68
  if class_weights is not None:
69
  self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
70
  else:
 
119
  path = '/home/user/app'
120
  with open(f'{path}/label_to_theme.json', 'r') as f:
121
  label_to_theme = json.load(f)
 
122
 
123
+ #class_weights = torch.load('model_info/class_weights.pth').to(device)
124
+ # model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
125
+ # model.load_state_dict(torch.load("model_info/full_model_v4.pth", map_location=device))
126
+ model = DebertPaperClassifierV5(device=device, num_labels=47).to(device)
127
  model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
128
  if test:
129
  print(device)
 
133
  print(get_themes(text, model, tokenizer, label_to_theme, device))
134
  return model, tokenizer, label_to_theme, device
135
 
136
+
137
  def kek():
138
 
139
  title = st.text_input("Title")