|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import List, Tuple |
|
from torch import Tensor |
|
from open_clip import get_tokenizer, create_model_from_pretrained |
|
import torchvision.transforms as T |
|
from .utils import imagenet_templates |
|
|
|
OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
class MaskClip(nn.Module): |
|
def __init__( |
|
self, |
|
clip_model="ViT-B-16", |
|
pretrained="laion2b_s34b_b88k", |
|
patch_size=16, |
|
img_size=(224, 224), |
|
in_channels=768, |
|
text_channels=512, |
|
): |
|
super(MaskClip, self).__init__() |
|
|
|
self.patch_size = patch_size |
|
self.img_size = img_size |
|
model, _ = create_model_from_pretrained(clip_model, pretrained=pretrained) |
|
model.eval() |
|
self.clip_T = OPENAI_NORMALIZE |
|
self.hook_features = {} |
|
self.backbone = model |
|
def hook_fn_forward(module, input, output): |
|
self.hook_features["v"] = output |
|
self.backbone.visual.transformer.resblocks[-2].register_forward_hook(hook_fn_forward) |
|
self._positional_embd = nn.Parameter(self.backbone.visual.positional_embedding.data.clone()) |
|
self.proj = nn.Conv2d(in_channels, text_channels, 1, bias=False) |
|
self.proj.weight = nn.Parameter(model.visual.proj.t()[:, :, None, None]) |
|
self.tokenizer = get_tokenizer(clip_model) |
|
|
|
@torch.no_grad() |
|
def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: |
|
"""Extract features from images.""" |
|
pos_embed = self.backbone.visual.positional_embedding |
|
|
|
B, C, H, W = inputs.shape |
|
hw_shape = (H // self.patch_size, W // self.patch_size) |
|
x_len, pos_len = hw_shape[0]*hw_shape[1], pos_embed.shape[0] |
|
|
|
if x_len != pos_len: |
|
if pos_len == (self.img_size[0] // self.patch_size) * (self.img_size[1] // self.patch_size) + 1: |
|
pos_h = self.img_size[0] // self.patch_size |
|
pos_w = self.img_size[1] // self.patch_size |
|
else: |
|
raise ValueError( |
|
'{}, {}'.format(x_len, pos_len)) |
|
|
|
self.backbone.visual.positional_embedding.data = self.resize_pos_embed( |
|
self._positional_embd[None], hw_shape, (pos_h, pos_w), 'bicubic')[0] |
|
|
|
_ = self.backbone(inputs) |
|
v = self.hook_features["v"] |
|
v = self.extract_v(v, self.backbone.visual.transformer.resblocks[-1]).permute(1, 0, 2) |
|
v = self.backbone.visual.ln_post(v) |
|
|
|
v = v.permute(1, 0, 2)[:, 1:] |
|
v = v.reshape(B, hw_shape[0], hw_shape[1], -1).permute(0, 3, 1, 2).contiguous() |
|
|
|
self.backbone.visual.positional_embedding.data = self._positional_embd |
|
return v |
|
|
|
@torch.no_grad() |
|
def extract_v(self, x, block): |
|
y = block.ln_1(x) |
|
y = torch.nn.functional.linear(y, block.attn.in_proj_weight, block.attn.in_proj_bias) |
|
B, N, C = y.shape |
|
y = y.view(B, N, 3, C // 3).permute(2, 0, 1, 3).reshape(3 * B, N, C // 3) |
|
y = F.linear(y, block.attn.out_proj.weight, block.attn.out_proj.bias) |
|
q, k, v = y.tensor_split(3, dim=0) |
|
v += x |
|
v += block.mlp(block.ln_2(v)) |
|
return v |
|
|
|
|
|
@staticmethod |
|
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): |
|
"""Resize pos_embed weights. |
|
|
|
Resize pos_embed using bicubic interpolate method. |
|
Args: |
|
pos_embed (torch.Tensor): Position embedding weights. |
|
input_shpae (tuple): Tuple for (downsampled input image height, |
|
downsampled input image width). |
|
pos_shape (tuple): The resolution of downsampled origin training |
|
image. |
|
mode (str): Algorithm used for upsampling: |
|
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | |
|
``'trilinear'``. Default: ``'nearest'`` |
|
Return: |
|
torch.Tensor: The resized pos_embed of shape [B, L_new, C] |
|
""" |
|
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' |
|
pos_h, pos_w = pos_shape |
|
cls_token_weight = pos_embed[:, 0] |
|
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] |
|
pos_embed_weight = pos_embed_weight.reshape( |
|
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) |
|
pos_embed_weight = F.interpolate( |
|
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) |
|
cls_token_weight = cls_token_weight.unsqueeze(1) |
|
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) |
|
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) |
|
return pos_embed |
|
|
|
@torch.no_grad() |
|
def decode_head(self, x: Tensor) -> Tensor: |
|
feat = self.proj(x) |
|
|
|
return feat |
|
|
|
|
|
@torch.no_grad() |
|
def forward(self, inputs: Tensor) -> Tensor: |
|
"""Encode images with backbone and decode into a semantic segmentation |
|
map of the same size as input.""" |
|
inputs = self.clip_T(inputs) |
|
x = self.extract_feat(inputs) |
|
feats = self.decode_head(x) |
|
return feats |
|
|
|
|
|
@torch.no_grad() |
|
def get_classifier(self, classnames:List[str]) -> Tensor: |
|
aug_embeddings = torch.stack([self._embed_label(label) for label in classnames]) |
|
aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True) |
|
return aug_embeddings.squeeze(1) |
|
|
|
|
|
@torch.no_grad() |
|
def _embed_label(self, label: str) -> Tensor: |
|
"""Encode label name into a single vector.""" |
|
all_prompts = [self.tokenizer(template.format(label)) for template in imagenet_templates] |
|
all_prompts = torch.cat(all_prompts) |
|
all_prompts = all_prompts.to(self.backbone.visual.positional_embedding.device) |
|
out = self.backbone.encode_text(all_prompts) |
|
out /= out.norm(dim=-1, keepdim=True) |
|
out = out.mean(dim=0) |
|
return out |
|
|