mobina1380 commited on
Commit
a286bf8
·
verified ·
1 Parent(s): c7a5220

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +110 -103
inference.py CHANGED
@@ -1,103 +1,110 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned).
5
- """
6
-
7
- import os
8
- import torch
9
- import speechbrain as sb
10
- from hyperpyyaml import load_hyperpyyaml
11
- from speechbrain.dataio.dataio import read_audio
12
-
13
- # ------------------------------------------------------------------
14
- # 1) paths
15
- # ------------------------------------------------------------------
16
- EXP_DIR = (
17
- "/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/"
18
- "emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968"
19
- )
20
- HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml")
21
- CKPT_DIR = os.path.join(EXP_DIR, "save")
22
-
23
- # ------------------------------------------------------------------
24
- # 2) hparams & modules
25
- # ------------------------------------------------------------------
26
- with open(HP_FILE) as f:
27
- hparams = load_hyperpyyaml(f)
28
-
29
- modules = {
30
- "compute_features": hparams["compute_features"],
31
- "mean_var_norm" : hparams["mean_var_norm"],
32
- "embedding_model" : hparams["embedding_model"],
33
- "classifier" : hparams["classifier"],
34
- }
35
-
36
- checkpointer = sb.utils.checkpoints.Checkpointer(
37
- checkpoints_dir=CKPT_DIR,
38
- recoverables=modules,
39
- allow_partial_load=True,
40
- )
41
- checkpointer.recover_if_possible()
42
-
43
- # ------------------------------------------------------------------
44
- # 3) Simple batch container (بدون PaddedBatch)
45
- # ------------------------------------------------------------------
46
- class SimpleBatch:
47
- def __init__(self, wav, lens):
48
- self.sig = (wav, lens)
49
-
50
- def to(self, device):
51
- wav, lens = self.sig
52
- self.sig = (wav.to(device), lens.to(device))
53
- return self
54
-
55
- # ------------------------------------------------------------------
56
- # 4) Brain for inference
57
- # ------------------------------------------------------------------
58
- class EmoIdBrain(sb.Brain):
59
- def compute_forward(self, batch, stage):
60
- wavs, lens = batch.sig
61
- feats = self.modules.compute_features(wavs)
62
- feats = self.modules.mean_var_norm(feats, lens)
63
- emb = self.modules.embedding_model(feats, lens)
64
- out = self.modules.classifier(emb)
65
- return out
66
-
67
- device = 'cpu'
68
- brain = EmoIdBrain(modules, hparams, run_opts={"device": device},
69
- checkpointer=checkpointer)
70
- print('dddddddddddddddd')
71
- # ------------------------------------------------------------------
72
- # 5) emotion labels (hard-coded)
73
- # ------------------------------------------------------------------
74
- IDX2LAB = [
75
- "anger", # 0
76
- "sadness", # 1
77
- "neutral", # 2
78
- "surprise", # 3
79
- "happiness", # 4
80
- "fear", # 5
81
- ]
82
-
83
- # # ------------------------------------------------------------------
84
- # # 6) predict function
85
- # # ------------------------------------------------------------------
86
- def predict(wav_path: str) -> str:
87
- wav = torch.tensor(read_audio(wav_path)).float().unsqueeze(0) # [1,T]
88
- lens = torch.tensor([1.0]) # full length
89
- batch = SimpleBatch(wav, lens).to(device)
90
-
91
- brain.modules.eval()
92
- # disable dropout if any
93
- with torch.no_grad():
94
- logits = brain.compute_forward(batch, stage=sb.Stage.TEST)
95
- idx = int(logits.argmax(dim=-1))
96
- return IDX2LAB[idx]
97
-
98
- # # ------------------------------------------------------------------
99
- # # 7) run
100
- # # ------------------------------------------------------------------
101
- if __name__ == "__main__":
102
- WAV_FILE = "shortvoice.wav" # change to your wav
103
- print("Predicted emotion:", predict(WAV_FILE))
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned).
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ import speechbrain as sb
10
+ from hyperpyyaml import load_hyperpyyaml
11
+ from speechbrain.dataio.dataio import read_audio
12
+
13
+ # ------------------------------------------------------------------
14
+ # 1) Paths
15
+ # ------------------------------------------------------------------
16
+ EXP_DIR = (
17
+ "/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/"
18
+ "emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968"
19
+ )
20
+ HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml")
21
+ CKPT_DIR = os.path.join(EXP_DIR, "save")
22
+
23
+ # ------------------------------------------------------------------
24
+ # 2) Load hyperparams and modules
25
+ # ------------------------------------------------------------------
26
+ with open(HP_FILE) as f:
27
+ hparams = load_hyperpyyaml(f)
28
+
29
+ modules = {
30
+ "compute_features": hparams["compute_features"],
31
+ "mean_var_norm" : hparams["mean_var_norm"],
32
+ "embedding_model" : hparams["embedding_model"],
33
+ "classifier" : hparams["classifier"],
34
+ }
35
+
36
+ # ------------------------------------------------------------------
37
+ # 3) Device setup
38
+ # ------------------------------------------------------------------
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ print(f"Using device: {device}")
41
+
42
+ checkpointer = sb.utils.checkpoints.Checkpointer(
43
+ checkpoints_dir=CKPT_DIR,
44
+ recoverables=modules,
45
+ allow_partial_load=True,
46
+ )
47
+ checkpointer.recover_if_possible()
48
+
49
+ # ------------------------------------------------------------------
50
+ # 4) Simple batch container
51
+ # ------------------------------------------------------------------
52
+ class SimpleBatch:
53
+ def __init__(self, wav, lens):
54
+ self.sig = (wav, lens)
55
+
56
+ def to(self, device):
57
+ wav, lens = self.sig
58
+ self.sig = (wav.to(device), lens.to(device))
59
+ return self
60
+
61
+ # ------------------------------------------------------------------
62
+ # 5) Brain class
63
+ # ------------------------------------------------------------------
64
+ class EmoIdBrain(sb.Brain):
65
+ def compute_forward(self, batch, stage):
66
+ wavs, lens = batch.sig
67
+ feats = self.modules.compute_features(wavs)
68
+ feats = self.modules.mean_var_norm(feats, lens)
69
+ emb = self.modules.embedding_model(feats, lens)
70
+ out = self.modules.classifier(emb)
71
+ return out
72
+
73
+ brain = EmoIdBrain(
74
+ modules=modules,
75
+ hparams=hparams,
76
+ run_opts={"device": device},
77
+ checkpointer=checkpointer
78
+ )
79
+
80
+ # ------------------------------------------------------------------
81
+ # 6) Emotion labels
82
+ # ------------------------------------------------------------------
83
+ IDX2LAB = [
84
+ "anger", "sadness", "neutral",
85
+ "surprise", "happiness", "fear"
86
+ ]
87
+
88
+ # ------------------------------------------------------------------
89
+ # 7) Prediction function
90
+ # ------------------------------------------------------------------
91
+ def predict(wav_path: str) -> str:
92
+ wav_raw = read_audio(wav_path)
93
+ wav = wav_raw.clone().detach().float().unsqueeze(0) if isinstance(wav_raw, torch.Tensor) else torch.tensor(wav_raw, dtype=torch.float32).unsqueeze(0)
94
+ lens = torch.tensor([1.0])
95
+
96
+ batch = SimpleBatch(wav, lens).to(device)
97
+ brain.modules.eval()
98
+
99
+ with torch.no_grad():
100
+ logits = brain.compute_forward(batch, stage=sb.Stage.TEST)
101
+
102
+ idx = int(logits.argmax(dim=-1))
103
+ return IDX2LAB[idx]
104
+
105
+ # ------------------------------------------------------------------
106
+ # 8) Run
107
+ # ------------------------------------------------------------------
108
+ if __name__ == "__main__":
109
+ WAV_FILE = "shortvoice.wav"
110
+ print("Predicted emotion:", predict(WAV_FILE))