Spaces:
Paused
Paused
import torch | |
import torch.nn.functional as F | |
def add_first_frame_conditioning( | |
latent_model_input, | |
first_frame, | |
vae | |
): | |
""" | |
Adds first frame conditioning to a video diffusion model input. | |
Args: | |
latent_model_input: Original latent input (bs, channels, num_frames, height, width) | |
first_frame: Tensor of first frame to condition on (bs, channels, height, width) | |
vae: VAE model for encoding the conditioning | |
Returns: | |
conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width) | |
""" | |
device = latent_model_input.device | |
dtype = latent_model_input.dtype | |
vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample) | |
# Get number of frames from latent model input | |
_, _, num_latent_frames, _, _ = latent_model_input.shape | |
# Calculate original number of frames | |
# For n original frames, there are (n-1)//4 + 1 latent frames | |
# So to get n: n = (num_latent_frames-1)*4 + 1 | |
num_frames = (num_latent_frames - 1) * 4 + 1 | |
if len(first_frame.shape) == 3: | |
# we have a single image | |
first_frame = first_frame.unsqueeze(0) | |
# if it doesnt match the batch size, we need to expand it | |
if first_frame.shape[0] != latent_model_input.shape[0]: | |
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1) | |
# resize first frame to match the latent model input | |
vae_scale_factor = 8 | |
first_frame = F.interpolate( | |
first_frame, | |
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor), | |
mode='bilinear', | |
align_corners=False | |
) | |
# Add temporal dimension to first frame | |
first_frame = first_frame.unsqueeze(2) | |
# Create video condition with first frame and zeros for remaining frames | |
zero_frame = torch.zeros_like(first_frame) | |
video_condition = torch.cat([ | |
first_frame, | |
*[zero_frame for _ in range(num_frames - 1)] | |
], dim=2) | |
# Prepare for VAE encoding (bs, channels, num_frames, height, width) | |
# video_condition = video_condition.permute(0, 2, 1, 3, 4) | |
# Encode with VAE | |
latent_condition = vae.encode( | |
video_condition.to(device, dtype) | |
).latent_dist.sample() | |
latent_condition = latent_condition.to(device, dtype) | |
# Create mask: 1 for conditioning frames, 0 for frames to generate | |
batch_size = first_frame.shape[0] | |
latent_height = latent_condition.shape[3] | |
latent_width = latent_condition.shape[4] | |
# Initialize mask for all frames | |
mask_lat_size = torch.ones( | |
batch_size, 1, num_frames, latent_height, latent_width) | |
# Set all non-first frames to 0 | |
mask_lat_size[:, :, list(range(1, num_frames))] = 0 | |
# Special handling for first frame | |
first_frame_mask = mask_lat_size[:, :, 0:1] | |
first_frame_mask = torch.repeat_interleave( | |
first_frame_mask, dim=2, repeats=vae_scale_factor_temporal) | |
# Combine first frame mask with rest | |
mask_lat_size = torch.concat( | |
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) | |
# Reshape and transpose for model input | |
mask_lat_size = mask_lat_size.view( | |
batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width) | |
mask_lat_size = mask_lat_size.transpose(1, 2) | |
mask_lat_size = mask_lat_size.to(device, dtype) | |
# Combine conditioning with latent input | |
first_frame_condition = torch.concat( | |
[mask_lat_size, latent_condition], dim=1) | |
conditioned_latent = torch.cat( | |
[latent_model_input, first_frame_condition], dim=1) | |
return conditioned_latent | |