Srinivasan Iyer sviyer commited on
Commit
9d907fe
·
unverified ·
1 Parent(s): 48e4ad0

disable reshard after forward (#56)

Browse files

Co-authored-by: Srini Iyer <[email protected]>

Files changed (1) hide show
  1. bytelatent/transformer.py +6 -6
bytelatent/transformer.py CHANGED
@@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
146
  group_plan.append(("output", True))
147
  else:
148
  for i in range(model_args.n_layers_local_encoder):
149
- group_plan.append((f"local_encoder.layers.{i}", True))
150
- group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
151
  for i in range(model_args.n_layers_local_decoder):
152
- group_plan.append((f"local_decoder.layers.{i}", True))
153
- group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
154
  for i in range(model_args.n_layers_global):
155
- group_plan.append((f"global_transformer.layers.{i}", True))
156
 
157
  for i in range(len(model_args.encoder_hash_byte_group_size)):
158
- group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
159
 
160
  return group_plan
161
 
 
146
  group_plan.append(("output", True))
147
  else:
148
  for i in range(model_args.n_layers_local_encoder):
149
+ group_plan.append((f"local_encoder.layers.{i}", False))
150
+ group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
151
  for i in range(model_args.n_layers_local_decoder):
152
+ group_plan.append((f"local_decoder.layers.{i}", False))
153
+ group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
154
  for i in range(model_args.n_layers_global):
155
+ group_plan.append((f"global_transformer.layers.{i}", False))
156
 
157
  for i in range(len(model_args.encoder_hash_byte_group_size)):
158
+ group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
159
 
160
  return group_plan
161