Spaces:
Running
Running
app fixes_v4
Browse files
app.py
CHANGED
@@ -83,16 +83,16 @@ class DebertPaperClassifierV5(torch.nn.Module):
|
|
83 |
else:
|
84 |
self.loss_fct = torch.nn.BCEWithLogitsLoss()
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
|
97 |
def _init_weights(self):
|
98 |
for module in self.classifier.modules():
|
|
|
83 |
else:
|
84 |
self.loss_fct = torch.nn.BCEWithLogitsLoss()
|
85 |
|
86 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
87 |
+
outputs = self.deberta(
|
88 |
+
input_ids=input_ids,
|
89 |
+
attention_mask=attention_mask
|
90 |
+
)
|
91 |
+
logits = self.classifier(outputs.last_hidden_state[:, 0, :])
|
92 |
+
loss = None
|
93 |
+
if labels is not None:
|
94 |
+
loss = self.loss_fct(logits, labels)
|
95 |
+
return (loss, logits) if loss is not None else logits
|
96 |
|
97 |
def _init_weights(self):
|
98 |
for module in self.classifier.modules():
|