|
import torch
|
|
from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
|
|
|
|
class AstralQuantizer(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
tokenizer_name: str,
|
|
ssl_model_name: str,
|
|
ssl_output_layer: int,
|
|
encoder: torch.nn.Module,
|
|
quantizer: torch.nn.Module,
|
|
skip_ssl: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.encoder = encoder
|
|
self.quantizer = quantizer
|
|
self.tokenizer_name = tokenizer_name
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
|
|
|
|
self.ssl_model_name = ssl_model_name
|
|
self.ssl_output_layer = ssl_output_layer
|
|
self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
|
|
|
|
if skip_ssl:
|
|
self.ssl_model = None
|
|
else:
|
|
self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
|
|
self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
|
|
self.ssl_model.encoder.layer_norm = torch.nn.Identity()
|
|
|
|
def load_separate_checkpoint(self, checkpoint_path):
|
|
params = torch.load(checkpoint_path, map_location='cpu')['net']
|
|
for key in params.keys():
|
|
for k in list(params[key].keys()):
|
|
if k.startswith("module."):
|
|
params[key][k[len("module."):]] = params[key][k]
|
|
del params[key][k]
|
|
self.encoder.load_state_dict(params['encoder'])
|
|
self.quantizer.load_state_dict(params['vq'])
|
|
if self.decoder is not None:
|
|
self.decoder.load_state_dict(params['decoder'])
|
|
if self.asr_decoder is not None:
|
|
self.asr_decoder.load_state_dict(params['predictor'], strict=False)
|
|
|
|
def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
|
|
ssl_fn = self.ssl_model if self.ssl_model else ssl_model
|
|
assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
|
|
waves_16k_input_list = [
|
|
waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
|
|
for bib in range(len(waves_16k))
|
|
]
|
|
alt_inputs = self.ssl_feature_extractor(
|
|
waves_16k_input_list,
|
|
return_tensors='pt',
|
|
return_attention_mask=True,
|
|
padding=True,
|
|
sampling_rate=16000
|
|
).to(waves_16k.device)
|
|
feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320
|
|
|
|
outputs = ssl_fn(
|
|
alt_inputs.input_values,
|
|
attention_mask=alt_inputs.attention_mask,
|
|
)
|
|
last_hidden_states = outputs.last_hidden_state
|
|
last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
|
|
feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
|
|
last_hidden_states = last_hidden_states.transpose(1, 2)
|
|
x_hidden = self.encoder(last_hidden_states, feature_lens)
|
|
x_hidden = x_hidden.transpose(1, 2)
|
|
x_quantized, indices = self.quantizer(x_hidden)[:2]
|
|
return x_quantized, indices, feature_lens |