Spaces:
Running
Running
app fixes_v5
Browse files
app.py
CHANGED
@@ -83,42 +83,23 @@ class DebertPaperClassifierV5(torch.nn.Module):
|
|
83 |
else:
|
84 |
self.loss_fct = torch.nn.BCEWithLogitsLoss()
|
85 |
|
86 |
-
|
87 |
-
outputs = self.deberta(
|
88 |
-
input_ids=input_ids,
|
89 |
-
attention_mask=attention_mask
|
90 |
-
)
|
91 |
-
logits = self.classifier(outputs.last_hidden_state[:, 0, :])
|
92 |
-
loss = None
|
93 |
-
if labels is not None:
|
94 |
-
loss = self.loss_fct(logits, labels)
|
95 |
-
return (loss, logits) if loss is not None else logits
|
96 |
-
|
97 |
-
def _init_weights(self):
|
98 |
-
for module in self.classifier.modules():
|
99 |
-
if isinstance(module, torch.nn.Linear):
|
100 |
-
module.weight.data.normal_(mean=0.0, std=0.02)
|
101 |
-
if module.bias is not None:
|
102 |
-
module.bias.data.zero_()
|
103 |
-
|
104 |
-
def forward(self,
|
105 |
-
input_ids,
|
106 |
-
attention_mask,
|
107 |
-
labels=None,
|
108 |
-
):
|
109 |
outputs = self.deberta(
|
110 |
input_ids=input_ids,
|
111 |
attention_mask=attention_mask
|
112 |
)
|
113 |
-
|
114 |
-
cls_output = outputs.last_hidden_state[:, 0, :]
|
115 |
-
logits = self.classifier(cls_output)
|
116 |
-
|
117 |
loss = None
|
118 |
if labels is not None:
|
119 |
loss = self.loss_fct(logits, labels)
|
120 |
-
|
121 |
return (loss, logits) if loss is not None else logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
@st.cache_resource
|
124 |
def load_model(test=False):
|
|
|
83 |
else:
|
84 |
self.loss_fct = torch.nn.BCEWithLogitsLoss()
|
85 |
|
86 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
outputs = self.deberta(
|
88 |
input_ids=input_ids,
|
89 |
attention_mask=attention_mask
|
90 |
)
|
91 |
+
logits = self.classifier(outputs.last_hidden_state[:, 0, :])
|
|
|
|
|
|
|
92 |
loss = None
|
93 |
if labels is not None:
|
94 |
loss = self.loss_fct(logits, labels)
|
|
|
95 |
return (loss, logits) if loss is not None else logits
|
96 |
+
|
97 |
+
def _init_weights(self):
|
98 |
+
for module in self.classifier.modules():
|
99 |
+
if isinstance(module, torch.nn.Linear):
|
100 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
101 |
+
if module.bias is not None:
|
102 |
+
module.bias.data.zero_()
|
103 |
|
104 |
@st.cache_resource
|
105 |
def load_model(test=False):
|