#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/layers/uni_deep_fsmn.py https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/mossformer2_se/fsmn.py """ import torch import torch.nn as nn import torch.nn.functional as F class UniDeepFsmn(nn.Module): def __init__(self, input_dim: int, hidden_size: int, lorder: int = 1, ): super(UniDeepFsmn, self).__init__() self.input_dim = input_dim self.hidden_size = hidden_size self.lorder = lorder self.linear = nn.Linear(input_dim, hidden_size) self.project = nn.Linear(hidden_size, input_dim, bias=False) self.conv1 = nn.Conv2d( input_dim, input_dim, kernel_size=(lorder, 1), stride=(1, 1), groups=input_dim, bias=False ) def forward(self, inputs: torch.Tensor): """ :param inputs: torch.Tensor, shape: [b, t, h] :return: torch.Tensor, shape: [b, t, h] """ x = F.relu(self.linear(inputs)) x = self.project(x) x = torch.unsqueeze(x, 1) # x shape: [b, 1, t, h] x = x.permute(0, 3, 2, 1) # x shape: [b, h, t, 1] y = F.pad(x, [0, 0, self.lorder - 1, 0]) x = x + self.conv1(y) x = x.permute(0, 3, 2, 1) # x shape: [b, 1, t, h] x = x.squeeze() result = inputs + x return result def main(): x = torch.rand(size=(1, 200, 32)) fsmn = UniDeepFsmn( input_dim=32, hidden_size=64, lorder=3, ) result = fsmn.forward(x) print(result.shape) return if __name__ == "__main__": main()