Spaces:
Running
Running
import argparse | |
import os | |
import string | |
from concurrent.futures import ProcessPoolExecutor | |
from pathlib import Path | |
import librosa | |
import numpy as np | |
import torch | |
from evaluate import load | |
from pymcd.mcd import Calculate_MCD | |
from tqdm import tqdm | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, Wav2Vec2FeatureExtractor, WavLMForXVector, pipeline | |
def convert_numbers_to_words(text): | |
"""Convert single digits in text to words with spaces""" | |
number_word_map = { | |
"0": "zero", | |
"1": "one", | |
"2": "two", | |
"3": "three", | |
"4": "four", | |
"5": "five", | |
"6": "six", | |
"7": "seven", | |
"8": "eight", | |
"9": "nine", | |
} | |
words = text.split() | |
converted_words = [] | |
for word in words: | |
# Check if the word contains both letters and numbers (like 'j4') | |
if any(c.isdigit() for c in word) and any(c.isalpha() for c in word): | |
# Split the word into parts and convert digits | |
new_word = "" | |
for c in word: | |
if c.isdigit(): | |
new_word += " " + number_word_map[c] | |
else: | |
new_word += c | |
converted_words.append(new_word) | |
# Check if the word is a single digit | |
elif word.isdigit() and len(word) == 1: | |
converted_words.append(number_word_map[word]) | |
else: | |
converted_words.append(word) | |
return " ".join(converted_words) | |
def clean_text(text): | |
text = convert_numbers_to_words(text) | |
text = text.translate(str.maketrans("", "", string.punctuation)) | |
text = text.lower() | |
return text | |
def wer_pipe(gen_dir: str, target_dir: str, model_id="openai/whisper-large-v3-turbo"): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
print(f"Using Model: {model_id} for WER Evaluation") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
).to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
gen_list = list(Path(gen_dir).glob("*.wav")) | |
for line in tqdm(gen_list, desc="Processing audio files"): | |
wav = line | |
if not wav.exists(): | |
continue | |
text = pipe(librosa.load(wav, sr=16000)[0], generate_kwargs={"language": "english"})["text"] | |
with open(wav.with_suffix(".asrtxt"), "w") as fw: | |
fw.write(text) | |
wer_metric = load("wer") | |
val_list = list(Path(target_dir).glob("*.txt")) | |
wer = [] | |
for txt in tqdm(val_list, desc="Calculating WER"): | |
try: | |
# Since the original text is automatically transcribed and has not been manually verified, all texts will be cleaned here. | |
target_text = " ".join(set(txt.read_text().splitlines())) | |
target_text = clean_text(target_text) | |
gen_text = " ".join(Path(os.path.join(gen_dir, txt.with_suffix(".asrtxt").name)).read_text().splitlines()) | |
gen_text = clean_text(gen_text) | |
if target_text == "" or gen_text == "": | |
continue | |
wer_ = wer_metric.compute(references=[target_text], predictions=[gen_text]) | |
except Exception as e: | |
print("Error in wer calculation: ", e) | |
continue | |
wer.append(wer_) | |
return np.mean(wer) | |
def spk_sim_pipe(gen_dir, target_dir): | |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-sv") | |
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-sv").cuda() | |
cosine_sim = torch.nn.CosineSimilarity(dim=-1) | |
val_list = list(Path(target_dir).glob("*.wav")) | |
scos = [] | |
for target_wav in tqdm(val_list, desc="Calculating speaker similarity"): | |
target = librosa.load(target_wav, sr=16000)[0] | |
gen = librosa.load(os.path.join(gen_dir, target_wav.name), sr=16000)[0] | |
try: | |
input1 = feature_extractor(gen, return_tensors="pt", sampling_rate=16000).to("cuda") | |
embeddings1 = model(**input1).embeddings | |
input2 = feature_extractor(target, return_tensors="pt", sampling_rate=16000).to("cuda") | |
embeddings2 = model(**input2).embeddings | |
similarity = cosine_sim(embeddings1[0], embeddings2[0]) | |
except Exception as e: | |
print(f"Error in {target_wav}, {e}") | |
continue | |
scos.append(similarity.detach().cpu().numpy()) | |
return np.mean(scos) | |
def calculate_mcd_for_wav(target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl): | |
_mcd_dtw = mcd_toolbox_dtw.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name)) | |
_mcd_dtw_sl = mcd_toolbox_dtw_sl.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name)) | |
return _mcd_dtw, _mcd_dtw_sl | |
def mcd_pipe(gen_dir, target_dir, num_processes=16): | |
mcd_toolbox_dtw = Calculate_MCD(MCD_mode="dtw") | |
mcd_toolbox_dtw_sl = Calculate_MCD(MCD_mode="dtw_sl") | |
val_list = list(Path(target_dir).glob("*.wav")) | |
mcd_dtw = [] | |
mcd_dtw_sl = [] | |
with ProcessPoolExecutor(max_workers=num_processes) as executor: | |
futures = [ | |
executor.submit(calculate_mcd_for_wav, target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl) | |
for target_wav in val_list | |
] | |
for future in tqdm(futures, desc="Calculating MCD"): | |
_mcd_dtw, _mcd_dtw_sl = future.result() | |
mcd_dtw.append(_mcd_dtw) | |
mcd_dtw_sl.append(_mcd_dtw_sl) | |
return np.mean(mcd_dtw), np.mean(mcd_dtw_sl) | |
def run_all_metrics(gen_dir, target_dir, whisper_model="openai/whisper-large-v3-turbo"): | |
"""Run all evaluation metrics and return results""" | |
results = {} | |
print("Running WER evaluation...") | |
results["wer"] = wer_pipe(gen_dir, target_dir, model_id=whisper_model) | |
print("Running speaker similarity evaluation...") | |
results["speaker_similarity"] = spk_sim_pipe(gen_dir, target_dir) | |
print("Running MCD evaluation...") | |
mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir) | |
results["mcd_dtw"] = mcd_dtw | |
results["mcd_dtw_sl"] = mcd_dtw_sl | |
return results | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Audio evaluation metrics") | |
parser.add_argument("--gen_dir", type=str, required=True, help="Directory containing generated audio files") | |
parser.add_argument("--target_dir", type=str, required=True, help="Directory containing target audio files") | |
parser.add_argument( | |
"--metric", | |
type=str, | |
default="all", | |
choices=["wer", "spk_sim", "mcd", "all"], | |
help="Evaluation metric to use", | |
) | |
parser.add_argument( | |
"--whisper_model", | |
type=str, | |
default="openai/whisper-large-v3-turbo", | |
help="Whisper model to use for WER evaluation", | |
) | |
# python eval.py --gen_dir path/to/generated --target_dir path/to/target | |
# keep the name of gen_wav and target_wav the same | |
args = parser.parse_args() | |
gen_dir = args.gen_dir | |
target_dir = args.target_dir | |
if not os.path.exists(gen_dir): | |
raise ValueError(f"Generated audio directory does not exist: {gen_dir}") | |
if not os.path.exists(target_dir): | |
raise ValueError(f"Target audio directory does not exist: {target_dir}") | |
if args.metric == "all": | |
results = run_all_metrics(gen_dir, target_dir, args.whisper_model) | |
print("\nEvaluation Results:") | |
print(f"WER: {results['wer']:.4f}") | |
print(f"Speaker Similarity: {results['speaker_similarity']:.4f}") | |
print(f"MCD (DTW): {results['mcd_dtw']:.4f}") | |
print(f"MCD (DTW-SL): {results['mcd_dtw_sl']:.4f}") | |
elif args.metric == "wer": | |
wer = wer_pipe(gen_dir, target_dir, model_id=args.whisper_model) | |
print(f"WER: {wer:.4f}") | |
elif args.metric == "spk_sim": | |
spk_sim = spk_sim_pipe(gen_dir, target_dir) | |
print(f"Speaker Similarity: {spk_sim:.4f}") | |
elif args.metric == "mcd": | |
mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir) | |
print(f"MCD (DTW): {mcd_dtw:.4f}") | |
print(f"MCD (DTW-SL): {mcd_dtw_sl:.4f}") | |