Spaces:
Running
on
Zero
Running
on
Zero
NielsRogge
commited on
Improve HF integration (#98)
Browse files* Add mixin
* Update license
- bytelatent/model/blt.py +7 -3
bytelatent/model/blt.py
CHANGED
@@ -4,7 +4,7 @@ from enum import Enum, auto
|
|
4 |
from typing import Any, Optional
|
5 |
|
6 |
import torch
|
7 |
-
from pydantic import
|
8 |
from torch import nn
|
9 |
from torch.nn.attention.flex_attention import create_block_mask
|
10 |
from typing_extensions import Self
|
@@ -13,7 +13,6 @@ from bytelatent.base_transformer import (
|
|
13 |
BaseTransformerArgs,
|
14 |
InitStdFactor,
|
15 |
SequenceModelWithOutput,
|
16 |
-
TransformerBlock,
|
17 |
)
|
18 |
from bytelatent.data.patcher import Patcher, PatcherArgs
|
19 |
from bytelatent.model.latent_transformer import GlobalTransformer
|
@@ -21,6 +20,8 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel
|
|
21 |
from bytelatent.model.utils import downsample
|
22 |
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
23 |
|
|
|
|
|
24 |
|
25 |
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
26 |
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
@@ -767,7 +768,10 @@ def compute_hash_embeddings(
|
|
767 |
return local_encoder_embeds
|
768 |
|
769 |
|
770 |
-
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput
|
|
|
|
|
|
|
771 |
"""
|
772 |
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
773 |
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
|
|
4 |
from typing import Any, Optional
|
5 |
|
6 |
import torch
|
7 |
+
from pydantic import model_validator
|
8 |
from torch import nn
|
9 |
from torch.nn.attention.flex_attention import create_block_mask
|
10 |
from typing_extensions import Self
|
|
|
13 |
BaseTransformerArgs,
|
14 |
InitStdFactor,
|
15 |
SequenceModelWithOutput,
|
|
|
16 |
)
|
17 |
from bytelatent.data.patcher import Patcher, PatcherArgs
|
18 |
from bytelatent.model.latent_transformer import GlobalTransformer
|
|
|
20 |
from bytelatent.model.utils import downsample
|
21 |
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
22 |
|
23 |
+
from huggingface_hub import PyTorchModelHubMixin
|
24 |
+
|
25 |
|
26 |
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
27 |
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
|
|
768 |
return local_encoder_embeds
|
769 |
|
770 |
|
771 |
+
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
|
772 |
+
repo_url="https://github.com/facebookresearch/blt",
|
773 |
+
pipeline_tag="text-generation",
|
774 |
+
license="other"):
|
775 |
"""
|
776 |
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
777 |
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|