NielsRogge commited on
Commit
1b67cbe
·
unverified ·
1 Parent(s): 96d51b5

Improve HF integration (#98)

Browse files

* Add mixin

* Update license

Files changed (1) hide show
  1. bytelatent/model/blt.py +7 -3
bytelatent/model/blt.py CHANGED
@@ -4,7 +4,7 @@ from enum import Enum, auto
4
  from typing import Any, Optional
5
 
6
  import torch
7
- from pydantic import ConfigDict, model_validator
8
  from torch import nn
9
  from torch.nn.attention.flex_attention import create_block_mask
10
  from typing_extensions import Self
@@ -13,7 +13,6 @@ from bytelatent.base_transformer import (
13
  BaseTransformerArgs,
14
  InitStdFactor,
15
  SequenceModelWithOutput,
16
- TransformerBlock,
17
  )
18
  from bytelatent.data.patcher import Patcher, PatcherArgs
19
  from bytelatent.model.latent_transformer import GlobalTransformer
@@ -21,6 +20,8 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel
21
  from bytelatent.model.utils import downsample
22
  from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
23
 
 
 
24
 
25
  def attention_flops_per_token(n_layers, seq_len, dim, causal):
26
  # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
@@ -767,7 +768,10 @@ def compute_hash_embeddings(
767
  return local_encoder_embeds
768
 
769
 
770
- class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
 
 
 
771
  """
772
  The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
773
  by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
 
4
  from typing import Any, Optional
5
 
6
  import torch
7
+ from pydantic import model_validator
8
  from torch import nn
9
  from torch.nn.attention.flex_attention import create_block_mask
10
  from typing_extensions import Self
 
13
  BaseTransformerArgs,
14
  InitStdFactor,
15
  SequenceModelWithOutput,
 
16
  )
17
  from bytelatent.data.patcher import Patcher, PatcherArgs
18
  from bytelatent.model.latent_transformer import GlobalTransformer
 
20
  from bytelatent.model.utils import downsample
21
  from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
22
 
23
+ from huggingface_hub import PyTorchModelHubMixin
24
+
25
 
26
  def attention_flops_per_token(n_layers, seq_len, dim, causal):
27
  # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
 
768
  return local_encoder_embeds
769
 
770
 
771
+ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
772
+ repo_url="https://github.com/facebookresearch/blt",
773
+ pipeline_tag="text-generation",
774
+ license="other"):
775
  """
776
  The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
777
  by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,