|
import torch
|
|
|
|
class SimpleTokenizer:
|
|
def __init__(self, vocab_path):
|
|
self.char_to_idx = torch.load(vocab_pth)
|
|
|
|
|
|
if '<unk>' not in self.char_to_idx:
|
|
self.char_to_idx['<unk>'] = max(self.char_to_idx.values()) + 1
|
|
|
|
self.idx_to_char = {i: c for c, i in self.char_to_idx.items()}
|
|
|
|
|
|
def encode(self, text):
|
|
return [self.char_to_idx.get(c, self.char_to_idx.get('<unk>', 0)) for c in text]
|
|
|
|
def decode(self, indices):
|
|
return ''.join([self.idx_to_char.get(i, '') for i in indices])
|
|
|
|
|
|
vocab_path = 'vocab.pth'
|
|
tokenizer = SimpleTokenizer(vocab_path)
|
|
|
|
text = "Hello, world!"
|
|
tokens = tokenizer.encode(text)
|
|
print(tokens)
|
|
|
|
decoded_text = tokenizer.decode(tokens)
|
|
print(decoded_text)
|
|
|