LPOSS / models /dino.py
stojnvla's picture
initial commit
06d49db
import torch
import torch.nn as nn
import torchvision.transforms as T
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
class DINO(nn.Module):
def __init__(self):
super().__init__()
self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
self.hook_features = {}
def hook_fn_forward_qkv(module, input, output):
self.hook_features["qkv"] = output
self.backbone._modules["blocks"][-1]._modules["attn"]._modules[
"qkv"
].register_forward_hook(hook_fn_forward_qkv)
self.patch_size = 16
self.enc_type_feats = "v"
@torch.no_grad()
def extract_feats(self, type_feats="k"):
"""
DINO feature extractor. Attaches a hook on the last attention layer.
:param type_feats: (string) - type of features from DINO ViT
"""
nh = self.backbone.blocks[-1].attn.num_heads
nb_im, nb_tokens, C_qkv = self.hook_features["qkv"].shape
qkv = (
self.hook_features["qkv"]
.reshape(
nb_im, nb_tokens, 3, nh, C_qkv // nh // 3
) # 3 corresponding to |qkv|
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
if type_feats == "q":
return q.transpose(1, 2).float()
elif type_feats == "k":
return k.transpose(1, 2).float()
elif type_feats == "v":
return v.transpose(1, 2).float()
else:
raise ValueError("Unknown features")
@torch.no_grad()
def forward(self, x):
x = NORMALIZE(x)
h_featmap = x.shape[-2] // self.patch_size
w_featmap = x.shape[-1] // self.patch_size
# Forward pass
# Encoder forward pass and get hooked intermediate values
_ = self.backbone(x)
# Get decoder features
feats = self.extract_feats(type_feats=self.enc_type_feats)
num_extra_tokens = 1
# B nbtokens+1 nh dim
feats = feats[:, num_extra_tokens:, :, :].flatten(-2, -1).permute(0, 2, 1) # B C nbtokens
# B, C, nbtokens
feats = feats / feats.norm(dim=1, keepdim=True) # normalize features
return feats, (h_featmap, w_featmap)