Spaces:
Paused
Paused
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from cmath import inf | |
import io | |
import librosa | |
import torch | |
import json | |
import tqdm | |
import numpy as np | |
import logging | |
import pickle | |
import os | |
import time | |
from torch.utils.data import Dataset | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from multiprocessing import Pool | |
import concurrent.futures | |
from pathlib import Path | |
from transformers import SeamlessM4TFeatureExtractor | |
from transformers import Wav2Vec2BertModel | |
os.chdir("./models/tts/debatts") | |
import sys | |
sys.path.append("./models/tts/debatts") | |
from utils.g2p_new.g2p_new import new_g2p | |
from torch.nn.utils.rnn import pad_sequence | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class WarningFilter(logging.Filter): | |
def filter(self, record): | |
if record.name == "phonemizer" and record.levelno == logging.WARNING: | |
return False | |
if record.name == "qcloud_cos.cos_client" and record.levelno == logging.INFO: | |
return False | |
if record.name == "jieba" and record.levelno == logging.DEBUG: | |
return False | |
return True | |
filter = WarningFilter() | |
logging.getLogger("phonemizer").addFilter(filter) | |
logging.getLogger("qcloud_cos.cos_client").addFilter(filter) | |
logging.getLogger("jieba").addFilter(filter) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class T2SDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
cfg=None, | |
): | |
self.cfg = cfg | |
self.meta_info_path = "Debatts-Data Summary Json" | |
with open(self.meta_info_path, "r") as f: | |
self.meta_info_data = json.load(f) | |
self.wav_paths = [] | |
self.prompt0_paths = [] # Add prompt0 paths | |
self.wav_path_index2duration = [] | |
self.wav_path_index2phonelen = [] | |
self.wav_path_index2spkid = [] | |
self.wav_path_index2phoneid = [] | |
self.index2num_frames = [] | |
self.index2lang = [] | |
self.lang2id = {"en": 1, "zh": 2, "ja": 3, "fr": 4, "ko": 5, "de": 6} | |
for info in self.meta_info_data: | |
if info["prompt0_wav_path"] == None: | |
continue | |
self.wav_paths.append(info["wav_path"]) | |
self.prompt0_paths.append(info["prompt0_wav_path"]) # Add prompt0 path | |
self.wav_path_index2duration.append(info["duration"]) | |
self.wav_path_index2phonelen.append(info["phone_count"]) | |
self.wav_path_index2spkid.append(info["speaker_id"]) | |
self.wav_path_index2phoneid.append(info["phone_id"]) | |
self.index2num_frames.append(info["duration"] * 50 + len(info["phone_id"])) | |
lang_id = self.lang2id[info["language"]] | |
self.index2lang.append(lang_id) | |
# self.index2num_frames.append(info["duration"] * self.cfg.preprocess.sample_rate) | |
self.num_frame_indices = np.array( | |
sorted( | |
range(len(self.index2num_frames)), | |
key=lambda k: self.index2num_frames[k], | |
) | |
) | |
self.processor = SeamlessM4TFeatureExtractor.from_pretrained("./w2v-bert-2") | |
def new_g2p(self, text, language): | |
return new_g2p(text, language) | |
def __len__(self): | |
return self.wav_paths.__len__() | |
def get_num_frames(self, index): | |
return ( | |
self.wav_path_index2duration[index] * 50 | |
+ self.wav_path_index2phonelen[index] | |
) | |
def __getitem__(self, idx): | |
wav_path = self.wav_paths[idx] | |
speech, sr = librosa.load(wav_path, sr=self.cfg.preprocess.sample_rate) | |
speech = np.pad( | |
speech, | |
( | |
0, | |
self.cfg.preprocess.hop_size | |
- len(speech) % self.cfg.preprocess.hop_size, | |
), | |
mode="constant", | |
) | |
# resample the speech to 16k for feature extraction | |
if self.cfg.preprocess.sample_rate != 16000: | |
speech_16k = librosa.resample( | |
speech, orig_sr=self.cfg.preprocess.sample_rate, target_sr=16000 | |
) | |
else: | |
speech_16k = speech | |
inputs = self.processor(speech_16k, sampling_rate=16000) | |
# wav 2 bert convert to useful feature | |
input_features = inputs["input_features"][0] | |
attention_mask = inputs["attention_mask"][0] | |
prompt0_wav_path = self.prompt0_paths[idx] # Get prompt0 path | |
speech_prompt0, sr_prompt0 = librosa.load( | |
prompt0_wav_path, sr=self.cfg.preprocess.sample_rate | |
) | |
speech_prompt0 = np.pad( | |
speech_prompt0, | |
( | |
0, | |
self.cfg.preprocess.hop_size | |
- len(speech_prompt0) % self.cfg.preprocess.hop_size, | |
), | |
mode="constant", | |
) | |
# resample the speech to 16k for feature extraction | |
if self.cfg.preprocess.sample_rate != 16000: | |
speech_16k_prompt0 = librosa.resample( | |
speech_prompt0, orig_sr=self.cfg.preprocess.sample_rate, target_sr=16000 | |
) | |
else: | |
speech_16k_prompt0 = speech_prompt0 | |
inputs_prompt0 = self.processor(speech_16k_prompt0, sampling_rate=16000) | |
input_features_prompt0 = inputs_prompt0["input_features"][0] | |
attention_mask_prompt0 = inputs_prompt0["attention_mask"][0] | |
# get speech mask | |
speech_frames = len(speech) // self.cfg.preprocess.hop_size | |
mask = np.ones(speech_frames) | |
speech_frames_prompt0 = len(speech_prompt0) // self.cfg.preprocess.hop_size | |
mask_prompt0 = np.ones(speech_frames_prompt0) | |
del speech, speech_16k, speech_prompt0, speech_16k_prompt0 | |
lang_id = self.index2lang[idx] | |
phone_id = self.wav_path_index2phoneid[idx] | |
phone_id = torch.tensor(phone_id, dtype=torch.long) | |
phone_mask = np.ones(len(phone_id)) | |
single_feature = dict() | |
spk_id = self.wav_path_index2spkid[idx] | |
single_feature.update({"spk_id": spk_id}) | |
single_feature.update({"lang_id": lang_id}) | |
single_feature.update({"phone_id": phone_id}) | |
single_feature.update({"phone_mask": phone_mask}) | |
single_feature.update( | |
{ | |
"input_features": input_features, | |
"attention_mask": attention_mask, | |
"mask": mask, | |
"input_features_prompt0": input_features_prompt0, | |
"attention_mask_prompt0": attention_mask_prompt0, | |
"mask_prompt0": mask_prompt0, | |
} | |
) | |
return single_feature | |
class T2SCollator(object): | |
def __init__(self, cfg): | |
self.cfg = cfg | |
def __call__(self, batch): | |
packed_batch_features = dict() | |
for key in batch[0].keys(): | |
if "input_features" in key: | |
packed_batch_features[key] = pad_sequence( | |
[ | |
( | |
utt[key].float() | |
if isinstance(utt[key], torch.Tensor) | |
else torch.tensor(utt[key]).float() | |
) | |
for utt in batch | |
], | |
batch_first=True, | |
) | |
if "attention_mask" in key: | |
packed_batch_features[key] = pad_sequence( | |
[ | |
( | |
utt[key].float() | |
if isinstance(utt[key], torch.Tensor) | |
else torch.tensor(utt[key]).float() | |
) | |
for utt in batch | |
], | |
batch_first=True, | |
) | |
if "mask" in key: | |
packed_batch_features[key] = pad_sequence( | |
[ | |
( | |
utt[key].long() | |
if isinstance(utt[key], torch.Tensor) | |
else torch.tensor(utt[key]).long() | |
) | |
for utt in batch | |
], | |
batch_first=True, | |
) | |
if "semantic_code" in key: | |
packed_batch_features[key] = pad_sequence( | |
[ | |
( | |
utt[key].float() | |
if isinstance(utt[key], torch.Tensor) | |
else torch.tensor(utt[key]).float() | |
) | |
for utt in batch | |
], | |
batch_first=True, | |
) | |
if key == "phone_id": | |
packed_batch_features[key] = pad_sequence( | |
[utt[key].long() for utt in batch], | |
batch_first=True, | |
padding_value=1023, # phone vocab size is 1024 | |
) | |
if key == "phone_mask": | |
packed_batch_features[key] = pad_sequence( | |
[torch.tensor(utt[key]).long() for utt in batch], batch_first=True | |
) | |
if key == "lang_id": | |
packed_batch_features[key] = torch.tensor( | |
[utt[key] for utt in batch] | |
).long() | |
if key == "spk_id": | |
packed_batch_features[key] = torch.tensor( | |
[utt[key] for utt in batch] | |
).long() | |
if key == "spk_emb_input_features": | |
packed_batch_features[key] = pad_sequence( | |
[torch.tensor(utt[key]).float() for utt in batch], batch_first=True | |
) | |
if key == "spk_emb_attention_mask": | |
packed_batch_features[key] = pad_sequence( | |
[torch.tensor(utt[key]).long() for utt in batch], batch_first=True | |
) | |
else: | |
pass | |
return packed_batch_features | |
class DownsampleWithMask(nn.Module): | |
def __init__(self, downsample_factor=2): | |
super(DownsampleWithMask, self).__init__() | |
self.downsample_factor = downsample_factor | |
def forward(self, x, mask): | |
# input from numpy.ndarray to torch.Tensor | |
if isinstance(x, np.ndarray): | |
x = torch.tensor(x, dtype=torch.float32) | |
if isinstance(mask, np.ndarray): | |
mask = torch.tensor(mask, dtype=torch.float32) | |
# print(f"################## x size original {x.shape}################################") | |
x = x.float() | |
x = x.permute(1, 0) # to (feature_dim, timestep) | |
x = x.unsqueeze(1) # add channel dimension: (timestep, 1, feature_dim) | |
if x.size(-1) < self.downsample_factor: | |
raise ValueError("Input size must be larger than downsample factor") | |
# print(f"################## x size before {x.shape}################################") | |
x = F.avg_pool1d(x, kernel_size=self.downsample_factor) | |
x = x.squeeze( | |
1 | |
) # remove channel dimension: (timestep, feature_dim // downsample_factor) | |
x = x.long() | |
x = x.permute(1, 0) # to (feature_dim, timestep) | |
mask = mask.float() # convert mask to float for pooling | |
mask = mask.unsqueeze(0).unsqueeze( | |
0 | |
) # add channel dimension: (timestep, 1, feature_dim) | |
if mask.size(-1) < self.downsample_factor: | |
raise ValueError("Mask size must be larger than downsample factor") | |
mask = F.avg_pool1d( | |
mask, kernel_size=self.downsample_factor, stride=self.downsample_factor | |
) | |
mask = mask.squeeze(0).squeeze( | |
0 | |
) # remove channel dimension: (timestep, feature_dim // downsample_factor) | |
mask = (mask >= 0.5).long() # if average > 0.5 --> 1, else 0 | |
return x, mask | |