|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
from copy import deepcopy |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.nn import Linear |
|
from tqdm import tqdm |
|
|
|
from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm |
|
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb |
|
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer |
|
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding |
|
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder |
|
|
|
FS_ENCODERS = { |
|
'rel_fft': lambda hp, dict_size: RelTransformerEncoder( |
|
dict_size, hp['hidden_size'], hp['hidden_size'], |
|
hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'], |
|
hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']), |
|
} |
|
|
|
def fill_with_neg_inf2(t): |
|
"""FP16-compatible function that fills a tensor with -inf.""" |
|
return t.float().fill_(-1e8).type_as(t) |
|
|
|
def expand_states(h, mel2token): |
|
h = F.pad(h, [0, 0, 1, 0]) |
|
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) |
|
h = torch.gather(h, 1, mel2token_) |
|
return h |
|
|
|
|
|
class CodePredictor(nn.Module): |
|
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size): |
|
super().__init__() |
|
self.hparams = deepcopy(hparams) |
|
self.hparams['hidden_size'] = hidden_size |
|
self.hidden_size = hidden_size |
|
char_dict_size = hparams.get('char_dict_size', 4000) |
|
if not hparams.get('lm_use_enc'): |
|
self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0) |
|
if hparams.get('mega_use_char', True): |
|
self.char_encoder = nn.Embedding(char_dict_size, |
|
self.hidden_size, padding_idx=0) |
|
else: |
|
self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size) |
|
if hparams.get('mega_use_char', True): |
|
self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size) |
|
if hparams['use_ph_pos_embed']: |
|
self.ph_pos_embed = PosEmb(self.hidden_size) |
|
|
|
self.char_empty_embed = nn.Embedding(1, self.hidden_size) |
|
if hparams.get('use_bert_input'): |
|
self.bert_input_proj = nn.Linear(768, self.hidden_size) |
|
self.ling_label_embed_layers = nn.ModuleDict() |
|
for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']): |
|
self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0) |
|
|
|
self.dec_hidden_size = dec_hidden_size |
|
self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size) |
|
self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0) |
|
self.use_pos_embed = hparams.get('use_pos_embed', False) |
|
if self.use_pos_embed: |
|
self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024) |
|
self.use_post_ln = hparams.get('use_post_ln', False) |
|
self.layers = None |
|
if not self.use_post_ln: |
|
self.layer_norm = LayerNorm(dec_hidden_size) |
|
self.code_size = code_size |
|
self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True) |
|
|
|
def forward_ling_encoder( |
|
self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre): |
|
ph_tokens = txt_tokens |
|
hparams = self.hparams |
|
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] |
|
x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre) |
|
|
|
|
|
if not hparams.get('lm_use_enc'): |
|
x_ph = self.encoder(ph_tokens) |
|
x_ph = x_ph + sum( |
|
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ |
|
if len(hparams['ling_labels']) > 0 else 0 |
|
x_ph = x_ph + x_spk |
|
else: |
|
|
|
ph_enc_oembed = sum( |
|
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ |
|
if len(hparams['ling_labels']) > 0 else 0 |
|
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( |
|
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) |
|
ph_enc_oembed = ph_enc_oembed + x_spk |
|
ph_enc_oembed = ph_enc_oembed * ph_nonpadding |
|
x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed) |
|
|
|
|
|
if char_tokens is not None and ph2char is not None: |
|
char_nonpadding = (char_tokens > 0).float()[:, :, None] |
|
x_char = self.char_encoder(char_tokens) |
|
empty_char = (ph2char > 100000).long() |
|
ph2char = ph2char * (1 - empty_char) |
|
x_char_phlevel = \ |
|
expand_states(x_char * char_nonpadding, ph2char) \ |
|
* (1 - empty_char)[..., None] + \ |
|
self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None] |
|
else: |
|
x_char_phlevel = 0 |
|
|
|
x_ling = x_ph + x_char_phlevel |
|
x_ling = x_ling * ph_nonpadding |
|
x_ling = self.enc_proj(x_ling) |
|
return x_ling |
|
|
|
def sample_one_step(self, vq_pred): |
|
hparams = self.hparams |
|
if hparams.get('infer_top_k'): |
|
top_k = hparams.get('infer_top_k') |
|
temperature = hparams.get('infer_temperature', 1) |
|
vq_pred = vq_pred[:, -1] / temperature |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1))) |
|
vq_pred[vq_pred < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = F.softmax(vq_pred, dim=-1) |
|
|
|
vq_pred = torch.multinomial(probs, num_samples=1) |
|
else: |
|
vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1) |
|
return vq_pred |
|
|
|
def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None): |
|
|
|
style_embed = 0 |
|
if self.hparams['use_spk_embed']: |
|
style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :] |
|
if self.hparams['use_spk_id']: |
|
style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :] |
|
if self.hparams['use_spk_enc']: |
|
style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :] |
|
return style_embed |
|
|
|
def buffered_future_mask(self, tensor): |
|
dim = tensor.size(0) |
|
if ( |
|
not hasattr(self, '_future_mask') |
|
or self._future_mask is None |
|
or self._future_mask.device != tensor.device |
|
or self._future_mask.size(0) < dim |
|
): |
|
self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1) |
|
return self._future_mask[:dim, :dim] |
|
|
|
|
|
class ARDurPredictor(CodePredictor): |
|
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True, |
|
op_version=1): |
|
super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size) |
|
self.use_rot_embed = use_rot_embed |
|
bias = hparams.get('lm_bias', True) |
|
if self.use_rot_embed: |
|
self.layers = nn.ModuleList([]) |
|
self.layers.extend([ |
|
RotTransformerDecoderLayer( |
|
dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4, |
|
post_ln=self.use_post_ln, op_version=op_version, bias=bias) |
|
for _ in range(lm_num_layers) |
|
]) |
|
if hparams['dur_model_type'] == 'ar_mse': |
|
self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus()) |
|
else: |
|
self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1) |
|
|
|
def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, |
|
prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None, |
|
incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None, |
|
prompt_length=None, cache_size=20, streaming=False): |
|
x = self.code_emb(prev_code) |
|
if x_ling is None: |
|
x_ling = self.forward_ling_encoder( |
|
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre) |
|
x_ling = x_ling.flatten(0, 1) |
|
txt_tokens = txt_tokens.flatten(0, 1) |
|
x_ling = x_ling[txt_tokens > 0][None] |
|
|
|
|
|
self_attn_padding_mask = None |
|
if self.use_pos_embed: |
|
positions = self.embed_positions( |
|
prev_code, |
|
incremental_state=incremental_state |
|
) |
|
if incremental_state is not None: |
|
x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]] |
|
if spk_pos_ids_flat is not None: |
|
spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]] |
|
x = x[:, -1:] |
|
if self.use_pos_embed: |
|
positions = positions[:, -1:] |
|
if streaming: |
|
|
|
spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device), |
|
spk_pos_ids_flat) |
|
|
|
|
|
if self.use_pos_embed: |
|
x = x + positions |
|
x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous() |
|
T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1]) |
|
x_ling = x_ling.reshape(-1, T, x_ling.shape[-1]) |
|
x = x + x_ling |
|
x = x.transpose(0, 1) |
|
|
|
for idx, layer in enumerate(self.layers): |
|
if incremental_state is None: |
|
self_attn_mask = self.buffered_future_mask(x) |
|
if attn_mask is not None: |
|
self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8 |
|
self_attn_mask = self_attn_mask.clamp_min(-1e8) |
|
else: |
|
self_attn_mask = None |
|
|
|
x, attn_weights = layer( |
|
x, |
|
incremental_state=incremental_state, |
|
self_attn_mask=self_attn_mask, |
|
self_attn_padding_mask=self_attn_padding_mask, |
|
spk_pos_ids_flat=spk_pos_ids_flat |
|
) |
|
|
|
if streaming and incremental_state != {}: |
|
for k, v in incremental_state.items(): |
|
if 'attn_state' in k: |
|
prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] |
|
cur_length = prev_key.shape[2] |
|
if cur_length - prompt_length > cache_size: |
|
prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2) |
|
prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]), |
|
dim=2) |
|
incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value |
|
|
|
if not self.use_post_ln: |
|
x = self.layer_norm(x) |
|
|
|
x = x.transpose(0, 1) |
|
x = self.project_out_dim(x) |
|
return x |
|
|
|
def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, |
|
spk_id=None, spk_embed=None, mels_timbre=None, |
|
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, |
|
first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs): |
|
if incremental_state is None: |
|
incremental_state = {} |
|
x_ling = self.forward_ling_encoder( |
|
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, |
|
spk_id, spk_embed, mels_timbre) |
|
x_ling = x_ling.flatten(0, 1) |
|
txt_tokens_ori = txt_tokens |
|
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) |
|
x_ling = x_ling[txt_tokens > 0][None] |
|
txt_tokens = txt_tokens[txt_tokens > 0][None] |
|
|
|
decoded = torch.zeros_like(txt_tokens) |
|
decoded = F.pad(decoded, [1, 0], value=self.code_size + 1) |
|
if incremental_state != {}: |
|
if first_decoder_inp is None: |
|
assert ctx_vqcodes is not None |
|
decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes |
|
ctx_vqcodes = None |
|
else: |
|
decoded[:, :1] = first_decoder_inp |
|
probs = [] |
|
for step in range(decoded.shape[1] - 1): |
|
vq_pred = self(txt_tokens, None, None, None, None, |
|
decoded[:, :step + 1], None, None, None, |
|
incremental_state=incremental_state, x_ling=x_ling, |
|
spk_pos_ids_flat=spk_pos_ids_flat, **kwargs) |
|
probs.append(vq_pred.cpu()) |
|
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: |
|
if self.hparams['dur_model_type'] == 'ar_mse': |
|
d = vq_pred[:, -1, 0] |
|
if dur_disturb > 0 and step >= 1: |
|
if random.random() > 0.5: |
|
d = d * (1 + random.random() * dur_disturb) |
|
else: |
|
d = d / (1 + random.random() * dur_disturb) |
|
d = torch.clamp_max(d, self.code_size - 1) |
|
vq_pred = torch.round(d).long() |
|
else: |
|
vq_pred = self.sample_one_step(vq_pred) |
|
decoded[:, step + 1] = torch.clamp_min(vq_pred, 1) |
|
if step == 0: |
|
decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min) |
|
else: |
|
decoded[:, step + 1] = ctx_vqcodes[:, step] |
|
decoded = decoded[:, 1:] |
|
decoded_2d = torch.zeros_like(txt_tokens_ori) |
|
decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded |
|
if return_state: |
|
return decoded_2d, incremental_state |
|
if return_probs: |
|
return decoded_2d, torch.cat(probs, 1) |
|
return decoded_2d |
|
|
|
def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, |
|
spk_id=None, spk_embed=None, mels_timbre=None, |
|
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, |
|
**kwargs): |
|
if incremental_state is None: |
|
incremental_state = {} |
|
x_ling = self.forward_ling_encoder( |
|
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, |
|
spk_id, spk_embed, mels_timbre) |
|
x_ling = x_ling.flatten(0, 1) |
|
txt_tokens_ori = txt_tokens |
|
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) |
|
x_ling = x_ling[txt_tokens > 0][None] |
|
txt_tokens = txt_tokens[txt_tokens > 0][None] |
|
|
|
vq_decoded = torch.zeros_like(txt_tokens) |
|
vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1) |
|
if incremental_state != {}: |
|
assert ctx_vqcodes is not None |
|
vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes |
|
ctx_vqcodes = None |
|
prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2] |
|
for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'): |
|
vq_pred = self(txt_tokens, None, None, None, None, |
|
vq_decoded[:, :step + 1], None, None, None, |
|
incremental_state=incremental_state, x_ling=x_ling, |
|
spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs) |
|
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: |
|
if self.hparams['dur_model_type'] == 'ar_mse': |
|
vq_pred = torch.round(vq_pred[:, -1, 0]).long() |
|
else: |
|
vq_pred = self.sample_one_step(vq_pred) |
|
vq_decoded[:, step + 1] = vq_pred |
|
else: |
|
vq_decoded[:, step + 1] = ctx_vqcodes[:, step] |
|
vq_decoded = vq_decoded[:, 1:] |
|
vq_decoded_2d = torch.zeros_like(txt_tokens_ori) |
|
vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded |
|
if return_state: |
|
return vq_decoded_2d, incremental_state |
|
return vq_decoded_2d |