DeepLearning101 commited on
Commit
b75ae28
·
verified ·
1 Parent(s): e56b358

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +47 -30
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -61,37 +61,54 @@ def load_dpt_model():
61
  return model
62
 
63
 
64
- def dpt_sep_process(wav_path, model=None, outfilename=None):
65
- if model is None:
66
- model = load_dpt_model()
67
-
68
- x, sr = torchaudio.load(wav_path)
69
- x = x.cpu()
70
-
71
- with torch.no_grad():
72
- est_sources = model(x) # shape: (1, 2, T)
73
-
74
- # 確保 est_sources 是 (1, 2, T),再拆分
75
- est_sources = est_sources.squeeze(0) # shape: (2, T)
76
-
77
- sep_1, sep_2 = est_sources # 拆成兩個 (T, ) 的 tensor
78
-
79
- # 正規化
80
- max_abs = x[0].abs().max().item()
81
- sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
82
- sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
83
-
84
- # 增加 channel 維度,變為 (1, T)
85
- sep_1 = sep_1.unsqueeze(0)
86
- sep_2 = sep_2.unsqueeze(0)
87
 
88
- if outfilename is not None:
89
- torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
90
- torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
91
- torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
92
- else:
93
- torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
94
- torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  if __name__ == '__main__':
 
61
  return model
62
 
63
 
64
+ import torchaudio
65
+ import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ def dpt_sep_process(wav_path, model=None, outfilename=None):
68
+ try:
69
+ if model is None:
70
+ model = load_dpt_model()
71
+
72
+ # 使用 torchaudio 的通用加載方法
73
+ x, sr = torchaudio.load(wav_path, format=wav_path.split('.')[-1])
74
+ x = x.mean(dim=0, keepdim=True) # 強制轉單聲道
75
+
76
+ # 自動重採樣處理
77
+ if sr != 16000:
78
+ resampler = torchaudio.transforms.Resample(sr, 16000)
79
+ x = resampler(x)
80
+ sr = 16000
81
+
82
+ with torch.no_grad():
83
+ est_sources = model(x)
84
+
85
+ # 後處理修正
86
+ est_sources = est_sources.squeeze(0)
87
+ sep_1, sep_2 = est_sources[0], est_sources[1]
88
+
89
+ # 正規化增強
90
+ peak = 0.9 * torch.max(torch.abs(x))
91
+ sep_1 = peak * sep_1 / torch.max(torch.abs(sep_1))
92
+ sep_2 = peak * sep_2 / torch.max(torch.abs(sep_2))
93
+
94
+ # 使用臨時輸出目錄
95
+ with tempfile.TemporaryDirectory() as tmp_dir:
96
+ sep1_path = os.path.join(tmp_dir, "sep1.wav")
97
+ sep2_path = os.path.join(tmp_dir, "sep2.wav")
98
+
99
+ torchaudio.save(sep1_path, sep_1.unsqueeze(0), sr)
100
+ torchaudio.save(sep2_path, sep_2.unsqueeze(0), sr)
101
+
102
+ # 移動檔案到最終位置
103
+ final_sep1 = outfilename.replace('.wav', '_sep1.wav')
104
+ final_sep2 = outfilename.replace('.wav', '_sep2.wav')
105
+ os.replace(sep1_path, final_sep1)
106
+ os.replace(sep2_path, final_sep2)
107
+
108
+ return final_sep1, final_sep2
109
+
110
+ except Exception as e:
111
+ raise RuntimeError(f"分離過程錯誤: {str(e)}") from e
112
 
113
 
114
  if __name__ == '__main__':