Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from fairseq.models import FairseqEncoder, register_model, FairseqEncoderDecoderModel, register_model_architecture | |
from fairseq.models.transformer import TransformerDecoder, Embedding, TransformerModel | |
from fairseq.models.fairseq_encoder import EncoderOut | |
from fairseq import utils | |
# from timm.models.vision_transformer import HybridEmbed, PatchEmbed, Block | |
from timm.models.layers import trunc_normal_ | |
import torch | |
from torch.hub import load_state_dict_from_url | |
from functools import partial | |
import logging | |
logger = logging.getLogger(__name__) | |
DEFAULT_MAX_TARGET_POSITIONS = 1024 | |
class ViTTRModel(FairseqEncoderDecoderModel): | |
def add_args(parser): | |
TransformerModel.add_args(parser) | |
# parser.add_argument('--decoder-embed-dim', type=int, metavar='N', | |
# help='decoder embedding dimension') | |
parser.add_argument( | |
'--vit-img-size', type=int, metavar='N', | |
help='the image size of h and w (h=w) of the ViT' | |
) | |
parser.add_argument( | |
'--vit-patch-size', type=int, metavar='N', | |
help='the patch size of h and w (h=w) of the ViT' | |
) | |
parser.add_argument( | |
'--vit-dim', type=int, metavar='N', | |
help='the hidden size of the ViT' | |
) | |
parser.add_argument( | |
'--vit-depth', type=int, metavar='N', | |
help='the layer num of the ViT' | |
) | |
parser.add_argument( | |
'--vit-heads', type=int, metavar='N', | |
help='the head num of the ViT' | |
) | |
parser.add_argument( | |
'--vit-channels', type=int, metavar='N', default=3, | |
help='the input image channels of the ViT' | |
) | |
parser.add_argument( | |
'--vit-dropout', type=float, default=0.0, | |
help='the dropout ratio of the ViT' | |
) | |
parser.add_argument( | |
'--vit-atten-dropout', type=float, default=0.0, | |
help='the input embedding dropout ratio of the ViT' | |
) | |
parser.add_argument( | |
'--encoder-pretrained-url', type=str, | |
help='the pretrained parameter url for the ViT encoder' | |
) | |
def build_model(cls, args, task): | |
encoder = ViTTREncoder( | |
args = args, | |
dictionary = task.source_dictionary | |
) | |
if args.encoder_pretrained_url: | |
logger.info('load pretrianed encoder parameter from: {}'.format(args.encoder_pretrained_url)) | |
encoder_state_dict = load_state_dict_from_url(args.encoder_pretrained_url) | |
encoder.load_state_dict(encoder_state_dict, strict=False) | |
if getattr(args, "max_target_positions", None) is None: | |
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS | |
decoder_embed_tokens = cls.build_embedding( | |
args, task.target_dictionary, args.decoder_embed_dim, args.decoder_embed_path | |
) | |
decoder = TransformerDecoder( | |
args = args, | |
dictionary=task.target_dictionary, | |
embed_tokens=decoder_embed_tokens, | |
no_encoder_attn=False | |
) | |
model = cls(encoder, decoder) | |
return model | |
def build_embedding(cls, args, dictionary, embed_dim, path=None): | |
num_embeddings = len(dictionary) | |
padding_idx = dictionary.pad() | |
emb = Embedding(num_embeddings, embed_dim, padding_idx) | |
# if provided, load from preloaded dictionaries | |
if path: | |
embed_dict = utils.parse_embedding(path) | |
utils.load_embedding(embed_dict, dictionary, emb) | |
return emb | |
def forward(self, imgs, prev_output_tokens, **kwargs): | |
encoder_out = self.encoder(imgs, **kwargs) | |
decoder_out = self.decoder( | |
prev_output_tokens, encoder_out=encoder_out, **kwargs | |
) | |
return decoder_out | |
def ViT_TR_base(args): | |
# ViT Encoder vit_base_patch16_224 | |
args.vit_img_size = getattr(args, "vit_img_size", 224) | |
args.resize_img_size = args.vit_img_size | |
args.vit_patch_size = getattr(args, "vit_patch_size", 16) | |
args.vit_dim = getattr(args, "vit_dim", 768) | |
args.vit_depth = getattr(args, "vit_depth", 12) | |
args.vit_heads = getattr(args, "vit_heads", 12) | |
args.encoder_pretrained_url = getattr(args, "encoder_pretrained_url", | |
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth") | |
# Transformer Decoder | |
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) | |
args.encoder_layers = getattr(args, "encoder_layers", 6) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) | |
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) | |
args.decoder_ffn_embed_dim = getattr( | |
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim | |
) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.0) | |
args.activation_dropout = getattr(args, "activation_dropout", 0.0) | |
args.activation_fn = getattr(args, "activation_fn", "relu") | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.share_decoder_input_output_embed = getattr( | |
args, "share_decoder_input_output_embed", False | |
) | |
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) | |
args.no_token_positional_embeddings = getattr( | |
args, "no_token_positional_embeddings", False | |
) | |
args.adaptive_input = getattr(args, "adaptive_input", False) | |
args.no_cross_attention = getattr(args, "no_cross_attention", False) | |
args.cross_self_attention = getattr(args, "cross_self_attention", False) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) | |
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) | |
args.checkpoint_activations = getattr(args, "checkpoint_activations", False) | |
args.offload_activations = getattr(args, "offload_activations", False) | |
if args.offload_activations: | |
args.checkpoint_activations = True | |
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) | |
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) | |
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) | |
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) | |
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) | |
def large_architecture(args): | |
# ViT Encoder vit_base_patch16_224 | |
args.vit_img_size = getattr(args, "vit_img_size", 384) | |
args.resize_img_size = args.vit_img_size | |
args.vit_patch_size = getattr(args, "vit_patch_size", 16) | |
args.vit_dim = getattr(args, "vit_dim", 1024) | |
args.vit_depth = getattr(args, "vit_depth", 24) | |
args.vit_heads = getattr(args, "vit_heads", 16) | |
args.encoder_pretrained_url = getattr(args, "encoder_pretrained_url", | |
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth") | |
# Transformer Decoder | |
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) | |
args.encoder_layers = getattr(args, "encoder_layers", 6) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) | |
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) | |
args.decoder_ffn_embed_dim = getattr( | |
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim | |
) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.0) | |
args.activation_dropout = getattr(args, "activation_dropout", 0.0) | |
args.activation_fn = getattr(args, "activation_fn", "relu") | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.share_decoder_input_output_embed = getattr( | |
args, "share_decoder_input_output_embed", False | |
) | |
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) | |
args.no_token_positional_embeddings = getattr( | |
args, "no_token_positional_embeddings", False | |
) | |
args.adaptive_input = getattr(args, "adaptive_input", False) | |
args.no_cross_attention = getattr(args, "no_cross_attention", False) | |
args.cross_self_attention = getattr(args, "cross_self_attention", False) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) | |
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) | |
args.checkpoint_activations = getattr(args, "checkpoint_activations", False) | |
args.offload_activations = getattr(args, "offload_activations", False) | |
if args.offload_activations: | |
args.checkpoint_activations = True | |
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) | |
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) | |
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) | |
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) | |
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) | |
class ViTTREncoder(FairseqEncoder): | |
def __init__(self, args, dictionary): | |
super().__init__(dictionary) | |
img_size = args.vit_img_size | |
patch_size = args.vit_patch_size | |
in_chans = args.vit_channels | |
embed_dim = args.vit_dim | |
depth = args.vit_depth | |
num_heads = args.vit_heads | |
mlp_ratio=4. | |
qkv_bias=True | |
qk_scale=None | |
drop_rate = args.vit_dropout | |
attn_drop_rate = args.vit_atten_dropout | |
drop_path_rate=0. | |
hybrid_backbone=None | |
norm_layer=None | |
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |
if hybrid_backbone is not None: | |
self.patch_embed = HybridEmbed( | |
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) | |
else: | |
self.patch_embed = PatchEmbed( | |
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) | |
num_patches = self.patch_embed.num_patches | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
self.blocks = nn.ModuleList([ | |
Block( | |
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | |
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) | |
for i in range(depth)]) | |
self.norm = norm_layer(embed_dim) | |
trunc_normal_(self.pos_embed, std=.02) | |
trunc_normal_(self.cls_token, std=.02) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward_features(self, x): | |
B = x.shape[0] # bs, num_patches, dim | |
x = self.patch_embed(x) | |
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |
x = torch.cat((cls_tokens, x), dim=1) | |
x = x + self.pos_embed | |
encoder_embedding = x # bs, n + 1, dim | |
x = self.pos_drop(x) | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) # bs, n + 1, dim | |
return x, encoder_embedding | |
def forward(self, imgs): | |
x, encoder_embedding = self.forward_features(imgs) # bs, n + 1, dim | |
x = x.transpose(0, 1) # n + 1, bs, dim | |
encoder_padding_mask = torch.zeros(*x.shape[:2]).transpose(0, 1).to(imgs.device) | |
return { | |
"encoder_out": [x], # T x B x C | |
"encoder_padding_mask": [encoder_padding_mask], # B x T | |
"encoder_embedding": [encoder_embedding], # B x T x C | |
"encoder_states": [], # List[T x B x C] | |
"src_tokens": [], | |
"src_lengths": [], | |
} | |
def reorder_encoder_out(self, encoder_out, new_order): | |
""" | |
Reorder encoder output according to `new_order`. | |
Args: | |
encoder_out: output from the ``forward()`` method | |
new_order (LongTensor): desired order | |
Returns: | |
`encoder_out` rearranged according to `new_order` | |
""" | |
_encoder_out = encoder_out['encoder_out'][0] | |
_encoder_padding_mask = encoder_out['encoder_padding_mask'][0] | |
_encoder_embedding = encoder_out['encoder_embedding'][0] | |
return { | |
"encoder_out": [_encoder_out.index_select(1, new_order)], | |
"encoder_padding_mask": [_encoder_padding_mask.index_select(0, new_order)], # B x T | |
"encoder_embedding": [_encoder_padding_mask.index_select(0, new_order)], # B x T x C | |
"encoder_states": [], | |
"src_tokens": [], | |
"src_lengths": [], | |
} | |
if __name__ == '__main__': | |
pass |