Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import os | |
import time | |
import torch | |
from omegaconf import OmegaConf | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.nn.attention.flex_attention import create_block_mask | |
from tqdm import tqdm | |
from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs | |
from bytelatent.base_transformer import ( | |
Attention, | |
causal_mask, | |
generate_doc_mask_mod, | |
lengths_to_local_ids, | |
lengths_to_start_ids, | |
) | |
from bytelatent.checkpoint import ( | |
CONSOLIDATE_FOLDER, | |
CONSOLIDATE_NAME, | |
consolidate_checkpoints, | |
) | |
from bytelatent.config_parser import parse_args_to_pydantic_model | |
from bytelatent.data.file_util import get_fs | |
from bytelatent.distributed import ( | |
DistributedArgs, | |
get_global_rank, | |
setup_torch_distributed, | |
) | |
from bytelatent.model.blt import ByteLatentTransformer | |
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer | |
from bytelatent.transformer import LMTransformer | |
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token | |
def sample_top_k(probs, k): | |
topk_value, _ = torch.topk(probs, k) # batch_sz x topk | |
min_value_top_k = topk_value[:, [-1]] | |
probs[probs < min_value_top_k] = 0.0 | |
probs.div_(probs.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs, num_samples=1) | |
return next_token | |
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): | |
shape = logits.shape | |
logits = logits.flatten(end_dim=-2) | |
if temperature > 0.0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
if top_p is not None: | |
next_token = sample_top_p(probs, top_p) | |
elif top_k is not None: | |
next_token = sample_top_k(probs, top_k) | |
else: | |
next_token = torch.multinomial(probs, num_samples=1) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
return next_token.view(shape[:-1]) | |
def pack_prompts(prompts: list[int]): | |
res = [] | |
lengths = [] | |
for i, p in enumerate(prompts): | |
p = torch.tensor(p, dtype=torch.long) | |
l = p.size(0) | |
res.append(p) | |
lengths.append(l) | |
lengths = torch.tensor(lengths, dtype=torch.long) | |
res = torch.cat(res) | |
return res, lengths | |
def batch_prompts(prompts, max_elements, lengths=None): | |
batches = [] | |
current_batch = [] | |
current_count = 0 | |
for i in range(len(prompts)): | |
prt = prompts[i] | |
prompt_size = len(prt) if lengths is None else lengths[i] | |
if current_count + prompt_size <= max_elements: | |
current_batch.append(prt) | |
current_count += prompt_size | |
else: | |
if current_batch: # Add the current batch to batches | |
batches.append(current_batch) | |
# Start a new batch with the current prompt | |
current_batch = [prt] | |
current_count = prompt_size | |
# Add the last batch if it contains any prompts | |
if current_batch: | |
batches.append(current_batch) | |
return batches | |
class KVCache(nn.Module): | |
def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device): | |
super().__init__() | |
shape = (bsz, seqlen, n_heads, head_dim) | |
self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device)) | |
self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device)) | |
self.offset = 0 | |
def reset(self): | |
self.k_cache.zero_() | |
self.v_cache.zero_() | |
self.offset = 0 | |
def update(self, k_val, v_val, tok_idx): | |
# input_pos: [B], k_val: [B, S, H, D] | |
self.k_cache.index_copy_(1, self.offset + tok_idx, k_val) | |
self.v_cache.index_copy_(1, self.offset + tok_idx, v_val) | |
return self.k_cache, self.v_cache | |
class PackedCausalTransformerGenerator: | |
def __init__( | |
self, | |
cfg: PackedCausalTransformerGeneratorArgs, | |
model: nn.Module, | |
tokenizer: Tokenizer, | |
): | |
""" | |
This class wraps a causal transformer model with its corresponding tokenizer | |
and provides an efficient way to pack prompts together and do generation on | |
the packed sequence. | |
For example, if we had the prompts "Hello, I am a " and "Initiating calibration " | |
Then this class will concatenate those sequence (pack them together) | |
"Hello, I am a Initiating calibration" | |
And make the necessary attention masks such that a sequence only attends to itself | |
during prefilling and generation. | |
This class creates a fixed size cache of size max_tokens or sum of prompt sizes | |
+ the max number of generated tokens per sequence. | |
""" | |
self.model = model | |
self.tokenizer = tokenizer | |
self.temperature = cfg.temperature | |
self.top_p = cfg.top_p | |
self.top_k = cfg.top_k | |
self.max_gen_len = cfg.max_gen_len | |
self.max_tokens = cfg.max_tokens | |
self.max_prompt_len = cfg.max_prompt_len | |
self.until = cfg.until | |
self.max_until_size = max([len(e) for e in self.until]) if self.until else 1 | |
self.device = cfg.device | |
# Compile if necessary | |
self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling) | |
self.generate_next_token = torch.compile( | |
self.generate_next_token, | |
mode="reduce-overhead", | |
disable=not cfg.reduce_generation_overhead, | |
) | |
self.show_progress = cfg.show_progress | |
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype] | |
self.prefill_doc_id, self.prefill_tok_id = None, None | |
self.padded_doc_id, self.padded_tok_id = None, None | |
self.current_doc_id, self.current_tok_id = None, None | |
self.padded_doc_start = None | |
self.prefill_mask = None | |
def clear_cache(self, offset): | |
for module in self.model.modules(): | |
if isinstance(module, Attention): | |
if not hasattr(module, "kv_cache"): | |
module.kv_cache = KVCache( | |
1, | |
self.max_tokens, | |
module.n_kv_heads, | |
module.head_dim, | |
self.dtype, | |
self.device, | |
) | |
module.kv_cache.offset = offset | |
def setup_prefilling(self, lengths: torch.Tensor): | |
# The KV cache is a fixed size tensor of size max_tokens that we need | |
# to update in order to do correct autoregressive generation. | |
# Here we will generate token by token but on multiple sequences | |
# at once. To do so, we need to have an attention mask that makes | |
# each sequence independent. | |
# Each sequence will write to its allocated space in the KV Cache. | |
# We allocate len(seq) + max_gen_len to each sequence in the cache. | |
# We will generate max_gen_len for each document | |
padded_lengths = lengths + self.max_gen_len | |
max_tokens = self.max_tokens or padded_lengths.sum().item() | |
# The last document might have more padding to fill up to max_tokens | |
padded_lengths[-1] += max_tokens - padded_lengths.sum() | |
# This is the start index in the cache for each document | |
self.padded_doc_start = lengths_to_start_ids(padded_lengths) | |
# For example with ab--123--cdef-- | |
# this would be 0, 4, 9 if max_gen_len is 2 | |
# We repeat interleave to align with tokens for prefilling | |
# Ex: ab--123--cdef-- | |
# 000044444999999 | |
prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths) | |
# This offset will make sure the tokens are written to the | |
# correct positions in the cache during prefilling | |
# We either init the cache or clear it by resetting the offset to prefill_offset | |
self.clear_cache(prefill_offset) | |
# The prefilling mask looks like the following for | |
# the two packed sequences ab and 123 : ab123 | |
# Where spaces are empty cache positions | |
# keys | |
# ab---123--- | |
# queries a 10000000000 | |
# b 11000000000 | |
# 1 00000100000 | |
# 2 00000110000 | |
# 3 00000111000 | |
# We make sure to skip the empty cache positions | |
# and only attend to positions within the same sequence | |
doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths) | |
self.prefill_mask = create_block_mask( | |
doc_mask_mod, 1, None, lengths.sum(), max_tokens | |
) | |
# This creates the prefilling token ids which look like | |
# the following for the packed sequence abcdefg1234 | |
# abcdefg1234 | |
# 01234560123 | |
# The token id gives us the position within each sequence | |
# This is used to compute ROPE and to update the cache | |
# At each forward pass the current tokens are written to | |
# offset + tok_id | |
self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths) | |
# This creates the padded token and document ids | |
# which look like the following for the packed sequence ab123 | |
# ab---123--- ab---123--- | |
# padded_doc_id 00000111111 padded_tok_id 01234012345 | |
# This will later be useful for the attention mask at generation | |
self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths) | |
def setup_generation(self, lengths): | |
# KV Cache offset is set to the start of the padded documents | |
for module in self.model.modules(): | |
if isinstance(module, Attention): | |
module.kv_cache.offset = self.padded_doc_start | |
# The token ids during generations correspond to the lengths of each doc | |
# current_tok_id will be incremented during generation | |
self.current_tok_id = lengths.clone() | |
# Since we're generating one token per document | |
# the document id is just an arange | |
self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device) | |
# From here on some methods for generation | |
def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor): | |
# Prefilling is done by taking multiple packed sequences and | |
# doing block diagonal attention on them so they remain independent | |
self.setup_prefilling(lengths=lengths) | |
prefill_out = self.model.forward( | |
tokens, | |
tok_idx=self.prefill_tok_id, | |
mask=self.prefill_mask, | |
attn_impl="flex_attention", | |
) | |
self.setup_generation(lengths=lengths) | |
return prefill_out | |
def generate_next_token(self, current_token): | |
# Since we're doing generation with multiple sequences at once | |
# we need to ignore tokens and cache entries from other sequences | |
# or in the future. | |
# Example mask : | |
# keys | |
# abc--1234-- | |
# queries c 11100000000 | |
# 4 00000111100 | |
# mask shape : (n_seqs, cache_size) | |
doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0) | |
caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0) | |
mask = doc_mask & caus_mask | |
out = self.model.forward( | |
current_token, | |
tok_idx=self.current_tok_id, # n_seqs | |
mask=mask, | |
attn_impl="sdpa", | |
) | |
self.current_tok_id += 1 | |
return out | |
def generate(self, prompts): | |
# Tokenize | |
prompts = [ | |
self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts | |
] | |
# Truncate | |
max_seqlen = ( | |
self.max_tokens | |
if not hasattr(self.model, "max_seqlen") | |
else self.model.max_seqlen | |
) | |
max_prompt_len = self.max_prompt_len or min( | |
max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len | |
) | |
prompts = [p[-max_prompt_len:] for p in prompts] | |
# Account for the generation in lengths | |
padded_lengths = [len(p) + self.max_gen_len for p in prompts] | |
generation = [] | |
loglikelihood = [] | |
greedy = [] | |
it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths) | |
if self.show_progress: | |
it = tqdm(it) | |
for batch in it: | |
n_seqs = len(batch) | |
generated_tokens = [[] for _ in range(n_seqs)] | |
is_done = [False for _ in range(n_seqs)] | |
packed_batch, lengths = pack_prompts(batch) | |
packed_batch, lengths = packed_batch.cuda(), lengths.cuda() | |
n_seqs = lengths.size(0) | |
# Prefilling cache | |
prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths) | |
# Selecting last token in each prompt | |
all_tokens = sample_tokens( | |
prompt_logits, self.temperature, self.top_p, self.top_k | |
) | |
start_token = all_tokens[:, lengths.cumsum(0) - 1] | |
for seq_id, tok in enumerate(start_token.squeeze(0).tolist()): | |
generated_tokens[seq_id].append(tok) | |
current_token = start_token | |
for i in range(1, self.max_gen_len): | |
next_logits = self.generate_next_token(current_token) | |
next_token = sample_tokens( | |
next_logits.clone(), self.temperature, self.top_p, self.top_k | |
) | |
for seq_id, tok in enumerate(next_token.squeeze(0).tolist()): | |
if not is_done[seq_id]: | |
generated_tokens[seq_id].append(tok) | |
current_end_str = self.tokenizer.decode( | |
generated_tokens[seq_id][-self.max_until_size :] | |
) | |
contains_end_string = any( | |
[e in current_end_str for e in self.until] | |
) | |
is_done[seq_id] = ( | |
contains_end_string or tok == self.tokenizer.eos_id | |
) | |
if all(is_done): | |
break | |
current_token = next_token | |
generation.extend([self.tokenizer.decode(g) for g in generated_tokens]) | |
for p, logit in zip( | |
batch, prompt_logits.squeeze(0).split(lengths.tolist()) | |
): | |
x = logit[:-1] | |
y = torch.tensor(p[1:], device=x.device) | |
loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu()) | |
greedy.append((x.argmax(dim=-1) == y).cpu()) | |
return generation, loglikelihood, greedy | |
def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=False): | |
if init_distributed: | |
distributed_args = DistributedArgs() | |
distributed_args.configure_world() | |
if not torch.distributed.is_initialized(): | |
setup_torch_distributed(distributed_args) | |
train_args_path = os.path.join(consolidated_path, "params.json") | |
fs = get_fs(train_args_path) | |
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) | |
if train_args.train_entropy_model: | |
model_args = train_args.entropy_model | |
model = LMTransformer(model_args) | |
else: | |
model_args = train_args.model | |
model = ByteLatentTransformer(model_args) | |
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ | |
train_args.distributed.model_dtype | |
] | |
tokenizer = train_args.data.tokenizer_args.build() | |
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: | |
st_dict = torch.load(f, weights_only=True) | |
model.load_state_dict(st_dict["model"]) | |
model = model.cuda().eval() | |
for param in model.parameters(): | |
param.data = param.data.to(dtype=param_dtype) | |
return model, tokenizer, train_args | |
def main(): | |
# Load CLI arguments (overrides) and combine with a YAML config | |
eval_args = parse_args_to_pydantic_model(EvalArgs) | |
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) | |
if ( | |
fs.exists(eval_args.ckpt_dir) | |
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) | |
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 | |
): | |
consolidate_path = eval_args.ckpt_dir | |
else: | |
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) | |
if not fs.exists(consolidate_path) and get_global_rank() == 0: | |
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) | |
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( | |
consolidate_path | |
) | |
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) | |
# Allow multiple prompts | |
prompts = [] | |
while True: | |
prompt = input("Enter a prompt (or press enter to finish): ") | |
if not prompt: | |
break | |
prompts.append(prompt) | |
# Start generation | |
start_time = time.time() | |
generation, loglikelihood, greedy = generator.generate(prompts) | |
end_time = time.time() | |
# Calculate tokens per second | |
total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation) | |
tokens_per_second = total_tokens / (end_time - start_time) | |
# Display the results | |
for i, gen in enumerate(generation): | |
print(f"\nPrompt {i+1}: {prompts[i]}") | |
print(f"Generated Text: {gen}") | |
print(f"\nTokens per second: {tokens_per_second:.2f}") | |
if __name__ == "__main__": | |
main() | |