Spaces:
Paused
Paused
from enum import Enum | |
import typing as tp | |
from .diffusion import ConditionedDiffusionModelWrapper | |
from ..inference.generation import generate_diffusion_cond | |
from ..inference.utils import prepare_audio | |
import torch | |
from torch.nn import functional as F | |
from torchaudio import transforms as T | |
# Define prior types enum | |
class PriorType(Enum): | |
MonoToStereo = 1 | |
SourceSeparation = 2 | |
class DiffusionPrior(ConditionedDiffusionModelWrapper): | |
def __init__(self, *args, prior_type: PriorType=None, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.prior_type = prior_type | |
class MonoToStereoDiffusionPrior(DiffusionPrior): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs) | |
def stereoize( | |
self, | |
audio: torch.Tensor, # (batch, channels, time) | |
in_sr: int, | |
steps: int, | |
sampler_kwargs: dict = {}, | |
): | |
""" | |
Generate stereo audio from mono audio using a pre-trained diffusion prior | |
Args: | |
audio: The mono audio to convert to stereo | |
in_sr: The sample rate of the input audio | |
steps: The number of diffusion steps to run | |
sampler_kwargs: Keyword arguments to pass to the diffusion sampler | |
""" | |
device = audio.device | |
sample_rate = self.sample_rate | |
# Resample input audio if necessary | |
if in_sr != sample_rate: | |
resample_tf = T.Resample(in_sr, sample_rate).to(audio.device) | |
audio = resample_tf(audio) | |
audio_length = audio.shape[-1] | |
# Pad input audio to be compatible with the model | |
min_length = self.min_input_length | |
padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length | |
# Pad input audio to be compatible with the model | |
if padded_input_length > audio_length: | |
audio = F.pad(audio, (0, padded_input_length - audio_length)) | |
# Make audio mono, duplicate to stereo | |
dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1) | |
if self.pretransform is not None: | |
dual_mono = self.pretransform.encode(dual_mono) | |
conditioning = {"source": [dual_mono]} | |
stereo_audio = generate_diffusion_cond( | |
self, | |
conditioning_tensors=conditioning, | |
steps=steps, | |
sample_size=padded_input_length, | |
sample_rate=sample_rate, | |
device=device, | |
**sampler_kwargs, | |
) | |
return stereo_audio | |
class SourceSeparationDiffusionPrior(DiffusionPrior): | |
""" | |
A diffusion prior model made for conditioned source separation | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, prior_type=PriorType.SourceSeparation, **kwargs) | |
def separate( | |
self, | |
mixed_audio: torch.Tensor, # (batch, channels, time) | |
in_sr: int, | |
steps: int, | |
conditioning: dict = None, | |
conditioning_tensors: tp.Optional[dict] = None, | |
sampler_kwargs: dict = {}, | |
): | |
""" | |
Separate audio sources based on conditioning using a pre-trained diffusion prior | |
Args: | |
mixed_audio: The mixed audio to separate | |
in_sr: The sample rate of the input audio | |
steps: The number of diffusion steps to run | |
conditioning: The conditioning to use for source separation | |
conditioning_tensors: Pre-computed conditioning tensors to use for source separation. If provided, conditioning is ignored. | |
sampler_kwargs: Keyword arguments to pass to the diffusion sampler | |
""" | |
device = mixed_audio.device | |
sample_rate = self.sample_rate | |
# Resample input audio if necessary | |
if in_sr != sample_rate: | |
resample_tf = T.Resample(in_sr, sample_rate).to(mixed_audio.device) | |
mixed_audio = resample_tf(mixed_audio) | |
audio_length = mixed_audio.shape[-1] | |
# Pad input audio to be compatible with the model | |
min_length = self.min_input_length | |
padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length | |
# Pad input audio to be compatible with the model | |
if padded_input_length > audio_length: | |
mixed_audio = F.pad(mixed_audio, (0, padded_input_length - audio_length)) | |
if self.pretransform is not None: | |
mixed_audio = self.pretransform.encode(mixed_audio) | |
# Conditioning | |
assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors for conditioned source separation" | |
if conditioning_tensors is None: | |
conditioning_tensors = self.conditioner(conditioning, device) | |
# Pass in the mixture audio as conditioning | |
conditioning_tensors["source"] = [mixed_audio] | |
stereo_audio = generate_diffusion_cond( | |
self, | |
conditioning_tensors=conditioning_tensors, | |
steps=steps, | |
sample_size=padded_input_length, | |
sample_rate=sample_rate, | |
device=device, | |
**sampler_kwargs, | |
) | |
return stereo_audio |