File size: 2,478 Bytes
5740893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from tokenizers.bpe import Tokenizer, get_stats, merge
import json


class BasicTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()
        self.token_to_id = {}
        self.id_to_token = {}

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256 and vocab_size <= 5000
        num_merges = vocab_size - 256

        text_bytes = text.encode('utf-8')
        ids = list(text_bytes)

        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}
        
        # Initialize token mappings
        self.token_to_id = {bytes([idx]).decode('utf-8', errors='replace'): idx for idx in range(256)}
        self.id_to_token = {idx: bytes([idx]).decode('utf-8', errors='replace') for idx in range(256)}

        for i in range(num_merges):
            stats = get_stats(ids)
            if not stats:
                break
            pair = max(stats, key=stats.get)
            idx = 256 + i
            ids = merge(ids, pair, idx)
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            
            # Update token mappings
            token = vocab[idx].decode('utf-8', errors='replace')
            self.token_to_id[token] = idx
            self.id_to_token[idx] = token
            
            if verbose and i % 20 == 0:
                print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        self.merges = merges
        self.vocab = vocab
        self._save_vocabulary()

    def _save_vocabulary(self):
        vocabulary = {
            'token_to_id': self.token_to_id,
            'id_to_token': self.id_to_token,
            'merges': {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
        }
        with open('vocabulary.json', 'w', encoding='utf-8') as f:
            json.dump(vocabulary, f, ensure_ascii=False, indent=2)

    def decode(self, ids):
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    def encode(self, text):
        text_bytes = text.encode("utf-8")
        ids = list(text_bytes)
        while len(ids) >= 2:
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))

            if pair not in self.merges:
                break

            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids