Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
import pandas as pd | |
import torchaudio | |
import random | |
import itertools | |
import numpy as np | |
import numpy as np | |
def normalize_wav(waveform): | |
waveform = waveform - torch.mean(waveform) | |
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) | |
return waveform * 0.5 | |
def pad_wav(waveform, segment_length): | |
waveform_length = len(waveform) | |
if segment_length is None or waveform_length == segment_length: | |
return waveform | |
elif waveform_length > segment_length: | |
return waveform[:segment_length] | |
else: | |
padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device) | |
waveform = torch.cat([waveform, padded_wav]) | |
return waveform | |
def read_wav_file(filename, duration_sec): | |
info = torchaudio.info(filename) | |
sample_rate = info.sample_rate | |
# Calculate the number of frames corresponding to the desired duration | |
num_frames = int(sample_rate * duration_sec) | |
waveform, sr = torchaudio.load(filename, num_frames=num_frames) # Faster!!! | |
if waveform.shape[0] == 2: ## Stereo audio | |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100) | |
resampled_waveform = resampler(waveform) | |
# print(resampled_waveform.shape) | |
padded_left = pad_wav( | |
resampled_waveform[0], int(44100 * duration_sec) | |
) ## We pad left and right seperately | |
padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec)) | |
return torch.stack([padded_left, padded_right]) | |
else: | |
waveform = torchaudio.functional.resample( | |
waveform, orig_freq=sr, new_freq=44100 | |
)[0] | |
waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0) | |
return waveform | |
class DPOText2AudioDataset(Dataset): | |
def __init__( | |
self, | |
dataset, | |
prefix, | |
text_column, | |
audio_w_column, | |
audio_l_column, | |
duration, | |
num_examples=-1, | |
): | |
inputs = list(dataset[text_column]) | |
self.inputs = [prefix + inp for inp in inputs] | |
self.audios_w = list(dataset[audio_w_column]) | |
self.audios_l = list(dataset[audio_l_column]) | |
self.durations = list(dataset[duration]) | |
self.indices = list(range(len(self.inputs))) | |
self.mapper = {} | |
for index, audio_w, audio_l, duration, text in zip( | |
self.indices, self.audios_w, self.audios_l, self.durations, inputs | |
): | |
self.mapper[index] = [audio_w, audio_l, duration, text] | |
if num_examples != -1: | |
self.inputs, self.audios_w, self.audios_l, self.durations = ( | |
self.inputs[:num_examples], | |
self.audios_w[:num_examples], | |
self.audios_l[:num_examples], | |
self.durations[:num_examples], | |
) | |
self.indices = self.indices[:num_examples] | |
def __len__(self): | |
return len(self.inputs) | |
def get_num_instances(self): | |
return len(self.inputs) | |
def __getitem__(self, index): | |
s1, s2, s3, s4, s5 = ( | |
self.inputs[index], | |
self.audios_w[index], | |
self.audios_l[index], | |
self.durations[index], | |
self.indices[index], | |
) | |
return s1, s2, s3, s4, s5 | |
def collate_fn(self, data): | |
dat = pd.DataFrame(data) | |
return [dat[i].tolist() for i in dat] | |
class Text2AudioDataset(Dataset): | |
def __init__( | |
self, dataset, prefix, text_column, audio_column, duration, num_examples=-1 | |
): | |
inputs = list(dataset[text_column]) | |
self.inputs = [prefix + inp for inp in inputs] | |
self.audios = list(dataset[audio_column]) | |
self.durations = list(dataset[duration]) | |
self.indices = list(range(len(self.inputs))) | |
self.mapper = {} | |
for index, audio, duration, text in zip( | |
self.indices, self.audios, self.durations, inputs | |
): | |
self.mapper[index] = [audio, text, duration] | |
if num_examples != -1: | |
self.inputs, self.audios, self.durations = ( | |
self.inputs[:num_examples], | |
self.audios[:num_examples], | |
self.durations[:num_examples], | |
) | |
self.indices = self.indices[:num_examples] | |
def __len__(self): | |
return len(self.inputs) | |
def get_num_instances(self): | |
return len(self.inputs) | |
def __getitem__(self, index): | |
s1, s2, s3, s4 = ( | |
self.inputs[index], | |
self.audios[index], | |
self.durations[index], | |
self.indices[index], | |
) | |
return s1, s2, s3, s4 | |
def collate_fn(self, data): | |
dat = pd.DataFrame(data) | |
return [dat[i].tolist() for i in dat] | |