File size: 3,301 Bytes
55866f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torchvision import transforms
import torch.nn as nn
import numpy as np

from concept_attention.segmentation import SegmentationAbstractClass
import concept_attention.binary_segmentation_baselines.dino_src.vision_transformer as vits

class DINOSegmentationModel(SegmentationAbstractClass):
    
    def __init__(self, arch="vit_small", patch_size=8, image_size=480, image_path=None, device="cuda"):
        self.device = device
        # build model
        self.image_size = image_size
        self.patch_size = patch_size
        self.model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
        for p in self.model.parameters():
            p.requires_grad = False
        self.model.eval()
        self.model.to(device)
        # Load up the model
        if arch == "vit_small" and patch_size == 16:
            url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
        elif arch == "vit_small" and patch_size == 8:
            url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"  # model used for visualizations in our paper
        elif arch == "vit_base" and patch_size == 16:
            url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
        elif arch == "vit_base" and patch_size == 8:
            url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"

        if url is not None:
            print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
            state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
            self.model.load_state_dict(state_dict, strict=True)
        else:
            print("There is no reference weights available for this model => We use random weights.")

        # Transforms
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

    def segment_individual_image(self, image, concepts, caption, **kwargs):
        # NOTE: Do nothing with concepts or caption, as this is not a text conditioned approach. 
        if isinstance(image, torch.Tensor):
            image = transforms.Resize(self.image_size)(image)
        else:
            image = self.transform(image)
        # Predict the raw scores. 
        # make the image divisible by the patch size
        w, h = image.shape[1] - image.shape[1] % self.patch_size, image.shape[2] - image.shape[2] % self.patch_size
        image = image[:, :w, :h].unsqueeze(0)

        w_featmap = image.shape[-2] // self.patch_size
        h_featmap = image.shape[-1] // self.patch_size

        attentions = self.model.get_last_selfattention(image.to(self.device))
        nh = attentions.shape[1] # number of head

        # we keep only the output patch attention
        attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
        attentions = attentions.reshape(nh, w_featmap, h_featmap)
        attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=self.patch_size, mode="nearest")[0]
        attentions = torch.mean(attentions, dim=0, keepdim=True)
        attentions = attentions.repeat(len(concepts), 1, 1)

        return attentions, None