import os import os.path as osp import time import random import numpy as np import random import soundfile as sf import torch from torch import nn import torch.nn.functional as F import torchaudio from torch.utils.data import DataLoader # from cotlet.phon import phonemize # from g2p_en import G2p import librosa import logging logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # from text_utils import TextCleaner np.random.seed(1) random.seed(1) # DEFAULT_DICT_PATH = osp.join(osp.dirname(__file__), 'word_index_dict.txt') SPECT_PARAMS = { "n_fft": 2048, "win_length": 2048, "hop_length": 512 } MEL_PARAMS = { "n_mels": 128, "sample_rate":44_100, "n_fft": 2048, "win_length": 2048, "hop_length": 512 } _pad = "$" _punctuation = ';:,.!?¡¿—…"«»“” ' _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" _additions = f"ー()-~_+=0123456789[]<>/%&*#@◌" + chr(860) + chr(861) + chr(862) + chr(863) + chr(864) + chr(865) + chr(866) # Export all symbols: symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + list(_additions) dicts = {} for i in range(len((symbols))): dicts[symbols[i]] = i class TextCleaner: def __init__(self, dummy=None): self.word_index_dictionary = dicts def __call__(self, text): indexes = [] for char in text: try: indexes.append(self.word_index_dictionary[char]) except KeyError: print(text) return indexes class MelDataset(torch.utils.data.Dataset): def __init__(self, data_list, # dict_path=DEFAULT_DICT_PATH, sr=44100 ): spect_params = SPECT_PARAMS mel_params = MEL_PARAMS _data_list = [l[:-1].split('|') for l in data_list] self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list] self.text_cleaner = TextCleaner() self.sr = sr self.to_melspec = torchaudio.transforms.MelSpectrogram(sample_rate=44_100, n_mels=128, n_fft=2048, win_length=2048, hop_length=512) self.mean, self.std = -4, 4 # self.g2p = hibiki_phon() def __len__(self): return len(self.data_list) def __getitem__(self, idx): data = self.data_list[idx] wave, text_tensor, speaker_id = self._load_tensor(data) wave_tensor = torch.from_numpy(wave).float() mel_tensor = self.to_melspec(wave_tensor) if (text_tensor.size(0)+1) >= (mel_tensor.size(1) // 3): mel_tensor = F.interpolate( mel_tensor.unsqueeze(0), size=(text_tensor.size(0)+1)*3, align_corners=False, mode='linear').squeeze(0) acoustic_feature = (torch.log(1e-5 + mel_tensor) - self.mean)/self.std length_feature = acoustic_feature.size(1) acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)] return wave_tensor, acoustic_feature, text_tensor, data[0] def _load_tensor(self, data): wave_path, text, speaker_id = data speaker_id = int(speaker_id) wave, sr = sf.read(wave_path) if wave.shape[-1] == 2: wave = wave[:, 0].squeeze() if sr != 44100: wave = librosa.resample(wave, orig_sr=sr, target_sr=44100) # print(wave_path, sr) # wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0) text = self.text_cleaner(text) text.insert(0, 0) text.append(0) text = torch.LongTensor(text) return wave, text, speaker_id class Collater(object): """ Args: return_wave (bool): if true, will return the wave data along with spectrogram. """ def __init__(self, return_wave=False): self.text_pad_index = 0 self.return_wave = return_wave def __call__(self, batch): batch_size = len(batch) # sort by mel length lengths = [b[1].shape[1] for b in batch] batch_indexes = np.argsort(lengths)[::-1] batch = [batch[bid] for bid in batch_indexes] nmels = batch[0][1].size(0) max_mel_length = max([b[1].shape[1] for b in batch]) max_text_length = max([b[2].shape[0] for b in batch]) mels = torch.zeros((batch_size, nmels, max_mel_length)).float() texts = torch.zeros((batch_size, max_text_length)).long() input_lengths = torch.zeros(batch_size).long() output_lengths = torch.zeros(batch_size).long() paths = ['' for _ in range(batch_size)] for bid, (_, mel, text, path) in enumerate(batch): mel_size = mel.size(1) text_size = text.size(0) mels[bid, :, :mel_size] = mel texts[bid, :text_size] = text input_lengths[bid] = text_size output_lengths[bid] = mel_size paths[bid] = path assert(text_size < (mel_size//2)) if self.return_wave: waves = [b[0] for b in batch] return texts, input_lengths, mels, output_lengths, paths, waves return texts, input_lengths, mels, output_lengths def build_dataloader(path_list, validation=False, batch_size=4, num_workers=1, device='cpu', collate_config={}, dataset_config={}): dataset = MelDataset(path_list, **dataset_config) collate_fn = Collater(**collate_config) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=(not validation), num_workers=num_workers, drop_last=(not validation), collate_fn=collate_fn, pin_memory=(device != 'cpu')) return data_loader