Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,765 Bytes
1034391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
import typing as tp
import torch
from .config import DataConfig
def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
Negative t_idx => BOS; t_idx >= T => PAD.
"""
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
t_idx_BxT = torch.broadcast_to(
torch.arange(T, dtype=torch.int32)[None, :],
[B, T],
)
t_idx_BxTx1 = t_idx_BxT[..., None]
t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
b_idx_BxTxC = torch.broadcast_to(
torch.arange(B, dtype=torch.int32).view(B, 1, 1),
[B, T, C],
)
c_idx_BxTxC = torch.broadcast_to(
torch.arange(C, dtype=torch.int32).view(1, 1, C),
[B, T, C],
)
# We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
indices_BTCx3 = torch.stack(
[
b_idx_BxTxC.reshape(-1),
t_clamped_BxTxC.reshape(-1),
c_idx_BxTxC.reshape(-1),
],
dim=1,
).long() # Ensure indices are long type for indexing
return t_idx_BxTxC, indices_BTCx3
def apply_audio_delay(
audio_BxTxC: torch.Tensor,
pad_value: int,
bos_value: int,
precomp: tp.Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
"""
Applies the delay pattern to batched audio tokens using precomputed indices,
inserting BOS where t_idx < 0 and PAD where t_idx >= T.
Args:
audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
pad_value: the padding token
bos_value: the BOS token
precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
Returns:
result_BxTxC: [B, T, C] delayed audio tokens
"""
device = audio_BxTxC.device # Get device from input tensor
t_idx_BxTxC, indices_BTCx3 = precomp
t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
indices_BTCx3 = indices_BTCx3.to(device)
# Equivalent of tf.gather_nd using advanced indexing
# Ensure indices are long type if not already (build_delay_indices should handle this)
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
# Create masks on the correct device
mask_bos = t_idx_BxTxC < 0 # => place bos_value
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
# Create scalar tensors on the correct device
bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
# If mask_bos, BOS; else if mask_pad, PAD; else original gather
# All tensors should now be on the same device
result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
return result_BxTxC
@torch.no_grad()
@torch.inference_mode()
def audio_to_codebook(
model,
input_values,
data_config: DataConfig,
padding_mask=None,
sample_rate=44100,
):
"""
Encodes the input audio waveform into discrete codes.
Args:
model: The model to use for encoding.
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Float values of the input audio waveform.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Padding mask used to pad the `input_values`.
sample_rate (`int`, *optional*) :
Signal sampling_rate
Returns:
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
`codebook` of shape `[batch_size, num_codebooks, frames]`.
Scale is not used here.
"""
audio_data = model.preprocess(input_values, sample_rate)
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
_, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
seq_length = encoded_frame.shape[2]
t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
B=1,
T=seq_length,
C=data_config.channels,
delay_pattern=data_config.delay_pattern,
)
encoded_frame = apply_audio_delay(
audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
pad_value=data_config.audio_pad_value,
bos_value=data_config.audio_bos_value,
precomp=(t_idx_BxTxC, indices_BTCx3),
)
return encoded_frame
def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""
Precompute indices for the revert operation using PyTorch.
Returns:
A tuple (t_idx_BxTxC, indices_BTCx3) where:
- t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
- indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
batch indices, clamped time indices, and channel indices.
"""
# Use default device unless specified otherwise; assumes inputs might define device later
device = None # Or determine dynamically if needed, e.g., from a model parameter
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
t_idx_BxTxC = torch.minimum(
t_idx_BT1 + delay_arr.view(1, 1, C),
torch.tensor(T - 1, device=device),
)
b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
indices_BTCx3 = torch.stack(
[
b_idx_BxTxC.reshape(-1),
t_idx_BxTxC.reshape(-1),
c_idx_BxTxC.reshape(-1),
],
axis=1,
).long() # Ensure indices are long type
return t_idx_BxTxC, indices_BTCx3
def revert_audio_delay(
audio_BxTxC: torch.Tensor,
pad_value: int,
precomp: tp.Tuple[torch.Tensor, torch.Tensor],
T: int,
) -> torch.Tensor:
"""
Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
Args:
audio_BxTxC: Input delayed audio tensor
pad_value: Padding value for out-of-bounds indices
precomp: Precomputed revert indices tuple containing:
- t_idx_BxTxC: Time offset indices tensor
- indices_BTCx3: Gather indices tensor for original audio
T: Original sequence length before padding
Returns:
Reverted audio tensor with same shape as input
"""
t_idx_BxTxC, indices_BTCx3 = precomp
device = audio_BxTxC.device # Get device from input tensor
# Move precomputed indices to the same device as audio_BxTxC if they aren't already
t_idx_BxTxC = t_idx_BxTxC.to(device)
indices_BTCx3 = indices_BTCx3.to(device)
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
# Create pad_tensor on the correct device
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
# Create T tensor on the correct device for comparison
T_tensor = torch.tensor(T, device=device)
result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
return result_BxTxC
@torch.no_grad()
@torch.inference_mode()
def decode(
model,
audio_codes,
):
"""
Decodes the given frames into an output audio waveform
"""
if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
try:
audio_values = model.quantizer.from_codes(audio_codes)
audio_values = model.decode(audio_values[0])
return audio_values
except Exception as e:
print(f"Error in decode method: {str(e)}")
raise
def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
"""Process a single codebook file to generate audio"""
# Remove BOS token
generated_codes = generated_codes[:, 1:]
if generated_codes.shape[1] > T:
generated_codes = generated_codes[:, :T]
seq_length = generated_codes.shape[1]
# Build revert indices
t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
# Transpose and add batch dimension
audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
reverted_codebook = revert_audio_delay(
audio_BxTxC=audio_BxTxC,
pad_value=0,
precomp=(t_idx_BxTxC, indices_BTCx3),
T=seq_length,
)
reverted_codebook = reverted_codebook[:, :-30, :]
codebook = reverted_codebook.transpose(1, 2)
min_valid_index = 0
max_valid_index = 1023
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
num_invalid = torch.sum(invalid_mask).item()
if num_invalid > 0:
print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
# Set invalid values to 0 (modify the tensor in-place)
codebook[invalid_mask] = 0
audio_array = decode(model, codebook)
return audio_array
|