marlenezw's picture
changing face alignment and removing its docker file.
22257c4
import torch
import torch.nn as nn
import torch.nn.functional as F
dim_enc = 512
dim_freq = 80
dim_f0 = 257
num_grp = 32
dim_dec = 512
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class Encoder(nn.Module):
"""Encoder module:
"""
def __init__(self, dim_neck, dim_emb, freq):
super(Encoder, self).__init__()
#self.dropout = nn.Dropout(0.0)
self.dim_neck = dim_neck
self.freq = freq
convolutions = []
for i in range(3):
conv_layer = nn.Sequential(
ConvNorm(dim_freq+dim_emb if i==0 else dim_enc,
dim_enc,
kernel_size=5, stride=1,
padding=2,
dilation=1, w_init_gain='relu'),
nn.GroupNorm(num_grp, dim_enc))
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
self.lstm = nn.LSTM(dim_enc, dim_neck, 2, batch_first=True, bidirectional=True)
def forward(self, x):
for conv in self.convolutions:
#x = self.dropout(F.relu(conv(x)))
x = F.relu(conv(x))
x = x.transpose(1, 2)
#self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
out_forward = outputs[:, :, :self.dim_neck]
out_backward = outputs[:, :, self.dim_neck:]
codes = []
for i in range(0, outputs.size(1), self.freq):
codes.append(torch.cat((out_forward[:,i+self.freq-1,:],out_backward[:,i,:]), dim=-1))
return codes
class Decoder(nn.Module):
"""Decoder module:
"""
def __init__(self, dim_neck, dim_emb, dim_pre):
super(Decoder, self).__init__()
self.lstm = nn.LSTM(dim_neck*2+dim_emb+dim_f0, dim_dec, 3, batch_first=True)
self.linear_projection = LinearNorm(dim_dec, dim_freq)
def forward(self, x):
#self.lstm1.flatten_parameters()
outputs, _ = self.lstm(x)
decoder_output = self.linear_projection(outputs)
return decoder_output
class Postnet(nn.Module):
"""Postnet
- Five 1-d convolution with 512 channels and kernel size 5
"""
def __init__(self):
super(Postnet, self).__init__()
#self.dropout = nn.Dropout(0.0)
self.convolutions = nn.ModuleList()
self.convolutions.append(
nn.Sequential(
ConvNorm(dim_freq, 512,
kernel_size=5, stride=1,
padding=2,
dilation=1, w_init_gain='tanh'),
nn.GroupNorm(num_grp, 512))
)
for i in range(1, 5 - 1):
self.convolutions.append(
nn.Sequential(
ConvNorm(512,
512,
kernel_size=5, stride=1,
padding=2,
dilation=1, w_init_gain='tanh'),
nn.GroupNorm(num_grp, 512))
)
self.convolutions.append(
nn.Sequential(
ConvNorm(512, dim_freq,
kernel_size=5, stride=1,
padding=2,
dilation=1, w_init_gain='linear'),
nn.GroupNorm(5, dim_freq))
)
def forward(self, x):
for i in range(len(self.convolutions) - 1):
#x = self.dropout(torch.tanh(self.convolutions[i](x)))
x = torch.tanh(self.convolutions[i](x))
#x = self.dropout(self.convolutions[-1](x))
x = self.convolutions[-1](x)
return x
class Generator(nn.Module):
"""Generator network."""
def __init__(self, dim_neck, dim_emb, dim_pre, freq):
super(Generator, self).__init__()
self.encoder = Encoder(dim_neck, dim_emb, freq)
self.decoder = Decoder(dim_neck, dim_emb, dim_pre)
self.postnet = Postnet()
self.freq = freq
def forward(self, x, c_org, f0_org=None, c_trg=None, f0_trg=None, enc_on=False):
x = x.transpose(2,1)
c_org = c_org.unsqueeze(-1).expand(-1, -1, x.size(-1))
x = torch.cat((x, c_org), dim=1)
codes = self.encoder(x)
if enc_on:
return torch.cat(codes, dim=-1)
tmp = []
for code in codes:
tmp.append(code.unsqueeze(1).expand(-1,self.freq,-1))
code_exp = torch.cat(tmp, dim=1)
encoder_outputs = torch.cat((code_exp,
c_trg.unsqueeze(1).expand(-1,x.size(-1),-1),
f0_trg), dim=-1)
mel_outputs = self.decoder(encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs.transpose(2,1))
mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2,1)
return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1)