m1ngcheng rbao2018 commited on
Commit
2ea7627
·
verified ·
1 Parent(s): e2cad8a

add eos token and the end of assistant content (#8)

Browse files

- add eos token and the end of assistant content (8dcffca4eb0f465b05b609d5684eb20d935cfb2d)
- update config.json (5472194a39c8f621f98957ca539397229f02956f)
- update (7440363c8ebb71720084ef78e020d44bf803b61f)
- update (2c3e6bec705f0d7928f2bed64c5f14a4382ac988)


Co-authored-by: Rong Bao <[email protected]>

Files changed (3) hide show
  1. config.json +2 -1
  2. modeling_bailing_moe.py +192 -14
  3. tokenizer_config.json +1 -1
config.json CHANGED
@@ -6,7 +6,8 @@
6
  "auto_map": {
7
  "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8
  "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9
- "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM"
 
10
  },
11
  "eos_token_id": 126081,
12
  "pad_token_id": 126081,
 
6
  "auto_map": {
7
  "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8
  "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9
+ "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM",
10
+ "AutoModelForTokenClassification": "modeling_bailing_moe.BailingMoeForTokenClassification"
11
  },
12
  "eos_token_id": 126081,
13
  "pad_token_id": 126081,
modeling_bailing_moe.py CHANGED
@@ -72,6 +72,81 @@ logger = logging.get_logger(__name__)
72
 
73
  _CONFIG_FOR_DOC = "BailingMoeConfig"
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def _get_unpad_data(attention_mask):
77
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -421,7 +496,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
421
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h)
422
  if self.config.num_shared_experts is not None:
423
  y = y + self.shared_experts(identity)
424
- return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1))
425
 
426
  @torch.no_grad()
427
  def moe_infer(self, x, topk_ids, topk_weight):
@@ -1363,21 +1438,12 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1363
 
1364
  def compute_logit(self, hidden_states):
1365
  if self.norm_head:
1366
- if self.training:
1367
- norm_weight = (
1368
- self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
1369
- )
1370
- logits = F.linear(hidden_states, norm_weight, None)
1371
- else:
1372
- self.lm_head.weight.data = (
1373
- self.lm_head.weight.data.float()
1374
- / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
1375
- ).to(hidden_states.dtype)
1376
- logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1377
- self.norm_head = False
1378
  else:
1379
  logits = self.lm_head(hidden_states)
1380
- return logits
1381
 
1382
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1383
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1452,6 +1518,14 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1452
  loss = None
1453
  aux_loss = None
1454
 
 
 
 
 
 
 
 
 
1455
  if labels is not None:
1456
  # Shift so that tokens < n predict n
1457
  shift_logits = logits[..., :-1, :].contiguous()
@@ -1547,3 +1621,107 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1547
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1548
  )
1549
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  _CONFIG_FOR_DOC = "BailingMoeConfig"
74
 
75
+ # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
76
+ def load_balancing_loss_func(
77
+ gate_logits_and_topk: Union[torch.Tensor, Tuple[torch.Tensor], None],
78
+ num_experts: Optional[int] = None,
79
+ top_k=2,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ ) -> Union[torch.Tensor, int]:
82
+ r"""
83
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
84
+
85
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
86
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
87
+ experts is too unbalanced.
88
+
89
+ Args:
90
+ gate_logits:
91
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
92
+ shape [batch_size X sequence_length, num_experts].
93
+ num_experts:
94
+ Number of experts
95
+ top_k:
96
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
97
+ parameter.
98
+ attention_mask (`torch.Tensor`, *optional*):
99
+ The attention_mask used in forward function
100
+ shape [batch_size X sequence_length] if not None.
101
+
102
+ Returns:
103
+ The auxiliary loss.
104
+ """
105
+ if gate_logits_and_topk is None or not isinstance(gate_logits_and_topk, tuple):
106
+ return 0
107
+
108
+ if isinstance(gate_logits_and_topk, tuple):
109
+ # concatenated_gate_logits.shape = [batch_size * num_layers * seq_len, num_experts]
110
+ concatenated_gate_logits = torch.cat([layer_gate[0] for layer_gate in gate_logits_and_topk], dim=0)
111
+ # selected_experts.shape = [batch_size * num_layers * seq_len, top_k_experts]
112
+ selected_experts = torch.cat([layer_gate[1] for layer_gate in gate_logits_and_topk], dim=0)
113
+ selected_experts.to(concatenated_gate_logits.device)
114
+
115
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
116
+
117
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
118
+
119
+ if attention_mask is None:
120
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
121
+ # Compute the average probability of routing to these experts
122
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
123
+ else:
124
+ batch_size, sequence_length = attention_mask.shape
125
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
126
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
127
+ expert_attention_mask = (
128
+ attention_mask[None, :, :, None, None]
129
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
130
+ .reshape(-1, top_k, num_experts)
131
+ )
132
+
133
+ # Compute the percentage of tokens routed to each experts
134
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
135
+ expert_attention_mask, dim=0
136
+ )
137
+
138
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
139
+ router_per_expert_attention_mask = (
140
+ attention_mask[None, :, :, None]
141
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
142
+ .reshape(-1, num_experts)
143
+ )
144
+
145
+ # Compute the average probability of routing to these experts
146
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
147
+
148
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
149
+ return overall_loss
150
 
151
  def _get_unpad_data(attention_mask):
152
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
496
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h)
497
  if self.config.num_shared_experts is not None:
498
  y = y + self.shared_experts(identity)
499
+ return y, (router_logits, topk_idx)
500
 
501
  @torch.no_grad()
502
  def moe_infer(self, x, topk_ids, topk_weight):
 
1438
 
1439
  def compute_logit(self, hidden_states):
1440
  if self.norm_head:
1441
+ weight_float = self.lm_head.weight.float()
1442
+ norm = torch.norm(weight_float, p=2, dim=0, keepdim=True).clamp(min=1e-7)
1443
+ norm_weight = (weight_float / norm).to(hidden_states.dtype)
1444
+ logits = F.linear(hidden_states, norm_weight, None)
 
 
 
 
 
 
 
 
1445
  else:
1446
  logits = self.lm_head(hidden_states)
 
1447
 
1448
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1449
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1518
  loss = None
1519
  aux_loss = None
1520
 
1521
+ if output_router_logits:
1522
+ aux_loss = load_balancing_loss_func(
1523
+ outputs.router_logits if return_dict else outputs[-1],
1524
+ self.num_experts,
1525
+ self.num_experts_per_tok,
1526
+ attention_mask,
1527
+ )
1528
+
1529
  if labels is not None:
1530
  # Shift so that tokens < n predict n
1531
  shift_logits = logits[..., :-1, :].contiguous()
 
1621
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1622
  )
1623
  return reordered_past
1624
+
1625
+
1626
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
1627
+ class BailingMoeForTokenClassification(BailingMoePreTrainedModel):
1628
+ def __init__(self, config):
1629
+ super().__init__(config)
1630
+ self.num_labels = config.num_labels
1631
+
1632
+ self.num_experts = config.num_experts
1633
+ self.num_experts_per_tok = config.num_experts_per_tok
1634
+
1635
+ self.model = BailingMoeModel(config)
1636
+ if getattr(config, "classifier_dropout", None) is not None:
1637
+ classifier_dropout = config.classifier_dropout
1638
+ elif getattr(config, "hidden_dropout", None) is not None:
1639
+ classifier_dropout = config.hidden_dropout
1640
+ else:
1641
+ classifier_dropout = 0.1
1642
+ self.dropout = nn.Dropout(classifier_dropout)
1643
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1644
+
1645
+ # Initialize weights and apply final processing
1646
+ self.post_init()
1647
+
1648
+ def get_input_embeddings(self):
1649
+ return self.model.embed_tokens
1650
+
1651
+ def set_input_embeddings(self, value):
1652
+ self.model.embed_tokens = value
1653
+
1654
+ @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1655
+ def forward(
1656
+ self,
1657
+ input_ids: Optional[torch.LongTensor] = None,
1658
+ attention_mask: Optional[torch.Tensor] = None,
1659
+ position_ids: Optional[torch.LongTensor] = None,
1660
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1661
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1662
+ labels: Optional[torch.LongTensor] = None,
1663
+ use_cache: Optional[bool] = None,
1664
+ output_attentions: Optional[bool] = None,
1665
+ output_hidden_states: Optional[bool] = None,
1666
+ output_router_logits: Optional[bool] = None,
1667
+ return_dict: Optional[bool] = None,
1668
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1669
+ r"""
1670
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1671
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1672
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1673
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1674
+ """
1675
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1676
+ output_router_logits = (
1677
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1678
+ )
1679
+
1680
+ outputs = self.model(
1681
+ input_ids,
1682
+ attention_mask=attention_mask,
1683
+ position_ids=position_ids,
1684
+ past_key_values=past_key_values,
1685
+ inputs_embeds=inputs_embeds,
1686
+ use_cache=use_cache,
1687
+ output_attentions=output_attentions,
1688
+ output_hidden_states=output_hidden_states,
1689
+ output_router_logits=output_router_logits,
1690
+ return_dict=return_dict,
1691
+ )
1692
+ sequence_output = outputs[0]
1693
+ sequence_output = self.dropout(sequence_output)
1694
+ logits = self.score(sequence_output)
1695
+
1696
+ loss = None
1697
+ aux_loss = None
1698
+ if labels is not None:
1699
+ loss = self.loss_function(logits, labels, self.config)
1700
+
1701
+ if output_router_logits:
1702
+ aux_loss = load_balancing_loss_func(
1703
+ outputs.router_logits if return_dict else outputs[-1],
1704
+ self.num_experts,
1705
+ self.num_experts_per_tok,
1706
+ attention_mask,
1707
+ )
1708
+
1709
+ if not return_dict:
1710
+ output = (logits,) + outputs[1:]
1711
+ if output_router_logits:
1712
+ output = (aux_loss,) + output
1713
+ return (loss,) + output if loss is not None else output
1714
+
1715
+ if not return_dict:
1716
+ output = (logits,) + outputs[2:]
1717
+ return ((loss,) + output) if loss is not None else output
1718
+
1719
+ return MoeCausalLMOutputWithPast(
1720
+ loss=loss,
1721
+ aux_loss=aux_loss,
1722
+ logits=logits,
1723
+ past_key_values=outputs.past_key_values,
1724
+ hidden_states=outputs.hidden_states,
1725
+ attentions=outputs.attentions,
1726
+ router_logits=outputs.router_logits,
1727
+ )
tokenizer_config.json CHANGED
@@ -10,7 +10,7 @@
10
  "<|number_end|>"
11
  ],
12
  "bos_token": "<|startoftext|>",
13
- "chat_template": "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '<role>' + role + '</role>' + message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '<role>ASSISTANT</role>' }}{% endif %}",
14
  "clean_up_tokenization_spaces": false,
15
  "cls_token": "[CLS]",
16
  "eos_token": "<|endoftext|>",
 
10
  "<|number_end|>"
11
  ],
12
  "bos_token": "<|startoftext|>",
13
+ "chat_template": "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '<role>' + role + '</role>' + message['content'] }}{% if role == 'ASSISTANT' %}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<role>ASSISTANT</role>' }}{% endif %}",
14
  "clean_up_tokenization_spaces": false,
15
  "cls_token": "[CLS]",
16
  "eos_token": "<|endoftext|>",