Spaces:
Running
Running
import os | |
import typer | |
from bytelatent.data.file_util import get_fs | |
from bytelatent.distributed import DistributedArgs, setup_torch_distributed | |
from bytelatent.generate_patcher import patcher_nocache | |
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer | |
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies | |
def main(prompt: str, model_name: str = "blt-1b"): | |
from bytelatent.args import TrainArgs | |
consolidated_path = os.path.join("hf-weights", model_name) | |
train_args_path = os.path.join(consolidated_path, "params.json") | |
fs = get_fs(train_args_path) | |
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) | |
tokenizer = train_args.data.tokenizer_args.build() | |
assert isinstance(tokenizer, BltTokenizer) | |
patcher_args = train_args.data.patcher_args.model_copy(deep=True) | |
patcher_args.realtime_patching = True | |
# NOTE: CPU currently unsupported due to reliance of xformers | |
patcher_args.patching_device = "cpu" | |
patcher_args.device = "cpu" | |
print("Loading entropy model and patcher") | |
patcher_args.entropy_model_checkpoint_dir = os.path.join( | |
consolidated_path, "entropy_model" | |
) | |
patcher = patcher_args.build() | |
prompts = [prompt] | |
results = patcher_nocache( | |
prompts, tokenizer=tokenizer, patcher=patcher | |
) | |
if not results: | |
raise Exception("Ruh roh") | |
batch_patch_lengths, batch_scores, batch_tokens = results | |
decoded_chars = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens] | |
plot_entropies( | |
batch_patch_lengths[0], | |
batch_scores[0], | |
decoded_chars[0], | |
threshold=patcher.threshold | |
) | |
if __name__ == "__main__": | |
typer.run(main) | |