File size: 4,381 Bytes
73ed896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import logging
import os
from pathlib import Path
from typing import Union

import open_clip
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset

log = logging.getLogger()


class WavTextClipsDataset(Dataset):

    def __init__(
        self,
        root: Union[str, Path],
        *,
        captions_tsv: Union[str, Path],
        clips_tsv: Union[str, Path],
        sample_rate: int,
        num_samples: int,
        normalize_audio: bool = False,
        reject_silent: bool = False,
        tokenizer_id: str = 'ViT-H-14-378-quickgelu',
    ):
        self.root = Path(root)
        self.sample_rate = sample_rate
        self.num_samples = num_samples
        self.normalize_audio = normalize_audio
        self.reject_silent = reject_silent
        self.tokenizer = open_clip.get_tokenizer(tokenizer_id)

        audios = sorted(os.listdir(self.root))
        audios = set([
            Path(audio).stem for audio in audios
            if audio.endswith('.wav') or audio.endswith('.flac')
        ])
        self.captions = {}

        # read the caption tsv
        df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
        for record in df_list:
            id = record['id']
            caption = record['caption']
            self.captions[id] = caption

        # read the clip tsv
        df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
            'id': str,
            'name': str
        }).to_dict('records')
        self.clips = []
        for record in df_list:
            record['id'] = record['id']
            record['name'] = record['name']
            id = record['id']
            name = record['name']
            if name not in self.captions:
                log.warning(f'Audio {name} not found in {captions_tsv}')
                continue
            record['caption'] = self.captions[name]
            self.clips.append(record)

        log.info(f'Found {len(self.clips)} audio files in {self.root}')

        self.resampler = {}

    def __getitem__(self, idx: int) -> torch.Tensor:
        try:
            clip = self.clips[idx]
            audio_name = clip['name']
            audio_id = clip['id']
            caption = clip['caption']
            start_sample = clip['start_sample']
            end_sample = clip['end_sample']

            audio_path = self.root / f'{audio_name}.flac'
            if not audio_path.exists():
                audio_path = self.root / f'{audio_name}.wav'
                assert audio_path.exists()

            audio_chunk, sample_rate = torchaudio.load(audio_path)
            audio_chunk = audio_chunk.mean(dim=0)  # mono
            abs_max = audio_chunk.abs().max()
            if self.normalize_audio:
                audio_chunk = audio_chunk / abs_max * 0.95

            if self.reject_silent and abs_max < 1e-6:
                log.warning(f'Rejecting silent audio')
                return None

            audio_chunk = audio_chunk[start_sample:end_sample]

            # resample
            if sample_rate == self.sample_rate:
                audio_chunk = audio_chunk
            else:
                if sample_rate not in self.resampler:
                    # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
                    self.resampler[sample_rate] = torchaudio.transforms.Resample(
                        sample_rate,
                        self.sample_rate,
                        lowpass_filter_width=64,
                        rolloff=0.9475937167399596,
                        resampling_method='sinc_interp_kaiser',
                        beta=14.769656459379492,
                    )
                audio_chunk = self.resampler[sample_rate](audio_chunk)

            if audio_chunk.shape[0] < self.num_samples:
                raise ValueError('Audio is too short')
            audio_chunk = audio_chunk[:self.num_samples]

            tokens = self.tokenizer([caption])[0]

            output = {
                'waveform': audio_chunk,
                'id': audio_id,
                'caption': caption,
                'tokens': tokens,
            }

            return output
        except Exception as e:
            log.error(f'Error reading {audio_path}: {e}')
            return None

    def __len__(self):
        return len(self.clips)