par-meta commited on
Commit
936d943
·
unverified ·
1 Parent(s): afedb16

Allow ArrowIterator to read from json (#45)

Browse files

Summary:

Currently, arrow iterator can only read arrow files. However, the pyarrow library can read
other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines,
so we can read from the original source data, and simply omit the entropy column when doing so

Test Plan:

Run train script until dataloader starts

bytelatent/args.py CHANGED
@@ -1,8 +1,10 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
2
  import logging
3
  import os
4
  from typing import Any
5
 
 
6
  import numpy as np
7
  import yaml
8
  from omegaconf import OmegaConf
@@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
10
 
11
  from bytelatent.checkpoint import CheckpointArgs
12
  from bytelatent.data.data_types import Batch
 
13
  from bytelatent.data.iterators.abstract_iterator import StatefulIterator
14
- from bytelatent.data.iterators.arrow_iterator import (
15
- ArrowFileIterator,
16
- find_and_sanitize_chunks,
17
- )
18
  from bytelatent.data.iterators.looping_iterator import LoopingIterator
19
  from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
20
  from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
@@ -53,6 +53,33 @@ def parse_args(args_cls):
53
  return pydantic_args
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def distribute_data_to_rank(
57
  *,
58
  dataset_path: str,
@@ -62,9 +89,10 @@ def distribute_data_to_rank(
62
  rank: int,
63
  world_size: int,
64
  s3_profile: str | None = None,
 
65
  ) -> ArrowFileIterator:
66
  dataset_chunks = find_and_sanitize_chunks(
67
- dataset_path, world_size, s3_profile=s3_profile
68
  )
69
  n_workers_per_chunk = world_size // len(dataset_chunks)
70
  rank_to_arrow_iterator_params = []
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import json
3
  import logging
4
  import os
5
  from typing import Any
6
 
7
+ import fsspec
8
  import numpy as np
9
  import yaml
10
  from omegaconf import OmegaConf
 
12
 
13
  from bytelatent.checkpoint import CheckpointArgs
14
  from bytelatent.data.data_types import Batch
15
+ from bytelatent.data.file_util import get_fs
16
  from bytelatent.data.iterators.abstract_iterator import StatefulIterator
17
+ from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
 
 
 
18
  from bytelatent.data.iterators.looping_iterator import LoopingIterator
19
  from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
20
  from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
 
53
  return pydantic_args
54
 
55
 
56
+ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
57
+
58
+
59
+ def find_and_sanitize_chunks(
60
+ dataset_path: str,
61
+ world_size: int,
62
+ file_pattern: str,
63
+ s3_profile: str | None = None,
64
+ ):
65
+ fs = get_fs(dataset_path, s3_profile=s3_profile)
66
+ path_with_glob = os.path.join(dataset_path, file_pattern)
67
+ dataset_chunks = fs.glob(path_with_glob)
68
+ n_chunks = len(dataset_chunks)
69
+
70
+ if n_chunks > world_size:
71
+ n_discard = n_chunks - world_size
72
+ dataset_chunks = dataset_chunks[:world_size]
73
+ else:
74
+ assert (
75
+ world_size % n_chunks == 0
76
+ ), "World size should be a multiple of number of chunks"
77
+
78
+ assert n_chunks > 0, f"No valid chunks in {dataset_path}"
79
+
80
+ return dataset_chunks
81
+
82
+
83
  def distribute_data_to_rank(
84
  *,
85
  dataset_path: str,
 
89
  rank: int,
90
  world_size: int,
91
  s3_profile: str | None = None,
92
+ file_pattern: str = TRAIN_DATA_FILE_PATTERN,
93
  ) -> ArrowFileIterator:
94
  dataset_chunks = find_and_sanitize_chunks(
95
+ dataset_path, world_size, file_pattern, s3_profile=s3_profile
96
  )
97
  n_workers_per_chunk = world_size // len(dataset_chunks)
98
  rank_to_arrow_iterator_params = []
bytelatent/data/iterators/arrow_iterator.py CHANGED
@@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
16
  from bytelatent.data.data_types import BltExample
17
  from bytelatent.data.file_util import get_fs
18
  from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
 
19
 
20
  logger = getLogger(__name__)
21
 
@@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
32
  arrow_batch_size: int = 100
33
  s3_profile: str | None
34
  filesystem_type: str | None = None
 
35
 
36
  def build(self) -> "ArrowFileIterator":
37
  arrow_file = ArrowFileIterator(
@@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
44
  dataset_files=self.dataset_files,
45
  s3_profile=self.s3_profile,
46
  filesystem_type=self.filesystem_type,
 
47
  )
48
  if self.row_num != 0:
49
  arrow_file._set_row_num(self.row_num)
@@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
70
  dataset_files: list[str] | None = None,
71
  s3_profile: str | None = None,
72
  filesystem_type: str | None = None,
 
73
  ):
74
  assert 0 <= worker_id < num_workers, (worker_id, num_workers)
75
  if file_path is None and dataset_files is None:
@@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
87
  self.arrow_batch_size = arrow_batch_size
88
  self.s3_profile = s3_profile
89
  self.filesystem_type = filesystem_type
 
90
  self.fs = None
91
  if self.filesystem_type is not None:
92
  if self.filesystem_type == "file":
93
  self.fs = fsspec.filesystem("file")
94
  elif self.filesystem_type == "s3":
95
  self.fs = fsspec.filesystem("s3", profile=s3_profile)
 
 
 
96
 
97
  if dataset_files is None:
98
  # Prepare arrow shards
@@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
153
  dataset_files=self.dataset_files,
154
  s3_profile=self.s3_profile,
155
  filesystem_type=self.filesystem_type,
 
156
  )
157
 
158
  def create_iter(
@@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
164
  else:
165
  filesystem = None
166
  self.dataset = pa.dataset.dataset(
167
- self.dataset_files, format="arrow", filesystem=filesystem
168
  )
169
  self.batch_iterator = self.dataset.to_batches(
170
  batch_size=self.arrow_batch_size
@@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
173
  if self.batch_to_consume is not None:
174
  batch_columns: dict[str, list] = self.batch_to_consume
175
  self.batch_to_consume = None
176
- sample_ids = batch_columns["sample_id"]
177
- texts = batch_columns["text"]
178
- entropies = batch_columns["entropies"]
 
 
 
 
 
 
 
 
 
179
  for i in range(len(sample_ids)):
180
  out = BltExample(
181
  sample_id=sample_ids[i],
182
- entropies=entropies[i],
183
  text=texts[i],
184
  tokens=None,
185
  mask=None,
@@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
191
 
192
  for batch in self.batch_iterator:
193
  batch_columns = batch.to_pydict()
194
- sample_ids = batch_columns["sample_id"]
195
- texts = batch_columns["text"]
196
- entropies = batch_columns["entropies"]
 
 
 
 
 
 
 
 
 
197
  for i in range(len(sample_ids)):
198
  out = BltExample(
199
  sample_id=sample_ids[i],
200
- entropies=entropies[i],
201
  text=texts[i],
202
  tokens=None,
203
  mask=None,
@@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
231
  for batch in self.batch_iterator:
232
  if len(batch) > curr_remaining:
233
  batch_columns: dict[str, list] = batch.to_pydict()
234
- batch_columns["sample_id"] = batch_columns["sample_id"][
235
- curr_remaining:
236
- ]
237
- batch_columns["entropies"] = batch_columns["entropies"][
238
- curr_remaining:
239
- ]
240
- batch_columns["text"] = batch_columns["text"][curr_remaining:]
 
 
 
 
 
 
 
 
 
 
 
241
  self.batch_to_consume = batch_columns
242
  break
243
  elif len(batch) == curr_remaining:
@@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
250
  logger.info(
251
  f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
252
  )
253
-
254
-
255
- TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
256
-
257
-
258
- def find_and_sanitize_chunks(
259
- dataset_path: str,
260
- world_size: int,
261
- file_pattern: str = TRAIN_DATA_FILE_PATTERN,
262
- s3_profile: str | None = None,
263
- ):
264
- fs = get_fs(dataset_path, s3_profile=s3_profile)
265
- path_with_glob = os.path.join(dataset_path, file_pattern)
266
- dataset_chunks = fs.glob(path_with_glob)
267
- n_chunks = len(dataset_chunks)
268
-
269
- if n_chunks > world_size:
270
- n_discard = n_chunks - world_size
271
- dataset_chunks = dataset_chunks[:world_size]
272
- else:
273
- assert (
274
- world_size % n_chunks == 0
275
- ), "World size should be a multiple of number of chunks"
276
-
277
- assert n_chunks > 0, f"No valid chunks in {dataset_path}"
278
-
279
- return dataset_chunks
 
16
  from bytelatent.data.data_types import BltExample
17
  from bytelatent.data.file_util import get_fs
18
  from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
19
+ from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
20
 
21
  logger = getLogger(__name__)
22
 
 
33
  arrow_batch_size: int = 100
34
  s3_profile: str | None
35
  filesystem_type: str | None = None
36
+ file_format: str
37
 
38
  def build(self) -> "ArrowFileIterator":
39
  arrow_file = ArrowFileIterator(
 
46
  dataset_files=self.dataset_files,
47
  s3_profile=self.s3_profile,
48
  filesystem_type=self.filesystem_type,
49
+ file_format=self.file_format,
50
  )
51
  if self.row_num != 0:
52
  arrow_file._set_row_num(self.row_num)
 
73
  dataset_files: list[str] | None = None,
74
  s3_profile: str | None = None,
75
  filesystem_type: str | None = None,
76
+ file_format: str = "arrow",
77
  ):
78
  assert 0 <= worker_id < num_workers, (worker_id, num_workers)
79
  if file_path is None and dataset_files is None:
 
91
  self.arrow_batch_size = arrow_batch_size
92
  self.s3_profile = s3_profile
93
  self.filesystem_type = filesystem_type
94
+ self.file_format = file_format
95
  self.fs = None
96
  if self.filesystem_type is not None:
97
  if self.filesystem_type == "file":
98
  self.fs = fsspec.filesystem("file")
99
  elif self.filesystem_type == "s3":
100
  self.fs = fsspec.filesystem("s3", profile=s3_profile)
101
+ else:
102
+ raise ValueError("Unknown filesystem")
103
+ logger.info("Arrow iterator using fs=%s", self.fs)
104
 
105
  if dataset_files is None:
106
  # Prepare arrow shards
 
161
  dataset_files=self.dataset_files,
162
  s3_profile=self.s3_profile,
163
  filesystem_type=self.filesystem_type,
164
+ file_format=self.file_format,
165
  )
166
 
167
  def create_iter(
 
173
  else:
174
  filesystem = None
175
  self.dataset = pa.dataset.dataset(
176
+ self.dataset_files, format=self.file_format, filesystem=filesystem
177
  )
178
  self.batch_iterator = self.dataset.to_batches(
179
  batch_size=self.arrow_batch_size
 
182
  if self.batch_to_consume is not None:
183
  batch_columns: dict[str, list] = self.batch_to_consume
184
  self.batch_to_consume = None
185
+ if self.file_format == "arrow":
186
+ sample_ids = batch_columns["sample_id"]
187
+ texts = batch_columns["text"]
188
+ entropies = batch_columns["entropies"]
189
+ elif self.file_format == "json":
190
+ # This data hasn't been preprocessed to a uniform format,
191
+ # so we have to do it now and omit entropies
192
+ sample_ids = batch_columns[get_id_key(batch_columns)]
193
+ texts = get_text(batch_columns)
194
+ entropies = None
195
+ else:
196
+ raise ValueError(f"Unknown file format: {self.file_format}")
197
  for i in range(len(sample_ids)):
198
  out = BltExample(
199
  sample_id=sample_ids[i],
200
+ entropies=entropies[i] if entropies is not None else None,
201
  text=texts[i],
202
  tokens=None,
203
  mask=None,
 
209
 
210
  for batch in self.batch_iterator:
211
  batch_columns = batch.to_pydict()
212
+ if self.file_format == "arrow":
213
+ sample_ids = batch_columns["sample_id"]
214
+ texts = batch_columns["text"]
215
+ entropies = batch_columns["entropies"]
216
+ elif self.file_format == "json":
217
+ # This data hasn't been preprocessed to a uniform format,
218
+ # so we have to do it now and omit entropies
219
+ sample_ids = batch_columns[get_id_key(batch_columns)]
220
+ texts = get_text(batch_columns)
221
+ entropies = None
222
+ else:
223
+ raise ValueError(f"Unknown file format: {self.file_format}")
224
  for i in range(len(sample_ids)):
225
  out = BltExample(
226
  sample_id=sample_ids[i],
227
+ entropies=entropies[i] if entropies is not None else None,
228
  text=texts[i],
229
  tokens=None,
230
  mask=None,
 
258
  for batch in self.batch_iterator:
259
  if len(batch) > curr_remaining:
260
  batch_columns: dict[str, list] = batch.to_pydict()
261
+ if self.file_format == "arrow":
262
+ leftover_sample_ids = batch_columns["sample_id"][
263
+ curr_remaining:
264
+ ]
265
+ leftover_entropies = batch_columns["entropies"][curr_remaining:]
266
+ leftover_texts = batch_columns["text"][curr_remaining:]
267
+ elif self.file_format == "json":
268
+ leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
269
+ curr_remaining:
270
+ ]
271
+ leftover_entropies = None
272
+ leftover_texts = get_text(batch_columns)[curr_remaining:]
273
+ else:
274
+ raise ValueError(f"Unknown file format: {self.file_format}")
275
+
276
+ batch_columns["sample_id"] = leftover_sample_ids
277
+ batch_columns["entropies"] = leftover_entropies
278
+ batch_columns["text"] = leftover_texts
279
  self.batch_to_consume = batch_columns
280
  break
281
  elif len(batch) == curr_remaining:
 
288
  logger.info(
289
  f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
290
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bytelatent/preprocess/preprocess_entropies.py CHANGED
@@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
15
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
16
 
17
 
18
- def get_id_from_doc(doc: dict) -> int:
19
  """
20
  We need a reliable way to ensure that samples from jsonl
21
  and arrow are the same, but there is no unique id field,
22
  so derive the best possible
23
  """
24
  if "sample_id" in doc:
25
- sample_id = doc["sample_id"]
26
  elif "title" in doc:
27
- sample_id = doc["title"]
28
  elif "qid" in doc:
29
- sample_id = doc["qid"]
30
  elif "paper_id" in doc:
31
- sample_id = doc["paper_id"]
32
  elif "path" in doc:
33
- sample_id = doc["path"]
34
  elif "url" in doc:
35
- sample_id = doc["url"]
36
  elif "id" in doc:
37
- sample_id = doc["id"]
38
  else:
39
  raise ValueError(f"Could not find a id key from: {doc.keys()}")
40
- return str(sample_id)
 
 
 
 
 
 
 
 
41
 
42
 
43
  def get_text(doc: dict):
 
15
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
16
 
17
 
18
+ def get_id_key(doc: dict) -> int:
19
  """
20
  We need a reliable way to ensure that samples from jsonl
21
  and arrow are the same, but there is no unique id field,
22
  so derive the best possible
23
  """
24
  if "sample_id" in doc:
25
+ return "sample_id"
26
  elif "title" in doc:
27
+ return "title"
28
  elif "qid" in doc:
29
+ return "qid"
30
  elif "paper_id" in doc:
31
+ return "paper_id"
32
  elif "path" in doc:
33
+ return "path"
34
  elif "url" in doc:
35
+ return "url"
36
  elif "id" in doc:
37
+ return "id"
38
  else:
39
  raise ValueError(f"Could not find a id key from: {doc.keys()}")
40
+
41
+
42
+ def get_id_from_doc(doc: dict) -> int:
43
+ """
44
+ We need a reliable way to ensure that samples from jsonl
45
+ and arrow are the same, but there is no unique id field,
46
+ so derive the best possible
47
+ """
48
+ return str(doc[get_id_key(doc)])
49
 
50
 
51
  def get_text(doc: dict):
bytelatent/stool.py CHANGED
@@ -4,10 +4,10 @@ import json
4
  import os
5
  import shutil
6
  import subprocess
7
- from pydantic import BaseModel
8
  from typing import Any, Dict
9
 
10
  from omegaconf import OmegaConf
 
11
 
12
 
13
  class StoolArgs(BaseModel):
 
4
  import os
5
  import shutil
6
  import subprocess
 
7
  from typing import Any, Dict
8
 
9
  from omegaconf import OmegaConf
10
+ from pydantic import BaseModel
11
 
12
 
13
  class StoolArgs(BaseModel):