M3Site / esm /tokenization /sequence_tokenizer.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
raw
history blame contribute delete
2.27 kB
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C
class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
"""
Constructs an ESM tokenizer.
"""
model_input_names = ["sequence_tokens", "attention_mask"]
def __init__(
self,
unk_token="<unk>",
cls_token="<cls>",
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
chainbreak_token="|",
**kwargs,
):
all_tokens = C.SEQUENCE_VOCAB
token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
# a character-level tokenizer is the same as BPE with no token merges
bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
tokenizer = Tokenizer(bpe)
special_tokens = [cls_token, pad_token, mask_token, eos_token, chainbreak_token]
additional_special_tokens = [chainbreak_token]
tokenizer.add_special_tokens(
special_tokens,
)
# This is where we configure the automatic addition of special tokens when we call
# tokenizer(text, add_special_tokens=True). Note that you can also configure how two
# sequences are merged if you want.
tokenizer.post_processor = TemplateProcessing( # type: ignore
single="<cls> $A <eos>",
special_tokens=[
("<cls>", tokenizer.token_to_id("<cls>")),
("<eos>", tokenizer.token_to_id("<eos>")),
],
)
super().__init__(
tokenizer_object=tokenizer,
unk_token=unk_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
# These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
@property
def bos_token(self):
return self.cls_token
@property
def bos_token_id(self):
return self.cls_token_id