import torch import torch.nn as nn import torchvision.transforms as T NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) class DINO(nn.Module): def __init__(self): super().__init__() self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') self.hook_features = {} def hook_fn_forward_qkv(module, input, output): self.hook_features["qkv"] = output self.backbone._modules["blocks"][-1]._modules["attn"]._modules[ "qkv" ].register_forward_hook(hook_fn_forward_qkv) self.patch_size = 16 self.enc_type_feats = "v" @torch.no_grad() def extract_feats(self, type_feats="k"): """ DINO feature extractor. Attaches a hook on the last attention layer. :param type_feats: (string) - type of features from DINO ViT """ nh = self.backbone.blocks[-1].attn.num_heads nb_im, nb_tokens, C_qkv = self.hook_features["qkv"].shape qkv = ( self.hook_features["qkv"] .reshape( nb_im, nb_tokens, 3, nh, C_qkv // nh // 3 ) # 3 corresponding to |qkv| .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] if type_feats == "q": return q.transpose(1, 2).float() elif type_feats == "k": return k.transpose(1, 2).float() elif type_feats == "v": return v.transpose(1, 2).float() else: raise ValueError("Unknown features") @torch.no_grad() def forward(self, x): x = NORMALIZE(x) h_featmap = x.shape[-2] // self.patch_size w_featmap = x.shape[-1] // self.patch_size # Forward pass # Encoder forward pass and get hooked intermediate values _ = self.backbone(x) # Get decoder features feats = self.extract_feats(type_feats=self.enc_type_feats) num_extra_tokens = 1 # B nbtokens+1 nh dim feats = feats[:, num_extra_tokens:, :, :].flatten(-2, -1).permute(0, 2, 1) # B C nbtokens # B, C, nbtokens feats = feats / feats.norm(dim=1, keepdim=True) # normalize features return feats, (h_featmap, w_featmap)