File size: 3,676 Bytes
a286bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned).
"""

import os
import torch
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.dataio.dataio import read_audio

# ------------------------------------------------------------------
# 1) Paths
# ------------------------------------------------------------------
EXP_DIR = (
    "/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/"
    "emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968"
)
HP_FILE  = os.path.join(EXP_DIR, "hyperparams.yaml")
CKPT_DIR = os.path.join(EXP_DIR, "save")

# ------------------------------------------------------------------
# 2) Load hyperparams and modules
# ------------------------------------------------------------------
with open(HP_FILE) as f:
    hparams = load_hyperpyyaml(f)

modules = {
    "compute_features": hparams["compute_features"],
    "mean_var_norm"   : hparams["mean_var_norm"],
    "embedding_model" : hparams["embedding_model"],
    "classifier"      : hparams["classifier"],
}

# ------------------------------------------------------------------
# 3) Device setup
# ------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

checkpointer = sb.utils.checkpoints.Checkpointer(
    checkpoints_dir=CKPT_DIR,
    recoverables=modules,
    allow_partial_load=True,
)
checkpointer.recover_if_possible()

# ------------------------------------------------------------------
# 4) Simple batch container
# ------------------------------------------------------------------
class SimpleBatch:
    def __init__(self, wav, lens):
        self.sig = (wav, lens)

    def to(self, device):
        wav, lens = self.sig
        self.sig = (wav.to(device), lens.to(device))
        return self

# ------------------------------------------------------------------
# 5) Brain class
# ------------------------------------------------------------------
class EmoIdBrain(sb.Brain):
    def compute_forward(self, batch, stage):
        wavs, lens = batch.sig
        feats = self.modules.compute_features(wavs)
        feats = self.modules.mean_var_norm(feats, lens)
        emb   = self.modules.embedding_model(feats, lens)
        out   = self.modules.classifier(emb)
        return out

brain = EmoIdBrain(
    modules=modules,
    hparams=hparams,
    run_opts={"device": device},
    checkpointer=checkpointer
)

# ------------------------------------------------------------------
# 6) Emotion labels
# ------------------------------------------------------------------
IDX2LAB = [
    "anger", "sadness", "neutral",
    "surprise", "happiness", "fear"
]

# ------------------------------------------------------------------
# 7) Prediction function
# ------------------------------------------------------------------
def predict(wav_path: str) -> str:
    wav_raw = read_audio(wav_path)
    wav = wav_raw.clone().detach().float().unsqueeze(0) if isinstance(wav_raw, torch.Tensor) else torch.tensor(wav_raw, dtype=torch.float32).unsqueeze(0)
    lens = torch.tensor([1.0])

    batch = SimpleBatch(wav, lens).to(device)
    brain.modules.eval()

    with torch.no_grad():
        logits = brain.compute_forward(batch, stage=sb.Stage.TEST)

    idx = int(logits.argmax(dim=-1))
    return IDX2LAB[idx]

# ------------------------------------------------------------------
# 8) Run
# ------------------------------------------------------------------
if __name__ == "__main__":
    WAV_FILE = "shortvoice.wav"
    print("Predicted emotion:", predict(WAV_FILE))