Respair commited on
Commit
fe76d2f
·
verified ·
1 Parent(s): bdb9479

Create meldataset.py

Browse files
Files changed (1) hide show
  1. meldataset.py +202 -0
meldataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import time
4
+ import random
5
+ import numpy as np
6
+ import random
7
+ import soundfile as sf
8
+
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+ from torch.utils.data import DataLoader
14
+ # from cotlet.phon import phonemize
15
+ # from g2p_en import G2p
16
+ import librosa
17
+
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.DEBUG)
21
+ # from text_utils import TextCleaner
22
+ np.random.seed(1)
23
+ random.seed(1)
24
+ # DEFAULT_DICT_PATH = osp.join(osp.dirname(__file__), 'word_index_dict.txt')
25
+
26
+ SPECT_PARAMS = {
27
+ "n_fft": 2048,
28
+ "win_length": 2048,
29
+ "hop_length": 512
30
+ }
31
+ MEL_PARAMS = {
32
+ "n_mels": 128,
33
+ "sample_rate":44_100,
34
+ "n_fft": 2048,
35
+ "win_length": 2048,
36
+ "hop_length": 512
37
+ }
38
+
39
+
40
+ _pad = "$"
41
+ _punctuation = ';:,.!?¡¿—…"«»“” '
42
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
43
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
44
+ _additions = f"ー()-~_+=0123456789[]<>/%&*#@◌" + chr(860) + chr(861) + chr(862) + chr(863) + chr(864) + chr(865) + chr(866)
45
+ # Export all symbols:
46
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + list(_additions)
47
+
48
+
49
+
50
+ dicts = {}
51
+ for i in range(len((symbols))):
52
+ dicts[symbols[i]] = i
53
+
54
+ class TextCleaner:
55
+ def __init__(self, dummy=None):
56
+ self.word_index_dictionary = dicts
57
+ def __call__(self, text):
58
+ indexes = []
59
+ for char in text:
60
+ try:
61
+ indexes.append(self.word_index_dictionary[char])
62
+ except KeyError:
63
+ print(text)
64
+ return indexes
65
+
66
+ class MelDataset(torch.utils.data.Dataset):
67
+ def __init__(self,
68
+ data_list,
69
+ sr=44100,
70
+ scaling_factor=1.0 # Add scaling_factor parameter
71
+ ):
72
+
73
+ spect_params = SPECT_PARAMS
74
+ mel_params = MEL_PARAMS
75
+
76
+ _data_list = [l[:-1].split('|') for l in data_list]
77
+ self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
78
+ self.text_cleaner = TextCleaner()
79
+ self.sr = sr
80
+
81
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(sample_rate=44_100,
82
+ n_mels=128,
83
+ n_fft=2048,
84
+ win_length=2048,
85
+ hop_length=512)
86
+ self.mean, self.std = -4, 4
87
+
88
+ # Add the beta-binomial interpolator
89
+ self.beta_binomial_interpolator = BetaBinomialInterpolator(scaling_factor=scaling_factor)
90
+
91
+ def __len__(self):
92
+ return len(self.data_list)
93
+
94
+ def __getitem__(self, idx):
95
+ data = self.data_list[idx]
96
+ wave, text_tensor, speaker_id = self._load_tensor(data)
97
+ wave_tensor = torch.from_numpy(wave).float()
98
+ mel_tensor = self.to_melspec(wave_tensor)
99
+
100
+ if (text_tensor.size(0)+1) >= (mel_tensor.size(1) // 3):
101
+ mel_tensor = F.interpolate(
102
+ mel_tensor.unsqueeze(0), size=(text_tensor.size(0)+1)*3, align_corners=False,
103
+ mode='linear').squeeze(0)
104
+
105
+ acoustic_feature = (torch.log(1e-5 + mel_tensor) - self.mean)/self.std
106
+
107
+ length_feature = acoustic_feature.size(1)
108
+ acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
109
+
110
+ # Generate attention prior matrix
111
+ text_len = text_tensor.size(0)
112
+ mel_len = acoustic_feature.size(1)
113
+ attn_prior = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_len)).float()
114
+
115
+ return wave_tensor, acoustic_feature, text_tensor, attn_prior, data[0]
116
+
117
+ def _load_tensor(self, data):
118
+ wave_path, text, speaker_id = data
119
+ speaker_id = int(speaker_id)
120
+ wave, sr = sf.read(wave_path)
121
+ if wave.shape[-1] == 2:
122
+ wave = wave[:, 0].squeeze()
123
+ if sr != 44100:
124
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=44100)
125
+
126
+ text = self.text_cleaner(text)
127
+
128
+ text.insert(0, 0)
129
+ text.append(0)
130
+
131
+ text = torch.LongTensor(text)
132
+
133
+ return wave, text, speaker_id
134
+
135
+ # Now modify the Collater class to handle the attention prior
136
+ class Collater(object):
137
+ def __init__(self, return_wave=False):
138
+ self.text_pad_index = 0
139
+ self.return_wave = return_wave
140
+
141
+ def __call__(self, batch):
142
+ batch_size = len(batch)
143
+
144
+ # sort by mel length
145
+ lengths = [b[1].shape[1] for b in batch]
146
+ batch_indexes = np.argsort(lengths)[::-1]
147
+ batch = [batch[bid] for bid in batch_indexes]
148
+
149
+ nmels = batch[0][1].size(0)
150
+ max_mel_length = max([b[1].shape[1] for b in batch])
151
+ max_text_length = max([b[2].shape[0] for b in batch])
152
+
153
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
154
+ texts = torch.zeros((batch_size, max_text_length)).long()
155
+ input_lengths = torch.zeros(batch_size).long()
156
+ output_lengths = torch.zeros(batch_size).long()
157
+
158
+ # Add tensor for attention priors
159
+ attn_priors = torch.zeros((batch_size, max_mel_length, max_text_length)).float()
160
+
161
+ paths = ['' for _ in range(batch_size)]
162
+
163
+ for bid, (_, mel, text, attn_prior, path) in enumerate(batch):
164
+ mel_size = mel.size(1)
165
+ text_size = text.size(0)
166
+ mels[bid, :, :mel_size] = mel
167
+ texts[bid, :text_size] = text
168
+ input_lengths[bid] = text_size
169
+ output_lengths[bid] = mel_size
170
+
171
+ # Handle attention prior
172
+ attn_priors[bid, :mel_size, :text_size] = attn_prior
173
+
174
+ paths[bid] = path
175
+ assert(text_size < (mel_size//2))
176
+
177
+ if self.return_wave:
178
+ waves = [b[0] for b in batch]
179
+ return texts, input_lengths, mels, output_lengths, attn_priors, paths, waves
180
+
181
+ return texts, input_lengths, mels, output_lengths, attn_priors
182
+
183
+ # Update the build_dataloader function to use the new MelDataset and Collater
184
+ def build_dataloader(path_list,
185
+ validation=False,
186
+ batch_size=4,
187
+ num_workers=1,
188
+ device='cpu',
189
+ collate_config={},
190
+ dataset_config={}):
191
+
192
+ dataset = MelDataset(path_list, **dataset_config)
193
+ collate_fn = Collater(**collate_config)
194
+ data_loader = DataLoader(dataset,
195
+ batch_size=batch_size,
196
+ shuffle=(not validation),
197
+ num_workers=num_workers,
198
+ drop_last=(not validation),
199
+ collate_fn=collate_fn,
200
+ pin_memory=(device != 'cpu'))
201
+
202
+ return data_loader