File size: 4,255 Bytes
eedeabf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import Optional
import torch
from torch import Tensor
from torch.nn import Linear, Module
from transformers import PreTrainedModel
from .encoder import MarlinEncoder
from .decoder import MarlinDecoder
from .config import MarlinConfig
class Marlin(Module):
def __init__(
self,
img_size: int,
patch_size: int,
n_frames: int,
encoder_embed_dim: int,
encoder_depth: int,
encoder_num_heads: int,
decoder_embed_dim: int,
decoder_depth: int,
decoder_num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
qk_scale: Optional[float],
drop_rate: float,
attn_drop_rate: float,
norm_layer: str,
init_values: float,
tubelet_size: int,
as_feature_extractor: bool = True,
):
super().__init__()
self.encoder = MarlinEncoder(
img_size=img_size,
patch_size=patch_size,
n_frames=n_frames,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
)
self.as_feature_extractor = as_feature_extractor
self.clip_frames = n_frames
if as_feature_extractor:
self.enc_dec_proj = None
self.decoder = None
else:
self.decoder = MarlinDecoder(
img_size=img_size,
patch_size=patch_size,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
)
self.enc_dec_proj = Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
if self.as_feature_extractor:
raise RuntimeError(
"For feature extraction, please use `extract_features` or `extract_video`."
)
else:
assert mask is not None
x = self.encoder(x, mask)
x = self.enc_dec_proj(x)
x = self.decoder(x, mask)
return x
@property
def device(self):
return self.encoder.norm.weight.device
def extract_features(self, x: Tensor, keep_seq: bool = True):
"""Extract features for one video clip (v)"""
if self.training:
return self.encoder.extract_features(x, seq_mean_pool=not keep_seq)
else:
with torch.no_grad():
return self.encoder.extract_features(x, seq_mean_pool=not keep_seq)
class MarlinModel(PreTrainedModel):
config_class = MarlinConfig
def __init__(self, config: MarlinConfig):
super().__init__(config)
self.config = config
self.marlin = Marlin(
img_size=config.img_size,
patch_size=config.patch_size,
n_frames=config.n_frames,
encoder_embed_dim=config.encoder_embed_dim,
encoder_depth=config.encoder_depth,
encoder_num_heads=config.encoder_num_heads,
decoder_embed_dim=config.decoder_embed_dim,
decoder_depth=config.decoder_depth,
decoder_num_heads=config.decoder_num_heads,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
qk_scale=config.qk_scale,
drop_rate=config.drop_rate,
attn_drop_rate=config.attn_drop_rate,
norm_layer=config.norm_layer,
init_values=config.init_values,
tubelet_size=config.tubelet_size,
)
def forward(self, x: Tensor, keep_seq: bool = True):
return self.marlin.extract_features(x, keep_seq=keep_seq)
|