Srinivasan Iyer sviyer commited on
Commit
48e4ad0
·
unverified ·
1 Parent(s): 22c7fe1

make sure max_encoder_seq_length matches (#55)

Browse files

* make sure max_encoder_seq_length matches

* black and assert comment

---------

Co-authored-by: Srini Iyer <[email protected]>

Files changed (1) hide show
  1. bytelatent/train.py +4 -1
bytelatent/train.py CHANGED
@@ -130,6 +130,9 @@ def validate_train_args(args: TrainArgs, output_size: int):
130
  if args.model is not None:
131
  logger.info(f"Setting model output size to {args.model.vocab_size}")
132
  args.model.vocab_size = output_size
 
 
 
133
 
134
  if args.entropy_model is not None:
135
  logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
@@ -610,7 +613,7 @@ def train(args: TrainArgs):
610
  interval_total_tok_loss_across_gpus = dist_sum(
611
  interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
612
  ).item()
613
- interval_total_n_bytes_per_gpu = n_bytes
614
  interval_total_n_bytes_across_gpus = dist_sum(
615
  n_bytes, reduce_dtype=torch.bfloat16
616
  ).item()
 
130
  if args.model is not None:
131
  logger.info(f"Setting model output size to {args.model.vocab_size}")
132
  args.model.vocab_size = output_size
133
+ assert (
134
+ args.model.max_encoder_seq_length == args.data.max_encoder_seq_length
135
+ ), "max_encoder_seq_length for model and data should match"
136
 
137
  if args.entropy_model is not None:
138
  logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
 
613
  interval_total_tok_loss_across_gpus = dist_sum(
614
  interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
615
  ).item()
616
+ interval_total_n_bytes_per_gpu = n_bytes.item()
617
  interval_total_n_bytes_across_gpus = dist_sum(
618
  n_bytes, reduce_dtype=torch.bfloat16
619
  ).item()