bvd757 commited on
Commit
fb76733
·
1 Parent(s): 31a1651

app fixes_v4

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -83,16 +83,16 @@ class DebertPaperClassifierV5(torch.nn.Module):
83
  else:
84
  self.loss_fct = torch.nn.BCEWithLogitsLoss()
85
 
86
- def forward(self, input_ids, attention_mask, labels=None):
87
- outputs = self.deberta(
88
- input_ids=input_ids,
89
- attention_mask=attention_mask
90
- )
91
- logits = self.classifier(outputs.last_hidden_state[:, 0, :])
92
- loss = None
93
- if labels is not None:
94
- loss = self.loss_fct(logits, labels)
95
- return (loss, logits) if loss is not None else logits
96
 
97
  def _init_weights(self):
98
  for module in self.classifier.modules():
 
83
  else:
84
  self.loss_fct = torch.nn.BCEWithLogitsLoss()
85
 
86
+ def forward(self, input_ids, attention_mask, labels=None):
87
+ outputs = self.deberta(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask
90
+ )
91
+ logits = self.classifier(outputs.last_hidden_state[:, 0, :])
92
+ loss = None
93
+ if labels is not None:
94
+ loss = self.loss_fct(logits, labels)
95
+ return (loss, logits) if loss is not None else logits
96
 
97
  def _init_weights(self):
98
  for module in self.classifier.modules():