File size: 3,480 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os

os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from collections import defaultdict  # noqa: E402

from datasets import load_dataset  # noqa: E402
from torch.utils.data import DataLoader, IterableDataset  # noqa: E402
from tqdm.auto import tqdm  # noqa: E402
from transformers import AutoTokenizer  # noqa: E402

ds = load_dataset("sdlm/data/dolma/dolma_dataset.py", streaming=True)

text_column_name = "text"
ds = ds.select_columns([text_column_name, "source"])
ds["train"] = ds["train"].shuffle(seed=42, buffer_size=10_000)


tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    revision="26bca36bde8333b5d7f72e9ed20ccda6a618af24",
    use_fast=True,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})


def tokenize_function(examples):
    """
    from sdlm/data/data_utils.py (`tokenize_data_new`)
    """
    # Remove empty lines
    examples[text_column_name] = [
        line
        for line in examples[text_column_name]
        if len(line) > 0 and not line.isspace()
    ]
    return tokenizer(
        examples[text_column_name],
        # hard coded
        padding="max_length",
        truncation=True,
        # hard coded
        max_length=512,
        return_special_tokens_mask=True,
    )


tokenized_datasets = ds.map(
    tokenize_function,
    batched=True,
    remove_columns=[text_column_name],
)


def simple_collate_fn(xs):
    """simple collate fn that collects key-values from dict"""
    result = defaultdict(list)
    for x in xs:
        for key, value in x.items():
            result[key].append(value)
    return result


def source_collat_fn(xs):
    result = simple_collate_fn(xs)
    return result["source"]


def tokenize_collate_fn(xs):
    result = simple_collate_fn(xs)
    return tokenize_function(result)


# from https://github.com/huggingface/datasets/issues/6279
# related https://discuss.huggingface.co/t/slow-dataloader-with-big-batch-size/57224
class Dataset2Iterable(IterableDataset):
    """
    Wrapper to use a HF dataset as pytorch IterableDataset to speed up data loading.
    """

    def __init__(self, dataset, batch_size=1, shuffle=True):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            self.dataset.shuffle()
        return self.dataset.iter(batch_size=self.batch_size)


# returns source information
source_dataloader = DataLoader(
    ds["train"],
    batch_size=8,
    num_workers=64,
    collate_fn=source_collat_fn,
    persistent_workers=True,
    prefetch_factor=2,
)

# returns tokens; current method via ds.map (very slow)
# also freezes if num_workers is too big ( > 1)
token_dataloader_v1 = DataLoader(
    tokenized_datasets["train"],
    batch_size=8,
    num_workers=32,
)

# returns tokens; grab text and tokenize in collate_fn on the fly
token_dataloader_v2 = DataLoader(
    ds["train"],
    batch_size=8,
    num_workers=32,
    collate_fn=tokenize_collate_fn,
)

token_dataloader_v3 = DataLoader(
    Dataset2Iterable(tokenized_datasets["train"]),
    batch_size=8,
    num_workers=0,  # required
)

# change params to test
stop_iter = 1_000
dataloader_to_test = token_dataloader_v3

for i, x in enumerate(tqdm(dataloader_to_test)):
    if i == stop_iter:
        break
    # just check iteration speed
    tqdm.write(str(i))
    # check content (for source)
    # print(i, x)