DeepLearning101 commited on
Commit
cf73d23
·
verified ·
1 Parent(s): 67c1496

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +33 -66
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -1,25 +1,12 @@
1
- # DPTNet_quant_sep.py
2
-
3
- import warnings
4
- warnings.filterwarnings("ignore", message="Failed to initialize NumPy: _ARRAY_API not found")
5
-
6
  import os
7
  import torch
8
  import numpy as np
9
  import torchaudio
10
- from huggingface_hub import hf_hub_download
11
-
12
- # 動態導入 asteroid_test 中的 DPTNet
13
- try:
14
- from . import asteroid_test
15
- except ImportError as e:
16
- raise ImportError("無法載入 asteroid_test 模組,請確認該模組與訓練時相同") from e
17
-
18
- torchaudio.set_audio_backend("sox_io")
19
 
20
 
21
  def get_conf():
22
- """取得模型參數設定"""
23
  conf_filterbank = {
24
  'n_filters': 64,
25
  'kernel_size': 16,
@@ -45,61 +32,19 @@ def get_conf():
45
 
46
  def load_dpt_model():
47
  print('Load Separation Model...')
48
-
49
- speech_sep_token = os.getenv("SpeechSeparation")
50
- if not speech_sep_token:
51
- raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
52
-
53
- model_path = hf_hub_download(
54
- repo_id="DeepLearning101/speech-separation",
55
- filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
56
- token=speech_sep_token
57
- )
58
-
59
  conf_filterbank, conf_masknet = get_conf()
60
-
61
- try:
62
- model_class = getattr(asteroid_test, "DPTNet")
63
- model = model_class(**conf_filterbank, **conf_masknet)
64
- except Exception as e:
65
- raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e
66
-
67
- model = torch.quantization.quantize_dynamic(
68
- model,
69
- {torch.nn.LSTM, torch.nn.Linear},
70
- dtype=torch.qint8
71
- )
72
-
73
  state_dict = torch.load(model_path, map_location="cpu")
74
- own_state = model.state_dict()
75
-
76
- # 只保留是 torch.Tensor 的 key-value pairs
77
- filtered_state_dict = {}
78
- for k, v in state_dict.items():
79
- if k in own_state:
80
- if isinstance(v, torch.Tensor) and isinstance(own_state[k], torch.Tensor):
81
- if v.shape == own_state[k].shape:
82
- filtered_state_dict[k] = v
83
- else:
84
- print(f"Skip '{k}': shape mismatch")
85
- else:
86
- print(f"Skip '{k}': not a tensor")
87
-
88
- missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
89
-
90
- if missing_keys:
91
- print("⚠️ Missing keys:", missing_keys)
92
- if unexpected_keys:
93
- print("ℹ️ Unexpected keys:", unexpected_keys)
94
-
95
  model.eval()
96
  return model
97
 
98
-
99
  def dpt_sep_process(wav_path, model=None, outfilename=None):
100
- """進行語音分離處理"""
101
  if model is None:
102
- model = load_dpt_model()
103
 
104
  x, sr = torchaudio.load(wav_path)
105
  x = x.cpu()
@@ -107,8 +52,10 @@ def dpt_sep_process(wav_path, model=None, outfilename=None):
107
  with torch.no_grad():
108
  est_sources = model(x) # shape: (1, 2, T)
109
 
 
110
  est_sources = est_sources.squeeze(0) # shape: (2, T)
111
- sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
 
112
 
113
  # 正規化
114
  max_abs = x[0].abs().max().item()
@@ -119,7 +66,6 @@ def dpt_sep_process(wav_path, model=None, outfilename=None):
119
  sep_1 = sep_1.unsqueeze(0)
120
  sep_2 = sep_2.unsqueeze(0)
121
 
122
- # 儲存結果
123
  if outfilename is not None:
124
  torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
125
  torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
@@ -127,7 +73,28 @@ def dpt_sep_process(wav_path, model=None, outfilename=None):
127
  else:
128
  torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
129
  torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
130
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  if __name__ == '__main__':
133
  print("This module should be used via Flask or Gradio.")
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import numpy as np
4
  import torchaudio
5
+ import yaml
6
+ from . import asteroid_test
 
 
 
 
 
 
 
7
 
8
 
9
  def get_conf():
 
10
  conf_filterbank = {
11
  'n_filters': 64,
12
  'kernel_size': 16,
 
32
 
33
  def load_dpt_model():
34
  print('Load Separation Model...')
35
+ now_path = os.path.split(os.path.realpath(__file__))[0]
 
 
 
 
 
 
 
 
 
 
36
  conf_filterbank, conf_masknet = get_conf()
37
+ model_path = os.path.join(now_path, "trained_model/train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p")
38
+ model = getattr(asteroid_test, "DPTNet")(**conf_filterbank, **conf_masknet)
39
+ model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
 
 
 
 
 
 
 
 
 
 
40
  state_dict = torch.load(model_path, map_location="cpu")
41
+ model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  model.eval()
43
  return model
44
 
 
45
  def dpt_sep_process(wav_path, model=None, outfilename=None):
 
46
  if model is None:
47
+ model = load_model()
48
 
49
  x, sr = torchaudio.load(wav_path)
50
  x = x.cpu()
 
52
  with torch.no_grad():
53
  est_sources = model(x) # shape: (1, 2, T)
54
 
55
+ # 確保 est_sources 是 (1, 2, T),再拆分
56
  est_sources = est_sources.squeeze(0) # shape: (2, T)
57
+
58
+ sep_1, sep_2 = est_sources # 拆成兩個 (T, ) 的 tensor
59
 
60
  # 正規化
61
  max_abs = x[0].abs().max().item()
 
66
  sep_1 = sep_1.unsqueeze(0)
67
  sep_2 = sep_2.unsqueeze(0)
68
 
 
69
  if outfilename is not None:
70
  torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
71
  torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
 
73
  else:
74
  torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
75
  torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
76
+
77
+ # def dpt_sep_process(wav_path, model=None, outfilename=None):
78
+ # if model == None:
79
+ # model = load_model()
80
+ # x, sr = torchaudio.load(wav_path)
81
+ # x = x.cpu()
82
+ # with torch.no_grad():
83
+ # est_sources = model(x)
84
+
85
+ # est_sources_np = est_sources.squeeze(0)
86
+
87
+ # sep_1, sep_2 = est_sources_np
88
+ # sep_1 = sep_1 * x[0].abs().max().item() / sep_1.abs().max().item()
89
+ # sep_2 = sep_2 * x[0].abs().max().item() / sep_2.abs().max().item()
90
+
91
+ # if outfilename != None:
92
+ # torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
93
+ # torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
94
+ # torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
95
+ # else:
96
+ # torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
97
+ # torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
98
 
99
  if __name__ == '__main__':
100
  print("This module should be used via Flask or Gradio.")