Srinivasan Iyer sviyer commited on
Commit
aebdc48
·
unverified ·
1 Parent(s): 936d943

Fix init and repro (#48)

Browse files

* Fix init and repro

* comment + black

---------

Co-authored-by: Srini Iyer <[email protected]>

bytelatent/base_transformer.py CHANGED
@@ -445,7 +445,7 @@ class Attention(nn.Module):
445
  return output
446
 
447
  def reset_parameters(self, init_std=None, factor=1.0):
448
- init_std = init_std or (self.dim ** (-0.5))
449
 
450
  for w in [self.wq, self.wk, self.wv]:
451
  nn.init.trunc_normal_(
@@ -459,7 +459,7 @@ class Attention(nn.Module):
459
  nn.init.trunc_normal_(
460
  self.wo.weight,
461
  mean=0.0,
462
- std=init_std / factor,
463
  a=-3 * init_std,
464
  b=3 * init_std,
465
  )
@@ -509,18 +509,16 @@ class FeedForward(nn.Module):
509
  return output
510
 
511
  def reset_parameters(self, init_std=None, factor=1.0):
512
- in_init_std = init_std or (self.dim ** (-0.5))
513
- out_init_std = init_std or (self.hidden_dim ** (-0.5))
514
- in_init_std = in_init_std
515
- out_init_std = out_init_std / factor
516
- for w in [self.w1, self.w3]:
517
- nn.init.trunc_normal_(
518
- w.weight,
519
- mean=0.0,
520
- std=in_init_std,
521
- a=-3 * in_init_std,
522
- b=3 * in_init_std,
523
- )
524
  nn.init.trunc_normal_(
525
  self.w2.weight,
526
  mean=0.0,
@@ -528,6 +526,13 @@ class FeedForward(nn.Module):
528
  a=-3 * out_init_std,
529
  b=3 * out_init_std,
530
  )
 
 
 
 
 
 
 
531
 
532
 
533
  class TransformerBlock(nn.Module):
 
445
  return output
446
 
447
  def reset_parameters(self, init_std=None, factor=1.0):
448
+ init_std = init_std or (self.dim ** (-0.5)) / factor
449
 
450
  for w in [self.wq, self.wk, self.wv]:
451
  nn.init.trunc_normal_(
 
459
  nn.init.trunc_normal_(
460
  self.wo.weight,
461
  mean=0.0,
462
+ std=init_std,
463
  a=-3 * init_std,
464
  b=3 * init_std,
465
  )
 
509
  return output
510
 
511
  def reset_parameters(self, init_std=None, factor=1.0):
512
+ in_init_std = init_std or (self.dim ** (-0.5)) / factor
513
+ out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
514
+
515
+ nn.init.trunc_normal_(
516
+ self.w1.weight,
517
+ mean=0.0,
518
+ std=in_init_std,
519
+ a=-3 * in_init_std,
520
+ b=3 * in_init_std,
521
+ )
 
 
522
  nn.init.trunc_normal_(
523
  self.w2.weight,
524
  mean=0.0,
 
526
  a=-3 * out_init_std,
527
  b=3 * out_init_std,
528
  )
529
+ nn.init.trunc_normal_(
530
+ self.w3.weight,
531
+ mean=0.0,
532
+ std=in_init_std,
533
+ a=-3 * in_init_std,
534
+ b=3 * in_init_std,
535
+ )
536
 
537
 
538
  class TransformerBlock(nn.Module):
bytelatent/distributed.py CHANGED
@@ -463,13 +463,21 @@ def parallelize_model(
463
  raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
464
 
465
  if distributed_args.selective_activation_checkpointing:
466
- model = checkpoint_wrapper(
467
- model,
468
- context_fn=partial(
469
- create_selective_checkpoint_contexts,
470
- get_default_policy(no_recompute_ops),
471
- ),
472
- )
 
 
 
 
 
 
 
 
473
 
474
  if distributed_args.compile:
475
  torch._dynamo.config.cache_size_limit = (
 
463
  raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
464
 
465
  if distributed_args.selective_activation_checkpointing:
466
+ # only works for blt models
467
+ # assuming that entropy models will not use checkpointing
468
+ for module in [
469
+ model.global_transformer,
470
+ model.local_encoder,
471
+ model.local_decoder,
472
+ ]:
473
+ for i in range(len(module.layers)):
474
+ module.layers[i] = checkpoint_wrapper(
475
+ module.layers[i],
476
+ context_fn=partial(
477
+ create_selective_checkpoint_contexts,
478
+ get_default_policy(no_recompute_ops),
479
+ ),
480
+ )
481
 
482
  if distributed_args.compile:
483
  torch._dynamo.config.cache_size_limit = (
bytelatent/model/blt.py CHANGED
@@ -825,12 +825,6 @@ class ByteLatentTransformer(nn.Module):
825
  local_encoder_dim=self.local_encoder.dim,
826
  encoder_hash_byte_group_size=None,
827
  )
828
- self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
829
-
830
- # Transformer layers
831
- self.layers = nn.ModuleList(
832
- [TransformerBlock(args) for _ in range(args.n_layers)]
833
- )
834
 
835
  # Encoder ngram embedding tables
836
  self.encoder_ngram_embedding = None
@@ -848,9 +842,6 @@ class ByteLatentTransformer(nn.Module):
848
 
849
  # Output layer
850
  assert args.vocab_size > 0, "vocab_size must be greater than 0"
851
- self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
852
- if args.weight_tying:
853
- self.output.weight = self.tok_embeddings.weight
854
 
855
  # Patcher module
856
  if args.patch_in_forward:
@@ -954,11 +945,10 @@ class ByteLatentTransformer(nn.Module):
954
  local_encoder_embeds = local_encoder_embeds + ngram_embeds
955
 
956
  # Local encoder
957
- h_cross = None
958
  (h_encoder, h_cross), cache_encoder = self.local_encoder(
959
  tokens=local_encoder_tokens,
960
  embeds=local_encoder_embeds,
961
- patch_embeds=h_cross if self.cross_attn_encoder else None,
962
  cross_mask=cross_attn_mask_enc,
963
  num_patches=patch_lengths.shape[1],
964
  patch_ids=patch_ids,
@@ -1033,47 +1023,17 @@ class ByteLatentTransformer(nn.Module):
1033
  )
1034
  return output
1035
 
1036
- def reset_parameters(self, init_std=None):
1037
- # Either use fixed base std or sqrt model dim
1038
- init_std = init_std or (self.dim ** (-0.5))
1039
- nn.init.trunc_normal_(
1040
- self.tok_embeddings.weight,
1041
- mean=0.0,
1042
- std=init_std,
1043
- a=-3 * init_std,
1044
- b=3 * init_std,
1045
- )
1046
- if not self.weight_tying:
1047
- nn.init.trunc_normal_(
1048
- self.output.weight,
1049
- mean=0.0,
1050
- std=init_std,
1051
- a=-3 * init_std,
1052
- b=3 * init_std,
1053
- )
1054
-
1055
  def init_weights(self):
1056
- self.reset_parameters()
1057
- self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
1058
- for depth, layer in enumerate(self.layers):
1059
- factor = {
1060
- InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
1061
- InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
1062
- InitStdFactor.DIM_RATIO: self.dim / 4096,
1063
- InitStdFactor.DISABLED: 1.0,
1064
- }[self.init_std_factor]
1065
-
1066
- layer.init_weights(self.init_base_std, factor)
1067
-
1068
- self.local_decoder.init_weights(self.init_base_std)
1069
- self.global_transformer.init_weights(self.init_base_std)
1070
- self.local_encoder.init_weights(self.init_base_std)
1071
 
 
1072
  for emb in self.encoder_hash_tok_embedding:
1073
  nn.init.trunc_normal_(
1074
  emb.weight,
1075
  mean=0.0,
1076
- std=self.init_base_std,
1077
- a=-3 * self.init_base_std,
1078
- b=3 * self.init_base_std,
1079
  )
 
825
  local_encoder_dim=self.local_encoder.dim,
826
  encoder_hash_byte_group_size=None,
827
  )
 
 
 
 
 
 
828
 
829
  # Encoder ngram embedding tables
830
  self.encoder_ngram_embedding = None
 
842
 
843
  # Output layer
844
  assert args.vocab_size > 0, "vocab_size must be greater than 0"
 
 
 
845
 
846
  # Patcher module
847
  if args.patch_in_forward:
 
945
  local_encoder_embeds = local_encoder_embeds + ngram_embeds
946
 
947
  # Local encoder
 
948
  (h_encoder, h_cross), cache_encoder = self.local_encoder(
949
  tokens=local_encoder_tokens,
950
  embeds=local_encoder_embeds,
951
+ patch_embeds=None,
952
  cross_mask=cross_attn_mask_enc,
953
  num_patches=patch_lengths.shape[1],
954
  patch_ids=patch_ids,
 
1023
  )
1024
  return output
1025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
  def init_weights(self):
1027
+ self.local_encoder.init_weights()
1028
+ self.global_transformer.init_weights()
1029
+ self.local_decoder.init_weights()
 
 
 
 
 
 
 
 
 
 
 
 
1030
 
1031
+ emb_std = self.local_encoder.dim ** (-0.5)
1032
  for emb in self.encoder_hash_tok_embedding:
1033
  nn.init.trunc_normal_(
1034
  emb.weight,
1035
  mean=0.0,
1036
+ std=emb_std,
1037
+ a=-3 * emb_std,
1038
+ b=3 * emb_std,
1039
  )
bytelatent/model/latent_transformer.py CHANGED
@@ -78,10 +78,10 @@ class CrossAttention(nn.Module):
78
  # B S D
79
  bsz, seq_len, _ = x.shape
80
  _, slen_kv, _ = kv.shape
81
- x = self.cross_attn_norm_q(x)
82
  kv = self.cross_attn_norm_kv(kv)
83
 
84
- xq = self.wq(x)
85
  xk = self.wk(kv)
86
  xv = self.wv(kv)
87
 
@@ -104,7 +104,7 @@ class CrossAttention(nn.Module):
104
  return x + output
105
 
106
  def init_weights(self, base_std: float, factor: float = 1.0):
107
- std = base_std * factor
108
 
109
  nn.init.trunc_normal_(
110
  self.wq.weight,
@@ -130,13 +130,12 @@ class CrossAttention(nn.Module):
130
  b=3 * std,
131
  )
132
 
133
- output_std = std / (2**0.5)
134
  nn.init.trunc_normal_(
135
  self.wo.weight,
136
  mean=0.0,
137
- std=output_std,
138
- a=-3 * output_std,
139
- b=3 * output_std,
140
  )
141
  self.cross_attn_norm_q.reset_parameters()
142
  self.cross_attn_norm_kv.reset_parameters()
@@ -147,6 +146,7 @@ class GlobalTransformer(BaseTransformer):
147
  super().__init__(args)
148
  self.dropout = args.dropout
149
  self.eos_id = args.eos_id
 
150
 
151
  self.token_embedding_projection = None
152
  if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
@@ -192,13 +192,14 @@ class GlobalTransformer(BaseTransformer):
192
  h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
193
  return h, cache
194
 
195
- def init_weights(self, init_base_std: float):
196
  super().init_weights()
 
197
  if self.token_embedding_projection is not None:
198
  nn.init.trunc_normal_(
199
  self.token_embedding_projection.weight,
200
  mean=0.0,
201
- std=init_base_std,
202
- a=-3 * init_base_std,
203
- b=3 * init_base_std,
204
  )
 
78
  # B S D
79
  bsz, seq_len, _ = x.shape
80
  _, slen_kv, _ = kv.shape
81
+ x_norm = self.cross_attn_norm_q(x)
82
  kv = self.cross_attn_norm_kv(kv)
83
 
84
+ xq = self.wq(x_norm)
85
  xk = self.wk(kv)
86
  xv = self.wv(kv)
87
 
 
104
  return x + output
105
 
106
  def init_weights(self, base_std: float, factor: float = 1.0):
107
+ std = base_std or (self.dim ** (-0.5)) / factor
108
 
109
  nn.init.trunc_normal_(
110
  self.wq.weight,
 
130
  b=3 * std,
131
  )
132
 
 
133
  nn.init.trunc_normal_(
134
  self.wo.weight,
135
  mean=0.0,
136
+ std=std,
137
+ a=-3 * std,
138
+ b=3 * std,
139
  )
140
  self.cross_attn_norm_q.reset_parameters()
141
  self.cross_attn_norm_kv.reset_parameters()
 
146
  super().__init__(args)
147
  self.dropout = args.dropout
148
  self.eos_id = args.eos_id
149
+ self.dim_token_emb = args.dim_token_emb
150
 
151
  self.token_embedding_projection = None
152
  if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
 
192
  h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
193
  return h, cache
194
 
195
+ def init_weights(self):
196
  super().init_weights()
197
+ std = self.dim_token_emb ** (-0.5)
198
  if self.token_embedding_projection is not None:
199
  nn.init.trunc_normal_(
200
  self.token_embedding_projection.weight,
201
  mean=0.0,
202
+ std=std,
203
+ a=-3 * std,
204
+ b=3 * std,
205
  )
bytelatent/model/local_models.py CHANGED
@@ -34,7 +34,7 @@ class LocalModelArgs(BaseTransformerArgs):
34
  # Local encoder specific dimensions
35
  dropout: float
36
  vocab_size: int
37
- patch_size: int
38
  sliding_window: int | None
39
  use_rope: bool
40
  cross_attn_encoder: bool | None
@@ -61,6 +61,7 @@ class LocalModelBase(nn.Module):
61
  self.dropout = args.dropout
62
  self.vocab_size = args.vocab_size
63
  self.patch_size = args.patch_size
 
64
 
65
  self.attn_impl = args.attn_impl
66
  self.sliding_window = args.sliding_window
@@ -130,6 +131,7 @@ class LocalModelBase(nn.Module):
130
 
131
  def init_weights(self, init_std=None):
132
  self.rope.reset_parameters()
 
133
 
134
  init_std = init_std or (self.dim ** (-0.5))
135
  nn.init.trunc_normal_(
@@ -156,33 +158,34 @@ class LocalModelBase(nn.Module):
156
  InitStdFactor.DISABLED: 1.0,
157
  }[self.init_std_factor]
158
 
159
- layer.init_weights(init_std, factor)
160
 
161
- if self.token_embedding_projection is not None:
162
  nn.init.trunc_normal_(
163
- self.token_embedding_projection.weight,
164
  mean=0.0,
165
  std=init_std,
166
  a=-3 * init_std,
167
  b=3 * init_std,
168
  )
169
 
170
- if self.patch_embedding_projection is not None:
171
  nn.init.trunc_normal_(
172
- self.patch_embedding_projection.weight,
173
  mean=0.0,
174
  std=init_std,
175
  a=-3 * init_std,
176
  b=3 * init_std,
177
  )
178
 
179
- if hasattr(self, "output"):
 
180
  nn.init.trunc_normal_(
181
- self.output.weight,
182
  mean=0.0,
183
- std=init_std,
184
- a=-3 * init_std,
185
- b=3 * init_std,
186
  )
187
 
188
  if self.cross_attn_layers is not None:
@@ -194,7 +197,7 @@ class LocalModelBase(nn.Module):
194
  InitStdFactor.DISABLED: 1.0,
195
  }[self.init_std_factor]
196
 
197
- layer.init_weights(init_std, factor)
198
 
199
 
200
  class LocalEncoder(LocalModelBase):
 
34
  # Local encoder specific dimensions
35
  dropout: float
36
  vocab_size: int
37
+ patch_size: float
38
  sliding_window: int | None
39
  use_rope: bool
40
  cross_attn_encoder: bool | None
 
61
  self.dropout = args.dropout
62
  self.vocab_size = args.vocab_size
63
  self.patch_size = args.patch_size
64
+ self.dim_patch_emb = args.dim_patch_emb
65
 
66
  self.attn_impl = args.attn_impl
67
  self.sliding_window = args.sliding_window
 
131
 
132
  def init_weights(self, init_std=None):
133
  self.rope.reset_parameters()
134
+ self.norm.reset_parameters()
135
 
136
  init_std = init_std or (self.dim ** (-0.5))
137
  nn.init.trunc_normal_(
 
158
  InitStdFactor.DISABLED: 1.0,
159
  }[self.init_std_factor]
160
 
161
+ layer.init_weights(None, factor)
162
 
163
+ if hasattr(self, "output"):
164
  nn.init.trunc_normal_(
165
+ self.output.weight,
166
  mean=0.0,
167
  std=init_std,
168
  a=-3 * init_std,
169
  b=3 * init_std,
170
  )
171
 
172
+ if self.token_embedding_projection is not None:
173
  nn.init.trunc_normal_(
174
+ self.token_embedding_projection.weight,
175
  mean=0.0,
176
  std=init_std,
177
  a=-3 * init_std,
178
  b=3 * init_std,
179
  )
180
 
181
+ if self.patch_embedding_projection is not None:
182
+ patch_emb_std = self.dim_patch_emb ** (-0.5)
183
  nn.init.trunc_normal_(
184
+ self.patch_embedding_projection.weight,
185
  mean=0.0,
186
+ std=patch_emb_std,
187
+ a=-3 * patch_emb_std,
188
+ b=3 * patch_emb_std,
189
  )
190
 
191
  if self.cross_attn_layers is not None:
 
197
  InitStdFactor.DISABLED: 1.0,
198
  }[self.init_std_factor]
199
 
200
+ layer.init_weights(None, factor)
201
 
202
 
203
  class LocalEncoder(LocalModelBase):
bytelatent/transformer.py CHANGED
@@ -137,14 +137,25 @@ def get_no_recompute_ops():
137
  def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
138
  group_plan: Tuple[int, bool] = []
139
 
140
- # Grouping and output seperately
141
- group_plan.append(("tok_embeddings", False))
142
-
143
- # Grouping by layers
144
- for i in range(model_args.n_layers):
145
- group_plan.append((f"layers.{i}", False))
146
-
147
- group_plan.append(("output", True))
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  return group_plan
150
 
 
137
  def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
138
  group_plan: Tuple[int, bool] = []
139
 
140
+ if isinstance(model_args, LMTransformerArgs):
141
+ group_plan.append(("tok_embeddings", False))
142
+
143
+ for i in range(model_args.n_layers):
144
+ group_plan.append((f"layers.{i}", False))
145
+
146
+ group_plan.append(("output", True))
147
+ else:
148
+ for i in range(model_args.n_layers_local_encoder):
149
+ group_plan.append((f"local_encoder.layers.{i}", True))
150
+ group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
151
+ for i in range(model_args.n_layers_local_decoder):
152
+ group_plan.append((f"local_decoder.layers.{i}", True))
153
+ group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
154
+ for i in range(model_args.n_layers_global):
155
+ group_plan.append((f"global_transformer.layers.{i}", True))
156
+
157
+ for i in range(len(model_args.encoder_hash_byte_group_size)):
158
+ group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
159
 
160
  return group_plan
161