Bocheng Li commited on
Commit
a6ed14f
·
unverified ·
1 Parent(s): fc3399e

Fix: Correct model_args usage in parallelize_model call (#69)

Browse files
Files changed (1) hide show
  1. bytelatent/train.py +1 -1
bytelatent/train.py CHANGED
@@ -317,7 +317,7 @@ def train(args: TrainArgs):
317
  model = parallelize_model(
318
  model,
319
  world_mesh,
320
- args.model,
321
  args.distributed,
322
  fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
323
  tp_parallelize=tp_parallelize,
 
317
  model = parallelize_model(
318
  model,
319
  world_mesh,
320
+ model_args,
321
  args.distributed,
322
  fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
323
  tp_parallelize=tp_parallelize,