File size: 3,918 Bytes
64ceedd
 
 
 
cf73d23
 
38d7181
 
 
64ceedd
d8be50a
64ceedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38d7181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ceedd
38d7181
 
cf73d23
38d7181
64ceedd
cf73d23
d8be50a
64ceedd
 
38d7181
b75ae28
 
64ceedd
b75ae28
 
dad0fd1
 
 
 
 
 
 
 
 
b75ae28
dad0fd1
b75ae28
 
dad0fd1
 
 
 
 
 
 
 
 
 
 
b75ae28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dad0fd1
 
 
 
 
64ceedd
 
b6c45cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import torch
import numpy as np
import torchaudio
import yaml
from . import asteroid_test
from huggingface_hub import hf_hub_download

torchaudio.set_audio_backend("sox_io")


def get_conf():
    conf_filterbank = {
        'n_filters': 64,
        'kernel_size': 16,
        'stride': 8
    }

    conf_masknet = {
        'in_chan': 64,
        'n_src': 2,
        'out_chan': 64,
        'ff_hid': 256,
        'ff_activation': "relu",
        'norm_type': "gLN",
        'chunk_size': 100,
        'hop_size': 50,
        'n_repeats': 2,
        'mask_act': 'sigmoid',
        'bidirectional': True,
        'dropout': 0
    }
    return conf_filterbank, conf_masknet


def load_dpt_model():
    print('Load Separation Model...')

    # 👇 從環境變數取得 HF Token
    from huggingface_hub import hf_hub_download
    speech_sep_token = os.getenv("SpeechSeparation")
    if not speech_sep_token:
        raise EnvironmentError("環境變數 SpeechSeparation 未設定!")

    # 👇 從 Hugging Face Hub 下載模型權重
    model_path = hf_hub_download(
        repo_id="DeepLearning101/speech-separation",  # 替換成你自己的 repo 名稱
        filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
        token=speech_sep_token
    )

    # 👇 原本邏輯完全不變
    conf_filterbank, conf_masknet = get_conf()
    model_class = getattr(asteroid_test, "DPTNet")
    model = model_class(**conf_filterbank, **conf_masknet)
    model = torch.quantization.quantize_dynamic(model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)

    state_dict = torch.load(model_path, map_location="cpu")
    model.load_state_dict(state_dict)
    model.eval()
    return model


import torchaudio
import tempfile

def dpt_sep_process(wav_path, model=None, outfilename=None):
    try:
        # 添加設備檢測
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        
        # 強化音訊加載
        x, sr = torchaudio.load(wav_path, format="wav")
        x = x.mean(dim=0, keepdim=True).to(device)
        
        # 自動重採樣
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(sr, 16000).to(device)
            x = resampler(x)
            sr = 16000
        
        # 分塊處理避免OOM
        chunk_size = sr * 60  # 每次處理1分鐘
        separated = []
        for i in range(0, x.shape[1], chunk_size):
            chunk = x[:, i:i+chunk_size]
            with torch.no_grad():
                est = model(chunk)
            separated.append(est.cpu())
        
        est_sources = torch.cat(separated, dim=2)

        # 後處理修正
        est_sources = est_sources.squeeze(0)
        sep_1, sep_2 = est_sources[0], est_sources[1]

        # 正規化增強
        peak = 0.9 * torch.max(torch.abs(x))
        sep_1 = peak * sep_1 / torch.max(torch.abs(sep_1))
        sep_2 = peak * sep_2 / torch.max(torch.abs(sep_2))

        # 使用臨時輸出目錄
        with tempfile.TemporaryDirectory() as tmp_dir:
            sep1_path = os.path.join(tmp_dir, "sep1.wav")
            sep2_path = os.path.join(tmp_dir, "sep2.wav")
            
            torchaudio.save(sep1_path, sep_1.unsqueeze(0), sr)
            torchaudio.save(sep2_path, sep_2.unsqueeze(0), sr)

            # 移動檔案到最終位置
            final_sep1 = outfilename.replace('.wav', '_sep1.wav')
            final_sep2 = outfilename.replace('.wav', '_sep2.wav')
            os.replace(sep1_path, final_sep1)
            os.replace(sep2_path, final_sep2)

        return final_sep1, final_sep2

    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            raise RuntimeError("記憶體不足,請縮短音訊長度") from e
        else:
            raise

if __name__ == '__main__':
    print("This module should be used via Flask or Gradio.")