Spaces:
Runtime error
Runtime error
File size: 4,381 Bytes
73ed896 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import logging
import os
from pathlib import Path
from typing import Union
import open_clip
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
log = logging.getLogger()
class WavTextClipsDataset(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
captions_tsv: Union[str, Path],
clips_tsv: Union[str, Path],
sample_rate: int,
num_samples: int,
normalize_audio: bool = False,
reject_silent: bool = False,
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
):
self.root = Path(root)
self.sample_rate = sample_rate
self.num_samples = num_samples
self.normalize_audio = normalize_audio
self.reject_silent = reject_silent
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
audios = sorted(os.listdir(self.root))
audios = set([
Path(audio).stem for audio in audios
if audio.endswith('.wav') or audio.endswith('.flac')
])
self.captions = {}
# read the caption tsv
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
caption = record['caption']
self.captions[id] = caption
# read the clip tsv
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
'id': str,
'name': str
}).to_dict('records')
self.clips = []
for record in df_list:
record['id'] = record['id']
record['name'] = record['name']
id = record['id']
name = record['name']
if name not in self.captions:
log.warning(f'Audio {name} not found in {captions_tsv}')
continue
record['caption'] = self.captions[name]
self.clips.append(record)
log.info(f'Found {len(self.clips)} audio files in {self.root}')
self.resampler = {}
def __getitem__(self, idx: int) -> torch.Tensor:
try:
clip = self.clips[idx]
audio_name = clip['name']
audio_id = clip['id']
caption = clip['caption']
start_sample = clip['start_sample']
end_sample = clip['end_sample']
audio_path = self.root / f'{audio_name}.flac'
if not audio_path.exists():
audio_path = self.root / f'{audio_name}.wav'
assert audio_path.exists()
audio_chunk, sample_rate = torchaudio.load(audio_path)
audio_chunk = audio_chunk.mean(dim=0) # mono
abs_max = audio_chunk.abs().max()
if self.normalize_audio:
audio_chunk = audio_chunk / abs_max * 0.95
if self.reject_silent and abs_max < 1e-6:
log.warning(f'Rejecting silent audio')
return None
audio_chunk = audio_chunk[start_sample:end_sample]
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.num_samples:
raise ValueError('Audio is too short')
audio_chunk = audio_chunk[:self.num_samples]
tokens = self.tokenizer([caption])[0]
output = {
'waveform': audio_chunk,
'id': audio_id,
'caption': caption,
'tokens': tokens,
}
return output
except Exception as e:
log.error(f'Error reading {audio_path}: {e}')
return None
def __len__(self):
return len(self.clips)
|