|
import torch |
|
from torch import nn |
|
import cupy as cp |
|
import math |
|
from os import path |
|
|
|
class Quantizer(nn.Module): |
|
def __init__(self, config, codebook): |
|
super().__init__() |
|
|
|
self.nsq, nc, self.d = codebook.shape |
|
self.b = int(math.log2(nc)) |
|
head_dim = config.hidden_size // config.num_attention_heads |
|
self.head_dim = head_dim |
|
qpk = config.num_attention_heads // config.num_key_value_heads |
|
self.window_length = getattr(config, 'window_length', 32) |
|
self.register_buffer('codebook', codebook) |
|
|
|
with open(path.join(path.dirname(__file__), "quantize.cu"), "r") as f: |
|
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) |
|
self._quantize = cp.RawKernel( |
|
kernel_code, |
|
'quantize', |
|
backend="nvrtc" |
|
) |
|
|
|
with open(path.join(path.dirname(__file__), "dequantize.cu"), "r") as f: |
|
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) |
|
self._dequantize = cp.RawKernel( |
|
kernel_code, |
|
'dequantize', |
|
backend="nvrtc" |
|
) |
|
|
|
with open(path.join(path.dirname(__file__), "dequantize_rope.cu"), "r") as f: |
|
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) |
|
self._dequantize_rope = cp.RawKernel( |
|
kernel_code, |
|
'dequantize_rope', |
|
backend="nvrtc" |
|
) |
|
|
|
with open(path.join(path.dirname(__file__), "fused_rope_mult.cu"), "r") as f: |
|
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) |
|
self._fused_rope_mult = cp.RawKernel( |
|
kernel_code, |
|
'fused_rope_mult', |
|
backend="nvrtc" |
|
) |
|
|
|
with open(path.join(path.dirname(__file__), "fused_rope_pos_mult_mqa.cu"), "r") as f: |
|
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)).replace('__ROPE_THETA__', str(config.rope_theta)) |
|
self._fused_rope_pos_mult = cp.RawKernel( |
|
kernel_code, |
|
'fused_rope_pos_mult', |
|
backend="nvrtc" |
|
) |
|
|
|
with open(path.join(path.dirname(__file__), "fused_mult_len.cu"), "r") as f: |
|
kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)) |
|
self._fused_mult = cp.RawKernel( |
|
kernel_code, |
|
'fused_mult', |
|
backend="nvrtc" |
|
) |
|
|
|
def quantize(self, x): |
|
n = x.numel() // x.shape[-1] |
|
codes = torch.empty(self.nsq, n, dtype=torch.uint8, device=x.device) |
|
blocks_per_grid = (self.nsq, ) |
|
threads_per_block = (1024, ) |
|
|
|
self._quantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ |
|
self.codebook.data_ptr(), |
|
x.data_ptr(), |
|
codes.data_ptr(), |
|
n |
|
]) |
|
|
|
return codes |
|
|
|
def dequantize(self, codes): |
|
n = codes.numel() // codes.shape[0] |
|
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) |
|
blocks_per_grid = (self.nsq, ) |
|
threads_per_block = (1024, ) |
|
|
|
self._dequantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ |
|
self.codebook.data_ptr(), |
|
codes.data_ptr(), |
|
x.data_ptr(), |
|
n |
|
]) |
|
|
|
return x |
|
|
|
def dequantize_rope(self, codes): |
|
_, batch_size, seq_len = codes.shape |
|
n = batch_size * seq_len |
|
x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) |
|
blocks_per_grid = (self.nsq, ) |
|
threads_per_block = (1024, ) |
|
|
|
self._dequantize_rope(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ |
|
self.codebook.data_ptr(), |
|
codes.data_ptr(), |
|
x.data_ptr(), |
|
batch_size, seq_len |
|
]) |
|
|
|
return x |
|
|
|
def fused_rope_mult(self, codes, queries): |
|
_, batch_size, k_len = codes.shape |
|
_, n_heads, q_len, _ = queries.shape |
|
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float16, device=codes.device) |
|
blocks_per_grid = (self.nsq, ) |
|
threads_per_block = (1024, ) |
|
|
|
self._fused_rope_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ |
|
self.codebook.data_ptr(), |
|
codes.data_ptr(), |
|
queries.data_ptr(), |
|
out.data_ptr(), |
|
batch_size, q_len, k_len |
|
]) |
|
|
|
return out |
|
|
|
def fused_rope_pos_mult(self, codes, queries, position_ids): |
|
_, batch_size, k_len = codes.shape |
|
_, n_heads, q_len, _ = queries.shape |
|
position_offsets = position_ids[:, -1] - k_len + 1 |
|
out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float32, device=codes.device) |
|
blocks_per_grid = (self.nsq, ) |
|
threads_per_block = (1024, ) |
|
|
|
self._fused_rope_pos_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ |
|
self.codebook.data_ptr(), |
|
codes.data_ptr(), |
|
position_offsets.data_ptr(), |
|
queries.data_ptr(), |
|
out.data_ptr(), |
|
batch_size, q_len, k_len |
|
]) |
|
|
|
return out |
|
|
|
def fused_mult(self, codes, weights, skip_last=0): |
|
batch_size, n_heads, q_len, k_len = weights.shape |
|
out = torch.zeros(batch_size, n_heads, q_len, self.head_dim, dtype=torch.float16, device=codes.device) |
|
blocks_per_grid = (self.nsq, ) |
|
threads_per_block = (min(1024, batch_size), ) |
|
|
|
self._fused_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ |
|
self.codebook.data_ptr(), |
|
codes.data_ptr(), |
|
weights.data_ptr(), |
|
out.data_ptr(), |
|
batch_size, q_len, k_len, k_len - skip_last |
|
]) |
|
|
|
return out |
|
|