Update inference.py
Browse files- inference.py +110 -103
inference.py
CHANGED
@@ -1,103 +1,110 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned).
|
5 |
-
"""
|
6 |
-
|
7 |
-
import os
|
8 |
-
import torch
|
9 |
-
import speechbrain as sb
|
10 |
-
from hyperpyyaml import load_hyperpyyaml
|
11 |
-
from speechbrain.dataio.dataio import read_audio
|
12 |
-
|
13 |
-
# ------------------------------------------------------------------
|
14 |
-
# 1)
|
15 |
-
# ------------------------------------------------------------------
|
16 |
-
EXP_DIR = (
|
17 |
-
"/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/"
|
18 |
-
"emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968"
|
19 |
-
)
|
20 |
-
HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml")
|
21 |
-
CKPT_DIR = os.path.join(EXP_DIR, "save")
|
22 |
-
|
23 |
-
# ------------------------------------------------------------------
|
24 |
-
# 2)
|
25 |
-
# ------------------------------------------------------------------
|
26 |
-
with open(HP_FILE) as f:
|
27 |
-
hparams = load_hyperpyyaml(f)
|
28 |
-
|
29 |
-
modules = {
|
30 |
-
"compute_features": hparams["compute_features"],
|
31 |
-
"mean_var_norm" : hparams["mean_var_norm"],
|
32 |
-
"embedding_model" : hparams["embedding_model"],
|
33 |
-
"classifier" : hparams["classifier"],
|
34 |
-
}
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
)
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
"
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned).
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import speechbrain as sb
|
10 |
+
from hyperpyyaml import load_hyperpyyaml
|
11 |
+
from speechbrain.dataio.dataio import read_audio
|
12 |
+
|
13 |
+
# ------------------------------------------------------------------
|
14 |
+
# 1) Paths
|
15 |
+
# ------------------------------------------------------------------
|
16 |
+
EXP_DIR = (
|
17 |
+
"/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/"
|
18 |
+
"emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968"
|
19 |
+
)
|
20 |
+
HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml")
|
21 |
+
CKPT_DIR = os.path.join(EXP_DIR, "save")
|
22 |
+
|
23 |
+
# ------------------------------------------------------------------
|
24 |
+
# 2) Load hyperparams and modules
|
25 |
+
# ------------------------------------------------------------------
|
26 |
+
with open(HP_FILE) as f:
|
27 |
+
hparams = load_hyperpyyaml(f)
|
28 |
+
|
29 |
+
modules = {
|
30 |
+
"compute_features": hparams["compute_features"],
|
31 |
+
"mean_var_norm" : hparams["mean_var_norm"],
|
32 |
+
"embedding_model" : hparams["embedding_model"],
|
33 |
+
"classifier" : hparams["classifier"],
|
34 |
+
}
|
35 |
+
|
36 |
+
# ------------------------------------------------------------------
|
37 |
+
# 3) Device setup
|
38 |
+
# ------------------------------------------------------------------
|
39 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
+
print(f"Using device: {device}")
|
41 |
+
|
42 |
+
checkpointer = sb.utils.checkpoints.Checkpointer(
|
43 |
+
checkpoints_dir=CKPT_DIR,
|
44 |
+
recoverables=modules,
|
45 |
+
allow_partial_load=True,
|
46 |
+
)
|
47 |
+
checkpointer.recover_if_possible()
|
48 |
+
|
49 |
+
# ------------------------------------------------------------------
|
50 |
+
# 4) Simple batch container
|
51 |
+
# ------------------------------------------------------------------
|
52 |
+
class SimpleBatch:
|
53 |
+
def __init__(self, wav, lens):
|
54 |
+
self.sig = (wav, lens)
|
55 |
+
|
56 |
+
def to(self, device):
|
57 |
+
wav, lens = self.sig
|
58 |
+
self.sig = (wav.to(device), lens.to(device))
|
59 |
+
return self
|
60 |
+
|
61 |
+
# ------------------------------------------------------------------
|
62 |
+
# 5) Brain class
|
63 |
+
# ------------------------------------------------------------------
|
64 |
+
class EmoIdBrain(sb.Brain):
|
65 |
+
def compute_forward(self, batch, stage):
|
66 |
+
wavs, lens = batch.sig
|
67 |
+
feats = self.modules.compute_features(wavs)
|
68 |
+
feats = self.modules.mean_var_norm(feats, lens)
|
69 |
+
emb = self.modules.embedding_model(feats, lens)
|
70 |
+
out = self.modules.classifier(emb)
|
71 |
+
return out
|
72 |
+
|
73 |
+
brain = EmoIdBrain(
|
74 |
+
modules=modules,
|
75 |
+
hparams=hparams,
|
76 |
+
run_opts={"device": device},
|
77 |
+
checkpointer=checkpointer
|
78 |
+
)
|
79 |
+
|
80 |
+
# ------------------------------------------------------------------
|
81 |
+
# 6) Emotion labels
|
82 |
+
# ------------------------------------------------------------------
|
83 |
+
IDX2LAB = [
|
84 |
+
"anger", "sadness", "neutral",
|
85 |
+
"surprise", "happiness", "fear"
|
86 |
+
]
|
87 |
+
|
88 |
+
# ------------------------------------------------------------------
|
89 |
+
# 7) Prediction function
|
90 |
+
# ------------------------------------------------------------------
|
91 |
+
def predict(wav_path: str) -> str:
|
92 |
+
wav_raw = read_audio(wav_path)
|
93 |
+
wav = wav_raw.clone().detach().float().unsqueeze(0) if isinstance(wav_raw, torch.Tensor) else torch.tensor(wav_raw, dtype=torch.float32).unsqueeze(0)
|
94 |
+
lens = torch.tensor([1.0])
|
95 |
+
|
96 |
+
batch = SimpleBatch(wav, lens).to(device)
|
97 |
+
brain.modules.eval()
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
logits = brain.compute_forward(batch, stage=sb.Stage.TEST)
|
101 |
+
|
102 |
+
idx = int(logits.argmax(dim=-1))
|
103 |
+
return IDX2LAB[idx]
|
104 |
+
|
105 |
+
# ------------------------------------------------------------------
|
106 |
+
# 8) Run
|
107 |
+
# ------------------------------------------------------------------
|
108 |
+
if __name__ == "__main__":
|
109 |
+
WAV_FILE = "shortvoice.wav"
|
110 |
+
print("Predicted emotion:", predict(WAV_FILE))
|