rbao2018 commited on
Commit
7440363
·
1 Parent(s): 5472194
Files changed (1) hide show
  1. modeling_bailing_moe.py +188 -1
modeling_bailing_moe.py CHANGED
@@ -72,6 +72,81 @@ logger = logging.get_logger(__name__)
72
 
73
  _CONFIG_FOR_DOC = "BailingMoeConfig"
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def _get_unpad_data(attention_mask):
77
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -421,7 +496,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
421
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h)
422
  if self.config.num_shared_experts is not None:
423
  y = y + self.shared_experts(identity)
424
- return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1))
425
 
426
  @torch.no_grad()
427
  def moe_infer(self, x, topk_ids, topk_weight):
@@ -1452,6 +1527,14 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1452
  loss = None
1453
  aux_loss = None
1454
 
 
 
 
 
 
 
 
 
1455
  if labels is not None:
1456
  # Shift so that tokens < n predict n
1457
  shift_logits = logits[..., :-1, :].contiguous()
@@ -1547,3 +1630,107 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1547
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1548
  )
1549
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  _CONFIG_FOR_DOC = "BailingMoeConfig"
74
 
75
+ # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
76
+ def load_balancing_loss_func(
77
+ gate_logits_and_topk: Union[torch.Tensor, Tuple[torch.Tensor], None],
78
+ num_experts: Optional[int] = None,
79
+ top_k=2,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ ) -> Union[torch.Tensor, int]:
82
+ r"""
83
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
84
+
85
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
86
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
87
+ experts is too unbalanced.
88
+
89
+ Args:
90
+ gate_logits:
91
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
92
+ shape [batch_size X sequence_length, num_experts].
93
+ num_experts:
94
+ Number of experts
95
+ top_k:
96
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
97
+ parameter.
98
+ attention_mask (`torch.Tensor`, *optional*):
99
+ The attention_mask used in forward function
100
+ shape [batch_size X sequence_length] if not None.
101
+
102
+ Returns:
103
+ The auxiliary loss.
104
+ """
105
+ if gate_logits_and_topk is None or not isinstance(gate_logits_and_topk, tuple):
106
+ return 0
107
+
108
+ if isinstance(gate_logits_and_topk, tuple):
109
+ # concatenated_gate_logits.shape = [batch_size * num_layers * seq_len, num_experts]
110
+ concatenated_gate_logits = torch.cat([layer_gate[0] for layer_gate in gate_logits_and_topk], dim=0)
111
+ # selected_experts.shape = [batch_size * num_layers * seq_len, top_k_experts]
112
+ selected_experts = torch.cat([layer_gate[1] for layer_gate in gate_logits_and_topk], dim=0)
113
+ selected_experts.to(concatenated_gate_logits.device)
114
+
115
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
116
+
117
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
118
+
119
+ if attention_mask is None:
120
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
121
+ # Compute the average probability of routing to these experts
122
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
123
+ else:
124
+ batch_size, sequence_length = attention_mask.shape
125
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
126
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
127
+ expert_attention_mask = (
128
+ attention_mask[None, :, :, None, None]
129
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
130
+ .reshape(-1, top_k, num_experts)
131
+ )
132
+
133
+ # Compute the percentage of tokens routed to each experts
134
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
135
+ expert_attention_mask, dim=0
136
+ )
137
+
138
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
139
+ router_per_expert_attention_mask = (
140
+ attention_mask[None, :, :, None]
141
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
142
+ .reshape(-1, num_experts)
143
+ )
144
+
145
+ # Compute the average probability of routing to these experts
146
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
147
+
148
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
149
+ return overall_loss
150
 
151
  def _get_unpad_data(attention_mask):
152
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
496
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h)
497
  if self.config.num_shared_experts is not None:
498
  y = y + self.shared_experts(identity)
499
+ return y, (router_logits, topk_idx)
500
 
501
  @torch.no_grad()
502
  def moe_infer(self, x, topk_ids, topk_weight):
 
1527
  loss = None
1528
  aux_loss = None
1529
 
1530
+ if output_router_logits:
1531
+ aux_loss = load_balancing_loss_func(
1532
+ outputs.router_logits if return_dict else outputs[-1],
1533
+ self.num_experts,
1534
+ self.num_experts_per_tok,
1535
+ attention_mask,
1536
+ )
1537
+
1538
  if labels is not None:
1539
  # Shift so that tokens < n predict n
1540
  shift_logits = logits[..., :-1, :].contiguous()
 
1630
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1631
  )
1632
  return reordered_past
1633
+
1634
+
1635
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
1636
+ class BailingMoeForTokenClassification(BailingMoePreTrainedModel):
1637
+ def __init__(self, config):
1638
+ super().__init__(config)
1639
+ self.num_labels = config.num_labels
1640
+
1641
+ self.num_experts = config.num_experts
1642
+ self.num_experts_per_tok = config.num_experts_per_tok
1643
+
1644
+ self.model = BailingMoeModel(config)
1645
+ if getattr(config, "classifier_dropout", None) is not None:
1646
+ classifier_dropout = config.classifier_dropout
1647
+ elif getattr(config, "hidden_dropout", None) is not None:
1648
+ classifier_dropout = config.hidden_dropout
1649
+ else:
1650
+ classifier_dropout = 0.1
1651
+ self.dropout = nn.Dropout(classifier_dropout)
1652
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1653
+
1654
+ # Initialize weights and apply final processing
1655
+ self.post_init()
1656
+
1657
+ def get_input_embeddings(self):
1658
+ return self.model.embed_tokens
1659
+
1660
+ def set_input_embeddings(self, value):
1661
+ self.model.embed_tokens = value
1662
+
1663
+ @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1664
+ def forward(
1665
+ self,
1666
+ input_ids: Optional[torch.LongTensor] = None,
1667
+ attention_mask: Optional[torch.Tensor] = None,
1668
+ position_ids: Optional[torch.LongTensor] = None,
1669
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1670
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1671
+ labels: Optional[torch.LongTensor] = None,
1672
+ use_cache: Optional[bool] = None,
1673
+ output_attentions: Optional[bool] = None,
1674
+ output_hidden_states: Optional[bool] = None,
1675
+ output_router_logits: Optional[bool] = None,
1676
+ return_dict: Optional[bool] = None,
1677
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1678
+ r"""
1679
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1680
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1681
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1682
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1683
+ """
1684
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1685
+ output_router_logits = (
1686
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1687
+ )
1688
+
1689
+ outputs = self.model(
1690
+ input_ids,
1691
+ attention_mask=attention_mask,
1692
+ position_ids=position_ids,
1693
+ past_key_values=past_key_values,
1694
+ inputs_embeds=inputs_embeds,
1695
+ use_cache=use_cache,
1696
+ output_attentions=output_attentions,
1697
+ output_hidden_states=output_hidden_states,
1698
+ output_router_logits=output_router_logits,
1699
+ return_dict=return_dict,
1700
+ )
1701
+ sequence_output = outputs[0]
1702
+ sequence_output = self.dropout(sequence_output)
1703
+ logits = self.score(sequence_output)
1704
+
1705
+ loss = None
1706
+ aux_loss = None
1707
+ if labels is not None:
1708
+ loss = self.loss_function(logits, labels, self.config)
1709
+
1710
+ if output_router_logits:
1711
+ aux_loss = load_balancing_loss_func(
1712
+ outputs.router_logits if return_dict else outputs[-1],
1713
+ self.num_experts,
1714
+ self.num_experts_per_tok,
1715
+ attention_mask,
1716
+ )
1717
+
1718
+ if not return_dict:
1719
+ output = (logits,) + outputs[1:]
1720
+ if output_router_logits:
1721
+ output = (aux_loss,) + output
1722
+ return (loss,) + output if loss is not None else output
1723
+
1724
+ if not return_dict:
1725
+ output = (logits,) + outputs[2:]
1726
+ return ((loss,) + output) if loss is not None else output
1727
+
1728
+ return MoeCausalLMOutputWithPast(
1729
+ loss=loss,
1730
+ aux_loss=aux_loss,
1731
+ logits=logits,
1732
+ past_key_values=outputs.past_key_values,
1733
+ hidden_states=outputs.hidden_states,
1734
+ attentions=outputs.attentions,
1735
+ router_logits=outputs.router_logits,
1736
+ )