|
|
|
import os |
|
import warnings |
|
|
|
import librosa |
|
import numpy as np |
|
import soundfile as sf |
|
import torch |
|
import torch.nn as nn |
|
import yaml |
|
from tqdm import tqdm |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class Roformer_Loader: |
|
def get_config(self, config_path): |
|
with open(config_path, "r", encoding="utf-8") as f: |
|
|
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
return config |
|
|
|
def get_default_config(self): |
|
default_config = None |
|
if self.model_type == "bs_roformer": |
|
|
|
|
|
|
|
default_config = { |
|
"audio": {"chunk_size": 352800, "sample_rate": 44100}, |
|
"model": { |
|
"dim": 512, |
|
"depth": 12, |
|
"stereo": True, |
|
"num_stems": 1, |
|
"time_transformer_depth": 1, |
|
"freq_transformer_depth": 1, |
|
"linear_transformer_depth": 0, |
|
"freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129), |
|
"dim_head": 64, |
|
"heads": 8, |
|
"attn_dropout": 0.1, |
|
"ff_dropout": 0.1, |
|
"flash_attn": True, |
|
"dim_freqs_in": 1025, |
|
"stft_n_fft": 2048, |
|
"stft_hop_length": 441, |
|
"stft_win_length": 2048, |
|
"stft_normalized": False, |
|
"mask_estimator_depth": 2, |
|
"multi_stft_resolution_loss_weight": 1.0, |
|
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256), |
|
"multi_stft_hop_size": 147, |
|
"multi_stft_normalized": False, |
|
}, |
|
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"}, |
|
"inference": {"batch_size": 2, "num_overlap": 2}, |
|
} |
|
|
|
elif self.model_type == "mel_band_roformer": |
|
|
|
|
|
default_config = { |
|
"audio": {"chunk_size": 352800, "sample_rate": 44100}, |
|
"model": { |
|
"dim": 384, |
|
"depth": 12, |
|
"stereo": True, |
|
"num_stems": 1, |
|
"time_transformer_depth": 1, |
|
"freq_transformer_depth": 1, |
|
"linear_transformer_depth": 0, |
|
"num_bands": 60, |
|
"dim_head": 64, |
|
"heads": 8, |
|
"attn_dropout": 0.1, |
|
"ff_dropout": 0.1, |
|
"flash_attn": True, |
|
"dim_freqs_in": 1025, |
|
"sample_rate": 44100, |
|
"stft_n_fft": 2048, |
|
"stft_hop_length": 441, |
|
"stft_win_length": 2048, |
|
"stft_normalized": False, |
|
"mask_estimator_depth": 2, |
|
"multi_stft_resolution_loss_weight": 1.0, |
|
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256), |
|
"multi_stft_hop_size": 147, |
|
"multi_stft_normalized": False, |
|
}, |
|
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"}, |
|
"inference": {"batch_size": 2, "num_overlap": 2}, |
|
} |
|
|
|
return default_config |
|
|
|
def get_model_from_config(self): |
|
if self.model_type == "bs_roformer": |
|
from bs_roformer.bs_roformer import BSRoformer |
|
|
|
model = BSRoformer(**dict(self.config["model"])) |
|
elif self.model_type == "mel_band_roformer": |
|
from bs_roformer.mel_band_roformer import MelBandRoformer |
|
|
|
model = MelBandRoformer(**dict(self.config["model"])) |
|
else: |
|
print("Error: Unknown model: {}".format(self.model_type)) |
|
model = None |
|
return model |
|
|
|
def demix_track(self, model, mix, device): |
|
C = self.config["audio"]["chunk_size"] |
|
N = self.config["inference"]["num_overlap"] |
|
fade_size = C // 10 |
|
step = int(C // N) |
|
border = C - step |
|
batch_size = self.config["inference"]["batch_size"] |
|
|
|
length_init = mix.shape[-1] |
|
progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False) |
|
|
|
|
|
if length_init > 2 * border and (border > 0): |
|
mix = nn.functional.pad(mix, (border, border), mode="reflect") |
|
|
|
|
|
window_size = C |
|
fadein = torch.linspace(0, 1, fade_size) |
|
fadeout = torch.linspace(1, 0, fade_size) |
|
window_start = torch.ones(window_size) |
|
window_middle = torch.ones(window_size) |
|
window_finish = torch.ones(window_size) |
|
window_start[-fade_size:] *= fadeout |
|
window_finish[:fade_size] *= fadein |
|
window_middle[-fade_size:] *= fadeout |
|
window_middle[:fade_size] *= fadein |
|
|
|
with torch.amp.autocast("cuda"): |
|
with torch.inference_mode(): |
|
if self.config["training"]["target_instrument"] is None: |
|
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape) |
|
else: |
|
req_shape = (1,) + tuple(mix.shape) |
|
|
|
result = torch.zeros(req_shape, dtype=torch.float32) |
|
counter = torch.zeros(req_shape, dtype=torch.float32) |
|
i = 0 |
|
batch_data = [] |
|
batch_locations = [] |
|
while i < mix.shape[1]: |
|
part = mix[:, i : i + C].to(device) |
|
length = part.shape[-1] |
|
if length < C: |
|
if length > C // 2 + 1: |
|
part = nn.functional.pad(input=part, pad=(0, C - length), mode="reflect") |
|
else: |
|
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode="constant", value=0) |
|
if self.is_half: |
|
part = part.half() |
|
batch_data.append(part) |
|
batch_locations.append((i, length)) |
|
i += step |
|
progress_bar.update(1) |
|
|
|
if len(batch_data) >= batch_size or (i >= mix.shape[1]): |
|
arr = torch.stack(batch_data, dim=0) |
|
|
|
x = model(arr) |
|
|
|
window = window_middle |
|
if i - step == 0: |
|
window = window_start |
|
elif i >= mix.shape[1]: |
|
window = window_finish |
|
|
|
for j in range(len(batch_locations)): |
|
start, l = batch_locations[j] |
|
result[..., start : start + l] += x[j][..., :l].cpu() * window[..., :l] |
|
counter[..., start : start + l] += window[..., :l] |
|
|
|
batch_data = [] |
|
batch_locations = [] |
|
|
|
estimated_sources = result / counter |
|
estimated_sources = estimated_sources.cpu().numpy() |
|
np.nan_to_num(estimated_sources, copy=False, nan=0.0) |
|
|
|
if length_init > 2 * border and (border > 0): |
|
|
|
estimated_sources = estimated_sources[..., border:-border] |
|
|
|
progress_bar.close() |
|
|
|
if self.config["training"]["target_instrument"] is None: |
|
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)} |
|
else: |
|
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)} |
|
|
|
def run_folder(self, input, vocal_root, others_root, format): |
|
self.model.eval() |
|
path = input |
|
os.makedirs(vocal_root, exist_ok=True) |
|
os.makedirs(others_root, exist_ok=True) |
|
file_base_name = os.path.splitext(os.path.basename(path))[0] |
|
|
|
sample_rate = 44100 |
|
if "sample_rate" in self.config["audio"]: |
|
sample_rate = self.config["audio"]["sample_rate"] |
|
|
|
try: |
|
mix, sr = librosa.load(path, sr=sample_rate, mono=False) |
|
except Exception as e: |
|
print("Can read track: {}".format(path)) |
|
print("Error message: {}".format(str(e))) |
|
return |
|
|
|
|
|
isstereo = self.config["model"].get("stereo", True) |
|
if not isstereo and len(mix.shape) != 1: |
|
mix = np.mean(mix, axis=0) |
|
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.") |
|
|
|
mix_orig = mix.copy() |
|
|
|
mixture = torch.tensor(mix, dtype=torch.float32) |
|
res = self.demix_track(self.model, mixture, self.device) |
|
|
|
if self.config["training"]["target_instrument"] is not None: |
|
|
|
|
|
target_instrument = self.config["training"]["target_instrument"] |
|
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument] |
|
other = mix_orig - res[target_instrument] |
|
|
|
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument) |
|
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0]) |
|
self.save_audio(path_vocal, res[target_instrument].T, sr, format) |
|
self.save_audio(path_other, other.T, sr, format) |
|
else: |
|
|
|
vocal_inst = self.config["training"]["instruments"][0] |
|
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst) |
|
self.save_audio(path_vocal, res[vocal_inst].T, sr, format) |
|
for other in self.config["training"]["instruments"][1:]: |
|
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other) |
|
self.save_audio(path_other, res[other].T, sr, format) |
|
|
|
def save_audio(self, path, data, sr, format): |
|
|
|
if format in ["wav", "flac"]: |
|
if format == "flac": |
|
path = path[:-3] + "flac" |
|
sf.write(path, data, sr) |
|
else: |
|
sf.write(path, data, sr) |
|
os.system('ffmpeg -i "{}" -vn "{}" -q:a 2 -y'.format(path, path[:-3] + format)) |
|
try: |
|
os.remove(path) |
|
except: |
|
pass |
|
|
|
def __init__(self, model_path, config_path, device, is_half): |
|
self.device = device |
|
self.is_half = is_half |
|
self.model_type = None |
|
self.config = None |
|
|
|
|
|
if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower(): |
|
self.model_type = "bs_roformer" |
|
elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower(): |
|
self.model_type = "mel_band_roformer" |
|
|
|
if not os.path.exists(config_path): |
|
if self.model_type is None: |
|
|
|
raise ValueError( |
|
"Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again." |
|
) |
|
self.config = self.get_default_config() |
|
else: |
|
|
|
self.config = self.get_config(config_path) |
|
if self.model_type is None: |
|
|
|
if "freqs_per_bands" in self.config["model"]: |
|
|
|
self.model_type = "bs_roformer" |
|
else: |
|
|
|
self.model_type = "mel_band_roformer" |
|
|
|
print("Detected model type: {}".format(self.model_type)) |
|
model = self.get_model_from_config() |
|
state_dict = torch.load(model_path, map_location="cpu") |
|
model.load_state_dict(state_dict) |
|
|
|
if is_half == False: |
|
self.model = model.to(device) |
|
else: |
|
self.model = model.half().to(device) |
|
|
|
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): |
|
self.run_folder(input, vocal_root, others_root, format) |
|
|