Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from enum import Enum, auto | |
from typing import Any, Optional | |
import torch | |
from pydantic import model_validator | |
from torch import nn | |
from torch.nn.attention.flex_attention import create_block_mask | |
from typing_extensions import Self | |
from bytelatent.base_transformer import ( | |
BaseTransformerArgs, | |
InitStdFactor, | |
SequenceModelWithOutput, | |
) | |
from bytelatent.data.patcher import Patcher, PatcherArgs | |
from bytelatent.model.latent_transformer import GlobalTransformer | |
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs | |
from bytelatent.model.utils import downsample | |
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID | |
from huggingface_hub import PyTorchModelHubMixin | |
def attention_flops_per_token(n_layers, seq_len, dim, causal): | |
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 | |
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1)) | |
def get_num_flop_per_token( | |
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int | |
) -> int: | |
return 6 * num_non_embed_params + attention_flops_per_token( | |
n_layers, seq_len, dim, True | |
) | |
def causal_mask(b, h, q_idx, kv_idx): | |
return q_idx >= kv_idx | |
def setattrs(_self, **kwargs): | |
for k, v in kwargs.items(): | |
setattr(_self, k, v) | |
def get_encoder_dim_token_emb(args): | |
if args.dim_token is not None: | |
dim_token_emb = args.dim_token | |
elif args.use_local_encoder_transformer: | |
dim_token_emb = args.dim_local_encoder | |
else: | |
dim_token_emb = args.dim_global // args.patch_size | |
return dim_token_emb | |
def get_encoder_dim_patch_emb(args): | |
dim_patch_emb = None | |
if args.cross_attn_encoder: | |
if args.cross_attn_init_by_pooling: | |
dim_patch_emb = args.dim_local_encoder | |
else: | |
dim_patch_emb = args.dim_global | |
return dim_patch_emb | |
def get_global_dim_patch_emb(args): | |
dim_token_emb = get_encoder_dim_token_emb(args) | |
if args.cross_attn_encoder: | |
dim_patch_emb = dim_token_emb * args.cross_attn_k | |
elif ( | |
args.downsampling_by_pooling is None | |
or not args.downsampling_by_pooling | |
or len(args.downsampling_by_pooling) == 0 | |
): | |
dim_patch_emb = dim_token_emb * args.patch_size | |
else: | |
dim_patch_emb = dim_token_emb * sum( | |
[ | |
pooling in args.downsampling_by_pooling | |
for pooling in ["avg", "min", "max"] | |
] | |
) | |
return dim_patch_emb | |
def get_decoder_dim_token_emb(args): | |
if args.share_encoder_decoder_emb: | |
dim_token_emb = get_encoder_dim_token_emb(args) | |
elif args.dim_token is not None: | |
dim_token_emb = args.dim_token | |
else: | |
dim_token_emb = args.dim_local_decoder | |
return dim_token_emb | |
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: | |
if ngram_to_size_str is None: | |
return None | |
ngram_to_size = {} | |
for entry in ngram_to_size_str.split(","): | |
ngram, size = entry.split(":") | |
ngram = int(ngram) | |
size = int(size) | |
ngram_to_size[ngram] = size | |
return ngram_to_size | |
def fill_tokens(tokens, patch_size, fill_id): | |
batch_size, seq_len = tokens.shape | |
if seq_len % patch_size == 0: | |
return tokens | |
else: | |
remaining = patch_size - seq_len % patch_size | |
final_padding = tokens.new(batch_size, remaining).fill_(fill_id) | |
return torch.cat((tokens, final_padding), dim=1) | |
def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): | |
first_patch_length = patch_lengths[0, 0] | |
assert torch.all( | |
first_patch_length == patch_lengths[:, 0] | |
), "first patch should always be the same size (1 for dynamic, patch_size for static)." | |
assert ( | |
first_patch_length - nb_boe == 1 | |
), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})" | |
# Remove first patch from patch_ids for local decoder inputs and shift the last patch. | |
# decoder_patch_lengths = patch_lengths[:, 1:].clone() | |
# decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1) | |
decoder_patch_lengths = patch_lengths[:, 1:] | |
assert ( | |
decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0] | |
== patch_lengths.sum() | |
), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}" | |
assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}" | |
decoder_patch_ids = patch_ids_from_lengths( | |
patch_lengths=decoder_patch_lengths, seq_len=seq_len | |
) | |
return decoder_patch_ids | |
primes = [ | |
1000000007, | |
5915587277, | |
1500450271, | |
3267000013, | |
5754853343, | |
4093082899, | |
9576890767, | |
3628273133, | |
2860486313, | |
5463458053, | |
3367900313, | |
] | |
def rolling_polynomial_hash(t, hash_func_nb: int = 0): | |
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) | |
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) | |
return torch.sum(t * prime_powers, dim=-1) | |
def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2): | |
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64) | |
prime_powers = torch.stack([prime**i for i in range(group_size)]) | |
def rolling_polynomial_hash_fn(t): | |
return torch.sum(t * prime_powers, dim=-1) | |
return rolling_polynomial_hash_fn | |
def byte_group_hash_function( | |
x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 | |
): | |
""" | |
Returns a hash of the input x and maps it to a value in the range [0, max_hash]. | |
expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. | |
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. | |
Note: max hash can make a big difference on the number of collisions. | |
""" | |
with torch.no_grad(): | |
bs, seq_len = x.shape | |
# x_numpy = x.numpy() | |
# hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False) | |
# for i in range(bs): | |
# for j in range(seq_len): | |
# start = max(j, j-group_size+1) | |
# end = j+1 | |
# hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash) | |
prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) | |
x = torch.cat([prefix, x], dim=1) | |
windows = x.unfold(1, group_size, 1) | |
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) | |
hashes = rolling_polynomial_hash(windows, hash_func_nb) | |
hash_values_range = hashes % max_hash | |
hash_values_range.requires_grad = False | |
return hash_values_range | |
def create_patch_mask_from_ids( | |
patch_ids, num_patches, window=None, patches_as_queries=False | |
): | |
""" | |
Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) | |
is True if the patch id at position (i, j) is less than or equal to k. | |
Args: | |
patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. | |
num_patches (int): Total number of patches. | |
window (int): If not None, only considers patches within a window of size window. | |
patches_as_queries (bool): If True, the patches are used as queries | |
Returns: | |
torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. | |
""" | |
bs, seq_len = patch_ids.shape | |
if not patches_as_queries: | |
q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) | |
kv_ids = ( | |
torch.arange(num_patches, device=patch_ids.device) | |
.unsqueeze(0) | |
.unsqueeze(0) | |
.expand(bs, seq_len, num_patches) | |
) | |
else: | |
kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) | |
q_ids = ( | |
torch.arange(num_patches, device=patch_ids.device) | |
.unsqueeze(0) | |
.unsqueeze(-1) | |
.expand(bs, num_patches, seq_len) | |
) | |
if window is None: | |
mask = q_ids == kv_ids | |
else: | |
mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) | |
return mask | |
def cross_attn_mask( | |
patch_ids, | |
patch_lengths, | |
N, | |
patches_as_queries=False, | |
cross_attn_k=1, | |
window=None, | |
block_mask=True, | |
): | |
bs = patch_ids.shape[0] | |
with torch.no_grad(): | |
# Create the patch mask | |
cross_mask = create_patch_mask_from_ids( | |
patch_ids, | |
patch_lengths.shape[1], | |
window=window, | |
patches_as_queries=patches_as_queries, | |
).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) | |
q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N | |
kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k | |
assert cross_mask.shape == ( | |
bs, | |
q_len, | |
kv_len, | |
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" | |
if block_mask: | |
def patch_mask(b, h, q_idx, kv_idx): | |
return cross_mask[b, q_idx, kv_idx] | |
block_mask = create_block_mask( | |
patch_mask, | |
B=bs, | |
H=None, | |
Q_LEN=q_len, | |
KV_LEN=kv_len, | |
_compile=True, | |
) | |
return block_mask | |
else: | |
return torch.where( | |
cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) | |
).unsqueeze( | |
1 | |
) # [bs, 1, q_len, kv_len] | |
def get_blt_input( | |
tokens: torch.Tensor, | |
enforce_patch_size_multiple: bool, | |
nb_boe: torch.Tensor, | |
patch_size: int, | |
boe_id: int, | |
): | |
""" | |
This function returns X_et, X_gt and X_dt, the encoder, global, and decoder | |
tokens respectively. | |
Consider the input and target sequences: | |
X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] | |
Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] | |
with patch_size=4 | |
Note 1: that there will be no special tokens introduced at the patch level. | |
Note 2: X_e needs to be trimmed to be passed to Global | |
Current without boe: | |
X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] | |
X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch | |
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] | |
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] | |
--> lag fix: | |
X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] | |
X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] | |
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] | |
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] | |
Dynamic (current): | |
X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] | |
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] | |
entropy patching: | |
input: 7, bos, 9, 10 | |
pred (high entropy): eos, 8, 10, eos | |
X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] | |
X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] | |
X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] | |
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] | |
--> lag fix no boe (force single byte first patch): | |
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] | |
X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch | |
X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] | |
Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] | |
input: 4, 7, bos, 9, 10 | |
pred (high entropy): 5, eos, 8, 10, eos | |
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] | |
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch | |
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] | |
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] | |
Handle the last byte properly. | |
patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] | |
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] | |
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch | |
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] | |
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] | |
bpe delim | |
X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12] | |
X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], .. | |
X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], .. | |
Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, .. | |
Note 1: that there will be no special tokens introduced at the patch level. | |
Note 2: X_e needs to be trimmed to be passed to Global | |
""" | |
batch_size, seq_len = tokens.shape | |
local_encoder_tokens = tokens | |
local_decoder_tokens = tokens | |
if nb_boe > 0: | |
padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) | |
local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) | |
# global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) | |
# create global tokens, contains boe tokens and eos | |
# padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) | |
# patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) | |
# global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] | |
# global_tokens += global_tokens.eq(0).int() * boe_id | |
# TODO: fix this when we want to use block causal in the global. | |
if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: | |
local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) | |
return local_encoder_tokens, None, local_decoder_tokens | |
def patch_ids_from_lengths(patch_lengths, seq_len): | |
bs, num_patches = patch_lengths.shape | |
# Create a tensor of cumulative sums of the patch lengths | |
cum_d = torch.cat( | |
[ | |
torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), | |
patch_lengths.cumsum(dim=-1), | |
], | |
dim=-1, | |
) | |
patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum( | |
dim=-2 | |
) - 1 | |
assert not ( | |
torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0 | |
), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0" | |
return patch_ids | |
class ByteLatentTransformerArgs(BaseTransformerArgs): | |
# Basic model configuration | |
seed: int = 42 | |
vocab_size: int = -1 | |
dim: int = 512 | |
n_layers: int = 8 | |
n_heads: int = 8 | |
# TODO: What is the purpose of this parameter? | |
weight_tying: bool = False | |
patch_in_forward: bool = False | |
# Architecture and dimensions | |
dim_token: int | None = None | |
dim_global: int = 512 | |
dim_local_decoder: int = 512 | |
dim_local_encoder: int = 512 | |
n_layers_global: int = 8 | |
n_layers_local_decoder: int = 8 | |
n_layers_local_encoder: int = 8 | |
# Tokenization and patching | |
patch_size: float | None = None | |
patching_mode: str | None = None | |
patching_threshold: float | None = None | |
patching_threshold_add: float | None = None | |
monotonicity: bool = False | |
patching_batch_size: int = 1 | |
patching_device: str = "cuda" | |
max_patch_length: int | None = None | |
# Encoder/Decoder configuration | |
tie_local_encoder_decoder_logits: bool = False | |
use_local_encoder_transformer: bool = False | |
encoder_lm_loss: bool = False | |
max_encoder_seq_length: int | None = None | |
pad_to_max_length: bool = False | |
encoder_enable_byte_ngrams: bool = False | |
encoder_enable_byte_group_hash: bool = False | |
ngram_vocab_sizes: int | None = None | |
# Cross attention configurations | |
cross_attn_encoder: bool = False | |
cross_attn_decoder: bool = False | |
cross_attn_window_encoder: int | None = None | |
cross_attn_window_decoder: int | None = None | |
cross_attn_k: int | None = None | |
cross_attn_nheads: int | None = None | |
cross_attn_all_layers_decoder: bool = False | |
cross_attn_all_layers_encoder: bool = False | |
cross_attn_use_flex_attention: bool = True | |
cross_attn_init_by_pooling: bool = False | |
# Encoder hash configurations | |
encoder_hash_byte_group_size: Any | None = None | |
encoder_hash_byte_group_vocab: int = 30000 | |
encoder_hash_byte_group_nb_functions: int = 3 | |
# Model behavior and optimization | |
log_patch_lengths: bool = False | |
non_linearity: str = "swiglu" | |
use_rope: bool = True | |
recompute_fc1_out: bool = False | |
recompute_fc3_out: bool = False | |
recompute_attn: bool = True | |
custom_bwd: bool = False | |
layer_ckpt: str = "all" | |
# Initialization and attention | |
init_use_gaussian: bool = True | |
init_use_depth: str = "current" | |
attn_bias_type: str = "causal" | |
alpha_depth: str = "disabled" | |
max_length: int = 2048 | |
# Norm configuration | |
norm_eps: float = 1e-5 | |
norm_affine: bool = True | |
pre_norm: bool = True | |
norm_type: str = "rmsnorm" | |
# Additional configurations | |
multiple_of: int = 256 | |
ffn_dim_multiplier: float = 1.0 | |
dropout: float = 0 | |
output_size: int = -1 | |
# Additional parameters from ModelArgs | |
architecture: str = "vanilla" | |
share_encoder_decoder_emb: bool = True | |
global_local_decoder_residual_layer: str | None = None | |
tokenize_with_bpe_delimiter: bool = False | |
patching_thresholds_str: str | None = None | |
tie_local_encoder_decoder: bool = False | |
encoder_preds_low_entropy_toks: float | None = None | |
encoder_preds_random_toks: float | None = None | |
dim_token_emb: int | None = None | |
dim_patch_emb: int | None = None | |
encoder_ngram_table_dir: str | None = None | |
encoder_ngram_to_size_str: str | None = None | |
# Model architecture params | |
entropy_model_checkpoint_dir: str | None = None | |
entropy_model_is_ngram_model: bool = False | |
downsampling_by_pooling: str | None = None | |
n_heads_global: int = 8 | |
n_heads_local_decoder: int = 8 | |
n_heads_local_encoder: int = 8 | |
n_kv_heads: int | None = None | |
n_kv_heads_global: int | None = None | |
conv_kernel_size: int | None = None | |
local_attention_window_len: int | None = None | |
# Performance optimization | |
sequence_parallel: bool = False | |
loss_parallel: bool = False | |
fuse_sequence_parallel: bool = False | |
use_fsdp: bool = True | |
attn_to_keep: str = "all" | |
# Parameter mixing | |
pm_size: int = 0 | |
# Logging | |
full_logging_n_layers: int = 4 | |
def check_hash_byte_sizes(self) -> Self: | |
if ( | |
self.encoder_hash_byte_group_size is not None | |
and type(self.encoder_hash_byte_group_size) == str | |
): | |
self.encoder_hash_byte_group_size = [ | |
int(x) | |
for x in self.encoder_hash_byte_group_size.split(",") | |
if len(x) > 0 | |
] | |
return self | |
class GlobalTransformerArgs(ByteLatentTransformerArgs): | |
# Global encoder specific dimensions | |
dim_token_emb: int | None = None | |
dim_patch_emb: int | None = None | |
def __post_init__(self): | |
# Override base args with global encoder specific values | |
self.dim = self.dim_global | |
self.n_layers = self.n_layers_global | |
self.n_heads = self.n_heads_global | |
self.n_kv_heads = self.n_kv_heads_global | |
self.local_attention_window_len = None | |
self.cross_attn_encoder = False | |
self.cross_attn_decoder = False | |
class LocalDecoderArgs(ByteLatentTransformerArgs): | |
# Local decoder specific dimensions | |
dim_token_emb: int | None = None | |
dim_patch_emb: int | None = None | |
def __post_init__(self): | |
# Override base args with local decoder specific values | |
self.dim = self.dim_local_decoder | |
self.n_layers = self.n_layers_local_decoder | |
self.n_heads = self.n_heads_local_decoder | |
self.cross_attn_encoder = False | |
self.cross_attn_init_by_pooling = False | |
self.attn_bias_type = "local_block_causal" | |
def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer: | |
global_args = args.model_copy( | |
deep=True, | |
update=dict( | |
dim=args.dim_global, | |
n_layers=args.n_layers_global, | |
n_heads=args.n_heads_global, | |
n_kv_heads=args.n_kv_heads_global, | |
local_attention_window_len=None, | |
dim_token_emb=get_global_dim_patch_emb(args), | |
dim_patch_emb=None, | |
cross_attn_encoder=False, | |
cross_attn_decoder=False, | |
), | |
) | |
return GlobalTransformer(global_args) | |
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: | |
local_encoder_args = LocalModelArgs( | |
# Updated args | |
dim=args.dim_local_encoder, | |
n_layers=args.n_layers_local_encoder, | |
n_heads=args.n_heads_local_encoder, | |
dim_token_emb=get_encoder_dim_token_emb(args), | |
dim_patch_emb=get_encoder_dim_patch_emb(args), | |
cross_attn_encoder=args.cross_attn_encoder, | |
cross_attn_decoder=False, | |
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, | |
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, | |
# Defaults | |
head_dim=args.head_dim, | |
max_seqlen=args.max_encoder_seq_length, | |
dropout=args.dropout, | |
vocab_size=args.vocab_size + args.pm_size, | |
norm_eps=args.norm_eps, | |
patch_size=args.patch_size, | |
sliding_window=args.local_attention_window_len, | |
use_rope=args.use_rope, | |
rope_theta=args.rope_theta, | |
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, | |
init_base_std=args.init_base_std, | |
init_std_factor=args.init_std_factor, | |
n_kv_heads=args.n_kv_heads, | |
attn_impl=args.attn_impl, | |
attn_bias_type="local_block_causal", | |
multiple_of=args.multiple_of, | |
ffn_dim_multiplier=args.ffn_dim_multiplier, | |
patching_mode=args.patching_mode, | |
use_local_encoder_transformer=args.use_local_encoder_transformer, | |
downsampling_by_pooling=args.downsampling_by_pooling, | |
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, | |
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, | |
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, | |
cross_attn_nheads=args.cross_attn_nheads, | |
eos_id=args.eos_id, | |
) | |
return LocalEncoder(local_encoder_args) | |
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: | |
# First deep copy the original args | |
local_decoder_args = LocalModelArgs( | |
dim=args.dim_local_decoder, | |
n_layers=args.n_layers_local_decoder, | |
n_heads=args.n_heads_local_decoder, | |
dim_token_emb=get_decoder_dim_token_emb(args), | |
dim_patch_emb=args.dim_global, | |
cross_attn_encoder=False, | |
cross_attn_decoder=args.cross_attn_decoder, | |
cross_attn_init_by_pooling=False, # states are already defined | |
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, | |
# Defaults | |
head_dim=args.head_dim, | |
max_seqlen=args.max_encoder_seq_length, | |
dropout=args.dropout, | |
vocab_size=args.vocab_size + args.pm_size, | |
norm_eps=args.norm_eps, | |
patch_size=args.patch_size, | |
sliding_window=args.local_attention_window_len, | |
use_rope=args.use_rope, | |
rope_theta=args.rope_theta, | |
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, | |
init_base_std=args.init_base_std, | |
init_std_factor=args.init_std_factor, | |
n_kv_heads=args.n_kv_heads, | |
attn_impl=args.attn_impl, | |
attn_bias_type="local_block_causal", | |
multiple_of=args.multiple_of, | |
ffn_dim_multiplier=args.ffn_dim_multiplier, | |
patching_mode=args.patching_mode, | |
use_local_encoder_transformer=args.use_local_encoder_transformer, | |
downsampling_by_pooling=args.downsampling_by_pooling, | |
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, | |
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, | |
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, | |
cross_attn_nheads=args.cross_attn_nheads, | |
eos_id=args.eos_id, | |
) | |
return LocalDecoder(local_decoder_args) | |
class EmbeddingType(Enum): | |
HASH_TOK = auto() | |
NGRAM = auto() | |
def init_embeddings( | |
args, | |
embedding_type: EmbeddingType, | |
local_encoder_dim: int, | |
encoder_hash_byte_group_size: list = None, | |
): | |
if ( | |
embedding_type == EmbeddingType.HASH_TOK | |
and args.encoder_hash_byte_group_size is None | |
): | |
return None | |
if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None: | |
return None | |
embeddings = [] | |
if embedding_type == EmbeddingType.HASH_TOK: | |
emb_dim = local_encoder_dim | |
encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab | |
for _ in range(args.encoder_hash_byte_group_nb_functions): | |
for _ in encoder_hash_byte_group_size: | |
embeddings.append( | |
nn.Embedding( | |
encoder_hash_byte_group_vocab, | |
emb_dim, | |
) | |
) | |
elif embedding_type == EmbeddingType.NGRAM: | |
encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str) | |
emb_dim = local_encoder_dim | |
OFFSET = 4 # This should be passed as parameter if it's variable | |
for ngram_vocab_size in encoder_ngram_to_size.values(): | |
embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) | |
return nn.ModuleList(embeddings) | |
def compute_hash_embeddings( | |
local_encoder_tokens: torch.Tensor, | |
local_encoder, | |
encoder_hash_tok_embedding: nn.ModuleList, | |
encoder_hash_byte_group_nb_functions: int, | |
encoder_hash_byte_group_size: list, | |
encoder_hash_byte_group_vocab: int, | |
) -> torch.Tensor: | |
""" | |
Compute embeddings using hash token embeddings. | |
Args: | |
local_encoder_tokens: Input tokens tensor | |
local_encoder: Encoder object with tok_embeddings method | |
encoder_hash_tok_embedding: ModuleList of hash token embeddings | |
encoder_hash_byte_group_nb_functions: Number of hash functions | |
encoder_hash_byte_group_size: List of byte group sizes | |
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings | |
Returns: | |
torch.Tensor: Combined embeddings | |
""" | |
if encoder_hash_tok_embedding is None: | |
return None | |
local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) | |
i = 0 | |
for func_nb in range(encoder_hash_byte_group_nb_functions): | |
for byte_group_size in encoder_hash_byte_group_size: | |
hash_ids = byte_group_hash_function( | |
local_encoder_tokens, | |
byte_group_size, | |
hash_func_nb=func_nb, | |
max_hash=encoder_hash_byte_group_vocab, | |
) | |
hash_tok_embedding = encoder_hash_tok_embedding[i] | |
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) | |
i += 1 | |
assert i == len(encoder_hash_tok_embedding) | |
return local_encoder_embeds | |
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin, | |
repo_url="https://github.com/facebookresearch/blt", | |
pipeline_tag="text-generation", | |
license="other"): | |
""" | |
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences | |
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, | |
and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for | |
improved performance and inference efficiency. | |
""" | |
def __init__(self, args: ByteLatentTransformerArgs): | |
super().__init__() | |
# General configuration | |
self.weight_tying = args.weight_tying | |
self.patch_size = args.patch_size | |
self.patching_mode = args.patching_mode | |
self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( | |
BOE_ID, | |
BOS_ID, | |
PAD_ID, | |
EOS_ID, | |
) | |
self.downsampling_by_pooling = args.downsampling_by_pooling | |
self.patching_threshold = args.patching_threshold | |
self.dim = args.dim | |
self.init_base_std = args.init_base_std | |
self.init_std_factor = InitStdFactor(args.init_std_factor) | |
self.max_seqlen = args.max_seqlen | |
# Cross attention configuration | |
self.cross_attn_encoder = args.cross_attn_encoder | |
self.cross_attn_decoder = args.cross_attn_decoder | |
self.cross_attn_k = args.cross_attn_k | |
self.cross_attn_window_encoder = args.cross_attn_window_encoder | |
self.cross_attn_window_decoder = args.cross_attn_window_decoder | |
self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention | |
# Encoder hash configuration | |
self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size | |
self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab | |
self.encoder_hash_byte_group_nb_functions = ( | |
args.encoder_hash_byte_group_nb_functions | |
) | |
# ByteLatent modules | |
self.local_encoder = create_local_encoder(args) | |
self.global_transformer = create_global_transformer(args) | |
self.local_decoder = create_local_decoder(args) | |
self.encoder_hash_tok_embedding = init_embeddings( | |
args, | |
EmbeddingType.HASH_TOK, | |
local_encoder_dim=self.local_encoder.dim, | |
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, | |
) | |
self.encoder_ngram_embedding = init_embeddings( | |
args, | |
EmbeddingType.NGRAM, | |
local_encoder_dim=self.local_encoder.dim, | |
encoder_hash_byte_group_size=None, | |
) | |
# Encoder ngram embedding tables | |
self.encoder_ngram_embedding = None | |
if args.encoder_enable_byte_ngrams: | |
self.encoder_ngram_embedding = nn.ModuleList() | |
assert args.ngram_vocab_sizes is not None | |
self.encoder_ngram_to_size = parse_ngram_to_size( | |
args.encoder_ngram_to_size_str | |
) | |
ngram_emb_dim = self.local_encoder.dim | |
for ngram_vocab_size in self.encoder_ngram_to_size.values(): | |
self.encoder_ngram_embedding.append( | |
nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) | |
) | |
# Output layer | |
assert args.vocab_size > 0, "vocab_size must be greater than 0" | |
# Patcher module | |
if args.patch_in_forward: | |
self.patcher = Patcher( | |
PatcherArgs( | |
patch_size=args.patch_size, | |
patching_mode=args.patching_mode, | |
patching_threshold=args.patching_threshold, | |
patching_threshold_add=args.patching_threshold_add, | |
monotonicity=args.monotonicity, | |
max_patch_length=args.max_patch_length, | |
) | |
) | |
def get_output_seq_len(self): | |
return self.max_seqlen | |
def forward( | |
self, | |
tokens: torch.Tensor, | |
patch_lengths: Optional[torch.Tensor] = None, | |
ngram_ids: Optional[torch.Tensor] = None, | |
): | |
# Ensure ngram_ids is either a tensor or None | |
assert ( | |
isinstance(ngram_ids, torch.Tensor) or ngram_ids is None | |
), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}" | |
bs, N = tokens.shape # Batch size and sequence length | |
# Get megabyte inputs | |
nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1) | |
local_encoder_tokens, _, local_decoder_tokens = get_blt_input( | |
tokens=tokens, | |
enforce_patch_size_multiple=False, | |
nb_boe=nb_boe, | |
patch_size=self.patch_size, | |
boe_id=self.boe_id, | |
) | |
# Patching | |
if patch_lengths is None: | |
assert ( | |
getattr(self, "patcher", None) is not None | |
), "Patcher not defined and no patch_lengths passed." | |
patch_lengths, tok_scores = self.patcher.patch( | |
local_encoder_tokens, | |
include_next_token=True, | |
threshold=self.patcher.threshold, | |
) | |
else: | |
if nb_boe > 0: | |
patch_lengths[:, 0] += nb_boe | |
assert torch.min(patch_lengths) >= 0 | |
# Generate patch IDs from patch_lengths | |
patch_ids = patch_ids_from_lengths( | |
patch_lengths, local_encoder_tokens.shape[-1] | |
) | |
assert torch.max(patch_ids) + 1 <= torch.max( | |
(patch_lengths != 0).sum(dim=-1) | |
), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" | |
cross_attn_mask_enc = None | |
# Cross-attention encoder | |
if self.cross_attn_encoder: | |
cross_attn_mask_enc = cross_attn_mask( | |
patch_ids, | |
patch_lengths, | |
N, | |
patches_as_queries=True, | |
cross_attn_k=self.cross_attn_k, | |
window=self.cross_attn_window_encoder, | |
block_mask=self.cross_attn_use_flex_attention, | |
) | |
# Hashing and embedding | |
local_encoder_embeds = compute_hash_embeddings( | |
local_encoder_tokens=local_encoder_tokens, | |
local_encoder=self.local_encoder, | |
encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, | |
encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, | |
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, | |
encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, | |
) | |
# N-gram table embeddings | |
if self.encoder_ngram_embedding is not None: | |
assert ngram_ids is not None, "ngram_ids must be provided" | |
if local_encoder_embeds is None: | |
local_encoder_embeds = self.local_encoder.tok_embeddings( | |
local_encoder_tokens | |
) | |
assert len(ngram_ids) == len( | |
self.encoder_ngram_embedding | |
), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}" | |
for i in range(ngram_ids.shape[0]): | |
ngram_embedding = self.encoder_ngram_embedding[i] | |
ngram_embeds = ngram_embedding(ngram_ids[i]) | |
assert ( | |
local_encoder_embeds.shape == ngram_embeds.shape | |
), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}" | |
local_encoder_embeds = local_encoder_embeds + ngram_embeds | |
# Local encoder | |
(h_encoder, h_cross), cache_encoder = self.local_encoder( | |
tokens=local_encoder_tokens, | |
embeds=local_encoder_embeds, | |
patch_embeds=None, | |
cross_mask=cross_attn_mask_enc, | |
num_patches=patch_lengths.shape[1], | |
patch_ids=patch_ids, | |
) | |
# Downsampling | |
if not self.cross_attn_encoder: | |
assert ( | |
patch_ids.shape[1] == h_encoder.shape[1] | |
), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}" | |
h = downsample( | |
h_encoder, | |
patch_lengths.shape[1], | |
patch_lengths, | |
patch_ids, | |
downsampling_by_pooling=self.downsampling_by_pooling, | |
patch_size=self.patch_size, | |
) | |
else: | |
# Reshape h_cross | |
h = h_cross.view(bs, patch_lengths.shape[1], -1) | |
# Global transformer | |
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) | |
rows, cols = torch.where(local_encoder_tokens == self.eos_id) | |
eos_patch_ids = patch_ids[rows, cols] | |
global_tokens[rows, eos_patch_ids] = self.eos_id | |
h, _ = self.global_transformer( | |
embeds=h, | |
tokens=global_tokens, | |
) | |
# Unpatching | |
dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] | |
# Generate decoder patch IDs | |
decoder_patch_ids = decoder_patch_ids_from_lengths( | |
patch_lengths, nb_boe, local_decoder_tokens.shape[-1] | |
) | |
assert ( | |
torch.max(decoder_patch_ids) + 1 <= h.shape[1] | |
), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" | |
assert ( | |
decoder_patch_ids.shape[1] == dec_embeds.shape[1] | |
), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" | |
# Cross-attention decoder | |
if not self.cross_attn_decoder: | |
h = torch.gather( | |
h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) | |
) | |
cross_attn_mask_dec = None | |
assert local_decoder_tokens.shape == h.shape[:-1] | |
else: | |
cross_attn_mask_dec = cross_attn_mask( | |
decoder_patch_ids, | |
patch_lengths, | |
N, | |
patches_as_queries=False, | |
cross_attn_k=self.cross_attn_k, | |
window=self.cross_attn_window_decoder, | |
block_mask=self.cross_attn_use_flex_attention, | |
) | |
# Local decoder | |
output, _ = self.local_decoder( | |
embeds=dec_embeds, | |
patch_embeds=h, | |
tokens=local_decoder_tokens, | |
cross_mask=cross_attn_mask_dec, | |
) | |
return output | |
def init_weights(self): | |
self.local_encoder.init_weights() | |
self.global_transformer.init_weights() | |
self.local_decoder.init_weights() | |
emb_std = self.local_encoder.dim ** (-0.5) | |
for emb in self.encoder_hash_tok_embedding: | |
nn.init.trunc_normal_( | |
emb.weight, | |
mean=0.0, | |
std=emb_std, | |
a=-3 * emb_std, | |
b=3 * emb_std, | |
) | |