""" Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import json from typing import Callable, Optional import torch import torch.nn as nn from einops import rearrange from einops.layers.torch import Rearrange from model.guide import GuideTransformer from model.modules.audio_encoder import Wav2VecEncoder from model.modules.rotary_embedding_torch import RotaryEmbedding from model.modules.transformer_modules import ( DecoderLayerStack, FiLMTransformerDecoderLayer, RegressionTransformer, TransformerEncoderLayerRotary, ) from model.utils import ( init_weight, PositionalEncoding, prob_mask_like, setup_lip_regressor, SinusoidalPosEmb, ) from model.vqvae import setup_tokenizer from torch.nn import functional as F from utils.misc import prGreen, prRed class Audio2LipRegressionTransformer(torch.nn.Module): def __init__( self, n_vertices: int = 338, causal: bool = False, train_wav2vec: bool = False, transformer_encoder_layers: int = 2, transformer_decoder_layers: int = 4, ): super().__init__() self.n_vertices = n_vertices self.audio_encoder = Wav2VecEncoder() if not train_wav2vec: self.audio_encoder.eval() for param in self.audio_encoder.parameters(): param.requires_grad = False self.regression_model = RegressionTransformer( transformer_encoder_layers=transformer_encoder_layers, transformer_decoder_layers=transformer_decoder_layers, d_model=512, d_cond=512, num_heads=4, causal=causal, ) self.project_output = torch.nn.Linear(512, self.n_vertices * 3) def forward(self, audio): """ :param audio: tensor of shape B x T x 1600 :return: tensor of shape B x T x n_vertices x 3 containing reconstructed lip geometry """ B, T = audio.shape[0], audio.shape[1] cond = self.audio_encoder(audio) x = torch.zeros(B, T, 512, device=audio.device) x = self.regression_model(x, cond) x = self.project_output(x) verts = x.view(B, T, self.n_vertices, 3) return verts class FiLMTransformer(nn.Module): def __init__( self, args, nfeats: int, latent_dim: int = 512, ff_size: int = 1024, num_layers: int = 4, num_heads: int = 4, dropout: float = 0.1, cond_feature_dim: int = 4800, activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu, use_rotary: bool = True, cond_mode: str = "audio", split_type: str = "train", device: str = "cuda", **kwargs, ) -> None: super().__init__() self.nfeats = nfeats self.cond_mode = cond_mode self.cond_feature_dim = cond_feature_dim self.add_frame_cond = args.add_frame_cond self.data_format = args.data_format self.split_type = split_type self.device = device # positional embeddings self.rotary = None self.abs_pos_encoding = nn.Identity() # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) if use_rotary: self.rotary = RotaryEmbedding(dim=latent_dim) else: self.abs_pos_encoding = PositionalEncoding( latent_dim, dropout, batch_first=True ) # time embedding processing self.time_mlp = nn.Sequential( SinusoidalPosEmb(latent_dim), nn.Linear(latent_dim, latent_dim * 4), nn.Mish(), ) self.to_time_cond = nn.Sequential( nn.Linear(latent_dim * 4, latent_dim), ) self.to_time_tokens = nn.Sequential( nn.Linear(latent_dim * 4, latent_dim * 2), Rearrange("b (r d) -> b r d", r=2), ) # null embeddings for guidance dropout self.seq_len = args.max_seq_length emb_len = 1998 # hardcoded for now self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, latent_dim)) self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim)) self.norm_cond = nn.LayerNorm(latent_dim) self.setup_audio_models() # set up pose/face specific parts of the model self.input_projection = nn.Linear(self.nfeats, latent_dim) if self.data_format == "pose": cond_feature_dim = 1024 key_feature_dim = 104 self.step = 30 self.use_cm = True self.setup_guide_models(args, latent_dim, key_feature_dim) self.post_pose_layers = self._build_single_pose_conv(self.nfeats) self.post_pose_layers.apply(init_weight) self.final_conv = torch.nn.Conv1d(self.nfeats, self.nfeats, kernel_size=1) self.receptive_field = 25 elif self.data_format == "face": self.use_cm = False cond_feature_dim = 1024 + 1014 self.setup_lip_models() self.cond_encoder = nn.Sequential() for _ in range(2): self.cond_encoder.append( TransformerEncoderLayerRotary( d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size, dropout=dropout, activation=activation, batch_first=True, rotary=self.rotary, ) ) self.cond_encoder.apply(init_weight) self.cond_projection = nn.Linear(cond_feature_dim, latent_dim) self.non_attn_cond_projection = nn.Sequential( nn.LayerNorm(latent_dim), nn.Linear(latent_dim, latent_dim), nn.SiLU(), nn.Linear(latent_dim, latent_dim), ) # decoder decoderstack = nn.ModuleList([]) for _ in range(num_layers): decoderstack.append( FiLMTransformerDecoderLayer( latent_dim, num_heads, dim_feedforward=ff_size, dropout=dropout, activation=activation, batch_first=True, rotary=self.rotary, use_cm=self.use_cm, ) ) self.seqTransDecoder = DecoderLayerStack(decoderstack) self.seqTransDecoder.apply(init_weight) self.final_layer = nn.Linear(latent_dim, self.nfeats) self.final_layer.apply(init_weight) def _build_single_pose_conv(self, nfeats: int) -> nn.ModuleList: post_pose_layers = torch.nn.ModuleList( [ torch.nn.Conv1d(nfeats, max(256, nfeats), kernel_size=3, dilation=1), torch.nn.Conv1d(max(256, nfeats), nfeats, kernel_size=3, dilation=2), torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3), torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=1), torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=2), torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3), ] ) return post_pose_layers def _run_single_pose_conv(self, output: torch.Tensor) -> torch.Tensor: output = torch.nn.functional.pad(output, pad=[self.receptive_field - 1, 0]) for _, layer in enumerate(self.post_pose_layers): y = torch.nn.functional.leaky_relu(layer(output), negative_slope=0.2) if self.split_type == "train": y = torch.nn.functional.dropout(y, 0.2) if output.shape[1] == y.shape[1]: output = (output[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection else: output = y return output def setup_guide_models(self, args, latent_dim: int, key_feature_dim: int) -> None: # set up conditioning info max_keyframe_len = len(list(range(self.seq_len))[:: self.step]) self.null_pose_embed = nn.Parameter( torch.randn(1, max_keyframe_len, latent_dim) ) prGreen(f"using keyframes: {self.null_pose_embed.shape}") self.frame_cond_projection = nn.Linear(key_feature_dim, latent_dim) self.frame_norm_cond = nn.LayerNorm(latent_dim) # for test time set up keyframe transformer self.resume_trans = None if self.split_type == "test": if hasattr(args, "resume_trans") and args.resume_trans is not None: self.resume_trans = args.resume_trans self.setup_guide_predictor(args.resume_trans) else: prRed("not using transformer, just using ground truth") def setup_guide_predictor(self, cp_path: str) -> None: cp_dir = cp_path.split("checkpoints/iter-")[0] with open(f"{cp_dir}/args.json") as f: trans_args = json.load(f) # set up tokenizer based on trans_arg load point self.tokenizer = setup_tokenizer(trans_args["resume_pth"]) # set up transformer self.transformer = GuideTransformer( tokens=self.tokenizer.n_clusters, num_layers=trans_args["layers"], dim=trans_args["dim"], emb_len=1998, num_audio_layers=trans_args["num_audio_layers"], ) for param in self.transformer.parameters(): param.requires_grad = False prGreen("loading TRANSFORMER checkpoint from {}".format(cp_path)) cp = torch.load(cp_path) missing_keys, unexpected_keys = self.transformer.load_state_dict( cp["model_state_dict"], strict=False ) assert len(missing_keys) == 0, missing_keys assert len(unexpected_keys) == 0, unexpected_keys def setup_audio_models(self) -> None: self.audio_model, self.audio_resampler = setup_lip_regressor() def setup_lip_models(self) -> None: self.lip_model = Audio2LipRegressionTransformer() cp_path = "./assets/iter-0200000.pt" cp = torch.load(cp_path, map_location=torch.device(self.device)) self.lip_model.load_state_dict(cp["model_state_dict"]) for param in self.lip_model.parameters(): param.requires_grad = False prGreen(f"adding lip conditioning {cp_path}") def parameters_w_grad(self): return [p for p in self.parameters() if p.requires_grad] def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor: device = next(self.parameters()).device a0 = self.audio_resampler(raw_audio[:, :, 0].to(device)) a1 = self.audio_resampler(raw_audio[:, :, 1].to(device)) with torch.no_grad(): z0 = self.audio_model.feature_extractor(a0) z1 = self.audio_model.feature_extractor(a1) emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1) return emb def encode_lip(self, audio: torch.Tensor, cond_embed: torch.Tensor) -> torch.Tensor: reshaped_audio = audio.reshape((audio.shape[0], -1, 1600, 2))[..., 0] # processes 4 seconds at a time B, T, _ = reshaped_audio.shape lip_cond = torch.zeros( (audio.shape[0], T, 338, 3), device=audio.device, dtype=audio.dtype, ) for i in range(0, T, 120): lip_cond[:, i : i + 120, ...] = self.lip_model( reshaped_audio[:, i : i + 120, ...] ) lip_cond = lip_cond.permute(0, 2, 3, 1).reshape((B, 338 * 3, -1)) lip_cond = torch.nn.functional.interpolate( lip_cond, size=cond_embed.shape[1], mode="nearest-exact" ).permute(0, 2, 1) cond_embed = torch.cat((cond_embed, lip_cond), dim=-1) return cond_embed def encode_keyframes( self, y: torch.Tensor, cond_drop_prob: float, batch_size: int ) -> torch.Tensor: pred = y["keyframes"] new_mask = y["mask"][..., :: self.step].squeeze((1, 2)) pred[~new_mask] = 0.0 # pad the unknown pose_hidden = self.frame_cond_projection(pred.detach().clone().cuda()) pose_embed = self.abs_pos_encoding(pose_hidden) pose_tokens = self.frame_norm_cond(pose_embed) # do conditional dropout for guide poses key_cond_drop_prob = cond_drop_prob keep_mask_pose = prob_mask_like( (batch_size,), 1 - key_cond_drop_prob, device=pose_tokens.device ) keep_mask_pose_embed = rearrange(keep_mask_pose, "b -> b 1 1") null_pose_embed = self.null_pose_embed.to(pose_tokens.dtype) pose_tokens = torch.where( keep_mask_pose_embed, pose_tokens, null_pose_embed[:, : pose_tokens.shape[1], :], ) return pose_tokens def forward( self, x: torch.Tensor, times: torch.Tensor, y: Optional[torch.Tensor] = None, cond_drop_prob: float = 0.0, ) -> torch.Tensor: if x.dim() == 4: x = x.permute(0, 3, 1, 2).squeeze(-1) batch_size, device = x.shape[0], x.device if self.cond_mode == "uncond": cond_embed = torch.zeros( (x.shape[0], x.shape[1], self.cond_feature_dim), dtype=x.dtype, device=x.device, ) else: cond_embed = y["audio"] cond_embed = self.encode_audio(cond_embed) if self.data_format == "face": cond_embed = self.encode_lip(y["audio"], cond_embed) pose_tokens = None if self.data_format == "pose": pose_tokens = self.encode_keyframes(y, cond_drop_prob, batch_size) assert cond_embed is not None, "cond emb should not be none" # process conditioning information x = self.input_projection(x) x = self.abs_pos_encoding(x) audio_cond_drop_prob = cond_drop_prob keep_mask = prob_mask_like( (batch_size,), 1 - audio_cond_drop_prob, device=device ) keep_mask_embed = rearrange(keep_mask, "b -> b 1 1") keep_mask_hidden = rearrange(keep_mask, "b -> b 1") cond_tokens = self.cond_projection(cond_embed) cond_tokens = self.abs_pos_encoding(cond_tokens) if self.data_format == "face": cond_tokens = self.cond_encoder(cond_tokens) null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype) cond_tokens = torch.where( keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :] ) mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) # create t conditioning t_hidden = self.time_mlp(times) t = self.to_time_cond(t_hidden) t_tokens = self.to_time_tokens(t_hidden) null_cond_hidden = self.null_cond_hidden.to(t.dtype) cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden) t += cond_hidden # cross-attention conditioning c = torch.cat((cond_tokens, t_tokens), dim=-2) cond_tokens = self.norm_cond(c) # Pass through the transformer decoder output = self.seqTransDecoder(x, cond_tokens, t, memory2=pose_tokens) output = self.final_layer(output) if self.data_format == "pose": output = output.permute(0, 2, 1) output = self._run_single_pose_conv(output) output = self.final_conv(output) output = output.permute(0, 2, 1) return output