File size: 8,443 Bytes
9d9ac6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import argparse
import os
import string
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

import librosa
import numpy as np
import torch
from evaluate import load
from pymcd.mcd import Calculate_MCD
from tqdm import tqdm
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, Wav2Vec2FeatureExtractor, WavLMForXVector, pipeline


def convert_numbers_to_words(text):
    """Convert single digits in text to words with spaces"""
    number_word_map = {
        "0": "zero",
        "1": "one",
        "2": "two",
        "3": "three",
        "4": "four",
        "5": "five",
        "6": "six",
        "7": "seven",
        "8": "eight",
        "9": "nine",
    }

    words = text.split()
    converted_words = []

    for word in words:
        # Check if the word contains both letters and numbers (like 'j4')
        if any(c.isdigit() for c in word) and any(c.isalpha() for c in word):
            # Split the word into parts and convert digits
            new_word = ""
            for c in word:
                if c.isdigit():
                    new_word += " " + number_word_map[c]
                else:
                    new_word += c
            converted_words.append(new_word)
        # Check if the word is a single digit
        elif word.isdigit() and len(word) == 1:
            converted_words.append(number_word_map[word])
        else:
            converted_words.append(word)

    return " ".join(converted_words)


def clean_text(text):
    text = convert_numbers_to_words(text)
    text = text.translate(str.maketrans("", "", string.punctuation))
    text = text.lower()
    return text


def wer_pipe(gen_dir: str, target_dir: str, model_id="openai/whisper-large-v3-turbo"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    print(f"Using Model: {model_id} for WER Evaluation")
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
    ).to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device,
    )

    gen_list = list(Path(gen_dir).glob("*.wav"))
    for line in tqdm(gen_list, desc="Processing audio files"):
        wav = line
        if not wav.exists():
            continue

        text = pipe(librosa.load(wav, sr=16000)[0], generate_kwargs={"language": "english"})["text"]
        with open(wav.with_suffix(".asrtxt"), "w") as fw:
            fw.write(text)

    wer_metric = load("wer")

    val_list = list(Path(target_dir).glob("*.txt"))

    wer = []
    for txt in tqdm(val_list, desc="Calculating WER"):
        try:
            # Since the original text is automatically transcribed and has not been manually verified, all texts will be cleaned here.

            target_text = " ".join(set(txt.read_text().splitlines()))
            target_text = clean_text(target_text)

            gen_text = " ".join(Path(os.path.join(gen_dir, txt.with_suffix(".asrtxt").name)).read_text().splitlines())
            gen_text = clean_text(gen_text)

            if target_text == "" or gen_text == "":
                continue

            wer_ = wer_metric.compute(references=[target_text], predictions=[gen_text])

        except Exception as e:
            print("Error in wer calculation: ", e)
            continue

        wer.append(wer_)

    return np.mean(wer)


def spk_sim_pipe(gen_dir, target_dir):
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-sv")
    model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-sv").cuda()

    cosine_sim = torch.nn.CosineSimilarity(dim=-1)

    val_list = list(Path(target_dir).glob("*.wav"))

    scos = []

    for target_wav in tqdm(val_list, desc="Calculating speaker similarity"):
        target = librosa.load(target_wav, sr=16000)[0]
        gen = librosa.load(os.path.join(gen_dir, target_wav.name), sr=16000)[0]

        try:
            input1 = feature_extractor(gen, return_tensors="pt", sampling_rate=16000).to("cuda")
            embeddings1 = model(**input1).embeddings

            input2 = feature_extractor(target, return_tensors="pt", sampling_rate=16000).to("cuda")
            embeddings2 = model(**input2).embeddings

            similarity = cosine_sim(embeddings1[0], embeddings2[0])

        except Exception as e:
            print(f"Error in {target_wav}, {e}")
            continue

        scos.append(similarity.detach().cpu().numpy())

    return np.mean(scos)


def calculate_mcd_for_wav(target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl):
    _mcd_dtw = mcd_toolbox_dtw.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name))
    _mcd_dtw_sl = mcd_toolbox_dtw_sl.calculate_mcd(target_wav, os.path.join(gen_dir, target_wav.name))
    return _mcd_dtw, _mcd_dtw_sl


def mcd_pipe(gen_dir, target_dir, num_processes=16):
    mcd_toolbox_dtw = Calculate_MCD(MCD_mode="dtw")
    mcd_toolbox_dtw_sl = Calculate_MCD(MCD_mode="dtw_sl")

    val_list = list(Path(target_dir).glob("*.wav"))

    mcd_dtw = []
    mcd_dtw_sl = []

    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        futures = [
            executor.submit(calculate_mcd_for_wav, target_wav, gen_dir, mcd_toolbox_dtw, mcd_toolbox_dtw_sl)
            for target_wav in val_list
        ]
        for future in tqdm(futures, desc="Calculating MCD"):
            _mcd_dtw, _mcd_dtw_sl = future.result()
            mcd_dtw.append(_mcd_dtw)
            mcd_dtw_sl.append(_mcd_dtw_sl)

    return np.mean(mcd_dtw), np.mean(mcd_dtw_sl)


def run_all_metrics(gen_dir, target_dir, whisper_model="openai/whisper-large-v3-turbo"):
    """Run all evaluation metrics and return results"""
    results = {}

    print("Running WER evaluation...")
    results["wer"] = wer_pipe(gen_dir, target_dir, model_id=whisper_model)

    print("Running speaker similarity evaluation...")
    results["speaker_similarity"] = spk_sim_pipe(gen_dir, target_dir)

    print("Running MCD evaluation...")
    mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir)
    results["mcd_dtw"] = mcd_dtw
    results["mcd_dtw_sl"] = mcd_dtw_sl

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Audio evaluation metrics")
    parser.add_argument("--gen_dir", type=str, required=True, help="Directory containing generated audio files")
    parser.add_argument("--target_dir", type=str, required=True, help="Directory containing target audio files")
    parser.add_argument(
        "--metric",
        type=str,
        default="all",
        choices=["wer", "spk_sim", "mcd", "all"],
        help="Evaluation metric to use",
    )
    parser.add_argument(
        "--whisper_model",
        type=str,
        default="openai/whisper-large-v3-turbo",
        help="Whisper model to use for WER evaluation",
    )
    # python eval.py --gen_dir path/to/generated --target_dir path/to/target
    # keep the name of gen_wav and target_wav the same
    args = parser.parse_args()

    gen_dir = args.gen_dir
    target_dir = args.target_dir

    if not os.path.exists(gen_dir):
        raise ValueError(f"Generated audio directory does not exist: {gen_dir}")
    if not os.path.exists(target_dir):
        raise ValueError(f"Target audio directory does not exist: {target_dir}")

    if args.metric == "all":
        results = run_all_metrics(gen_dir, target_dir, args.whisper_model)
        print("\nEvaluation Results:")
        print(f"WER: {results['wer']:.4f}")
        print(f"Speaker Similarity: {results['speaker_similarity']:.4f}")
        print(f"MCD (DTW): {results['mcd_dtw']:.4f}")
        print(f"MCD (DTW-SL): {results['mcd_dtw_sl']:.4f}")

    elif args.metric == "wer":
        wer = wer_pipe(gen_dir, target_dir, model_id=args.whisper_model)
        print(f"WER: {wer:.4f}")

    elif args.metric == "spk_sim":
        spk_sim = spk_sim_pipe(gen_dir, target_dir)
        print(f"Speaker Similarity: {spk_sim:.4f}")

    elif args.metric == "mcd":
        mcd_dtw, mcd_dtw_sl = mcd_pipe(gen_dir, target_dir)
        print(f"MCD (DTW): {mcd_dtw:.4f}")
        print(f"MCD (DTW-SL): {mcd_dtw_sl:.4f}")