File size: 1,743 Bytes
41ea791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

import typer

from bytelatent.data.file_util import get_fs
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
from bytelatent.generate_patcher import patcher_nocache
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies


def main(prompt: str, model_name: str = "blt-1b"):
    from bytelatent.args import TrainArgs
    consolidated_path = os.path.join("hf-weights", model_name)
    train_args_path = os.path.join(consolidated_path, "params.json")
    fs = get_fs(train_args_path)
    train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))

    tokenizer = train_args.data.tokenizer_args.build()
    assert isinstance(tokenizer, BltTokenizer)
    patcher_args = train_args.data.patcher_args.model_copy(deep=True)
    patcher_args.realtime_patching = True
    #  NOTE: CPU currently unsupported due to reliance of xformers
    patcher_args.patching_device = "cpu"
    patcher_args.device = "cpu"
    print("Loading entropy model and patcher")
    patcher_args.entropy_model_checkpoint_dir = os.path.join(
        consolidated_path, "entropy_model"
    )
    patcher = patcher_args.build()
    prompts = [prompt]
    results = patcher_nocache(
        prompts, tokenizer=tokenizer, patcher=patcher
    )
    if not results:
        raise Exception("Ruh roh")
    batch_patch_lengths, batch_scores, batch_tokens = results
    decoded_chars = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
    plot_entropies(
        batch_patch_lengths[0],
        batch_scores[0],
        decoded_chars[0],
        threshold=patcher.threshold
    )


if __name__ == "__main__":
    typer.run(main)