marlin_vit_large_ytf / encoder.py
ControlNet's picture
Upload model
eedeabf verified
from torch import nn, Tensor
from torch.nn import ModuleList, LayerNorm
from .modules import PatchEmbedding3d, Block
from .positional_embedding import SinCosPositionalEmbedding
class MarlinEncoder(nn.Module):
def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
norm_layer="LayerNorm", init_values=0., tubelet_size=2
):
super().__init__()
self.embed_dim = embed_dim
self.patch_embedding = PatchEmbedding3d(
input_size=(3, n_frames, img_size, img_size),
patch_size=(tubelet_size, patch_size, patch_size),
embedding=embed_dim
)
num_patches = (img_size // patch_size) * (img_size // patch_size) * (n_frames // tubelet_size)
# sine-cosine positional embeddings
self.pos_embedding = SinCosPositionalEmbedding((num_patches, embed_dim), dropout_rate=0.)
if norm_layer == "LayerNorm":
self.norm_layer = LayerNorm
self.norm = self.norm_layer(embed_dim)
else:
raise NotImplementedError("Only LayerNorm is supported")
self.blocks = ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer,
init_values=init_values)
for _ in range(depth)
])
self.apply(self._init_weights)
@staticmethod
def _init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
# mask: (B, T, N) with boolean values, 0 -> masked, 1 -> visible
assert len(x.shape) == 5, "x must be 5D"
emb = self.patch_embedding(x)
emb = self.pos_embedding(emb)
b, _, c = emb.shape
emb = emb[mask].view(b, -1, c) # only visible patches are used
emb = self.forward_features(emb)
return emb
def extract_features(self, x: Tensor, seq_mean_pool: bool) -> Tensor:
x = self.patch_embedding(x)
x = self.pos_embedding(x)
for block in self.blocks:
x = block(x)
if seq_mean_pool:
x = x.mean(dim=1)
x = self.norm(x)
return x