import torch import json import gc import spaces import librosa import soundfile as sf import numpy as np from pathlib import Path from typing import Dict, Tuple from utils import convert_to_stereo_and_wav from mdxnet_model import MDX, MDXModel import time STEM_NAMING = { "Vocals": "Instrumental", "Other": "Instruments", "Instrumental": "Vocals", "Drums": "Drumless", "Bass": "Bassless", } @spaces.GPU() def run_mdx(model_params: Dict, input_filename: Path, output_dir: Path, model_path: Path, denoise: bool = False, m_threads: int = 2, device_base: str = "cuda", ) -> Tuple[str, str]: """ Separate vocals using MDX model """ if device_base == "cuda": device = torch.device("cuda:0") processor_num = 0 device_properties = torch.cuda.get_device_properties(device) vram_gb = device_properties.total_memory / 1024**3 m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2) else: device = torch.device("cpu") processor_num = -1 m_threads = 1 print(f"device: {device}") model_hash = MDX.get_hash(model_path) # type: str mp = model_params.get(model_hash) model = MDXModel( device, dim_f=mp["mdx_dim_f_set"], dim_t=2 ** mp["mdx_dim_t_set"], n_fft=mp["mdx_n_fft_scale_set"], stem_name=mp["primary_stem"], compensation=mp["compensate"], ) mdx_sess = MDX(model_path, model, processor=processor_num) wave, sr = librosa.load(input_filename, mono=False, sr=44100) # normalizing input wave gives better output peak = max(np.max(wave), abs(np.min(wave))) wave /= peak if denoise: wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) # type: np.array wave_processed *= 0.5 else: wave_processed = mdx_sess.process_wave(wave, m_threads) # return to previous peak wave_processed *= peak stem_name = model.stem_name # output main track main_filepath = output_dir / input_filename.with_name(f"{input_filename.stem}_{stem_name}.wav") sf.write(main_filepath, wave_processed.T, sr) # output reverse track invert_filepath = output_dir / input_filename.with_name(f"{input_filename.stem}_{stem_name}_reverse.wav") sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr) del mdx_sess, wave_processed, wave gc.collect() torch.cuda.empty_cache() return main_filepath, invert_filepath @spaces.GPU() def run_mdx_return_np(model_params: Dict, input_filename: Path, model_path: Path, denoise: bool = False, m_threads: int = 2, device_base: str = "cuda", ) -> Tuple[np.ndarray, np.ndarray]: """ Separate vocals using MDX model """ if device_base == "cuda": device = torch.device("cuda:0") processor_num = 0 device_properties = torch.cuda.get_device_properties(device) vram_gb = device_properties.total_memory / 1024**3 m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2) else: device = torch.device("cpu") processor_num = -1 m_threads = 1 print(f"device: {device}") model_hash = MDX.get_hash(model_path) # type: str mp = model_params.get(model_hash) model = MDXModel( device, dim_f=mp["mdx_dim_f_set"], dim_t=2 ** mp["mdx_dim_t_set"], n_fft=mp["mdx_n_fft_scale_set"], stem_name=mp["primary_stem"], compensation=mp["compensate"], ) mdx_sess = MDX(model_path, model, processor=processor_num) wave, sr = librosa.load(input_filename, mono=False, sr=44100) # normalizing input wave gives better output peak = max(np.max(wave), abs(np.min(wave))) wave /= peak if denoise: wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) # type: np.array wave_processed *= 0.5 else: wave_processed = mdx_sess.process_wave(wave, m_threads) # return to previous peak wave_processed *= peak stem_name = model.stem_name # output main track main_track = wave_processed.T # output reverse track invert_track = (-wave_processed.T * model.compensation) + wave.T return main_track, invert_track def extract_bgm(mdx_model_params: Dict, input_filename: Path, model_bgm_path: Path, output_dir: Path, device_base: str = "cuda") -> Path: """ Extract pure background music, remove vocals """ background_path, _ = run_mdx(model_params=mdx_model_params, input_filename=input_filename, output_dir=output_dir, model_path=model_bgm_path, denoise=False, device_base=device_base, ) return background_path def extract_vocal(mdx_model_params: Dict, input_filename: Path, model_basic_vocal_path: Path, model_main_vocal_path: Path, output_dir: Path, main_vocals_flag: bool = False, device_base: str = "cuda") -> Path: """ Extract vocals """ # First use UVR-MDX-NET-Voc_FT.onnx basic vocal separation model vocals_path, _ = run_mdx(mdx_model_params, input_filename, output_dir, model_basic_vocal_path, denoise=True, device_base=device_base, ) # If "main_vocals_flag" is enabled, use UVR_MDXNET_KARA_2.onnx to further separate main vocals (Main) from backup vocals/background vocals (Backup) if main_vocals_flag: time.sleep(2) backup_vocals_path, main_vocals_path = run_mdx(mdx_model_params, output_dir, model_main_vocal_path, vocals_path, denoise=True, device_base=device_base, ) vocals_path = main_vocals_path # If "dereverb_flag" is enabled, use Reverb_HQ_By_FoxJoy.onnx for dereverberation # deactived since Model license unknown # if dereverb_flag: # time.sleep(2) # _, vocals_dereverb_path = run_mdx(mdx_model_params, # output_dir, # mdxnet_models_dir/"Reverb_HQ_By_FoxJoy.onnx", # vocals_path, # denoise=True, # device_base=device_base, # ) # vocals_path = vocals_dereverb_path return vocals_path def process_uvr_task(input_file_path: Path, output_dir: Path, models_path: Dict[str, Path], main_vocals_flag: bool = False, # If "Main" is enabled, use UVR_MDXNET_KARA_2.onnx to further separate main and backup vocals ) -> Tuple[Path, Path]: device_base = "cuda" if torch.cuda.is_available() else "cpu" # load mdx model definition with open("./mdx_models/model_data.json") as infile: mdx_model_params = json.load(infile) # type: Dict output_dir.mkdir(parents=True, exist_ok=True) input_file_path = convert_to_stereo_and_wav(input_file_path) # type: Path # 1. Extract pure background music, remove vocals background_path = extract_bgm(mdx_model_params, input_file_path, models_path["bgm"], output_dir, device_base=device_base) # 2. Separate vocals # First use UVR-MDX-NET-Voc_FT.onnx basic vocal separation model vocals_path = extract_vocal(mdx_model_params, input_file_path, models_path["basic_vocal"], models_path["main_vocal"], output_dir, main_vocals_flag=main_vocals_flag, device_base=device_base) return background_path, vocals_path def get_model_params(model_path: Path) -> Dict: """ Get model parameters from model path """ with open(model_path / "model_data.json") as infile: return json.load(infile) # type: Dict