naonauno's picture
Upload 855 files
d66c48f verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantize(nn.Module):
"""Vector quantization w/ exponential moving averages (EMA)"""
def __init__(
self,
dim: int,
codebook_size: int,
decay=0.8,
commitment=1.0,
eps=1e-5,
n_embed=None,
):
super().__init__()
n_embed = self.default(n_embed, codebook_size)
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
self.commitment = commitment
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
@property
def codebook(self):
return self.embed.transpose(0, 1)
def exists(self, val):
return val is not None
def default(self, val, d):
return val if self.exists(val) else d
def ema_inplace(self, moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(self, x, n_categories, eps=1e-5):
return (x + eps) / (x.sum() + n_categories * eps)
def forward(self, input):
dtype = input.dtype
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
if self.training:
self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
self.ema_inplace(self.embed_avg, embed_sum, self.decay)
cluster_size = (
self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
loss = F.mse_loss(quantize.detach(), input) * self.commitment
quantize = input + (quantize - input).detach()
avg_probs = torch.mean(embed_onehot, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return quantize, loss, perplexity
def forward_index(self, input):
dtype = input.dtype
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
quantize = input + (quantize - input).detach()
return quantize, embed_ind
class ResidualVQ(nn.Module):
"""Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantize(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x):
quantized_out = 0.0
residual = x
all_losses = []
all_perplexities = []
for layer in self.layers:
quantized, loss, perplexity = layer(residual)
# Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
# We found considering only the 1st layer VQ's graident results in better performance
# residual = residual - quantized.detach() # considering all layers' graidents
residual = (
residual - quantized
) # considering only the first layer's graident
quantized_out = quantized_out + quantized
all_losses.append(loss)
all_perplexities.append(perplexity)
all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
return quantized_out, all_losses, all_perplexities
def forward_index(self, x, flatten_idx=False):
"""
all_indices: [num_of_quantizers, B, T]
"""
quantized_out = 0.0
residual = x
all_indices = []
for i, layer in enumerate(self.layers):
quantized, indices = layer.forward_index(residual)
# residual = residual - quantized.detach()
residual = residual - quantized
quantized_out = quantized_out + quantized
if flatten_idx:
indices += self.codebook_size * i
all_indices.append(indices)
all_indices = torch.stack(all_indices)
return quantized_out, all_indices
def initial(self):
self.codebook = []
for layer in self.layers:
self.codebook.append(layer.codebook)
self.codebook_size = self.codebook[0].size(0)
self.codebook = torch.stack(self.codebook)
self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
def lookup(self, indices):
quantized_out = F.embedding(indices, self.codebook) # Num x T x C
return torch.sum(quantized_out, dim=0, keepdim=True)
class Quantizer(nn.Module):
def __init__(
self,
code_dim: int,
codebook_num: int,
codebook_size: int,
):
super().__init__()
self.codebook = ResidualVQ(
dim=code_dim, num_quantizers=codebook_num, codebook_size=codebook_size
)
def initial(self):
self.codebook.initial()
def forward(self, z):
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, vqloss, perplexity
def inference(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, indices
def encode(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
return zq, indices
def decode(self, indices):
z = self.codebook.lookup(indices)
return z
class Conv1d1x1(nn.Conv1d):
"""1x1 Conv1d."""
def __init__(self, in_channels, out_channels, bias=True):
super(Conv1d1x1, self).__init__(
in_channels, out_channels, kernel_size=1, bias=bias
)
class Conv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = -1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
if padding < 0:
padding = (kernel_size - 1) // 2 * dilation
self.dilation = dilation
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C, T).
"""
x = self.conv(x)
return x
class ConvTranspose1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding=-1,
output_padding=-1,
groups=1,
bias=True,
):
super().__init__()
if padding < 0:
padding = (stride + 1) // 2
if output_padding < 0:
output_padding = 1 if stride % 2 else 0
self.deconv = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
)
def forward(self, x):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C', T').
"""
x = self.deconv(x)
return x
class ResidualUnit(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
dilation=1,
bias=False,
nonlinear_activation="ELU",
nonlinear_activation_params={},
):
super().__init__()
self.activation = getattr(nn, nonlinear_activation)(
**nonlinear_activation_params
)
self.conv1 = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
bias=bias,
)
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
def forward(self, x):
y = self.conv1(self.activation(x))
y = self.conv2(self.activation(y))
return x + y
class Projector(nn.Module):
def __init__(
self, input_channels: int, code_dim: int, kernel_size=3, stride=1, bias=False
):
super().__init__()
self.project = Conv1d(
input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias
)
def forward(self, x):
return self.project(x)
class EncoderBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
dilations=(1, 1),
unit_kernel_size=3,
bias=True,
):
super().__init__()
self.res_units = torch.nn.ModuleList()
for dilation in dilations:
self.res_units += [
ResidualUnit(
in_channels,
in_channels,
kernel_size=unit_kernel_size,
dilation=dilation,
)
]
self.num_res = len(self.res_units)
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(
3 if stride == 1 else (2 * stride)
), # special case: stride=1, do not use kernel=2
stride=stride,
bias=bias,
)
def forward(self, x):
for idx in range(self.num_res):
x = self.res_units[idx](x)
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
input_channels: int,
encode_channels: int,
channel_ratios=(1, 1),
strides=(1, 1),
kernel_size=3,
bias=True,
block_dilations=(1, 1),
unit_kernel_size=3,
):
super().__init__()
assert len(channel_ratios) == len(strides)
self.conv = Conv1d(
in_channels=input_channels,
out_channels=encode_channels,
kernel_size=kernel_size,
stride=1,
bias=False,
)
self.conv_blocks = torch.nn.ModuleList()
in_channels = encode_channels
for idx, stride in enumerate(strides):
out_channels = int(encode_channels * channel_ratios[idx]) # could be float
self.conv_blocks += [
EncoderBlock(
in_channels,
out_channels,
stride,
dilations=block_dilations,
unit_kernel_size=unit_kernel_size,
bias=bias,
)
]
in_channels = out_channels
self.num_blocks = len(self.conv_blocks)
self.out_channels = out_channels
def forward(self, x):
x = self.conv(x)
for i in range(self.num_blocks):
x = self.conv_blocks[i](x)
return x
class DecoderBlock(nn.Module):
"""Decoder block (no up-sampling)"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
dilations=(1, 1),
unit_kernel_size=3,
bias=True,
):
super().__init__()
if stride == 1:
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
stride=stride,
bias=bias,
)
else:
self.conv = ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(2 * stride),
stride=stride,
bias=bias,
)
self.res_units = torch.nn.ModuleList()
for idx, dilation in enumerate(dilations):
self.res_units += [
ResidualUnit(
out_channels,
out_channels,
kernel_size=unit_kernel_size,
dilation=dilation,
)
]
self.num_res = len(self.res_units)
def forward(self, x):
x = self.conv(x)
for idx in range(self.num_res):
x = self.res_units[idx](x)
return x
class Decoder(nn.Module):
def __init__(
self,
code_dim: int,
output_channels: int,
decode_channels: int,
channel_ratios=(1, 1),
strides=(1, 1),
kernel_size=3,
bias=True,
block_dilations=(1, 1),
unit_kernel_size=3,
):
super().__init__()
assert len(channel_ratios) == len(strides)
self.conv1 = Conv1d(
in_channels=code_dim,
out_channels=int(decode_channels * channel_ratios[0]),
kernel_size=kernel_size,
stride=1,
bias=False,
)
self.conv_blocks = torch.nn.ModuleList()
for idx, stride in enumerate(strides):
in_channels = int(decode_channels * channel_ratios[idx])
if idx < (len(channel_ratios) - 1):
out_channels = int(decode_channels * channel_ratios[idx + 1])
else:
out_channels = decode_channels
self.conv_blocks += [
DecoderBlock(
in_channels,
out_channels,
stride,
dilations=block_dilations,
unit_kernel_size=unit_kernel_size,
bias=bias,
)
]
self.num_blocks = len(self.conv_blocks)
self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
def forward(self, z):
x = self.conv1(z)
for i in range(self.num_blocks):
x = self.conv_blocks[i](x)
x = self.conv2(x)
return x
class VevoRepCodec(nn.Module):
def __init__(
self,
input_channels=768,
output_channels=768,
encode_channels=768,
decode_channels=768,
code_dim=768,
codebook_num=1,
codebook_size=1024,
bias=True,
enc_ratios=(1, 1),
dec_ratios=(1, 1),
enc_strides=(1, 1),
dec_strides=(1, 1),
enc_kernel_size=3,
dec_kernel_size=3,
enc_block_dilations=(1, 1),
enc_block_kernel_size=3,
dec_block_dilations=(1, 1),
dec_block_kernel_size=3,
):
super().__init__()
self.input_channels = input_channels
self.encoder = Encoder(
input_channels=input_channels,
encode_channels=encode_channels,
channel_ratios=enc_ratios,
strides=enc_strides,
kernel_size=enc_kernel_size,
bias=bias,
block_dilations=enc_block_dilations,
unit_kernel_size=enc_block_kernel_size,
)
self.decoder = Decoder(
code_dim=code_dim,
output_channels=output_channels,
decode_channels=decode_channels,
channel_ratios=dec_ratios,
strides=dec_strides,
kernel_size=dec_kernel_size,
bias=bias,
block_dilations=dec_block_dilations,
unit_kernel_size=dec_block_kernel_size,
)
self.projector = Projector(
input_channels=self.encoder.out_channels,
code_dim=code_dim,
kernel_size=3,
stride=1,
bias=False,
)
self.quantizer = Quantizer(
code_dim=code_dim, codebook_num=codebook_num, codebook_size=codebook_size
)
def forward(self, x):
x = self.encoder(x)
z = self.projector(x)
zq, vqloss, perplexity = self.quantizer(z)
y = self.decoder(zq)
return y, zq, z, vqloss, perplexity