File size: 2,265 Bytes
06d49db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import torchvision.transforms as T

NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

class DINO(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
        self.hook_features = {}
        def hook_fn_forward_qkv(module, input, output):
            self.hook_features["qkv"] = output

        self.backbone._modules["blocks"][-1]._modules["attn"]._modules[
            "qkv"
        ].register_forward_hook(hook_fn_forward_qkv)

        self.patch_size = 16
        self.enc_type_feats = "v"


    @torch.no_grad()
    def extract_feats(self, type_feats="k"):
        """
        DINO feature extractor. Attaches a hook on the last attention layer.
        :param type_feats: (string) - type of features from DINO ViT
        """
        nh = self.backbone.blocks[-1].attn.num_heads
        nb_im, nb_tokens, C_qkv = self.hook_features["qkv"].shape

        qkv = (
            self.hook_features["qkv"]
                .reshape(
                nb_im, nb_tokens, 3, nh, C_qkv // nh // 3
            )  # 3 corresponding to |qkv|
                .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]
        if type_feats == "q":
            return q.transpose(1, 2).float()
        elif type_feats == "k":
            return k.transpose(1, 2).float()
        elif type_feats == "v":
            return v.transpose(1, 2).float()
        else:
            raise ValueError("Unknown features")


    @torch.no_grad()
    def forward(self, x):
        x = NORMALIZE(x)
        h_featmap = x.shape[-2] // self.patch_size
        w_featmap = x.shape[-1] // self.patch_size

        # Forward pass
        # Encoder forward pass and get hooked intermediate values
        _ = self.backbone(x)

        # Get decoder features
        feats = self.extract_feats(type_feats=self.enc_type_feats)
        num_extra_tokens = 1

        # B nbtokens+1 nh dim
        feats = feats[:, num_extra_tokens:, :, :].flatten(-2, -1).permute(0, 2, 1)  # B C nbtokens
        # B, C, nbtokens
        feats = feats / feats.norm(dim=1, keepdim=True)  # normalize features

        return feats, (h_featmap, w_featmap)