LPOSS / models /maskclip.py
stojnvla's picture
initial commit
06d49db
# ---------------------------------------------------------------------------------------------------
# CLIP-DINOiser
# authors: Monika Wysoczanska, Warsaw University of Technology
# Copyright (c) OpenMMLab. All rights reserved.
# Modified version of the original MaskCLIP code: https://github.com/chongzhou96/MaskCLIP/tree/master
# ---------------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
from torch import Tensor
from open_clip import get_tokenizer, create_model_from_pretrained
import torchvision.transforms as T
from .utils import imagenet_templates
OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
class MaskClip(nn.Module):
def __init__(
self,
clip_model="ViT-B-16",
pretrained="laion2b_s34b_b88k",
patch_size=16,
img_size=(224, 224),
in_channels=768,
text_channels=512,
):
super(MaskClip, self).__init__()
self.patch_size = patch_size
self.img_size = img_size
model, _ = create_model_from_pretrained(clip_model, pretrained=pretrained)
model.eval()
self.clip_T = OPENAI_NORMALIZE
self.hook_features = {}
self.backbone = model
def hook_fn_forward(module, input, output):
self.hook_features["v"] = output
self.backbone.visual.transformer.resblocks[-2].register_forward_hook(hook_fn_forward)
self._positional_embd = nn.Parameter(self.backbone.visual.positional_embedding.data.clone())
self.proj = nn.Conv2d(in_channels, text_channels, 1, bias=False)
self.proj.weight = nn.Parameter(model.visual.proj.t()[:, :, None, None])
self.tokenizer = get_tokenizer(clip_model)
@torch.no_grad()
def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
"""Extract features from images."""
pos_embed = self.backbone.visual.positional_embedding
B, C, H, W = inputs.shape
hw_shape = (H // self.patch_size, W // self.patch_size)
x_len, pos_len = hw_shape[0]*hw_shape[1], pos_embed.shape[0]
if x_len != pos_len:
if pos_len == (self.img_size[0] // self.patch_size) * (self.img_size[1] // self.patch_size) + 1:
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
'{}, {}'.format(x_len, pos_len))
self.backbone.visual.positional_embedding.data = self.resize_pos_embed(
self._positional_embd[None], hw_shape, (pos_h, pos_w), 'bicubic')[0]
_ = self.backbone(inputs)
v = self.hook_features["v"]
v = self.extract_v(v, self.backbone.visual.transformer.resblocks[-1]).permute(1, 0, 2)
v = self.backbone.visual.ln_post(v)
# v = v[:, 1:] # was there in original code
v = v.permute(1, 0, 2)[:, 1:] # put this as per https://github.com/wysoczanska/clip_dinoiser/issues/10
v = v.reshape(B, hw_shape[0], hw_shape[1], -1).permute(0, 3, 1, 2).contiguous()
self.backbone.visual.positional_embedding.data = self._positional_embd
return v
@torch.no_grad()
def extract_v(self, x, block):
y = block.ln_1(x)
y = torch.nn.functional.linear(y, block.attn.in_proj_weight, block.attn.in_proj_bias)
B, N, C = y.shape
y = y.view(B, N, 3, C // 3).permute(2, 0, 1, 3).reshape(3 * B, N, C // 3)
y = F.linear(y, block.attn.out_proj.weight, block.attn.out_proj.bias)
q, k, v = y.tensor_split(3, dim=0)
v += x
v += block.mlp(block.ln_2(v))
return v
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = pos_shape
cls_token_weight = pos_embed[:, 0]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
@torch.no_grad()
def decode_head(self, x: Tensor) -> Tensor:
feat = self.proj(x)
return feat
@torch.no_grad()
def forward(self, inputs: Tensor) -> Tensor:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
inputs = self.clip_T(inputs)
x = self.extract_feat(inputs)
feats = self.decode_head(x)
return feats
@torch.no_grad()
def get_classifier(self, classnames:List[str]) -> Tensor:
aug_embeddings = torch.stack([self._embed_label(label) for label in classnames])
aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True)
return aug_embeddings.squeeze(1)
@torch.no_grad()
def _embed_label(self, label: str) -> Tensor:
"""Encode label name into a single vector."""
all_prompts = [self.tokenizer(template.format(label)) for template in imagenet_templates]
all_prompts = torch.cat(all_prompts)
all_prompts = all_prompts.to(self.backbone.visual.positional_embedding.device)
out = self.backbone.encode_text(all_prompts)
out /= out.norm(dim=-1, keepdim=True)
out = out.mean(dim=0)
return out