File size: 4,771 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff36aa8
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (c) Meta Platforms, Inc. and affiliates.
import re

from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.tokenizers.constants import (
    BOE_ID,
    BOS_ID,
    BPE_ID,
    BYTE_UNITS,
    EOS_ID,
    OFFSET,
    PAD_ID,
)
from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer


def convert_to_bytes(s):
    # check if the output is a bytes like object of the format <0x00>
    if re.match(r"<0x[0-9a-fA-F]+>", s):
        return bytes.fromhex(s[3:-1])
    else:
        return bytes(s, "utf-8", errors="ignore")


def text2bytes_bpe_delims(
    text: str,
    *,
    bpe_tokenizer,
    bpe_id: int,
    offsetting_special_char: int,
    add_bos: bool,
    add_eos: bool,
):
    cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos)
    # merge the leading space tokens
    leading_space_tokens = []
    other_bpe_tokens = []
    leading = True
    for token in cur_bpe:
        bpe_str = bpe_tokenizer.sp_model.id_to_piece(token)
        if leading and all(c == "▁" for c in bpe_str):
            leading_space_tokens.append(bpe_str)
        else:
            leading = False
            other_bpe_tokens.append(bpe_str)
    cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens

    # Remove the '▁' characters
    bpe_strs = []
    for i, bpe_str in enumerate(cur_bpe_strs):
        if (
            len(bpe_strs) <= 1
            and all([c == " " for s in bpe_strs for c in s])
            and not all(c == "▁" for c in bpe_str)
        ):
            # Remove leading space for first non space token.
            bpe_str = bpe_str.replace("▁", "")
        elif i == 0 and all(c == "▁" for c in bpe_str):
            bpe_str = " " * (len(text) - len(text.lstrip(" ")))
        else:
            bpe_str = bpe_str.replace("▁", " ")
        if len(bpe_str) > 0:
            bpe_strs.append(bpe_str)
    ex_seq = []
    # Convert bpe tokens to bytes
    for s in bpe_strs:
        byte_chunk = convert_to_bytes(s)
        proc_chunk = [int(unit) for unit in byte_chunk]
        ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk)

    return ex_seq


class BltTokenizer(Tokenizer):
    def __init__(
        self,
        *,
        vocab_size_unit_1: int = BYTE_UNITS,
        bpe_delim: bool = False,
        bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
        add_bos: bool = True,
        add_eos: bool = True,
    ):
        self.add_bos = add_bos
        self.add_eos = add_eos
        self.vocab_size_unit_1 = vocab_size_unit_1
        self.boe_id = BOE_ID
        self.bos_id = BOS_ID
        self.eos_id = EOS_ID
        self.pad_id = PAD_ID
        self.bpe_id = BPE_ID
        self.bpe_tokenizer_path = bpe_tokenizer_path
        if bpe_delim:
            self.bpe_tokenizer = SentencePieceTokenizer(
                model_path=self.bpe_tokenizer_path
            )
        else:
            self.bpe_tokenizer = None
        self.bpe_delim = bpe_delim
        self.offsetting_special_char = OFFSET
        self.vocab_size_unit_1 = vocab_size_unit_1
        self.n_words = vocab_size_unit_1 + self.offsetting_special_char

    def get_vocab_size(self) -> int:
        return self.n_words

    def encode(
        self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
    ):
        if add_bos is None:
            add_bos = self.add_bos
        if add_eos is None:
            add_eos = self.add_eos

        if self.bpe_delim:
            tokens = text2bytes_bpe_delims(
                text,
                bpe_tokenizer=self.bpe_tokenizer,
                bpe_id=self.bpe_id,
                offsetting_special_char=self.offsetting_special_char,
                add_bos=False,
                add_eos=False,
            )
        else:
            tokens = bytes(text, encoding="utf-8", errors="ignore")

        # Offsetting
        tokens = [int(unit) + self.offsetting_special_char for unit in tokens]

        if add_bos:
            tokens.insert(0, self.bos_id)
        if add_eos:
            tokens.append(self.eos_id)

        return tokens

    def decode(self, tokens: list[int], cut_at_eos: bool = False):
        if cut_at_eos:
            for k, t in enumerate(tokens):
                if t == self.eos_id:
                    tokens = tokens[: k + 1]
                    break
        return bytes(
            [
                tok - self.offsetting_special_char
                for tok in tokens
                if tok - self.offsetting_special_char >= 0
            ]
        ).decode("utf-8", errors="ignore")

    def get_token_offsets(self, text: str, tokens: list[int] | None = None):
        # TODO: Figure out what this does
        raise NotImplementedError()