Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from audioldm_eval.datasets.load_mel import load_npy_data, MelPairedDataset, WaveDataset | |
import numpy as np | |
import argparse | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from audioldm_eval.metrics.fad import FrechetAudioDistance | |
from audioldm_eval import calculate_fid, calculate_isc, calculate_kid, calculate_kl | |
from skimage.metrics import peak_signal_noise_ratio as psnr | |
from skimage.metrics import structural_similarity as ssim | |
from audioldm_eval.feature_extractors.panns import Cnn14 | |
from audioldm_eval.audio.tools import save_pickle, load_pickle, write_json, load_json | |
from ssr_eval.metrics import AudioMetrics | |
import audioldm_eval.audio as Audio | |
import time | |
class EvaluationHelper: | |
def __init__(self, sampling_rate, device, backbone="cnn14") -> None: | |
self.device = device | |
self.backbone = backbone | |
self.sampling_rate = sampling_rate | |
self.frechet = FrechetAudioDistance( | |
use_pca=False, | |
use_activation=False, | |
verbose=False, | |
) | |
self.lsd_metric = AudioMetrics(self.sampling_rate) | |
self.frechet.model = self.frechet.model.to(device) | |
features_list = ["2048", "logits"] | |
if self.sampling_rate == 16000: | |
self.mel_model = Cnn14( | |
features_list=features_list, | |
sample_rate=16000, | |
window_size=512, | |
hop_size=160, | |
mel_bins=64, | |
fmin=50, | |
fmax=8000, | |
classes_num=527, | |
) | |
elif self.sampling_rate == 32000: | |
self.mel_model = Cnn14( | |
features_list=features_list, | |
sample_rate=32000, | |
window_size=1024, | |
hop_size=320, | |
mel_bins=64, | |
fmin=50, | |
fmax=14000, | |
classes_num=527, | |
) | |
else: | |
raise ValueError( | |
"We only support the evaluation on 16kHz and 32kHz sampling rate." | |
) | |
if self.sampling_rate == 16000: | |
self._stft = Audio.TacotronSTFT(512, 160, 512, 64, 16000, 50, 8000) | |
elif self.sampling_rate == 32000: | |
self._stft = Audio.TacotronSTFT(1024, 320, 1024, 64, 32000, 50, 14000) | |
else: | |
raise ValueError( | |
"We only support the evaluation on 16kHz and 32kHz sampling rate." | |
) | |
self.mel_model.eval() | |
self.mel_model.to(self.device) | |
self.fbin_mean, self.fbin_std = None, None | |
def main( | |
self, | |
generate_files_path, | |
groundtruth_path, | |
limit_num=None, | |
): | |
self.file_init_check(generate_files_path) | |
self.file_init_check(groundtruth_path) | |
same_name = self.get_filename_intersection_ratio( | |
generate_files_path, groundtruth_path, limit_num=limit_num | |
) | |
metrics = self.calculate_metrics(generate_files_path, groundtruth_path, same_name, limit_num) | |
return metrics | |
def file_init_check(self, dir): | |
assert os.path.exists(dir), "The path does not exist %s" % dir | |
assert len(os.listdir(dir)) > 1, "There is no files in %s" % dir | |
def get_filename_intersection_ratio( | |
self, dir1, dir2, threshold=0.99, limit_num=None | |
): | |
self.datalist1 = [os.path.join(dir1, x) for x in os.listdir(dir1)] | |
self.datalist1 = sorted(self.datalist1) | |
self.datalist1 = [item for item in self.datalist1 if item.endswith(".wav")] | |
self.datalist2 = [os.path.join(dir2, x) for x in os.listdir(dir2)] | |
self.datalist2 = sorted(self.datalist2) | |
self.datalist2 = [item for item in self.datalist2 if item.endswith(".wav")] | |
data_dict1 = {os.path.basename(x): x for x in self.datalist1} | |
data_dict2 = {os.path.basename(x): x for x in self.datalist2} | |
keyset1 = set(data_dict1.keys()) | |
keyset2 = set(data_dict2.keys()) | |
intersect_keys = keyset1.intersection(keyset2) | |
if ( | |
len(intersect_keys) / len(keyset1) > threshold | |
and len(intersect_keys) / len(keyset2) > threshold | |
): | |
''' | |
print( | |
"+Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=True" | |
% (len(intersect_keys), len(keyset1), len(keyset2)) | |
) | |
''' | |
return True | |
else: | |
''' | |
print( | |
"-Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=False" | |
% (len(intersect_keys), len(keyset1), len(keyset2)) | |
) | |
''' | |
return False | |
def calculate_lsd(self, pairedloader, same_name=True, time_offset=160 * 7): | |
if same_name == False: | |
return { | |
"lsd": -1, | |
"ssim_stft": -1, | |
} | |
# print("Calculating LSD using a time offset of %s ..." % time_offset) | |
lsd_avg = [] | |
ssim_stft_avg = [] | |
for _, _, filename, (audio1, audio2) in tqdm(pairedloader, leave=False): | |
audio1 = audio1.cpu().numpy()[0, 0] | |
audio2 = audio2.cpu().numpy()[0, 0] | |
# If you use HIFIGAN (verified on 2023-01-12), you need seven frames' offset | |
audio1 = audio1[time_offset:] | |
audio1 = audio1 - np.mean(audio1) | |
audio2 = audio2 - np.mean(audio2) | |
audio1 = audio1 / np.max(np.abs(audio1)) | |
audio2 = audio2 / np.max(np.abs(audio2)) | |
min_len = min(audio1.shape[0], audio2.shape[0]) | |
audio1, audio2 = audio1[:min_len], audio2[:min_len] | |
try: | |
result = self.lsd(audio1, audio2) | |
lsd_avg.append(result["lsd"]) | |
ssim_stft_avg.append(result["ssim"]) | |
except: | |
continue | |
return {"lsd": np.mean(lsd_avg), "ssim_stft": np.mean(ssim_stft_avg)} | |
def lsd(self, audio1, audio2): | |
result = self.lsd_metric.evaluation(audio1, audio2, None) | |
return result | |
def calculate_psnr_ssim(self, pairedloader, same_name=True): | |
if same_name == False: | |
return {"psnr": -1, "ssim": -1} | |
psnr_avg = [] | |
ssim_avg = [] | |
for mel_gen, mel_target, filename, _ in tqdm(pairedloader, leave=False): | |
mel_gen = mel_gen.cpu().numpy()[0] | |
mel_target = mel_target.cpu().numpy()[0] | |
psnrval = psnr(mel_gen, mel_target) | |
if np.isinf(psnrval): | |
print("Infinite value encountered in psnr %s " % filename) | |
continue | |
psnr_avg.append(psnrval) | |
ssim_avg.append(ssim(mel_gen, mel_target)) | |
return {"psnr": np.mean(psnr_avg), "ssim": np.mean(ssim_avg)} | |
def calculate_metrics(self, generate_files_path, groundtruth_path, same_name, limit_num=None): | |
# Generation, target | |
torch.manual_seed(0) | |
num_workers = 0 | |
outputloader = DataLoader( | |
WaveDataset( | |
generate_files_path, | |
self.sampling_rate, | |
limit_num=limit_num, | |
), | |
batch_size=1, | |
sampler=None, | |
num_workers=num_workers, | |
) | |
resultloader = DataLoader( | |
WaveDataset( | |
groundtruth_path, | |
self.sampling_rate, | |
limit_num=limit_num, | |
), | |
batch_size=1, | |
sampler=None, | |
num_workers=num_workers, | |
) | |
pairedloader = DataLoader( | |
MelPairedDataset( | |
generate_files_path, | |
groundtruth_path, | |
self._stft, | |
self.sampling_rate, | |
self.fbin_mean, | |
self.fbin_std, | |
limit_num=limit_num, | |
), | |
batch_size=1, | |
sampler=None, | |
num_workers=16, | |
) | |
out = {} | |
metric_lsd = self.calculate_lsd(pairedloader, same_name=same_name) | |
out.update(metric_lsd) | |
featuresdict_2 = self.get_featuresdict(resultloader) | |
featuresdict_1 = self.get_featuresdict(outputloader) | |
# if cfg.have_kl: | |
metric_psnr_ssim = self.calculate_psnr_ssim(pairedloader, same_name=same_name) | |
out.update(metric_psnr_ssim) | |
metric_kl, kl_ref, paths_1 = calculate_kl( | |
featuresdict_1, featuresdict_2, "logits", same_name | |
) | |
out.update(metric_kl) | |
metric_isc = calculate_isc( | |
featuresdict_1, | |
feat_layer_name="logits", | |
splits=10, | |
samples_shuffle=True, | |
rng_seed=2020, | |
) | |
out.update(metric_isc) | |
metric_fid = calculate_fid( | |
featuresdict_1, featuresdict_2, feat_layer_name="2048" | |
) | |
out.update(metric_fid) | |
# Gen, target | |
fad_score = self.frechet.score(generate_files_path, groundtruth_path, limit_num=limit_num) | |
out.update(fad_score) | |
metric_kid = calculate_kid( | |
featuresdict_1, | |
featuresdict_2, | |
feat_layer_name="2048", | |
subsets=100, | |
subset_size=1000, | |
degree=3, | |
gamma=None, | |
coef0=1, | |
rng_seed=2020, | |
) | |
out.update(metric_kid) | |
''' | |
print("\n".join((f"{k}: {v:.7f}" for k, v in out.items()))) | |
print("\n") | |
print(limit_num) | |
print( | |
f'KL_Sigmoid: {out.get("kullback_leibler_divergence_sigmoid", float("nan")):8.5f};', | |
f'KL: {out.get("kullback_leibler_divergence_softmax", float("nan")):8.5f};', | |
f'PSNR: {out.get("psnr", float("nan")):.5f}', | |
f'SSIM: {out.get("ssim", float("nan")):.5f}', | |
f'ISc: {out.get("inception_score_mean", float("nan")):8.5f} ({out.get("inception_score_std", float("nan")):5f});', | |
f'KID: {out.get("kernel_inception_distance_mean", float("nan")):.5f}', | |
f'({out.get("kernel_inception_distance_std", float("nan")):.5f})', | |
f'FD: {out.get("frechet_distance", float("nan")):8.5f};', | |
f'FAD: {out.get("frechet_audio_distance", float("nan")):.5f}', | |
f'LSD: {out.get("lsd", float("nan")):.5f}', | |
f'SSIM_STFT: {out.get("ssim_stft", float("nan")):.5f}', | |
) | |
''' | |
result = { | |
"frechet_distance": out.get("frechet_distance", float("nan")), | |
"frechet_audio_distance": out.get("frechet_audio_distance", float("nan")), | |
"kl_sigmoid": out.get( | |
"kullback_leibler_divergence_sigmoid", float("nan") | |
), | |
"kl_softmax": out.get( | |
"kullback_leibler_divergence_softmax", float("nan") | |
), | |
"lsd": out.get("lsd", float("nan")), | |
"psnr": out.get("psnr", float("nan")), | |
"ssim": out.get("ssim", float("nan")), | |
"ssim_stft": out.get("ssim_stft", float("nan")), | |
"is_mean": out.get("inception_score_mean", float("nan")), | |
"is_std": out.get("inception_score_std", float("nan")), | |
"kid_mean": out.get( | |
"kernel_inception_distance_mean", float("nan") | |
), | |
"kid_std": out.get( | |
"kernel_inception_distance_std", float("nan") | |
), | |
} | |
result = {k: round(v, 4) for k, v in result.items()} | |
json_path = generate_files_path + "_evaluation_results.json" | |
write_json(result, json_path) | |
return result | |
def get_featuresdict(self, dataloader): | |
out = None | |
out_meta = None | |
# transforms=StandardNormalizeAudio() | |
for waveform, filename in tqdm(dataloader, leave=False): | |
try: | |
metadict = { | |
"file_path_": filename, | |
} | |
waveform = waveform.squeeze(1) | |
# batch = transforms(batch) | |
waveform = waveform.float().to(self.device) | |
with torch.no_grad(): | |
featuresdict = self.mel_model(waveform) | |
# featuresdict = self.mel_model.convert_features_tuple_to_dict(features) | |
featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()} | |
if out is None: | |
out = featuresdict | |
else: | |
out = {k: out[k] + featuresdict[k] for k in out.keys()} | |
if out_meta is None: | |
out_meta = metadict | |
else: | |
out_meta = {k: out_meta[k] + metadict[k] for k in out_meta.keys()} | |
except Exception as e: | |
import ipdb | |
ipdb.set_trace() | |
print("PANNs Inference error: ", e) | |
continue | |
out = {k: torch.cat(v, dim=0) for k, v in out.items()} | |
return {**out, **out_meta} | |
def sample_from(self, samples, number_to_use): | |
assert samples.shape[0] >= number_to_use | |
rand_order = np.random.permutation(samples.shape[0]) | |
return samples[rand_order[: samples.shape[0]], :] | |
''' | |
if __name__ == "__main__": | |
import yaml | |
import argparse | |
from audioldm_eval import EvaluationHelper | |
import torch | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-g", | |
"--generation_result_path", | |
type=str, | |
required=False, | |
help="Audio sampling rate during evaluation", | |
default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/balanced_train_segments", | |
) | |
parser.add_argument( | |
"-t", | |
"--target_audio_path", | |
type=str, | |
required=False, | |
help="Audio sampling rate during evaluation", | |
default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/eval_segments", | |
) | |
parser.add_argument( | |
"-sr", | |
"--sampling_rate", | |
type=int, | |
required=False, | |
help="Audio sampling rate during evaluation", | |
default=16000, | |
) | |
parser.add_argument( | |
"-l", | |
"--limit_num", | |
type=int, | |
required=False, | |
help="Audio clip numbers limit for evaluation", | |
default=None, | |
) | |
args = parser.parse_args() | |
device = torch.device(f"cuda:{0}") | |
evaluator = EvaluationHelper(args.sampling_rate, device) | |
metrics = evaluator.main( | |
args.generation_result_path, | |
args.target_audio_path, | |
limit_num=args.limit_num, | |
same_name=args.same_name, | |
) | |
print(metrics) | |
''' |