DeepLearning101 commited on
Commit
e56b358
·
verified ·
1 Parent(s): 9c54cd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -36
app.py CHANGED
@@ -1,47 +1,57 @@
1
  import gradio as gr
2
  import torch
3
- from DPTNet_eval.DPTNet_quant_sep import load_dpt_model, dpt_sep_process
4
  import os
5
  import soundfile as sf
6
  import numpy as np
7
  import librosa
8
  import warnings
 
 
9
 
10
- # 加載模型
 
 
 
 
11
  model = load_dpt_model()
12
 
13
  def separate_audio(input_wav):
14
- """
15
- Gradio Audio(filepath) → 處理 → 回傳兩個分離後的音檔路徑
16
- """
17
- file_extension = os.path.splitext(input_wav)[1].lower()
18
-
19
- # 如果是 MP3 或其他格式,先轉成 WAV
20
- if file_extension != ".wav":
21
- data, sr = sf.read(input_wav)
22
-
23
- # 轉單聲道
24
- if len(data.shape) > 1:
25
- data = data.mean(axis=1)
26
-
27
- # 重採樣到 16kHz
28
  if sr != 16000:
29
  data = librosa.resample(data, orig_sr=sr, target_sr=16000)
30
-
31
- # 存成 WAV
32
- sf.write("input.wav", data, 16000)
33
- wav_path = "input.wav"
34
- else:
35
- wav_path = input_wav
36
-
37
- # 分離語音
38
- outfilename = "output.wav"
39
- dpt_sep_process(wav_path, model=model, outfilename=outfilename)
40
-
41
- return (
42
- outfilename.replace('.wav', '_sep1.wav'),
43
- outfilename.replace('.wav', '_sep2.wav')
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # 🎯 你提供的 description 內容(已轉為 HTML)
47
  description_html = """
@@ -87,16 +97,31 @@ description_html = """
87
  """
88
 
89
  if __name__ == "__main__":
 
90
  interface = gr.Interface(
91
  fn=separate_audio,
92
- inputs=gr.Audio(type="filepath", label="請上傳混音音檔 (.mp3/.wav)"),
 
 
 
 
93
  outputs=[
94
- gr.Audio(label="語音 1"),
95
- gr.Audio(label="語音 2")
96
  ],
97
  title="🎙️ 語音分離 Demo - Deep Learning 101",
98
  description=description_html,
99
- allow_flagging="never"
 
 
 
 
100
  )
101
 
102
- interface.launch(debug=True)
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
 
3
  import os
4
  import soundfile as sf
5
  import numpy as np
6
  import librosa
7
  import warnings
8
+ import tempfile
9
+ from DPTNet_eval.DPTNet_quant_sep import load_dpt_model, dpt_sep_process
10
 
11
+ # 過濾警告訊息
12
+ warnings.filterwarnings("ignore", category=UserWarning)
13
+ warnings.filterwarnings("ignore", category=FutureWarning)
14
+
15
+ # 加載模型(全局變量)
16
  model = load_dpt_model()
17
 
18
  def separate_audio(input_wav):
19
+ """處理音訊分離的主要函數"""
20
+ try:
21
+ # 步驟 1:讀取音訊並標準化格式
22
+ data, sr = librosa.load(input_wav, sr=None, mono=True)
23
+
24
+ # 步驟 2:強制重採樣到 16kHz
 
 
 
 
 
 
 
 
25
  if sr != 16000:
26
  data = librosa.resample(data, orig_sr=sr, target_sr=16000)
27
+ sr = 16000
28
+
29
+ # 步驟 3:生成唯一臨時檔案
30
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
31
+ temp_wav = tmp_file.name
32
+ sf.write(temp_wav, data, sr, subtype='PCM_16')
33
+
34
+ # 步驟 4:執行語音分離
35
+ outfilename = "output.wav"
36
+ dpt_sep_process(temp_wav, model=model, outfilename=outfilename)
37
+
38
+ # 步驟 5:清理臨時檔案
39
+ os.remove(temp_wav)
40
+
41
+ # 步驟 6:驗證輸出檔案存在
42
+ output_files = [
43
+ outfilename.replace('.wav', '_sep1.wav'),
44
+ outfilename.replace('.wav', '_sep2.wav')
45
+ ]
46
+ if not all(os.path.exists(f) for f in output_files):
47
+ raise gr.Error("分離過程中發生錯誤,請檢查輸入檔案格式!")
48
+
49
+ return output_files
50
+
51
+ except Exception as e:
52
+ # 錯誤處理
53
+ error_msg = f"處理失敗:{str(e)}"
54
+ raise gr.Error(error_msg) from e
55
 
56
  # 🎯 你提供的 description 內容(已轉為 HTML)
57
  description_html = """
 
97
  """
98
 
99
  if __name__ == "__main__":
100
+ # 配置 Gradio 介面
101
  interface = gr.Interface(
102
  fn=separate_audio,
103
+ inputs=gr.Audio(
104
+ type="filepath",
105
+ label="請上傳混音音檔 (支援格式:mp3/wav/ogg)",
106
+ max_length=300 # 限制 5 分鐘長度
107
+ ),
108
  outputs=[
109
+ gr.Audio(label="語音軌道 1"),
110
+ gr.Audio(label="語音軌道 2")
111
  ],
112
  title="🎙️ 語音分離 Demo - Deep Learning 101",
113
  description=description_html,
114
+ allow_flagging="never",
115
+ examples=[
116
+ [os.path.join("examples", "sample1.wav")],
117
+ [os.path.join("examples", "sample2.mp3")]
118
+ ]
119
  )
120
 
121
+ # 啟動服務
122
+ interface.launch(
123
+ server_name="0.0.0.0",
124
+ server_port=7860,
125
+ share=False,
126
+ debug=False
127
+ )