File size: 4,379 Bytes
64ceedd cf73d23 38d7181 d2c5b31 38d7181 64ceedd d8be50a 64ceedd 38d7181 64ceedd 38d7181 cf73d23 38d7181 76875bc cf73d23 d8be50a 64ceedd 38d7181 b75ae28 64ceedd b75ae28 dad0fd1 b75ae28 dad0fd1 b75ae28 dad0fd1 b75ae28 8467c17 b75ae28 dad0fd1 64ceedd b6c45cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import torch
import numpy as np
import torchaudio
import yaml
from . import asteroid_test
from huggingface_hub import hf_hub_download
import logging
logger = logging.getLogger(__name__)
torchaudio.set_audio_backend("sox_io")
def get_conf():
conf_filterbank = {
'n_filters': 64,
'kernel_size': 16,
'stride': 8
}
conf_masknet = {
'in_chan': 64,
'n_src': 2,
'out_chan': 64,
'ff_hid': 256,
'ff_activation': "relu",
'norm_type': "gLN",
'chunk_size': 100,
'hop_size': 50,
'n_repeats': 2,
'mask_act': 'sigmoid',
'bidirectional': True,
'dropout': 0
}
return conf_filterbank, conf_masknet
def load_dpt_model():
print('Load Separation Model...')
# 👇 從環境變數取得 HF Token
from huggingface_hub import hf_hub_download
speech_sep_token = os.getenv("SpeechSeparation")
if not speech_sep_token:
raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
# 👇 從 Hugging Face Hub 下載模型權重
model_path = hf_hub_download(
repo_id="DeepLearning101/speech-separation", # 替換成你自己的 repo 名稱
filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
token=speech_sep_token
)
conf_filterbank, conf_masknet = get_conf()
model_class = getattr(asteroid_test, "DPTNet")
model = model_class(**conf_filterbank, **conf_masknet)
model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
try:
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
except pickle.UnpicklingError as e:
raise RuntimeError(
"模型載入失敗!請確認:\n"
"1. 模型來源是否可信\n"
"2. 是否為舊版 PyTorch 儲存的模型\n"
"3. 嘗試鎖定 PyTorch 版本為 2.5.x"
) from e
model.load_state_dict(state_dict)
model.eval()
return model
import torchaudio
import tempfile
def dpt_sep_process(wav_path, model=None, outfilename=None):
try:
# 添加設備檢測
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 強化音訊加載
x, sr = torchaudio.load(wav_path, format="wav")
x = x.mean(dim=0, keepdim=True).to(device)
# 自動重採樣
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000).to(device)
x = resampler(x)
sr = 16000
# 分塊處理避免OOM
chunk_size = sr * 60 # 每次處理1分鐘
separated = []
for i in range(0, x.shape[1], chunk_size):
chunk = x[:, i:i+chunk_size]
with torch.no_grad():
est = model(chunk)
separated.append(est.cpu())
est_sources = torch.cat(separated, dim=2)
# 後處理修正
est_sources = est_sources.squeeze(0)
sep_1, sep_2 = est_sources[0], est_sources[1]
# 正規化增強
peak = 0.9 * torch.max(torch.abs(x))
sep_1 = peak * sep_1 / torch.max(torch.abs(sep_1))
sep_2 = peak * sep_2 / torch.max(torch.abs(sep_2))
# 使用臨時輸出目錄
with tempfile.TemporaryDirectory() as tmp_dir:
sep1_path = os.path.join(tmp_dir, "sep1.wav")
sep2_path = os.path.join(tmp_dir, "sep2.wav")
torchaudio.save(sep1_path, sep_1.unsqueeze(0), sr)
torchaudio.save(sep2_path, sep_2.unsqueeze(0), sr)
# 移動檔案到最終位置
final_sep1 = outfilename.replace('.wav', '_sep1.wav')
final_sep2 = outfilename.replace('.wav', '_sep2.wav')
os.replace(sep1_path, final_sep1)
os.replace(sep2_path, final_sep2)
# 新增日誌
logger.info(f"💾 寫入輸出檔案至: {final_sep1}, {final_sep2}")
return final_sep1, final_sep2
except RuntimeError as e:
if "CUDA out of memory" in str(e):
raise RuntimeError("記憶體不足,請縮短音訊長度") from e
else:
raise
if __name__ == '__main__':
print("This module should be used via Flask or Gradio.") |