blt-entropy-patcher / demo_patcher.py
luca-peric's picture
Working locally, TBD HF space
4f7c221
import os
import typer
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
from bytelatent.generate_patcher import patcher_nocache
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
def main(prompt: str, model_name: str = "blt-1b"):
from bytelatent.args import TrainArgs
consolidated_path = os.path.join("hf-weights", model_name)
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
tokenizer = train_args.data.tokenizer_args.build()
assert isinstance(tokenizer, BltTokenizer)
patcher_args = train_args.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True
# NOTE: CPU currently unsupported due to reliance of xformers
patcher_args.patching_device = "cpu"
patcher_args.device = "cpu"
print("Loading entropy model and patcher")
patcher_args.entropy_model_checkpoint_dir = os.path.join(
consolidated_path, "entropy_model"
)
patcher = patcher_args.build()
prompts = [prompt]
results = patcher_nocache(
prompts, tokenizer=tokenizer, patcher=patcher
)
if not results:
raise Exception("Ruh roh")
batch_patch_lengths, batch_scores, batch_tokens = results
decoded_chars = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
plot_entropies(
batch_patch_lengths[0],
batch_scores[0],
decoded_chars[0],
threshold=patcher.threshold
)
if __name__ == "__main__":
typer.run(main)