par-meta commited on
Commit
08b8c7c
·
unverified ·
1 Parent(s): 0da051f

Pass mask in packing_iterator, correctly handle last batch, fix masking (#65)

Browse files

This commit does/fixes the following:

1. Adds unit tests for byte and patch packing to ensure it works correctly
2. Fixes a bug where for batches that end up with <max_length number of bytes (e.g., short patches), the mask was including elements that had value pad_id. This fixes the mask by setting it to be !=pad_id, if its not specified.
3. Correctly handles the last batch, where previously it would crash. This didn't affect training since we had enough data and/or looped iterators, but for evaluation perplexity, it comes up if we validation on an entire file.
4. Correctly forward the mask if it exists for byte packing

Test Plan:

```
pytest bytelatent/
```

Testing these changes more thoroughly in a stacked PR that fixes evals

bytelatent/data/iterators/packing_iterator.py CHANGED
@@ -41,12 +41,12 @@ class PackingIteratorState(PydanticIteratorState):
41
  )
42
 
43
 
44
- def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]):
45
  assert len(mask_seqs) == bs
46
  lens = [len(m) for m in mask_seqs]
47
  if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
48
- return None
49
- assert slen == max(lens) - 1
50
  mask = np.zeros((bs, slen), dtype=bool)
51
  for i, m in enumerate(mask_seqs):
52
  if m is None:
@@ -176,28 +176,41 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
176
  while True:
177
  tokens: list[list[int]] = []
178
  masks: list[list[bool]] = []
179
-
180
- for _ in range(self.packing_args.batch_size):
181
- sequence = next(sequence_iter)
182
- _tokens = sequence.tokens
183
- _mask = sequence.mask
184
- assert (
185
- sequence.patch_lengths is None
186
- ), "patch_lengths should not be used in byte packing"
187
- tokens.append(_tokens)
188
- masks.append(_mask)
 
 
 
 
 
 
 
 
 
189
 
190
  x = np.full((batch_size, seq_len), fill_value=pad_id)
191
  y = np.full((batch_size, seq_len), fill_value=pad_id)
 
192
 
193
  for i, tok_seq in enumerate(tokens):
194
  x[i, : len(tok_seq)] = tok_seq
195
  y[i, : len(tok_seq) - 1] = tok_seq[1:]
196
- batch = Batch(x=x, y=y)
 
197
  assert (
198
  batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
199
  ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
200
  yield batch
 
 
201
 
202
  def _create_iter_from_patch_lengths(self):
203
  sequence_iter = self.sequence_iterator.create_iter()
@@ -207,29 +220,36 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
207
  pad_to_max_length = self.packing_args.pad_to_max_length
208
  enable_byte_ngrams = self.packing_args.enable_byte_ngrams
209
  max_length = self.packing_args.max_length
 
210
  while True:
211
  tokens: list[list[int]] = []
212
  masks: list[list[bool]] = []
213
  patch_lengths: list[list[int]] = []
214
-
215
- for _ in range(self.packing_args.batch_size):
216
- sequence = next(sequence_iter)
217
- _tokens = sequence.tokens
218
- _mask = sequence.mask
219
- _patch_lengths = sequence.patch_lengths
220
- assert (
221
- _patch_lengths is not None
222
- ), "patch lengths are required for packing based on patches."
223
- # Reminder: seq_len is in terms of patches
224
- assert len(sequence.patch_lengths) == self.packing_args.seq_len
225
- last_patch_length = 0
226
- if _patch_lengths[0] > 1:
227
- last_patch_length = _patch_lengths[-1]
228
- _patch_lengths[0] -= 1
229
- _patch_lengths = [1] + _patch_lengths[:-1]
230
- tokens.append(_tokens[: len(_tokens) - last_patch_length])
231
- masks.append(_mask[: len(_mask) - last_patch_length])
232
- patch_lengths.append(_patch_lengths)
 
 
 
 
 
 
233
 
234
  x_patch_lengths = np.array(patch_lengths)
235
  # pad batch to same length
@@ -257,6 +277,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
257
  ngram_ids=ngram_ids,
258
  mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
259
  )
 
260
  assert (
261
  x_patch_lengths.sum() == x.size + batch_size
262
  ), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
@@ -277,3 +298,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
277
  enable_byte_ngrams=enable_byte_ngrams,
278
  )
279
  yield batch
 
 
 
41
  )
42
 
43
 
44
+ def _merge_patch_seq_masks(bs: int, slen: int, mask_seqs: list[list[bool]]):
45
  assert len(mask_seqs) == bs
46
  lens = [len(m) for m in mask_seqs]
47
  if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
48
+ return np.ones((bs, slen), dtype=bool)
49
+ assert slen == max(lens) - 1, f"slen={slen} != max(lens)-1={max(lens) - 1}"
50
  mask = np.zeros((bs, slen), dtype=bool)
51
  for i, m in enumerate(mask_seqs):
52
  if m is None:
 
176
  while True:
177
  tokens: list[list[int]] = []
178
  masks: list[list[bool]] = []
179
+ stop_iteration = False
180
+ try:
181
+ for _ in range(self.packing_args.batch_size):
182
+ sequence = next(sequence_iter)
183
+ _tokens = sequence.tokens
184
+ _mask = sequence.mask
185
+ assert (
186
+ sequence.patch_lengths is None
187
+ ), "patch_lengths should not be used in byte packing"
188
+ tokens.append(_tokens)
189
+ masks.append(_mask)
190
+ except StopIteration:
191
+ # At this point, there will be no new sequences, so we need to stop
192
+ # after yielding the already accumulated data (one batch).
193
+ # In this case, either:
194
+ # 1. We have a complete batch, so things go as normal
195
+ # 2. We have an incomplete batch, but due to creating a right sized batch,
196
+ # then filling the values in, this case is automatically handled.
197
+ stop_iteration = True
198
 
199
  x = np.full((batch_size, seq_len), fill_value=pad_id)
200
  y = np.full((batch_size, seq_len), fill_value=pad_id)
201
+ m = np.zeros((batch_size, seq_len), dtype=np.bool)
202
 
203
  for i, tok_seq in enumerate(tokens):
204
  x[i, : len(tok_seq)] = tok_seq
205
  y[i, : len(tok_seq) - 1] = tok_seq[1:]
206
+ m[i, : len(tok_seq)] = masks[i]
207
+ batch = Batch(x=x, y=y, mask=m)
208
  assert (
209
  batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
210
  ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
211
  yield batch
212
+ if stop_iteration:
213
+ break
214
 
215
  def _create_iter_from_patch_lengths(self):
216
  sequence_iter = self.sequence_iterator.create_iter()
 
220
  pad_to_max_length = self.packing_args.pad_to_max_length
221
  enable_byte_ngrams = self.packing_args.enable_byte_ngrams
222
  max_length = self.packing_args.max_length
223
+ assert max_length is not None
224
  while True:
225
  tokens: list[list[int]] = []
226
  masks: list[list[bool]] = []
227
  patch_lengths: list[list[int]] = []
228
+ stop_iteration = False
229
+ try:
230
+ for _ in range(self.packing_args.batch_size):
231
+ sequence = next(sequence_iter)
232
+ _tokens = sequence.tokens
233
+ _mask = sequence.mask
234
+ _patch_lengths = sequence.patch_lengths
235
+ assert (
236
+ _patch_lengths is not None
237
+ ), "patch lengths are required for packing based on patches."
238
+ # Reminder: seq_len is in terms of patches
239
+ assert len(sequence.patch_lengths) == self.packing_args.seq_len
240
+ last_patch_length = 0
241
+ if _patch_lengths[0] > 1:
242
+ last_patch_length = _patch_lengths[-1]
243
+ _patch_lengths[0] -= 1
244
+ _patch_lengths = [1] + _patch_lengths[:-1]
245
+ tokens.append(_tokens[: len(_tokens) - last_patch_length])
246
+ masks.append(_mask[: len(_mask) - last_patch_length])
247
+ patch_lengths.append(_patch_lengths)
248
+ except StopIteration:
249
+ stop_iteration = True
250
+
251
+ if len(tokens) == 0 and stop_iteration:
252
+ break
253
 
254
  x_patch_lengths = np.array(patch_lengths)
255
  # pad batch to same length
 
277
  ngram_ids=ngram_ids,
278
  mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
279
  )
280
+
281
  assert (
282
  x_patch_lengths.sum() == x.size + batch_size
283
  ), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
 
298
  enable_byte_ngrams=enable_byte_ngrams,
299
  )
300
  yield batch
301
+ if stop_iteration:
302
+ break
bytelatent/data/iterators/test_packing_iterator.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from bytelatent.data.data_types import BltSequence
5
+ from bytelatent.data.iterators.abstract_iterator import StatefulIterator
6
+ from bytelatent.data.iterators.packing_iterator import (
7
+ PackingArgs,
8
+ PackingIterator,
9
+ PackingMode,
10
+ _merge_patch_seq_masks,
11
+ )
12
+
13
+
14
+ class DummySequenceIterator(StatefulIterator):
15
+ def __init__(
16
+ self,
17
+ *,
18
+ seq_len: int,
19
+ n_seqs: int,
20
+ patch_lengths: list[int] | None = None,
21
+ pad_id: int = 0,
22
+ ):
23
+ self.seq_len = seq_len
24
+ self.n_seqs = n_seqs
25
+ self.patch_lengths = patch_lengths
26
+ self.pad_id = pad_id
27
+
28
+ def get_state(self):
29
+ raise NotImplementedError()
30
+
31
+ def create_iter(self):
32
+ for i in range(self.n_seqs):
33
+ if self.patch_lengths is None:
34
+ tokens = np.arange(
35
+ i * self.seq_len + 1, (i + 1) * self.seq_len + 1
36
+ ).tolist()
37
+ mask = [True] * self.seq_len # type: ignore
38
+ assert len(tokens) == self.seq_len
39
+ else:
40
+ n = sum(self.patch_lengths)
41
+ tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist()
42
+ assert len(tokens) == n
43
+ mask = [True] * n
44
+ assert len(mask) == len(tokens)
45
+ yield BltSequence(
46
+ tokens=tokens,
47
+ mask=mask,
48
+ patch_lengths=self.patch_lengths,
49
+ )
50
+
51
+
52
+ def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int):
53
+ sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs)
54
+ packing_iterator = PackingIterator(
55
+ sequence_iterator,
56
+ packing_args=PackingArgs(
57
+ batch_size=batch_size,
58
+ seq_len=seq_len,
59
+ pad_id=pad_id,
60
+ packing_mode=PackingMode.BYTES,
61
+ max_length=None,
62
+ pad_to_max_length=False,
63
+ enable_byte_ngrams=False,
64
+ ),
65
+ )
66
+ return packing_iterator.create_iter()
67
+
68
+
69
+ def create_patches_iter(
70
+ *,
71
+ seq_len: int,
72
+ n_seqs: int,
73
+ batch_size: int,
74
+ pad_id: int,
75
+ patch_lengths: list[int] | None,
76
+ max_length: int,
77
+ ):
78
+ sequence_iterator = DummySequenceIterator(
79
+ # seq_len=number of bytes, which for blt/patches, is max_length since seq_len is
80
+ # in terms of number of patches
81
+ seq_len=max_length,
82
+ n_seqs=n_seqs,
83
+ patch_lengths=patch_lengths,
84
+ )
85
+ packing_iterator = PackingIterator(
86
+ sequence_iterator,
87
+ packing_args=PackingArgs(
88
+ batch_size=batch_size,
89
+ seq_len=seq_len,
90
+ pad_id=pad_id,
91
+ packing_mode=PackingMode.PATCHING,
92
+ max_length=max_length,
93
+ pad_to_max_length=True,
94
+ enable_byte_ngrams=False,
95
+ ),
96
+ )
97
+ return packing_iterator.create_iter()
98
+
99
+
100
+ def test_last_batch_correctness_bytes():
101
+ seq_len = 1024
102
+ n_seqs = 10
103
+ batch_size = 4
104
+ pad_id = 0
105
+ iterator = create_bytes_iter(
106
+ seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
107
+ )
108
+ batches = []
109
+ n_nonpad = 0
110
+ n_nonmask = 0
111
+ for b in iterator:
112
+ assert b.x.shape[0] == batch_size
113
+ assert b.x.shape[1] == seq_len
114
+ n_nonpad += (b.x != pad_id).sum()
115
+ if b.mask is None:
116
+ n_nonmask += b.x.size
117
+ else:
118
+ n_nonmask += b.mask.sum()
119
+ batches.append(b)
120
+ assert len(batches) == 3
121
+ assert n_nonpad == n_nonmask == seq_len * n_seqs
122
+ # The second half of the last batch should be all pads
123
+ assert batches[-1].mask[2:].sum() == 0
124
+
125
+
126
+ def test_edgecase_batch_correctness_bytes():
127
+ seq_len = 1024
128
+ n_seqs = 10
129
+ batch_size = 12
130
+ pad_id = 0
131
+ iterator = create_bytes_iter(
132
+ seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
133
+ )
134
+ batches = []
135
+ n_nonpad = 0
136
+ n_nonmask = 0
137
+ for b in iterator:
138
+ assert b.x.shape[0] == batch_size
139
+ assert b.x.shape[1] == seq_len
140
+ n_nonpad += (b.x != pad_id).sum()
141
+ if b.mask is None:
142
+ n_nonmask += b.x.size
143
+ else:
144
+ n_nonmask += b.mask.sum()
145
+ batches.append(b)
146
+ assert len(batches) == 1
147
+ assert n_nonpad == n_nonmask == seq_len * n_seqs
148
+ # The second half of the last batch should be all pads
149
+ assert batches[0].mask[10:].sum() == 0
150
+
151
+
152
+ def test_exact_batch_correctness_bytes():
153
+ seq_len = 1024
154
+ n_seqs = 12
155
+ batch_size = 4
156
+ pad_id = 0
157
+ iterator = create_bytes_iter(
158
+ seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
159
+ )
160
+ batches = []
161
+ n_nonpad = 0
162
+ n_nonmask = 0
163
+ for b in iterator:
164
+ assert b.x.shape[0] == batch_size
165
+ assert b.x.shape[1] == seq_len
166
+ n_nonpad += (b.x != pad_id).sum()
167
+ if b.mask is None:
168
+ n_nonmask += b.x.size
169
+ else:
170
+ n_nonmask += b.mask.sum()
171
+ batches.append(b)
172
+ assert len(batches) == 4
173
+ assert n_nonpad == n_nonmask == seq_len * n_seqs
174
+
175
+
176
+ def test_exact_batch_correctness_patches():
177
+ # First patch length is forced to be 1
178
+ patch_lengths = [1, 255, 256, 256, 256]
179
+ # Recall: This is in terms of bytes
180
+ max_length = 1024
181
+ # Recall: This is in terms of patches
182
+ seq_len = 5
183
+ n_seqs = 12
184
+ batch_size = 4
185
+ pad_id = 0
186
+ iterator = create_patches_iter(
187
+ seq_len=seq_len,
188
+ n_seqs=n_seqs,
189
+ batch_size=batch_size,
190
+ pad_id=pad_id,
191
+ patch_lengths=patch_lengths,
192
+ max_length=max_length,
193
+ )
194
+ batches = []
195
+ n_nonpad = 0
196
+ n_nonmask = 0
197
+ for batch in iterator:
198
+ assert batch.x.shape[0] == batch_size
199
+ assert batch.x.shape[1] == max_length
200
+ n_nonpad += (batch.x != pad_id).sum()
201
+ if batch.mask is None:
202
+ n_nonmask += batch.x.size
203
+ else:
204
+ n_nonmask += batch.mask.sum()
205
+ batches.append(batch)
206
+
207
+ assert len(batches) == 3
208
+
209
+ # max_length - 1 is due to chopping off the last byte for
210
+ # having a y target
211
+ assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs
212
+
213
+
214
+ def test_short_batch_correctness_patches():
215
+ # First patch length is forced to be 1
216
+ # Total=48
217
+ patch_lengths = [1, 11, 12, 12, 12]
218
+ # Recall: This is in terms of bytes
219
+ max_length = 1024
220
+ # Recall: This is in terms of patches
221
+ seq_len = 5
222
+ n_seqs = 12
223
+ batch_size = 4
224
+ pad_id = 0
225
+ iterator = create_patches_iter(
226
+ seq_len=seq_len,
227
+ n_seqs=n_seqs,
228
+ batch_size=batch_size,
229
+ pad_id=pad_id,
230
+ patch_lengths=patch_lengths,
231
+ max_length=max_length,
232
+ )
233
+ batches = []
234
+ n_nonpad = 0
235
+ n_nonmask = 0
236
+ for batch in iterator:
237
+ assert batch.x.shape[0] == batch_size
238
+ assert batch.x.shape[1] == max_length
239
+ n_nonpad += (batch.x != pad_id).sum()
240
+ if batch.mask is None:
241
+ n_nonmask += batch.x.size
242
+ else:
243
+ n_nonmask += batch.mask.sum()
244
+ batches.append(batch)
245
+
246
+ assert len(batches) == 3
247
+
248
+ # We'll still always have one byte chopped off the end
249
+ assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs)
250
+
251
+
252
+ def test_long_batch_correctness_patches():
253
+ # First patch length is forced to be 1
254
+ # Total=48
255
+ patch_lengths = [1, 255, 256, 256, 1024]
256
+ # Recall: This is in terms of bytes
257
+ max_length = 1024
258
+ # Recall: This is in terms of patches
259
+ seq_len = 5
260
+ n_seqs = 12
261
+ batch_size = 4
262
+ pad_id = 0
263
+ iterator = create_patches_iter(
264
+ seq_len=seq_len,
265
+ n_seqs=n_seqs,
266
+ batch_size=batch_size,
267
+ pad_id=pad_id,
268
+ patch_lengths=patch_lengths,
269
+ max_length=max_length,
270
+ )
271
+ batches = []
272
+ n_nonpad = 0
273
+ n_nonmask = 0
274
+ for batch in iterator:
275
+ assert batch.x.shape[0] == batch_size
276
+ assert batch.x.shape[1] == max_length
277
+ n_nonpad += (batch.x != pad_id).sum()
278
+ if batch.mask is None:
279
+ n_nonmask += batch.x.size
280
+ else:
281
+ n_nonmask += batch.mask.sum()
282
+ batches.append(batch)
283
+
284
+ assert len(batches) == 3
285
+
286
+ # No chop here since the next byte is available
287
+ assert n_nonpad == n_nonmask == max_length * n_seqs
288
+
289
+
290
+ def test_merge_patch_seq_masks():
291
+ batch_size = 4
292
+ seq_len = 1024
293
+ masks = []
294
+ masks.append([True] * 1025)
295
+ masks.append([True] * 512)
296
+ masks.append([True] * 256)
297
+ masks.append([True] * 10)
298
+ expected_mask = np.zeros((batch_size, seq_len), dtype=bool)
299
+ expected_mask[0, :] = True
300
+ expected_mask[1, :511] = True
301
+ expected_mask[2, :255] = True
302
+ expected_mask[3, :9] = True
303
+ merged_mask = _merge_patch_seq_masks(batch_size, seq_len, masks)
304
+ assert (merged_mask == expected_mask).all()
305
+
306
+ with pytest.raises(AssertionError):
307
+ masks = []
308
+ masks.append([True] * 1024)
309
+ masks.append([True] * 512)
310
+ masks.append([True] * 256)
311
+ masks.append([True] * 10)
312
+ _merge_patch_seq_masks(batch_size, seq_len, masks)