bvd757 commited on
Commit
8292fd2
·
1 Parent(s): fb76733

app fixes_v5

Browse files
Files changed (1) hide show
  1. app.py +9 -28
app.py CHANGED
@@ -83,42 +83,23 @@ class DebertPaperClassifierV5(torch.nn.Module):
83
  else:
84
  self.loss_fct = torch.nn.BCEWithLogitsLoss()
85
 
86
- def forward(self, input_ids, attention_mask, labels=None):
87
- outputs = self.deberta(
88
- input_ids=input_ids,
89
- attention_mask=attention_mask
90
- )
91
- logits = self.classifier(outputs.last_hidden_state[:, 0, :])
92
- loss = None
93
- if labels is not None:
94
- loss = self.loss_fct(logits, labels)
95
- return (loss, logits) if loss is not None else logits
96
-
97
- def _init_weights(self):
98
- for module in self.classifier.modules():
99
- if isinstance(module, torch.nn.Linear):
100
- module.weight.data.normal_(mean=0.0, std=0.02)
101
- if module.bias is not None:
102
- module.bias.data.zero_()
103
-
104
- def forward(self,
105
- input_ids,
106
- attention_mask,
107
- labels=None,
108
- ):
109
  outputs = self.deberta(
110
  input_ids=input_ids,
111
  attention_mask=attention_mask
112
  )
113
-
114
- cls_output = outputs.last_hidden_state[:, 0, :]
115
- logits = self.classifier(cls_output)
116
-
117
  loss = None
118
  if labels is not None:
119
  loss = self.loss_fct(logits, labels)
120
-
121
  return (loss, logits) if loss is not None else logits
 
 
 
 
 
 
 
122
 
123
  @st.cache_resource
124
  def load_model(test=False):
 
83
  else:
84
  self.loss_fct = torch.nn.BCEWithLogitsLoss()
85
 
86
+ def forward(self, input_ids, attention_mask, labels=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  outputs = self.deberta(
88
  input_ids=input_ids,
89
  attention_mask=attention_mask
90
  )
91
+ logits = self.classifier(outputs.last_hidden_state[:, 0, :])
 
 
 
92
  loss = None
93
  if labels is not None:
94
  loss = self.loss_fct(logits, labels)
 
95
  return (loss, logits) if loss is not None else logits
96
+
97
+ def _init_weights(self):
98
+ for module in self.classifier.modules():
99
+ if isinstance(module, torch.nn.Linear):
100
+ module.weight.data.normal_(mean=0.0, std=0.02)
101
+ if module.bias is not None:
102
+ module.bias.data.zero_()
103
 
104
  @st.cache_resource
105
  def load_model(test=False):