Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import math
import torch
import xformers
import xformers.ops
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from fairscale.nn.checkpoint import checkpoint_wrapper
from .util import *
# from .mha_flash import FlashAttentionBlock
from utils.registry_class import MODEL
USE_TEMPORAL_TRANSFORMER = True
class PreNormattention(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
class PreNormattention_qkv(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, q, k, v, **kwargs):
return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Attention_qkv(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, q, k, v):
b, n, _, h = *q.shape, self.heads
bk = k.shape[0]
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h = h)
v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h = h)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class PostNormattention(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.norm(self.fn(x, **kwargs) + x)
class Transformer_v2(nn.Module):
def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1):
super().__init__()
self.layers = nn.ModuleList([])
self.depth = depth
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)),
FeedForward(dim, mlp_dim, dropout = dropout_ffn),
]))
def forward(self, x):
for attn, ff in self.layers[:1]:
x = attn(x)
x = ff(x) + x
if self.depth > 1:
for attn, ff in self.layers[1:]:
x = attn(x)
x = ff(x) + x
return x
class DropPath(nn.Module):
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
"""
def __init__(self, p):
super(DropPath, self).__init__()
self.p = p
def forward(self, *args, zero=None, keep=None):
if not self.training:
return args[0] if len(args) == 1 else args
# params
x = args[0]
b = x.size(0)
n = (torch.rand(b) < self.p).sum()
# non-zero and non-keep mask
mask = x.new_ones(b, dtype=torch.bool)
if keep is not None:
mask[keep] = False
if zero is not None:
mask[zero] = False
# drop-path index
index = torch.where(mask)[0]
index = index[torch.randperm(len(index))[:n]]
if zero is not None:
index = torch.cat([index, torch.where(zero)[0]], dim=0)
# drop-path multiplier
multiplier = x.new_ones(b)
multiplier[index] = 0.0
output = tuple(u * self.broadcast(multiplier, u) for u in args)
return output[0] if len(args) == 1 else output
def broadcast(self, src, dst):
assert src.size(0) == dst.size(0)
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
return src.view(shape)
@MODEL.register_class()
class UNetSD_UniAnimate(nn.Module):
def __init__(self,
config=None,
in_dim=4,
dim=512,
y_dim=512,
context_dim=1024,
hist_dim = 156,
concat_dim = 8,
out_dim=6,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention = True,
use_checkpoint=False,
use_image_dataset=False,
use_fps_condition= False,
use_sim_mask = False,
misc_dropout = 0.5,
training=True,
inpainting=True,
p_all_zero=0.1,
p_all_keep=0.1,
zero_y = None,
black_image_feature = None,
adapter_transformer_layers = 1,
num_tokens=4,
**kwargs
):
embed_dim = dim * 4
num_heads=num_heads if num_heads else dim//32
super(UNetSD_UniAnimate, self).__init__()
self.zero_y = zero_y
self.black_image_feature = black_image_feature
self.cfg = config
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.context_dim = context_dim
self.num_tokens = num_tokens
self.hist_dim = hist_dim
self.concat_dim = concat_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_heads = num_heads
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.use_scale_shift_norm = use_scale_shift_norm
self.temporal_attn_times = temporal_attn_times
self.temporal_attention = temporal_attention
self.use_checkpoint = use_checkpoint
self.use_image_dataset = use_image_dataset
self.use_fps_condition = use_fps_condition
self.use_sim_mask = use_sim_mask
self.training=training
self.inpainting = inpainting
self.video_compositions = self.cfg.video_compositions
self.misc_dropout = misc_dropout
self.p_all_zero = p_all_zero
self.p_all_keep = p_all_keep
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
self.resolution = config.resolution
# embeddings
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
if 'image' in self.video_compositions:
self.pre_image_condition = nn.Sequential(
nn.Linear(self.context_dim, self.context_dim),
nn.SiLU(),
nn.Linear(self.context_dim, self.context_dim*self.num_tokens))
if 'local_image' in self.video_compositions:
self.local_image_embedding = nn.Sequential(
nn.Conv2d(3, concat_dim * 4, 3, padding=1),
nn.SiLU(),
nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
self.local_image_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
if 'dwpose' in self.video_compositions:
self.dwpose_embedding = nn.Sequential(
nn.Conv2d(3, concat_dim * 4, 3, padding=1),
nn.SiLU(),
nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
self.dwpose_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
if 'randomref_pose' in self.video_compositions:
randomref_dim = 4
self.randomref_pose2_embedding = nn.Sequential(
nn.Conv2d(3, concat_dim * 4, 3, padding=1),
nn.SiLU(),
nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=2, padding=1))
self.randomref_pose2_embedding_after = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
if 'randomref' in self.video_compositions:
randomref_dim = 4
self.randomref_embedding2 = nn.Sequential(
nn.Conv2d(randomref_dim, concat_dim * 4, 3, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=1, padding=1))
self.randomref_embedding_after2 = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
### Condition Dropout
self.misc_dropout = DropPath(misc_dropout)
if temporal_attention and not USE_TEMPORAL_TRANSFORMER:
self.rotary_emb = RotaryEmbedding(min(32, head_dim))
self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32) # realistically will not be able to generate that many frames of video... yet
if self.use_fps_condition:
self.fps_embedding = nn.Sequential(
nn.Linear(dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
# encoder
self.input_blocks = nn.ModuleList()
self.pre_image = nn.Sequential()
init_block = nn.ModuleList([nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)])
#### need an initial temporal attention?
if temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset))
self.input_blocks.append(init_block)
shortcut_dims.append(dim)
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
self.input_blocks.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim
)
shortcut_dims.append(out_dim)
scale /= 2.0
self.input_blocks.append(downsample)
# middle
self.middle_block = nn.ModuleList([
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,),
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
disable_self_attn=False, use_linear=True
)])
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
self.middle_block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
)
)
else:
self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
# decoder
self.output_blocks = nn.ModuleList()
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024,
disable_self_attn=False, use_linear=True
)
)
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformer(
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset
)
)
else:
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim)
scale *= 2.0
block.append(upsample)
self.output_blocks.append(block)
# head
self.out = nn.Sequential(
nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.out[-1].weight)
def forward(self,
x,
t,
y = None,
depth = None,
image = None,
motion = None,
local_image = None,
single_sketch = None,
masked = None,
canny = None,
sketch = None,
dwpose = None,
randomref = None,
histogram = None,
fps = None,
video_mask = None,
focus_present_mask = None,
prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
mask_last_frame_num = 0 # mask last frame num
):
assert self.inpainting or masked is None, 'inpainting is not supported'
batch, c, f, h, w= x.shape
frames = f
device = x.device
self.batch = batch
#### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)
else:
time_rel_pos_bias = None
# all-zero and all-keep masks
zero = torch.zeros(batch, dtype=torch.bool).to(x.device)
keep = torch.zeros(batch, dtype=torch.bool).to(x.device)
if self.training:
nzero = (torch.rand(batch) < self.p_all_zero).sum()
nkeep = (torch.rand(batch) < self.p_all_keep).sum()
index = torch.randperm(batch)
zero[index[0:nzero]] = True
keep[index[nzero:nzero + nkeep]] = True
assert not (zero & keep).any()
misc_dropout = partial(self.misc_dropout, zero = zero, keep = keep)
concat = x.new_zeros(batch, self.concat_dim, f, h, w)
# local_image_embedding (first frame)
if local_image is not None:
local_image = rearrange(local_image, 'b c f h w -> (b f) c h w')
local_image = self.local_image_embedding(local_image)
h = local_image.shape[2]
local_image = self.local_image_embedding_after(rearrange(local_image, '(b f) c h w -> (b h w) f c', b = batch))
local_image = rearrange(local_image, '(b h w) f c -> b c f h w', b = batch, h = h)
concat = concat + misc_dropout(local_image)
if dwpose is not None:
if 'randomref_pose' in self.video_compositions:
dwpose_random_ref = dwpose[:,:,:1].clone()
dwpose = dwpose[:,:,1:]
dwpose = rearrange(dwpose, 'b c f h w -> (b f) c h w')
dwpose = self.dwpose_embedding(dwpose)
h = dwpose.shape[2]
dwpose = self.dwpose_embedding_after(rearrange(dwpose, '(b f) c h w -> (b h w) f c', b = batch))
dwpose = rearrange(dwpose, '(b h w) f c -> b c f h w', b = batch, h = h)
concat = concat + misc_dropout(dwpose)
randomref_b = x.new_zeros(batch, self.concat_dim+4, 1, h, w)
if randomref is not None:
randomref = rearrange(randomref[:,:,:1,], 'b c f h w -> (b f) c h w')
randomref = self.randomref_embedding2(randomref)
h = randomref.shape[2]
randomref = self.randomref_embedding_after2(rearrange(randomref, '(b f) c h w -> (b h w) f c', b = batch))
if 'randomref_pose' in self.video_compositions:
dwpose_random_ref = rearrange(dwpose_random_ref, 'b c f h w -> (b f) c h w')
dwpose_random_ref = self.randomref_pose2_embedding(dwpose_random_ref)
dwpose_random_ref = self.randomref_pose2_embedding_after(rearrange(dwpose_random_ref, '(b f) c h w -> (b h w) f c', b = batch))
randomref = randomref + dwpose_random_ref
randomref_a = rearrange(randomref, '(b h w) f c -> b c f h w', b = batch, h = h)
randomref_b = randomref_b + randomref_a
x = torch.cat([randomref_b, torch.cat([x, concat], dim=1)], dim=2)
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = self.pre_image(x)
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
# embeddings
if self.use_fps_condition and fps is not None:
e = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim))
else:
e = self.time_embed(sinusoidal_embedding(t, self.dim))
context = x.new_zeros(batch, 0, self.context_dim)
if image is not None:
y_context = self.zero_y.repeat(batch, 1, 1)
context = torch.cat([context, y_context], dim=1)
image_context = misc_dropout(self.pre_image_condition(image).view(-1, self.num_tokens, self.context_dim)) # torch.cat([y[:,:-1,:], self.pre_image_condition(y[:,-1:,:]) ], dim=1)
context = torch.cat([context, image_context], dim=1)
else:
y_context = self.zero_y.repeat(batch, 1, 1)
context = torch.cat([context, y_context], dim=1)
image_context = torch.zeros_like(self.zero_y.repeat(batch, 1, 1))[:,:self.num_tokens]
context = torch.cat([context, image_context], dim=1)
# repeat f times for spatial e and context
e = e.repeat_interleave(repeats=f+1, dim=0)
context = context.repeat_interleave(repeats=f+1, dim=0)
## always in shape (b f) c h w, except for temporal layer
x = rearrange(x, 'b c f h w -> (b f) c h w')
# encoder
xs = []
for block in self.input_blocks:
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask)
xs.append(x)
# middle
for block in self.middle_block:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask)
# decoder
for block in self.output_blocks:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None)
# head
x = self.out(x)
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
return x[:,:,1:]
def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None):
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
x = module(x, context)
elif isinstance(module, Upsample):
x = module(x)
elif isinstance(module, Downsample):
x = module(x)
elif isinstance(module, Resample):
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference)
else:
x = module(x)
return x