Tzktz's picture
Upload 7664 files
6fc683c verified
import torch.nn as nn
from fairseq.models import FairseqEncoder, register_model, FairseqEncoderDecoderModel, register_model_architecture
from fairseq.models.transformer import TransformerDecoder, Embedding, TransformerModel
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq import utils
# from timm.models.vision_transformer import HybridEmbed, PatchEmbed, Block
from timm.models.layers import trunc_normal_
import torch
from torch.hub import load_state_dict_from_url
from functools import partial
import logging
logger = logging.getLogger(__name__)
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('ViT_TR')
class ViTTRModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
TransformerModel.add_args(parser)
# parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
# help='decoder embedding dimension')
parser.add_argument(
'--vit-img-size', type=int, metavar='N',
help='the image size of h and w (h=w) of the ViT'
)
parser.add_argument(
'--vit-patch-size', type=int, metavar='N',
help='the patch size of h and w (h=w) of the ViT'
)
parser.add_argument(
'--vit-dim', type=int, metavar='N',
help='the hidden size of the ViT'
)
parser.add_argument(
'--vit-depth', type=int, metavar='N',
help='the layer num of the ViT'
)
parser.add_argument(
'--vit-heads', type=int, metavar='N',
help='the head num of the ViT'
)
parser.add_argument(
'--vit-channels', type=int, metavar='N', default=3,
help='the input image channels of the ViT'
)
parser.add_argument(
'--vit-dropout', type=float, default=0.0,
help='the dropout ratio of the ViT'
)
parser.add_argument(
'--vit-atten-dropout', type=float, default=0.0,
help='the input embedding dropout ratio of the ViT'
)
parser.add_argument(
'--encoder-pretrained-url', type=str,
help='the pretrained parameter url for the ViT encoder'
)
@classmethod
def build_model(cls, args, task):
encoder = ViTTREncoder(
args = args,
dictionary = task.source_dictionary
)
if args.encoder_pretrained_url:
logger.info('load pretrianed encoder parameter from: {}'.format(args.encoder_pretrained_url))
encoder_state_dict = load_state_dict_from_url(args.encoder_pretrained_url)
encoder.load_state_dict(encoder_state_dict, strict=False)
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
decoder_embed_tokens = cls.build_embedding(
args, task.target_dictionary, args.decoder_embed_dim, args.decoder_embed_path
)
decoder = TransformerDecoder(
args = args,
dictionary=task.target_dictionary,
embed_tokens=decoder_embed_tokens,
no_encoder_attn=False
)
model = cls(encoder, decoder)
return model
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
def forward(self, imgs, prev_output_tokens, **kwargs):
encoder_out = self.encoder(imgs, **kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
@register_model_architecture('ViT_TR', 'ViT_TR_base')
def ViT_TR_base(args):
# ViT Encoder vit_base_patch16_224
args.vit_img_size = getattr(args, "vit_img_size", 224)
args.resize_img_size = args.vit_img_size
args.vit_patch_size = getattr(args, "vit_patch_size", 16)
args.vit_dim = getattr(args, "vit_dim", 768)
args.vit_depth = getattr(args, "vit_depth", 12)
args.vit_heads = getattr(args, "vit_heads", 12)
args.encoder_pretrained_url = getattr(args, "encoder_pretrained_url",
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth")
# Transformer Decoder
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
@register_model_architecture('ViT_TR', 'ViT_TR_large')
def large_architecture(args):
# ViT Encoder vit_base_patch16_224
args.vit_img_size = getattr(args, "vit_img_size", 384)
args.resize_img_size = args.vit_img_size
args.vit_patch_size = getattr(args, "vit_patch_size", 16)
args.vit_dim = getattr(args, "vit_dim", 1024)
args.vit_depth = getattr(args, "vit_depth", 24)
args.vit_heads = getattr(args, "vit_heads", 16)
args.encoder_pretrained_url = getattr(args, "encoder_pretrained_url",
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth")
# Transformer Decoder
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
class ViTTREncoder(FairseqEncoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
img_size = args.vit_img_size
patch_size = args.vit_patch_size
in_chans = args.vit_channels
embed_dim = args.vit_dim
depth = args.vit_depth
num_heads = args.vit_heads
mlp_ratio=4.
qkv_bias=True
qk_scale=None
drop_rate = args.vit_dropout
attn_drop_rate = args.vit_atten_dropout
drop_path_rate=0.
hybrid_backbone=None
norm_layer=None
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
B = x.shape[0] # bs, num_patches, dim
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
encoder_embedding = x # bs, n + 1, dim
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x) # bs, n + 1, dim
return x, encoder_embedding
def forward(self, imgs):
x, encoder_embedding = self.forward_features(imgs) # bs, n + 1, dim
x = x.transpose(0, 1) # n + 1, bs, dim
encoder_padding_mask = torch.zeros(*x.shape[:2]).transpose(0, 1).to(imgs.device)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
_encoder_out = encoder_out['encoder_out'][0]
_encoder_padding_mask = encoder_out['encoder_padding_mask'][0]
_encoder_embedding = encoder_out['encoder_embedding'][0]
return {
"encoder_out": [_encoder_out.index_select(1, new_order)],
"encoder_padding_mask": [_encoder_padding_mask.index_select(0, new_order)], # B x T
"encoder_embedding": [_encoder_padding_mask.index_select(0, new_order)], # B x T x C
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
if __name__ == '__main__':
pass