par-meta commited on
Commit
083656c
·
unverified ·
1 Parent(s): f84ee63

Update ppl evals to work with blt model, in addition to entropy model (#82)

Browse files

Summary:

Test Plan:

Run
```
python -m bytelatent.eval config=../internal-blt/configs/eval_blt.yaml validation.max_n_docs=null
python -m bytelatent.eval config=../internal-blt/configs/eval_entropy.yaml validation.max_n_docs=null
```

bytelatent/args.py CHANGED
@@ -263,6 +263,7 @@ class ValidationArgs(BaseModel):
263
  use_val_from_train_src: bool = True # Use the validation set from training sources
264
  root_dir: str = ""
265
  sources: list[str] = [] # Other sources to eval on
 
266
 
267
 
268
  class EvalArgs(BaseModel):
 
263
  use_val_from_train_src: bool = True # Use the validation set from training sources
264
  root_dir: str = ""
265
  sources: list[str] = [] # Other sources to eval on
266
+ batch_size: int = 8
267
 
268
 
269
  class EvalArgs(BaseModel):
bytelatent/data/iterators/packing_iterator.py CHANGED
@@ -221,6 +221,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
221
  enable_byte_ngrams = self.packing_args.enable_byte_ngrams
222
  max_length = self.packing_args.max_length
223
  assert max_length is not None
 
224
  while True:
225
  tokens: list[list[int]] = []
226
  masks: list[list[bool]] = []
@@ -252,6 +253,9 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
252
  break
253
 
254
  x_patch_lengths = np.array(patch_lengths)
 
 
 
255
  # pad batch to same length
256
  tok_seq_len = max([len(toks) for toks in tokens]) - 1
257
  x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
@@ -263,7 +267,30 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
263
  # Adjust patch lengths to match x
264
  x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
265
 
266
- assert x_patch_lengths.shape == (batch_size, seq_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  if enable_byte_ngrams:
269
  raise NotImplementedError()
 
221
  enable_byte_ngrams = self.packing_args.enable_byte_ngrams
222
  max_length = self.packing_args.max_length
223
  assert max_length is not None
224
+ final_leftover_batch = False
225
  while True:
226
  tokens: list[list[int]] = []
227
  masks: list[list[bool]] = []
 
253
  break
254
 
255
  x_patch_lengths = np.array(patch_lengths)
256
+ assert (
257
+ x_patch_lengths.shape[1] == seq_len
258
+ ), f"{x_patch_lengths.shape[1]} vs {seq_len}"
259
  # pad batch to same length
260
  tok_seq_len = max([len(toks) for toks in tokens]) - 1
261
  x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
 
267
  # Adjust patch lengths to match x
268
  x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
269
 
270
+ if x_patch_lengths.shape[0] < batch_size:
271
+ if final_leftover_batch:
272
+ raise ValueError(
273
+ "There should only be one partial batch, but found multiple"
274
+ )
275
+ final_leftover_batch = True
276
+ assert len(masks) == len(x_patch_lengths)
277
+ n_missing = batch_size - x_patch_lengths.shape[0]
278
+ # Repeat the last patch length to validly pad it out, but
279
+ # update the mask to ignore the row
280
+ x_patch_lengths = np.vstack(
281
+ [
282
+ x_patch_lengths,
283
+ np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0),
284
+ ]
285
+ )
286
+ for _ in range(n_missing):
287
+ masks.append([0] * tok_seq_len)
288
+ assert len(masks) == batch_size
289
+
290
+ assert x_patch_lengths.shape == (
291
+ batch_size,
292
+ seq_len,
293
+ ), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}"
294
 
295
  if enable_byte_ngrams:
296
  raise NotImplementedError()
bytelatent/eval.py CHANGED
@@ -148,35 +148,25 @@ def eval_ppl_on_path(
148
  model: LMTransformer | ByteLatentTransformer,
149
  tokenizer_args: TokenizerArgs,
150
  patcher_args: PatcherArgs,
 
151
  add_patches: bool,
152
  path: str,
153
- batch_size: int,
154
  arrow_batch_size: int,
155
  max_n_docs: int | None,
156
  s3_profile: str | None = None,
157
  ):
158
  model.eval()
159
- tokenizer = tokenizer_args.build()
160
  seq_len = model.get_output_seq_len()
161
- chunks = find_and_sanitize_chunks(
162
- path,
163
- world_size=1,
164
- file_pattern="*.val.jsonl",
165
- s3_profile=s3_profile,
166
- )
167
- assert (
168
- len(chunks) == 1
169
- ), f"There should be only 1 chunk per validation file, but found: {chunks}"
170
- chunk = chunks[0]
171
  arrow_iterator = ArrowFileIterator(
172
- file_path=chunk,
173
- preprocess_dir=None,
174
  entropy_model_name=None,
175
  worker_id=world_rank,
176
  num_workers=world_size,
177
  arrow_batch_size=arrow_batch_size,
 
178
  s3_profile=s3_profile,
179
- file_format="json",
180
  )
181
  if max_n_docs is not None:
182
  arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
@@ -195,16 +185,6 @@ def eval_ppl_on_path(
195
  ),
196
  rng_state=None,
197
  )
198
- packing_args = PackingArgs(
199
- batch_size=batch_size,
200
- seq_len=seq_len,
201
- # TODO: make these seq lens worth with blt
202
- max_length=seq_len,
203
- pad_to_max_length=True,
204
- enable_byte_ngrams=False,
205
- pad_id=tokenizer.boe_id,
206
- packing_mode=PackingMode.BYTES,
207
- )
208
  packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
209
  total_loss = 0.0
210
  n_bytes = 0
@@ -213,9 +193,16 @@ def eval_ppl_on_path(
213
  x = torch.from_numpy(batch.x).cuda()
214
  y = torch.from_numpy(batch.y).cuda()
215
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
 
 
 
 
216
  if tokenizer_args.name in ["bytes", "blt"]:
217
  n_bytes += y.numel() if mask is None else mask.sum().item()
218
- pred = model(x)
 
 
 
219
  loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
220
  total_loss += loss.item()
221
  else:
@@ -234,82 +221,6 @@ def eval_ppl_on_path(
234
  }
235
 
236
 
237
- def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
238
- srcs = []
239
- for src in val_args.sources:
240
- path = os.path.join(val_args.root_dir, src)
241
- srcs.append(path)
242
-
243
- for src in train_cfg.data.sources:
244
- path = os.path.join(train_cfg.data.root_dir, src)
245
- srcs.append(path)
246
-
247
- path_to_iter = {}
248
- for path in srcs:
249
- chunks = find_and_sanitize_chunks(
250
- path,
251
- world_size=1,
252
- file_pattern="*.val.jsonl",
253
- s3_profile=train_cfg.data.s3_profile,
254
- )
255
- assert (
256
- len(chunks) == 1
257
- ), f"There should be only 1 chunk per validation file, but found: {chunks}"
258
- chunk = chunks[0]
259
- iterator = ArrowFileIterator(
260
- dataset_files=[chunk],
261
- file_path=None,
262
- preprocess_dir=None,
263
- entropy_model_name=None,
264
- worker_id=0,
265
- num_workers=1,
266
- arrow_batch_size=train_cfg.data.arrow_batch_size,
267
- s3_profile=train_cfg.data.s3_profile,
268
- file_format="json",
269
- )
270
- path_to_iter[path] = iterator
271
-
272
- max_gen_len = generator.max_gen_len
273
- # We temporarily lower max gen len
274
- generator.max_gen_len = 1
275
-
276
- all_val_metrics = {}
277
- for src in path_to_iter:
278
- example_iterator = path_to_iter[src].create_iter()
279
- texts = []
280
- logger.info(f"Running validation on {src}...")
281
- for step, example in enumerate(example_iterator):
282
- texts.append(example.text)
283
-
284
- _, loglikelihood, _ = generator.generate(texts)
285
-
286
- metrics = defaultdict(list)
287
- for i, ll in enumerate(loglikelihood):
288
- tmp = ll.sum().item()
289
- metrics["nll"].append(tmp)
290
- metrics["nll_per_token"].append(tmp / len(ll))
291
- metrics["nll_per_char"].append(tmp / len(texts[i]))
292
-
293
- metrics["avg_seqlen"].append(len(ll))
294
-
295
- for m in metrics:
296
- metrics[m] = sum(metrics[m]) / len(metrics[m])
297
- metrics.update(dist_mean_dict(metrics))
298
- logger.info(f"Validation on {src} done. Metrics: {metrics}")
299
-
300
- name = os.path.basename(src)
301
- if name in all_val_metrics:
302
- logger.warning(
303
- f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
304
- )
305
- name = f"{name}_1"
306
- all_val_metrics[name] = metrics
307
-
308
- generator.max_gen_len = max_gen_len
309
-
310
- return all_val_metrics
311
-
312
-
313
  def launch_eval(eval_args: EvalArgs):
314
  assert eval_args.dump_dir is not None
315
  assert eval_args.ckpt_dir is not None
@@ -342,17 +253,29 @@ def launch_eval(eval_args: EvalArgs):
342
 
343
  torch.distributed.barrier()
344
  logger.info("Loading model")
345
- # TODO: Make this general so that it works with either
346
- # LMTransformer or Blt, similar with args
347
  model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
348
  consolidate_path,
349
  )
 
350
  model.eval()
351
  logger.info("Model loaded")
352
 
353
  ppl_results = None
354
  if eval_args.run_ppl:
355
  assert eval_args.validation is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  if len(eval_args.validation.sources) > 0:
357
  ppl_results = {}
358
  logger.info("Starting PPL evaluation on validation sets")
@@ -362,14 +285,13 @@ def launch_eval(eval_args: EvalArgs):
362
  world_size=world_size,
363
  model=model,
364
  tokenizer_args=train_cfg.data.tokenizer_args,
365
- # TODO: Don't hardcode, modify based on model
366
- patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
367
- add_patches=False,
368
  path=os.path.join(eval_args.validation.root_dir, source),
369
  max_n_docs=eval_args.validation.max_n_docs,
370
- batch_size=8,
371
- arrow_batch_size=100,
372
- s3_profile="blt",
373
  )
374
 
375
  task_results = None
 
148
  model: LMTransformer | ByteLatentTransformer,
149
  tokenizer_args: TokenizerArgs,
150
  patcher_args: PatcherArgs,
151
+ packing_args: PackingArgs,
152
  add_patches: bool,
153
  path: str,
 
154
  arrow_batch_size: int,
155
  max_n_docs: int | None,
156
  s3_profile: str | None = None,
157
  ):
158
  model.eval()
 
159
  seq_len = model.get_output_seq_len()
 
 
 
 
 
 
 
 
 
 
160
  arrow_iterator = ArrowFileIterator(
161
+ file_path=None,
162
+ dataset_files=[path],
163
  entropy_model_name=None,
164
  worker_id=world_rank,
165
  num_workers=world_size,
166
  arrow_batch_size=arrow_batch_size,
167
+ preprocess_dir=None,
168
  s3_profile=s3_profile,
169
+ file_format="arrow" if path.endswith("arrow") else "json",
170
  )
171
  if max_n_docs is not None:
172
  arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
 
185
  ),
186
  rng_state=None,
187
  )
 
 
 
 
 
 
 
 
 
 
188
  packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
189
  total_loss = 0.0
190
  n_bytes = 0
 
193
  x = torch.from_numpy(batch.x).cuda()
194
  y = torch.from_numpy(batch.y).cuda()
195
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
196
+ patch_lengths = batch.patch_lengths
197
+ if patch_lengths is not None:
198
+ patch_lengths = torch.from_numpy(patch_lengths).cuda()
199
+
200
  if tokenizer_args.name in ["bytes", "blt"]:
201
  n_bytes += y.numel() if mask is None else mask.sum().item()
202
+ if isinstance(model, ByteLatentTransformer):
203
+ pred = model(x, patch_lengths=patch_lengths)
204
+ else:
205
+ pred = model(x)
206
  loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
207
  total_loss += loss.item()
208
  else:
 
221
  }
222
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def launch_eval(eval_args: EvalArgs):
225
  assert eval_args.dump_dir is not None
226
  assert eval_args.ckpt_dir is not None
 
253
 
254
  torch.distributed.barrier()
255
  logger.info("Loading model")
 
 
256
  model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
257
  consolidate_path,
258
  )
259
+ pad_id = 0 if train_cfg.data.tokenizer_args.name == "bytes" else tokenizer.boe_id
260
  model.eval()
261
  logger.info("Model loaded")
262
 
263
  ppl_results = None
264
  if eval_args.run_ppl:
265
  assert eval_args.validation is not None
266
+ packing_args = PackingArgs(
267
+ batch_size=eval_args.validation.batch_size,
268
+ seq_len=train_cfg.data.seq_len,
269
+ max_length=train_cfg.data.max_encoder_seq_length,
270
+ pad_to_max_length=True,
271
+ enable_byte_ngrams=False,
272
+ pad_id=pad_id,
273
+ packing_mode=(
274
+ PackingMode.BYTES
275
+ if train_cfg.data.patcher_args.patching_mode == PatchingModeEnum.byte
276
+ else PackingMode.PATCHING
277
+ ),
278
+ )
279
  if len(eval_args.validation.sources) > 0:
280
  ppl_results = {}
281
  logger.info("Starting PPL evaluation on validation sets")
 
285
  world_size=world_size,
286
  model=model,
287
  tokenizer_args=train_cfg.data.tokenizer_args,
288
+ patcher_args=train_cfg.data.patcher_args,
289
+ packing_args=packing_args,
290
+ add_patches=train_cfg.data.add_patches,
291
  path=os.path.join(eval_args.validation.root_dir, source),
292
  max_n_docs=eval_args.validation.max_n_docs,
293
+ arrow_batch_size=20,
294
+ s3_profile=eval_args.s3_profile,
 
295
  )
296
 
297
  task_results = None