File size: 5,260 Bytes
5915064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from enum import Enum
import typing as tp

from .diffusion import ConditionedDiffusionModelWrapper
from ..inference.generation import generate_diffusion_cond
from ..inference.utils import prepare_audio

import torch
from torch.nn import functional as F
from torchaudio import transforms as T

# Define prior types enum
class PriorType(Enum):
    MonoToStereo = 1
    SourceSeparation = 2

class DiffusionPrior(ConditionedDiffusionModelWrapper):
    def __init__(self, *args, prior_type: PriorType=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.prior_type = prior_type  

class MonoToStereoDiffusionPrior(DiffusionPrior):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs)

    def stereoize(
        self, 
        audio: torch.Tensor, # (batch, channels, time)
        in_sr: int,
        steps: int,
        sampler_kwargs: dict = {},
    ):
        """
        Generate stereo audio from mono audio using a pre-trained diffusion prior

        Args:
            audio: The mono audio to convert to stereo
            in_sr: The sample rate of the input audio
            steps: The number of diffusion steps to run
            sampler_kwargs: Keyword arguments to pass to the diffusion sampler
        """

        device = audio.device

        sample_rate = self.sample_rate

        # Resample input audio if necessary
        if in_sr != sample_rate:
            resample_tf = T.Resample(in_sr, sample_rate).to(audio.device)
            audio = resample_tf(audio)

        audio_length = audio.shape[-1]

        # Pad input audio to be compatible with the model
        min_length = self.min_input_length
        padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length

        # Pad input audio to be compatible with the model
        if padded_input_length > audio_length:
            audio = F.pad(audio, (0, padded_input_length - audio_length))

        # Make audio mono, duplicate to stereo
        dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1)

        if self.pretransform is not None:
            dual_mono = self.pretransform.encode(dual_mono)

        conditioning = {"source": [dual_mono]}

        stereo_audio = generate_diffusion_cond(
            self, 
            conditioning_tensors=conditioning,
            steps=steps,
            sample_size=padded_input_length,
            sample_rate=sample_rate,
            device=device,
            **sampler_kwargs,
        ) 

        return stereo_audio


class SourceSeparationDiffusionPrior(DiffusionPrior):
    """
    A diffusion prior model made for conditioned source separation
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, prior_type=PriorType.SourceSeparation, **kwargs)

    def separate(
        self, 
        mixed_audio: torch.Tensor, # (batch, channels, time)
        in_sr: int,
        steps: int,
        conditioning: dict = None,
        conditioning_tensors: tp.Optional[dict] = None,
        sampler_kwargs: dict = {},
    ):
        """
        Separate audio sources based on conditioning using a pre-trained diffusion prior

        Args:
            mixed_audio: The mixed audio to separate
            in_sr: The sample rate of the input audio
            steps: The number of diffusion steps to run
            conditioning: The conditioning to use for source separation
            conditioning_tensors: Pre-computed conditioning tensors to use for source separation. If provided, conditioning is ignored.
            sampler_kwargs: Keyword arguments to pass to the diffusion sampler
        """

        device = mixed_audio.device

        sample_rate = self.sample_rate

        # Resample input audio if necessary
        if in_sr != sample_rate:
            resample_tf = T.Resample(in_sr, sample_rate).to(mixed_audio.device)
            mixed_audio = resample_tf(mixed_audio)

        audio_length = mixed_audio.shape[-1]

        # Pad input audio to be compatible with the model
        min_length = self.min_input_length
        padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length

        # Pad input audio to be compatible with the model
        if padded_input_length > audio_length:
            mixed_audio = F.pad(mixed_audio, (0, padded_input_length - audio_length))

        if self.pretransform is not None:
            mixed_audio = self.pretransform.encode(mixed_audio)

        # Conditioning
        assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors for conditioned source separation"
        if conditioning_tensors is None:
            conditioning_tensors = self.conditioner(conditioning, device)
        
        # Pass in the mixture audio as conditioning
        conditioning_tensors["source"] = [mixed_audio]

        stereo_audio = generate_diffusion_cond(
            self, 
            conditioning_tensors=conditioning_tensors,
            steps=steps,
            sample_size=padded_input_length,
            sample_rate=sample_rate,
            device=device,
            **sampler_kwargs,
        ) 

        return stereo_audio