par-meta commited on
Commit
ff36aa8
·
unverified ·
1 Parent(s): a6ed14f

Add vocab and seq len abstract fields (#66)

Browse files
bytelatent/base_transformer.py CHANGED
@@ -1,4 +1,5 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
2
  import logging
3
  import os
4
  from enum import Enum
@@ -572,7 +573,13 @@ class TransformerBlock(nn.Module):
572
  self.ffn_norm.reset_parameters()
573
 
574
 
575
- class BaseTransformer(nn.Module):
 
 
 
 
 
 
576
  def __init__(self, args: BaseTransformerArgs):
577
  super().__init__()
578
  self.dim = args.dim
@@ -593,6 +600,9 @@ class BaseTransformer(nn.Module):
593
  for _ in range(args.n_layers):
594
  self.layers.append(TransformerBlock(args))
595
 
 
 
 
596
  def forward(
597
  self,
598
  h,
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import abc
3
  import logging
4
  import os
5
  from enum import Enum
 
573
  self.ffn_norm.reset_parameters()
574
 
575
 
576
+ class SequenceModelWithOutput(abc.ABC):
577
+ @abc.abstractmethod
578
+ def get_output_seq_len(self) -> int:
579
+ pass
580
+
581
+
582
+ class BaseTransformer(nn.Module, SequenceModelWithOutput):
583
  def __init__(self, args: BaseTransformerArgs):
584
  super().__init__()
585
  self.dim = args.dim
 
600
  for _ in range(args.n_layers):
601
  self.layers.append(TransformerBlock(args))
602
 
603
+ def get_output_seq_len(self):
604
+ return self.max_seqlen
605
+
606
  def forward(
607
  self,
608
  h,
bytelatent/model/blt.py CHANGED
@@ -12,6 +12,7 @@ from typing_extensions import Self
12
  from bytelatent.base_transformer import (
13
  BaseTransformerArgs,
14
  InitStdFactor,
 
15
  TransformerBlock,
16
  )
17
  from bytelatent.data.patcher import Patcher, PatcherArgs
@@ -766,7 +767,7 @@ def compute_hash_embeddings(
766
  return local_encoder_embeds
767
 
768
 
769
- class ByteLatentTransformer(nn.Module):
770
  """
771
  The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
772
  by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
@@ -856,6 +857,9 @@ class ByteLatentTransformer(nn.Module):
856
  )
857
  )
858
 
 
 
 
859
  def forward(
860
  self,
861
  tokens: torch.Tensor,
 
12
  from bytelatent.base_transformer import (
13
  BaseTransformerArgs,
14
  InitStdFactor,
15
+ SequenceModelWithOutput,
16
  TransformerBlock,
17
  )
18
  from bytelatent.data.patcher import Patcher, PatcherArgs
 
767
  return local_encoder_embeds
768
 
769
 
770
+ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
771
  """
772
  The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
773
  by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
 
857
  )
858
  )
859
 
860
+ def get_output_seq_len(self):
861
+ return self.max_seqlen
862
+
863
  def forward(
864
  self,
865
  tokens: torch.Tensor,
bytelatent/tokenizers/abstract_tokenizer.py CHANGED
@@ -17,3 +17,7 @@ class Tokenizer(abc.ABC):
17
  ) -> tuple[list[str], list[int]]:
18
  """Return the offsets of the tokens in the original text. Only used for evaluation."""
19
  pass
 
 
 
 
 
17
  ) -> tuple[list[str], list[int]]:
18
  """Return the offsets of the tokens in the original text. Only used for evaluation."""
19
  pass
20
+
21
+ @abc.abstractmethod
22
+ def get_vocab_size(self) -> int:
23
+ pass
bytelatent/tokenizers/blt_tokenizer.py CHANGED
@@ -101,6 +101,9 @@ class BltTokenizer(Tokenizer):
101
  self.vocab_size_unit_1 = vocab_size_unit_1
102
  self.n_words = vocab_size_unit_1 + self.offsetting_special_char
103
 
 
 
 
104
  def encode(
105
  self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
106
  ):
 
101
  self.vocab_size_unit_1 = vocab_size_unit_1
102
  self.n_words = vocab_size_unit_1 + self.offsetting_special_char
103
 
104
+ def get_vocab_size(self) -> int:
105
+ return self.n_words
106
+
107
  def encode(
108
  self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
109
  ):
bytelatent/tokenizers/sentence_piece_tokenizer.py CHANGED
@@ -35,6 +35,9 @@ class SentencePieceTokenizer(Tokenizer):
35
  )
36
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
37
 
 
 
 
38
  def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
39
  if add_bos is None:
40
  add_bos = self.add_bos
 
35
  )
36
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
37
 
38
+ def get_vocab_size(self) -> int:
39
+ return self.n_words
40
+
41
  def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
42
  if add_bos is None:
43
  add_bos = self.add_bos
bytelatent/tokenizers/tiktoken_tokenizer.py CHANGED
@@ -53,6 +53,9 @@ class TikTokenTokenizer(Tokenizer):
53
  f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
54
  )
55
 
 
 
 
56
  def encode(self, s: str, add_bos: bool, add_eos: bool):
57
  assert isinstance(s, str)
58
 
 
53
  f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
54
  )
55
 
56
+ def get_vocab_size(self) -> int:
57
+ return self.n_words
58
+
59
  def encode(self, s: str, add_bos: bool, add_eos: bool):
60
  assert isinstance(s, str)
61