Spaces:
Paused
Paused
File size: 5,498 Bytes
5915064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
import typing as tp
from audiocraft.models import MusicGen, CompressionModel, LMModel
import audiocraft.quantization as qt
from .autoencoders import AudioAutoencoder
from .bottleneck import DACRVQBottleneck, DACRVQVAEBottleneck
from audiocraft.modules.codebooks_patterns import (
DelayedPatternProvider,
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
)
from audiocraft.modules.conditioners import (
ConditionFuser,
ConditioningProvider,
T5Conditioner,
)
def create_musicgen_from_config(config):
model_config = config.get('model', None)
assert model_config is not None, 'model config must be specified in config'
if model_config.get("pretrained", False):
model = MusicGen.get_pretrained(model_config["pretrained"], device="cpu")
if model_config.get("reinit_lm", False):
model.lm._init_weights("gaussian", "current", True)
return model
# Create MusicGen model from scratch
compression_config = model_config.get('compression', None)
assert compression_config is not None, 'compression config must be specified in model config'
compression_type = compression_config.get('type', None)
assert compression_type is not None, 'type must be specified in compression config'
if compression_type == 'pretrained':
compression_model = CompressionModel.get_pretrained(compression_config["config"]["name"])
elif compression_type == "dac_rvq_ae":
from .autoencoders import create_autoencoder_from_config
autoencoder = create_autoencoder_from_config({"model": compression_config["config"], "sample_rate": config["sample_rate"]})
autoencoder.load_state_dict(torch.load(compression_config["ckpt_path"], map_location="cpu")["state_dict"])
compression_model = DACRVQCompressionModel(autoencoder)
lm_config = model_config.get('lm', None)
assert lm_config is not None, 'lm config must be specified in model config'
codebook_pattern = lm_config.pop("codebook_pattern", "delay")
pattern_providers = {
'parallel': ParallelPatternProvider,
'delay': DelayedPatternProvider,
'unroll': UnrolledPatternProvider,
'valle': VALLEPattern,
'musiclm': MusicLMPattern,
}
pattern_provider = pattern_providers[codebook_pattern](n_q=compression_model.num_codebooks)
conditioning_config = model_config.get("conditioning", {})
condition_output_dim = conditioning_config.get("output_dim", 768)
condition_provider = ConditioningProvider(
conditioners = {
"description": T5Conditioner(
name="t5-base",
output_dim=condition_output_dim,
word_dropout=0.3,
normalize_text=False,
finetune=False,
device="cpu"
)
}
)
condition_fuser = ConditionFuser(fuse2cond={
"cross": ["description"],
"prepend": [],
"sum": []
})
lm = LMModel(
pattern_provider = pattern_provider,
condition_provider = condition_provider,
fuser = condition_fuser,
n_q = compression_model.num_codebooks,
card = compression_model.cardinality,
**lm_config
)
model = MusicGen(
name = model_config.get("name", "musicgen-scratch"),
compression_model = compression_model,
lm = lm,
max_duration=30
)
return model
class DACRVQCompressionModel(CompressionModel):
def __init__(self, autoencoder: AudioAutoencoder):
super().__init__()
self.model = autoencoder.eval()
assert isinstance(self.model.bottleneck, DACRVQBottleneck) or isinstance(self.model.bottleneck, DACRVQVAEBottleneck), "Autoencoder must have a DACRVQBottleneck or DACRVQVAEBottleneck"
self.n_quantizers = self.model.bottleneck.num_quantizers
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
raise NotImplementedError("Forward and training with DAC RVQ not supported")
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
_, info = self.model.encode(x, return_info=True, n_quantizers=self.n_quantizers)
codes = info["codes"]
return codes, None
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
assert scale is None
z_q = self.decode_latent(codes)
return self.model.decode(z_q)
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
return self.model.bottleneck.quantizer.from_codes(codes)[0]
@property
def channels(self) -> int:
return self.model.io_channels
@property
def frame_rate(self) -> float:
return self.model.sample_rate / self.model.downsampling_ratio
@property
def sample_rate(self) -> int:
return self.model.sample_rate
@property
def cardinality(self) -> int:
return self.model.bottleneck.quantizer.codebook_size
@property
def num_codebooks(self) -> int:
return self.n_quantizers
@property
def total_codebooks(self) -> int:
self.model.bottleneck.num_quantizers
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
"""
assert n >= 1
assert n <= self.total_codebooks
self.n_quantizers = n |