File size: 2,018 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Meta Platforms, Inc. and affiliates.
import subprocess
from pathlib import Path

import luigi

# CHANGEME: Change this to point to your data
BASE_DIR = Path("datasets")
DATASETS = ["dclm"]
TARGET_DIR = Path("entropy_preprocess")

SHARD_SCRIPT = """split -C 2500m -d {source} {destination}.shard_"""


def list_dataset_shards(dataset: str):
    dataset_dir = BASE_DIR / dataset
    return list(dataset_dir.glob("*.chunk.*.jsonl"))


class ChunkFile(luigi.ExternalTask):
    file = luigi.Parameter()

    def output(self):
        return luigi.LocalTarget(self.file)


class ShardDatasetChunk(luigi.Task):
    dataset_name = luigi.Parameter()
    chunk_file = luigi.Parameter()

    def _chunk_filename(self):
        return Path(self.chunk_file).name

    def requires(self):
        return ChunkFile(self.chunk_file)

    def run(self):
        destination_dir = TARGET_DIR / str(self.dataset_name)
        destination_dir.mkdir(parents=True, exist_ok=True)
        destination = destination_dir / self._chunk_filename()
        subprocess.check_output(
            SHARD_SCRIPT.format(source=str(self.chunk_file), destination=destination),
            shell=True,
        )
        (
            Path(TARGET_DIR)
            / str(self.dataset_name)
            / f"{self._chunk_filename()}.shard.COMPLETE"
        ).touch()

    def output(self):
        return luigi.LocalTarget(
            TARGET_DIR
            / str(self.dataset_name)
            / f"{self._chunk_filename()}.shard.COMPLETE"
        )


class ShardDataset(luigi.WrapperTask):
    dataset_name = luigi.Parameter()

    def requires(self):
        for f in list_dataset_shards(self.dataset_name):
            yield ShardDatasetChunk(dataset_name=self.dataset_name, chunk_file=str(f))


class ShardAllDatasets(luigi.WrapperTask):
    def requires(self):
        for d in DATASETS:
            yield ShardDataset(dataset_name=d)


if __name__ == "__main__":
    luigi.build([ShardAllDatasets()], local_scheduler=True, workers=128)