DeepLearning101 commited on
Commit
dad0fd1
·
verified ·
1 Parent(s): b75ae28

Update DPTNet_eval/DPTNet_quant_sep.py

Browse files
Files changed (1) hide show
  1. DPTNet_eval/DPTNet_quant_sep.py +26 -15
DPTNet_eval/DPTNet_quant_sep.py CHANGED
@@ -66,21 +66,30 @@ import tempfile
66
 
67
  def dpt_sep_process(wav_path, model=None, outfilename=None):
68
  try:
69
- if model is None:
70
- model = load_dpt_model()
71
-
72
- # 使用 torchaudio 的通用加載方法
73
- x, sr = torchaudio.load(wav_path, format=wav_path.split('.')[-1])
74
- x = x.mean(dim=0, keepdim=True) # 強制轉單聲道
75
-
76
- # 自動重採樣處理
 
77
  if sr != 16000:
78
- resampler = torchaudio.transforms.Resample(sr, 16000)
79
  x = resampler(x)
80
  sr = 16000
81
-
82
- with torch.no_grad():
83
- est_sources = model(x)
 
 
 
 
 
 
 
 
84
 
85
  # 後處理修正
86
  est_sources = est_sources.squeeze(0)
@@ -107,9 +116,11 @@ def dpt_sep_process(wav_path, model=None, outfilename=None):
107
 
108
  return final_sep1, final_sep2
109
 
110
- except Exception as e:
111
- raise RuntimeError(f"分離過程錯誤: {str(e)}") from e
112
-
 
 
113
 
114
  if __name__ == '__main__':
115
  print("This module should be used via Flask or Gradio.")
 
66
 
67
  def dpt_sep_process(wav_path, model=None, outfilename=None):
68
  try:
69
+ # 添加設備檢測
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ model = model.to(device)
72
+
73
+ # 強化音訊加載
74
+ x, sr = torchaudio.load(wav_path, format="wav")
75
+ x = x.mean(dim=0, keepdim=True).to(device)
76
+
77
+ # 自動重採樣
78
  if sr != 16000:
79
+ resampler = torchaudio.transforms.Resample(sr, 16000).to(device)
80
  x = resampler(x)
81
  sr = 16000
82
+
83
+ # 分塊處理避免OOM
84
+ chunk_size = sr * 60 # 每次處理1分鐘
85
+ separated = []
86
+ for i in range(0, x.shape[1], chunk_size):
87
+ chunk = x[:, i:i+chunk_size]
88
+ with torch.no_grad():
89
+ est = model(chunk)
90
+ separated.append(est.cpu())
91
+
92
+ est_sources = torch.cat(separated, dim=2)
93
 
94
  # 後處理修正
95
  est_sources = est_sources.squeeze(0)
 
116
 
117
  return final_sep1, final_sep2
118
 
119
+ except RuntimeError as e:
120
+ if "CUDA out of memory" in str(e):
121
+ raise RuntimeError("記憶體不足,請縮短音訊長度") from e
122
+ else:
123
+ raise
124
 
125
  if __name__ == '__main__':
126
  print("This module should be used via Flask or Gradio.")