File size: 4,379 Bytes
64ceedd
 
 
 
cf73d23
 
38d7181
d2c5b31
 
38d7181
 
64ceedd
d8be50a
64ceedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38d7181
 
 
 
 
 
 
 
 
 
 
 
 
 
64ceedd
38d7181
 
cf73d23
38d7181
76875bc
 
 
 
 
 
 
 
 
 
cf73d23
d8be50a
64ceedd
 
38d7181
b75ae28
 
64ceedd
b75ae28
 
dad0fd1
 
 
 
 
 
 
 
 
b75ae28
dad0fd1
b75ae28
 
dad0fd1
 
 
 
 
 
 
 
 
 
 
b75ae28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8467c17
 
 
b75ae28
 
dad0fd1
 
 
 
 
64ceedd
 
b6c45cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import torch
import numpy as np
import torchaudio
import yaml
from . import asteroid_test
from huggingface_hub import hf_hub_download
import logging
logger = logging.getLogger(__name__)

torchaudio.set_audio_backend("sox_io")


def get_conf():
    conf_filterbank = {
        'n_filters': 64,
        'kernel_size': 16,
        'stride': 8
    }

    conf_masknet = {
        'in_chan': 64,
        'n_src': 2,
        'out_chan': 64,
        'ff_hid': 256,
        'ff_activation': "relu",
        'norm_type': "gLN",
        'chunk_size': 100,
        'hop_size': 50,
        'n_repeats': 2,
        'mask_act': 'sigmoid',
        'bidirectional': True,
        'dropout': 0
    }
    return conf_filterbank, conf_masknet


def load_dpt_model():
    print('Load Separation Model...')

    # 👇 從環境變數取得 HF Token
    from huggingface_hub import hf_hub_download
    speech_sep_token = os.getenv("SpeechSeparation")
    if not speech_sep_token:
        raise EnvironmentError("環境變數 SpeechSeparation 未設定!")

    # 👇 從 Hugging Face Hub 下載模型權重
    model_path = hf_hub_download(
        repo_id="DeepLearning101/speech-separation",  # 替換成你自己的 repo 名稱
        filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
        token=speech_sep_token
    )

    conf_filterbank, conf_masknet = get_conf()
    model_class = getattr(asteroid_test, "DPTNet")
    model = model_class(**conf_filterbank, **conf_masknet)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)

    try:
        state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
    except pickle.UnpicklingError as e:
        raise RuntimeError(
            "模型載入失敗!請確認:\n"
            "1. 模型來源是否可信\n"
            "2. 是否為舊版 PyTorch 儲存的模型\n"
            "3. 嘗試鎖定 PyTorch 版本為 2.5.x"
        ) from e

    model.load_state_dict(state_dict)
    model.eval()
    return model


import torchaudio
import tempfile

def dpt_sep_process(wav_path, model=None, outfilename=None):
    try:
        # 添加設備檢測
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        
        # 強化音訊加載
        x, sr = torchaudio.load(wav_path, format="wav")
        x = x.mean(dim=0, keepdim=True).to(device)
        
        # 自動重採樣
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(sr, 16000).to(device)
            x = resampler(x)
            sr = 16000
        
        # 分塊處理避免OOM
        chunk_size = sr * 60  # 每次處理1分鐘
        separated = []
        for i in range(0, x.shape[1], chunk_size):
            chunk = x[:, i:i+chunk_size]
            with torch.no_grad():
                est = model(chunk)
            separated.append(est.cpu())
        
        est_sources = torch.cat(separated, dim=2)

        # 後處理修正
        est_sources = est_sources.squeeze(0)
        sep_1, sep_2 = est_sources[0], est_sources[1]

        # 正規化增強
        peak = 0.9 * torch.max(torch.abs(x))
        sep_1 = peak * sep_1 / torch.max(torch.abs(sep_1))
        sep_2 = peak * sep_2 / torch.max(torch.abs(sep_2))

        # 使用臨時輸出目錄
        with tempfile.TemporaryDirectory() as tmp_dir:
            sep1_path = os.path.join(tmp_dir, "sep1.wav")
            sep2_path = os.path.join(tmp_dir, "sep2.wav")
            
            torchaudio.save(sep1_path, sep_1.unsqueeze(0), sr)
            torchaudio.save(sep2_path, sep_2.unsqueeze(0), sr)

            # 移動檔案到最終位置
            final_sep1 = outfilename.replace('.wav', '_sep1.wav')
            final_sep2 = outfilename.replace('.wav', '_sep2.wav')
            os.replace(sep1_path, final_sep1)
            os.replace(sep2_path, final_sep2)

        # 新增日誌
        logger.info(f"💾 寫入輸出檔案至: {final_sep1}, {final_sep2}")        
        
        return final_sep1, final_sep2

    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            raise RuntimeError("記憶體不足,請縮短音訊長度") from e
        else:
            raise

if __name__ == '__main__':
    print("This module should be used via Flask or Gradio.")