par-meta commited on
Commit
fc3399e
·
unverified ·
1 Parent(s): b0956bd

Update iterator inheritance, pass file format args, limit iterator (#63)

Browse files

- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`

.gitignore CHANGED
@@ -168,3 +168,5 @@ figures/
168
  internal/
169
  jobs_parallel-copy/
170
  wandb/
 
 
 
168
  internal/
169
  jobs_parallel-copy/
170
  wandb/
171
+ *.ipynb
172
+
bytelatent/args.py CHANGED
@@ -72,6 +72,7 @@ def distribute_data_to_rank(
72
  arrow_batch_size: int,
73
  rank: int,
74
  world_size: int,
 
75
  s3_profile: str | None = None,
76
  file_pattern: str = TRAIN_DATA_FILE_PATTERN,
77
  ) -> ArrowFileIterator:
@@ -85,6 +86,7 @@ def distribute_data_to_rank(
85
  rank_to_arrow_iterator_params.append(
86
  ArrowFileIterator(
87
  file_path=chunk_path,
 
88
  worker_id=worker_id,
89
  num_workers=n_workers_per_chunk,
90
  preprocess_dir=preprocess_dir,
@@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel):
130
  entropy_model_name: str | None = "transformer_100m"
131
  arrow_batch_size: int = 100
132
  buffer_size: int = 64
 
133
 
134
  pad_to_max_length: bool = True
135
  max_encoder_seq_length: int = 12288
@@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel):
151
  for dataset_path in self.sources:
152
  shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
153
  arrow_iterator = distribute_data_to_rank(
 
154
  dataset_path=os.path.join(self.root_dir, dataset_path),
155
  preprocess_dir=self.preprocess_dir,
156
  entropy_model_name=self.entropy_model_name,
@@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel):
238
 
239
  class ValidationArgs(BaseModel):
240
  model_config = ConfigDict(extra="forbid")
241
- max_steps: int | None = (
242
  None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
243
  )
244
  use_val_from_train_src: bool = True # Use the validation set from training sources
@@ -248,8 +252,8 @@ class ValidationArgs(BaseModel):
248
 
249
  class EvalArgs(BaseModel):
250
  model_config = ConfigDict(extra="forbid")
251
- dump_dir: str
252
- ckpt_dir: str
253
  metric_log_dir: str | None = None
254
  generator: PackedCausalTransformerGeneratorArgs = (
255
  PackedCausalTransformerGeneratorArgs()
 
72
  arrow_batch_size: int,
73
  rank: int,
74
  world_size: int,
75
+ file_format: str,
76
  s3_profile: str | None = None,
77
  file_pattern: str = TRAIN_DATA_FILE_PATTERN,
78
  ) -> ArrowFileIterator:
 
86
  rank_to_arrow_iterator_params.append(
87
  ArrowFileIterator(
88
  file_path=chunk_path,
89
+ file_format=file_format,
90
  worker_id=worker_id,
91
  num_workers=n_workers_per_chunk,
92
  preprocess_dir=preprocess_dir,
 
132
  entropy_model_name: str | None = "transformer_100m"
133
  arrow_batch_size: int = 100
134
  buffer_size: int = 64
135
+ file_format: str = "arrow"
136
 
137
  pad_to_max_length: bool = True
138
  max_encoder_seq_length: int = 12288
 
154
  for dataset_path in self.sources:
155
  shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
156
  arrow_iterator = distribute_data_to_rank(
157
+ file_format=self.file_format,
158
  dataset_path=os.path.join(self.root_dir, dataset_path),
159
  preprocess_dir=self.preprocess_dir,
160
  entropy_model_name=self.entropy_model_name,
 
242
 
243
  class ValidationArgs(BaseModel):
244
  model_config = ConfigDict(extra="forbid")
245
+ max_n_docs: int | None = (
246
  None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
247
  )
248
  use_val_from_train_src: bool = True # Use the validation set from training sources
 
252
 
253
  class EvalArgs(BaseModel):
254
  model_config = ConfigDict(extra="forbid")
255
+ dump_dir: str | None = None
256
+ ckpt_dir: str | None = None
257
  metric_log_dir: str | None = None
258
  generator: PackedCausalTransformerGeneratorArgs = (
259
  PackedCausalTransformerGeneratorArgs()
bytelatent/data/iterators/abstract_iterator.py CHANGED
@@ -2,6 +2,8 @@
2
  import abc
3
  from typing import Any, Generator, Generic, TypeVar
4
 
 
 
5
  T = TypeVar("T")
6
  C = TypeVar("C")
7
 
@@ -23,6 +25,10 @@ class IteratorState(Generic[C]):
23
  pass
24
 
25
 
 
 
 
 
26
  def get_state_and_refresh(iterator: StatefulIterator):
27
  # Re-init dataloader and iterator is necessary since get_state()
28
  # on mp iterator shuts down MP to correctly persist state and it needs
 
2
  import abc
3
  from typing import Any, Generator, Generic, TypeVar
4
 
5
+ import pydantic
6
+
7
  T = TypeVar("T")
8
  C = TypeVar("C")
9
 
 
25
  pass
26
 
27
 
28
+ class PydanticIteratorState(pydantic.BaseModel, IteratorState):
29
+ model_config = pydantic.ConfigDict(extra="forbid")
30
+
31
+
32
  def get_state_and_refresh(iterator: StatefulIterator):
33
  # Re-init dataloader and iterator is necessary since get_state()
34
  # on mp iterator shuts down MP to correctly persist state and it needs
bytelatent/data/iterators/arrow_iterator.py CHANGED
@@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict
15
  from bytelatent import ByteLatentError
16
  from bytelatent.data.data_types import BltExample
17
  from bytelatent.data.file_util import get_fs
18
- from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
 
 
19
  from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
20
 
21
  logger = getLogger(__name__)
22
 
23
 
24
- class ArrowFileIteratorState(BaseModel, IteratorState):
25
  model_config = ConfigDict(extra="forbid")
26
  file_path: str | None
27
  row_num: int
@@ -110,39 +113,51 @@ class ArrowFileIterator(StatefulIterator):
110
  logger.info("Arrow iterator using fs=%s", self.fs)
111
 
112
  if dataset_files is None:
113
- # Prepare arrow shards
114
- jsonl_file = file_path
115
- parts = re.match(
116
- r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
117
- )
118
- assert parts is not None
119
- dataset = parts.group(1)
120
- data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
121
- data_dir_with_glob = os.path.join(
122
- data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
123
- )
124
- if self.fs is None:
125
- self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
126
- if isinstance(self.fs, s3fs.S3FileSystem):
127
- self.filesystem_type = "s3"
128
- else:
129
- self.filesystem_type = "file"
130
- shard_files = self.fs.glob(data_dir_with_glob)
131
-
132
- for s in shard_files:
133
- complete_file = os.path.join(
134
- data_dir, f"{os.path.basename(s)}.complete"
135
  )
 
 
 
 
 
 
 
136
 
137
- if not self.fs.exists(complete_file):
138
- raise ValueError(f"Missing .complete for input file: {s}")
 
 
139
 
140
- shard_files = sorted(shard_files, key=shard_sort_key)
141
- if len(shard_files) == 0:
142
- raise ByteLatentError(
143
- f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
144
- )
145
- self.dataset_files = [f for f in shard_files]
 
 
 
146
  else:
147
  self.preprocess_dir = None
148
  self.dataset_files = dataset_files
 
15
  from bytelatent import ByteLatentError
16
  from bytelatent.data.data_types import BltExample
17
  from bytelatent.data.file_util import get_fs
18
+ from bytelatent.data.iterators.abstract_iterator import (
19
+ PydanticIteratorState,
20
+ StatefulIterator,
21
+ )
22
  from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
23
 
24
  logger = getLogger(__name__)
25
 
26
 
27
+ class ArrowFileIteratorState(PydanticIteratorState):
28
  model_config = ConfigDict(extra="forbid")
29
  file_path: str | None
30
  row_num: int
 
113
  logger.info("Arrow iterator using fs=%s", self.fs)
114
 
115
  if dataset_files is None:
116
+ assert (
117
+ file_path is not None
118
+ ), "Must specify file_Path if dataset_files is None"
119
+ if file_format == "json":
120
+ if self.fs is None:
121
+ self.fs = get_fs(file_path, s3_profile=s3_profile)
122
+ if isinstance(self.fs, s3fs.S3FileSystem):
123
+ self.filesystem_type = "s3"
124
+ else:
125
+ self.filesystem_type = "file"
126
+ self.dataset_files = [file_path]
127
+ else:
128
+ # Prepare arrow shards
129
+ jsonl_file = file_path
130
+ parts = re.match(
131
+ r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
132
+ )
133
+ assert parts is not None
134
+ dataset = parts.group(1)
135
+ data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
136
+ data_dir_with_glob = os.path.join(
137
+ data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
138
  )
139
+ if self.fs is None:
140
+ self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
141
+ if isinstance(self.fs, s3fs.S3FileSystem):
142
+ self.filesystem_type = "s3"
143
+ else:
144
+ self.filesystem_type = "file"
145
+ shard_files = self.fs.glob(data_dir_with_glob)
146
 
147
+ for s in shard_files:
148
+ complete_file = os.path.join(
149
+ data_dir, f"{os.path.basename(s)}.complete"
150
+ )
151
 
152
+ if not self.fs.exists(complete_file):
153
+ raise ValueError(f"Missing .complete for input file: {s}")
154
+
155
+ shard_files = sorted(shard_files, key=shard_sort_key)
156
+ if len(shard_files) == 0:
157
+ raise ByteLatentError(
158
+ f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
159
+ )
160
+ self.dataset_files = [f for f in shard_files]
161
  else:
162
  self.preprocess_dir = None
163
  self.dataset_files = dataset_files
bytelatent/data/iterators/dev_iterators.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from pydantic import ConfigDict
3
+
4
+ from bytelatent.data.data_types import BltExample
5
+ from bytelatent.data.iterators.abstract_iterator import (
6
+ PydanticIteratorState,
7
+ StatefulIterator,
8
+ )
9
+
10
+
11
+ class BltTestIteratorState(PydanticIteratorState):
12
+ model_config = ConfigDict(extra="forbid")
13
+ position: int
14
+ total: int
15
+
16
+ def build(self):
17
+ blt_iter = BltTestIteratorState(total=self.total)
18
+ blt_iter.position = self.position
19
+ return blt_iter
20
+
21
+
22
+ class BltTestIterator(StatefulIterator):
23
+ def __init__(self, total: int):
24
+ self.position = 0
25
+ self.total = total
26
+
27
+ def get_state(self):
28
+ return BltTestIteratorState(position=self.position, total=self.total)
29
+
30
+ def create_iter(self):
31
+ for i in range(self.total):
32
+ self.position += 1
33
+ yield BltExample(
34
+ sample_id=f"test_{i}",
35
+ text=f"This is some test {i} text.",
36
+ tokens=None,
37
+ mask=None,
38
+ entropies=None,
39
+ patch_lengths=None,
40
+ )
41
+
42
+
43
+ class BltTestWithEntropiesIteratorState(PydanticIteratorState):
44
+ model_config = ConfigDict(extra="forbid")
45
+ position: int
46
+ total: int
47
+
48
+ def build(self):
49
+ blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
50
+ blt_iter.position = self.position
51
+ return blt_iter
52
+
53
+
54
+ class BltTestWithEntropiesIterator(StatefulIterator):
55
+ def __init__(self, total: int):
56
+ self.position = 0
57
+ self.total = total
58
+
59
+ def get_state(self):
60
+ return BltTestIteratorState(position=self.position, total=self.total)
61
+
62
+ def create_iter(self):
63
+ text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
64
+ df = pd.read_json("fixtures/tokens_with_entropies.json")
65
+ tokens = df["token_ids"].tolist()
66
+ entropies = df["entropies"].tolist()
67
+ # BOS and EOS
68
+ assert len(tokens) == len(text) + 2
69
+ for i in range(self.total):
70
+ self.position += 1
71
+ yield BltExample(
72
+ sample_id=f"test_{i}",
73
+ text=text,
74
+ tokens=tokens,
75
+ mask=[True] * len(tokens),
76
+ entropies=entropies,
77
+ patch_lengths=None,
78
+ )
bytelatent/data/iterators/limit_iterator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import ConfigDict
2
+
3
+ from bytelatent.data.iterators.abstract_iterator import (
4
+ PydanticIteratorState,
5
+ StatefulIterator,
6
+ )
7
+ from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
8
+ from bytelatent.data.iterators.dev_iterators import BltTestIteratorState
9
+
10
+
11
+ class LimitIteratorState(PydanticIteratorState):
12
+ model_config = ConfigDict(extra="forbid")
13
+ base_iterator_state: (
14
+ BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
15
+ )
16
+ n_yielded: int
17
+ limit: int
18
+
19
+ def build(self) -> "LimitIterator":
20
+ return LimitIterator(
21
+ base_iterator=self.base_iterator_state.build(),
22
+ n_yielded=self.n_yielded,
23
+ limit=self.limit,
24
+ )
25
+
26
+
27
+ class LimitIterator(StatefulIterator):
28
+ def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
29
+ self.base_iterator = base_iterator
30
+ self.n_yielded = n_yielded
31
+ self.limit = limit
32
+
33
+ def get_state(self):
34
+ return LimitIteratorState(
35
+ base_iterator_state=self.base_iterator.get_state(),
36
+ n_yielded=self.n_yielded,
37
+ limit=self.limit,
38
+ )
39
+
40
+ def create_iter(self):
41
+ iterator = self.base_iterator.create_iter()
42
+ try:
43
+ while self.n_yielded < self.limit or self.limit < 0:
44
+ yield next(iterator)
45
+ self.n_yielded += 1
46
+ except StopIteration:
47
+ pass
bytelatent/data/iterators/looping_iterator.py CHANGED
@@ -1,14 +1,16 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- from pydantic import BaseModel
3
 
4
- from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
 
 
5
  from bytelatent.data.iterators.arrow_iterator import (
6
  ArrowFileIterator,
7
  ArrowFileIteratorState,
8
  )
9
 
10
 
11
- class LoopingIteratorState(BaseModel, IteratorState):
12
  file_iterator_state: ArrowFileIteratorState
13
  epoch: int
14
 
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
2
 
3
+ from bytelatent.data.iterators.abstract_iterator import (
4
+ PydanticIteratorState,
5
+ StatefulIterator,
6
+ )
7
  from bytelatent.data.iterators.arrow_iterator import (
8
  ArrowFileIterator,
9
  ArrowFileIteratorState,
10
  )
11
 
12
 
13
+ class LoopingIteratorState(PydanticIteratorState):
14
  file_iterator_state: ArrowFileIteratorState
15
  epoch: int
16
 
bytelatent/data/iterators/multiprocess_iterator.py CHANGED
@@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass
6
  from queue import Empty, Full
7
 
8
  import numpy as np
9
- from pydantic import BaseModel, ConfigDict
10
 
11
  from bytelatent.data.data_types import Batch
12
- from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
 
 
 
13
  from bytelatent.data.iterators.packing_iterator import PackingIteratorState
14
 
15
  logger = logging.getLogger()
16
 
17
 
18
- class MultiprocessIteratorState(BaseModel, IteratorState):
19
  model_config = ConfigDict(extra="forbid")
20
  base_iterator_state: PackingIteratorState
21
  n_batches_to_prefetch: int
 
6
  from queue import Empty, Full
7
 
8
  import numpy as np
9
+ from pydantic import ConfigDict
10
 
11
  from bytelatent.data.data_types import Batch
12
+ from bytelatent.data.iterators.abstract_iterator import (
13
+ IteratorState,
14
+ PydanticIteratorState,
15
+ StatefulIterator,
16
+ )
17
  from bytelatent.data.iterators.packing_iterator import PackingIteratorState
18
 
19
  logger = logging.getLogger()
20
 
21
 
22
+ class MultiprocessIteratorState(PydanticIteratorState):
23
  model_config = ConfigDict(extra="forbid")
24
  base_iterator_state: PackingIteratorState
25
  n_batches_to_prefetch: int
bytelatent/data/iterators/packing_iterator.py CHANGED
@@ -5,7 +5,10 @@ import numpy as np
5
  from pydantic import BaseModel, ConfigDict
6
 
7
  from bytelatent.data.data_types import Batch, BltSequence
8
- from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
 
 
9
  from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
10
 
11
 
@@ -20,7 +23,7 @@ class PackingArgs(BaseModel):
20
  tokenizer_name: str
21
 
22
 
23
- class PackingIteratorState(BaseModel, IteratorState):
24
  model_config = ConfigDict(extra="forbid")
25
  sequence_iterator_state: SamplingIteratorState
26
  packing_args: PackingArgs
 
5
  from pydantic import BaseModel, ConfigDict
6
 
7
  from bytelatent.data.data_types import Batch, BltSequence
8
+ from bytelatent.data.iterators.abstract_iterator import (
9
+ PydanticIteratorState,
10
+ StatefulIterator,
11
+ )
12
  from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
13
 
14
 
 
23
  tokenizer_name: str
24
 
25
 
26
+ class PackingIteratorState(PydanticIteratorState):
27
  model_config = ConfigDict(extra="forbid")
28
  sequence_iterator_state: SamplingIteratorState
29
  packing_args: PackingArgs
bytelatent/data/iterators/preprocess_iterator.py CHANGED
@@ -5,20 +5,29 @@ import torch
5
  from pydantic import BaseModel, ConfigDict
6
 
7
  from bytelatent.data.data_types import BltExample
8
- from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
 
 
9
  from bytelatent.data.iterators.arrow_iterator import (
10
  ArrowFileIterator,
11
  ArrowFileIteratorState,
12
  )
13
- from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
 
 
 
 
14
  from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
15
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
16
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
17
 
18
 
19
- class PreprocessIteratorState(BaseModel, IteratorState):
20
  model_config = ConfigDict(extra="forbid")
21
- arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
 
 
22
  add_tokens: bool
23
  add_patches: bool
24
  tokenizer_args: TokenizerArgs
@@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator):
43
 
44
  def __init__(
45
  self,
46
- arrow_iterator: ArrowFileIterator,
47
  *,
48
  patcher_args: PatcherArgs,
49
  tokenizer_args: TokenizerArgs,
 
5
  from pydantic import BaseModel, ConfigDict
6
 
7
  from bytelatent.data.data_types import BltExample
8
+ from bytelatent.data.iterators.abstract_iterator import (
9
+ PydanticIteratorState,
10
+ StatefulIterator,
11
+ )
12
  from bytelatent.data.iterators.arrow_iterator import (
13
  ArrowFileIterator,
14
  ArrowFileIteratorState,
15
  )
16
+ from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
17
+ from bytelatent.data.iterators.looping_iterator import (
18
+ LoopingIterator,
19
+ LoopingIteratorState,
20
+ )
21
  from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
22
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
23
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
24
 
25
 
26
+ class PreprocessIteratorState(PydanticIteratorState):
27
  model_config = ConfigDict(extra="forbid")
28
+ arrow_file_iterator_state: (
29
+ ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
30
+ )
31
  add_tokens: bool
32
  add_patches: bool
33
  tokenizer_args: TokenizerArgs
 
52
 
53
  def __init__(
54
  self,
55
+ arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
56
  *,
57
  patcher_args: PatcherArgs,
58
  tokenizer_args: TokenizerArgs,
bytelatent/data/iterators/sampling_iterator.py CHANGED
@@ -2,13 +2,16 @@
2
  from typing import Any
3
 
4
  import numpy as np
5
- from pydantic import BaseModel, ConfigDict
6
 
7
- from bytelatent.data.iterators.abstract_iterator import StatefulIterator
 
 
 
8
  from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
9
 
10
 
11
- class SamplingIteratorState(BaseModel):
12
  model_config = ConfigDict(extra="forbid")
13
  rng_state: dict[str, Any]
14
  source_to_weight: dict[str, float]
 
2
  from typing import Any
3
 
4
  import numpy as np
5
+ from pydantic import ConfigDict
6
 
7
+ from bytelatent.data.iterators.abstract_iterator import (
8
+ PydanticIteratorState,
9
+ StatefulIterator,
10
+ )
11
  from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
12
 
13
 
14
+ class SamplingIteratorState(PydanticIteratorState):
15
  model_config = ConfigDict(extra="forbid")
16
  rng_state: dict[str, Any]
17
  source_to_weight: dict[str, float]
bytelatent/data/iterators/sequence_iterator.py CHANGED
@@ -6,7 +6,10 @@ import numpy as np
6
  from pydantic import BaseModel, ConfigDict
7
 
8
  from bytelatent.data.data_types import BltSequence
9
- from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
 
 
10
  from bytelatent.data.iterators.preprocess_iterator import (
11
  PreprocessIterator,
12
  PreprocessIteratorState,
@@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel):
21
  buffer_size: int
22
 
23
 
24
- class SequenceIteratorState(BaseModel, IteratorState):
25
  model_config = ConfigDict(extra="forbid")
26
  sequence_packing_args: SequencePackingArgs
27
  preprocess_iterator_state: PreprocessIteratorState
28
- rng_state: dict[str, Any]
 
29
 
30
  def build(self):
31
  preprocess_iterator = self.preprocess_iterator_state.build()
@@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator):
41
  self,
42
  preprocess_iterator: PreprocessIterator,
43
  *,
44
- rng_state: dict[str, Any],
45
  sequence_packing_args: SequencePackingArgs,
46
  ):
47
  self.preprocess_iterator = preprocess_iterator
48
  self.sequence_packing_args = sequence_packing_args
49
  self.output_seq_len = sequence_packing_args.output_seq_len
50
  self.buffer_size = sequence_packing_args.buffer_size
51
- self.rng = np.random.default_rng()
52
- self.rng.bit_generator.state = rng_state
 
 
 
53
 
54
  def get_state(self):
55
  # TODO: need to also perist the current shuffle buffer
56
  return SequenceIteratorState(
57
  sequence_packing_args=self.sequence_packing_args,
58
  preprocess_iterator_state=self.preprocess_iterator.get_state(),
59
- rng_state=self.rng.bit_generator.state,
60
  )
61
 
62
  def create_iter(self):
@@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator):
114
 
115
  seq_patch_lengths: list[list[int]] = x_patches.tolist()
116
  assert len(seq_patch_lengths) == self.buffer_size
117
- for idx in self.rng.permutation(len(seq_patch_lengths)):
 
 
 
 
 
118
  assert len(seq_patch_lengths[idx]) == self.output_seq_len
119
  assert (
120
  sum(seq_patch_lengths[idx])
 
6
  from pydantic import BaseModel, ConfigDict
7
 
8
  from bytelatent.data.data_types import BltSequence
9
+ from bytelatent.data.iterators.abstract_iterator import (
10
+ PydanticIteratorState,
11
+ StatefulIterator,
12
+ )
13
  from bytelatent.data.iterators.preprocess_iterator import (
14
  PreprocessIterator,
15
  PreprocessIteratorState,
 
24
  buffer_size: int
25
 
26
 
27
+ class SequenceIteratorState(PydanticIteratorState):
28
  model_config = ConfigDict(extra="forbid")
29
  sequence_packing_args: SequencePackingArgs
30
  preprocess_iterator_state: PreprocessIteratorState
31
+ # If None, rng is disabled.
32
+ rng_state: dict[str, Any] | None
33
 
34
  def build(self):
35
  preprocess_iterator = self.preprocess_iterator_state.build()
 
45
  self,
46
  preprocess_iterator: PreprocessIterator,
47
  *,
48
+ rng_state: dict[str, Any] | None,
49
  sequence_packing_args: SequencePackingArgs,
50
  ):
51
  self.preprocess_iterator = preprocess_iterator
52
  self.sequence_packing_args = sequence_packing_args
53
  self.output_seq_len = sequence_packing_args.output_seq_len
54
  self.buffer_size = sequence_packing_args.buffer_size
55
+ if rng_state is None:
56
+ self.rng = None
57
+ else:
58
+ self.rng = np.random.default_rng()
59
+ self.rng.bit_generator.state = rng_state
60
 
61
  def get_state(self):
62
  # TODO: need to also perist the current shuffle buffer
63
  return SequenceIteratorState(
64
  sequence_packing_args=self.sequence_packing_args,
65
  preprocess_iterator_state=self.preprocess_iterator.get_state(),
66
+ rng_state=None if self.rng is None else self.rng.bit_generator.state,
67
  )
68
 
69
  def create_iter(self):
 
121
 
122
  seq_patch_lengths: list[list[int]] = x_patches.tolist()
123
  assert len(seq_patch_lengths) == self.buffer_size
124
+ if self.rng is None:
125
+ permutations = list(range(len(seq_patch_lengths)))
126
+ else:
127
+ permutations = self.rng.permutation(len(seq_patch_lengths))
128
+
129
+ for idx in permutations:
130
  assert len(seq_patch_lengths[idx]) == self.output_seq_len
131
  assert (
132
  sum(seq_patch_lengths[idx])
bytelatent/data/iterators/test_arrow_iterator.py CHANGED
@@ -6,7 +6,10 @@ import pyarrow as pa
6
  import pyarrow.dataset # pyright: ignore
7
 
8
  from bytelatent.constants import BLT_DATA
9
- from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
 
 
 
10
 
11
  ENTROPY_MODEL = "transformer_100m"
12
  ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
@@ -93,3 +96,19 @@ def test_basic_arrow_file():
93
  i += 1
94
  if i >= len(expected_ids):
95
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import pyarrow.dataset # pyright: ignore
7
 
8
  from bytelatent.constants import BLT_DATA
9
+ from bytelatent.data.iterators.arrow_iterator import (
10
+ ArrowFileIterator,
11
+ ArrowFileIteratorState,
12
+ )
13
 
14
  ENTROPY_MODEL = "transformer_100m"
15
  ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
 
96
  i += 1
97
  if i >= len(expected_ids):
98
  break
99
+
100
+
101
+ def test_read_jsonl_from_arrow():
102
+ arrow_iterator = ArrowFileIterator(
103
+ file_path="fixtures/test_docs.jsonl",
104
+ num_workers=1,
105
+ worker_id=0,
106
+ preprocess_dir=None,
107
+ entropy_model_name=None,
108
+ file_format="json",
109
+ arrow_batch_size=100,
110
+ )
111
+ iterator = arrow_iterator.create_iter()
112
+ for i, example in enumerate(iterator):
113
+ assert example.sample_id == str(i)
114
+ assert example.text == f"test_{i}"
bytelatent/data/iterators/test_iters.py CHANGED
@@ -1,83 +1,15 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- import pandas as pd
3
- from pydantic import BaseModel
4
 
5
  from bytelatent.constants import BLT_DATA
6
- from bytelatent.data.data_types import BltExample
7
- from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
 
8
  from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
9
  from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
10
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
11
 
12
 
13
- class BltTestIteratorState(BaseModel, IteratorState):
14
- position: int
15
- total: int
16
-
17
- def build(self):
18
- blt_iter = BltTestIteratorState(total=self.total)
19
- blt_iter.position = self.position
20
- return blt_iter
21
-
22
-
23
- class BltTestIterator(StatefulIterator):
24
- def __init__(self, total: int):
25
- self.position = 0
26
- self.total = total
27
-
28
- def get_state(self):
29
- return BltTestIteratorState(position=self.position, total=self.total)
30
-
31
- def create_iter(self):
32
- for i in range(self.total):
33
- self.position += 1
34
- yield BltExample(
35
- sample_id=f"test_{i}",
36
- text=f"This is some test {i} text.",
37
- tokens=None,
38
- mask=None,
39
- entropies=None,
40
- patch_lengths=None,
41
- )
42
-
43
-
44
- class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
45
- position: int
46
- total: int
47
-
48
- def build(self):
49
- blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
50
- blt_iter.position = self.position
51
- return blt_iter
52
-
53
-
54
- class BltTestWithEntropiesIterator(StatefulIterator):
55
- def __init__(self, total: int):
56
- self.position = 0
57
- self.total = total
58
-
59
- def get_state(self):
60
- return BltTestIteratorState(position=self.position, total=self.total)
61
-
62
- def create_iter(self):
63
- text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
64
- df = pd.read_json("fixtures/tokens_with_entropies.json")
65
- tokens = df["token_ids"].tolist()
66
- entropies = df["entropies"].tolist()
67
- # BOS and EOS
68
- assert len(tokens) == len(text) + 2
69
- for i in range(self.total):
70
- self.position += 1
71
- yield BltExample(
72
- sample_id=f"test_{i}",
73
- text=text,
74
- tokens=tokens,
75
- mask=[True] * len(tokens),
76
- entropies=entropies,
77
- patch_lengths=None,
78
- )
79
-
80
-
81
  def test_preprocess_iter():
82
  total = 3
83
  tokenizer_args = TokenizerArgs(
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
 
2
 
3
  from bytelatent.constants import BLT_DATA
4
+ from bytelatent.data.iterators.dev_iterators import (
5
+ BltTestIterator,
6
+ BltTestWithEntropiesIterator,
7
+ )
8
  from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
9
  from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
10
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def test_preprocess_iter():
14
  total = 3
15
  tokenizer_args = TokenizerArgs(
bytelatent/data/iterators/test_limit_iterator.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bytelatent.data.iterators.dev_iterators import BltTestIterator
2
+ from bytelatent.data.iterators.limit_iterator import LimitIterator
3
+
4
+
5
+ def test_limit_iterator():
6
+ total = 10
7
+ limit = 5
8
+ base_iterator = BltTestIterator(total=total)
9
+ limit_iterator = LimitIterator(base_iterator, limit=limit)
10
+ iterator = limit_iterator.create_iter()
11
+ n = 0
12
+ for example in iterator:
13
+ assert example.sample_id == f"test_{n}"
14
+ n += 1
15
+ assert n == limit
16
+
17
+ limit = 10
18
+ base_iterator = BltTestIterator(total=total)
19
+ limit_iterator = LimitIterator(base_iterator, limit=limit)
20
+ iterator = limit_iterator.create_iter()
21
+ n = 0
22
+ for example in iterator:
23
+ assert example.sample_id == f"test_{n}"
24
+ n += 1
25
+ assert n == limit == total
26
+
27
+ limit = 20
28
+ base_iterator = BltTestIterator(total=total)
29
+ limit_iterator = LimitIterator(base_iterator, limit=limit)
30
+ iterator = limit_iterator.create_iter()
31
+ n = 0
32
+ for example in iterator:
33
+ assert example.sample_id == f"test_{n}"
34
+ n += 1
35
+ assert n == total
36
+
37
+ limit = -1
38
+ base_iterator = BltTestIterator(total=total)
39
+ limit_iterator = LimitIterator(base_iterator, limit=limit)
40
+ iterator = limit_iterator.create_iter()
41
+ n = 0
42
+ for example in iterator:
43
+ assert example.sample_id == f"test_{n}"
44
+ n += 1
45
+ assert n == total
fixtures/test_docs.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"sample_id": "0", "text": "test_0"}
2
+ {"sample_id": "1", "text": "test_1"}
3
+ {"sample_id": "2", "text": "test_2"}