Spaces:
Sleeping
Sleeping
from tokenizers.bpe import Tokenizer, get_stats, merge | |
import json | |
class BasicTokenizer(Tokenizer): | |
def __init__(self): | |
super().__init__() | |
self.token_to_id = {} | |
self.id_to_token = {} | |
def train(self, text, vocab_size, verbose=False): | |
assert vocab_size >= 256 and vocab_size <= 5000 | |
num_merges = vocab_size - 256 | |
text_bytes = text.encode('utf-8') | |
ids = list(text_bytes) | |
merges = {} | |
vocab = {idx: bytes([idx]) for idx in range(256)} | |
# Initialize token mappings | |
self.token_to_id = {bytes([idx]).decode('utf-8', errors='replace'): idx for idx in range(256)} | |
self.id_to_token = {idx: bytes([idx]).decode('utf-8', errors='replace') for idx in range(256)} | |
for i in range(num_merges): | |
stats = get_stats(ids) | |
if not stats: | |
break | |
pair = max(stats, key=stats.get) | |
idx = 256 + i | |
ids = merge(ids, pair, idx) | |
merges[pair] = idx | |
vocab[idx] = vocab[pair[0]] + vocab[pair[1]] | |
# Update token mappings | |
token = vocab[idx].decode('utf-8', errors='replace') | |
self.token_to_id[token] = idx | |
self.id_to_token[idx] = token | |
if verbose and i % 20 == 0: | |
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") | |
self.merges = merges | |
self.vocab = vocab | |
self._save_vocabulary() | |
def _save_vocabulary(self): | |
vocabulary = { | |
'token_to_id': self.token_to_id, | |
'id_to_token': self.id_to_token, | |
'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()} | |
} | |
with open('vocabulary.json', 'w', encoding='utf-8') as f: | |
json.dump(vocabulary, f, ensure_ascii=False, indent=2) | |
def decode(self, ids): | |
text_bytes = b"".join(self.vocab[idx] for idx in ids) | |
text = text_bytes.decode("utf-8", errors="replace") | |
return text | |
def encode(self, text): | |
text_bytes = text.encode("utf-8") | |
ids = list(text_bytes) | |
while len(ids) >= 2: | |
stats = get_stats(ids) | |
pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
if pair not in self.merges: | |
break | |
idx = self.merges[pair] | |
ids = merge(ids, pair, idx) | |
return ids | |