AudioMorphix / src /data /dataset.py
JinhuaL1ANG's picture
v1
9a6dac6
import os
import yaml
import random
import torch
import torchaudio
import json
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset
from .processors import NaiveAudioProcessor, WaveformAudioProcessor, FbankAudioProcessor
from ..utils import load_json
def label2caption(label, background_sound=None, template="{} can be heard"):
r"""This is a helper function converting list of labels to captions."""
if background_sound is None:
return [template.format(", ".join(l)) for l in label]
if isinstance(background_sound, str):
background_sound = [[background_sound]] * len(label)
assert len(label) == len(
background_sound
), "the number of `background_sound` should match the number of `label`."
caption = []
for l, bg in zip(label, background_sound):
cap = template.format(", ".join(l))
cap += " with the background sounds of {}".format(", ".join(bg))
caption.append(cap)
return caption
class AudioDataset(Dataset):
def __init__(
self,
metadata_root: str = "/mnt/bn/lqhaoheliu/metadata/processed/dataset_root.json",
dataset_name: list = ["audioset"],
split: str = "train",
include_caption: bool = True,
enable_mixup: bool = False,
audio_processor: NaiveAudioProcessor = NaiveAudioProcessor(),
):
"""
Dataset that manages audio recordings.
:param audio_conf: Dictionary containing the audio loading and preprocessing settings
:param dataset_json_file
"""
self.metadata_root = load_json(metadata_root)
self.dataset_name = dataset_name
self.split = split
self.include_caption = include_caption
self.audio_processor = audio_processor
self.enable_mixup = enable_mixup
self.mixture_caption_template = "{} | {}"
if self.enable_mixup:
print(
f"Template for the caption of mixture is: {self.mixture_caption_template}"
)
self.build_dataset()
print("Dataset initialization finished.")
def __getitem__(self, index):
datum = self.data[index]
fname = datum["wav"] # base name of the wav file
mix_datum = {"wav": None}
if self.enable_mixup:
if random.random() > 0.5:
mix_datum = self.data[random.randint(0, len(self.data) - 1)]
fname += " " + mix_datum["wav"]
data = {"fname": fname}
if self.include_caption:
caption = self.get_caption_from_datum(
datum,
mix_datum,
template_description=self.mixture_caption_template,
)
data.update({"caption": caption})
data.update(self.audio_processor(datum["wav"], mix_datum["wav"]))
return data
def text_to_filename(self, text):
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
def get_dataset_root_path(self, dataset):
assert dataset in self.metadata_root.keys()
return self.metadata_root[dataset]
def get_dataset_metadata_path(self, dataset, key):
# key: train, test, val, class_label_indices
try:
if dataset in self.metadata_root["metadata"]["path"].keys():
return self.metadata_root["metadata"]["path"][dataset][key]
except:
raise ValueError(
'Dataset %s does not metadata "%s" specified' % (dataset, key)
)
def __len__(self):
return len(self.data)
def _relative_path_to_absolute_path(self, metadata, dataset_name):
root_path = self.get_dataset_root_path(dataset_name)
for i in range(len(metadata["data"])):
assert "wav" in metadata["data"][i].keys(), metadata["data"][i]
assert metadata["data"][i]["wav"][0] != "/", (
"The dataset metadata should only contain relative path to the audio file: "
+ str(metadata["data"][i]["wav"])
)
metadata["data"][i]["wav"] = os.path.join(
root_path, metadata["data"][i]["wav"]
)
return metadata
def build_dataset(self):
self.data = []
print("Build dataset split %s from %s" % (self.split, self.dataset_name))
if type(self.dataset_name) is str:
data_json = load_json(
self.get_dataset_metadata_path(self.dataset_name, key=self.split)
)
data_json = self._relative_path_to_absolute_path(
data_json, self.dataset_name
)
self.data = data_json["data"]
elif type(self.dataset_name) is list:
for dataset_name in self.dataset_name:
data_json = load_json(
self.get_dataset_metadata_path(dataset_name, key=self.split)
)
data_json = self._relative_path_to_absolute_path(
data_json, dataset_name
)
self.data += data_json["data"]
else:
raise Exception("Invalid data format")
print("Data size: {}".format(len(self.data)))
def is_contain_caption(self, datum):
if datum is not None:
caption_keys = [x for x in datum.keys() if ("caption" in x)]
return len(caption_keys) > 0
else:
return False
def _read_datum_caption(self, datum):
if datum is not None:
caption_keys = [x for x in datum.keys() if ("caption" in x)]
random_index = torch.randint(0, len(caption_keys), (1,))[0].item()
return datum[caption_keys[random_index]]
else:
return "" # NOTE: return empty string if datum is not provided
def label_indices_to_text(
self,
datum,
label_indices,
template_description: str = "{}", # e.g., "This audio contains the sound of {}"
):
if self.is_contain_caption(datum):
return self._read_datum_caption(datum)
elif "label" in datum.keys():
name_indices = torch.where(label_indices > 0.1)[0]
labels = ""
for id, each in enumerate(name_indices):
if id == len(name_indices) - 1:
labels += "%s." % self.num2label[int(each)]
else:
labels += "%s, " % self.num2label[int(each)]
return template_description.format(labels)
else:
return "" # NOTE: return empty string if both label and caption are not provided
def get_sample_text_caption(self, datum, mix_datum, label_indices):
text = self.label_indices_to_text(datum, label_indices)
if mix_datum is not None:
text += " " + self.label_indices_to_text(mix_datum, label_indices)
return text
def get_caption_from_datum(
self, datum, mix_datum=None, template_description="{} {}"
):
caption = ""
if self.is_contain_caption(datum):
caption += self._read_datum_caption(datum)
# Mixup the caption if `mix_datum` is not None
if mix_datum is not None and self.is_contain_caption(mix_datum):
mix_caption = self._read_datum_caption(mix_datum)
caption = template_description.format(caption, mix_caption)
return caption
if __name__ == "__main__":
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
dataset = AudioDataset(
dataset_name=["audiocaps"],
include_caption=True,
enable_mixup=True,
audio_processor=FbankAudioProcessor(),
)
loader = DataLoader(dataset, batch_size=2, num_workers=0, shuffle=True)
for cnt, each in tqdm(enumerate(loader)):
# print(each["waveform"].size(), each["log_mel_spec"].size())
# print(each['freq_energy_percentile'])
import ipdb
ipdb.set_trace()