import os import torch import numpy as np import torchaudio import yaml from . import asteroid_test from huggingface_hub import hf_hub_download import logging logger = logging.getLogger(__name__) torchaudio.set_audio_backend("sox_io") def get_conf(): conf_filterbank = { 'n_filters': 64, 'kernel_size': 16, 'stride': 8 } conf_masknet = { 'in_chan': 64, 'n_src': 2, 'out_chan': 64, 'ff_hid': 256, 'ff_activation': "relu", 'norm_type': "gLN", 'chunk_size': 100, 'hop_size': 50, 'n_repeats': 2, 'mask_act': 'sigmoid', 'bidirectional': True, 'dropout': 0 } return conf_filterbank, conf_masknet def load_dpt_model(): print('Load Separation Model...') # 👇 從環境變數取得 HF Token from huggingface_hub import hf_hub_download speech_sep_token = os.getenv("SpeechSeparation") if not speech_sep_token: raise EnvironmentError("環境變數 SpeechSeparation 未設定!") # 👇 從 Hugging Face Hub 下載模型權重 model_path = hf_hub_download( repo_id="DeepLearning101/speech-separation", # 替換成你自己的 repo 名稱 filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p", token=speech_sep_token ) conf_filterbank, conf_masknet = get_conf() model_class = getattr(asteroid_test, "DPTNet") model = model_class(**conf_filterbank, **conf_masknet) model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8) try: state_dict = torch.load(model_path, map_location="cpu", weights_only=False) except pickle.UnpicklingError as e: raise RuntimeError( "模型載入失敗!請確認:\n" "1. 模型來源是否可信\n" "2. 是否為舊版 PyTorch 儲存的模型\n" "3. 嘗試鎖定 PyTorch 版本為 2.5.x" ) from e model.load_state_dict(state_dict) model.eval() return model import torchaudio import tempfile def dpt_sep_process(wav_path, model=None, outfilename=None): try: # 添加設備檢測 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # 強化音訊加載 x, sr = torchaudio.load(wav_path, format="wav") x = x.mean(dim=0, keepdim=True).to(device) # 自動重採樣 if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000).to(device) x = resampler(x) sr = 16000 # 分塊處理避免OOM chunk_size = sr * 60 # 每次處理1分鐘 separated = [] for i in range(0, x.shape[1], chunk_size): chunk = x[:, i:i+chunk_size] with torch.no_grad(): est = model(chunk) separated.append(est.cpu()) est_sources = torch.cat(separated, dim=2) # 後處理修正 est_sources = est_sources.squeeze(0) sep_1, sep_2 = est_sources[0], est_sources[1] # 正規化增強 peak = 0.9 * torch.max(torch.abs(x)) sep_1 = peak * sep_1 / torch.max(torch.abs(sep_1)) sep_2 = peak * sep_2 / torch.max(torch.abs(sep_2)) # 使用臨時輸出目錄 with tempfile.TemporaryDirectory() as tmp_dir: sep1_path = os.path.join(tmp_dir, "sep1.wav") sep2_path = os.path.join(tmp_dir, "sep2.wav") torchaudio.save(sep1_path, sep_1.unsqueeze(0), sr) torchaudio.save(sep2_path, sep_2.unsqueeze(0), sr) # 移動檔案到最終位置 final_sep1 = outfilename.replace('.wav', '_sep1.wav') final_sep2 = outfilename.replace('.wav', '_sep2.wav') os.replace(sep1_path, final_sep1) os.replace(sep2_path, final_sep2) # 新增日誌 logger.info(f"💾 寫入輸出檔案至: {final_sep1}, {final_sep2}") return final_sep1, final_sep2 except RuntimeError as e: if "CUDA out of memory" in str(e): raise RuntimeError("記憶體不足,請縮短音訊長度") from e else: raise if __name__ == '__main__': print("This module should be used via Flask or Gradio.")