Spaces:
Sleeping
Sleeping
File size: 2,478 Bytes
5740893 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from tokenizers.bpe import Tokenizer, get_stats, merge
import json
class BasicTokenizer(Tokenizer):
def __init__(self):
super().__init__()
self.token_to_id = {}
self.id_to_token = {}
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256 and vocab_size <= 5000
num_merges = vocab_size - 256
text_bytes = text.encode('utf-8')
ids = list(text_bytes)
merges = {}
vocab = {idx: bytes([idx]) for idx in range(256)}
# Initialize token mappings
self.token_to_id = {bytes([idx]).decode('utf-8', errors='replace'): idx for idx in range(256)}
self.id_to_token = {idx: bytes([idx]).decode('utf-8', errors='replace') for idx in range(256)}
for i in range(num_merges):
stats = get_stats(ids)
if not stats:
break
pair = max(stats, key=stats.get)
idx = 256 + i
ids = merge(ids, pair, idx)
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# Update token mappings
token = vocab[idx].decode('utf-8', errors='replace')
self.token_to_id[token] = idx
self.id_to_token[idx] = token
if verbose and i % 20 == 0:
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
self.merges = merges
self.vocab = vocab
self._save_vocabulary()
def _save_vocabulary(self):
vocabulary = {
'token_to_id': self.token_to_id,
'id_to_token': self.id_to_token,
'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
}
with open('vocabulary.json', 'w', encoding='utf-8') as f:
json.dump(vocabulary, f, ensure_ascii=False, indent=2)
def decode(self, ids):
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
return text
def encode(self, text):
text_bytes = text.encode("utf-8")
ids = list(text_bytes)
while len(ids) >= 2:
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
if pair not in self.merges:
break
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
|