import logging import os from pathlib import Path from typing import Union import open_clip import pandas as pd import torch import torchaudio from torch.utils.data.dataset import Dataset log = logging.getLogger() class WavTextClipsDataset(Dataset): def __init__( self, root: Union[str, Path], *, captions_tsv: Union[str, Path], clips_tsv: Union[str, Path], sample_rate: int, num_samples: int, normalize_audio: bool = False, reject_silent: bool = False, tokenizer_id: str = 'ViT-H-14-378-quickgelu', ): self.root = Path(root) self.sample_rate = sample_rate self.num_samples = num_samples self.normalize_audio = normalize_audio self.reject_silent = reject_silent self.tokenizer = open_clip.get_tokenizer(tokenizer_id) audios = sorted(os.listdir(self.root)) audios = set([ Path(audio).stem for audio in audios if audio.endswith('.wav') or audio.endswith('.flac') ]) self.captions = {} # read the caption tsv df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records') for record in df_list: id = record['id'] caption = record['caption'] self.captions[id] = caption # read the clip tsv df_list = pd.read_csv(clips_tsv, sep='\t', dtype={ 'id': str, 'name': str }).to_dict('records') self.clips = [] for record in df_list: record['id'] = record['id'] record['name'] = record['name'] id = record['id'] name = record['name'] if name not in self.captions: log.warning(f'Audio {name} not found in {captions_tsv}') continue record['caption'] = self.captions[name] self.clips.append(record) log.info(f'Found {len(self.clips)} audio files in {self.root}') self.resampler = {} def __getitem__(self, idx: int) -> torch.Tensor: try: clip = self.clips[idx] audio_name = clip['name'] audio_id = clip['id'] caption = clip['caption'] start_sample = clip['start_sample'] end_sample = clip['end_sample'] audio_path = self.root / f'{audio_name}.flac' if not audio_path.exists(): audio_path = self.root / f'{audio_name}.wav' assert audio_path.exists() audio_chunk, sample_rate = torchaudio.load(audio_path) audio_chunk = audio_chunk.mean(dim=0) # mono abs_max = audio_chunk.abs().max() if self.normalize_audio: audio_chunk = audio_chunk / abs_max * 0.95 if self.reject_silent and abs_max < 1e-6: log.warning(f'Rejecting silent audio') return None audio_chunk = audio_chunk[start_sample:end_sample] # resample if sample_rate == self.sample_rate: audio_chunk = audio_chunk else: if sample_rate not in self.resampler: # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best self.resampler[sample_rate] = torchaudio.transforms.Resample( sample_rate, self.sample_rate, lowpass_filter_width=64, rolloff=0.9475937167399596, resampling_method='sinc_interp_kaiser', beta=14.769656459379492, ) audio_chunk = self.resampler[sample_rate](audio_chunk) if audio_chunk.shape[0] < self.num_samples: raise ValueError('Audio is too short') audio_chunk = audio_chunk[:self.num_samples] tokens = self.tokenizer([caption])[0] output = { 'waveform': audio_chunk, 'id': audio_id, 'caption': caption, 'tokens': tokens, } return output except Exception as e: log.error(f'Error reading {audio_path}: {e}') return None def __len__(self): return len(self.clips)