|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
|
|
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
|
|
class DINO(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') |
|
self.hook_features = {} |
|
def hook_fn_forward_qkv(module, input, output): |
|
self.hook_features["qkv"] = output |
|
|
|
self.backbone._modules["blocks"][-1]._modules["attn"]._modules[ |
|
"qkv" |
|
].register_forward_hook(hook_fn_forward_qkv) |
|
|
|
self.patch_size = 16 |
|
self.enc_type_feats = "v" |
|
|
|
|
|
@torch.no_grad() |
|
def extract_feats(self, type_feats="k"): |
|
""" |
|
DINO feature extractor. Attaches a hook on the last attention layer. |
|
:param type_feats: (string) - type of features from DINO ViT |
|
""" |
|
nh = self.backbone.blocks[-1].attn.num_heads |
|
nb_im, nb_tokens, C_qkv = self.hook_features["qkv"].shape |
|
|
|
qkv = ( |
|
self.hook_features["qkv"] |
|
.reshape( |
|
nb_im, nb_tokens, 3, nh, C_qkv // nh // 3 |
|
) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
if type_feats == "q": |
|
return q.transpose(1, 2).float() |
|
elif type_feats == "k": |
|
return k.transpose(1, 2).float() |
|
elif type_feats == "v": |
|
return v.transpose(1, 2).float() |
|
else: |
|
raise ValueError("Unknown features") |
|
|
|
|
|
@torch.no_grad() |
|
def forward(self, x): |
|
x = NORMALIZE(x) |
|
h_featmap = x.shape[-2] // self.patch_size |
|
w_featmap = x.shape[-1] // self.patch_size |
|
|
|
|
|
|
|
_ = self.backbone(x) |
|
|
|
|
|
feats = self.extract_feats(type_feats=self.enc_type_feats) |
|
num_extra_tokens = 1 |
|
|
|
|
|
feats = feats[:, num_extra_tokens:, :, :].flatten(-2, -1).permute(0, 2, 1) |
|
|
|
feats = feats / feats.norm(dim=1, keepdim=True) |
|
|
|
return feats, (h_featmap, w_featmap) |