par-meta commited on
Commit
b0120da
·
unverified ·
1 Parent(s): d4ddb95

Replace regular filesystem calls with fsspec + add s3 support (#18)

Browse files

Summary:

For compatibility with either local/nfs or S3 datasets, swap to fsspec.

Add a tool to compare local and remote filesystems

Test Plan:

- Ran regular train script
- Ran with config with data in S3

.gitignore CHANGED
@@ -165,4 +165,4 @@ cython_debug/
165
  figures/
166
  .vscode/
167
  .DS_Store
168
-
 
165
  figures/
166
  .vscode/
167
  .DS_Store
168
+ internal/
bytelatent/args.py CHANGED
@@ -46,8 +46,11 @@ def distribute_data_to_rank(
46
  arrow_batch_size: int,
47
  rank: int,
48
  world_size: int,
 
49
  ) -> ArrowFileIterator:
50
- dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size)
 
 
51
  n_workers_per_chunk = world_size // len(dataset_chunks)
52
  rank_to_arrow_iterator_params = []
53
  for chunk_path in dataset_chunks:
@@ -61,6 +64,7 @@ def distribute_data_to_rank(
61
  dataset_files=None,
62
  entropy_model_name=entropy_model_name,
63
  arrow_batch_size=arrow_batch_size,
 
64
  )
65
  )
66
  return rank_to_arrow_iterator_params[rank]
@@ -68,6 +72,7 @@ def distribute_data_to_rank(
68
 
69
  class DataloaderArgs(BaseModel):
70
  model_config = ConfigDict(extra="forbid")
 
71
  root_dir: str | None = None
72
  sources: dict[str, float] = {}
73
  batch_size: int = 2
@@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel):
107
  arrow_batch_size=self.arrow_batch_size,
108
  rank=rank,
109
  world_size=world_size,
 
110
  )
111
  looping_iterator = LoopingIterator(arrow_iterator)
112
  preprocess_iterator = PreprocessIterator(
 
46
  arrow_batch_size: int,
47
  rank: int,
48
  world_size: int,
49
+ s3_profile: str | None = None,
50
  ) -> ArrowFileIterator:
51
+ dataset_chunks = find_and_sanitize_chunks(
52
+ dataset_path, world_size, s3_profile=s3_profile
53
+ )
54
  n_workers_per_chunk = world_size // len(dataset_chunks)
55
  rank_to_arrow_iterator_params = []
56
  for chunk_path in dataset_chunks:
 
64
  dataset_files=None,
65
  entropy_model_name=entropy_model_name,
66
  arrow_batch_size=arrow_batch_size,
67
+ s3_profile=s3_profile,
68
  )
69
  )
70
  return rank_to_arrow_iterator_params[rank]
 
72
 
73
  class DataloaderArgs(BaseModel):
74
  model_config = ConfigDict(extra="forbid")
75
+ s3_profile: str | None = None
76
  root_dir: str | None = None
77
  sources: dict[str, float] = {}
78
  batch_size: int = 2
 
112
  arrow_batch_size=self.arrow_batch_size,
113
  rank=rank,
114
  world_size=world_size,
115
+ s3_profile=self.s3_profile,
116
  )
117
  looping_iterator = LoopingIterator(arrow_iterator)
118
  preprocess_iterator = PreprocessIterator(
bytelatent/data/file_util.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import fsspec
4
+ import pyarrow as pa
5
+
6
+ # pyarrow needs the initialization from this import
7
+ import pyarrow.dataset # pyright: ignore
8
+ import typer
9
+ from pyarrow.lib import ArrowInvalid
10
+ from rich.progress import track
11
+
12
+
13
+ def is_valid_arrow_file(path: str):
14
+ try:
15
+ dataset = pa.dataset.dataset(path, format="arrow")
16
+ return True
17
+ except ArrowInvalid:
18
+ return False
19
+
20
+
21
+ app = typer.Typer()
22
+
23
+ S3_PREFIX = "s3://"
24
+
25
+
26
+ def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem:
27
+ if path.startswith("s3://"):
28
+ if s3_profile is None:
29
+ return fsspec.filesystem("s3")
30
+ else:
31
+ return fsspec.filesystem("s3", profile=s3_profile)
32
+ else:
33
+ return fsspec.filesystem("file")
34
+
35
+
36
+ @app.command()
37
+ def print_local_to_delete(
38
+ blob_dir: str, local_dirs: list[str], s3_profile: str = "blt"
39
+ ):
40
+ for s in local_dirs:
41
+ assert s.endswith("/"), "Dirs must end with /"
42
+ assert blob_dir.endswith("/"), "Dirs must end with /"
43
+ blob_fs = fsspec.filesystem("s3", profile=s3_profile)
44
+ blob_files = blob_fs.find(blob_dir)
45
+ for f in track(blob_files):
46
+ size = blob_fs.info(f)["Size"]
47
+ if not f.lower().endswith(".complete"):
48
+ assert size != 0, f"Size was invalidly zero for {f}"
49
+
50
+ blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files}
51
+ local_fs = fsspec.filesystem("file")
52
+
53
+ files_to_delete = []
54
+ for local_dir in local_dirs:
55
+ local_files = local_fs.find(local_dir)
56
+ for f in local_files:
57
+ relative_path = f[len(local_dir) :]
58
+ if relative_path in blob_relative_paths and not os.path.islink(f):
59
+ files_to_delete.append(f)
60
+ print(len(files_to_delete))
61
+ with open("/tmp/files_to_delete.txt", "w") as f:
62
+ for file in files_to_delete:
63
+ f.write(f"{file}\n")
64
+
65
+
66
+ @app.command()
67
+ def compare_local_to_blob(
68
+ source_dirs: list[str], dst_dir: str, s3_profile: str = "blt"
69
+ ):
70
+ for s in source_dirs:
71
+ assert s.endswith("/"), "Dirs must end with /"
72
+ assert dst_dir.endswith("/"), "Dirs must end with /"
73
+ assert len(source_dirs) != 0
74
+ assert dst_dir.startswith("s3://")
75
+ local_fs = fsspec.filesystem("file")
76
+ dst_fs = fsspec.filesystem("s3", profile=s3_profile)
77
+ source_to_files = {}
78
+ all_local_files = set()
79
+ for s in source_dirs:
80
+ skipped = []
81
+ if s not in source_to_files:
82
+ source_to_files[s] = []
83
+ for f in local_fs.find(s):
84
+ if os.path.islink(f):
85
+ continue
86
+ if f.endswith(".COMPLETE") or f.endswith(".complete"):
87
+ is_complete_file = True
88
+ assert os.path.getsize(f) == 0, ".COMPLETE files should be empty"
89
+ else:
90
+ is_complete_file = False
91
+
92
+ if not is_complete_file and os.path.getsize(f) == 0:
93
+ skipped.append(f)
94
+ continue
95
+ if f.endswith(".arrow"):
96
+ if not is_valid_arrow_file(f):
97
+ skipped.append(f)
98
+ continue
99
+
100
+ source_to_files[s].append(f)
101
+ all_local_files.add(f[len(s) :])
102
+ print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10])
103
+
104
+ dst_files = dst_fs.find(dst_dir)
105
+ print(dst_dir, len(dst_files))
106
+
107
+ dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files}
108
+ diff = all_local_files.symmetric_difference(dst_file_set)
109
+ print("Local files", len(all_local_files))
110
+ print("DST Files", len(dst_file_set))
111
+ print("Symmetric difference", len(diff))
112
+ dst_only_files = dst_file_set - all_local_files
113
+ print("DST only", len(dst_only_files), list(dst_only_files)[:10])
114
+
115
+
116
+ if __name__ == "__main__":
117
+ app()
bytelatent/data/iterators/arrow_iterator.py CHANGED
@@ -1,17 +1,20 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
2
  import re
3
  from logging import getLogger
4
- from pathlib import Path
5
  from typing import Any, Generator
6
 
 
7
  import pyarrow as pa
8
 
9
  # pyarrow needs the initialization from this import
10
  import pyarrow.dataset # pyright: ignore
 
11
  from pydantic import BaseModel, ConfigDict
12
 
13
  from bytelatent import ByteLatentError
14
  from bytelatent.data.data_types import BltExample
 
15
  from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
16
 
17
  logger = getLogger(__name__)
@@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
27
  dataset_files: list[str] | None
28
  entropy_model_name: str | None
29
  arrow_batch_size: int = 100
 
 
30
 
31
  def build(self) -> "ArrowFileIterator":
32
  arrow_file = ArrowFileIterator(
@@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
37
  entropy_model_name=self.entropy_model_name,
38
  arrow_batch_size=self.arrow_batch_size,
39
  dataset_files=self.dataset_files,
 
 
40
  )
41
  if self.row_num != 0:
42
  arrow_file._set_row_num(self.row_num)
43
  return arrow_file
44
 
45
 
46
- def shard_sort_key(file: str | Path):
47
- match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
 
48
  shard_number = int(match.group(1))
49
  return shard_number
50
 
@@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator):
60
  entropy_model_name: str | None,
61
  arrow_batch_size: int,
62
  dataset_files: list[str] | None = None,
 
 
63
  ):
64
  assert 0 <= worker_id < num_workers, (worker_id, num_workers)
65
  if file_path is None and dataset_files is None:
@@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator):
75
  self.preprocess_dir = preprocess_dir
76
  self.entropy_model_name = entropy_model_name
77
  self.arrow_batch_size = arrow_batch_size
 
 
 
 
 
 
 
 
 
78
  if dataset_files is None:
79
  # Prepare arrow shards
80
- jsonl_file = Path(file_path)
81
- parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
 
 
82
  assert parts is not None
83
  dataset = parts.group(1)
84
- data_dir = Path(preprocess_dir) / dataset / entropy_model_name
85
- shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
 
 
 
 
 
 
 
 
 
 
86
  for s in shard_files:
87
- if not (data_dir / f"{s.name}.complete").exists():
 
 
 
 
88
  raise ValueError(f"Missing .complete for input file: {s}")
89
 
90
  shard_files = sorted(shard_files, key=shard_sort_key)
@@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator):
92
  raise ByteLatentError(
93
  f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
94
  )
95
- self.dataset_files = [str(f) for f in shard_files]
96
  else:
97
  self.preprocess_dir = None
98
  self.dataset_files = dataset_files
 
 
 
 
 
 
 
 
 
99
 
100
  def get_state(self) -> ArrowFileIteratorState:
101
  return ArrowFileIteratorState(
@@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator):
107
  entropy_model_name=self.entropy_model_name,
108
  arrow_batch_size=self.arrow_batch_size,
109
  dataset_files=self.dataset_files,
 
 
110
  )
111
 
112
  def create_iter(
113
  self,
114
  ) -> Generator[BltExample, Any, None]:
115
  if self.dataset is None:
116
- self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
 
 
 
 
 
 
117
  self.batch_iterator = self.dataset.to_batches(
118
  batch_size=self.arrow_batch_size
119
  )
@@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator):
165
  self.batch_iterator = None
166
  self.batch_to_consume = None
167
  else:
168
- self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
 
 
 
 
 
 
169
  self.batch_iterator = self.dataset.to_batches(
170
  batch_size=self.arrow_batch_size
171
  )
@@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
198
 
199
 
200
  def find_and_sanitize_chunks(
201
- dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
 
 
 
202
  ):
203
- dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
 
 
204
  n_chunks = len(dataset_chunks)
205
 
206
  if n_chunks > world_size:
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import os
3
  import re
4
  from logging import getLogger
 
5
  from typing import Any, Generator
6
 
7
+ import fsspec
8
  import pyarrow as pa
9
 
10
  # pyarrow needs the initialization from this import
11
  import pyarrow.dataset # pyright: ignore
12
+ import s3fs
13
  from pydantic import BaseModel, ConfigDict
14
 
15
  from bytelatent import ByteLatentError
16
  from bytelatent.data.data_types import BltExample
17
+ from bytelatent.data.file_util import get_fs
18
  from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
19
 
20
  logger = getLogger(__name__)
 
30
  dataset_files: list[str] | None
31
  entropy_model_name: str | None
32
  arrow_batch_size: int = 100
33
+ s3_profile: str | None
34
+ filesystem_type: str | None = None
35
 
36
  def build(self) -> "ArrowFileIterator":
37
  arrow_file = ArrowFileIterator(
 
42
  entropy_model_name=self.entropy_model_name,
43
  arrow_batch_size=self.arrow_batch_size,
44
  dataset_files=self.dataset_files,
45
+ s3_profile=self.s3_profile,
46
+ filesystem_type=self.filesystem_type,
47
  )
48
  if self.row_num != 0:
49
  arrow_file._set_row_num(self.row_num)
50
  return arrow_file
51
 
52
 
53
+ def shard_sort_key(file: str):
54
+ assert isinstance(file, str)
55
+ match = re.search(r".+\.shard_([0-9]+)\.arrow", file)
56
  shard_number = int(match.group(1))
57
  return shard_number
58
 
 
68
  entropy_model_name: str | None,
69
  arrow_batch_size: int,
70
  dataset_files: list[str] | None = None,
71
+ s3_profile: str | None = None,
72
+ filesystem_type: str | None = None,
73
  ):
74
  assert 0 <= worker_id < num_workers, (worker_id, num_workers)
75
  if file_path is None and dataset_files is None:
 
85
  self.preprocess_dir = preprocess_dir
86
  self.entropy_model_name = entropy_model_name
87
  self.arrow_batch_size = arrow_batch_size
88
+ self.s3_profile = s3_profile
89
+ self.filesystem_type = filesystem_type
90
+ self.fs = None
91
+ if self.filesystem_type is not None:
92
+ if self.filesystem_type == "file":
93
+ self.fs = fsspec.filesystem("file")
94
+ elif self.filesystem_type == "s3":
95
+ self.fs = fsspec.filesystem("s3", profile=s3_profile)
96
+
97
  if dataset_files is None:
98
  # Prepare arrow shards
99
+ jsonl_file = file_path
100
+ parts = re.match(
101
+ r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
102
+ )
103
  assert parts is not None
104
  dataset = parts.group(1)
105
+ data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
106
+ data_dir_with_glob = os.path.join(
107
+ data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
108
+ )
109
+ if self.fs is None:
110
+ self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
111
+ if isinstance(self.fs, s3fs.S3FileSystem):
112
+ self.filesystem_type = "s3"
113
+ else:
114
+ self.filesystem_type = "file"
115
+ shard_files = self.fs.glob(data_dir_with_glob)
116
+
117
  for s in shard_files:
118
+ complete_file = os.path.join(
119
+ data_dir, f"{os.path.basename(s)}.complete"
120
+ )
121
+
122
+ if not self.fs.exists(complete_file):
123
  raise ValueError(f"Missing .complete for input file: {s}")
124
 
125
  shard_files = sorted(shard_files, key=shard_sort_key)
 
127
  raise ByteLatentError(
128
  f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
129
  )
130
+ self.dataset_files = [f for f in shard_files]
131
  else:
132
  self.preprocess_dir = None
133
  self.dataset_files = dataset_files
134
+ if dataset_files[0].startswith("s3://"):
135
+ for f in dataset_files:
136
+ assert f.startswith("s3://")
137
+ if self.fs is None:
138
+ self.fs = get_fs(dataset_files[0], s3_profile=s3_profile)
139
+ if isinstance(self.fs, s3fs.S3FileSystem):
140
+ self.filesystem_type = "s3"
141
+ else:
142
+ self.filesystem_type = "file"
143
 
144
  def get_state(self) -> ArrowFileIteratorState:
145
  return ArrowFileIteratorState(
 
151
  entropy_model_name=self.entropy_model_name,
152
  arrow_batch_size=self.arrow_batch_size,
153
  dataset_files=self.dataset_files,
154
+ s3_profile=self.s3_profile,
155
+ filesystem_type=self.filesystem_type,
156
  )
157
 
158
  def create_iter(
159
  self,
160
  ) -> Generator[BltExample, Any, None]:
161
  if self.dataset is None:
162
+ if isinstance(self.fs, s3fs.core.S3FileSystem):
163
+ filesystem = self.fs
164
+ else:
165
+ filesystem = None
166
+ self.dataset = pa.dataset.dataset(
167
+ self.dataset_files, format="arrow", filesystem=filesystem
168
+ )
169
  self.batch_iterator = self.dataset.to_batches(
170
  batch_size=self.arrow_batch_size
171
  )
 
217
  self.batch_iterator = None
218
  self.batch_to_consume = None
219
  else:
220
+ if isinstance(self.fs, s3fs.core.S3FileSystem):
221
+ filesystem = self.fs
222
+ else:
223
+ filesystem = None
224
+ self.dataset = pa.dataset.dataset(
225
+ self.dataset_files, format="arrow", filesystem=filesystem
226
+ )
227
  self.batch_iterator = self.dataset.to_batches(
228
  batch_size=self.arrow_batch_size
229
  )
 
256
 
257
 
258
  def find_and_sanitize_chunks(
259
+ dataset_path: str,
260
+ world_size: int,
261
+ file_pattern: str = TRAIN_DATA_FILE_PATTERN,
262
+ s3_profile: str | None = None,
263
  ):
264
+ fs = get_fs(dataset_path, s3_profile=s3_profile)
265
+ path_with_glob = os.path.join(dataset_path, file_pattern)
266
+ dataset_chunks = fs.glob(path_with_glob)
267
  n_chunks = len(dataset_chunks)
268
 
269
  if n_chunks > world_size:
bytelatent/logger.py CHANGED
@@ -91,7 +91,7 @@ def init_logger(
91
  log_file: str | None = None,
92
  *,
93
  name: str | None = None,
94
- level: str = "NOTSET",
95
  ):
96
  """
97
  Setup logging.
 
91
  log_file: str | None = None,
92
  *,
93
  name: str | None = None,
94
+ level: str = "INFO",
95
  ):
96
  """
97
  Setup logging.
requirements.txt CHANGED
@@ -20,3 +20,4 @@ altair
20
  submitit
21
  typer
22
  rich
 
 
20
  submitit
21
  typer
22
  rich
23
+ fsspec[full]
setup/download_prepare_hf_data.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import subprocess
6
  import time
7
 
 
8
  import requests
9
  from huggingface_hub import snapshot_download
10
 
@@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns):
38
  print(f"Dataset downloaded to {local_dir}")
39
 
40
 
41
- def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
 
 
42
  from datatrove.executor import LocalPipelineExecutor
43
  from datatrove.pipeline.readers import ParquetReader
44
  from datatrove.pipeline.writers import JsonlWriter
45
 
 
 
 
 
 
 
 
 
46
  pipeline_exec = LocalPipelineExecutor(
47
  pipeline=[
48
  ParquetReader(
@@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
52
  glob_pattern="**/*.parquet",
53
  ),
54
  JsonlWriter(
55
- tgt_dir,
56
  output_filename=dataset + ".chunk.${rank}.jsonl",
57
  compression=None,
58
  ),
@@ -77,7 +88,7 @@ def setup_terashuf(work_dir):
77
  return terashuf_dir
78
 
79
 
80
- def main(dataset, memory, data_dir, seed=42, nchunks=32):
81
  # Configuration
82
  repo_id = {
83
  "fineweb_edu": "HuggingFaceFW/fineweb-edu",
 
5
  import subprocess
6
  import time
7
 
8
+ import fsspec
9
  import requests
10
  from huggingface_hub import snapshot_download
11
 
 
39
  print(f"Dataset downloaded to {local_dir}")
40
 
41
 
42
+ def parquet_to_jsonl(
43
+ dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None
44
+ ):
45
  from datatrove.executor import LocalPipelineExecutor
46
  from datatrove.pipeline.readers import ParquetReader
47
  from datatrove.pipeline.writers import JsonlWriter
48
 
49
+ if tgt_dir.startswith("s3//"):
50
+ if s3_profile is None:
51
+ out_spec = tgt_dir
52
+ else:
53
+ out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile))
54
+ else:
55
+ out_spec = tgt_dir
56
+
57
  pipeline_exec = LocalPipelineExecutor(
58
  pipeline=[
59
  ParquetReader(
 
63
  glob_pattern="**/*.parquet",
64
  ),
65
  JsonlWriter(
66
+ out_spec,
67
  output_filename=dataset + ".chunk.${rank}.jsonl",
68
  compression=None,
69
  ),
 
88
  return terashuf_dir
89
 
90
 
91
+ def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None):
92
  # Configuration
93
  repo_id = {
94
  "fineweb_edu": "HuggingFaceFW/fineweb-edu",