File size: 16,314 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ffeb66
bcc039b
 
6ffeb66
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ffeb66
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7622d28
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ffeb66
bcc039b
 
 
 
 
6ffeb66
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ffeb66
 
bcc039b
6ffeb66
 
 
 
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ffeb66
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ffeb66
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from dataclasses import replace

import numpy as np
import pytest
import torch

from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import Batch
from bytelatent.data.ngram_processor import NgramProcessor
from bytelatent.model.blt import (
    ByteLatentTransformer,
    ByteLatentTransformerArgs,
    EmbeddingType,
    compute_hash_embeddings,
    create_global_transformer,
    create_local_decoder,
    create_local_encoder,
    cross_attn_mask,
    decoder_patch_ids_from_lengths,
    get_blt_input,
    init_embeddings,
    patch_ids_from_lengths,
)
from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask
from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.tokenizers.constants import EOS_ID
from bytelatent.train import compute_loss


def batch_to_tensors_and_gpu(batch):
    x = torch.from_numpy(batch.x)
    y = torch.from_numpy(batch.y)
    mask = None if batch.mask is None else torch.from_numpy(batch.mask)
    patch_lengths = (
        None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths)
    )
    ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids)

    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()
        if mask is not None:
            mask = mask.cuda()
        if patch_lengths is not None:
            patch_lengths = patch_lengths.cuda()
        if ngram_ids is not None:
            ngram_ids = ngram_ids.cuda()
    return x, y, mask, patch_lengths, ngram_ids


def fake_batch():
    batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
    del batch_dict["x2"]
    del batch_dict["y2"]
    del batch_dict["src_names"]
    return Batch(**batch_dict)


def create_args(cross_attention=False):
    transformer_args = ByteLatentTransformerArgs(
        # Base args provided
        n_heads=8,
        dim=512,
        vocab_size=260,
        # Additional args from command line
        dim_token=256,
        patch_size=6,
        patching_mode="space",
        tie_local_encoder_decoder_logits=False,
        patch_in_forward=False,
        max_encoder_seq_length=12288,
        pad_to_max_length=True,
        encoder_lm_loss=False,
        patching_threshold=3.1439168453216553,
        encoder_hash_byte_group_size=[4],
        encoder_hash_byte_group_vocab=50002,
        encoder_hash_byte_group_nb_functions=3,
        cross_attn_encoder=cross_attention,  # True,
        cross_attn_decoder=cross_attention,  # True,
        cross_attn_window_encoder=512,
        cross_attn_window_decoder=512,
        dim_local_encoder=256,
        dim_local_decoder=256,
        cross_attn_k=8,
        cross_attn_nheads=4,
        cross_attn_all_layers_decoder=True,
        cross_attn_all_layers_encoder=True,
        cross_attn_use_flex_attention=True,
        cross_attn_init_by_pooling=True,
        log_patch_lengths=True,
        non_linearity="swiglu",
        use_rope=True,
        recompute_fc1_out=False,
        recompute_fc3_out=False,
        recompute_attn=False,
        custom_bwd=False,
        layer_ckpt="none",
        use_local_encoder_transformer=True,
        init_use_gaussian=True,
        init_use_depth="current",
        attn_bias_type="block_causal",
        attn_impl="xformers",
        alpha_depth="disabled",
        max_length=256,
        local_attention_window_len=512,
        max_seqlen=12288,
        downsampling_by_pooling="max",
        eos_id=EOS_ID,
    )
    return transformer_args


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
class TestByteLatentTransformer:
    def test_local_encoder(self):
        args = create_args()
        device = torch.device("cuda")
        local_encoder = create_local_encoder(args).to(device)

        batch = fake_batch()
        tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)

        local_encoder_tokens, _, _ = get_blt_input(
            tokens=tokens,
            enforce_patch_size_multiple=False,
            nb_boe=0,
            patch_size=local_encoder.patch_size,
            boe_id=local_encoder.boe_id,
        )

        patch_ids = patch_ids_from_lengths(
            patch_lengths, local_encoder_tokens.shape[-1]
        )

        encoder_hash_tok_embedding = init_embeddings(
            args,
            EmbeddingType.HASH_TOK,
            local_encoder_dim=local_encoder.dim,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
        ).to(device)

        local_encoder_embeds = compute_hash_embeddings(
            local_encoder_tokens=local_encoder_tokens,
            local_encoder=local_encoder,
            encoder_hash_tok_embedding=encoder_hash_tok_embedding,
            encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
            encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
        )

        reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt")
        reference_tokens = torch.load(reference_path).to(device)
        torch.testing.assert_close(
            local_encoder_tokens,
            reference_tokens,
            msg="Generated tokens don't match reference tokens",
        )

        (h_encoder, h_cross), cache_encoder = local_encoder(
            tokens=local_encoder_tokens,
            embeds=local_encoder_embeds,
            patch_embeds=None,
            cross_mask=None,
            num_patches=patch_lengths.shape[1],
            patch_ids=patch_ids,
        )

        assert h_encoder is not None
        assert h_cross is None
        assert cache_encoder is None

        expected_shape = (
            local_encoder_tokens.shape[0],
            local_encoder_tokens.shape[1],
            local_encoder.dim,
        )
        assert h_encoder.shape == expected_shape

    def test_local_encoder_cross_attention(self):
        args = create_args(cross_attention=True)
        device = torch.device("cuda")
        local_encoder = create_local_encoder(args).to(device)

        batch = fake_batch()
        tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)

        local_encoder_tokens, _, _ = get_blt_input(
            tokens=tokens,
            enforce_patch_size_multiple=False,
            nb_boe=0,
            patch_size=local_encoder.patch_size,
            boe_id=local_encoder.boe_id,
        )

        patch_ids = patch_ids_from_lengths(
            patch_lengths, local_encoder_tokens.shape[-1]
        )

        encoder_hash_tok_embedding = init_embeddings(
            args,
            EmbeddingType.HASH_TOK,
            local_encoder_dim=local_encoder.dim,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
        ).to(device)

        cross_attn_mask_enc = cross_attn_mask(
            patch_ids,
            patch_lengths,
            local_encoder_tokens.shape[-1],
            patches_as_queries=True,
            cross_attn_k=args.cross_attn_k,
            window=args.cross_attn_window_encoder,
            block_mask=True,
        )

        local_encoder_embeds = compute_hash_embeddings(
            local_encoder_tokens=local_encoder_tokens,
            local_encoder=local_encoder,
            encoder_hash_tok_embedding=encoder_hash_tok_embedding,
            encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
            encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
        )
        (h_encoder, h_cross), cache_encoder = local_encoder(
            tokens=local_encoder_tokens,
            embeds=local_encoder_embeds,
            patch_embeds=None,
            cross_mask=cross_attn_mask_enc,
            num_patches=patch_lengths.shape[1],
            patch_ids=patch_ids,
        )
        assert h_encoder is not None
        assert h_cross is not None
        assert cache_encoder is None
        expected_shape = (
            local_encoder_tokens.shape[0],
            local_encoder_tokens.shape[1],
            local_encoder.dim,
        )
        assert h_encoder.shape == expected_shape
        assert h_cross.shape == (2, 2048, local_encoder.dim)

    def test_local_decoder_cross_attention(self):
        args = create_args(cross_attention=True)
        device = torch.device("cuda")
        local_decoder = create_local_decoder(args).to(device)

        test_files = {
            "dec_embeds": "dec_embeds.pt",
            "decoder_tokens": "local_decoder_tokens.pt",
            "patch_embeds": "decoder_patch_cross_embeds.pt",
        }
        batch = fake_batch()
        _, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)

        tensors = {
            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
            for name, filename in test_files.items()
        }
        decoder_patch_ids = decoder_patch_ids_from_lengths(
            patch_lengths, 0, tensors["decoder_tokens"].shape[-1]
        )
        cross_attn_mask_dec = cross_attn_mask(
            decoder_patch_ids,
            patch_lengths,
            tensors["decoder_tokens"].shape[-1],
            patches_as_queries=False,
            cross_attn_k=args.cross_attn_k,
            window=args.cross_attn_window_decoder,
            block_mask=True,
        )
        output, _ = local_decoder(
            embeds=tensors["dec_embeds"],
            patch_embeds=tensors["patch_embeds"],
            tokens=tensors["decoder_tokens"],
            cross_mask=cross_attn_mask_dec,
            cache=None,
        )
        assert output is not None
        assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size)

    def test_local_decoder(self):
        args = create_args()
        device = torch.device("cuda")

        local_decoder = create_local_decoder(args).to(device)

        test_files = {
            "dec_embeds": "dec_embeds.pt",
            "decoder_tokens": "local_decoder_tokens.pt",
            "patch_embeds": "decoder_patch_embeds.pt",
        }

        tensors = {
            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
            for name, filename in test_files.items()
        }

        output, cache_decoder = local_decoder(
            embeds=tensors["dec_embeds"],
            patch_embeds=tensors["patch_embeds"],
            tokens=tensors["decoder_tokens"],
            cross_mask=None,
            cache=None,
        )
        assert output is not None
        expected_shape = (
            tensors["decoder_tokens"].shape[0],
            tensors["decoder_tokens"].shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape
        assert cache_decoder is None

    def test_global_transformer(self):
        args = create_args()
        device = torch.device("cuda")
        global_transformer = create_global_transformer(args).to(device)

        test_files = {
            "global_embeds": "global_embeds.pt",
            "global_tokens": "global_tokens.pt",
        }
        tensors = {
            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
            for name, filename in test_files.items()
        }
        h, cache = global_transformer(
            embeds=tensors["global_embeds"], tokens=tensors["global_tokens"]
        )
        h is not None
        assert h.shape == (2, 256, 512)
        assert cache is None

    def test_blt_transformer_init(self):
        args = create_args()
        model = ByteLatentTransformer(args)
        assert model is not None

    @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"])
    def test_blt_transformer_forward(self, attn_impl):
        args = create_args()
        if attn_impl == "sdpa":
            os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1"
        else:
            os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0"

        args = args.model_copy(update=dict(attn_impl=attn_impl))
        model = ByteLatentTransformer(args)
        model = model.cuda()
        batch = fake_batch()
        x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        output = model(
            tokens=x,
            patch_lengths=patch_lengths,
            ngram_ids=ngram_ids,
        )
        assert output is not None
        expected_shape = (
            x.shape[0],
            x.shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape

    def test_blt_transformer_cross_attn_forward(self):
        args = create_args(cross_attention=True)
        model = ByteLatentTransformer(args)
        model = model.cuda()
        batch = fake_batch()
        x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        output = model(
            tokens=x,
            patch_lengths=patch_lengths,
            ngram_ids=ngram_ids,
        )
        assert output is not None
        expected_shape = (
            x.shape[0],
            x.shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape

    def test_cross_attention_rand(self):
        x = torch.randn(2, 256, 512, device="cuda")
        kv = torch.randn(2, 256, 512, device="cuda")
        cross_attention = CrossAttention(
            dim=512,
            head_dim=64,
            n_heads=8,
            n_kv_heads=4,
            norm_eps=1e-6,
        ).to("cuda")
        mask = create_causal_mask(
            x.shape[1], "flex_attention", None, sliding_window=None
        )
        output = cross_attention(x, kv, mask)
        assert output is not None
        assert output.shape == (2, 256, 512)

    def test_ngram_embeddings(self):
        ngram_to_size = {
            2: 38396,
            3: 50000,
            4: 50000,
            5: 50000,
            6: 50000,
            7: 50000,
            8: 50000,
        }
        batch = fake_batch()
        ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size)
        ngram_ids = ngram_processor.encode_token_ngrams(batch.x)
        ngram_ids = np.stack(ngram_ids, axis=0)
        batch = replace(batch, ngram_ids=ngram_ids)
        args = create_args(cross_attention=True)
        args = args.model_copy(
            update=dict(
                encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000",
                encoder_enable_byte_ngrams=True,
                ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes,
            )
        )
        model = ByteLatentTransformer(args)
        model = model.cuda()
        x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        output = model(
            tokens=x,
            patch_lengths=patch_lengths,
            ngram_ids=ngram_ids,
        )
        assert output is not None
        expected_shape = (
            x.shape[0],
            x.shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape

    def test_loss_backward(self):
        args = create_args()
        args = args.model_copy(update=dict(attn_impl="xformers"))
        batch = fake_batch()
        model = ByteLatentTransformer(args)
        steps = 10
        optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps)
        model = model.cuda()
        x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        initial_loss = None
        final_loss = None
        for step in range(steps):
            output = model(
                tokens=x,
                patch_lengths=patch_lengths,
                ngram_ids=ngram_ids,
            )
            loss, _ = compute_loss(output, y, mask, 1.0)
            if step == 0:
                initial_loss = loss.item()
            if step == steps - 1:
                final_loss = loss.item()
            prev_loss = loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        assert (
            final_loss < initial_loss
        ), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}"