rbao2018 commited on
Commit
2c3e6be
·
1 Parent(s): 7440363
Files changed (1) hide show
  1. modeling_bailing_moe.py +4 -13
modeling_bailing_moe.py CHANGED
@@ -1438,21 +1438,12 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1438
 
1439
  def compute_logit(self, hidden_states):
1440
  if self.norm_head:
1441
- if self.training:
1442
- norm_weight = (
1443
- self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
1444
- )
1445
- logits = F.linear(hidden_states, norm_weight, None)
1446
- else:
1447
- self.lm_head.weight.data = (
1448
- self.lm_head.weight.data.float()
1449
- / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
1450
- ).to(hidden_states.dtype)
1451
- logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1452
- self.norm_head = False
1453
  else:
1454
  logits = self.lm_head(hidden_states)
1455
- return logits
1456
 
1457
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1458
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1438
 
1439
  def compute_logit(self, hidden_states):
1440
  if self.norm_head:
1441
+ weight_float = self.lm_head.weight.float()
1442
+ norm = torch.norm(weight_float, p=2, dim=0, keepdim=True).clamp(min=1e-7)
1443
+ norm_weight = (weight_float / norm).to(hidden_states.dtype)
1444
+ logits = F.linear(hidden_states, norm_weight, None)
 
 
 
 
 
 
 
 
1445
  else:
1446
  logits = self.lm_head(hidden_states)
 
1447
 
1448
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1449
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)