Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- tokenizers/basic.py +73 -0
- tokenizers/bpe.py +164 -0
tokenizers/basic.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers.bpe import Tokenizer, get_stats, merge
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
class BasicTokenizer(Tokenizer):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.token_to_id = {}
|
9 |
+
self.id_to_token = {}
|
10 |
+
|
11 |
+
def train(self, text, vocab_size, verbose=False):
|
12 |
+
assert vocab_size >= 256 and vocab_size <= 5000
|
13 |
+
num_merges = vocab_size - 256
|
14 |
+
|
15 |
+
text_bytes = text.encode('utf-8')
|
16 |
+
ids = list(text_bytes)
|
17 |
+
|
18 |
+
merges = {}
|
19 |
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
20 |
+
|
21 |
+
# Initialize token mappings
|
22 |
+
self.token_to_id = {bytes([idx]).decode('utf-8', errors='replace'): idx for idx in range(256)}
|
23 |
+
self.id_to_token = {idx: bytes([idx]).decode('utf-8', errors='replace') for idx in range(256)}
|
24 |
+
|
25 |
+
for i in range(num_merges):
|
26 |
+
stats = get_stats(ids)
|
27 |
+
if not stats:
|
28 |
+
break
|
29 |
+
pair = max(stats, key=stats.get)
|
30 |
+
idx = 256 + i
|
31 |
+
ids = merge(ids, pair, idx)
|
32 |
+
merges[pair] = idx
|
33 |
+
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
34 |
+
|
35 |
+
# Update token mappings
|
36 |
+
token = vocab[idx].decode('utf-8', errors='replace')
|
37 |
+
self.token_to_id[token] = idx
|
38 |
+
self.id_to_token[idx] = token
|
39 |
+
|
40 |
+
if verbose and i % 20 == 0:
|
41 |
+
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
|
42 |
+
|
43 |
+
self.merges = merges
|
44 |
+
self.vocab = vocab
|
45 |
+
self._save_vocabulary()
|
46 |
+
|
47 |
+
def _save_vocabulary(self):
|
48 |
+
vocabulary = {
|
49 |
+
'token_to_id': self.token_to_id,
|
50 |
+
'id_to_token': self.id_to_token,
|
51 |
+
'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
|
52 |
+
}
|
53 |
+
with open('vocabulary.json', 'w', encoding='utf-8') as f:
|
54 |
+
json.dump(vocabulary, f, ensure_ascii=False, indent=2)
|
55 |
+
|
56 |
+
def decode(self, ids):
|
57 |
+
text_bytes = b"".join(self.vocab[idx] for idx in ids)
|
58 |
+
text = text_bytes.decode("utf-8", errors="replace")
|
59 |
+
return text
|
60 |
+
|
61 |
+
def encode(self, text):
|
62 |
+
text_bytes = text.encode("utf-8")
|
63 |
+
ids = list(text_bytes)
|
64 |
+
while len(ids) >= 2:
|
65 |
+
stats = get_stats(ids)
|
66 |
+
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
|
67 |
+
|
68 |
+
if pair not in self.merges:
|
69 |
+
break
|
70 |
+
|
71 |
+
idx = self.merges[pair]
|
72 |
+
ids = merge(ids, pair, idx)
|
73 |
+
return ids
|
tokenizers/bpe.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contains the base Tokenizer class and a few common helper functions.
|
3 |
+
The base class also contains the (common) save/load functionality.
|
4 |
+
It would be possible to be a lot more strict about the interface and
|
5 |
+
e.g. isolating all regex/pattern parts to the RegexTokenizer, but
|
6 |
+
some concessions are made for simplicity.
|
7 |
+
"""
|
8 |
+
import unicodedata
|
9 |
+
|
10 |
+
# -----------------------------------------------------------------------------
|
11 |
+
# a few helper functions useful for both BasicTokenizer and RegexTokenizer
|
12 |
+
|
13 |
+
def get_stats(ids, counts=None):
|
14 |
+
"""
|
15 |
+
Given a list of integers, return a dictionary of counts of consecutive pairs
|
16 |
+
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
|
17 |
+
Optionally allows to update an existing dictionary of counts
|
18 |
+
"""
|
19 |
+
counts = {} if counts is None else counts
|
20 |
+
for pair in zip(ids, ids[1:]): # iterate consecutive elements
|
21 |
+
counts[pair] = counts.get(pair, 0) + 1
|
22 |
+
return counts
|
23 |
+
|
24 |
+
|
25 |
+
def merge(ids, pair, idx):
|
26 |
+
"""
|
27 |
+
In the list of integers (ids), replace all consecutive occurrences
|
28 |
+
of pair with the new integer token idx
|
29 |
+
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
|
30 |
+
"""
|
31 |
+
newids = []
|
32 |
+
i = 0
|
33 |
+
while i < len(ids):
|
34 |
+
# if not at the very last position AND the pair matches, replace it
|
35 |
+
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
|
36 |
+
newids.append(idx)
|
37 |
+
i += 2
|
38 |
+
else:
|
39 |
+
newids.append(ids[i])
|
40 |
+
i += 1
|
41 |
+
return newids
|
42 |
+
|
43 |
+
# first two helper functions...
|
44 |
+
def replace_control_characters(s: str) -> str:
|
45 |
+
# we don't want to print control characters
|
46 |
+
# which distort the output (e.g. \n or much worse)
|
47 |
+
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
|
48 |
+
# http://www.unicode.org/reports/tr44/#GC_Values_Table
|
49 |
+
chars = []
|
50 |
+
for ch in s:
|
51 |
+
if unicodedata.category(ch)[0] != "C":
|
52 |
+
chars.append(ch) # this character is ok
|
53 |
+
else:
|
54 |
+
chars.append(f"\\u{ord(ch):04x}") # escape
|
55 |
+
return "".join(chars)
|
56 |
+
|
57 |
+
def render_token(t: bytes) -> str:
|
58 |
+
# pretty print a token, escaping control characters
|
59 |
+
s = t.decode('utf-8', errors='replace')
|
60 |
+
s = replace_control_characters(s)
|
61 |
+
return s
|
62 |
+
|
63 |
+
# -----------------------------------------------------------------------------
|
64 |
+
# the base Tokenizer class
|
65 |
+
|
66 |
+
class Tokenizer:
|
67 |
+
"""Base class for Tokenizers"""
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
# default: vocab size of 256 (all bytes), no merges, no patterns
|
71 |
+
self.merges = {} # (int, int) -> int
|
72 |
+
self.pattern = "" # str
|
73 |
+
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
|
74 |
+
self.vocab = self._build_vocab() # int -> bytes
|
75 |
+
|
76 |
+
def train(self, text, vocab_size, verbose=False):
|
77 |
+
# Tokenizer can train a vocabulary of size vocab_size from text
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
def encode(self, text):
|
81 |
+
# Tokenizer can encode a string into a list of integers
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
def decode(self, ids):
|
85 |
+
# Tokenizer can decode a list of integers into a string
|
86 |
+
raise NotImplementedError
|
87 |
+
|
88 |
+
def _build_vocab(self):
|
89 |
+
# vocab is simply and deterministically derived from merges
|
90 |
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
91 |
+
for (p0, p1), idx in self.merges.items():
|
92 |
+
vocab[idx] = vocab[p0] + vocab[p1]
|
93 |
+
for special, idx in self.special_tokens.items():
|
94 |
+
vocab[idx] = special.encode("utf-8")
|
95 |
+
return vocab
|
96 |
+
|
97 |
+
def save(self, file_prefix):
|
98 |
+
"""
|
99 |
+
Saves two files: file_prefix.vocab and file_prefix.model
|
100 |
+
This is inspired (but not equivalent to!) sentencepiece's model saving:
|
101 |
+
- model file is the critical one, intended for load()
|
102 |
+
- vocab file is just a pretty printed version for human inspection only
|
103 |
+
"""
|
104 |
+
# write the model: to be used in load() later
|
105 |
+
model_file = file_prefix + ".model"
|
106 |
+
with open(model_file, 'w') as f:
|
107 |
+
# write the version, pattern and merges, that's all that's needed
|
108 |
+
f.write("minbpe v1\n")
|
109 |
+
f.write(f"{self.pattern}\n")
|
110 |
+
# write the special tokens, first the number of them, then each one
|
111 |
+
f.write(f"{len(self.special_tokens)}\n")
|
112 |
+
for special, idx in self.special_tokens.items():
|
113 |
+
f.write(f"{special} {idx}\n")
|
114 |
+
# the merges dict
|
115 |
+
for idx1, idx2 in self.merges:
|
116 |
+
f.write(f"{idx1} {idx2}\n")
|
117 |
+
# write the vocab: for the human to look at
|
118 |
+
vocab_file = file_prefix + ".vocab"
|
119 |
+
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
|
120 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
121 |
+
for idx, token in self.vocab.items():
|
122 |
+
# note: many tokens may be partial utf-8 sequences
|
123 |
+
# and cannot be decoded into valid strings. Here we're using
|
124 |
+
# errors='replace' to replace them with the replacement char �.
|
125 |
+
# this also means that we couldn't possibly use .vocab in load()
|
126 |
+
# because decoding in this way is a lossy operation!
|
127 |
+
s = render_token(token)
|
128 |
+
# find the children of this token, if any
|
129 |
+
if idx in inverted_merges:
|
130 |
+
# if this token has children, render it nicely as a merge
|
131 |
+
idx0, idx1 = inverted_merges[idx]
|
132 |
+
f.write(f"{s}\n")
|
133 |
+
else:
|
134 |
+
# otherwise this is leaf token, just print it
|
135 |
+
# (this should just be the first 256 tokens, the bytes)
|
136 |
+
f.write(f"{s}\n")
|
137 |
+
|
138 |
+
def load(self, model_file):
|
139 |
+
"""Inverse of save() but only for the model file"""
|
140 |
+
assert model_file.endswith(".model")
|
141 |
+
# read the model file
|
142 |
+
merges = {}
|
143 |
+
special_tokens = {}
|
144 |
+
idx = 256
|
145 |
+
with open(model_file, 'r', encoding="utf-8") as f:
|
146 |
+
# read the version
|
147 |
+
version = f.readline().strip()
|
148 |
+
assert version == "minbpe v1"
|
149 |
+
# read the pattern
|
150 |
+
self.pattern = f.readline().strip()
|
151 |
+
# read the special tokens
|
152 |
+
num_special = int(f.readline().strip())
|
153 |
+
for _ in range(num_special):
|
154 |
+
special, special_idx = f.readline().strip().split()
|
155 |
+
special_tokens[special] = int(special_idx)
|
156 |
+
# read the merges
|
157 |
+
for line in f:
|
158 |
+
idx1, idx2 = map(int, line.split())
|
159 |
+
merges[(idx1, idx2)] = idx
|
160 |
+
idx += 1
|
161 |
+
self.merges = merges
|
162 |
+
print('lenmerges: ', len(self.merges))
|
163 |
+
self.special_tokens = special_tokens
|
164 |
+
self.vocab = self._build_vocab()
|