# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from multiprocessing.sharedctypes import Value from re import T from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel import torch import torch.nn.functional as F import numpy as np import os import json import torch.nn as nn import tqdm from einops import rearrange os.chdir("./models/tts/debatts") import sys sys.path.append("./models/tts/debatts") from utils.topk_sampling import top_k_top_p_filtering import pickle class T2SLlama_new(nn.Module): def __init__( self, phone_vocab_size=1024, target_vocab_size=2048, hidden_size=1024, intermediate_size=4096, num_hidden_layers=12, num_attention_heads=16, pad_token_id=3072, bos_target_id=3073, eos_target_id=3074, bos_phone_id=3075, eos_phone_id=3076, bos_prompt0_id=3077, eos_prompt0_id=3078, use_lang_emb=False, cfg=None, ): super().__init__() phone_vocab_size = ( cfg.phone_vocab_size if cfg is not None and hasattr(cfg, "phone_vocab_size") else phone_vocab_size ) target_vocab_size = ( cfg.target_vocab_size if cfg is not None and hasattr(cfg, "target_vocab_size") else target_vocab_size ) hidden_size = ( cfg.hidden_size if cfg is not None and hasattr(cfg, "hidden_size") else hidden_size ) intermediate_size = ( cfg.intermediate_size if cfg is not None and hasattr(cfg, "intermediate_size") else intermediate_size ) num_hidden_layers = ( cfg.num_hidden_layers if cfg is not None and hasattr(cfg, "num_hidden_layers") else num_hidden_layers ) num_attention_heads = ( cfg.num_attention_heads if cfg is not None and hasattr(cfg, "num_attention_heads") else num_attention_heads ) pad_token_id = ( cfg.pad_token_id if cfg is not None and hasattr(cfg, "pad_token_id") else pad_token_id ) bos_target_id = ( cfg.bos_target_id if cfg is not None and hasattr(cfg, "bos_target_id") else bos_target_id ) eos_target_id = ( cfg.eos_target_id if cfg is not None and hasattr(cfg, "eos_target_id") else eos_target_id ) bos_phone_id = ( cfg.bos_phone_id if cfg is not None and hasattr(cfg, "bos_phone_id") else bos_phone_id ) eos_phone_id = ( cfg.eos_phone_id if cfg is not None and hasattr(cfg, "eos_phone_id") else eos_phone_id ) use_lang_emb = ( cfg.use_lang_emb if cfg is not None and hasattr(cfg, "use_lang_emb") else use_lang_emb ) bos_prompt0_id = ( cfg.bos_prompt0_id if cfg is not None and hasattr(cfg, "bos_prompt0_id") else bos_prompt0_id ) eos_prompt0_id = ( cfg.eos_prompt0_id if cfg is not None and hasattr(cfg, "eos_prompt0_id") else eos_prompt0_id ) self.config = LlamaConfig( vocab_size=phone_vocab_size + target_vocab_size + 20, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, pad_token_id=pad_token_id, bos_token_id=bos_target_id, eos_token_id=eos_target_id, bos_prompt0_id=bos_prompt0_id, eos_prompt0_id=eos_prompt0_id, ) self.phone_vocab_size = phone_vocab_size self.target_vocab_size = target_vocab_size self.hidden_size = hidden_size self.pad_token_id = pad_token_id self.bos_target_id = bos_target_id self.eos_target_id = eos_target_id self.bos_phone_id = bos_phone_id self.eos_phone_id = eos_phone_id self.use_lang_emb = use_lang_emb self.bos_prompt0_id = bos_prompt0_id self.eos_prompt0_id = eos_prompt0_id self.model = LlamaForCausalLM(self.config) if self.use_lang_emb: self.lang_emb = nn.Embedding(25, hidden_size, padding_idx=0) torch.nn.init.normal_(self.lang_emb.weight, mean=0.0, std=0.02) def forward( self, prompt0_ids, prompt0_mask, phone_ids, phone_mask, target_ids, target_mask, lang_id=None, ): prompt0_ids, prompt0_mask, prompt0_label, prompt0_lang_mask = ( self.add_phone_eos_bos_label( prompt0_ids, prompt0_mask, self.eos_prompt0_id, self.bos_prompt0_id, self.pad_token_id, label="prompt0_id", ) ) phone_ids, phone_mask, phone_label, lang_mask = self.add_phone_eos_bos_label( phone_ids, phone_mask, self.eos_phone_id, self.bos_phone_id, self.pad_token_id, label="phone_id", ) target_ids, target_mask, target_label = self.add_target_eos_bos_label( target_ids, target_mask, self.eos_target_id, self.bos_target_id, self.pad_token_id, ) input_token_ids = torch.cat([prompt0_ids, phone_ids, target_ids], dim=-1) attention_mask = torch.cat([prompt0_mask, phone_mask, target_mask], dim=-1) labels = torch.cat([prompt0_label, phone_label, target_label], dim=-1) # lang_id: (B,); lang_mask: (B, T) if self.use_lang_emb: lang_embedding = self.lang_emb(lang_id).unsqueeze(1) # (B, d) -> (B, 1, d) lang_embedding = lang_embedding * torch.cat( [prompt0_lang_mask, lang_mask, torch.zeros_like(target_mask)], dim=-1 ).unsqueeze( -1 ) # (B, T, d) input_token_embedding = self.model.model.embed_tokens( input_token_ids ) # (B, T, d) inputs_embeds = input_token_embedding + lang_embedding out = self.model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, return_dict=True, ) else: out = self.model( input_token_ids, attention_mask=attention_mask, labels=labels, return_dict=True, ) return out def add_phone_eos_bos_label( self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id, label ): # phone_ids: [B, T] # phone_mask: [B, T] # add 0 in the left lang_mask = F.pad(phone_mask, (1, 0), value=0) # add 0 in the right lang_mask = F.pad(lang_mask, (0, 1), value=0) if label == "phone_id": phone_ids = phone_ids + self.target_vocab_size * phone_mask phone_ids = phone_ids * phone_mask """Step-by-Step Computation: Pad phone_ids: After padding: [[101, 102, 103, 0]] Invert and Pad phone_mask: Inverted mask: [[0, 0, 0]] Padded inverted mask: [[0, 0, 0, 1]] Calculate EOS Insertion: Multiply with phone_eos_id: [[0, 0, 0, 200]] Combine: Combined result: [[101, 102, 103, 200]] """ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( 1 - phone_mask, (0, 1), value=1 ) # make pad token eos token, add eos token at the end phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask phone_ids = phone_ids * phone_mask + pad_token_id * ( 1 - phone_mask ) # restore pad token ids phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask phone_label = -100 * torch.ones_like( phone_ids ) # loss for entire phone is not computed (passed to llama) return phone_ids, phone_mask, phone_label, lang_mask def add_target_eos_bos_label( self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id ): # target_ids: [B, T] # target_mask: [B, T] target_ids = target_ids * target_mask target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad( 1 - target_mask, (0, 1), value=1 ) target_mask = F.pad(target_mask, (1, 0), value=1) target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask) target_ids = F.pad(target_ids, (1, 0), value=target_bos_id) target_mask = F.pad(target_mask, (1, 0), value=1) target_label = target_ids * target_mask + (-100) * ( 1 - target_mask ) # loss for target is computed on unmasked tokens return target_ids, target_mask, target_label def add_phone_middle_label( self, prompt0_ids, prompt0_mask, eos_prompt0_id, pad_token_id ): # prompt0_ids: [B, T] # prompt0_mask: [B, T] prompt0_ids = prompt0_ids * prompt0_mask prompt0_ids = F.pad(prompt0_ids, (0, 1), value=0) + eos_prompt0_id * F.pad( 1 - prompt0_mask, (0, 1), value=1 ) # Add eos_prompt0_id at the positions transitioning to padding prompt0_mask = F.pad( prompt0_mask, (1, 0), value=1 ) # Pad the mask for the new eos_prompt0_id prompt0_ids = prompt0_ids * prompt0_mask + pad_token_id * ( 1 - prompt0_mask ) # Restore pad tokens prompt0_ids = F.pad( prompt0_ids, (1, 0), value=eos_prompt0_id ) # Add eos_prompt0_id at the beginning prompt0_mask = F.pad( prompt0_mask, (1, 0), value=1 ) # Adjust the mask for the added eos_prompt0_id prompt0_label = prompt0_ids * prompt0_mask + (-100) * ( 1 - prompt0_mask ) # Set up labels for loss computation return prompt0_ids, prompt0_mask, prompt0_label @torch.no_grad() def sample_hf( self, phone_ids, # the phones of prompt and target should be concatenated together prompt_ids, prompt0_ids=None, max_length=100000, temperature=0.3, top_k=30, top_p=0.7, repeat_penalty=3.5, lang_ids=None, ): if prompt0_ids is not None: phone_mask = torch.ones_like(phone_ids) prompt_mask = torch.ones_like(prompt_ids) prompt_mask_prompt0 = torch.ones_like(prompt0_ids) # downsample = DownsampleWithMask(downsample_factor=2) # prompt0_ids, prompt_mask_prompt0 = downsample(prompt0_ids, prompt_mask_prompt0) phone_ids, _, _, _ = self.add_phone_eos_bos_label( phone_ids, phone_mask, self.eos_phone_id, self.bos_phone_id, self.pad_token_id, label="phone_id", ) prompt_ids, _, _ = self.add_target_eos_bos_label( prompt_ids, prompt_mask, self.eos_target_id, self.bos_target_id, self.pad_token_id, ) prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode prompt0_ids, _, _ = self.add_target_eos_bos_label( prompt0_ids, prompt_mask_prompt0, self.eos_prompt0_id, self.bos_prompt0_id, self.pad_token_id, ) input_token_ids = torch.cat([prompt0_ids, phone_ids, prompt_ids], dim=-1) input_length = input_token_ids.shape[1] if lang_ids != None and self.use_lang_emb: lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0) input_token_embedding = self.model.model.embed_tokens( input_token_ids ) # (B, T, d) # lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh'] lang_mask = torch.ones_like(phone_ids) lang_mask[:, 0] = 0 lang_mask[:, -1] = 0 lang_embedding = torch.cat( [ self.lang_emb(lang_ids), self.lang_emb(lang_ids), torch.zeros( lang_ids.shape[0], input_token_ids.shape[1] - lang_ids.shape[1], self.hidden_size, ).to(input_token_ids.device), ], dim=1, ) * torch.cat( [lang_mask, torch.zeros_like(prompt_ids)], dim=-1 ).unsqueeze( -1 ) inputs_embeds = input_token_embedding + lang_embedding # if prosody_features is not None: # # prosody_features = prosody_features.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) # inputs_embeds = inputs_embeds + prosody_features generated_ids = self.model.generate( # input wav phone token ids + text token ids inputs_embeds=inputs_embeds, do_sample=True, max_length=max_length, pad_token_id=self.pad_token_id, eos_token_id=self.eos_target_id, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, min_new_tokens=50, ) gen_tokens = generated_ids[:, :-1] else: input_token_embedding = self.model.model.embed_tokens(input_token_ids) generated_ids = self.model.generate( input_token_ids, do_sample=True, max_length=max_length, pad_token_id=self.pad_token_id, eos_token_id=self.eos_target_id, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, min_new_tokens=50, ) gen_tokens = generated_ids[:, input_length:-1] return gen_tokens else: phone_mask = torch.ones_like(phone_ids) prompt_mask = torch.ones_like(prompt_ids) phone_ids, _, _, _ = self.add_phone_eos_bos_label( phone_ids, phone_mask, self.eos_phone_id, self.bos_phone_id, self.pad_token_id, label="phone_ids", ) prompt_ids, _, _ = self.add_target_eos_bos_label( prompt_ids, prompt_mask, self.eos_target_id, self.bos_target_id, self.pad_token_id, ) prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1) input_length = input_token_ids.shape[1] if lang_ids != None and self.use_lang_emb: lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0) # token to vector input_token_embedding = self.model.model.embed_tokens( input_token_ids ) # (B, T, d) # lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh'] lang_mask = torch.ones_like(phone_ids) lang_mask[:, 0] = 0 lang_mask[:, -1] = 0 lang_embedding = torch.cat( [ self.lang_emb(lang_ids), torch.zeros( lang_ids.shape[0], input_token_ids.shape[1] - lang_ids.shape[1], self.hidden_size, ).to(input_token_ids.device), ], dim=1, ) * torch.cat( [lang_mask, torch.zeros_like(prompt_ids)], dim=-1 ).unsqueeze( -1 ) inputs_embeds = input_token_embedding + lang_embedding generated_ids = self.model.generate( # input wav phone token ids + text token ids inputs_embeds=inputs_embeds, do_sample=True, max_length=max_length, pad_token_id=self.pad_token_id, eos_token_id=self.eos_target_id, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, min_new_tokens=50, ) # assert generated_ids.size(1) > input_length, f"Generated tokens length {generated_ids.size(1)} is less than input length {input_length}, generated ids is {generated_ids}" gen_tokens = generated_ids[:, :-1] else: input_token_embedding = self.model.model.embed_tokens(input_token_ids) # if prosody_features is not None: # # prosody_features = prosody_features.unsqueeze(1).expand(-1, input_token_embedding.size(1), -1) # inputs_embeds = input_token_embedding + prosody_features # generated_ids = self.model.generate( # inputs_embeds=inputs_embeds, generated_ids = self.model.generate( input_token_ids, do_sample=True, max_length=max_length, pad_token_id=self.pad_token_id, eos_token_id=self.eos_target_id, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, min_new_tokens=50, ) return gen_tokens class DownsampleWithMask(nn.Module): def __init__(self, downsample_factor=2): super(DownsampleWithMask, self).__init__() self.downsample_factor = downsample_factor def forward(self, x, mask): # x shape: (batch_size, seq_len) x = x.float() x = x.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len) x = F.avg_pool1d( x, kernel_size=self.downsample_factor, stride=self.downsample_factor ) x = x.squeeze( 1 ) # remove channel dimension: (batch_size, seq_len // downsample_factor) x = x.long() # average pooling mask = mask.float() # convert mask to float for pooling mask = mask.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len) mask = F.avg_pool1d( mask, kernel_size=self.downsample_factor, stride=self.downsample_factor )