Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import json | |
import logging | |
import math | |
import os | |
from collections import defaultdict | |
from datetime import datetime | |
import torch | |
from lm_eval import simple_evaluate | |
from lm_eval.api.instance import Instance | |
from lm_eval.api.model import LM | |
from rich.progress import track | |
from torch.nn import functional as F | |
from bytelatent.args import ( | |
EvalArgs, | |
TrainArgs, | |
ValidationArgs, | |
find_and_sanitize_chunks, | |
) | |
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints | |
from bytelatent.config_parser import parse_args_to_pydantic_model | |
from bytelatent.data.file_util import get_fs | |
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator | |
from bytelatent.data.iterators.limit_iterator import LimitIterator | |
from bytelatent.data.iterators.packing_iterator import ( | |
PackingArgs, | |
PackingIterator, | |
PackingMode, | |
) | |
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator | |
from bytelatent.data.iterators.sequence_iterator import ( | |
SequenceIterator, | |
SequencePackingArgs, | |
) | |
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum | |
from bytelatent.distributed import ( | |
DistributedArgs, | |
dist_mean_dict, | |
dist_sum, | |
get_device_mesh, | |
get_global_rank, | |
get_world_size, | |
setup_torch_distributed, | |
to_py_num, | |
) | |
from bytelatent.generate import ( | |
PackedCausalTransformerGenerator, | |
load_consolidated_model_and_tokenizer, | |
) | |
from bytelatent.model.blt import ByteLatentTransformer | |
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs | |
from bytelatent.transformer import LMTransformer | |
EVAL_FOLDER_NAME = "{:010d}" | |
logger = logging.getLogger() | |
def all_dicts_same(dict_list): | |
if not dict_list: # Check if the list is empty | |
return True | |
# Compare each dictionary to the first one | |
first_dict = dict_list[0] | |
return all(d == first_dict for d in dict_list) | |
class MockAccelerator: | |
def gather(self, tensor): | |
l = [torch.zeros_like(tensor) for _ in range(get_world_size())] | |
torch.distributed.all_gather(l, tensor) | |
return torch.stack(l) | |
def wait_for_everyone(self): | |
torch.distributed.barrier() | |
# Light wrapper around generator for lm-eval harness | |
class EvalHarnessLM(LM): | |
def __init__(self, generator): | |
super().__init__() | |
self.generator = generator | |
self.accelerator = MockAccelerator() | |
self._rank = get_global_rank() | |
self._world_size = get_world_size() | |
self.device = generator.device | |
def generate_until(self, requests: list[Instance]) -> list[str]: | |
prompts, gen_args = zip(*[req.args for req in requests]) | |
assert all_dicts_same(gen_args), "Doesn't support different gen args for now" | |
gen_args = gen_args[0] | |
temperature = gen_args.get("temperature", 0.0) | |
top_p = gen_args.get("top_p", None) | |
top_k = gen_args.get("top_k", None) | |
until = gen_args.get("until", []) | |
self.generator.temperature = temperature | |
self.generator.top_p = top_p | |
self.generator.top_k = top_k | |
self.generator.until = until | |
generations, _, _ = self.generator.generate(prompts) | |
filtered_gen = [] | |
for g in generations: | |
for e in until: | |
g = g.replace(e, "") | |
filtered_gen.append(g) | |
return filtered_gen | |
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: | |
prompts, continuations = zip(*[req.args for req in requests]) | |
inputs = [req.args[0] + req.args[1] for req in requests] | |
max_gen_len = self.generator.max_gen_len | |
# We temporarily lower max gen len | |
self.generator.max_gen_len = 1 | |
_, lls, greedy = self.generator.generate(inputs) | |
results = [] | |
for p, ll, gr in zip(prompts, lls, greedy): | |
p_len = len( | |
self.generator.tokenizer.encode(p, add_bos=False, add_eos=False) | |
) | |
results.append((ll[p_len:].sum().item(), gr[p_len:].all().item())) | |
self.generator.max_gen_len = max_gen_len | |
return results | |
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: | |
prompts = [req.args[0] for req in requests] | |
max_gen_len = self.generator.max_gen_len | |
# We temporarily lower max gen len | |
self.generator.max_gen_len = 1 | |
_, lls, _ = self.generator.generate(prompts) | |
results = [] | |
for ll in lls: | |
results.append((ll.sum().item(),)) | |
self.generator.max_gen_len = max_gen_len | |
return results | |
def eval_ppl_on_path( | |
*, | |
world_rank: int, | |
world_size: int, | |
model: LMTransformer | ByteLatentTransformer, | |
tokenizer_args: TokenizerArgs, | |
patcher_args: PatcherArgs, | |
packing_args: PackingArgs, | |
add_patches: bool, | |
path: str, | |
arrow_batch_size: int, | |
max_n_docs: int | None, | |
max_n_batches: int | None, | |
s3_profile: str | None = None, | |
): | |
model.eval() | |
seq_len = model.get_output_seq_len() | |
arrow_iterator = ArrowFileIterator( | |
file_path=None, | |
dataset_files=[path], | |
entropy_model_name=None, | |
worker_id=world_rank, | |
num_workers=world_size, | |
arrow_batch_size=arrow_batch_size, | |
preprocess_dir=None, | |
s3_profile=s3_profile, | |
file_format="arrow" if path.endswith("arrow") else "json", | |
) | |
if max_n_docs is not None: | |
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs) | |
preprocess_iterator = PreprocessIterator( | |
arrow_iterator, | |
patcher_args=patcher_args, | |
tokenizer_args=tokenizer_args, | |
add_patches=add_patches, | |
) | |
sequence_iterator = SequenceIterator( | |
preprocess_iterator, | |
sequence_packing_args=SequencePackingArgs( | |
output_seq_len=seq_len, | |
# Effectively disables shuffles | |
buffer_size=1, | |
), | |
rng_state=None, | |
) | |
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args) | |
total_loss = 0.0 | |
n_bytes = 0 | |
batch_iterator = packing_iterator.create_iter() | |
for i, batch in enumerate(batch_iterator): | |
if i == max_n_batches: | |
break | |
x = torch.from_numpy(batch.x).cuda() | |
y = torch.from_numpy(batch.y).cuda() | |
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() | |
patch_lengths = batch.patch_lengths | |
if patch_lengths is not None: | |
patch_lengths = torch.from_numpy(patch_lengths).cuda() | |
if tokenizer_args.name in ["bytes", "blt"]: | |
n_bytes += y.numel() if mask is None else mask.sum().item() | |
if isinstance(model, ByteLatentTransformer): | |
pred = model(x, patch_lengths=patch_lengths) | |
else: | |
pred = model(x) | |
loss = F.cross_entropy( | |
pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0 | |
) | |
total_loss += loss.item() | |
else: | |
raise NotImplementedError() | |
all_n_bytes = to_py_num(dist_sum(n_bytes)) | |
all_total_loss = to_py_num(dist_sum(total_loss)) | |
return { | |
"n_bytes": all_n_bytes, | |
"n_bytes_gpu": n_bytes, | |
"loss_sum": all_total_loss, | |
"loss_sum_gpu": total_loss, | |
"loss_mean": all_total_loss / all_n_bytes, | |
"loss_mean_gpu": total_loss / n_bytes, | |
"ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0, | |
"bpb": all_total_loss / math.log(2) / all_n_bytes, | |
} | |
def launch_eval(eval_args: EvalArgs): | |
assert eval_args.dump_dir is not None | |
assert eval_args.ckpt_dir is not None | |
distributed_args = DistributedArgs() | |
distributed_args.configure_world() | |
if not torch.distributed.is_initialized(): | |
setup_torch_distributed(distributed_args) | |
world_mesh = get_device_mesh(distributed_args) | |
dp_mesh = world_mesh["dp_replicate"] | |
assert distributed_args.dp_shard == 1 | |
world_size = dp_mesh.size() | |
world_rank = dp_mesh.get_local_rank() | |
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) | |
if ( | |
fs.exists(eval_args.ckpt_dir) | |
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) | |
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 | |
): | |
consolidate_path = eval_args.ckpt_dir | |
else: | |
if eval_args.consolidate_if_needed: | |
logger.info( | |
"Found a model checkpoint, but it has not been consolidated.... so consolidating the checkpoint" | |
) | |
consolidate_path = os.path.join( | |
eval_args.ckpt_dir, eval_args.consolidate_folder | |
) | |
if not fs.exists(consolidate_path) and get_global_rank() == 0: | |
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) | |
logger.info("Model consolidated to: %s", consolidate_path) | |
else: | |
raise ValueError( | |
"Did not find a consolidated checkpoint and consolidate_if_needed is False" | |
) | |
fs.mkdirs(eval_args.dump_dir, exist_ok=True) | |
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: | |
f.write(eval_args.model_dump_json()) | |
torch.distributed.barrier() | |
logger.info("Loading model") | |
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( | |
consolidate_path, | |
) | |
pad_id = 0 if train_cfg.data.tokenizer_args.name == "bytes" else tokenizer.boe_id | |
model.eval() | |
logger.info("Model loaded") | |
ppl_results = None | |
if eval_args.run_ppl: | |
assert eval_args.validation is not None | |
packing_args = PackingArgs( | |
batch_size=eval_args.validation.batch_size, | |
seq_len=train_cfg.data.seq_len, | |
max_length=train_cfg.data.max_encoder_seq_length, | |
pad_to_max_length=True, | |
enable_byte_ngrams=False, | |
pad_id=pad_id, | |
packing_mode=( | |
PackingMode.BYTES | |
if train_cfg.data.patcher_args.patching_mode == PatchingModeEnum.byte | |
else PackingMode.PATCHING | |
), | |
) | |
if len(eval_args.validation.sources) > 0: | |
ppl_results = {} | |
logger.info("Starting PPL evaluation on validation sets") | |
for source in eval_args.validation.sources: | |
ppl_results[source] = eval_ppl_on_path( | |
world_rank=world_rank, | |
world_size=world_size, | |
model=model, | |
tokenizer_args=train_cfg.data.tokenizer_args, | |
patcher_args=train_cfg.data.patcher_args, | |
packing_args=packing_args, | |
add_patches=train_cfg.data.add_patches, | |
path=os.path.join(eval_args.validation.root_dir, source), | |
max_n_docs=eval_args.validation.max_n_docs, | |
max_n_batches=eval_args.validation.max_n_batches, | |
arrow_batch_size=20, | |
s3_profile=eval_args.s3_profile, | |
) | |
task_results = None | |
if eval_args.run_tasks: | |
assert eval_args.generator is not None | |
assert eval_args.harness is not None | |
generator = PackedCausalTransformerGenerator( | |
eval_args.generator, model, tokenizer | |
) | |
wrap = EvalHarnessLM(generator) | |
# TODO: This needs to be checked/sped up | |
task_results = simple_evaluate(wrap, **eval_args.harness.model_dump()) | |
results = {"ppl": ppl_results, "tasks": task_results} | |
# TODO: Serial and Parallel yield slightly different number of bytes, debug this later, | |
# leaving this log statement here to help with that. | |
# logging.info("Rank: %s Results: %s", world_rank, results) | |
if get_global_rank() == 0: | |
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: | |
f.write(json.dumps(results)) | |
logger.info(f"All evaluation results: {results}") | |
if ppl_results is not None: | |
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: | |
f.write(json.dumps(ppl_results)) | |
logger.info(f"All validation results: {ppl_results}") | |
if eval_args.metric_log_dir and get_global_rank() == 0: | |
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") | |
logger.info(f"Writing metric logs to {metric_log_path}") | |
timestamp: dict[str, int | str] = { | |
"created_at": datetime.utcnow().isoformat(), | |
} | |
if eval_args.global_step is not None: | |
timestamp["global_step"] = eval_args.global_step | |
print( | |
json.dumps(timestamp | results), | |
file=fs.open(metric_log_path, mode="a"), | |
flush=True, | |
) | |
val_log_path = os.path.join( | |
eval_args.metric_log_dir, "metrics.validation.jsonl" | |
) | |
if ppl_results is not None: | |
print( | |
json.dumps(timestamp | ppl_results), | |
file=fs.open(val_log_path, mode="a"), | |
flush=True, | |
) | |
def main(): | |
eval_args = parse_args_to_pydantic_model(EvalArgs) | |
launch_eval(eval_args) | |
if __name__ == "__main__": | |
main() | |