File size: 4,320 Bytes
e0ca513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import torch.nn as nn
from torchvision import transforms as T
from omegaconf import OmegaConf
from typing import List
from mmseg import datasets as mmseg_datasets
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import numpy as np
from PIL import Image
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer
# TCL
from models import build_model
from models.tcl.pamr import PAMR
from datasets.builder import build_text_transform
from segmentation.evaluation.builder import build_dataset_class_tokens
PALETTE = mmseg_datasets.PascalVOCDataset.PALETTE + mmseg_datasets.COCOStuffDataset.PALETTE
PALETTE *= 5
def build_demo_model(ckpt_path="./tcl.pth", size=224):
# Load TCL model
print(f"Load {ckpt_path} ...")
ckpt = torch.load(ckpt_path)
cfg = OmegaConf.load("./tcl/configs/tcl.yml")
model = build_model(cfg.model)
# The (minimal) checkpoint only contains learned parameters; Frozen CLIP params are not contained.
model.load_state_dict(ckpt['model'], strict=False)
model.eval()
# build TCLDemo
demo = TCLDemo(model, size)
return demo
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return T.Compose([
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
_convert_image_to_rgb,
T.ToTensor(),
T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
class TCLDemo(nn.Module):
"""
Args:
model: TCL model
size: resize shorter side of image to `size`
"""
def __init__(self, model, size=224):
super().__init__()
self.model = model
self.size = size
self.preprocess = _transform(size)
self.tokenizer = build_text_transform()
self.pamr = PAMR(10, [1, 2, 4, 8, 12, 24]).eval()
@property
def device(self):
return next(self.model.parameters()).device
def build_text_embedding(self, texts: List[str]):
text_tokens = build_dataset_class_tokens(self.tokenizer, "custom", texts)
text_embeddings = self.model.build_text_embedding(text_tokens)
return text_embeddings
def forward(self, image, texts: List[str], apply_pamr=True):
"""
Args:
image: PIL.Image
texts: List[str]
"""
with_bg = False
if texts[0] in ["bg", "background"]:
with_bg = True
texts = texts[1:]
# preprocess
image = self.preprocess(image).unsqueeze(0).to(self.device)
text_embs = self.build_text_embedding(texts)
# forward
mask, simmap = self.model.generate_masks(
image,
text_embs,
)
# refinement
if apply_pamr:
mask = self.pamr(image, mask)
I, T, H, W = mask.shape
if with_bg:
bg_thresh = 0.4 if apply_pamr else 0.5
bg = torch.full(
[I, 1, H, W],
bg_thresh,
dtype=torch.float,
device=mask.device
)
mask = torch.cat([bg, mask], dim=1)
return mask
def visualize(self, image, texts, mask):
"""
Args:
image (PIL.Image)
texts (List[str])
mask (Tensor)
"""
with_bg = texts[0] in ["bg", "background"]
N = len(texts)
if with_bg:
palette = PALETTE
else:
palette = PALETTE[1:]
MetadataCatalog.pop("__unused", None)
md = MetadataCatalog.get("__unused")
md.set(
thing_classes=texts,
thing_colors=palette,
stuff_classes=texts,
stuff_colors=palette,
)
seg_res = mask.squeeze(0).argmax(0).cpu()
if with_bg:
seg_res[seg_res == 0] = N + 10
image = image.resize(mask.shape[2:][::-1])
image = np.asarray(image)
visualizer = Visualizer(image, md)
r = visualizer.draw_sem_seg(seg_res)
res = Image.fromarray(r.get_image())
return res
def forward_vis(self, image, texts, apply_pamr=True):
mask = self(image, texts, apply_pamr=apply_pamr)
res = self.visualize(image, texts, mask)
return res
|