mobina1380 commited on
Commit
2c112a2
·
1 Parent(s): 73fce0c

First Persian SER model with SpeechBrain

Browse files
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Persian Speech Emotion Recognition with SpeechBrain (ShEMO)
2
+
3
+ This model is a fine-tuned ECAPA-TDNN using the [ShEMO](https://github.com/ashkanpourmir/shEMO-database) dataset for Persian speech emotion recognition.
4
+ Trained with [SpeechBrain](https://github.com/speechbrain/speechbrain).
5
+
6
+ **Classes**: `anger`, `sadness`, `neutral`, `surprise`, `happiness`, `fear`
7
+
8
+ To use:
9
+ ```python
10
+ from inference import predict
11
+ print(predict("yourfile.wav"))
custom.py ADDED
@@ -0,0 +1 @@
 
 
1
+ /mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968/custom.py
hyperparams.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2025-04-21 from:
2
+ # /content/test/hparams/train.yaml
3
+ # yamllint disable
4
+ # ########################################
5
+ # Emotion recognition from Persian speech using ECAPA-TDNN
6
+ # Dataset: ShEMO
7
+ # Language: Persian
8
+ # ########################################
9
+
10
+ # تنظیمات تصادفی (اختیاری)
11
+ seed: 1968
12
+ number_of_epochs: 30
13
+ # ⚠️ این خط حذف شد چون ممکنه در بعضی محیط‌ها مشکل بده:
14
+ # __set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
15
+
16
+ # مسیر فولدر داده‌ها (در لوکال مسیر پروژه)
17
+ data_folder: .
18
+
19
+ # مسیر خروجی مدل‌ها و لاگ‌ها
20
+ output_folder: results/ECAPA-TDNN/1968
21
+ save_folder: results/ECAPA-TDNN/1968/save
22
+ train_log: results/ECAPA-TDNN/1968/train_log.txt
23
+
24
+ # فایل‌های CSV دیتاست
25
+ csv_train: ./test/train.csv
26
+ csv_valid: ./test/valid.csv
27
+ csv_test: ./test/test.csv
28
+
29
+ # Logger برای ذخیره‌ی وضعیت آموزش
30
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
31
+ save_file: results/ECAPA-TDNN/1968/train_log.txt
32
+
33
+ # ارزیابی خطا
34
+ error_stats: !name:speechbrain.utils.metric_stats.MetricStats
35
+ metric: !name:speechbrain.nnet.losses.classification_error
36
+ reduction: batch
37
+
38
+ ckpt_interval_minutes: 15
39
+
40
+ # پارامترهای آموزش
41
+
42
+ batch_size: 4
43
+ grad_accumulation_factor: 2
44
+ lr: 0.0001
45
+ weight_decay: 0.00002
46
+ base_lr: 0.000001
47
+ max_lr: 0.0001
48
+ step_size: 1088
49
+ mode: exp_range
50
+ gamma: 0.9998
51
+ shuffle: true
52
+ drop_last: false
53
+
54
+ # ویژگی‌های صوتی
55
+ n_mels: 80
56
+ left_frames: 0
57
+ right_frames: 0
58
+ deltas: false
59
+
60
+ # کلاس‌های احساسات در ShEMO
61
+ out_n_neurons: 6
62
+
63
+ # نگاشت لیبل‌ها
64
+ label_dict:
65
+ anger: 0
66
+ surprise: 1
67
+ happiness: 2
68
+ sadness: 3
69
+ neutral: 4
70
+ fear: 5
71
+
72
+ label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder
73
+
74
+ # تنظیمات DataLoader
75
+ dataloader_options:
76
+ batch_size: 4
77
+ shuffle: true
78
+ num_workers: 2
79
+ drop_last: false
80
+
81
+ # استخراج ویژگی‌ها (Mel Spectrogram)
82
+ compute_features: &id001 !new:speechbrain.lobes.features.Fbank
83
+ n_mels: 80
84
+ left_frames: 0
85
+ right_frames: 0
86
+ deltas: false
87
+
88
+ # مدل ECAPA-TDNN
89
+ embedding_model: &id002 !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN
90
+ input_size: 80
91
+ channels: [512, 512, 512, 512, 1536]
92
+ kernel_sizes: [5, 3, 3, 3, 1]
93
+ dilations: [1, 2, 3, 4, 1]
94
+ attention_channels: 64
95
+ lin_neurons: 96
96
+
97
+ # کلاس‌فایر خروجی
98
+ classifier: &id003 !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
99
+ input_size: 96
100
+ out_neurons: 6
101
+
102
+ # شمارنده اپوک‌ها
103
+ epoch_counter: &id005 !new:speechbrain.utils.epoch_loop.EpochCounter
104
+ limit: 30
105
+
106
+ # نرمال‌سازی ویژگی‌ها
107
+ mean_var_norm: &id004 !new:speechbrain.processing.features.InputNormalization
108
+
109
+ # تابع خطا
110
+ norm_type: sentence
111
+ std_norm: false
112
+
113
+ # ماژول‌های مدل
114
+ modules:
115
+ compute_features: *id001
116
+ embedding_model: *id002
117
+ classifier: *id003
118
+ mean_var_norm: *id004
119
+ compute_cost: !new:speechbrain.nnet.losses.LogSoftmaxWrapper
120
+ loss_fn: !new:speechbrain.nnet.losses.AdditiveAngularMargin
121
+ margin: 0.2
122
+ scale: 30
123
+
124
+ # اپتیمایزر
125
+ opt_class: !name:torch.optim.Adam
126
+ lr: 0.0001
127
+ weight_decay: 0.00002
128
+
129
+ # زمان‌بندی یادگیری
130
+ lr_annealing: !new:speechbrain.nnet.schedulers.CyclicLRScheduler
131
+ mode: exp_range
132
+ gamma: 0.9998
133
+ base_lr: 0.000001
134
+ max_lr: 0.0001
135
+ step_size: 1088
136
+
137
+ # مدیریت چک‌پوینت
138
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
139
+ checkpoints_dir: results/ECAPA-TDNN/1968/save
140
+ recoverables:
141
+ embedding_model: *id002
142
+ classifier: *id003
143
+ normalizer: *id004
144
+ counter: *id005
inference.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned).
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ import speechbrain as sb
10
+ from hyperpyyaml import load_hyperpyyaml
11
+ from speechbrain.dataio.dataio import read_audio
12
+
13
+ # ------------------------------------------------------------------
14
+ # 1) paths
15
+ # ------------------------------------------------------------------
16
+ EXP_DIR = (
17
+ "/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/"
18
+ "emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968"
19
+ )
20
+ HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml")
21
+ CKPT_DIR = os.path.join(EXP_DIR, "save")
22
+
23
+ # ------------------------------------------------------------------
24
+ # 2) hparams & modules
25
+ # ------------------------------------------------------------------
26
+ with open(HP_FILE) as f:
27
+ hparams = load_hyperpyyaml(f)
28
+
29
+ modules = {
30
+ "compute_features": hparams["compute_features"],
31
+ "mean_var_norm" : hparams["mean_var_norm"],
32
+ "embedding_model" : hparams["embedding_model"],
33
+ "classifier" : hparams["classifier"],
34
+ }
35
+
36
+ checkpointer = sb.utils.checkpoints.Checkpointer(
37
+ checkpoints_dir=CKPT_DIR,
38
+ recoverables=modules,
39
+ allow_partial_load=True,
40
+ )
41
+ checkpointer.recover_if_possible()
42
+
43
+ # ------------------------------------------------------------------
44
+ # 3) Simple batch container (بدون PaddedBatch)
45
+ # ------------------------------------------------------------------
46
+ class SimpleBatch:
47
+ def __init__(self, wav, lens):
48
+ self.sig = (wav, lens)
49
+
50
+ def to(self, device):
51
+ wav, lens = self.sig
52
+ self.sig = (wav.to(device), lens.to(device))
53
+ return self
54
+
55
+ # ------------------------------------------------------------------
56
+ # 4) Brain for inference
57
+ # ------------------------------------------------------------------
58
+ class EmoIdBrain(sb.Brain):
59
+ def compute_forward(self, batch, stage):
60
+ wavs, lens = batch.sig
61
+ feats = self.modules.compute_features(wavs)
62
+ feats = self.modules.mean_var_norm(feats, lens)
63
+ emb = self.modules.embedding_model(feats, lens)
64
+ out = self.modules.classifier(emb)
65
+ return out
66
+
67
+ device = 'cpu'
68
+ brain = EmoIdBrain(modules, hparams, run_opts={"device": device},
69
+ checkpointer=checkpointer)
70
+ print('dddddddddddddddd')
71
+ # ------------------------------------------------------------------
72
+ # 5) emotion labels (hard-coded)
73
+ # ------------------------------------------------------------------
74
+ IDX2LAB = [
75
+ "anger", # 0
76
+ "sadness", # 1
77
+ "neutral", # 2
78
+ "surprise", # 3
79
+ "happiness", # 4
80
+ "fear", # 5
81
+ ]
82
+
83
+ # # ------------------------------------------------------------------
84
+ # # 6) predict function
85
+ # # ------------------------------------------------------------------
86
+ def predict(wav_path: str) -> str:
87
+ wav = torch.tensor(read_audio(wav_path)).float().unsqueeze(0) # [1,T]
88
+ lens = torch.tensor([1.0]) # full length
89
+ batch = SimpleBatch(wav, lens).to(device)
90
+
91
+ brain.modules.eval()
92
+ # disable dropout if any
93
+ with torch.no_grad():
94
+ logits = brain.compute_forward(batch, stage=sb.Stage.TEST)
95
+ idx = int(logits.argmax(dim=-1))
96
+ return IDX2LAB[idx]
97
+
98
+ # # ------------------------------------------------------------------
99
+ # # 7) run
100
+ # # ------------------------------------------------------------------
101
+ if __name__ == "__main__":
102
+ WAV_FILE = "shortvoice.wav" # change to your wav
103
+ print("Predicted emotion:", predict(WAV_FILE))
save/CKPT+2025-04-21+07-00-58+00/CKPT.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # yamllint disable
2
+ end-of-epoch: true
3
+ error: 0.20416668057441711
4
+ loss: 3.3974924573247933
5
+ unixtime: 1745218858.6184402
save/CKPT+2025-04-21+07-00-58+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95c0499225266093885120979ec7f686a648697370e9d71bd41f01704ca5bea7
3
+ size 49
save/CKPT+2025-04-21+07-00-58+00/classifier.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d117e69c0ad85332c7eb0b4157f99a3b5c76d06a4acf237dcf69fb213002632
3
+ size 3627
save/CKPT+2025-04-21+07-00-58+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7a56873cd771f2c446d369b649430b65a756ba278ff97ec81bb6f55b2e73569
3
+ size 2
save/CKPT+2025-04-21+07-00-58+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84f01dd97c687fb28a296bcc2ef1801446ea7405860595924eb2b5bb634718d1
3
+ size 3
save/CKPT+2025-04-21+07-00-58+00/embedding_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad046a8e1200e5755afdef0178305671823cbf306a2054476134c3a0da3a9814
3
+ size 22190908
save/CKPT+2025-04-21+07-00-58+00/normalizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92244ada292c7d670d1dc88549e74ed24b3e25e70f27fe443420cf4832d6811b
3
+ size 1578
save/CKPT+2025-04-21+07-00-58+00/optimizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef0f83b10bd9415fa702f62c75e81130c8f816d1b366e35655ffeffe91be97ec
3
+ size 44165498
save/CKPT+2025-04-21+07-07-30+00/CKPT.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # yamllint disable
2
+ end-of-epoch: true
3
+ error: 0.2291666716337204
4
+ loss: 3.7238532538904106
5
+ unixtime: 1745219250.8527672
save/CKPT+2025-04-21+07-07-30+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cd2e4ff5cb587a5c2c9804f660aa0a4f4497060feaa3a9370cc0df80b1654a4
3
+ size 49
save/CKPT+2025-04-21+07-07-30+00/classifier.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7298d2305a544f2f28efbee7a8d182f6c256eccf12f1afd95f38ea48ee126b66
3
+ size 3627
save/CKPT+2025-04-21+07-07-30+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:624b60c58c9d8bfb6ff1886c2fd605d2adeb6ea4da576068201b6c6958ce93f4
3
+ size 2
save/CKPT+2025-04-21+07-07-30+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84f01dd97c687fb28a296bcc2ef1801446ea7405860595924eb2b5bb634718d1
3
+ size 3
save/CKPT+2025-04-21+07-07-30+00/embedding_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e227e8ab5607485ac0a8d7bf70466f8c01ea0172af95f2d3f50987b31e6cf57
3
+ size 22190908
save/CKPT+2025-04-21+07-07-30+00/normalizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92244ada292c7d670d1dc88549e74ed24b3e25e70f27fe443420cf4832d6811b
3
+ size 1578
save/CKPT+2025-04-21+07-07-30+00/optimizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a27a58e33bfe20d69045cf9748af9891fb3285e87b96cd1528d14cc47dabd7f0
3
+ size 44165498