# --------------------------------------------------------------------------------------------------- # CLIP-DINOiser # authors: Monika Wysoczanska, Warsaw University of Technology # Copyright (c) OpenMMLab. All rights reserved. # Modified version of the original MaskCLIP code: https://github.com/chongzhou96/MaskCLIP/tree/master # --------------------------------------------------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple from torch import Tensor from open_clip import get_tokenizer, create_model_from_pretrained import torchvision.transforms as T from .utils import imagenet_templates OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) class MaskClip(nn.Module): def __init__( self, clip_model="ViT-B-16", pretrained="laion2b_s34b_b88k", patch_size=16, img_size=(224, 224), in_channels=768, text_channels=512, ): super(MaskClip, self).__init__() self.patch_size = patch_size self.img_size = img_size model, _ = create_model_from_pretrained(clip_model, pretrained=pretrained) model.eval() self.clip_T = OPENAI_NORMALIZE self.hook_features = {} self.backbone = model def hook_fn_forward(module, input, output): self.hook_features["v"] = output self.backbone.visual.transformer.resblocks[-2].register_forward_hook(hook_fn_forward) self._positional_embd = nn.Parameter(self.backbone.visual.positional_embedding.data.clone()) self.proj = nn.Conv2d(in_channels, text_channels, 1, bias=False) self.proj.weight = nn.Parameter(model.visual.proj.t()[:, :, None, None]) self.tokenizer = get_tokenizer(clip_model) @torch.no_grad() def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: """Extract features from images.""" pos_embed = self.backbone.visual.positional_embedding B, C, H, W = inputs.shape hw_shape = (H // self.patch_size, W // self.patch_size) x_len, pos_len = hw_shape[0]*hw_shape[1], pos_embed.shape[0] if x_len != pos_len: if pos_len == (self.img_size[0] // self.patch_size) * (self.img_size[1] // self.patch_size) + 1: pos_h = self.img_size[0] // self.patch_size pos_w = self.img_size[1] // self.patch_size else: raise ValueError( '{}, {}'.format(x_len, pos_len)) self.backbone.visual.positional_embedding.data = self.resize_pos_embed( self._positional_embd[None], hw_shape, (pos_h, pos_w), 'bicubic')[0] _ = self.backbone(inputs) v = self.hook_features["v"] v = self.extract_v(v, self.backbone.visual.transformer.resblocks[-1]).permute(1, 0, 2) v = self.backbone.visual.ln_post(v) # v = v[:, 1:] # was there in original code v = v.permute(1, 0, 2)[:, 1:] # put this as per https://github.com/wysoczanska/clip_dinoiser/issues/10 v = v.reshape(B, hw_shape[0], hw_shape[1], -1).permute(0, 3, 1, 2).contiguous() self.backbone.visual.positional_embedding.data = self._positional_embd return v @torch.no_grad() def extract_v(self, x, block): y = block.ln_1(x) y = torch.nn.functional.linear(y, block.attn.in_proj_weight, block.attn.in_proj_bias) B, N, C = y.shape y = y.view(B, N, 3, C // 3).permute(2, 0, 1, 3).reshape(3 * B, N, C // 3) y = F.linear(y, block.attn.out_proj.weight, block.attn.out_proj.bias) q, k, v = y.tensor_split(3, dim=0) v += x v += block.mlp(block.ln_2(v)) return v @staticmethod def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): """Resize pos_embed weights. Resize pos_embed using bicubic interpolate method. Args: pos_embed (torch.Tensor): Position embedding weights. input_shpae (tuple): Tuple for (downsampled input image height, downsampled input image width). pos_shape (tuple): The resolution of downsampled origin training image. mode (str): Algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'nearest'`` Return: torch.Tensor: The resized pos_embed of shape [B, L_new, C] """ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' pos_h, pos_w = pos_shape cls_token_weight = pos_embed[:, 0] pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] pos_embed_weight = pos_embed_weight.reshape( 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) pos_embed_weight = F.interpolate( pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) cls_token_weight = cls_token_weight.unsqueeze(1) pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) return pos_embed @torch.no_grad() def decode_head(self, x: Tensor) -> Tensor: feat = self.proj(x) return feat @torch.no_grad() def forward(self, inputs: Tensor) -> Tensor: """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" inputs = self.clip_T(inputs) x = self.extract_feat(inputs) feats = self.decode_head(x) return feats @torch.no_grad() def get_classifier(self, classnames:List[str]) -> Tensor: aug_embeddings = torch.stack([self._embed_label(label) for label in classnames]) aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True) return aug_embeddings.squeeze(1) @torch.no_grad() def _embed_label(self, label: str) -> Tensor: """Encode label name into a single vector.""" all_prompts = [self.tokenizer(template.format(label)) for template in imagenet_templates] all_prompts = torch.cat(all_prompts) all_prompts = all_prompts.to(self.backbone.visual.positional_embedding.device) out = self.backbone.encode_text(all_prompts) out /= out.norm(dim=-1, keepdim=True) out = out.mean(dim=0) return out