Spaces:
Running
on
Zero
Running
on
Zero
Fix init and repro (#48)
Browse files* Fix init and repro
* comment + black
---------
Co-authored-by: Srini Iyer <[email protected]>
- bytelatent/base_transformer.py +19 -14
- bytelatent/distributed.py +15 -7
- bytelatent/model/blt.py +8 -48
- bytelatent/model/latent_transformer.py +12 -11
- bytelatent/model/local_models.py +15 -12
- bytelatent/transformer.py +19 -8
bytelatent/base_transformer.py
CHANGED
@@ -445,7 +445,7 @@ class Attention(nn.Module):
|
|
445 |
return output
|
446 |
|
447 |
def reset_parameters(self, init_std=None, factor=1.0):
|
448 |
-
init_std = init_std or (self.dim ** (-0.5))
|
449 |
|
450 |
for w in [self.wq, self.wk, self.wv]:
|
451 |
nn.init.trunc_normal_(
|
@@ -459,7 +459,7 @@ class Attention(nn.Module):
|
|
459 |
nn.init.trunc_normal_(
|
460 |
self.wo.weight,
|
461 |
mean=0.0,
|
462 |
-
std=init_std
|
463 |
a=-3 * init_std,
|
464 |
b=3 * init_std,
|
465 |
)
|
@@ -509,18 +509,16 @@ class FeedForward(nn.Module):
|
|
509 |
return output
|
510 |
|
511 |
def reset_parameters(self, init_std=None, factor=1.0):
|
512 |
-
in_init_std = init_std or (self.dim ** (-0.5))
|
513 |
-
out_init_std = init_std or (self.hidden_dim ** (-0.5))
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
b=3 * in_init_std,
|
523 |
-
)
|
524 |
nn.init.trunc_normal_(
|
525 |
self.w2.weight,
|
526 |
mean=0.0,
|
@@ -528,6 +526,13 @@ class FeedForward(nn.Module):
|
|
528 |
a=-3 * out_init_std,
|
529 |
b=3 * out_init_std,
|
530 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
|
532 |
|
533 |
class TransformerBlock(nn.Module):
|
|
|
445 |
return output
|
446 |
|
447 |
def reset_parameters(self, init_std=None, factor=1.0):
|
448 |
+
init_std = init_std or (self.dim ** (-0.5)) / factor
|
449 |
|
450 |
for w in [self.wq, self.wk, self.wv]:
|
451 |
nn.init.trunc_normal_(
|
|
|
459 |
nn.init.trunc_normal_(
|
460 |
self.wo.weight,
|
461 |
mean=0.0,
|
462 |
+
std=init_std,
|
463 |
a=-3 * init_std,
|
464 |
b=3 * init_std,
|
465 |
)
|
|
|
509 |
return output
|
510 |
|
511 |
def reset_parameters(self, init_std=None, factor=1.0):
|
512 |
+
in_init_std = init_std or (self.dim ** (-0.5)) / factor
|
513 |
+
out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
|
514 |
+
|
515 |
+
nn.init.trunc_normal_(
|
516 |
+
self.w1.weight,
|
517 |
+
mean=0.0,
|
518 |
+
std=in_init_std,
|
519 |
+
a=-3 * in_init_std,
|
520 |
+
b=3 * in_init_std,
|
521 |
+
)
|
|
|
|
|
522 |
nn.init.trunc_normal_(
|
523 |
self.w2.weight,
|
524 |
mean=0.0,
|
|
|
526 |
a=-3 * out_init_std,
|
527 |
b=3 * out_init_std,
|
528 |
)
|
529 |
+
nn.init.trunc_normal_(
|
530 |
+
self.w3.weight,
|
531 |
+
mean=0.0,
|
532 |
+
std=in_init_std,
|
533 |
+
a=-3 * in_init_std,
|
534 |
+
b=3 * in_init_std,
|
535 |
+
)
|
536 |
|
537 |
|
538 |
class TransformerBlock(nn.Module):
|
bytelatent/distributed.py
CHANGED
@@ -463,13 +463,21 @@ def parallelize_model(
|
|
463 |
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
464 |
|
465 |
if distributed_args.selective_activation_checkpointing:
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
|
474 |
if distributed_args.compile:
|
475 |
torch._dynamo.config.cache_size_limit = (
|
|
|
463 |
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
464 |
|
465 |
if distributed_args.selective_activation_checkpointing:
|
466 |
+
# only works for blt models
|
467 |
+
# assuming that entropy models will not use checkpointing
|
468 |
+
for module in [
|
469 |
+
model.global_transformer,
|
470 |
+
model.local_encoder,
|
471 |
+
model.local_decoder,
|
472 |
+
]:
|
473 |
+
for i in range(len(module.layers)):
|
474 |
+
module.layers[i] = checkpoint_wrapper(
|
475 |
+
module.layers[i],
|
476 |
+
context_fn=partial(
|
477 |
+
create_selective_checkpoint_contexts,
|
478 |
+
get_default_policy(no_recompute_ops),
|
479 |
+
),
|
480 |
+
)
|
481 |
|
482 |
if distributed_args.compile:
|
483 |
torch._dynamo.config.cache_size_limit = (
|
bytelatent/model/blt.py
CHANGED
@@ -825,12 +825,6 @@ class ByteLatentTransformer(nn.Module):
|
|
825 |
local_encoder_dim=self.local_encoder.dim,
|
826 |
encoder_hash_byte_group_size=None,
|
827 |
)
|
828 |
-
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
829 |
-
|
830 |
-
# Transformer layers
|
831 |
-
self.layers = nn.ModuleList(
|
832 |
-
[TransformerBlock(args) for _ in range(args.n_layers)]
|
833 |
-
)
|
834 |
|
835 |
# Encoder ngram embedding tables
|
836 |
self.encoder_ngram_embedding = None
|
@@ -848,9 +842,6 @@ class ByteLatentTransformer(nn.Module):
|
|
848 |
|
849 |
# Output layer
|
850 |
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
851 |
-
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
852 |
-
if args.weight_tying:
|
853 |
-
self.output.weight = self.tok_embeddings.weight
|
854 |
|
855 |
# Patcher module
|
856 |
if args.patch_in_forward:
|
@@ -954,11 +945,10 @@ class ByteLatentTransformer(nn.Module):
|
|
954 |
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
955 |
|
956 |
# Local encoder
|
957 |
-
h_cross = None
|
958 |
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
959 |
tokens=local_encoder_tokens,
|
960 |
embeds=local_encoder_embeds,
|
961 |
-
patch_embeds=
|
962 |
cross_mask=cross_attn_mask_enc,
|
963 |
num_patches=patch_lengths.shape[1],
|
964 |
patch_ids=patch_ids,
|
@@ -1033,47 +1023,17 @@ class ByteLatentTransformer(nn.Module):
|
|
1033 |
)
|
1034 |
return output
|
1035 |
|
1036 |
-
def reset_parameters(self, init_std=None):
|
1037 |
-
# Either use fixed base std or sqrt model dim
|
1038 |
-
init_std = init_std or (self.dim ** (-0.5))
|
1039 |
-
nn.init.trunc_normal_(
|
1040 |
-
self.tok_embeddings.weight,
|
1041 |
-
mean=0.0,
|
1042 |
-
std=init_std,
|
1043 |
-
a=-3 * init_std,
|
1044 |
-
b=3 * init_std,
|
1045 |
-
)
|
1046 |
-
if not self.weight_tying:
|
1047 |
-
nn.init.trunc_normal_(
|
1048 |
-
self.output.weight,
|
1049 |
-
mean=0.0,
|
1050 |
-
std=init_std,
|
1051 |
-
a=-3 * init_std,
|
1052 |
-
b=3 * init_std,
|
1053 |
-
)
|
1054 |
-
|
1055 |
def init_weights(self):
|
1056 |
-
self.
|
1057 |
-
self.
|
1058 |
-
|
1059 |
-
factor = {
|
1060 |
-
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
1061 |
-
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
1062 |
-
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
1063 |
-
InitStdFactor.DISABLED: 1.0,
|
1064 |
-
}[self.init_std_factor]
|
1065 |
-
|
1066 |
-
layer.init_weights(self.init_base_std, factor)
|
1067 |
-
|
1068 |
-
self.local_decoder.init_weights(self.init_base_std)
|
1069 |
-
self.global_transformer.init_weights(self.init_base_std)
|
1070 |
-
self.local_encoder.init_weights(self.init_base_std)
|
1071 |
|
|
|
1072 |
for emb in self.encoder_hash_tok_embedding:
|
1073 |
nn.init.trunc_normal_(
|
1074 |
emb.weight,
|
1075 |
mean=0.0,
|
1076 |
-
std=
|
1077 |
-
a=-3 *
|
1078 |
-
b=3 *
|
1079 |
)
|
|
|
825 |
local_encoder_dim=self.local_encoder.dim,
|
826 |
encoder_hash_byte_group_size=None,
|
827 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
|
829 |
# Encoder ngram embedding tables
|
830 |
self.encoder_ngram_embedding = None
|
|
|
842 |
|
843 |
# Output layer
|
844 |
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
|
|
|
|
|
|
845 |
|
846 |
# Patcher module
|
847 |
if args.patch_in_forward:
|
|
|
945 |
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
946 |
|
947 |
# Local encoder
|
|
|
948 |
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
949 |
tokens=local_encoder_tokens,
|
950 |
embeds=local_encoder_embeds,
|
951 |
+
patch_embeds=None,
|
952 |
cross_mask=cross_attn_mask_enc,
|
953 |
num_patches=patch_lengths.shape[1],
|
954 |
patch_ids=patch_ids,
|
|
|
1023 |
)
|
1024 |
return output
|
1025 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1026 |
def init_weights(self):
|
1027 |
+
self.local_encoder.init_weights()
|
1028 |
+
self.global_transformer.init_weights()
|
1029 |
+
self.local_decoder.init_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1030 |
|
1031 |
+
emb_std = self.local_encoder.dim ** (-0.5)
|
1032 |
for emb in self.encoder_hash_tok_embedding:
|
1033 |
nn.init.trunc_normal_(
|
1034 |
emb.weight,
|
1035 |
mean=0.0,
|
1036 |
+
std=emb_std,
|
1037 |
+
a=-3 * emb_std,
|
1038 |
+
b=3 * emb_std,
|
1039 |
)
|
bytelatent/model/latent_transformer.py
CHANGED
@@ -78,10 +78,10 @@ class CrossAttention(nn.Module):
|
|
78 |
# B S D
|
79 |
bsz, seq_len, _ = x.shape
|
80 |
_, slen_kv, _ = kv.shape
|
81 |
-
|
82 |
kv = self.cross_attn_norm_kv(kv)
|
83 |
|
84 |
-
xq = self.wq(
|
85 |
xk = self.wk(kv)
|
86 |
xv = self.wv(kv)
|
87 |
|
@@ -104,7 +104,7 @@ class CrossAttention(nn.Module):
|
|
104 |
return x + output
|
105 |
|
106 |
def init_weights(self, base_std: float, factor: float = 1.0):
|
107 |
-
std = base_std
|
108 |
|
109 |
nn.init.trunc_normal_(
|
110 |
self.wq.weight,
|
@@ -130,13 +130,12 @@ class CrossAttention(nn.Module):
|
|
130 |
b=3 * std,
|
131 |
)
|
132 |
|
133 |
-
output_std = std / (2**0.5)
|
134 |
nn.init.trunc_normal_(
|
135 |
self.wo.weight,
|
136 |
mean=0.0,
|
137 |
-
std=
|
138 |
-
a=-3 *
|
139 |
-
b=3 *
|
140 |
)
|
141 |
self.cross_attn_norm_q.reset_parameters()
|
142 |
self.cross_attn_norm_kv.reset_parameters()
|
@@ -147,6 +146,7 @@ class GlobalTransformer(BaseTransformer):
|
|
147 |
super().__init__(args)
|
148 |
self.dropout = args.dropout
|
149 |
self.eos_id = args.eos_id
|
|
|
150 |
|
151 |
self.token_embedding_projection = None
|
152 |
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
@@ -192,13 +192,14 @@ class GlobalTransformer(BaseTransformer):
|
|
192 |
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
193 |
return h, cache
|
194 |
|
195 |
-
def init_weights(self
|
196 |
super().init_weights()
|
|
|
197 |
if self.token_embedding_projection is not None:
|
198 |
nn.init.trunc_normal_(
|
199 |
self.token_embedding_projection.weight,
|
200 |
mean=0.0,
|
201 |
-
std=
|
202 |
-
a=-3 *
|
203 |
-
b=3 *
|
204 |
)
|
|
|
78 |
# B S D
|
79 |
bsz, seq_len, _ = x.shape
|
80 |
_, slen_kv, _ = kv.shape
|
81 |
+
x_norm = self.cross_attn_norm_q(x)
|
82 |
kv = self.cross_attn_norm_kv(kv)
|
83 |
|
84 |
+
xq = self.wq(x_norm)
|
85 |
xk = self.wk(kv)
|
86 |
xv = self.wv(kv)
|
87 |
|
|
|
104 |
return x + output
|
105 |
|
106 |
def init_weights(self, base_std: float, factor: float = 1.0):
|
107 |
+
std = base_std or (self.dim ** (-0.5)) / factor
|
108 |
|
109 |
nn.init.trunc_normal_(
|
110 |
self.wq.weight,
|
|
|
130 |
b=3 * std,
|
131 |
)
|
132 |
|
|
|
133 |
nn.init.trunc_normal_(
|
134 |
self.wo.weight,
|
135 |
mean=0.0,
|
136 |
+
std=std,
|
137 |
+
a=-3 * std,
|
138 |
+
b=3 * std,
|
139 |
)
|
140 |
self.cross_attn_norm_q.reset_parameters()
|
141 |
self.cross_attn_norm_kv.reset_parameters()
|
|
|
146 |
super().__init__(args)
|
147 |
self.dropout = args.dropout
|
148 |
self.eos_id = args.eos_id
|
149 |
+
self.dim_token_emb = args.dim_token_emb
|
150 |
|
151 |
self.token_embedding_projection = None
|
152 |
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
|
|
192 |
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
193 |
return h, cache
|
194 |
|
195 |
+
def init_weights(self):
|
196 |
super().init_weights()
|
197 |
+
std = self.dim_token_emb ** (-0.5)
|
198 |
if self.token_embedding_projection is not None:
|
199 |
nn.init.trunc_normal_(
|
200 |
self.token_embedding_projection.weight,
|
201 |
mean=0.0,
|
202 |
+
std=std,
|
203 |
+
a=-3 * std,
|
204 |
+
b=3 * std,
|
205 |
)
|
bytelatent/model/local_models.py
CHANGED
@@ -34,7 +34,7 @@ class LocalModelArgs(BaseTransformerArgs):
|
|
34 |
# Local encoder specific dimensions
|
35 |
dropout: float
|
36 |
vocab_size: int
|
37 |
-
patch_size:
|
38 |
sliding_window: int | None
|
39 |
use_rope: bool
|
40 |
cross_attn_encoder: bool | None
|
@@ -61,6 +61,7 @@ class LocalModelBase(nn.Module):
|
|
61 |
self.dropout = args.dropout
|
62 |
self.vocab_size = args.vocab_size
|
63 |
self.patch_size = args.patch_size
|
|
|
64 |
|
65 |
self.attn_impl = args.attn_impl
|
66 |
self.sliding_window = args.sliding_window
|
@@ -130,6 +131,7 @@ class LocalModelBase(nn.Module):
|
|
130 |
|
131 |
def init_weights(self, init_std=None):
|
132 |
self.rope.reset_parameters()
|
|
|
133 |
|
134 |
init_std = init_std or (self.dim ** (-0.5))
|
135 |
nn.init.trunc_normal_(
|
@@ -156,33 +158,34 @@ class LocalModelBase(nn.Module):
|
|
156 |
InitStdFactor.DISABLED: 1.0,
|
157 |
}[self.init_std_factor]
|
158 |
|
159 |
-
layer.init_weights(
|
160 |
|
161 |
-
if self
|
162 |
nn.init.trunc_normal_(
|
163 |
-
self.
|
164 |
mean=0.0,
|
165 |
std=init_std,
|
166 |
a=-3 * init_std,
|
167 |
b=3 * init_std,
|
168 |
)
|
169 |
|
170 |
-
if self.
|
171 |
nn.init.trunc_normal_(
|
172 |
-
self.
|
173 |
mean=0.0,
|
174 |
std=init_std,
|
175 |
a=-3 * init_std,
|
176 |
b=3 * init_std,
|
177 |
)
|
178 |
|
179 |
-
if
|
|
|
180 |
nn.init.trunc_normal_(
|
181 |
-
self.
|
182 |
mean=0.0,
|
183 |
-
std=
|
184 |
-
a=-3 *
|
185 |
-
b=3 *
|
186 |
)
|
187 |
|
188 |
if self.cross_attn_layers is not None:
|
@@ -194,7 +197,7 @@ class LocalModelBase(nn.Module):
|
|
194 |
InitStdFactor.DISABLED: 1.0,
|
195 |
}[self.init_std_factor]
|
196 |
|
197 |
-
layer.init_weights(
|
198 |
|
199 |
|
200 |
class LocalEncoder(LocalModelBase):
|
|
|
34 |
# Local encoder specific dimensions
|
35 |
dropout: float
|
36 |
vocab_size: int
|
37 |
+
patch_size: float
|
38 |
sliding_window: int | None
|
39 |
use_rope: bool
|
40 |
cross_attn_encoder: bool | None
|
|
|
61 |
self.dropout = args.dropout
|
62 |
self.vocab_size = args.vocab_size
|
63 |
self.patch_size = args.patch_size
|
64 |
+
self.dim_patch_emb = args.dim_patch_emb
|
65 |
|
66 |
self.attn_impl = args.attn_impl
|
67 |
self.sliding_window = args.sliding_window
|
|
|
131 |
|
132 |
def init_weights(self, init_std=None):
|
133 |
self.rope.reset_parameters()
|
134 |
+
self.norm.reset_parameters()
|
135 |
|
136 |
init_std = init_std or (self.dim ** (-0.5))
|
137 |
nn.init.trunc_normal_(
|
|
|
158 |
InitStdFactor.DISABLED: 1.0,
|
159 |
}[self.init_std_factor]
|
160 |
|
161 |
+
layer.init_weights(None, factor)
|
162 |
|
163 |
+
if hasattr(self, "output"):
|
164 |
nn.init.trunc_normal_(
|
165 |
+
self.output.weight,
|
166 |
mean=0.0,
|
167 |
std=init_std,
|
168 |
a=-3 * init_std,
|
169 |
b=3 * init_std,
|
170 |
)
|
171 |
|
172 |
+
if self.token_embedding_projection is not None:
|
173 |
nn.init.trunc_normal_(
|
174 |
+
self.token_embedding_projection.weight,
|
175 |
mean=0.0,
|
176 |
std=init_std,
|
177 |
a=-3 * init_std,
|
178 |
b=3 * init_std,
|
179 |
)
|
180 |
|
181 |
+
if self.patch_embedding_projection is not None:
|
182 |
+
patch_emb_std = self.dim_patch_emb ** (-0.5)
|
183 |
nn.init.trunc_normal_(
|
184 |
+
self.patch_embedding_projection.weight,
|
185 |
mean=0.0,
|
186 |
+
std=patch_emb_std,
|
187 |
+
a=-3 * patch_emb_std,
|
188 |
+
b=3 * patch_emb_std,
|
189 |
)
|
190 |
|
191 |
if self.cross_attn_layers is not None:
|
|
|
197 |
InitStdFactor.DISABLED: 1.0,
|
198 |
}[self.init_std_factor]
|
199 |
|
200 |
+
layer.init_weights(None, factor)
|
201 |
|
202 |
|
203 |
class LocalEncoder(LocalModelBase):
|
bytelatent/transformer.py
CHANGED
@@ -137,14 +137,25 @@ def get_no_recompute_ops():
|
|
137 |
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
138 |
group_plan: Tuple[int, bool] = []
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
return group_plan
|
150 |
|
|
|
137 |
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
138 |
group_plan: Tuple[int, bool] = []
|
139 |
|
140 |
+
if isinstance(model_args, LMTransformerArgs):
|
141 |
+
group_plan.append(("tok_embeddings", False))
|
142 |
+
|
143 |
+
for i in range(model_args.n_layers):
|
144 |
+
group_plan.append((f"layers.{i}", False))
|
145 |
+
|
146 |
+
group_plan.append(("output", True))
|
147 |
+
else:
|
148 |
+
for i in range(model_args.n_layers_local_encoder):
|
149 |
+
group_plan.append((f"local_encoder.layers.{i}", True))
|
150 |
+
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
|
151 |
+
for i in range(model_args.n_layers_local_decoder):
|
152 |
+
group_plan.append((f"local_decoder.layers.{i}", True))
|
153 |
+
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
|
154 |
+
for i in range(model_args.n_layers_global):
|
155 |
+
group_plan.append((f"global_transformer.layers.{i}", True))
|
156 |
+
|
157 |
+
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
158 |
+
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
|
159 |
|
160 |
return group_plan
|
161 |
|