# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from .utils import convert_to_pil class REBNCONV(nn.Module): def __init__(self, in_ch=3, out_ch=3, dirate=1): super(REBNCONV, self).__init__() self.conv_s1 = nn.Conv2d( in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) self.bn_s1 = nn.BatchNorm2d(out_ch) self.relu_s1 = nn.ReLU(inplace=True) def forward(self, x): hx = x xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) return xout def _upsample_like(src, tar): """upsample tensor 'src' to have the same spatial size with tensor 'tar'.""" src = F.upsample(src, size=tar.shape[2:], mode='bilinear') return src class RSU7(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU7, self).__init__() self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) def forward(self, x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx = self.pool4(hx4) hx5 = self.rebnconv5(hx) hx = self.pool5(hx5) hx6 = self.rebnconv6(hx) hx7 = self.rebnconv7(hx6) hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) hx6dup = _upsample_like(hx6d, hx5) hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) hx5dup = _upsample_like(hx5d, hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) hx4dup = _upsample_like(hx4d, hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) hx3dup = _upsample_like(hx3d, hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) hx2dup = _upsample_like(hx2d, hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) return hx1d + hxin class RSU6(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU6, self).__init__() self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) def forward(self, x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx = self.pool4(hx4) hx5 = self.rebnconv5(hx) hx6 = self.rebnconv6(hx5) hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) hx5dup = _upsample_like(hx5d, hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) hx4dup = _upsample_like(hx4d, hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) hx3dup = _upsample_like(hx3d, hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) hx2dup = _upsample_like(hx2d, hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) return hx1d + hxin class RSU5(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU5, self).__init__() self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) def forward(self, x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx5 = self.rebnconv5(hx4) hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) hx4dup = _upsample_like(hx4d, hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) hx3dup = _upsample_like(hx3d, hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) hx2dup = _upsample_like(hx2d, hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) return hx1d + hxin class RSU4(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU4, self).__init__() self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) def forward(self, x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx4 = self.rebnconv4(hx3) hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) hx3dup = _upsample_like(hx3d, hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) hx2dup = _upsample_like(hx2d, hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) return hx1d + hxin class RSU4F(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU4F, self).__init__() self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) def forward(self, x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx2 = self.rebnconv2(hx1) hx3 = self.rebnconv3(hx2) hx4 = self.rebnconv4(hx3) hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) return hx1d + hxin class U2NET(nn.Module): def __init__(self, in_ch=3, out_ch=1): super(U2NET, self).__init__() # encoder self.stage1 = RSU7(in_ch, 32, 64) self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.stage2 = RSU6(64, 32, 128) self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.stage3 = RSU5(128, 64, 256) self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.stage4 = RSU4(256, 128, 512) self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.stage5 = RSU4F(512, 256, 512) self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.stage6 = RSU4F(512, 256, 512) # decoder self.stage5d = RSU4F(1024, 256, 512) self.stage4d = RSU4(1024, 128, 256) self.stage3d = RSU5(512, 64, 128) self.stage2d = RSU6(256, 32, 64) self.stage1d = RSU7(128, 16, 64) self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) def forward(self, x): hx = x hx1 = self.stage1(hx) hx = self.pool12(hx1) hx2 = self.stage2(hx) hx = self.pool23(hx2) hx3 = self.stage3(hx) hx = self.pool34(hx3) hx4 = self.stage4(hx) hx = self.pool45(hx4) hx5 = self.stage5(hx) hx = self.pool56(hx5) hx6 = self.stage6(hx) hx6up = _upsample_like(hx6, hx5) hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) hx5dup = _upsample_like(hx5d, hx4) hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) hx4dup = _upsample_like(hx4d, hx3) hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) hx3dup = _upsample_like(hx3d, hx2) hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) hx2dup = _upsample_like(hx2d, hx1) hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) d1 = self.side1(hx1d) d2 = self.side2(hx2d) d2 = _upsample_like(d2, d1) d3 = self.side3(hx3d) d3 = _upsample_like(d3, d1) d4 = self.side4(hx4d) d4 = _upsample_like(d4, d1) d5 = self.side5(hx5d) d5 = _upsample_like(d5, d1) d6 = self.side6(hx6) d6 = _upsample_like(d6, d1) d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid( d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid( d5), torch.sigmoid(d6) class SalientAnnotator: def __init__(self, cfg, device=None): self.return_image = cfg.get('RETURN_IMAGE', False) self.use_crop = cfg.get('USE_CROP', False) pretrained_model = cfg['PRETRAINED_MODEL'] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.norm_mean = [0.485, 0.456, 0.406] self.norm_std = [0.229, 0.224, 0.225] self.norm_size = cfg.get('NORM_SIZE', [320, 320]) self.model = U2NET(3, 1) self.model.load_state_dict(torch.load(pretrained_model, map_location='cpu')) self.model = self.model.to(self.device).eval() self.transform_input = transforms.Compose([ transforms.Resize(self.norm_size), transforms.ToTensor(), transforms.Normalize(mean=self.norm_mean, std=self.norm_std) ]) def forward(self, image, return_image=None): return_image = return_image if return_image is not None else self.return_image image = convert_to_pil(image) img_w, img_h = image.size input_image = self.transform_input(image).float().unsqueeze(0).to(self.device) with torch.no_grad(): results = self.model(input_image) data = results[0][0, 0, :, :] data_norm = (data - torch.min(data)) / ( torch.max(data) - torch.min(data)) data_norm_np = (data_norm.cpu().numpy() * 255).astype('uint8') data_norm_rst = cv2.resize(data_norm_np, (img_w, img_h)) if return_image: image_np = np.array(image) _, binary_mask = cv2.threshold(data_norm_rst, 1, 255, cv2.THRESH_BINARY) white_bg = np.ones_like(image) * 255 ret_image = np.where(binary_mask[:, :, np.newaxis] == 255, image_np, white_bg).astype(np.uint8) ret_mask = np.where(binary_mask, 255, 0).astype(np.uint8) if self.use_crop: x, y, w, h = cv2.boundingRect(binary_mask) ret_image = ret_image[y:y + h, x:x + w] ret_mask = ret_mask[y:y + h, x:x + w] return {"image": ret_image, "mask": ret_mask} else: return data_norm_rst class SalientVideoAnnotator(SalientAnnotator): def forward(self, frames, return_image=None): ret_frames = [] for frame in frames: anno_frame = super().forward(frame) ret_frames.append(anno_frame) return ret_frames