JinhuaL1ANG's picture
v1
9a6dac6
import os
from audioldm_eval.datasets.load_mel import load_npy_data, MelPairedDataset, WaveDataset
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from audioldm_eval.metrics.fad import FrechetAudioDistance
from audioldm_eval import calculate_fid, calculate_isc, calculate_kid, calculate_kl
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from audioldm_eval.feature_extractors.panns import Cnn14
from audioldm_eval.audio.tools import save_pickle, load_pickle, write_json, load_json
from ssr_eval.metrics import AudioMetrics
import audioldm_eval.audio as Audio
import time
class EvaluationHelper:
def __init__(self, sampling_rate, device, backbone="cnn14") -> None:
self.device = device
self.backbone = backbone
self.sampling_rate = sampling_rate
self.frechet = FrechetAudioDistance(
use_pca=False,
use_activation=False,
verbose=False,
)
self.lsd_metric = AudioMetrics(self.sampling_rate)
self.frechet.model = self.frechet.model.to(device)
features_list = ["2048", "logits"]
if self.sampling_rate == 16000:
self.mel_model = Cnn14(
features_list=features_list,
sample_rate=16000,
window_size=512,
hop_size=160,
mel_bins=64,
fmin=50,
fmax=8000,
classes_num=527,
)
elif self.sampling_rate == 32000:
self.mel_model = Cnn14(
features_list=features_list,
sample_rate=32000,
window_size=1024,
hop_size=320,
mel_bins=64,
fmin=50,
fmax=14000,
classes_num=527,
)
else:
raise ValueError(
"We only support the evaluation on 16kHz and 32kHz sampling rate."
)
if self.sampling_rate == 16000:
self._stft = Audio.TacotronSTFT(512, 160, 512, 64, 16000, 50, 8000)
elif self.sampling_rate == 32000:
self._stft = Audio.TacotronSTFT(1024, 320, 1024, 64, 32000, 50, 14000)
else:
raise ValueError(
"We only support the evaluation on 16kHz and 32kHz sampling rate."
)
self.mel_model.eval()
self.mel_model.to(self.device)
self.fbin_mean, self.fbin_std = None, None
def main(
self,
generate_files_path,
groundtruth_path,
limit_num=None,
):
self.file_init_check(generate_files_path)
self.file_init_check(groundtruth_path)
same_name = self.get_filename_intersection_ratio(
generate_files_path, groundtruth_path, limit_num=limit_num
)
metrics = self.calculate_metrics(generate_files_path, groundtruth_path, same_name, limit_num)
return metrics
def file_init_check(self, dir):
assert os.path.exists(dir), "The path does not exist %s" % dir
assert len(os.listdir(dir)) > 1, "There is no files in %s" % dir
def get_filename_intersection_ratio(
self, dir1, dir2, threshold=0.99, limit_num=None
):
self.datalist1 = [os.path.join(dir1, x) for x in os.listdir(dir1)]
self.datalist1 = sorted(self.datalist1)
self.datalist1 = [item for item in self.datalist1 if item.endswith(".wav")]
self.datalist2 = [os.path.join(dir2, x) for x in os.listdir(dir2)]
self.datalist2 = sorted(self.datalist2)
self.datalist2 = [item for item in self.datalist2 if item.endswith(".wav")]
data_dict1 = {os.path.basename(x): x for x in self.datalist1}
data_dict2 = {os.path.basename(x): x for x in self.datalist2}
keyset1 = set(data_dict1.keys())
keyset2 = set(data_dict2.keys())
intersect_keys = keyset1.intersection(keyset2)
if (
len(intersect_keys) / len(keyset1) > threshold
and len(intersect_keys) / len(keyset2) > threshold
):
'''
print(
"+Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=True"
% (len(intersect_keys), len(keyset1), len(keyset2))
)
'''
return True
else:
'''
print(
"-Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=False"
% (len(intersect_keys), len(keyset1), len(keyset2))
)
'''
return False
def calculate_lsd(self, pairedloader, same_name=True, time_offset=160 * 7):
if same_name == False:
return {
"lsd": -1,
"ssim_stft": -1,
}
# print("Calculating LSD using a time offset of %s ..." % time_offset)
lsd_avg = []
ssim_stft_avg = []
for _, _, filename, (audio1, audio2) in tqdm(pairedloader, leave=False):
audio1 = audio1.cpu().numpy()[0, 0]
audio2 = audio2.cpu().numpy()[0, 0]
# If you use HIFIGAN (verified on 2023-01-12), you need seven frames' offset
audio1 = audio1[time_offset:]
audio1 = audio1 - np.mean(audio1)
audio2 = audio2 - np.mean(audio2)
audio1 = audio1 / np.max(np.abs(audio1))
audio2 = audio2 / np.max(np.abs(audio2))
min_len = min(audio1.shape[0], audio2.shape[0])
audio1, audio2 = audio1[:min_len], audio2[:min_len]
try:
result = self.lsd(audio1, audio2)
lsd_avg.append(result["lsd"])
ssim_stft_avg.append(result["ssim"])
except:
continue
return {"lsd": np.mean(lsd_avg), "ssim_stft": np.mean(ssim_stft_avg)}
def lsd(self, audio1, audio2):
result = self.lsd_metric.evaluation(audio1, audio2, None)
return result
def calculate_psnr_ssim(self, pairedloader, same_name=True):
if same_name == False:
return {"psnr": -1, "ssim": -1}
psnr_avg = []
ssim_avg = []
for mel_gen, mel_target, filename, _ in tqdm(pairedloader, leave=False):
mel_gen = mel_gen.cpu().numpy()[0]
mel_target = mel_target.cpu().numpy()[0]
psnrval = psnr(mel_gen, mel_target)
if np.isinf(psnrval):
print("Infinite value encountered in psnr %s " % filename)
continue
psnr_avg.append(psnrval)
ssim_avg.append(ssim(mel_gen, mel_target))
return {"psnr": np.mean(psnr_avg), "ssim": np.mean(ssim_avg)}
def calculate_metrics(self, generate_files_path, groundtruth_path, same_name, limit_num=None):
# Generation, target
torch.manual_seed(0)
num_workers = 0
outputloader = DataLoader(
WaveDataset(
generate_files_path,
self.sampling_rate,
limit_num=limit_num,
),
batch_size=1,
sampler=None,
num_workers=num_workers,
)
resultloader = DataLoader(
WaveDataset(
groundtruth_path,
self.sampling_rate,
limit_num=limit_num,
),
batch_size=1,
sampler=None,
num_workers=num_workers,
)
pairedloader = DataLoader(
MelPairedDataset(
generate_files_path,
groundtruth_path,
self._stft,
self.sampling_rate,
self.fbin_mean,
self.fbin_std,
limit_num=limit_num,
),
batch_size=1,
sampler=None,
num_workers=16,
)
out = {}
metric_lsd = self.calculate_lsd(pairedloader, same_name=same_name)
out.update(metric_lsd)
featuresdict_2 = self.get_featuresdict(resultloader)
featuresdict_1 = self.get_featuresdict(outputloader)
# if cfg.have_kl:
metric_psnr_ssim = self.calculate_psnr_ssim(pairedloader, same_name=same_name)
out.update(metric_psnr_ssim)
metric_kl, kl_ref, paths_1 = calculate_kl(
featuresdict_1, featuresdict_2, "logits", same_name
)
out.update(metric_kl)
metric_isc = calculate_isc(
featuresdict_1,
feat_layer_name="logits",
splits=10,
samples_shuffle=True,
rng_seed=2020,
)
out.update(metric_isc)
metric_fid = calculate_fid(
featuresdict_1, featuresdict_2, feat_layer_name="2048"
)
out.update(metric_fid)
# Gen, target
fad_score = self.frechet.score(generate_files_path, groundtruth_path, limit_num=limit_num)
out.update(fad_score)
metric_kid = calculate_kid(
featuresdict_1,
featuresdict_2,
feat_layer_name="2048",
subsets=100,
subset_size=1000,
degree=3,
gamma=None,
coef0=1,
rng_seed=2020,
)
out.update(metric_kid)
'''
print("\n".join((f"{k}: {v:.7f}" for k, v in out.items())))
print("\n")
print(limit_num)
print(
f'KL_Sigmoid: {out.get("kullback_leibler_divergence_sigmoid", float("nan")):8.5f};',
f'KL: {out.get("kullback_leibler_divergence_softmax", float("nan")):8.5f};',
f'PSNR: {out.get("psnr", float("nan")):.5f}',
f'SSIM: {out.get("ssim", float("nan")):.5f}',
f'ISc: {out.get("inception_score_mean", float("nan")):8.5f} ({out.get("inception_score_std", float("nan")):5f});',
f'KID: {out.get("kernel_inception_distance_mean", float("nan")):.5f}',
f'({out.get("kernel_inception_distance_std", float("nan")):.5f})',
f'FD: {out.get("frechet_distance", float("nan")):8.5f};',
f'FAD: {out.get("frechet_audio_distance", float("nan")):.5f}',
f'LSD: {out.get("lsd", float("nan")):.5f}',
f'SSIM_STFT: {out.get("ssim_stft", float("nan")):.5f}',
)
'''
result = {
"frechet_distance": out.get("frechet_distance", float("nan")),
"frechet_audio_distance": out.get("frechet_audio_distance", float("nan")),
"kl_sigmoid": out.get(
"kullback_leibler_divergence_sigmoid", float("nan")
),
"kl_softmax": out.get(
"kullback_leibler_divergence_softmax", float("nan")
),
"lsd": out.get("lsd", float("nan")),
"psnr": out.get("psnr", float("nan")),
"ssim": out.get("ssim", float("nan")),
"ssim_stft": out.get("ssim_stft", float("nan")),
"is_mean": out.get("inception_score_mean", float("nan")),
"is_std": out.get("inception_score_std", float("nan")),
"kid_mean": out.get(
"kernel_inception_distance_mean", float("nan")
),
"kid_std": out.get(
"kernel_inception_distance_std", float("nan")
),
}
result = {k: round(v, 4) for k, v in result.items()}
json_path = generate_files_path + "_evaluation_results.json"
write_json(result, json_path)
return result
def get_featuresdict(self, dataloader):
out = None
out_meta = None
# transforms=StandardNormalizeAudio()
for waveform, filename in tqdm(dataloader, leave=False):
try:
metadict = {
"file_path_": filename,
}
waveform = waveform.squeeze(1)
# batch = transforms(batch)
waveform = waveform.float().to(self.device)
with torch.no_grad():
featuresdict = self.mel_model(waveform)
# featuresdict = self.mel_model.convert_features_tuple_to_dict(features)
featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()}
if out is None:
out = featuresdict
else:
out = {k: out[k] + featuresdict[k] for k in out.keys()}
if out_meta is None:
out_meta = metadict
else:
out_meta = {k: out_meta[k] + metadict[k] for k in out_meta.keys()}
except Exception as e:
import ipdb
ipdb.set_trace()
print("PANNs Inference error: ", e)
continue
out = {k: torch.cat(v, dim=0) for k, v in out.items()}
return {**out, **out_meta}
def sample_from(self, samples, number_to_use):
assert samples.shape[0] >= number_to_use
rand_order = np.random.permutation(samples.shape[0])
return samples[rand_order[: samples.shape[0]], :]
'''
if __name__ == "__main__":
import yaml
import argparse
from audioldm_eval import EvaluationHelper
import torch
parser = argparse.ArgumentParser()
parser.add_argument(
"-g",
"--generation_result_path",
type=str,
required=False,
help="Audio sampling rate during evaluation",
default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/balanced_train_segments",
)
parser.add_argument(
"-t",
"--target_audio_path",
type=str,
required=False,
help="Audio sampling rate during evaluation",
default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/eval_segments",
)
parser.add_argument(
"-sr",
"--sampling_rate",
type=int,
required=False,
help="Audio sampling rate during evaluation",
default=16000,
)
parser.add_argument(
"-l",
"--limit_num",
type=int,
required=False,
help="Audio clip numbers limit for evaluation",
default=None,
)
args = parser.parse_args()
device = torch.device(f"cuda:{0}")
evaluator = EvaluationHelper(args.sampling_rate, device)
metrics = evaluator.main(
args.generation_result_path,
args.target_audio_path,
limit_num=args.limit_num,
same_name=args.same_name,
)
print(metrics)
'''