ControlNet's picture
Upload model
eedeabf verified
from typing import Optional
import torch
from torch import Tensor
from torch.nn import Linear, Module
from transformers import PreTrainedModel
from .encoder import MarlinEncoder
from .decoder import MarlinDecoder
from .config import MarlinConfig
class Marlin(Module):
def __init__(
self,
img_size: int,
patch_size: int,
n_frames: int,
encoder_embed_dim: int,
encoder_depth: int,
encoder_num_heads: int,
decoder_embed_dim: int,
decoder_depth: int,
decoder_num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
qk_scale: Optional[float],
drop_rate: float,
attn_drop_rate: float,
norm_layer: str,
init_values: float,
tubelet_size: int,
as_feature_extractor: bool = True,
):
super().__init__()
self.encoder = MarlinEncoder(
img_size=img_size,
patch_size=patch_size,
n_frames=n_frames,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
)
self.as_feature_extractor = as_feature_extractor
self.clip_frames = n_frames
if as_feature_extractor:
self.enc_dec_proj = None
self.decoder = None
else:
self.decoder = MarlinDecoder(
img_size=img_size,
patch_size=patch_size,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
)
self.enc_dec_proj = Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
if self.as_feature_extractor:
raise RuntimeError(
"For feature extraction, please use `extract_features` or `extract_video`."
)
else:
assert mask is not None
x = self.encoder(x, mask)
x = self.enc_dec_proj(x)
x = self.decoder(x, mask)
return x
@property
def device(self):
return self.encoder.norm.weight.device
def extract_features(self, x: Tensor, keep_seq: bool = True):
"""Extract features for one video clip (v)"""
if self.training:
return self.encoder.extract_features(x, seq_mean_pool=not keep_seq)
else:
with torch.no_grad():
return self.encoder.extract_features(x, seq_mean_pool=not keep_seq)
class MarlinModel(PreTrainedModel):
config_class = MarlinConfig
def __init__(self, config: MarlinConfig):
super().__init__(config)
self.config = config
self.marlin = Marlin(
img_size=config.img_size,
patch_size=config.patch_size,
n_frames=config.n_frames,
encoder_embed_dim=config.encoder_embed_dim,
encoder_depth=config.encoder_depth,
encoder_num_heads=config.encoder_num_heads,
decoder_embed_dim=config.decoder_embed_dim,
decoder_depth=config.decoder_depth,
decoder_num_heads=config.decoder_num_heads,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
qk_scale=config.qk_scale,
drop_rate=config.drop_rate,
attn_drop_rate=config.attn_drop_rate,
norm_layer=config.norm_layer,
init_values=config.init_values,
tubelet_size=config.tubelet_size,
)
def forward(self, x: Tensor, keep_seq: bool = True):
return self.marlin.extract_features(x, keep_seq=keep_seq)