|
|
|
|
|
""" |
|
Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned). |
|
""" |
|
|
|
import os |
|
import torch |
|
import speechbrain as sb |
|
from hyperpyyaml import load_hyperpyyaml |
|
from speechbrain.dataio.dataio import read_audio |
|
|
|
|
|
|
|
|
|
EXP_DIR = ( |
|
"/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/" |
|
"emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968" |
|
) |
|
HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml") |
|
CKPT_DIR = os.path.join(EXP_DIR, "save") |
|
|
|
|
|
|
|
|
|
with open(HP_FILE) as f: |
|
hparams = load_hyperpyyaml(f) |
|
|
|
modules = { |
|
"compute_features": hparams["compute_features"], |
|
"mean_var_norm" : hparams["mean_var_norm"], |
|
"embedding_model" : hparams["embedding_model"], |
|
"classifier" : hparams["classifier"], |
|
} |
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
checkpointer = sb.utils.checkpoints.Checkpointer( |
|
checkpoints_dir=CKPT_DIR, |
|
recoverables=modules, |
|
allow_partial_load=True, |
|
) |
|
checkpointer.recover_if_possible() |
|
|
|
|
|
|
|
|
|
class SimpleBatch: |
|
def __init__(self, wav, lens): |
|
self.sig = (wav, lens) |
|
|
|
def to(self, device): |
|
wav, lens = self.sig |
|
self.sig = (wav.to(device), lens.to(device)) |
|
return self |
|
|
|
|
|
|
|
|
|
class EmoIdBrain(sb.Brain): |
|
def compute_forward(self, batch, stage): |
|
wavs, lens = batch.sig |
|
feats = self.modules.compute_features(wavs) |
|
feats = self.modules.mean_var_norm(feats, lens) |
|
emb = self.modules.embedding_model(feats, lens) |
|
out = self.modules.classifier(emb) |
|
return out |
|
|
|
brain = EmoIdBrain( |
|
modules=modules, |
|
hparams=hparams, |
|
run_opts={"device": device}, |
|
checkpointer=checkpointer |
|
) |
|
|
|
|
|
|
|
|
|
IDX2LAB = [ |
|
"anger", "sadness", "neutral", |
|
"surprise", "happiness", "fear" |
|
] |
|
|
|
|
|
|
|
|
|
def predict(wav_path: str) -> str: |
|
wav_raw = read_audio(wav_path) |
|
wav = wav_raw.clone().detach().float().unsqueeze(0) if isinstance(wav_raw, torch.Tensor) else torch.tensor(wav_raw, dtype=torch.float32).unsqueeze(0) |
|
lens = torch.tensor([1.0]) |
|
|
|
batch = SimpleBatch(wav, lens).to(device) |
|
brain.modules.eval() |
|
|
|
with torch.no_grad(): |
|
logits = brain.compute_forward(batch, stage=sb.Stage.TEST) |
|
|
|
idx = int(logits.argmax(dim=-1)) |
|
return IDX2LAB[idx] |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
WAV_FILE = "shortvoice.wav" |
|
print("Predicted emotion:", predict(WAV_FILE)) |