import gradio as gr import torch import torchaudio import os import tempfile import logging import traceback from datetime import datetime # 設定日誌系統 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 檢查 Hugging Face 環境變數 if not os.getenv("SpeechSeparation"): logger.warning("⚠️ 環境變數 SpeechSeparation 未設定!請在 Hugging Face Space 的 Secrets 中設定 HF_TOKEN") # 載入模型模組 try: logger.info("🔧 開始載入語音分離模型...") from DPTNet_eval.DPTNet_quant_sep import load_dpt_model, dpt_sep_process logger.info("✅ 模型模組載入成功") except ImportError as e: logger.error(f"❌ 模組載入失敗: {str(e)}") raise RuntimeError("本地模組路徑配置錯誤") from e # 全域模型初始化 try: logger.info("🔄 初始化模型中...") model = load_dpt_model() logger.info(f"🧠 模型載入完成,運行設備: {'GPU' if torch.cuda.is_available() else 'CPU'}") except Exception as e: logger.error(f"💣 模型初始化失敗: {str(e)}") raise RuntimeError("模型載入異常終止") from e def validate_audio(path): """驗證音檔格式與內容有效性""" try: info = torchaudio.info(path) logger.info(f"🔊 音檔資訊: 采樣率={info.sample_rate}Hz, 通道數={info.num_channels}") if info.num_channels not in [1, 2]: raise gr.Error("❌ 不支援的音檔通道數(僅支援單聲道或立體聲)") if info.sample_rate < 8000 or info.sample_rate > 48000: raise gr.Error("❌ 不支援的采樣率(需介於 8kHz~48kHz)") return info.sample_rate except Exception as e: logger.error(f"⚠️ 音檔驗證失敗: {str(e)}") raise gr.Error("❌ 無效的音訊檔案格式") def convert_to_wav(input_path): """統一轉換為 16kHz WAV 格式""" try: # 使用 torchaudio 保持一致性 waveform, sample_rate = torchaudio.load(input_path) # 單聲道轉換 if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # 重采樣至 16kHz if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # 建立臨時 WAV 檔案 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: torchaudio.save(tmpfile.name, waveform, 16000, bits_per_sample=16) logger.info(f"📝 已生成標準 WAV 檔案: {tmpfile.name}") return tmpfile.name except Exception as e: logger.error(f"⚠️ 音檔轉換失敗: {str(e)}") raise gr.Error("❌ 音訊格式轉換失敗") def separate_audio(input_audio): process_id = datetime.now().strftime("%Y%m%d%H%M%S%f") temp_wav = None try: logger.info(f"[{process_id}] 🚀 收到新請求: {input_audio}") # 1️⃣ 檢查檔案大小 if os.path.getsize(input_audio) > 50 * 1024 * 1024: raise gr.Error("❌ 檔案超過 50MB 限制") # 2️⃣ 轉換為標準格式 logger.info(f"[{process_id}] 🔁 轉換標準音檔...") temp_wav = convert_to_wav(input_audio) validate_audio(temp_wav) # 3️⃣ 建立固定輸出目錄 output_dir = os.path.join("/tmp/gradio_outputs", process_id) os.makedirs(output_dir) outfilename = os.path.join(output_dir, "output.wav") # 4️⃣ 執行語音分離 logger.info(f"[{process_id}] 🧠 開始分離...") sep_files = dpt_sep_process(temp_wav, model=model, outfilename=outfilename) # 5️⃣ 驗證輸出 for f in sep_files: if not os.path.exists(f): raise gr.Error(f"❌ 缺失輸出檔案: {f}") validate_audio(f) logger.info(f"[{process_id}] ✅ 處理完成") return sep_files except RuntimeError as e: if "CUDA out of memory" in str(e): logger.error(f"[{process_id}] 💥 GPU 記憶體不足") raise gr.Error("⚠️ 請縮短音檔長度") from e else: raise except Exception as e: logger.error(f"[{process_id}] ❌ 處理失敗: {str(e)}\n{traceback.format_exc()}") raise gr.Error(f"⚠️ 處理失敗: {str(e)}") from e finally: # 清理臨時檔案 if temp_wav and os.path.exists(temp_wav): os.unlink(temp_wav) logger.info(f"[{process_id}] 🧹 臨時檔案已清理") # 🎯 description 內容(轉為 HTML) description_html = """