Spaces:
Build error
Build error
from model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from tokenizer import ExLlamaTokenizer | |
from generator import ExLlamaGenerator | |
from lora import ExLlamaLora | |
import perplexity | |
from perplexity import Perplexity | |
import time | |
import torch | |
import torch.nn.functional as F | |
import argparse | |
import json | |
import math | |
import sys | |
import os | |
import glob | |
import model_init | |
torch.cuda._lazy_init() | |
# torch.backends.cuda.matmul.allow_tf32 = True | |
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | |
torch.set_printoptions(precision = 10) | |
torch_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())] | |
cache = None | |
model = None | |
def begin(): | |
global model, cache | |
if cache is None: cache = ExLlamaCache(model) | |
else: cache.current_seq_len = 0 | |
def next_logits(input_ids, apply_lora, last_id_only = True, input_mask = None): | |
global model, cache | |
# n_logits = None | |
# a = 0 | |
# while a < input_ids.shape[-1]: | |
# b = min(input_ids.shape[-1], a + 2048) | |
# n_logits = model.forward(input_ids[:, a:b], cache, last_id_only, lora = apply_lora, input_mask = input_mask) | |
# a = b | |
n_logits = model.forward(input_ids, cache, last_id_only, lora=apply_lora, input_mask=input_mask) | |
return n_logits | |
def tokenize(text): | |
global tokenizer | |
return tokenizer.encode(text) | |
def timer(name, func): | |
t = time.time() | |
ret = func() | |
t = time.time() - t | |
print(f" ** Time, {name}: {t:.2f} seconds") | |
return ret | |
mem_base = {} | |
mem_last = {} | |
for dev in torch_devices: | |
torch.cuda.reset_peak_memory_stats(dev) | |
mem_base[dev] = mem_last[dev] = torch.cuda.max_memory_allocated(dev) | |
def mem(name, total = False): | |
global mem_base, mem_last | |
res = f" ** VRAM, {name}: " | |
first = True | |
for device in torch_devices: | |
mem_c = torch.cuda.max_memory_allocated(device) | |
mem_this = mem_c - mem_last[device] if not total else mem_c - mem_base[device] | |
mem_last[device] = mem_c | |
if not first: res += " - " | |
first = False | |
res += f"[{device}] {mem_this / (1024 ** 2):,.2f} MB" | |
print(res) | |
# Parse arguments | |
parser = argparse.ArgumentParser(description = "Benchmark tests for ExLlama") | |
model_init.add_args(parser) | |
perplexity.add_args(parser) | |
parser.add_argument("-p", "--perf", action = "store_true", help = "Benchmark speed and VRAM usage") | |
parser.add_argument("-v", "--validate", action = "count", help = "Run validation check and generate some sample output; specify twice for a more thorough test") | |
parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark") | |
parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark") | |
parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark") | |
args = parser.parse_args() | |
model_init.post_parse(args) | |
perplexity.post_parse(args) | |
model_init.get_model_files(args) | |
# Paths | |
if args.lora_dir is not None: | |
args.lora_config = os.path.join(args.lora_dir, "adapter_config.json") | |
args.lora = os.path.join(args.lora_dir, "adapter_model.bin") | |
# Feedback | |
print_opts = [] | |
if args.perf: print_opts.append("perf") | |
if args.validate: print_opts.append("validate") | |
if args.perplexity: print_opts.append("perplexity") | |
if args.perplexity_token: print_opts.append("perplexity_token") | |
model_init.print_options(args, print_opts) | |
# Globals | |
model_init.set_globals(args) | |
# Instantiate model | |
config = model_init.make_config(args) | |
model = timer("Load model", lambda: ExLlama(config)) | |
tokenizer = timer("Load tokenizer", lambda: ExLlamaTokenizer(args.tokenizer)) | |
model_init.print_stats(model) | |
torch.cuda.reset_peak_memory_stats("cuda") | |
mem("Model") | |
cache = ExLlamaCache(model) | |
mem("Cache") | |
# Load LoRA | |
lora = None | |
if args.lora: | |
print(f" -- LoRA config: {args.lora_config}") | |
print(f" -- Loading LoRA: {args.lora}") | |
if args.lora_config is None: | |
print(f" ## Error: please specify lora path to adapter_config.json") | |
sys.exit() | |
lora = ExLlamaLora(model, args.lora_config, args.lora) | |
if lora.bias_ignored: | |
print(f" !! Warning: LoRA zero bias ignored") | |
# Test sequence | |
gen_tokens = 128 | |
max_seq_len = args.length | |
ids = torch.randint(0, 31999, (1, max_seq_len - gen_tokens)).cuda() | |
# Benchmark memory and performance | |
if args.perf: | |
# Warming up apparently makes a huge difference | |
for i in range(1, 3): | |
print(f" -- Warmup pass {i}...") | |
begin() | |
logits = timer("Warmup", lambda: next_logits(ids, lora)) | |
# Do the actual benchmark | |
begin() | |
t = time.time() | |
print(" -- Inference, first pass.") | |
logits = timer("Inference", lambda: next_logits(ids, lora)) | |
t = time.time() - t | |
print(f" ** Speed: {ids.shape[-1] / t:.2f} tokens/second") | |
for j in range(2): | |
t = time.time() | |
print(f" -- Generating {gen_tokens} tokens, {ids.shape[-1]} token prompt...") | |
for i in range(gen_tokens): | |
logits = logits[0, -1, :] | |
token = torch.argmax(logits) | |
next_id = token.unsqueeze(0).unsqueeze(0) | |
logits = next_logits(next_id, lora) | |
t = time.time() - t | |
print(f" ** Speed: {gen_tokens / t:.2f} tokens/second") | |
ids = ids[:, :4] | |
cache.current_seq_len = 4 | |
mem("Inference") | |
mem("Total", total = True) | |
# Benchmark perplexity | |
if args.perplexity: | |
ppl = Perplexity(args.perplexity, model, cache, tokenizer) | |
print(" -- Loading dataset...") | |
ppl.load(dataset_path = args.perplexity_dataset, | |
chunk_size = args.perplexity_chunk_size, | |
chunk_truncate = args.perplexity_chunk_truncate, | |
overlap = args.perplexity_chunk_overlap, | |
minlength = args.perplexity_chunk_min, | |
json_key = args.perplexity_json_key) | |
begin() | |
ppl.test(args.perplexity_chunk_num, | |
lora = lora, | |
ppl_token = args.perplexity_token) | |
# Validate file | |
if args.validate: | |
ppl = Perplexity(args.perplexity, model, cache, tokenizer) | |
ppl.load(dataset_path = "datasets/wikitext2_val_sample.jsonl", | |
chunk_size = 2048, | |
chunk_truncate = 2048, | |
overlap = 0, | |
minlength = 50, | |
json_key = "text") | |
# Short perplexity tests in switched and quant mode, should produce roughly equal results | |
begin() | |
ppl.cache.zero() | |
model.config.matmul_recons_thd = 1 | |
ppl.test(8, lora = lora, tag = " (reconstruct)") | |
ppl.cache.zero() | |
model.config.matmul_recons_thd = 0 | |
ppl.test(8, lora = lora, tag = " (quant, token)", ppl_token = True) | |
# Do a short, easy topk=1 completion to see if we're generating garbage. Should run in switched mode | |
# for the prompt and quant for individual tokens | |
model.config.matmul_recons_thd = 4 | |
generator = ExLlamaGenerator(model, tokenizer, cache) | |
generator.settings.top_k = 1 | |
generator.lora = lora | |
text = generator.generate_simple("To be or not to be, that is the", max_new_tokens = 20 * args.validate) | |
print(f" ** Generation: {repr(text)}") | |
if args.validate > 1: | |
# Test batched generation | |
bsz = 8 | |
gen_len = 20 | |
torch.manual_seed(42) | |
torch.cuda.manual_seed_all(42) | |
# Bigger cache for the batch | |
del cache | |
cache = ExLlamaCache(model, batch_size = bsz) | |
# Create tokenized batch and attention mask | |
identical_batch_prompt = "When you have eliminated the impossible, whatever remains," | |
continuations = [ | |
" must be considered", | |
" ought to be", | |
" (and some scholars say this is", | |
" however improbable, is a banana.", | |
] | |
prompts = [identical_batch_prompt] * (bsz - len(continuations)) | |
for cont in continuations: | |
prompts.append(identical_batch_prompt + cont) | |
ids = tokenizer.encode(prompts) | |
assert ids.shape[1] < model.config.max_seq_len, f"Max length {ids.shape[1]} exceeds model limit {model.config.max_seq_len}" | |
mask = ids.ne(tokenizer.pad_token_id) | |
# Batched generation with greedy sampling | |
sequence = torch.empty((bsz, 0), dtype = torch.long, device = "cpu") | |
logits = next_logits(ids, lora, input_mask = mask) | |
for i in range(gen_len): | |
logits = logits[:, -1, :] | |
id_per_batch = torch.argmax(logits, dim=-1) | |
assert id_per_batch.shape == (bsz,), f"{id_per_batch.shape} != {(bsz,)}" | |
next_id_per_batch = id_per_batch.unsqueeze(-1) | |
sequence = torch.cat((sequence, next_id_per_batch), dim = -1) | |
logits = next_logits(next_id_per_batch, lora) | |
# Print output batch | |
print(f"\n ** Batching sanity check: 1-{bsz - len(continuations)} should be identical. All should be reasonable for the model you're using.\n") | |
outputs = tokenizer.decode(sequence) | |
for b in range(bsz): | |
print(f"{b + 1} {repr(prompts[b])} -> {repr(outputs[b])}") | |
# TODO Save the logits and then rerun each prompt with a batch size of 1, same input. The logits should be identical. | |