kishkath commited on
Commit
5740893
·
verified ·
1 Parent(s): ad067b3

Upload 2 files

Browse files
Files changed (2) hide show
  1. tokenizers/basic.py +73 -0
  2. tokenizers/bpe.py +164 -0
tokenizers/basic.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers.bpe import Tokenizer, get_stats, merge
2
+ import json
3
+
4
+
5
+ class BasicTokenizer(Tokenizer):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.token_to_id = {}
9
+ self.id_to_token = {}
10
+
11
+ def train(self, text, vocab_size, verbose=False):
12
+ assert vocab_size >= 256 and vocab_size <= 5000
13
+ num_merges = vocab_size - 256
14
+
15
+ text_bytes = text.encode('utf-8')
16
+ ids = list(text_bytes)
17
+
18
+ merges = {}
19
+ vocab = {idx: bytes([idx]) for idx in range(256)}
20
+
21
+ # Initialize token mappings
22
+ self.token_to_id = {bytes([idx]).decode('utf-8', errors='replace'): idx for idx in range(256)}
23
+ self.id_to_token = {idx: bytes([idx]).decode('utf-8', errors='replace') for idx in range(256)}
24
+
25
+ for i in range(num_merges):
26
+ stats = get_stats(ids)
27
+ if not stats:
28
+ break
29
+ pair = max(stats, key=stats.get)
30
+ idx = 256 + i
31
+ ids = merge(ids, pair, idx)
32
+ merges[pair] = idx
33
+ vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
34
+
35
+ # Update token mappings
36
+ token = vocab[idx].decode('utf-8', errors='replace')
37
+ self.token_to_id[token] = idx
38
+ self.id_to_token[idx] = token
39
+
40
+ if verbose and i % 20 == 0:
41
+ print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
42
+
43
+ self.merges = merges
44
+ self.vocab = vocab
45
+ self._save_vocabulary()
46
+
47
+ def _save_vocabulary(self):
48
+ vocabulary = {
49
+ 'token_to_id': self.token_to_id,
50
+ 'id_to_token': self.id_to_token,
51
+ 'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
52
+ }
53
+ with open('vocabulary.json', 'w', encoding='utf-8') as f:
54
+ json.dump(vocabulary, f, ensure_ascii=False, indent=2)
55
+
56
+ def decode(self, ids):
57
+ text_bytes = b"".join(self.vocab[idx] for idx in ids)
58
+ text = text_bytes.decode("utf-8", errors="replace")
59
+ return text
60
+
61
+ def encode(self, text):
62
+ text_bytes = text.encode("utf-8")
63
+ ids = list(text_bytes)
64
+ while len(ids) >= 2:
65
+ stats = get_stats(ids)
66
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
67
+
68
+ if pair not in self.merges:
69
+ break
70
+
71
+ idx = self.merges[pair]
72
+ ids = merge(ids, pair, idx)
73
+ return ids
tokenizers/bpe.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains the base Tokenizer class and a few common helper functions.
3
+ The base class also contains the (common) save/load functionality.
4
+ It would be possible to be a lot more strict about the interface and
5
+ e.g. isolating all regex/pattern parts to the RegexTokenizer, but
6
+ some concessions are made for simplicity.
7
+ """
8
+ import unicodedata
9
+
10
+ # -----------------------------------------------------------------------------
11
+ # a few helper functions useful for both BasicTokenizer and RegexTokenizer
12
+
13
+ def get_stats(ids, counts=None):
14
+ """
15
+ Given a list of integers, return a dictionary of counts of consecutive pairs
16
+ Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
17
+ Optionally allows to update an existing dictionary of counts
18
+ """
19
+ counts = {} if counts is None else counts
20
+ for pair in zip(ids, ids[1:]): # iterate consecutive elements
21
+ counts[pair] = counts.get(pair, 0) + 1
22
+ return counts
23
+
24
+
25
+ def merge(ids, pair, idx):
26
+ """
27
+ In the list of integers (ids), replace all consecutive occurrences
28
+ of pair with the new integer token idx
29
+ Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
30
+ """
31
+ newids = []
32
+ i = 0
33
+ while i < len(ids):
34
+ # if not at the very last position AND the pair matches, replace it
35
+ if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
36
+ newids.append(idx)
37
+ i += 2
38
+ else:
39
+ newids.append(ids[i])
40
+ i += 1
41
+ return newids
42
+
43
+ # first two helper functions...
44
+ def replace_control_characters(s: str) -> str:
45
+ # we don't want to print control characters
46
+ # which distort the output (e.g. \n or much worse)
47
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
48
+ # http://www.unicode.org/reports/tr44/#GC_Values_Table
49
+ chars = []
50
+ for ch in s:
51
+ if unicodedata.category(ch)[0] != "C":
52
+ chars.append(ch) # this character is ok
53
+ else:
54
+ chars.append(f"\\u{ord(ch):04x}") # escape
55
+ return "".join(chars)
56
+
57
+ def render_token(t: bytes) -> str:
58
+ # pretty print a token, escaping control characters
59
+ s = t.decode('utf-8', errors='replace')
60
+ s = replace_control_characters(s)
61
+ return s
62
+
63
+ # -----------------------------------------------------------------------------
64
+ # the base Tokenizer class
65
+
66
+ class Tokenizer:
67
+ """Base class for Tokenizers"""
68
+
69
+ def __init__(self):
70
+ # default: vocab size of 256 (all bytes), no merges, no patterns
71
+ self.merges = {} # (int, int) -> int
72
+ self.pattern = "" # str
73
+ self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
74
+ self.vocab = self._build_vocab() # int -> bytes
75
+
76
+ def train(self, text, vocab_size, verbose=False):
77
+ # Tokenizer can train a vocabulary of size vocab_size from text
78
+ raise NotImplementedError
79
+
80
+ def encode(self, text):
81
+ # Tokenizer can encode a string into a list of integers
82
+ raise NotImplementedError
83
+
84
+ def decode(self, ids):
85
+ # Tokenizer can decode a list of integers into a string
86
+ raise NotImplementedError
87
+
88
+ def _build_vocab(self):
89
+ # vocab is simply and deterministically derived from merges
90
+ vocab = {idx: bytes([idx]) for idx in range(256)}
91
+ for (p0, p1), idx in self.merges.items():
92
+ vocab[idx] = vocab[p0] + vocab[p1]
93
+ for special, idx in self.special_tokens.items():
94
+ vocab[idx] = special.encode("utf-8")
95
+ return vocab
96
+
97
+ def save(self, file_prefix):
98
+ """
99
+ Saves two files: file_prefix.vocab and file_prefix.model
100
+ This is inspired (but not equivalent to!) sentencepiece's model saving:
101
+ - model file is the critical one, intended for load()
102
+ - vocab file is just a pretty printed version for human inspection only
103
+ """
104
+ # write the model: to be used in load() later
105
+ model_file = file_prefix + ".model"
106
+ with open(model_file, 'w') as f:
107
+ # write the version, pattern and merges, that's all that's needed
108
+ f.write("minbpe v1\n")
109
+ f.write(f"{self.pattern}\n")
110
+ # write the special tokens, first the number of them, then each one
111
+ f.write(f"{len(self.special_tokens)}\n")
112
+ for special, idx in self.special_tokens.items():
113
+ f.write(f"{special} {idx}\n")
114
+ # the merges dict
115
+ for idx1, idx2 in self.merges:
116
+ f.write(f"{idx1} {idx2}\n")
117
+ # write the vocab: for the human to look at
118
+ vocab_file = file_prefix + ".vocab"
119
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
120
+ with open(vocab_file, "w", encoding="utf-8") as f:
121
+ for idx, token in self.vocab.items():
122
+ # note: many tokens may be partial utf-8 sequences
123
+ # and cannot be decoded into valid strings. Here we're using
124
+ # errors='replace' to replace them with the replacement char �.
125
+ # this also means that we couldn't possibly use .vocab in load()
126
+ # because decoding in this way is a lossy operation!
127
+ s = render_token(token)
128
+ # find the children of this token, if any
129
+ if idx in inverted_merges:
130
+ # if this token has children, render it nicely as a merge
131
+ idx0, idx1 = inverted_merges[idx]
132
+ f.write(f"{s}\n")
133
+ else:
134
+ # otherwise this is leaf token, just print it
135
+ # (this should just be the first 256 tokens, the bytes)
136
+ f.write(f"{s}\n")
137
+
138
+ def load(self, model_file):
139
+ """Inverse of save() but only for the model file"""
140
+ assert model_file.endswith(".model")
141
+ # read the model file
142
+ merges = {}
143
+ special_tokens = {}
144
+ idx = 256
145
+ with open(model_file, 'r', encoding="utf-8") as f:
146
+ # read the version
147
+ version = f.readline().strip()
148
+ assert version == "minbpe v1"
149
+ # read the pattern
150
+ self.pattern = f.readline().strip()
151
+ # read the special tokens
152
+ num_special = int(f.readline().strip())
153
+ for _ in range(num_special):
154
+ special, special_idx = f.readline().strip().split()
155
+ special_tokens[special] = int(special_idx)
156
+ # read the merges
157
+ for line in f:
158
+ idx1, idx2 = map(int, line.split())
159
+ merges[(idx1, idx2)] = idx
160
+ idx += 1
161
+ self.merges = merges
162
+ print('lenmerges: ', len(self.merges))
163
+ self.special_tokens = special_tokens
164
+ self.vocab = self._build_vocab()