import gradio as gr import torch import torchaudio import os import tempfile import logging import traceback from datetime import datetime # 設定日誌系統 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 檢查 Hugging Face 環境變數 if not os.getenv("SpeechSeparation"): logger.warning("⚠️ 環境變數 SpeechSeparation 未設定!請在 Hugging Face Space 的 Secrets 中設定 HF_TOKEN") # 載入模型模組 try: logger.info("🔧 開始載入語音分離模型...") from DPTNet_eval.DPTNet_quant_sep import load_dpt_model, dpt_sep_process logger.info("✅ 模型模組載入成功") except ImportError as e: logger.error(f"❌ 模組載入失敗: {str(e)}") raise RuntimeError("本地模組路徑配置錯誤") from e # 全域模型初始化 try: logger.info("🔄 初始化模型中...") model = load_dpt_model() logger.info(f"🧠 模型載入完成,運行設備: {'GPU' if torch.cuda.is_available() else 'CPU'}") except Exception as e: logger.error(f"💣 模型初始化失敗: {str(e)}") raise RuntimeError("模型載入異常終止") from e def validate_audio(path): """驗證音檔格式與內容有效性""" try: info = torchaudio.info(path) logger.info(f"🔊 音檔資訊: 采樣率={info.sample_rate}Hz, 通道數={info.num_channels}") if info.num_channels not in [1, 2]: raise gr.Error("❌ 不支援的音檔通道數(僅支援單聲道或立體聲)") if info.sample_rate < 8000 or info.sample_rate > 48000: raise gr.Error("❌ 不支援的采樣率(需介於 8kHz~48kHz)") return info.sample_rate except Exception as e: logger.error(f"⚠️ 音檔驗證失敗: {str(e)}") raise gr.Error("❌ 無效的音訊檔案格式") def convert_to_wav(input_path): """統一轉換為 16kHz WAV 格式""" try: # 使用 torchaudio 保持一致性 waveform, sample_rate = torchaudio.load(input_path) # 單聲道轉換 if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # 重采樣至 16kHz if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # 建立臨時 WAV 檔案 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: torchaudio.save(tmpfile.name, waveform, 16000, bits_per_sample=16) logger.info(f"📝 已生成標準 WAV 檔案: {tmpfile.name}") return tmpfile.name except Exception as e: logger.error(f"⚠️ 音檔轉換失敗: {str(e)}") raise gr.Error("❌ 音訊格式轉換失敗") def separate_audio(input_audio): process_id = datetime.now().strftime("%Y%m%d%H%M%S%f") temp_wav = None try: logger.info(f"[{process_id}] 🚀 收到新請求: {input_audio}") # 1️⃣ 檢查檔案大小 if os.path.getsize(input_audio) > 50 * 1024 * 1024: raise gr.Error("❌ 檔案超過 50MB 限制") # 2️⃣ 轉換為標準格式 logger.info(f"[{process_id}] 🔁 轉換標準音檔...") temp_wav = convert_to_wav(input_audio) validate_audio(temp_wav) # 3️⃣ 建立固定輸出目錄 output_dir = os.path.join("/tmp/gradio_outputs", process_id) os.makedirs(output_dir) outfilename = os.path.join(output_dir, "output.wav") # 4️⃣ 執行語音分離 logger.info(f"[{process_id}] 🧠 開始分離...") sep_files = dpt_sep_process(temp_wav, model=model, outfilename=outfilename) # 5️⃣ 驗證輸出 for f in sep_files: if not os.path.exists(f): raise gr.Error(f"❌ 缺失輸出檔案: {f}") validate_audio(f) logger.info(f"[{process_id}] ✅ 處理完成") return sep_files except RuntimeError as e: if "CUDA out of memory" in str(e): logger.error(f"[{process_id}] 💥 GPU 記憶體不足") raise gr.Error("⚠️ 請縮短音檔長度") from e else: raise except Exception as e: logger.error(f"[{process_id}] ❌ 處理失敗: {str(e)}\n{traceback.format_exc()}") raise gr.Error(f"⚠️ 處理失敗: {str(e)}") from e finally: # 清理臨時檔案 if temp_wav and os.path.exists(temp_wav): os.unlink(temp_wav) logger.info(f"[{process_id}] 🧹 臨時檔案已清理") # 🎯 description 內容(轉為 HTML) description_html = """

中文語者分離(分割)

TonTon Huang Ph.D. | AI | 手把手帶你一起踩AI坑 | GitHub | Deep Learning 101 | YouTube


""" EXAMPLES = [ ["examples/sample1.wav"], ["examples/sample2.wav"] ] AUDIO_INPUT = gr.Audio( label="🔊 上傳混合音檔", type="filepath", sources=["upload", "microphone"], show_label=True, max_length=180 # 最大 3 分鐘 ) # 修改 Gradio 輸出設定 AUDIO_OUTPUTS = [ gr.Audio(label="🗣️ 語音軌道 1", type="filepath", format="wav"), gr.Audio(label="🗣️ 語音軌道 2", type="filepath", format="wav") ] # 🚀 啟動應用程式 interface = gr.Interface( fn=separate_audio, inputs=AUDIO_INPUT, outputs=AUDIO_OUTPUTS, title="🎙️ 語音分離,上傳一段混音音檔(支援.mp3, .wav),自動分離出兩個人的聲音;Deep Learning 101", description=description_html, examples=EXAMPLES, allow_flagging="never", cache_examples=False, theme="default" ) LAUNCH_CONFIG = { "server_name": "0.0.0.0", "server_port": int(os.environ.get("PORT", 7860)), # 預設值是給本地測試用 "share": False, "debug": True, "auth": None, "inbrowser": True, "quiet": False } if __name__ == "__main__": logger.info("🚀 啟動 Gradio 服務...") interface.launch(**LAUNCH_CONFIG)