File size: 1,511 Bytes
bcc039b
 
6ffeb66
bcc039b
 
 
 
 
 
6ffeb66
 
bcc039b
 
 
 
 
 
fc946a1
6ffeb66
 
 
bcc039b
 
 
 
 
fc946a1
bcc039b
 
a37fec7
661d10b
6ffeb66
bcc039b
 
 
 
fc946a1
bcc039b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os

import torch

from bytelatent.transformer import LMTransformer, LMTransformerArgs

logger = logging.getLogger()


def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
    with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
        reloaded = json.loads(fr.read())

    torch.set_default_dtype(torch.bfloat16)
    model_params = reloaded["entropy_model"]
    logger.warning(
        "Update checkpoint to load attn and sliding window args from checkpoint"
    )
    entropy_model = LMTransformer(
        LMTransformerArgs(
            dim=model_params["dim"],
            n_layers=model_params["n_layers"],
            n_heads=model_params["n_heads"],
            max_seqlen=model_params["max_seqlen"],
            ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
            vocab_size=model_params["vocab_size"],
            attn_bias_type="local_block_causal" if torch.cuda.is_available() else "causal",
            attn_impl="xformers" if torch.cuda.is_available() else "sdpa",
            sliding_window=512,
        )
    )

    entropy_model.load_state_dict(
        torch.load(state_dict_path, map_location=device)["model"], strict=False
    )
    entropy_model.to(device)
    entropy_model = entropy_model.eval()
    # no grads for the model:
    for param in entropy_model.parameters():
        param.requires_grad = False
    return entropy_model