Spaces:
Paused
Paused
File size: 5,260 Bytes
5915064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from enum import Enum
import typing as tp
from .diffusion import ConditionedDiffusionModelWrapper
from ..inference.generation import generate_diffusion_cond
from ..inference.utils import prepare_audio
import torch
from torch.nn import functional as F
from torchaudio import transforms as T
# Define prior types enum
class PriorType(Enum):
MonoToStereo = 1
SourceSeparation = 2
class DiffusionPrior(ConditionedDiffusionModelWrapper):
def __init__(self, *args, prior_type: PriorType=None, **kwargs):
super().__init__(*args, **kwargs)
self.prior_type = prior_type
class MonoToStereoDiffusionPrior(DiffusionPrior):
def __init__(self, *args, **kwargs):
super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs)
def stereoize(
self,
audio: torch.Tensor, # (batch, channels, time)
in_sr: int,
steps: int,
sampler_kwargs: dict = {},
):
"""
Generate stereo audio from mono audio using a pre-trained diffusion prior
Args:
audio: The mono audio to convert to stereo
in_sr: The sample rate of the input audio
steps: The number of diffusion steps to run
sampler_kwargs: Keyword arguments to pass to the diffusion sampler
"""
device = audio.device
sample_rate = self.sample_rate
# Resample input audio if necessary
if in_sr != sample_rate:
resample_tf = T.Resample(in_sr, sample_rate).to(audio.device)
audio = resample_tf(audio)
audio_length = audio.shape[-1]
# Pad input audio to be compatible with the model
min_length = self.min_input_length
padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
# Pad input audio to be compatible with the model
if padded_input_length > audio_length:
audio = F.pad(audio, (0, padded_input_length - audio_length))
# Make audio mono, duplicate to stereo
dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1)
if self.pretransform is not None:
dual_mono = self.pretransform.encode(dual_mono)
conditioning = {"source": [dual_mono]}
stereo_audio = generate_diffusion_cond(
self,
conditioning_tensors=conditioning,
steps=steps,
sample_size=padded_input_length,
sample_rate=sample_rate,
device=device,
**sampler_kwargs,
)
return stereo_audio
class SourceSeparationDiffusionPrior(DiffusionPrior):
"""
A diffusion prior model made for conditioned source separation
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, prior_type=PriorType.SourceSeparation, **kwargs)
def separate(
self,
mixed_audio: torch.Tensor, # (batch, channels, time)
in_sr: int,
steps: int,
conditioning: dict = None,
conditioning_tensors: tp.Optional[dict] = None,
sampler_kwargs: dict = {},
):
"""
Separate audio sources based on conditioning using a pre-trained diffusion prior
Args:
mixed_audio: The mixed audio to separate
in_sr: The sample rate of the input audio
steps: The number of diffusion steps to run
conditioning: The conditioning to use for source separation
conditioning_tensors: Pre-computed conditioning tensors to use for source separation. If provided, conditioning is ignored.
sampler_kwargs: Keyword arguments to pass to the diffusion sampler
"""
device = mixed_audio.device
sample_rate = self.sample_rate
# Resample input audio if necessary
if in_sr != sample_rate:
resample_tf = T.Resample(in_sr, sample_rate).to(mixed_audio.device)
mixed_audio = resample_tf(mixed_audio)
audio_length = mixed_audio.shape[-1]
# Pad input audio to be compatible with the model
min_length = self.min_input_length
padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
# Pad input audio to be compatible with the model
if padded_input_length > audio_length:
mixed_audio = F.pad(mixed_audio, (0, padded_input_length - audio_length))
if self.pretransform is not None:
mixed_audio = self.pretransform.encode(mixed_audio)
# Conditioning
assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors for conditioned source separation"
if conditioning_tensors is None:
conditioning_tensors = self.conditioner(conditioning, device)
# Pass in the mixture audio as conditioning
conditioning_tensors["source"] = [mixed_audio]
stereo_audio = generate_diffusion_cond(
self,
conditioning_tensors=conditioning_tensors,
steps=steps,
sample_size=padded_input_length,
sample_rate=sample_rate,
device=device,
**sampler_kwargs,
)
return stereo_audio |