speechbrain-persian-ser / inference.py
mobina1380's picture
Update inference.py
a286bf8 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned).
"""
import os
import torch
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.dataio.dataio import read_audio
# ------------------------------------------------------------------
# 1) Paths
# ------------------------------------------------------------------
EXP_DIR = (
"/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/"
"emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968"
)
HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml")
CKPT_DIR = os.path.join(EXP_DIR, "save")
# ------------------------------------------------------------------
# 2) Load hyperparams and modules
# ------------------------------------------------------------------
with open(HP_FILE) as f:
hparams = load_hyperpyyaml(f)
modules = {
"compute_features": hparams["compute_features"],
"mean_var_norm" : hparams["mean_var_norm"],
"embedding_model" : hparams["embedding_model"],
"classifier" : hparams["classifier"],
}
# ------------------------------------------------------------------
# 3) Device setup
# ------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
checkpointer = sb.utils.checkpoints.Checkpointer(
checkpoints_dir=CKPT_DIR,
recoverables=modules,
allow_partial_load=True,
)
checkpointer.recover_if_possible()
# ------------------------------------------------------------------
# 4) Simple batch container
# ------------------------------------------------------------------
class SimpleBatch:
def __init__(self, wav, lens):
self.sig = (wav, lens)
def to(self, device):
wav, lens = self.sig
self.sig = (wav.to(device), lens.to(device))
return self
# ------------------------------------------------------------------
# 5) Brain class
# ------------------------------------------------------------------
class EmoIdBrain(sb.Brain):
def compute_forward(self, batch, stage):
wavs, lens = batch.sig
feats = self.modules.compute_features(wavs)
feats = self.modules.mean_var_norm(feats, lens)
emb = self.modules.embedding_model(feats, lens)
out = self.modules.classifier(emb)
return out
brain = EmoIdBrain(
modules=modules,
hparams=hparams,
run_opts={"device": device},
checkpointer=checkpointer
)
# ------------------------------------------------------------------
# 6) Emotion labels
# ------------------------------------------------------------------
IDX2LAB = [
"anger", "sadness", "neutral",
"surprise", "happiness", "fear"
]
# ------------------------------------------------------------------
# 7) Prediction function
# ------------------------------------------------------------------
def predict(wav_path: str) -> str:
wav_raw = read_audio(wav_path)
wav = wav_raw.clone().detach().float().unsqueeze(0) if isinstance(wav_raw, torch.Tensor) else torch.tensor(wav_raw, dtype=torch.float32).unsqueeze(0)
lens = torch.tensor([1.0])
batch = SimpleBatch(wav, lens).to(device)
brain.modules.eval()
with torch.no_grad():
logits = brain.compute_forward(batch, stage=sb.Stage.TEST)
idx = int(logits.argmax(dim=-1))
return IDX2LAB[idx]
# ------------------------------------------------------------------
# 8) Run
# ------------------------------------------------------------------
if __name__ == "__main__":
WAV_FILE = "shortvoice.wav"
print("Predicted emotion:", predict(WAV_FILE))