naonauno's picture
Upload 855 files
d66c48f verified
raw
history blame contribute delete
20.4 kB
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from multiprocessing.sharedctypes import Value
from re import T
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
import torch
import torch.nn.functional as F
import numpy as np
import os
import json
import torch.nn as nn
import tqdm
from einops import rearrange
os.chdir("./models/tts/debatts")
import sys
sys.path.append("./models/tts/debatts")
from utils.topk_sampling import top_k_top_p_filtering
import pickle
class T2SLlama_new(nn.Module):
def __init__(
self,
phone_vocab_size=1024,
target_vocab_size=2048,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=12,
num_attention_heads=16,
pad_token_id=3072,
bos_target_id=3073,
eos_target_id=3074,
bos_phone_id=3075,
eos_phone_id=3076,
bos_prompt0_id=3077,
eos_prompt0_id=3078,
use_lang_emb=False,
cfg=None,
):
super().__init__()
phone_vocab_size = (
cfg.phone_vocab_size
if cfg is not None and hasattr(cfg, "phone_vocab_size")
else phone_vocab_size
)
target_vocab_size = (
cfg.target_vocab_size
if cfg is not None and hasattr(cfg, "target_vocab_size")
else target_vocab_size
)
hidden_size = (
cfg.hidden_size
if cfg is not None and hasattr(cfg, "hidden_size")
else hidden_size
)
intermediate_size = (
cfg.intermediate_size
if cfg is not None and hasattr(cfg, "intermediate_size")
else intermediate_size
)
num_hidden_layers = (
cfg.num_hidden_layers
if cfg is not None and hasattr(cfg, "num_hidden_layers")
else num_hidden_layers
)
num_attention_heads = (
cfg.num_attention_heads
if cfg is not None and hasattr(cfg, "num_attention_heads")
else num_attention_heads
)
pad_token_id = (
cfg.pad_token_id
if cfg is not None and hasattr(cfg, "pad_token_id")
else pad_token_id
)
bos_target_id = (
cfg.bos_target_id
if cfg is not None and hasattr(cfg, "bos_target_id")
else bos_target_id
)
eos_target_id = (
cfg.eos_target_id
if cfg is not None and hasattr(cfg, "eos_target_id")
else eos_target_id
)
bos_phone_id = (
cfg.bos_phone_id
if cfg is not None and hasattr(cfg, "bos_phone_id")
else bos_phone_id
)
eos_phone_id = (
cfg.eos_phone_id
if cfg is not None and hasattr(cfg, "eos_phone_id")
else eos_phone_id
)
use_lang_emb = (
cfg.use_lang_emb
if cfg is not None and hasattr(cfg, "use_lang_emb")
else use_lang_emb
)
bos_prompt0_id = (
cfg.bos_prompt0_id
if cfg is not None and hasattr(cfg, "bos_prompt0_id")
else bos_prompt0_id
)
eos_prompt0_id = (
cfg.eos_prompt0_id
if cfg is not None and hasattr(cfg, "eos_prompt0_id")
else eos_prompt0_id
)
self.config = LlamaConfig(
vocab_size=phone_vocab_size + target_vocab_size + 20,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
pad_token_id=pad_token_id,
bos_token_id=bos_target_id,
eos_token_id=eos_target_id,
bos_prompt0_id=bos_prompt0_id,
eos_prompt0_id=eos_prompt0_id,
)
self.phone_vocab_size = phone_vocab_size
self.target_vocab_size = target_vocab_size
self.hidden_size = hidden_size
self.pad_token_id = pad_token_id
self.bos_target_id = bos_target_id
self.eos_target_id = eos_target_id
self.bos_phone_id = bos_phone_id
self.eos_phone_id = eos_phone_id
self.use_lang_emb = use_lang_emb
self.bos_prompt0_id = bos_prompt0_id
self.eos_prompt0_id = eos_prompt0_id
self.model = LlamaForCausalLM(self.config)
if self.use_lang_emb:
self.lang_emb = nn.Embedding(25, hidden_size, padding_idx=0)
torch.nn.init.normal_(self.lang_emb.weight, mean=0.0, std=0.02)
def forward(
self,
prompt0_ids,
prompt0_mask,
phone_ids,
phone_mask,
target_ids,
target_mask,
lang_id=None,
):
prompt0_ids, prompt0_mask, prompt0_label, prompt0_lang_mask = (
self.add_phone_eos_bos_label(
prompt0_ids,
prompt0_mask,
self.eos_prompt0_id,
self.bos_prompt0_id,
self.pad_token_id,
label="prompt0_id",
)
)
phone_ids, phone_mask, phone_label, lang_mask = self.add_phone_eos_bos_label(
phone_ids,
phone_mask,
self.eos_phone_id,
self.bos_phone_id,
self.pad_token_id,
label="phone_id",
)
target_ids, target_mask, target_label = self.add_target_eos_bos_label(
target_ids,
target_mask,
self.eos_target_id,
self.bos_target_id,
self.pad_token_id,
)
input_token_ids = torch.cat([prompt0_ids, phone_ids, target_ids], dim=-1)
attention_mask = torch.cat([prompt0_mask, phone_mask, target_mask], dim=-1)
labels = torch.cat([prompt0_label, phone_label, target_label], dim=-1)
# lang_id: (B,); lang_mask: (B, T)
if self.use_lang_emb:
lang_embedding = self.lang_emb(lang_id).unsqueeze(1) # (B, d) -> (B, 1, d)
lang_embedding = lang_embedding * torch.cat(
[prompt0_lang_mask, lang_mask, torch.zeros_like(target_mask)], dim=-1
).unsqueeze(
-1
) # (B, T, d)
input_token_embedding = self.model.model.embed_tokens(
input_token_ids
) # (B, T, d)
inputs_embeds = input_token_embedding + lang_embedding
out = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
)
else:
out = self.model(
input_token_ids,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
)
return out
def add_phone_eos_bos_label(
self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id, label
):
# phone_ids: [B, T]
# phone_mask: [B, T]
# add 0 in the left
lang_mask = F.pad(phone_mask, (1, 0), value=0)
# add 0 in the right
lang_mask = F.pad(lang_mask, (0, 1), value=0)
if label == "phone_id":
phone_ids = phone_ids + self.target_vocab_size * phone_mask
phone_ids = phone_ids * phone_mask
"""Step-by-Step Computation:
Pad phone_ids:
After padding: [[101, 102, 103, 0]]
Invert and Pad phone_mask:
Inverted mask: [[0, 0, 0]]
Padded inverted mask: [[0, 0, 0, 1]]
Calculate EOS Insertion:
Multiply with phone_eos_id: [[0, 0, 0, 200]]
Combine:
Combined result: [[101, 102, 103, 200]]
"""
phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
1 - phone_mask, (0, 1), value=1
) # make pad token eos token, add eos token at the end
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
phone_ids = phone_ids * phone_mask + pad_token_id * (
1 - phone_mask
) # restore pad token ids
phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
phone_label = -100 * torch.ones_like(
phone_ids
) # loss for entire phone is not computed (passed to llama)
return phone_ids, phone_mask, phone_label, lang_mask
def add_target_eos_bos_label(
self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id
):
# target_ids: [B, T]
# target_mask: [B, T]
target_ids = target_ids * target_mask
target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad(
1 - target_mask, (0, 1), value=1
)
target_mask = F.pad(target_mask, (1, 0), value=1)
target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask)
target_ids = F.pad(target_ids, (1, 0), value=target_bos_id)
target_mask = F.pad(target_mask, (1, 0), value=1)
target_label = target_ids * target_mask + (-100) * (
1 - target_mask
) # loss for target is computed on unmasked tokens
return target_ids, target_mask, target_label
def add_phone_middle_label(
self, prompt0_ids, prompt0_mask, eos_prompt0_id, pad_token_id
):
# prompt0_ids: [B, T]
# prompt0_mask: [B, T]
prompt0_ids = prompt0_ids * prompt0_mask
prompt0_ids = F.pad(prompt0_ids, (0, 1), value=0) + eos_prompt0_id * F.pad(
1 - prompt0_mask, (0, 1), value=1
) # Add eos_prompt0_id at the positions transitioning to padding
prompt0_mask = F.pad(
prompt0_mask, (1, 0), value=1
) # Pad the mask for the new eos_prompt0_id
prompt0_ids = prompt0_ids * prompt0_mask + pad_token_id * (
1 - prompt0_mask
) # Restore pad tokens
prompt0_ids = F.pad(
prompt0_ids, (1, 0), value=eos_prompt0_id
) # Add eos_prompt0_id at the beginning
prompt0_mask = F.pad(
prompt0_mask, (1, 0), value=1
) # Adjust the mask for the added eos_prompt0_id
prompt0_label = prompt0_ids * prompt0_mask + (-100) * (
1 - prompt0_mask
) # Set up labels for loss computation
return prompt0_ids, prompt0_mask, prompt0_label
@torch.no_grad()
def sample_hf(
self,
phone_ids, # the phones of prompt and target should be concatenated together
prompt_ids,
prompt0_ids=None,
max_length=100000,
temperature=0.3,
top_k=30,
top_p=0.7,
repeat_penalty=3.5,
lang_ids=None,
):
if prompt0_ids is not None:
phone_mask = torch.ones_like(phone_ids)
prompt_mask = torch.ones_like(prompt_ids)
prompt_mask_prompt0 = torch.ones_like(prompt0_ids)
# downsample = DownsampleWithMask(downsample_factor=2)
# prompt0_ids, prompt_mask_prompt0 = downsample(prompt0_ids, prompt_mask_prompt0)
phone_ids, _, _, _ = self.add_phone_eos_bos_label(
phone_ids,
phone_mask,
self.eos_phone_id,
self.bos_phone_id,
self.pad_token_id,
label="phone_id",
)
prompt_ids, _, _ = self.add_target_eos_bos_label(
prompt_ids,
prompt_mask,
self.eos_target_id,
self.bos_target_id,
self.pad_token_id,
)
prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode
prompt0_ids, _, _ = self.add_target_eos_bos_label(
prompt0_ids,
prompt_mask_prompt0,
self.eos_prompt0_id,
self.bos_prompt0_id,
self.pad_token_id,
)
input_token_ids = torch.cat([prompt0_ids, phone_ids, prompt_ids], dim=-1)
input_length = input_token_ids.shape[1]
if lang_ids != None and self.use_lang_emb:
lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0)
input_token_embedding = self.model.model.embed_tokens(
input_token_ids
) # (B, T, d)
# lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh']
lang_mask = torch.ones_like(phone_ids)
lang_mask[:, 0] = 0
lang_mask[:, -1] = 0
lang_embedding = torch.cat(
[
self.lang_emb(lang_ids),
self.lang_emb(lang_ids),
torch.zeros(
lang_ids.shape[0],
input_token_ids.shape[1] - lang_ids.shape[1],
self.hidden_size,
).to(input_token_ids.device),
],
dim=1,
) * torch.cat(
[lang_mask, torch.zeros_like(prompt_ids)], dim=-1
).unsqueeze(
-1
)
inputs_embeds = input_token_embedding + lang_embedding
# if prosody_features is not None:
#
# prosody_features = prosody_features.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1)
# inputs_embeds = inputs_embeds + prosody_features
generated_ids = self.model.generate(
# input wav phone token ids + text token ids
inputs_embeds=inputs_embeds,
do_sample=True,
max_length=max_length,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_target_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repeat_penalty,
min_new_tokens=50,
)
gen_tokens = generated_ids[:, :-1]
else:
input_token_embedding = self.model.model.embed_tokens(input_token_ids)
generated_ids = self.model.generate(
input_token_ids,
do_sample=True,
max_length=max_length,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_target_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repeat_penalty,
min_new_tokens=50,
)
gen_tokens = generated_ids[:, input_length:-1]
return gen_tokens
else:
phone_mask = torch.ones_like(phone_ids)
prompt_mask = torch.ones_like(prompt_ids)
phone_ids, _, _, _ = self.add_phone_eos_bos_label(
phone_ids,
phone_mask,
self.eos_phone_id,
self.bos_phone_id,
self.pad_token_id,
label="phone_ids",
)
prompt_ids, _, _ = self.add_target_eos_bos_label(
prompt_ids,
prompt_mask,
self.eos_target_id,
self.bos_target_id,
self.pad_token_id,
)
prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode
input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1)
input_length = input_token_ids.shape[1]
if lang_ids != None and self.use_lang_emb:
lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0)
# token to vector
input_token_embedding = self.model.model.embed_tokens(
input_token_ids
) # (B, T, d)
# lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh']
lang_mask = torch.ones_like(phone_ids)
lang_mask[:, 0] = 0
lang_mask[:, -1] = 0
lang_embedding = torch.cat(
[
self.lang_emb(lang_ids),
torch.zeros(
lang_ids.shape[0],
input_token_ids.shape[1] - lang_ids.shape[1],
self.hidden_size,
).to(input_token_ids.device),
],
dim=1,
) * torch.cat(
[lang_mask, torch.zeros_like(prompt_ids)], dim=-1
).unsqueeze(
-1
)
inputs_embeds = input_token_embedding + lang_embedding
generated_ids = self.model.generate(
# input wav phone token ids + text token ids
inputs_embeds=inputs_embeds,
do_sample=True,
max_length=max_length,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_target_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repeat_penalty,
min_new_tokens=50,
)
# assert generated_ids.size(1) > input_length, f"Generated tokens length {generated_ids.size(1)} is less than input length {input_length}, generated ids is {generated_ids}"
gen_tokens = generated_ids[:, :-1]
else:
input_token_embedding = self.model.model.embed_tokens(input_token_ids)
# if prosody_features is not None:
#
# prosody_features = prosody_features.unsqueeze(1).expand(-1, input_token_embedding.size(1), -1)
# inputs_embeds = input_token_embedding + prosody_features
# generated_ids = self.model.generate(
# inputs_embeds=inputs_embeds,
generated_ids = self.model.generate(
input_token_ids,
do_sample=True,
max_length=max_length,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_target_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repeat_penalty,
min_new_tokens=50,
)
return gen_tokens
class DownsampleWithMask(nn.Module):
def __init__(self, downsample_factor=2):
super(DownsampleWithMask, self).__init__()
self.downsample_factor = downsample_factor
def forward(self, x, mask):
# x shape: (batch_size, seq_len)
x = x.float()
x = x.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len)
x = F.avg_pool1d(
x, kernel_size=self.downsample_factor, stride=self.downsample_factor
)
x = x.squeeze(
1
) # remove channel dimension: (batch_size, seq_len // downsample_factor)
x = x.long()
# average pooling
mask = mask.float() # convert mask to float for pooling
mask = mask.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len)
mask = F.avg_pool1d(
mask, kernel_size=self.downsample_factor, stride=self.downsample_factor
)