Spaces:
Build error
Build error
import time | |
import torch | |
import torch.nn as nn | |
from gptq import * | |
from modelutils import * | |
from quant import * | |
def get_llama(model): | |
import torch | |
def skip(*args, **kwargs): | |
pass | |
torch.nn.init.kaiming_uniform_ = skip | |
torch.nn.init.uniform_ = skip | |
torch.nn.init.normal_ = skip | |
from transformers import LlamaForCausalLM | |
model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') | |
model.seqlen = 2048 | |
return model | |
def llama_sequential(model, dataloader, dev): | |
print('Starting ...') | |
use_cache = model.config.use_cache | |
model.config.use_cache = False | |
layers = model.model.layers | |
model.model.embed_tokens = model.model.embed_tokens.to(dev) | |
model.model.norm = model.model.norm.to(dev) | |
layers[0] = layers[0].to(dev) | |
dtype = next(iter(model.parameters())).dtype | |
inps = torch.zeros( | |
(args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev | |
) | |
cache = {'i': 0, 'attention_mask': None} | |
class Catcher(nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
def forward(self, inp, **kwargs): | |
inps[cache['i']] = inp | |
cache['i'] += 1 | |
cache['attention_mask'] = kwargs['attention_mask'] | |
cache['position_ids'] = kwargs['position_ids'] | |
raise ValueError | |
layers[0] = Catcher(layers[0]) | |
for batch in dataloader: | |
try: | |
model(batch[0].to(dev)) | |
except ValueError: | |
pass | |
layers[0] = layers[0].module | |
layers[0] = layers[0].cpu() | |
model.model.embed_tokens = model.model.embed_tokens.cpu() | |
model.model.norm = model.model.norm.cpu() | |
torch.cuda.empty_cache() | |
outs = torch.zeros_like(inps) | |
attention_mask = cache['attention_mask'] | |
position_ids = cache['position_ids'] | |
print('Ready.') | |
quantizers = {} | |
for i in range(len(layers)): | |
layer = layers[i].to(dev) | |
full = find_layers(layer) | |
if args.true_sequential: | |
sequential = [ | |
['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], | |
['self_attn.o_proj'], | |
['mlp.up_proj', 'mlp.gate_proj'], | |
['mlp.down_proj'] | |
] | |
else: | |
sequential = [list(full.keys())] | |
for names in sequential: | |
subset = {n: full[n] for n in names} | |
gptq = {} | |
for name in subset: | |
gptq[name] = GPTQ(subset[name]) | |
gptq[name].quantizer = Quantizer() | |
gptq[name].quantizer.configure( | |
args.wbits, perchannel=True, sym=args.sym, mse=False | |
) | |
def add_batch(name): | |
def tmp(_, inp, out): | |
gptq[name].add_batch(inp[0].data, out.data) | |
return tmp | |
handles = [] | |
for name in subset: | |
handles.append(subset[name].register_forward_hook(add_batch(name))) | |
for j in range(args.nsamples): | |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids = position_ids)[0] | |
for h in handles: | |
h.remove() | |
for name in subset: | |
print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') | |
scale,zero = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) | |
quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer,scale,zero) | |
gptq[name].free() | |
for j in range(args.nsamples): | |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids = position_ids)[0] | |
layers[i] = layer.cpu() | |
del layer | |
del gptq | |
torch.cuda.empty_cache() | |
inps, outs = outs, inps | |
model.config.use_cache = use_cache | |
return quantizers | |
def llama_eval(model, testenc, dev): | |
print('Evaluating ...') | |
testenc = testenc.input_ids | |
nsamples = testenc.numel() // model.seqlen | |
use_cache = model.config.use_cache | |
model.config.use_cache = False | |
layers = model.model.layers | |
model.model.embed_tokens = model.model.embed_tokens.to(dev) | |
layers[0] = layers[0].to(dev) | |
dtype = next(iter(model.parameters())).dtype | |
inps = torch.zeros( | |
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev | |
) | |
cache = {'i': 0, 'attention_mask': None} | |
class Catcher(nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
def forward(self, inp, **kwargs): | |
inps[cache['i']] = inp | |
cache['i'] += 1 | |
cache['attention_mask'] = kwargs['attention_mask'] | |
cache['position_ids'] = kwargs['position_ids'] | |
raise ValueError | |
layers[0] = Catcher(layers[0]) | |
for i in range(nsamples): | |
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) | |
try: | |
model(batch) | |
except ValueError: | |
pass | |
layers[0] = layers[0].module | |
layers[0] = layers[0].cpu() | |
model.model.embed_tokens = model.model.embed_tokens.cpu() | |
torch.cuda.empty_cache() | |
outs = torch.zeros_like(inps) | |
attention_mask = cache['attention_mask'] | |
position_ids = cache['position_ids'] | |
for i in range(len(layers)): | |
print(i) | |
layer = layers[i].to(dev) | |
if args.nearest: | |
subset = find_layers(layer) | |
for name in subset: | |
quantizer = Quantizer() | |
quantizer.configure( | |
args.wbits, perchannel=True, sym=False, mse=False | |
) | |
W = subset[name].weight.data | |
quantizer.find_params(W, weight=True) | |
subset[name].weight.data = quantize( | |
W, quantizer.scale, quantizer.zero, quantizer.maxq | |
).to(next(iter(layer.parameters())).dtype) | |
for j in range(nsamples): | |
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids = position_ids)[0] | |
layers[i] = layer.cpu() | |
del layer | |
torch.cuda.empty_cache() | |
inps, outs = outs, inps | |
if model.model.norm is not None: | |
model.model.norm = model.model.norm.to(dev) | |
model.lm_head = model.lm_head.to(dev) | |
testenc = testenc.to(dev) | |
nlls = [] | |
for i in range(nsamples): | |
hidden_states = inps[i].unsqueeze(0) | |
if model.model.norm is not None: | |
hidden_states = model.model.norm(hidden_states) | |
lm_logits = model.lm_head(hidden_states) | |
shift_logits = lm_logits[:, :-1, :].contiguous() | |
shift_labels = testenc[ | |
:, (i * model.seqlen):((i + 1) * model.seqlen) | |
][:, 1:] | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
neg_log_likelihood = loss.float() * model.seqlen | |
nlls.append(neg_log_likelihood) | |
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) | |
print(ppl.item()) | |
model.config.use_cache = use_cache | |
# TODO: perform packing on GPU | |
def llama_pack(model, quantizers, wbits, groupsize): | |
layers = find_layers(model) | |
layers = {n: layers[n] for n in quantizers} | |
make_quant(model, quantizers, wbits, groupsize) | |
qlayers = find_layers(model, [QuantLinear]) | |
print('Packing ...') | |
for name in qlayers: | |
print(name) | |
quantizers[name],scale,zero = quantizers[name] | |
quantizers[name],scale,zero = quantizers[name].cpu(),scale.cpu(),zero.cpu() | |
qlayers[name].pack(layers[name], scale, zero) | |
print('Done.') | |
return model | |
def load_quant(model, checkpoint, wbits, groupsize=-1,faster_kernel=False): | |
from transformers import LlamaConfig, LlamaForCausalLM | |
config = LlamaConfig.from_pretrained(model) | |
def noop(*args, **kwargs): | |
pass | |
torch.nn.init.kaiming_uniform_ = noop | |
torch.nn.init.uniform_ = noop | |
torch.nn.init.normal_ = noop | |
torch.set_default_dtype(torch.half) | |
transformers.modeling_utils._init_weights = False | |
torch.set_default_dtype(torch.half) | |
model = LlamaForCausalLM(config) | |
torch.set_default_dtype(torch.float) | |
model = model.eval() | |
layers = find_layers(model) | |
for name in ['lm_head']: | |
if name in layers: | |
del layers[name] | |
make_quant(model, layers, wbits, groupsize, faster=faster_kernel) | |
del layers | |
print('Loading model ...') | |
if checkpoint.endswith('.safetensors'): | |
from safetensors.torch import load_file as safe_load | |
model.load_state_dict(safe_load(checkpoint)) | |
else: | |
model.load_state_dict(torch.load(checkpoint)) | |
model.seqlen = 2048 | |
print('Done.') | |
return model | |
def llama_multigpu(model, gpus): | |
model.model.embed_tokens = model.model.embed_tokens.to(gpus[0]) | |
if hasattr(model.model, 'norm') and model.model.norm: | |
model.model.norm = model.model.norm.to(gpus[-1]) | |
import copy | |
model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) | |
cache = {'mask': None} | |
class MoveModule(nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
self.dev = next(iter(self.module.parameters())).device | |
def forward(self, *inp, **kwargs): | |
inp = list(inp) | |
if inp[0].device != self.dev: | |
inp[0] = inp[0].to(self.dev) | |
if cache['mask'] is None or cache['mask'].device != self.dev: | |
cache['mask'] = kwargs['attention_mask'].to(self.dev) | |
kwargs['attention_mask'] = cache['mask'] | |
tmp = self.module(*inp, **kwargs) | |
return tmp | |
layers = model.model.layers | |
pergpu = math.ceil(len(layers) / len(gpus)) | |
for i in range(len(layers)): | |
layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) | |
model.gpus = gpus | |
def benchmark(model, input_ids, check=False): | |
input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) | |
torch.cuda.synchronize() | |
cache = {'past': None} | |
def clear_past(i): | |
def tmp(layer, inp, out): | |
if cache['past']: | |
cache['past'][i] = None | |
return tmp | |
for i, layer in enumerate(model.model.layers): | |
layer.register_forward_hook(clear_past(i)) | |
print('Benchmarking ...') | |
if check: | |
loss = nn.CrossEntropyLoss() | |
tot = 0. | |
def sync(): | |
if hasattr(model, 'gpus'): | |
for gpu in model.gpus: | |
torch.cuda.synchronize(gpu) | |
else: | |
torch.cuda.synchronize() | |
max_memory = 0 | |
with torch.no_grad(): | |
attention_mask = torch.ones((1, input_ids.numel()), device=DEV) | |
times = [] | |
for i in range(input_ids.numel()): | |
tick = time.time() | |
out = model( | |
input_ids[:, i:i+1], | |
past_key_values=cache['past'], | |
attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)) | |
) | |
sync() | |
times.append(time.time() - tick) | |
print(i, times[-1]) | |
max_memory = max(max_memory,torch.cuda.memory_allocated() / 1024 /1024) | |
if check and i != input_ids.numel() - 1: | |
tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() | |
cache['past'] = list(out.past_key_values) | |
del out | |
sync() | |
import numpy as np | |
print('Median:', np.median(times)) | |
if check: | |
print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) | |
print('max memory(MiB):',max_memory) | |
if __name__ == '__main__': | |
import argparse | |
from datautils import * | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'model', type=str, | |
help='llama model to load' | |
) | |
parser.add_argument( | |
'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], | |
help='Where to extract calibration data from.' | |
) | |
parser.add_argument( | |
'--seed', | |
type=int, default=0, help='Seed for sampling the calibration data.' | |
) | |
parser.add_argument( | |
'--nsamples', type=int, default=128, | |
help='Number of calibration data samples.' | |
) | |
parser.add_argument( | |
'--percdamp', type=float, default=.01, | |
help='Percent of the average Hessian diagonal to use for dampening.' | |
) | |
parser.add_argument( | |
'--nearest', action='store_true', | |
help='Whether to run the RTN baseline.' | |
) | |
parser.add_argument( | |
'--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], | |
help='#bits to use for quantization; use 16 for evaluating base model.' | |
) | |
parser.add_argument( | |
'--trits', action='store_true', | |
help='Whether to use trits for quantization.' | |
) | |
parser.add_argument( | |
'--groupsize', type=int, default=-1, | |
help='Groupsize to use for quantization; default uses full row.' | |
) | |
parser.add_argument( | |
'--eval', action='store_true', | |
help='evaluate quantized model.' | |
) | |
parser.add_argument( | |
'--save', type=str, default='', | |
help='Save quantized checkpoint under this name.' | |
) | |
parser.add_argument( | |
'--save_safetensors', type=str, default='', | |
help='Save quantized `.safetensors` checkpoint under this name.' | |
) | |
parser.add_argument( | |
'--load', type=str, default='', | |
help='Load quantized model.' | |
) | |
parser.add_argument( | |
'--benchmark', type=int, default=0, | |
help='Number of tokens to use for benchmarking.' | |
) | |
parser.add_argument( | |
'--check', action='store_true', | |
help='Whether to compute perplexity during benchmarking for verification.' | |
) | |
parser.add_argument( | |
'--sym', action='store_true', | |
help='Whether to perform symmetric quantization.' | |
) | |
parser.add_argument( | |
'--act-order', action='store_true', | |
help='Whether to apply the activation order GPTQ heuristic' | |
) | |
parser.add_argument( | |
'--true-sequential', action='store_true', | |
help='Whether to run in true sequential model.' | |
) | |
parser.add_argument( | |
'--new-eval', action='store_true', | |
help='Whether to use the new PTB and C4 eval' | |
) | |
parser.add_argument( | |
'--faster-kernel', action='store_true', | |
help='Whether to use the new faster kernel for benchmarking.' | |
) | |
args = parser.parse_args() | |
if type(args.load) is not str: | |
args.load = args.load.as_posix() | |
if args.load: | |
model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.faster_kernel) | |
else: | |
model = get_llama(args.model) | |
model.eval() | |
dataloader, testloader = get_loaders( | |
args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen | |
) | |
if not args.load and args.wbits < 16 and not args.nearest: | |
tick = time.time() | |
quantizers = llama_sequential(model, dataloader, DEV) | |
print(time.time() - tick) | |
if args.eval: | |
datasets = ['wikitext2', 'ptb', 'c4'] | |
if args.new_eval: | |
datasets = ['wikitext2', 'ptb-new', 'c4-new'] | |
for dataset in datasets: | |
dataloader, testloader = get_loaders( | |
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen | |
) | |
print(dataset) | |
llama_eval(model, testloader, DEV) | |
if args.save: | |
llama_pack(model, quantizers, args.wbits, args.groupsize) | |
torch.save(model.state_dict(), args.save) | |
if args.save_safetensors: | |
llama_pack(model, quantizers, args.wbits, args.groupsize) | |
from safetensors.torch import save_file as safe_save | |
safe_save(model.state_dict(), args.save_safetensors) | |
if args.benchmark: | |
gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] | |
if len(gpus) > 1: | |
llama_multigpu(model, gpus) | |
else: | |
model = model.to(DEV) | |
if args.benchmark: | |
input_ids = next(iter(dataloader))[0][:, :args.benchmark] | |
benchmark(model, input_ids, check=args.check) | |
if args.load: | |
exit() | |