File size: 5,498 Bytes
5915064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import typing as tp
from audiocraft.models import MusicGen, CompressionModel, LMModel
import audiocraft.quantization as qt
from .autoencoders import AudioAutoencoder
from .bottleneck import DACRVQBottleneck, DACRVQVAEBottleneck

from audiocraft.modules.codebooks_patterns import (
    DelayedPatternProvider,
    MusicLMPattern,
    ParallelPatternProvider,
    UnrolledPatternProvider,
    VALLEPattern,
)

from audiocraft.modules.conditioners import (
    ConditionFuser,
    ConditioningProvider,
    T5Conditioner,
)

def create_musicgen_from_config(config):
    model_config = config.get('model', None)
    assert model_config is not None, 'model config must be specified in config'

    if model_config.get("pretrained", False):
        model = MusicGen.get_pretrained(model_config["pretrained"], device="cpu")

        if model_config.get("reinit_lm", False):
            model.lm._init_weights("gaussian", "current", True)
    
        return model
    
    # Create MusicGen model from scratch
    compression_config = model_config.get('compression', None)
    assert compression_config is not None, 'compression config must be specified in model config'

    compression_type = compression_config.get('type', None)
    assert compression_type is not None, 'type must be specified in compression config'

    if compression_type == 'pretrained':
        compression_model = CompressionModel.get_pretrained(compression_config["config"]["name"])
    elif compression_type == "dac_rvq_ae":
        from .autoencoders import create_autoencoder_from_config
        autoencoder = create_autoencoder_from_config({"model": compression_config["config"], "sample_rate": config["sample_rate"]})
        autoencoder.load_state_dict(torch.load(compression_config["ckpt_path"], map_location="cpu")["state_dict"])
        compression_model = DACRVQCompressionModel(autoencoder)
    
    lm_config = model_config.get('lm', None)
    assert lm_config is not None, 'lm config must be specified in model config'

    codebook_pattern = lm_config.pop("codebook_pattern", "delay")

    pattern_providers = {
        'parallel': ParallelPatternProvider,
        'delay': DelayedPatternProvider,
        'unroll': UnrolledPatternProvider,
        'valle': VALLEPattern,
        'musiclm': MusicLMPattern,
    }

    pattern_provider = pattern_providers[codebook_pattern](n_q=compression_model.num_codebooks)

    conditioning_config = model_config.get("conditioning", {})

    condition_output_dim = conditioning_config.get("output_dim", 768)

    condition_provider = ConditioningProvider(
        conditioners = {
            "description": T5Conditioner(
                name="t5-base",
                output_dim=condition_output_dim,
                word_dropout=0.3,
                normalize_text=False,
                finetune=False,
                device="cpu"
            )
        }
    )

    condition_fuser = ConditionFuser(fuse2cond={
        "cross": ["description"],
        "prepend": [],
        "sum": []
        })

    lm = LMModel(
        pattern_provider = pattern_provider,
        condition_provider = condition_provider,
        fuser = condition_fuser,
        n_q = compression_model.num_codebooks,
        card = compression_model.cardinality,
        **lm_config
    )


    model = MusicGen(
        name = model_config.get("name", "musicgen-scratch"),
        compression_model = compression_model,
        lm = lm,
        max_duration=30
    )

    return model

class DACRVQCompressionModel(CompressionModel):
    def __init__(self, autoencoder: AudioAutoencoder):
        super().__init__()
        self.model = autoencoder.eval()

        assert isinstance(self.model.bottleneck, DACRVQBottleneck) or isinstance(self.model.bottleneck, DACRVQVAEBottleneck), "Autoencoder must have a DACRVQBottleneck or DACRVQVAEBottleneck"

        self.n_quantizers = self.model.bottleneck.num_quantizers

    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
        raise NotImplementedError("Forward and training with DAC RVQ not supported")

    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
        _, info = self.model.encode(x, return_info=True, n_quantizers=self.n_quantizers)
        codes = info["codes"]
        return codes, None

    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
        assert scale is None
        z_q = self.decode_latent(codes)
        return self.model.decode(z_q)

    def decode_latent(self, codes: torch.Tensor):
        """Decode from the discrete codes to continuous latent space."""
        return self.model.bottleneck.quantizer.from_codes(codes)[0]

    @property
    def channels(self) -> int:
        return self.model.io_channels

    @property
    def frame_rate(self) -> float:
        return self.model.sample_rate / self.model.downsampling_ratio

    @property
    def sample_rate(self) -> int:
        return self.model.sample_rate

    @property
    def cardinality(self) -> int:
        return self.model.bottleneck.quantizer.codebook_size

    @property
    def num_codebooks(self) -> int:
        return self.n_quantizers

    @property
    def total_codebooks(self) -> int:
        self.model.bottleneck.num_quantizers

    def set_num_codebooks(self, n: int):
        """Set the active number of codebooks used by the quantizer.
        """
        assert n >= 1
        assert n <= self.total_codebooks
        self.n_quantizers = n