audio_test / stable_audio_tools /models /diffusion_prior.py
“fred-dev”
added repo files
5915064
from enum import Enum
import typing as tp
from .diffusion import ConditionedDiffusionModelWrapper
from ..inference.generation import generate_diffusion_cond
from ..inference.utils import prepare_audio
import torch
from torch.nn import functional as F
from torchaudio import transforms as T
# Define prior types enum
class PriorType(Enum):
MonoToStereo = 1
SourceSeparation = 2
class DiffusionPrior(ConditionedDiffusionModelWrapper):
def __init__(self, *args, prior_type: PriorType=None, **kwargs):
super().__init__(*args, **kwargs)
self.prior_type = prior_type
class MonoToStereoDiffusionPrior(DiffusionPrior):
def __init__(self, *args, **kwargs):
super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs)
def stereoize(
self,
audio: torch.Tensor, # (batch, channels, time)
in_sr: int,
steps: int,
sampler_kwargs: dict = {},
):
"""
Generate stereo audio from mono audio using a pre-trained diffusion prior
Args:
audio: The mono audio to convert to stereo
in_sr: The sample rate of the input audio
steps: The number of diffusion steps to run
sampler_kwargs: Keyword arguments to pass to the diffusion sampler
"""
device = audio.device
sample_rate = self.sample_rate
# Resample input audio if necessary
if in_sr != sample_rate:
resample_tf = T.Resample(in_sr, sample_rate).to(audio.device)
audio = resample_tf(audio)
audio_length = audio.shape[-1]
# Pad input audio to be compatible with the model
min_length = self.min_input_length
padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
# Pad input audio to be compatible with the model
if padded_input_length > audio_length:
audio = F.pad(audio, (0, padded_input_length - audio_length))
# Make audio mono, duplicate to stereo
dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1)
if self.pretransform is not None:
dual_mono = self.pretransform.encode(dual_mono)
conditioning = {"source": [dual_mono]}
stereo_audio = generate_diffusion_cond(
self,
conditioning_tensors=conditioning,
steps=steps,
sample_size=padded_input_length,
sample_rate=sample_rate,
device=device,
**sampler_kwargs,
)
return stereo_audio
class SourceSeparationDiffusionPrior(DiffusionPrior):
"""
A diffusion prior model made for conditioned source separation
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, prior_type=PriorType.SourceSeparation, **kwargs)
def separate(
self,
mixed_audio: torch.Tensor, # (batch, channels, time)
in_sr: int,
steps: int,
conditioning: dict = None,
conditioning_tensors: tp.Optional[dict] = None,
sampler_kwargs: dict = {},
):
"""
Separate audio sources based on conditioning using a pre-trained diffusion prior
Args:
mixed_audio: The mixed audio to separate
in_sr: The sample rate of the input audio
steps: The number of diffusion steps to run
conditioning: The conditioning to use for source separation
conditioning_tensors: Pre-computed conditioning tensors to use for source separation. If provided, conditioning is ignored.
sampler_kwargs: Keyword arguments to pass to the diffusion sampler
"""
device = mixed_audio.device
sample_rate = self.sample_rate
# Resample input audio if necessary
if in_sr != sample_rate:
resample_tf = T.Resample(in_sr, sample_rate).to(mixed_audio.device)
mixed_audio = resample_tf(mixed_audio)
audio_length = mixed_audio.shape[-1]
# Pad input audio to be compatible with the model
min_length = self.min_input_length
padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
# Pad input audio to be compatible with the model
if padded_input_length > audio_length:
mixed_audio = F.pad(mixed_audio, (0, padded_input_length - audio_length))
if self.pretransform is not None:
mixed_audio = self.pretransform.encode(mixed_audio)
# Conditioning
assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors for conditioned source separation"
if conditioning_tensors is None:
conditioning_tensors = self.conditioner(conditioning, device)
# Pass in the mixture audio as conditioning
conditioning_tensors["source"] = [mixed_audio]
stereo_audio = generate_diffusion_cond(
self,
conditioning_tensors=conditioning_tensors,
steps=steps,
sample_size=padded_input_length,
sample_rate=sample_rate,
device=device,
**sampler_kwargs,
)
return stereo_audio