import torch import torch.nn as nn from modules.audio_tokenizer.quantize import ResidualVQ from modules.audio_tokenizer.vocos import VocosBackbone from modules.audio_tokenizer.transformer import TransformerEncoder def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) class RepCodec(nn.Module): def __init__( self, codebook_size=8192, hidden_size=1024, codebook_dim=8, vocos_dim=384, vocos_intermediate_dim=2048, vocos_num_layers=12, num_quantizers=1, use_timbre_encoder=False, cfg=None, ): super().__init__() codebook_size = ( cfg.codebook_size if cfg is not None and hasattr(cfg, "codebook_size") else codebook_size ) codebook_dim = ( cfg.codebook_dim if cfg is not None and hasattr(cfg, "codebook_dim") else codebook_dim ) hidden_size = ( cfg.hidden_size if cfg is not None and hasattr(cfg, "hidden_size") else hidden_size ) vocos_dim = ( cfg.vocos_dim if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_dim ) vocos_intermediate_dim = ( cfg.vocos_intermediate_dim if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_intermediate_dim ) vocos_num_layers = ( cfg.vocos_num_layers if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_num_layers ) num_quantizers = ( cfg.num_quantizers if cfg is not None and hasattr(cfg, "num_quantizers") else num_quantizers ) use_timbre_encoder = ( cfg.use_timbre_encoder if cfg is not None and hasattr(cfg, "use_timbre_encoder") else use_timbre_encoder ) self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.hidden_size = hidden_size self.vocos_dim = vocos_dim self.vocos_intermediate_dim = vocos_intermediate_dim self.vocos_num_layers = vocos_num_layers self.num_quantizers = num_quantizers self.use_timbre_encoder = use_timbre_encoder self.encoder = nn.Sequential( VocosBackbone( input_channels=self.hidden_size, dim=384, intermediate_dim=2048, num_layers=12, adanorm_num_embeddings=None ), nn.Linear(384, self.hidden_size) ) self.decoder = nn.Sequential( VocosBackbone( input_channels=self.hidden_size, dim=384, intermediate_dim=2048, num_layers=12, adanorm_num_embeddings=None ), nn.Linear(384, self.hidden_size) ) self.quantizer = ResidualVQ( input_dim=hidden_size, num_quantizers=num_quantizers, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_type="fvq", quantizer_dropout=0.0, commitment=0.15, codebook_loss_weight=1.0, use_l2_normlize=True, ) if self.use_timbre_encoder: #TODO: write encoder hidden (256) as a hyparam self.timbre_in = nn.Linear(hidden_size, 256) self.timbre_encoder = TransformerEncoder( enc_emb_tokens=None, encoder_layer=4, encoder_hidden=256, encoder_head=4, conv_filter_size=1024, conv_kernel_size=5, encoder_dropout=0.1, use_pe=False, cfg=None, ) self.timbre_out = nn.Linear(256, hidden_size) self.timbre_linear = nn.Linear(hidden_size, hidden_size * 2) self.timbre_linear.bias.data[:hidden_size] = 1 self.timbre_linear.bias.data[hidden_size:] = 0 self.timbre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) self.enc_ln = nn.LayerNorm(hidden_size, elementwise_affine=False) self.reset_parameters() def forward(self, x): x = self.encoder(x.transpose(1, 2)).transpose(1, 2) if self.use_timbre_encoder: x_timbre = x x = x.transpose(1, 2) x = self.enc_ln(x) x = x.transpose(1, 2) ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, _, ) = self.quantizer(x) if self.use_timbre_encoder: x_timbre = x_timbre.transpose(1, 2) x_timbre = self.timbre_in(x_timbre) x_timbre = self.timbre_encoder(x_timbre, None, None) x_timbre = self.timbre_out(x_timbre) x_timbre = x_timbre.transpose(1, 2) spk_embs = torch.mean(x_timbre, dim=2) style = self.timbre_linear(spk_embs).unsqueeze(2) # (B, 2d, 1) gamma, beta = style.chunk(2, 1) # (B, d, 1) quantized_out = quantized_out.transpose(1, 2) quantized_out = self.timbre_norm(quantized_out) quantized_out = quantized_out.transpose(1, 2) quantized_out = quantized_out * gamma + beta x_rec = self.decoder(quantized_out) codebook_loss = (all_codebook_losses + all_commit_losses).mean() all_indices = all_indices return x_rec, codebook_loss, all_indices def quantize(self, x): x = self.encoder(x.transpose(1, 2)).transpose(1, 2) if self.use_timbre_encoder: x = x.transpose(1, 2) x = self.enc_ln(x) x = x.transpose(1, 2) ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, _, ) = self.quantizer(x) if all_indices.shape[0] == 1: return all_indices.squeeze(0), quantized_out.transpose(1, 2) return all_indices, quantized_out.transpose(1, 2) def reset_parameters(self): self.apply(init_weights)