Spaces:
Running
Running
disable reshard after forward (#56)
Browse filesCo-authored-by: Srini Iyer <[email protected]>
bytelatent/transformer.py
CHANGED
@@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
|
146 |
group_plan.append(("output", True))
|
147 |
else:
|
148 |
for i in range(model_args.n_layers_local_encoder):
|
149 |
-
group_plan.append((f"local_encoder.layers.{i}",
|
150 |
-
group_plan.append((f"local_encoder.cross_attn_layers.{i}",
|
151 |
for i in range(model_args.n_layers_local_decoder):
|
152 |
-
group_plan.append((f"local_decoder.layers.{i}",
|
153 |
-
group_plan.append((f"local_decoder.cross_attn_layers.{i}",
|
154 |
for i in range(model_args.n_layers_global):
|
155 |
-
group_plan.append((f"global_transformer.layers.{i}",
|
156 |
|
157 |
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
158 |
-
group_plan.append((f"encoder_hash_tok_embedding.{i}",
|
159 |
|
160 |
return group_plan
|
161 |
|
|
|
146 |
group_plan.append(("output", True))
|
147 |
else:
|
148 |
for i in range(model_args.n_layers_local_encoder):
|
149 |
+
group_plan.append((f"local_encoder.layers.{i}", False))
|
150 |
+
group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
|
151 |
for i in range(model_args.n_layers_local_decoder):
|
152 |
+
group_plan.append((f"local_decoder.layers.{i}", False))
|
153 |
+
group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
|
154 |
for i in range(model_args.n_layers_global):
|
155 |
+
group_plan.append((f"global_transformer.layers.{i}", False))
|
156 |
|
157 |
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
158 |
+
group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
|
159 |
|
160 |
return group_plan
|
161 |
|