Spaces:
Running
on
Zero
Running
on
Zero
Update iterator inheritance, pass file format args, limit iterator (#63)
Browse files- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)
Test Plan:
- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
- .gitignore +2 -0
- bytelatent/args.py +7 -3
- bytelatent/data/iterators/abstract_iterator.py +6 -0
- bytelatent/data/iterators/arrow_iterator.py +47 -32
- bytelatent/data/iterators/dev_iterators.py +78 -0
- bytelatent/data/iterators/limit_iterator.py +47 -0
- bytelatent/data/iterators/looping_iterator.py +5 -3
- bytelatent/data/iterators/multiprocess_iterator.py +7 -3
- bytelatent/data/iterators/packing_iterator.py +5 -2
- bytelatent/data/iterators/preprocess_iterator.py +14 -5
- bytelatent/data/iterators/sampling_iterator.py +6 -3
- bytelatent/data/iterators/sequence_iterator.py +20 -8
- bytelatent/data/iterators/test_arrow_iterator.py +20 -1
- bytelatent/data/iterators/test_iters.py +4 -72
- bytelatent/data/iterators/test_limit_iterator.py +45 -0
- fixtures/test_docs.jsonl +3 -0
.gitignore
CHANGED
@@ -168,3 +168,5 @@ figures/
|
|
168 |
internal/
|
169 |
jobs_parallel-copy/
|
170 |
wandb/
|
|
|
|
|
|
168 |
internal/
|
169 |
jobs_parallel-copy/
|
170 |
wandb/
|
171 |
+
*.ipynb
|
172 |
+
|
bytelatent/args.py
CHANGED
@@ -72,6 +72,7 @@ def distribute_data_to_rank(
|
|
72 |
arrow_batch_size: int,
|
73 |
rank: int,
|
74 |
world_size: int,
|
|
|
75 |
s3_profile: str | None = None,
|
76 |
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
77 |
) -> ArrowFileIterator:
|
@@ -85,6 +86,7 @@ def distribute_data_to_rank(
|
|
85 |
rank_to_arrow_iterator_params.append(
|
86 |
ArrowFileIterator(
|
87 |
file_path=chunk_path,
|
|
|
88 |
worker_id=worker_id,
|
89 |
num_workers=n_workers_per_chunk,
|
90 |
preprocess_dir=preprocess_dir,
|
@@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel):
|
|
130 |
entropy_model_name: str | None = "transformer_100m"
|
131 |
arrow_batch_size: int = 100
|
132 |
buffer_size: int = 64
|
|
|
133 |
|
134 |
pad_to_max_length: bool = True
|
135 |
max_encoder_seq_length: int = 12288
|
@@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel):
|
|
151 |
for dataset_path in self.sources:
|
152 |
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
153 |
arrow_iterator = distribute_data_to_rank(
|
|
|
154 |
dataset_path=os.path.join(self.root_dir, dataset_path),
|
155 |
preprocess_dir=self.preprocess_dir,
|
156 |
entropy_model_name=self.entropy_model_name,
|
@@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel):
|
|
238 |
|
239 |
class ValidationArgs(BaseModel):
|
240 |
model_config = ConfigDict(extra="forbid")
|
241 |
-
|
242 |
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
243 |
)
|
244 |
use_val_from_train_src: bool = True # Use the validation set from training sources
|
@@ -248,8 +252,8 @@ class ValidationArgs(BaseModel):
|
|
248 |
|
249 |
class EvalArgs(BaseModel):
|
250 |
model_config = ConfigDict(extra="forbid")
|
251 |
-
dump_dir: str
|
252 |
-
ckpt_dir: str
|
253 |
metric_log_dir: str | None = None
|
254 |
generator: PackedCausalTransformerGeneratorArgs = (
|
255 |
PackedCausalTransformerGeneratorArgs()
|
|
|
72 |
arrow_batch_size: int,
|
73 |
rank: int,
|
74 |
world_size: int,
|
75 |
+
file_format: str,
|
76 |
s3_profile: str | None = None,
|
77 |
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
78 |
) -> ArrowFileIterator:
|
|
|
86 |
rank_to_arrow_iterator_params.append(
|
87 |
ArrowFileIterator(
|
88 |
file_path=chunk_path,
|
89 |
+
file_format=file_format,
|
90 |
worker_id=worker_id,
|
91 |
num_workers=n_workers_per_chunk,
|
92 |
preprocess_dir=preprocess_dir,
|
|
|
132 |
entropy_model_name: str | None = "transformer_100m"
|
133 |
arrow_batch_size: int = 100
|
134 |
buffer_size: int = 64
|
135 |
+
file_format: str = "arrow"
|
136 |
|
137 |
pad_to_max_length: bool = True
|
138 |
max_encoder_seq_length: int = 12288
|
|
|
154 |
for dataset_path in self.sources:
|
155 |
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
156 |
arrow_iterator = distribute_data_to_rank(
|
157 |
+
file_format=self.file_format,
|
158 |
dataset_path=os.path.join(self.root_dir, dataset_path),
|
159 |
preprocess_dir=self.preprocess_dir,
|
160 |
entropy_model_name=self.entropy_model_name,
|
|
|
242 |
|
243 |
class ValidationArgs(BaseModel):
|
244 |
model_config = ConfigDict(extra="forbid")
|
245 |
+
max_n_docs: int | None = (
|
246 |
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
247 |
)
|
248 |
use_val_from_train_src: bool = True # Use the validation set from training sources
|
|
|
252 |
|
253 |
class EvalArgs(BaseModel):
|
254 |
model_config = ConfigDict(extra="forbid")
|
255 |
+
dump_dir: str | None = None
|
256 |
+
ckpt_dir: str | None = None
|
257 |
metric_log_dir: str | None = None
|
258 |
generator: PackedCausalTransformerGeneratorArgs = (
|
259 |
PackedCausalTransformerGeneratorArgs()
|
bytelatent/data/iterators/abstract_iterator.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
import abc
|
3 |
from typing import Any, Generator, Generic, TypeVar
|
4 |
|
|
|
|
|
5 |
T = TypeVar("T")
|
6 |
C = TypeVar("C")
|
7 |
|
@@ -23,6 +25,10 @@ class IteratorState(Generic[C]):
|
|
23 |
pass
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
26 |
def get_state_and_refresh(iterator: StatefulIterator):
|
27 |
# Re-init dataloader and iterator is necessary since get_state()
|
28 |
# on mp iterator shuts down MP to correctly persist state and it needs
|
|
|
2 |
import abc
|
3 |
from typing import Any, Generator, Generic, TypeVar
|
4 |
|
5 |
+
import pydantic
|
6 |
+
|
7 |
T = TypeVar("T")
|
8 |
C = TypeVar("C")
|
9 |
|
|
|
25 |
pass
|
26 |
|
27 |
|
28 |
+
class PydanticIteratorState(pydantic.BaseModel, IteratorState):
|
29 |
+
model_config = pydantic.ConfigDict(extra="forbid")
|
30 |
+
|
31 |
+
|
32 |
def get_state_and_refresh(iterator: StatefulIterator):
|
33 |
# Re-init dataloader and iterator is necessary since get_state()
|
34 |
# on mp iterator shuts down MP to correctly persist state and it needs
|
bytelatent/data/iterators/arrow_iterator.py
CHANGED
@@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict
|
|
15 |
from bytelatent import ByteLatentError
|
16 |
from bytelatent.data.data_types import BltExample
|
17 |
from bytelatent.data.file_util import get_fs
|
18 |
-
from bytelatent.data.iterators.abstract_iterator import
|
|
|
|
|
|
|
19 |
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
20 |
|
21 |
logger = getLogger(__name__)
|
22 |
|
23 |
|
24 |
-
class ArrowFileIteratorState(
|
25 |
model_config = ConfigDict(extra="forbid")
|
26 |
file_path: str | None
|
27 |
row_num: int
|
@@ -110,39 +113,51 @@ class ArrowFileIterator(StatefulIterator):
|
|
110 |
logger.info("Arrow iterator using fs=%s", self.fs)
|
111 |
|
112 |
if dataset_files is None:
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
data_dir, f"{os.path.basename(
|
135 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
|
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
)
|
145 |
-
|
|
|
|
|
|
|
146 |
else:
|
147 |
self.preprocess_dir = None
|
148 |
self.dataset_files = dataset_files
|
|
|
15 |
from bytelatent import ByteLatentError
|
16 |
from bytelatent.data.data_types import BltExample
|
17 |
from bytelatent.data.file_util import get_fs
|
18 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
19 |
+
PydanticIteratorState,
|
20 |
+
StatefulIterator,
|
21 |
+
)
|
22 |
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
23 |
|
24 |
logger = getLogger(__name__)
|
25 |
|
26 |
|
27 |
+
class ArrowFileIteratorState(PydanticIteratorState):
|
28 |
model_config = ConfigDict(extra="forbid")
|
29 |
file_path: str | None
|
30 |
row_num: int
|
|
|
113 |
logger.info("Arrow iterator using fs=%s", self.fs)
|
114 |
|
115 |
if dataset_files is None:
|
116 |
+
assert (
|
117 |
+
file_path is not None
|
118 |
+
), "Must specify file_Path if dataset_files is None"
|
119 |
+
if file_format == "json":
|
120 |
+
if self.fs is None:
|
121 |
+
self.fs = get_fs(file_path, s3_profile=s3_profile)
|
122 |
+
if isinstance(self.fs, s3fs.S3FileSystem):
|
123 |
+
self.filesystem_type = "s3"
|
124 |
+
else:
|
125 |
+
self.filesystem_type = "file"
|
126 |
+
self.dataset_files = [file_path]
|
127 |
+
else:
|
128 |
+
# Prepare arrow shards
|
129 |
+
jsonl_file = file_path
|
130 |
+
parts = re.match(
|
131 |
+
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
|
132 |
+
)
|
133 |
+
assert parts is not None
|
134 |
+
dataset = parts.group(1)
|
135 |
+
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
|
136 |
+
data_dir_with_glob = os.path.join(
|
137 |
+
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
|
138 |
)
|
139 |
+
if self.fs is None:
|
140 |
+
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
|
141 |
+
if isinstance(self.fs, s3fs.S3FileSystem):
|
142 |
+
self.filesystem_type = "s3"
|
143 |
+
else:
|
144 |
+
self.filesystem_type = "file"
|
145 |
+
shard_files = self.fs.glob(data_dir_with_glob)
|
146 |
|
147 |
+
for s in shard_files:
|
148 |
+
complete_file = os.path.join(
|
149 |
+
data_dir, f"{os.path.basename(s)}.complete"
|
150 |
+
)
|
151 |
|
152 |
+
if not self.fs.exists(complete_file):
|
153 |
+
raise ValueError(f"Missing .complete for input file: {s}")
|
154 |
+
|
155 |
+
shard_files = sorted(shard_files, key=shard_sort_key)
|
156 |
+
if len(shard_files) == 0:
|
157 |
+
raise ByteLatentError(
|
158 |
+
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
159 |
+
)
|
160 |
+
self.dataset_files = [f for f in shard_files]
|
161 |
else:
|
162 |
self.preprocess_dir = None
|
163 |
self.dataset_files = dataset_files
|
bytelatent/data/iterators/dev_iterators.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from pydantic import ConfigDict
|
3 |
+
|
4 |
+
from bytelatent.data.data_types import BltExample
|
5 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
6 |
+
PydanticIteratorState,
|
7 |
+
StatefulIterator,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
class BltTestIteratorState(PydanticIteratorState):
|
12 |
+
model_config = ConfigDict(extra="forbid")
|
13 |
+
position: int
|
14 |
+
total: int
|
15 |
+
|
16 |
+
def build(self):
|
17 |
+
blt_iter = BltTestIteratorState(total=self.total)
|
18 |
+
blt_iter.position = self.position
|
19 |
+
return blt_iter
|
20 |
+
|
21 |
+
|
22 |
+
class BltTestIterator(StatefulIterator):
|
23 |
+
def __init__(self, total: int):
|
24 |
+
self.position = 0
|
25 |
+
self.total = total
|
26 |
+
|
27 |
+
def get_state(self):
|
28 |
+
return BltTestIteratorState(position=self.position, total=self.total)
|
29 |
+
|
30 |
+
def create_iter(self):
|
31 |
+
for i in range(self.total):
|
32 |
+
self.position += 1
|
33 |
+
yield BltExample(
|
34 |
+
sample_id=f"test_{i}",
|
35 |
+
text=f"This is some test {i} text.",
|
36 |
+
tokens=None,
|
37 |
+
mask=None,
|
38 |
+
entropies=None,
|
39 |
+
patch_lengths=None,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class BltTestWithEntropiesIteratorState(PydanticIteratorState):
|
44 |
+
model_config = ConfigDict(extra="forbid")
|
45 |
+
position: int
|
46 |
+
total: int
|
47 |
+
|
48 |
+
def build(self):
|
49 |
+
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
50 |
+
blt_iter.position = self.position
|
51 |
+
return blt_iter
|
52 |
+
|
53 |
+
|
54 |
+
class BltTestWithEntropiesIterator(StatefulIterator):
|
55 |
+
def __init__(self, total: int):
|
56 |
+
self.position = 0
|
57 |
+
self.total = total
|
58 |
+
|
59 |
+
def get_state(self):
|
60 |
+
return BltTestIteratorState(position=self.position, total=self.total)
|
61 |
+
|
62 |
+
def create_iter(self):
|
63 |
+
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
64 |
+
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
65 |
+
tokens = df["token_ids"].tolist()
|
66 |
+
entropies = df["entropies"].tolist()
|
67 |
+
# BOS and EOS
|
68 |
+
assert len(tokens) == len(text) + 2
|
69 |
+
for i in range(self.total):
|
70 |
+
self.position += 1
|
71 |
+
yield BltExample(
|
72 |
+
sample_id=f"test_{i}",
|
73 |
+
text=text,
|
74 |
+
tokens=tokens,
|
75 |
+
mask=[True] * len(tokens),
|
76 |
+
entropies=entropies,
|
77 |
+
patch_lengths=None,
|
78 |
+
)
|
bytelatent/data/iterators/limit_iterator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import ConfigDict
|
2 |
+
|
3 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
4 |
+
PydanticIteratorState,
|
5 |
+
StatefulIterator,
|
6 |
+
)
|
7 |
+
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
8 |
+
from bytelatent.data.iterators.dev_iterators import BltTestIteratorState
|
9 |
+
|
10 |
+
|
11 |
+
class LimitIteratorState(PydanticIteratorState):
|
12 |
+
model_config = ConfigDict(extra="forbid")
|
13 |
+
base_iterator_state: (
|
14 |
+
BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
|
15 |
+
)
|
16 |
+
n_yielded: int
|
17 |
+
limit: int
|
18 |
+
|
19 |
+
def build(self) -> "LimitIterator":
|
20 |
+
return LimitIterator(
|
21 |
+
base_iterator=self.base_iterator_state.build(),
|
22 |
+
n_yielded=self.n_yielded,
|
23 |
+
limit=self.limit,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class LimitIterator(StatefulIterator):
|
28 |
+
def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
|
29 |
+
self.base_iterator = base_iterator
|
30 |
+
self.n_yielded = n_yielded
|
31 |
+
self.limit = limit
|
32 |
+
|
33 |
+
def get_state(self):
|
34 |
+
return LimitIteratorState(
|
35 |
+
base_iterator_state=self.base_iterator.get_state(),
|
36 |
+
n_yielded=self.n_yielded,
|
37 |
+
limit=self.limit,
|
38 |
+
)
|
39 |
+
|
40 |
+
def create_iter(self):
|
41 |
+
iterator = self.base_iterator.create_iter()
|
42 |
+
try:
|
43 |
+
while self.n_yielded < self.limit or self.limit < 0:
|
44 |
+
yield next(iterator)
|
45 |
+
self.n_yielded += 1
|
46 |
+
except StopIteration:
|
47 |
+
pass
|
bytelatent/data/iterators/looping_iterator.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
from pydantic import BaseModel
|
3 |
|
4 |
-
from bytelatent.data.iterators.abstract_iterator import
|
|
|
|
|
|
|
5 |
from bytelatent.data.iterators.arrow_iterator import (
|
6 |
ArrowFileIterator,
|
7 |
ArrowFileIteratorState,
|
8 |
)
|
9 |
|
10 |
|
11 |
-
class LoopingIteratorState(
|
12 |
file_iterator_state: ArrowFileIteratorState
|
13 |
epoch: int
|
14 |
|
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
2 |
|
3 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
4 |
+
PydanticIteratorState,
|
5 |
+
StatefulIterator,
|
6 |
+
)
|
7 |
from bytelatent.data.iterators.arrow_iterator import (
|
8 |
ArrowFileIterator,
|
9 |
ArrowFileIteratorState,
|
10 |
)
|
11 |
|
12 |
|
13 |
+
class LoopingIteratorState(PydanticIteratorState):
|
14 |
file_iterator_state: ArrowFileIteratorState
|
15 |
epoch: int
|
16 |
|
bytelatent/data/iterators/multiprocess_iterator.py
CHANGED
@@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass
|
|
6 |
from queue import Empty, Full
|
7 |
|
8 |
import numpy as np
|
9 |
-
from pydantic import
|
10 |
|
11 |
from bytelatent.data.data_types import Batch
|
12 |
-
from bytelatent.data.iterators.abstract_iterator import
|
|
|
|
|
|
|
|
|
13 |
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
14 |
|
15 |
logger = logging.getLogger()
|
16 |
|
17 |
|
18 |
-
class MultiprocessIteratorState(
|
19 |
model_config = ConfigDict(extra="forbid")
|
20 |
base_iterator_state: PackingIteratorState
|
21 |
n_batches_to_prefetch: int
|
|
|
6 |
from queue import Empty, Full
|
7 |
|
8 |
import numpy as np
|
9 |
+
from pydantic import ConfigDict
|
10 |
|
11 |
from bytelatent.data.data_types import Batch
|
12 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
13 |
+
IteratorState,
|
14 |
+
PydanticIteratorState,
|
15 |
+
StatefulIterator,
|
16 |
+
)
|
17 |
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
18 |
|
19 |
logger = logging.getLogger()
|
20 |
|
21 |
|
22 |
+
class MultiprocessIteratorState(PydanticIteratorState):
|
23 |
model_config = ConfigDict(extra="forbid")
|
24 |
base_iterator_state: PackingIteratorState
|
25 |
n_batches_to_prefetch: int
|
bytelatent/data/iterators/packing_iterator.py
CHANGED
@@ -5,7 +5,10 @@ import numpy as np
|
|
5 |
from pydantic import BaseModel, ConfigDict
|
6 |
|
7 |
from bytelatent.data.data_types import Batch, BltSequence
|
8 |
-
from bytelatent.data.iterators.abstract_iterator import
|
|
|
|
|
|
|
9 |
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
|
10 |
|
11 |
|
@@ -20,7 +23,7 @@ class PackingArgs(BaseModel):
|
|
20 |
tokenizer_name: str
|
21 |
|
22 |
|
23 |
-
class PackingIteratorState(
|
24 |
model_config = ConfigDict(extra="forbid")
|
25 |
sequence_iterator_state: SamplingIteratorState
|
26 |
packing_args: PackingArgs
|
|
|
5 |
from pydantic import BaseModel, ConfigDict
|
6 |
|
7 |
from bytelatent.data.data_types import Batch, BltSequence
|
8 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
9 |
+
PydanticIteratorState,
|
10 |
+
StatefulIterator,
|
11 |
+
)
|
12 |
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
|
13 |
|
14 |
|
|
|
23 |
tokenizer_name: str
|
24 |
|
25 |
|
26 |
+
class PackingIteratorState(PydanticIteratorState):
|
27 |
model_config = ConfigDict(extra="forbid")
|
28 |
sequence_iterator_state: SamplingIteratorState
|
29 |
packing_args: PackingArgs
|
bytelatent/data/iterators/preprocess_iterator.py
CHANGED
@@ -5,20 +5,29 @@ import torch
|
|
5 |
from pydantic import BaseModel, ConfigDict
|
6 |
|
7 |
from bytelatent.data.data_types import BltExample
|
8 |
-
from bytelatent.data.iterators.abstract_iterator import
|
|
|
|
|
|
|
9 |
from bytelatent.data.iterators.arrow_iterator import (
|
10 |
ArrowFileIterator,
|
11 |
ArrowFileIteratorState,
|
12 |
)
|
13 |
-
from bytelatent.data.iterators.
|
|
|
|
|
|
|
|
|
14 |
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
15 |
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
16 |
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
17 |
|
18 |
|
19 |
-
class PreprocessIteratorState(
|
20 |
model_config = ConfigDict(extra="forbid")
|
21 |
-
arrow_file_iterator_state:
|
|
|
|
|
22 |
add_tokens: bool
|
23 |
add_patches: bool
|
24 |
tokenizer_args: TokenizerArgs
|
@@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator):
|
|
43 |
|
44 |
def __init__(
|
45 |
self,
|
46 |
-
arrow_iterator: ArrowFileIterator,
|
47 |
*,
|
48 |
patcher_args: PatcherArgs,
|
49 |
tokenizer_args: TokenizerArgs,
|
|
|
5 |
from pydantic import BaseModel, ConfigDict
|
6 |
|
7 |
from bytelatent.data.data_types import BltExample
|
8 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
9 |
+
PydanticIteratorState,
|
10 |
+
StatefulIterator,
|
11 |
+
)
|
12 |
from bytelatent.data.iterators.arrow_iterator import (
|
13 |
ArrowFileIterator,
|
14 |
ArrowFileIteratorState,
|
15 |
)
|
16 |
+
from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
|
17 |
+
from bytelatent.data.iterators.looping_iterator import (
|
18 |
+
LoopingIterator,
|
19 |
+
LoopingIteratorState,
|
20 |
+
)
|
21 |
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
22 |
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
23 |
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
24 |
|
25 |
|
26 |
+
class PreprocessIteratorState(PydanticIteratorState):
|
27 |
model_config = ConfigDict(extra="forbid")
|
28 |
+
arrow_file_iterator_state: (
|
29 |
+
ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
|
30 |
+
)
|
31 |
add_tokens: bool
|
32 |
add_patches: bool
|
33 |
tokenizer_args: TokenizerArgs
|
|
|
52 |
|
53 |
def __init__(
|
54 |
self,
|
55 |
+
arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
|
56 |
*,
|
57 |
patcher_args: PatcherArgs,
|
58 |
tokenizer_args: TokenizerArgs,
|
bytelatent/data/iterators/sampling_iterator.py
CHANGED
@@ -2,13 +2,16 @@
|
|
2 |
from typing import Any
|
3 |
|
4 |
import numpy as np
|
5 |
-
from pydantic import
|
6 |
|
7 |
-
from bytelatent.data.iterators.abstract_iterator import
|
|
|
|
|
|
|
8 |
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
|
9 |
|
10 |
|
11 |
-
class SamplingIteratorState(
|
12 |
model_config = ConfigDict(extra="forbid")
|
13 |
rng_state: dict[str, Any]
|
14 |
source_to_weight: dict[str, float]
|
|
|
2 |
from typing import Any
|
3 |
|
4 |
import numpy as np
|
5 |
+
from pydantic import ConfigDict
|
6 |
|
7 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
8 |
+
PydanticIteratorState,
|
9 |
+
StatefulIterator,
|
10 |
+
)
|
11 |
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
|
12 |
|
13 |
|
14 |
+
class SamplingIteratorState(PydanticIteratorState):
|
15 |
model_config = ConfigDict(extra="forbid")
|
16 |
rng_state: dict[str, Any]
|
17 |
source_to_weight: dict[str, float]
|
bytelatent/data/iterators/sequence_iterator.py
CHANGED
@@ -6,7 +6,10 @@ import numpy as np
|
|
6 |
from pydantic import BaseModel, ConfigDict
|
7 |
|
8 |
from bytelatent.data.data_types import BltSequence
|
9 |
-
from bytelatent.data.iterators.abstract_iterator import
|
|
|
|
|
|
|
10 |
from bytelatent.data.iterators.preprocess_iterator import (
|
11 |
PreprocessIterator,
|
12 |
PreprocessIteratorState,
|
@@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel):
|
|
21 |
buffer_size: int
|
22 |
|
23 |
|
24 |
-
class SequenceIteratorState(
|
25 |
model_config = ConfigDict(extra="forbid")
|
26 |
sequence_packing_args: SequencePackingArgs
|
27 |
preprocess_iterator_state: PreprocessIteratorState
|
28 |
-
|
|
|
29 |
|
30 |
def build(self):
|
31 |
preprocess_iterator = self.preprocess_iterator_state.build()
|
@@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator):
|
|
41 |
self,
|
42 |
preprocess_iterator: PreprocessIterator,
|
43 |
*,
|
44 |
-
rng_state: dict[str, Any],
|
45 |
sequence_packing_args: SequencePackingArgs,
|
46 |
):
|
47 |
self.preprocess_iterator = preprocess_iterator
|
48 |
self.sequence_packing_args = sequence_packing_args
|
49 |
self.output_seq_len = sequence_packing_args.output_seq_len
|
50 |
self.buffer_size = sequence_packing_args.buffer_size
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
def get_state(self):
|
55 |
# TODO: need to also perist the current shuffle buffer
|
56 |
return SequenceIteratorState(
|
57 |
sequence_packing_args=self.sequence_packing_args,
|
58 |
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
59 |
-
rng_state=self.rng.bit_generator.state,
|
60 |
)
|
61 |
|
62 |
def create_iter(self):
|
@@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator):
|
|
114 |
|
115 |
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
116 |
assert len(seq_patch_lengths) == self.buffer_size
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
118 |
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
119 |
assert (
|
120 |
sum(seq_patch_lengths[idx])
|
|
|
6 |
from pydantic import BaseModel, ConfigDict
|
7 |
|
8 |
from bytelatent.data.data_types import BltSequence
|
9 |
+
from bytelatent.data.iterators.abstract_iterator import (
|
10 |
+
PydanticIteratorState,
|
11 |
+
StatefulIterator,
|
12 |
+
)
|
13 |
from bytelatent.data.iterators.preprocess_iterator import (
|
14 |
PreprocessIterator,
|
15 |
PreprocessIteratorState,
|
|
|
24 |
buffer_size: int
|
25 |
|
26 |
|
27 |
+
class SequenceIteratorState(PydanticIteratorState):
|
28 |
model_config = ConfigDict(extra="forbid")
|
29 |
sequence_packing_args: SequencePackingArgs
|
30 |
preprocess_iterator_state: PreprocessIteratorState
|
31 |
+
# If None, rng is disabled.
|
32 |
+
rng_state: dict[str, Any] | None
|
33 |
|
34 |
def build(self):
|
35 |
preprocess_iterator = self.preprocess_iterator_state.build()
|
|
|
45 |
self,
|
46 |
preprocess_iterator: PreprocessIterator,
|
47 |
*,
|
48 |
+
rng_state: dict[str, Any] | None,
|
49 |
sequence_packing_args: SequencePackingArgs,
|
50 |
):
|
51 |
self.preprocess_iterator = preprocess_iterator
|
52 |
self.sequence_packing_args = sequence_packing_args
|
53 |
self.output_seq_len = sequence_packing_args.output_seq_len
|
54 |
self.buffer_size = sequence_packing_args.buffer_size
|
55 |
+
if rng_state is None:
|
56 |
+
self.rng = None
|
57 |
+
else:
|
58 |
+
self.rng = np.random.default_rng()
|
59 |
+
self.rng.bit_generator.state = rng_state
|
60 |
|
61 |
def get_state(self):
|
62 |
# TODO: need to also perist the current shuffle buffer
|
63 |
return SequenceIteratorState(
|
64 |
sequence_packing_args=self.sequence_packing_args,
|
65 |
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
66 |
+
rng_state=None if self.rng is None else self.rng.bit_generator.state,
|
67 |
)
|
68 |
|
69 |
def create_iter(self):
|
|
|
121 |
|
122 |
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
123 |
assert len(seq_patch_lengths) == self.buffer_size
|
124 |
+
if self.rng is None:
|
125 |
+
permutations = list(range(len(seq_patch_lengths)))
|
126 |
+
else:
|
127 |
+
permutations = self.rng.permutation(len(seq_patch_lengths))
|
128 |
+
|
129 |
+
for idx in permutations:
|
130 |
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
131 |
assert (
|
132 |
sum(seq_patch_lengths[idx])
|
bytelatent/data/iterators/test_arrow_iterator.py
CHANGED
@@ -6,7 +6,10 @@ import pyarrow as pa
|
|
6 |
import pyarrow.dataset # pyright: ignore
|
7 |
|
8 |
from bytelatent.constants import BLT_DATA
|
9 |
-
from bytelatent.data.iterators.arrow_iterator import
|
|
|
|
|
|
|
10 |
|
11 |
ENTROPY_MODEL = "transformer_100m"
|
12 |
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
@@ -93,3 +96,19 @@ def test_basic_arrow_file():
|
|
93 |
i += 1
|
94 |
if i >= len(expected_ids):
|
95 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import pyarrow.dataset # pyright: ignore
|
7 |
|
8 |
from bytelatent.constants import BLT_DATA
|
9 |
+
from bytelatent.data.iterators.arrow_iterator import (
|
10 |
+
ArrowFileIterator,
|
11 |
+
ArrowFileIteratorState,
|
12 |
+
)
|
13 |
|
14 |
ENTROPY_MODEL = "transformer_100m"
|
15 |
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
|
|
96 |
i += 1
|
97 |
if i >= len(expected_ids):
|
98 |
break
|
99 |
+
|
100 |
+
|
101 |
+
def test_read_jsonl_from_arrow():
|
102 |
+
arrow_iterator = ArrowFileIterator(
|
103 |
+
file_path="fixtures/test_docs.jsonl",
|
104 |
+
num_workers=1,
|
105 |
+
worker_id=0,
|
106 |
+
preprocess_dir=None,
|
107 |
+
entropy_model_name=None,
|
108 |
+
file_format="json",
|
109 |
+
arrow_batch_size=100,
|
110 |
+
)
|
111 |
+
iterator = arrow_iterator.create_iter()
|
112 |
+
for i, example in enumerate(iterator):
|
113 |
+
assert example.sample_id == str(i)
|
114 |
+
assert example.text == f"test_{i}"
|
bytelatent/data/iterators/test_iters.py
CHANGED
@@ -1,83 +1,15 @@
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
import pandas as pd
|
3 |
-
from pydantic import BaseModel
|
4 |
|
5 |
from bytelatent.constants import BLT_DATA
|
6 |
-
from bytelatent.data.
|
7 |
-
|
|
|
|
|
8 |
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
9 |
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
10 |
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
11 |
|
12 |
|
13 |
-
class BltTestIteratorState(BaseModel, IteratorState):
|
14 |
-
position: int
|
15 |
-
total: int
|
16 |
-
|
17 |
-
def build(self):
|
18 |
-
blt_iter = BltTestIteratorState(total=self.total)
|
19 |
-
blt_iter.position = self.position
|
20 |
-
return blt_iter
|
21 |
-
|
22 |
-
|
23 |
-
class BltTestIterator(StatefulIterator):
|
24 |
-
def __init__(self, total: int):
|
25 |
-
self.position = 0
|
26 |
-
self.total = total
|
27 |
-
|
28 |
-
def get_state(self):
|
29 |
-
return BltTestIteratorState(position=self.position, total=self.total)
|
30 |
-
|
31 |
-
def create_iter(self):
|
32 |
-
for i in range(self.total):
|
33 |
-
self.position += 1
|
34 |
-
yield BltExample(
|
35 |
-
sample_id=f"test_{i}",
|
36 |
-
text=f"This is some test {i} text.",
|
37 |
-
tokens=None,
|
38 |
-
mask=None,
|
39 |
-
entropies=None,
|
40 |
-
patch_lengths=None,
|
41 |
-
)
|
42 |
-
|
43 |
-
|
44 |
-
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
|
45 |
-
position: int
|
46 |
-
total: int
|
47 |
-
|
48 |
-
def build(self):
|
49 |
-
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
50 |
-
blt_iter.position = self.position
|
51 |
-
return blt_iter
|
52 |
-
|
53 |
-
|
54 |
-
class BltTestWithEntropiesIterator(StatefulIterator):
|
55 |
-
def __init__(self, total: int):
|
56 |
-
self.position = 0
|
57 |
-
self.total = total
|
58 |
-
|
59 |
-
def get_state(self):
|
60 |
-
return BltTestIteratorState(position=self.position, total=self.total)
|
61 |
-
|
62 |
-
def create_iter(self):
|
63 |
-
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
64 |
-
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
65 |
-
tokens = df["token_ids"].tolist()
|
66 |
-
entropies = df["entropies"].tolist()
|
67 |
-
# BOS and EOS
|
68 |
-
assert len(tokens) == len(text) + 2
|
69 |
-
for i in range(self.total):
|
70 |
-
self.position += 1
|
71 |
-
yield BltExample(
|
72 |
-
sample_id=f"test_{i}",
|
73 |
-
text=text,
|
74 |
-
tokens=tokens,
|
75 |
-
mask=[True] * len(tokens),
|
76 |
-
entropies=entropies,
|
77 |
-
patch_lengths=None,
|
78 |
-
)
|
79 |
-
|
80 |
-
|
81 |
def test_preprocess_iter():
|
82 |
total = 3
|
83 |
tokenizer_args = TokenizerArgs(
|
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
|
2 |
|
3 |
from bytelatent.constants import BLT_DATA
|
4 |
+
from bytelatent.data.iterators.dev_iterators import (
|
5 |
+
BltTestIterator,
|
6 |
+
BltTestWithEntropiesIterator,
|
7 |
+
)
|
8 |
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
9 |
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
10 |
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def test_preprocess_iter():
|
14 |
total = 3
|
15 |
tokenizer_args = TokenizerArgs(
|
bytelatent/data/iterators/test_limit_iterator.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bytelatent.data.iterators.dev_iterators import BltTestIterator
|
2 |
+
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
3 |
+
|
4 |
+
|
5 |
+
def test_limit_iterator():
|
6 |
+
total = 10
|
7 |
+
limit = 5
|
8 |
+
base_iterator = BltTestIterator(total=total)
|
9 |
+
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
10 |
+
iterator = limit_iterator.create_iter()
|
11 |
+
n = 0
|
12 |
+
for example in iterator:
|
13 |
+
assert example.sample_id == f"test_{n}"
|
14 |
+
n += 1
|
15 |
+
assert n == limit
|
16 |
+
|
17 |
+
limit = 10
|
18 |
+
base_iterator = BltTestIterator(total=total)
|
19 |
+
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
20 |
+
iterator = limit_iterator.create_iter()
|
21 |
+
n = 0
|
22 |
+
for example in iterator:
|
23 |
+
assert example.sample_id == f"test_{n}"
|
24 |
+
n += 1
|
25 |
+
assert n == limit == total
|
26 |
+
|
27 |
+
limit = 20
|
28 |
+
base_iterator = BltTestIterator(total=total)
|
29 |
+
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
30 |
+
iterator = limit_iterator.create_iter()
|
31 |
+
n = 0
|
32 |
+
for example in iterator:
|
33 |
+
assert example.sample_id == f"test_{n}"
|
34 |
+
n += 1
|
35 |
+
assert n == total
|
36 |
+
|
37 |
+
limit = -1
|
38 |
+
base_iterator = BltTestIterator(total=total)
|
39 |
+
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
40 |
+
iterator = limit_iterator.create_iter()
|
41 |
+
n = 0
|
42 |
+
for example in iterator:
|
43 |
+
assert example.sample_id == f"test_{n}"
|
44 |
+
n += 1
|
45 |
+
assert n == total
|
fixtures/test_docs.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{"sample_id": "0", "text": "test_0"}
|
2 |
+
{"sample_id": "1", "text": "test_1"}
|
3 |
+
{"sample_id": "2", "text": "test_2"}
|