import torch from modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper from modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper class PrefixStreamingFlowMatchingDetokenizer: def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None: self.dtype = torch.bfloat16 print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer") self.vocoder = vocoder self.vocoder.to_dtype(self.dtype) self.semantic_fm = fm # initialize mel_spec self.max_pos_size = 4096 self.is_timbre_semantic_token = False self.pre_mel = None self.frame_size = 480 # how many samples in a frame self.pre_wav = None self.state_dict_backup = None self.hamming_window_cache = {} self.previous_chunk_left = None self.look_ahead_tokens = look_ahead_tokens self.clear_states() @classmethod def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device, look_ahead_tokens=0, max_prompt_chunk=2, max_kv_cache_tokens=900, use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"): bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device) semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens, use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule) return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens) @torch.inference_mode() def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None): """ Arguments: timbre_speech: torch.Tensor, shape [B, N_speech_24k] timbre_semantic_token: torch.Tensor, shape [B, N] chunk_size: int, chunk size for prefilling timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech """ if timbre_mel is None: assert timbre_speech is not None, "timbre_speech should not be None if timbre_mel is not None" assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0 assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1 mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0)) else: assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0 assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1 mel_spec = timbre_mel.squeeze(0) if mel_spec.shape[0] < timbre_semantic_token.shape[1]: # pad mel_spec mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0])) elif mel_spec.shape[0] > timbre_semantic_token.shape[1]: # truncate mel_spec mel_spec = mel_spec[:timbre_semantic_token.shape[1], :] # clear all states self.semantic_fm.clear_all_states() self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False) self.state_dict_backup = self.semantic_fm.state_dict() @torch.inference_mode() def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver="neural_ode_euler", is_final=False, upsample_factor=1): assert len(semantic_token.shape) == 2 and ode_step > 0 assert semantic_token.shape[0] == 1 semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1) semantic_token = semantic_token.squeeze(0) if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None: semantic_token_previous = self.previous_chunk_left["semantic_token"] semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1) x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype) if self.look_ahead_tokens != 0 and self.previous_chunk_left is None: self.previous_chunk_left = {"semantic_token": None} speech_mel = self.semantic_fm.infer_chunk( xt_chunk=x_t_chunk, semantic_tokens_chunk=semantic_token, start_position_id=self.semantic_fm.start_position_id, ode_steps=ode_step, verbose=verbose, look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0, cache=self.previous_chunk_left, ode_solver=ode_solver ) chunk_size = speech_mel.shape[0] length = speech_mel.shape[0] self.semantic_fm.start_position_id += length self.semantic_fm.update_incremental_state() self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens # smoothing # I will maintain the history of seqlen wav # For the first chunk, I will only return the half chunk wav, and save the res half chunk in history # For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the if self.pre_mel is None: # first chunk, related to TTFB concat_mel = speech_mel concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel) if is_final: self.clear_states() self.state_dict_backup = None ret_wav = concat_reconstructed_wav.float() else: reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step self.pre_mel = speech_mel[-chunk_size//2:, :] ret_wav = reconstructed_wav.float() else: concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0) concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel) if is_final: self.clear_states() self.state_dict_backup = None ret_wav = concat_reconstructed_wav.float() else: # fetch history prev_speech_len = self.pre_wav.shape[1] if concat_reconstructed_wav.shape[1] > prev_speech_len * 2: gen_speech_len = prev_speech_len * 2 else: gen_speech_len = concat_reconstructed_wav.shape[1] // 2 reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk if gen_speech_len not in self.hamming_window_cache: self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0) hamming_window = self.hamming_window_cache[gen_speech_len] # apply smoothing of the first half chunk reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \ reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)] res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len res_mel_len = res_speech_len // self.frame_size self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:] self.pre_mel = speech_mel[-res_mel_len:, :] ret_wav = reconstructed_wav.float() if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size: # out of position id, self.semantic_fm.clear_all_states() self.semantic_fm.load_state_dict(self.state_dict_backup) return ret_wav def clear_states(self): self.semantic_fm.clear_all_states() self.previous_chunk_left = None self.pre_mel = None self.pre_wav = None def get_audio_detokenizer(): fm_model_config = "resources/audio_detokenizer/config.yaml" fm_ckpt_path = "resources/audio_detokenizer/model.pt" bigvgan_config_file = "resources/vocoder/config.json" bigvgan_ckpt_path = "resources/vocoder/model.pt" device=torch.cuda.current_device() detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained( vocoder_config=bigvgan_config_file, vocoder_ckpt=bigvgan_ckpt_path, max_prompt_chunk=10, # 10 * 3 = 30s fm_config=fm_model_config, fm_ckpt=fm_ckpt_path, device=device, use_cfg=False, look_ahead_tokens=12) return detokenizer def detokenize(detokenizer, tokens, ref_wav, ref_tokens, streaming=False): with torch.no_grad(): detokenizer.clear_states() detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150) cache_speech_collection = [] chunk_size = 150 first_chunk_size = 100 first_chunk_tokens = tokens[:, :first_chunk_size] gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size) if streaming: yield gen_speech else: cache_speech_collection.append(gen_speech) res_tokens = tokens[:, first_chunk_size:] for i in range(0, res_tokens.size(1), chunk_size): chunk_tokens = res_tokens[:, i:i+chunk_size] gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1))) if streaming: yield gen_speech else: cache_speech_collection.append(gen_speech) if not streaming: gen_speech_all = torch.cat(cache_speech_collection, dim=-1) return gen_speech_all def detokenize_noref(detokenizer, tokens, streaming=False): with torch.no_grad(): detokenizer.clear_states() cache_speech_collection = [] chunk_size = 150 first_chunk_size = 100 first_chunk_tokens = tokens[:, :first_chunk_size] gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size) if streaming: yield gen_speech else: cache_speech_collection.append(gen_speech) res_tokens = tokens[:, first_chunk_size:] for i in range(0, res_tokens.size(1), chunk_size): chunk_tokens = res_tokens[:, i:i+chunk_size] gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1))) if streaming: yield gen_speech else: cache_speech_collection.append(gen_speech) if not streaming: gen_speech_all = torch.cat(cache_speech_collection, dim=-1) return gen_speech_all