|
"""
|
|
Lookup Free Quantization
|
|
Proposed in https://arxiv.org/abs/2310.05737
|
|
|
|
In the simplest setup, each dimension is quantized into {-1, 1}.
|
|
An entropy penalty is used to encourage utilization.
|
|
"""
|
|
|
|
from math import log2, ceil
|
|
from functools import partial, cache
|
|
from collections import namedtuple
|
|
from contextlib import nullcontext
|
|
|
|
import torch.distributed as dist
|
|
from torch.distributed import nn as dist_nn
|
|
|
|
import torch
|
|
from torch import nn, einsum
|
|
import torch.nn.functional as F
|
|
from torch.nn import Module
|
|
from torch.amp import autocast
|
|
|
|
from einops import rearrange, reduce, pack, unpack
|
|
|
|
|
|
|
|
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
|
|
|
|
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
|
|
|
|
|
|
|
|
@cache
|
|
def is_distributed():
|
|
return dist.is_initialized() and dist.get_world_size() > 1
|
|
|
|
def maybe_distributed_mean(t):
|
|
if not is_distributed():
|
|
return t
|
|
|
|
dist_nn.all_reduce(t)
|
|
t = t / dist.get_world_size()
|
|
return t
|
|
|
|
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
def identity(t):
|
|
return t
|
|
|
|
def default(*args):
|
|
for arg in args:
|
|
if exists(arg):
|
|
return arg() if callable(arg) else arg
|
|
return None
|
|
|
|
def pack_one(t, pattern):
|
|
return pack([t], pattern)
|
|
|
|
def unpack_one(t, ps, pattern):
|
|
return unpack(t, ps, pattern)[0]
|
|
|
|
def l2norm(t):
|
|
return F.normalize(t, dim = -1)
|
|
|
|
|
|
|
|
def log(t, eps = 1e-5):
|
|
return t.clamp(min = eps).log()
|
|
|
|
def entropy(prob):
|
|
return (-prob * log(prob)).sum(dim=-1)
|
|
|
|
|
|
|
|
class CosineSimLinear(Module):
|
|
def __init__(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
scale = 1.
|
|
):
|
|
super().__init__()
|
|
self.scale = scale
|
|
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
|
|
|
|
def forward(self, x):
|
|
x = F.normalize(x, dim = -1)
|
|
w = F.normalize(self.weight, dim = 0)
|
|
return (x @ w) * self.scale
|
|
|
|
def soft_entropy_loss(u, tau=1.0, gamma=1.0):
|
|
"""
|
|
Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
|
|
|
|
Args:
|
|
u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
|
|
tau (float): Temperature scaling factor.
|
|
gamma (float): Weight for the second entropy term.
|
|
|
|
Returns:
|
|
torch.Tensor: Soft entropy loss.
|
|
"""
|
|
|
|
L = u.size(1)
|
|
corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
|
|
|
|
|
|
|
|
prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2))
|
|
|
|
|
|
entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1)
|
|
entropy_term1 = entropy_per_dim.mean()
|
|
|
|
|
|
expected_probs = prob_matrix.mean(dim=0)
|
|
entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
|
|
|
|
|
|
loss = entropy_term1 - gamma * entropy_term2
|
|
return loss
|
|
|
|
|
|
|
|
class BinarySphericalQuantize(Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim = None,
|
|
codebook_size = None,
|
|
entropy_loss_weight = 0.1,
|
|
commitment_loss_weight = 0.,
|
|
diversity_gamma = 1.,
|
|
straight_through_activation = nn.Identity(),
|
|
num_codebooks = 1,
|
|
keep_num_codebooks_dim = None,
|
|
codebook_scale = 1.,
|
|
frac_per_sample_entropy = 0.25,
|
|
has_projections = None,
|
|
projection_has_bias = True,
|
|
soft_clamp_input_value = None,
|
|
cosine_sim_project_in = False,
|
|
cosine_sim_project_in_scale = None,
|
|
channel_first = None,
|
|
experimental_softplus_entropy_loss = False,
|
|
entropy_loss_offset = 5.,
|
|
spherical = True,
|
|
force_quantization_f32 = True,
|
|
enable_entropy_loss = True,
|
|
soft_entropy_loss = True,
|
|
):
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
|
|
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
|
|
|
|
codebook_size = default(codebook_size, lambda: 2 ** dim)
|
|
self.codebook_size = codebook_size
|
|
|
|
codebook_dim = int(log2(codebook_size))
|
|
codebook_dims = codebook_dim * num_codebooks
|
|
dim = default(dim, codebook_dims)
|
|
|
|
has_projections = default(has_projections, dim != codebook_dims)
|
|
|
|
if cosine_sim_project_in:
|
|
cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
|
|
project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
|
|
else:
|
|
project_in_klass = partial(nn.Linear, bias = projection_has_bias)
|
|
|
|
self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
|
|
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
|
self.has_projections = has_projections
|
|
|
|
self.dim = dim
|
|
self.codebook_dim = codebook_dim
|
|
self.num_codebooks = num_codebooks
|
|
|
|
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
|
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
|
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
|
|
|
|
|
|
|
self.channel_first = channel_first
|
|
|
|
|
|
|
|
self.activation = straight_through_activation
|
|
|
|
|
|
|
|
self.spherical = spherical
|
|
self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
|
|
|
|
|
|
|
|
assert 0 < frac_per_sample_entropy <= 1.
|
|
self.frac_per_sample_entropy = frac_per_sample_entropy
|
|
|
|
self.diversity_gamma = diversity_gamma
|
|
self.entropy_loss_weight = entropy_loss_weight
|
|
|
|
|
|
|
|
self.codebook_scale = codebook_scale
|
|
|
|
|
|
|
|
self.commitment_loss_weight = commitment_loss_weight
|
|
|
|
|
|
|
|
self.soft_clamp_input_value = soft_clamp_input_value
|
|
assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
|
|
|
|
|
|
|
|
self.entropy_loss_offset = entropy_loss_offset
|
|
self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
|
|
|
|
|
|
|
|
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
|
|
|
|
|
|
self.force_quantization_f32 = force_quantization_f32
|
|
|
|
|
|
self.enable_entropy_loss = enable_entropy_loss
|
|
self.soft_entropy_loss = soft_entropy_loss
|
|
if codebook_size <= 100000:
|
|
all_codes = torch.arange(codebook_size)
|
|
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
|
codebook = self.bits_to_codes(bits)
|
|
|
|
self.register_buffer('codebook', codebook.float(), persistent = False)
|
|
else:
|
|
all_codes = torch.arange(pow(2, 16))
|
|
mask = 2 ** torch.arange(16 - 1, -1, -1)
|
|
bits = ((all_codes[..., None].int() & mask) != 0).float()
|
|
codebook = self.bits_to_codes(bits)
|
|
|
|
self.register_buffer('codebook', codebook.float(), persistent = False)
|
|
|
|
def bits_to_codes(self, bits):
|
|
return bits * self.codebook_scale * 2 - self.codebook_scale
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.codebook.dtype
|
|
|
|
def indices_to_codes(
|
|
self,
|
|
indices,
|
|
project_out = True
|
|
):
|
|
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
|
should_transpose = default(self.channel_first, is_img_or_video)
|
|
|
|
if not self.keep_num_codebooks_dim:
|
|
indices = rearrange(indices, '... -> ... 1')
|
|
|
|
|
|
|
|
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
|
|
|
|
codes = self.bits_to_codes(bits)
|
|
|
|
codes = self.maybe_l2norm(codes)
|
|
|
|
codes = rearrange(codes, '... c d -> ... (c d)')
|
|
|
|
|
|
|
|
|
|
if project_out:
|
|
codes = self.project_out(codes)
|
|
|
|
|
|
|
|
if should_transpose:
|
|
codes = rearrange(codes, 'b ... d -> b d ...')
|
|
|
|
return codes
|
|
|
|
def bits_to_z(self, bits):
|
|
|
|
assert torch.all(bits.abs() == 1)
|
|
quantized = bits.float()
|
|
quantized = self.maybe_l2norm(quantized)
|
|
z = self.project_out(quantized)
|
|
return z
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
inv_temperature = 100.,
|
|
return_loss_breakdown = False,
|
|
mask = None,
|
|
return_bits = False
|
|
):
|
|
"""
|
|
einstein notation
|
|
b - batch
|
|
n - sequence (or flattened spatial dimensions)
|
|
d - feature dimension, which is also log2(codebook size)
|
|
c - number of codebook dim
|
|
"""
|
|
|
|
is_img_or_video = x.ndim >= 4
|
|
should_transpose = default(self.channel_first, is_img_or_video)
|
|
|
|
|
|
|
|
if should_transpose:
|
|
x = rearrange(x, 'b d ... -> b ... d')
|
|
x, ps = pack_one(x, 'b * d')
|
|
|
|
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
|
|
|
x = self.project_in(x)
|
|
|
|
|
|
|
|
if exists(self.soft_clamp_input_value):
|
|
clamp_value = self.soft_clamp_input_value
|
|
x = (x / clamp_value).tanh() * clamp_value
|
|
|
|
|
|
|
|
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
|
|
|
|
|
|
|
x = self.maybe_l2norm(x)
|
|
|
|
|
|
|
|
force_f32 = self.force_quantization_f32
|
|
|
|
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
|
|
|
|
with quantization_context():
|
|
|
|
if force_f32:
|
|
orig_dtype = x.dtype
|
|
x = x.float()
|
|
|
|
|
|
|
|
original_input = x
|
|
|
|
codebook_value = torch.ones_like(x) * self.codebook_scale
|
|
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
|
if return_bits:
|
|
return quantized
|
|
|
|
|
|
|
|
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
|
|
|
|
|
|
|
quantized = self.maybe_l2norm(quantized)
|
|
|
|
|
|
|
|
if self.training:
|
|
x = self.activation(x)
|
|
x = x + (quantized - x).detach()
|
|
else:
|
|
x = quantized
|
|
|
|
|
|
if self.soft_entropy_loss:
|
|
entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
|
|
elif self.training and self.enable_entropy_loss:
|
|
|
|
if force_f32:
|
|
codebook = self.codebook.float()
|
|
|
|
codebook = self.maybe_l2norm(codebook)
|
|
|
|
|
|
|
|
if self.frac_per_sample_entropy < 1.:
|
|
|
|
if exists(mask):
|
|
original_input = original_input[mask]
|
|
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
|
|
|
rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
|
|
|
|
sampled_input = original_input[..., rand_mask]
|
|
|
|
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
|
|
|
|
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
|
|
|
|
per_sample_probs = sampled_prob
|
|
else:
|
|
if exists(mask):
|
|
original_input = original_input[mask]
|
|
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
|
|
|
|
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
|
|
|
|
prob = (-distance * inv_temperature).softmax(dim = -1)
|
|
|
|
per_sample_probs = prob
|
|
|
|
|
|
|
|
per_sample_entropy = entropy(per_sample_probs).mean()
|
|
|
|
|
|
|
|
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
|
|
|
|
avg_prob = maybe_distributed_mean(avg_prob)
|
|
|
|
codebook_entropy = entropy(avg_prob).mean()
|
|
|
|
|
|
|
|
|
|
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
|
else:
|
|
|
|
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
|
|
|
|
|
|
|
if self.training and self.experimental_softplus_entropy_loss:
|
|
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
|
|
|
|
|
|
|
|
if self.training and self.commitment_loss_weight > 0.:
|
|
|
|
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
|
|
|
|
if exists(mask):
|
|
commit_loss = commit_loss[mask]
|
|
|
|
commit_loss = commit_loss.mean()
|
|
else:
|
|
commit_loss = self.zero
|
|
|
|
|
|
|
|
if force_f32:
|
|
x = x.type(orig_dtype)
|
|
|
|
|
|
|
|
x = rearrange(x, 'b n c d -> b n (c d)')
|
|
|
|
|
|
|
|
x = self.project_out(x)
|
|
|
|
|
|
|
|
if should_transpose:
|
|
x = unpack_one(x, ps, 'b * d')
|
|
x = rearrange(x, 'b ... d -> b d ...')
|
|
|
|
indices = unpack_one(indices, ps, 'b * c')
|
|
|
|
|
|
|
|
if not self.keep_num_codebooks_dim:
|
|
indices = rearrange(indices, '... 1 -> ...')
|
|
|
|
|
|
|
|
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
|
|
|
|
|
|
|
ret = Return(x, indices, aux_loss)
|
|
|
|
if not return_loss_breakdown:
|
|
return ret
|
|
|
|
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
|
|
|
|
class GroupedResidualBSQ(Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
groups = 1,
|
|
accept_image_fmap = False,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.groups = groups
|
|
assert (dim % groups) == 0
|
|
dim_per_group = dim // groups
|
|
|
|
self.accept_image_fmap = accept_image_fmap
|
|
|
|
self.rvqs = nn.ModuleList([])
|
|
|
|
for _ in range(groups):
|
|
self.rvqs.append(LFQ(
|
|
dim = dim_per_group,
|
|
**kwargs
|
|
))
|
|
|
|
self.codebook_size = self.rvqs[0].codebook_size
|
|
|
|
@property
|
|
def codebooks(self):
|
|
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
|
|
|
@property
|
|
def split_dim(self):
|
|
return 1 if self.accept_image_fmap else -1
|
|
|
|
def get_codes_from_indices(self, indices):
|
|
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
|
return torch.stack(codes)
|
|
|
|
def get_output_from_indices(self, indices):
|
|
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
|
|
return torch.cat(outputs, dim = self.split_dim)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
return_all_codes = False
|
|
):
|
|
shape, split_dim = x.shape, self.split_dim
|
|
assert shape[split_dim] == self.dim
|
|
|
|
|
|
|
|
x = x.chunk(self.groups, dim = split_dim)
|
|
|
|
forward_kwargs = dict(
|
|
)
|
|
|
|
|
|
|
|
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
|
out = tuple(zip(*out))
|
|
|
|
|
|
|
|
quantized, all_indices, *maybe_aux_loss = out
|
|
|
|
quantized = torch.cat(quantized, dim = split_dim)
|
|
all_indices = torch.stack(all_indices)
|
|
|
|
ret = (quantized, all_indices, *maybe_aux_loss)
|
|
return ret
|
|
|