bvd757 commited on
Commit
1751037
·
1 Parent(s): b8c8071

with class_weights

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -120,10 +120,11 @@ def load_model(test=False):
120
  with open(f'{path}/label_to_theme.json', 'r') as f:
121
  label_to_theme = json.load(f)
122
 
123
- #class_weights = torch.load('model_info/class_weights.pth').to(device)
124
  # model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
125
  # model.load_state_dict(torch.load("model_info/full_model_v4.pth", map_location=device))
126
- model = DebertPaperClassifierV5(device=device, num_labels=47).to(device)
 
 
127
  model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
128
  if test:
129
  print(device)
 
120
  with open(f'{path}/label_to_theme.json', 'r') as f:
121
  label_to_theme = json.load(f)
122
 
 
123
  # model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
124
  # model.load_state_dict(torch.load("model_info/full_model_v4.pth", map_location=device))
125
+
126
+ class_weights = torch.load(f'{path}/class_weights.pth').to(device)
127
+ model = DebertPaperClassifierV5(device=device, num_labels=47, class_weights=class_weights).to(device)
128
  model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
129
  if test:
130
  print(device)