File size: 1,778 Bytes
1763424
 
 
 
 
a246192
1763424
 
 
 
 
 
a246192
 
 
 
 
1763424
 
a246192
1763424
a246192
 
1763424
a246192
 
 
 
 
 
1763424
a246192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1763424
a246192
1763424
a246192
 
 
1763424
a246192
 
1763424
a246192
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
import sys
import torch

from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, load_metric
import torchaudio.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = sys.argv[1]
lang = sys.argv[2]
lang_phoneme = sys.argv[3]
num_samples = int(sys.argv[4])

model = AutoModelForCTC.from_pretrained(model_id).to(device)
processor = AutoProcessor.from_pretrained(model_id)

ds = load_dataset("common_voice", lang, split="test", streaming=True)
sample_iter = iter(ds)

wer = load_metric("wer")
cer = load_metric("cer")

targets_ids = []
predictions_ids = []
for i in range(num_samples):
    sample = next(sample_iter)
    resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
    input_values = processor(resampled_audio, return_tensors="pt").input_values

    with torch.no_grad():
        logits = model(input_values.to(device)).logits

    prediction_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(prediction_ids)

    print(f"Correct: {sample['sentence']}")
    print(f"Predict: {transcription}")
    print(20 * '-')

    predictions_ids.append(prediction_ids[0].tolist())

    kwargs = {}
    if len(lang_phoneme) > 0:
        kwargs["phonemizer_lang"] = lang_phoneme

    targets_ids.append(processor.tokenizer(sample["sentence"], **kwargs).input_ids)

print("Compute metrics.....")

import ipdb; ipdb.set_trace()
transcriptions = processor.batch_decode(predictions_ids)
targets_str = processor.batch_decode(targets_ids, group_tokens=False)

wer = wer.compute(predictions=transcriptions, references=targets_str)
cer = cer.compute(predictions=transcriptions, references=targets_str)

print("wer", wer)
print("cer", cer)