DeepLearning101 commited on
Commit
64ceedd
·
verified ·
1 Parent(s): b6c45cb

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +107 -107
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -1,108 +1,108 @@
1
- # DPTNet_quant_sep.py
2
-
3
- import os
4
- import torch
5
- import numpy as np
6
- import torchaudio
7
- from huggingface_hub import hf_hub_download
8
- from . import asteroid_test
9
-
10
- torchaudio.set_audio_backend("sox_io")
11
-
12
- def get_conf():
13
- conf_filterbank = {
14
- 'n_filters': 64,
15
- 'kernel_size': 16,
16
- 'stride': 8
17
- }
18
-
19
- conf_masknet = {
20
- 'in_chan': 64,
21
- 'n_src': 2,
22
- 'out_chan': 64,
23
- 'ff_hid': 256,
24
- 'ff_activation': "relu",
25
- 'norm_type': "gLN",
26
- 'chunk_size': 100,
27
- 'hop_size': 50,
28
- 'n_repeats': 2,
29
- 'mask_act': 'sigmoid',
30
- 'bidirectional': True,
31
- 'dropout': 0
32
- }
33
- return conf_filterbank, conf_masknet
34
-
35
-
36
- def load_dpt_model():
37
- print('Load Separation Model...')
38
-
39
- # 從環境變數取得 Hugging Face Token
40
- HF_TOKEN = os.getenv("HF_TOKEN")
41
- if not HF_TOKEN:
42
- raise EnvironmentError("環境變數 HF_TOKEN 未設定!請先執行 export HF_TOKEN=xxx")
43
-
44
- # 從 Hugging Face Hub 下載模型權重
45
- model_path = hf_hub_download(
46
- repo_id="DeepLearning101/speech-separation", # ← 替換成你的 repo 名稱
47
- filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
48
- token=HF_TOKEN
49
- )
50
-
51
- # 取得模型參數
52
- conf_filterbank, conf_masknet = get_conf()
53
-
54
- # 建立模型架構
55
- model_class = getattr(asteroid_test, "DPTNet")
56
- model = model_class(**conf_filterbank, **conf_masknet)
57
-
58
- # 套用量化設定
59
- model = torch.quantization.quantize_dynamic(
60
- model,
61
- {torch.nn.LSTM, torch.nn.Linear},
62
- dtype=torch.qint8
63
- )
64
-
65
- # 載入權重(忽略不匹配的 keys)
66
- state_dict = torch.load(model_path, map_location="cpu")
67
- model_state_dict = model.state_dict()
68
- filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
69
- model.load_state_dict(filtered_state_dict, strict=False)
70
- model.eval()
71
-
72
- return model
73
-
74
-
75
- def dpt_sep_process(wav_path, model=None, outfilename=None):
76
- if model is None:
77
- model = load_dpt_model()
78
-
79
- x, sr = torchaudio.load(wav_path)
80
- x = x.cpu()
81
-
82
- with torch.no_grad():
83
- est_sources = model(x) # shape: (1, 2, T)
84
-
85
- est_sources = est_sources.squeeze(0) # shape: (2, T)
86
- sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
87
-
88
- # 正規化
89
- max_abs = x[0].abs().max().item()
90
- sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
91
- sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
92
-
93
- # 增加 channel 維度,變為 (1, T)
94
- sep_1 = sep_1.unsqueeze(0)
95
- sep_2 = sep_2.unsqueeze(0)
96
-
97
- # 儲存結果
98
- if outfilename is not None:
99
- torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
100
- torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
101
- torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
102
- else:
103
- torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
104
- torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
105
-
106
-
107
- if __name__ == '__main__':
108
  print("This module should be used via Flask or Gradio.")
 
1
+ # DPTNet_quant_sep.py
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import torchaudio
7
+ from huggingface_hub import hf_hub_download
8
+ from . import asteroid_test
9
+
10
+ torchaudio.set_audio_backend("sox_io")
11
+
12
+ def get_conf():
13
+ conf_filterbank = {
14
+ 'n_filters': 64,
15
+ 'kernel_size': 16,
16
+ 'stride': 8
17
+ }
18
+
19
+ conf_masknet = {
20
+ 'in_chan': 64,
21
+ 'n_src': 2,
22
+ 'out_chan': 64,
23
+ 'ff_hid': 256,
24
+ 'ff_activation': "relu",
25
+ 'norm_type': "gLN",
26
+ 'chunk_size': 100,
27
+ 'hop_size': 50,
28
+ 'n_repeats': 2,
29
+ 'mask_act': 'sigmoid',
30
+ 'bidirectional': True,
31
+ 'dropout': 0
32
+ }
33
+ return conf_filterbank, conf_masknet
34
+
35
+
36
+ def load_dpt_model():
37
+ print('Load Separation Model...')
38
+
39
+ # 從環境變數取得 Hugging Face Token
40
+ HF_TOKEN = os.getenv("SpeechSeparation")
41
+ if not HF_TOKEN:
42
+ raise EnvironmentError("環境變數 HF_TOKEN 未設定!請先執行 export HF_TOKEN=xxx")
43
+
44
+ # 從 Hugging Face Hub 下載模型權重
45
+ model_path = hf_hub_download(
46
+ repo_id="DeepLearning101/speech-separation", # ← 替換成你的 repo 名稱
47
+ filename="train_dptnet_aishell_partOverlap_B6_300epoch_quan-int8.p",
48
+ token=HF_TOKEN
49
+ )
50
+
51
+ # 取得模型參數
52
+ conf_filterbank, conf_masknet = get_conf()
53
+
54
+ # 建立模型架構
55
+ model_class = getattr(asteroid_test, "DPTNet")
56
+ model = model_class(**conf_filterbank, **conf_masknet)
57
+
58
+ # 套用量化設定
59
+ model = torch.quantization.quantize_dynamic(
60
+ model,
61
+ {torch.nn.LSTM, torch.nn.Linear},
62
+ dtype=torch.qint8
63
+ )
64
+
65
+ # 載入權重(忽略不匹配的 keys)
66
+ state_dict = torch.load(model_path, map_location="cpu")
67
+ model_state_dict = model.state_dict()
68
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
69
+ model.load_state_dict(filtered_state_dict, strict=False)
70
+ model.eval()
71
+
72
+ return model
73
+
74
+
75
+ def dpt_sep_process(wav_path, model=None, outfilename=None):
76
+ if model is None:
77
+ model = load_dpt_model()
78
+
79
+ x, sr = torchaudio.load(wav_path)
80
+ x = x.cpu()
81
+
82
+ with torch.no_grad():
83
+ est_sources = model(x) # shape: (1, 2, T)
84
+
85
+ est_sources = est_sources.squeeze(0) # shape: (2, T)
86
+ sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
87
+
88
+ # 正規化
89
+ max_abs = x[0].abs().max().item()
90
+ sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
91
+ sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
92
+
93
+ # 增加 channel 維度,變為 (1, T)
94
+ sep_1 = sep_1.unsqueeze(0)
95
+ sep_2 = sep_2.unsqueeze(0)
96
+
97
+ # 儲存結果
98
+ if outfilename is not None:
99
+ torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
100
+ torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
101
+ torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
102
+ else:
103
+ torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
104
+ torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
105
+
106
+
107
+ if __name__ == '__main__':
108
  print("This module should be used via Flask or Gradio.")