|
import typing as tp |
|
|
|
import torch |
|
|
|
from .config import DataConfig |
|
|
|
|
|
def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c]. |
|
Negative t_idx => BOS; t_idx >= T => PAD. |
|
""" |
|
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32) |
|
|
|
t_idx_BxT = torch.broadcast_to( |
|
torch.arange(T, dtype=torch.int32)[None, :], |
|
[B, T], |
|
) |
|
t_idx_BxTx1 = t_idx_BxT[..., None] |
|
t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C) |
|
|
|
b_idx_BxTxC = torch.broadcast_to( |
|
torch.arange(B, dtype=torch.int32).view(B, 1, 1), |
|
[B, T, C], |
|
) |
|
c_idx_BxTxC = torch.broadcast_to( |
|
torch.arange(C, dtype=torch.int32).view(1, 1, C), |
|
[B, T, C], |
|
) |
|
|
|
|
|
t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1) |
|
|
|
indices_BTCx3 = torch.stack( |
|
[ |
|
b_idx_BxTxC.reshape(-1), |
|
t_clamped_BxTxC.reshape(-1), |
|
c_idx_BxTxC.reshape(-1), |
|
], |
|
dim=1, |
|
).long() |
|
|
|
return t_idx_BxTxC, indices_BTCx3 |
|
|
|
|
|
def apply_audio_delay( |
|
audio_BxTxC: torch.Tensor, |
|
pad_value: int, |
|
bos_value: int, |
|
precomp: tp.Tuple[torch.Tensor, torch.Tensor], |
|
) -> torch.Tensor: |
|
""" |
|
Applies the delay pattern to batched audio tokens using precomputed indices, |
|
inserting BOS where t_idx < 0 and PAD where t_idx >= T. |
|
|
|
Args: |
|
audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float) |
|
pad_value: the padding token |
|
bos_value: the BOS token |
|
precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices |
|
|
|
Returns: |
|
result_BxTxC: [B, T, C] delayed audio tokens |
|
""" |
|
device = audio_BxTxC.device |
|
t_idx_BxTxC, indices_BTCx3 = precomp |
|
t_idx_BxTxC = t_idx_BxTxC.to(device) |
|
indices_BTCx3 = indices_BTCx3.to(device) |
|
|
|
|
|
|
|
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]] |
|
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape) |
|
|
|
|
|
mask_bos = t_idx_BxTxC < 0 |
|
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] |
|
|
|
|
|
bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device) |
|
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device) |
|
|
|
|
|
|
|
result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC)) |
|
|
|
return result_BxTxC |
|
|
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def audio_to_codebook( |
|
model, |
|
input_values, |
|
data_config: DataConfig, |
|
padding_mask=None, |
|
sample_rate=44100, |
|
): |
|
""" |
|
Encodes the input audio waveform into discrete codes. |
|
|
|
Args: |
|
model: The model to use for encoding. |
|
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): |
|
Float values of the input audio waveform. |
|
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): |
|
Padding mask used to pad the `input_values`. |
|
sample_rate (`int`, *optional*) : |
|
Signal sampling_rate |
|
|
|
Returns: |
|
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling |
|
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with |
|
`codebook` of shape `[batch_size, num_codebooks, frames]`. |
|
Scale is not used here. |
|
|
|
""" |
|
audio_data = model.preprocess(input_values, sample_rate) |
|
|
|
if padding_mask is None: |
|
padding_mask = torch.ones_like(input_values).bool() |
|
|
|
_, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) |
|
seq_length = encoded_frame.shape[2] |
|
|
|
t_idx_BxTxC, indices_BTCx3 = build_delay_indices( |
|
B=1, |
|
T=seq_length, |
|
C=data_config.channels, |
|
delay_pattern=data_config.delay_pattern, |
|
) |
|
|
|
encoded_frame = apply_audio_delay( |
|
audio_BxTxC=encoded_frame.transpose(1, 2), |
|
pad_value=data_config.audio_pad_value, |
|
bos_value=data_config.audio_bos_value, |
|
precomp=(t_idx_BxTxC, indices_BTCx3), |
|
) |
|
|
|
return encoded_frame |
|
|
|
|
|
def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Precompute indices for the revert operation using PyTorch. |
|
|
|
Returns: |
|
A tuple (t_idx_BxTxC, indices_BTCx3) where: |
|
- t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay. |
|
- indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from: |
|
batch indices, clamped time indices, and channel indices. |
|
""" |
|
|
|
device = None |
|
|
|
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device) |
|
|
|
t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T]) |
|
t_idx_BT1 = t_idx_BT1.unsqueeze(-1) |
|
|
|
t_idx_BxTxC = torch.minimum( |
|
t_idx_BT1 + delay_arr.view(1, 1, C), |
|
torch.tensor(T - 1, device=device), |
|
) |
|
b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C]) |
|
c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C]) |
|
|
|
indices_BTCx3 = torch.stack( |
|
[ |
|
b_idx_BxTxC.reshape(-1), |
|
t_idx_BxTxC.reshape(-1), |
|
c_idx_BxTxC.reshape(-1), |
|
], |
|
axis=1, |
|
).long() |
|
|
|
return t_idx_BxTxC, indices_BTCx3 |
|
|
|
|
|
def revert_audio_delay( |
|
audio_BxTxC: torch.Tensor, |
|
pad_value: int, |
|
precomp: tp.Tuple[torch.Tensor, torch.Tensor], |
|
T: int, |
|
) -> torch.Tensor: |
|
""" |
|
Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version). |
|
|
|
Args: |
|
audio_BxTxC: Input delayed audio tensor |
|
pad_value: Padding value for out-of-bounds indices |
|
precomp: Precomputed revert indices tuple containing: |
|
- t_idx_BxTxC: Time offset indices tensor |
|
- indices_BTCx3: Gather indices tensor for original audio |
|
T: Original sequence length before padding |
|
|
|
Returns: |
|
Reverted audio tensor with same shape as input |
|
""" |
|
t_idx_BxTxC, indices_BTCx3 = precomp |
|
device = audio_BxTxC.device |
|
|
|
|
|
t_idx_BxTxC = t_idx_BxTxC.to(device) |
|
indices_BTCx3 = indices_BTCx3.to(device) |
|
|
|
|
|
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]] |
|
gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) |
|
|
|
|
|
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device) |
|
|
|
T_tensor = torch.tensor(T, device=device) |
|
|
|
result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) |
|
|
|
return result_BxTxC |
|
|
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def decode( |
|
model, |
|
audio_codes, |
|
): |
|
""" |
|
Decodes the given frames into an output audio waveform |
|
""" |
|
if len(audio_codes) != 1: |
|
raise ValueError(f"Expected one frame, got {len(audio_codes)}") |
|
|
|
try: |
|
audio_values = model.quantizer.from_codes(audio_codes) |
|
audio_values = model.decode(audio_values[0]) |
|
|
|
return audio_values |
|
except Exception as e: |
|
print(f"Error in decode method: {str(e)}") |
|
raise |
|
|
|
|
|
def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9): |
|
"""Process a single codebook file to generate audio""" |
|
|
|
generated_codes = generated_codes[:, 1:] |
|
|
|
if generated_codes.shape[1] > T: |
|
generated_codes = generated_codes[:, :T] |
|
|
|
seq_length = generated_codes.shape[1] |
|
|
|
|
|
t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern) |
|
|
|
|
|
audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0) |
|
reverted_codebook = revert_audio_delay( |
|
audio_BxTxC=audio_BxTxC, |
|
pad_value=0, |
|
precomp=(t_idx_BxTxC, indices_BTCx3), |
|
T=seq_length, |
|
) |
|
reverted_codebook = reverted_codebook[:, :-30, :] |
|
|
|
codebook = reverted_codebook.transpose(1, 2) |
|
|
|
min_valid_index = 0 |
|
max_valid_index = 1023 |
|
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index) |
|
|
|
num_invalid = torch.sum(invalid_mask).item() |
|
if num_invalid > 0: |
|
print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.") |
|
|
|
|
|
codebook[invalid_mask] = 0 |
|
audio_array = decode(model, codebook) |
|
|
|
return audio_array |
|
|