update
Browse files- modeling_bailing_moe.py +4 -13
modeling_bailing_moe.py
CHANGED
@@ -1438,21 +1438,12 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
1438 |
|
1439 |
def compute_logit(self, hidden_states):
|
1440 |
if self.norm_head:
|
1441 |
-
|
1442 |
-
|
1443 |
-
|
1444 |
-
|
1445 |
-
logits = F.linear(hidden_states, norm_weight, None)
|
1446 |
-
else:
|
1447 |
-
self.lm_head.weight.data = (
|
1448 |
-
self.lm_head.weight.data.float()
|
1449 |
-
/ (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
|
1450 |
-
).to(hidden_states.dtype)
|
1451 |
-
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
1452 |
-
self.norm_head = False
|
1453 |
else:
|
1454 |
logits = self.lm_head(hidden_states)
|
1455 |
-
return logits
|
1456 |
|
1457 |
@add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
|
1458 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1438 |
|
1439 |
def compute_logit(self, hidden_states):
|
1440 |
if self.norm_head:
|
1441 |
+
weight_float = self.lm_head.weight.float()
|
1442 |
+
norm = torch.norm(weight_float, p=2, dim=0, keepdim=True).clamp(min=1e-7)
|
1443 |
+
norm_weight = (weight_float / norm).to(hidden_states.dtype)
|
1444 |
+
logits = F.linear(hidden_states, norm_weight, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1445 |
else:
|
1446 |
logits = self.lm_head(hidden_states)
|
|
|
1447 |
|
1448 |
@add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
|
1449 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|