File size: 4,274 Bytes
bcc039b
 
 
 
 
 
 
fc3399e
 
 
 
bcc039b
 
 
 
fc3399e
 
 
 
 
bcc039b
 
 
 
 
fc3399e
bcc039b
fc3399e
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3399e
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any, Generator

import torch
from pydantic import BaseModel, ConfigDict

from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import (
    PydanticIteratorState,
    StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import (
    ArrowFileIterator,
    ArrowFileIteratorState,
)
from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
from bytelatent.data.iterators.looping_iterator import (
    LoopingIterator,
    LoopingIteratorState,
)
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs


class PreprocessIteratorState(PydanticIteratorState):
    model_config = ConfigDict(extra="forbid")
    arrow_file_iterator_state: (
        ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
    )
    add_tokens: bool
    add_patches: bool
    tokenizer_args: TokenizerArgs
    patcher_args: PatcherArgs

    def build(self):
        arrow_iterator = self.arrow_file_iterator_state.build()
        return PreprocessIterator(
            arrow_iterator,
            patcher_args=self.patcher_args,
            tokenizer_args=self.tokenizer_args,
            add_tokens=self.add_tokens,
            add_patches=self.add_patches,
        )


class PreprocessIterator(StatefulIterator):
    """
    Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
    preprocessing like tokenization and patching
    """

    def __init__(
        self,
        arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
        *,
        patcher_args: PatcherArgs,
        tokenizer_args: TokenizerArgs,
        add_tokens: bool = True,
        add_patches: bool = True,
    ):
        self.arrow_iterator = arrow_iterator
        self.tokenizer_args = tokenizer_args
        self.patcher_args = patcher_args
        self.add_tokens = add_tokens
        self.add_patches = add_patches
        self.tokenizer: BltTokenizer | None = None
        self.patcher: Patcher | None = None

    def get_state(self) -> PreprocessIteratorState:
        """
        The only state to maintain here is from arrow, there
        isn't any internal state on this iterator.
        """
        return PreprocessIteratorState(
            arrow_file_iterator_state=self.arrow_iterator.get_state(),
            tokenizer_args=self.tokenizer_args,
            patcher_args=self.patcher_args,
            add_tokens=self.add_tokens,
            add_patches=self.add_patches,
        )

    def create_iter(self) -> Generator[BltExample, Any, None]:
        if self.tokenizer is None and self.add_tokens:
            self.tokenizer = self.tokenizer_args.build()
        if self.patcher is None and self.add_patches:
            self.patcher = self.patcher_args.build()

        example_iter = self.arrow_iterator.create_iter()
        for example in example_iter:
            if self.add_tokens:
                tokens = self.tokenizer.encode(example.text)
            else:
                tokens = example.tokens
            if (
                self.patcher is not None
                and self.patcher.patching_mode == PatchingModeEnum.entropy
            ):
                assert (
                    example.entropies is not None
                ), "For patching, entropies cannot be None"
                entropies = torch.tensor(example.entropies).unsqueeze(0)
            else:
                entropies = None
            if self.patcher is None:
                patch_lengths = None
            else:
                patch_lengths = self.patcher.patch(
                    torch.tensor(tokens).unsqueeze(0),
                    include_next_token=False,
                    entropies=entropies,
                )[0][0].tolist()
            yield BltExample(
                sample_id=example.sample_id,
                text=example.text,
                tokens=tokens,
                mask=[True] * len(tokens),
                patch_lengths=patch_lengths,
                entropies=example.entropies,
            )