demo / mmaudio /data /eval /audiocaps.py
Phil Sobrepena
initial commit
73ed896
raw
history blame
1.05 kB
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Union
import pandas as pd
import torch
from torch.utils.data.dataset import Dataset
log = logging.getLogger()
class AudioCapsData(Dataset):
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
df = pd.read_csv(csv_path).to_dict(orient='records')
audio_files = sorted(os.listdir(audio_path))
audio_files = set(
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
self.data = []
for row in df:
self.data.append({
'name': row['name'],
'caption': row['caption'],
})
self.audio_path = Path(audio_path)
self.csv_path = Path(csv_path)
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[idx]
def __len__(self):
return len(self.data)