Spaces:
Running
Running
with class_weights
Browse files
app.py
CHANGED
@@ -120,10 +120,11 @@ def load_model(test=False):
|
|
120 |
with open(f'{path}/label_to_theme.json', 'r') as f:
|
121 |
label_to_theme = json.load(f)
|
122 |
|
123 |
-
#class_weights = torch.load('model_info/class_weights.pth').to(device)
|
124 |
# model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
|
125 |
# model.load_state_dict(torch.load("model_info/full_model_v4.pth", map_location=device))
|
126 |
-
|
|
|
|
|
127 |
model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
|
128 |
if test:
|
129 |
print(device)
|
|
|
120 |
with open(f'{path}/label_to_theme.json', 'r') as f:
|
121 |
label_to_theme = json.load(f)
|
122 |
|
|
|
123 |
# model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
|
124 |
# model.load_state_dict(torch.load("model_info/full_model_v4.pth", map_location=device))
|
125 |
+
|
126 |
+
class_weights = torch.load(f'{path}/class_weights.pth').to(device)
|
127 |
+
model = DebertPaperClassifierV5(device=device, num_labels=47, class_weights=class_weights).to(device)
|
128 |
model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
|
129 |
if test:
|
130 |
print(device)
|