# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import numpy as np import torch.nn as nn from functools import partial import torch.nn.functional as F from typing import Callable, Dict from funasr_detach.models.emotion2vec.fairseq_modules import ( LayerNorm, SamePad, TransposeLast, ConvFeatureExtractionModel, ) from funasr_detach.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d from funasr_detach.models.emotion2vec.base import ( ModalitySpecificEncoder, get_alibi_bias, ) class AudioEncoder(ModalitySpecificEncoder): def __init__( self, modality_cfg, embed_dim: int, make_block: Callable[[float], nn.ModuleList], norm_layer: Callable[[int], nn.LayerNorm], layer_norm_first: bool, alibi_biases: Dict, ): self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec) feature_embed_dim = self.feature_enc_layers[-1][0] local_encoder = ConvFeatureExtractionModel( conv_layers=self.feature_enc_layers, dropout=0.0, mode=modality_cfg.extractor_mode, conv_bias=False, ) project_features = nn.Sequential( TransposeLast(), nn.LayerNorm(feature_embed_dim), nn.Linear(feature_embed_dim, embed_dim), ) num_pos_layers = modality_cfg.conv_pos_depth k = max(3, modality_cfg.conv_pos_width // num_pos_layers) positional_encoder = nn.Sequential( TransposeLast(), *[ nn.Sequential( nn.Conv1d( embed_dim, embed_dim, kernel_size=k, padding=k // 2, groups=modality_cfg.conv_pos_groups, ), SamePad(k), TransposeLast(), LayerNorm(embed_dim, elementwise_affine=False), TransposeLast(), nn.GELU(), ) for _ in range(num_pos_layers) ], TransposeLast(), ) if modality_cfg.conv_pos_pre_ln: positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder) dpr = np.linspace( modality_cfg.start_drop_path_rate, modality_cfg.end_drop_path_rate, modality_cfg.prenet_depth, ) context_encoder = BlockEncoder( nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), norm_layer(embed_dim) if not layer_norm_first else None, layer_norm_first, modality_cfg.prenet_layerdrop, modality_cfg.prenet_dropout, ) decoder = ( Decoder1d(modality_cfg.decoder, embed_dim) if modality_cfg.decoder is not None else None ) alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) super().__init__( modality_cfg=modality_cfg, embed_dim=embed_dim, local_encoder=local_encoder, project_features=project_features, fixed_positional_encoder=None, relative_positional_encoder=positional_encoder, context_encoder=context_encoder, decoder=decoder, get_alibi_bias=alibi_bias_fn, ) def convert_padding_mask(self, x, padding_mask): def get_feat_extract_output_lengths(input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers """ def _conv_out_length(input_length, kernel_size, stride): return torch.floor((input_length - kernel_size) / stride + 1) for i in range(len(self.feature_enc_layers)): input_lengths = _conv_out_length( input_lengths, self.feature_enc_layers[i][1], self.feature_enc_layers[i][2], ) return input_lengths.to(torch.long) if padding_mask is not None: input_lengths = (1 - padding_mask.long()).sum(-1) # apply conv formula to get real output_lengths output_lengths = get_feat_extract_output_lengths(input_lengths) if padding_mask.any(): padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device) # these two operations makes sure that all values # before the output lengths indices are attended to padding_mask[ ( torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1, ) ] = 1 padding_mask = ( 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1]) ).bool() else: padding_mask = torch.zeros( x.shape[:2], dtype=torch.bool, device=x.device ) return padding_mask def reset_parameters(self): super().reset_parameters() for mod in self.project_features.children(): if isinstance(mod, nn.Linear): mod.reset_parameters() if self.decoder is not None: self.decoder.reset_parameters()