|
|
|
|
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
|
|
from .utils import convert_to_pil |
|
|
|
|
|
class REBNCONV(nn.Module): |
|
|
|
def __init__(self, in_ch=3, out_ch=3, dirate=1): |
|
super(REBNCONV, self).__init__() |
|
self.conv_s1 = nn.Conv2d( |
|
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) |
|
self.bn_s1 = nn.BatchNorm2d(out_ch) |
|
self.relu_s1 = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
hx = x |
|
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) |
|
return xout |
|
|
|
|
|
def _upsample_like(src, tar): |
|
"""upsample tensor 'src' to have the same spatial size with tensor 'tar'.""" |
|
src = F.upsample(src, size=tar.shape[2:], mode='bilinear') |
|
return src |
|
|
|
|
|
class RSU7(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU7, self).__init__() |
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) |
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) |
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) |
|
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) |
|
|
|
def forward(self, x): |
|
hx = x |
|
hxin = self.rebnconvin(hx) |
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
hx3 = self.rebnconv3(hx) |
|
hx = self.pool3(hx3) |
|
hx4 = self.rebnconv4(hx) |
|
hx = self.pool4(hx4) |
|
hx5 = self.rebnconv5(hx) |
|
hx = self.pool5(hx5) |
|
hx6 = self.rebnconv6(hx) |
|
hx7 = self.rebnconv7(hx6) |
|
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) |
|
hx6dup = _upsample_like(hx6d, hx5) |
|
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) |
|
hx5dup = _upsample_like(hx5d, hx4) |
|
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) |
|
hx4dup = _upsample_like(hx4d, hx3) |
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) |
|
hx3dup = _upsample_like(hx3d, hx2) |
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) |
|
hx2dup = _upsample_like(hx2d, hx1) |
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) |
|
return hx1d + hxin |
|
|
|
|
|
class RSU6(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU6, self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) |
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) |
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) |
|
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) |
|
|
|
def forward(self, x): |
|
hx = x |
|
hxin = self.rebnconvin(hx) |
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
hx3 = self.rebnconv3(hx) |
|
hx = self.pool3(hx3) |
|
hx4 = self.rebnconv4(hx) |
|
hx = self.pool4(hx4) |
|
hx5 = self.rebnconv5(hx) |
|
hx6 = self.rebnconv6(hx5) |
|
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) |
|
hx5dup = _upsample_like(hx5d, hx4) |
|
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) |
|
hx4dup = _upsample_like(hx4d, hx3) |
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) |
|
hx3dup = _upsample_like(hx3d, hx2) |
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) |
|
hx2dup = _upsample_like(hx2d, hx1) |
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) |
|
return hx1d + hxin |
|
|
|
|
|
class RSU5(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU5, self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) |
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) |
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) |
|
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) |
|
|
|
def forward(self, x): |
|
hx = x |
|
hxin = self.rebnconvin(hx) |
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
hx3 = self.rebnconv3(hx) |
|
hx = self.pool3(hx3) |
|
hx4 = self.rebnconv4(hx) |
|
hx5 = self.rebnconv5(hx4) |
|
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) |
|
hx4dup = _upsample_like(hx4d, hx3) |
|
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) |
|
hx3dup = _upsample_like(hx3d, hx2) |
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) |
|
hx2dup = _upsample_like(hx2d, hx1) |
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) |
|
return hx1d + hxin |
|
|
|
|
|
class RSU4(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU4, self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) |
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) |
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) |
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) |
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) |
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) |
|
|
|
def forward(self, x): |
|
|
|
hx = x |
|
hxin = self.rebnconvin(hx) |
|
hx1 = self.rebnconv1(hxin) |
|
hx = self.pool1(hx1) |
|
hx2 = self.rebnconv2(hx) |
|
hx = self.pool2(hx2) |
|
hx3 = self.rebnconv3(hx) |
|
hx4 = self.rebnconv4(hx3) |
|
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) |
|
hx3dup = _upsample_like(hx3d, hx2) |
|
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) |
|
hx2dup = _upsample_like(hx2d, hx1) |
|
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) |
|
return hx1d + hxin |
|
|
|
|
|
class RSU4F(nn.Module): |
|
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
|
super(RSU4F, self).__init__() |
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) |
|
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) |
|
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) |
|
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) |
|
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) |
|
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) |
|
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) |
|
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) |
|
|
|
def forward(self, x): |
|
|
|
hx = x |
|
hxin = self.rebnconvin(hx) |
|
hx1 = self.rebnconv1(hxin) |
|
hx2 = self.rebnconv2(hx1) |
|
hx3 = self.rebnconv3(hx2) |
|
hx4 = self.rebnconv4(hx3) |
|
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) |
|
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) |
|
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) |
|
return hx1d + hxin |
|
|
|
|
|
class U2NET(nn.Module): |
|
|
|
def __init__(self, in_ch=3, out_ch=1): |
|
super(U2NET, self).__init__() |
|
|
|
|
|
self.stage1 = RSU7(in_ch, 32, 64) |
|
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.stage2 = RSU6(64, 32, 128) |
|
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.stage3 = RSU5(128, 64, 256) |
|
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.stage4 = RSU4(256, 128, 512) |
|
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.stage5 = RSU4F(512, 256, 512) |
|
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) |
|
self.stage6 = RSU4F(512, 256, 512) |
|
|
|
self.stage5d = RSU4F(1024, 256, 512) |
|
self.stage4d = RSU4(1024, 128, 256) |
|
self.stage3d = RSU5(512, 64, 128) |
|
self.stage2d = RSU6(256, 32, 64) |
|
self.stage1d = RSU7(128, 16, 64) |
|
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) |
|
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) |
|
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) |
|
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) |
|
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) |
|
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) |
|
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) |
|
|
|
def forward(self, x): |
|
|
|
hx = x |
|
hx1 = self.stage1(hx) |
|
hx = self.pool12(hx1) |
|
hx2 = self.stage2(hx) |
|
hx = self.pool23(hx2) |
|
hx3 = self.stage3(hx) |
|
hx = self.pool34(hx3) |
|
hx4 = self.stage4(hx) |
|
hx = self.pool45(hx4) |
|
hx5 = self.stage5(hx) |
|
hx = self.pool56(hx5) |
|
hx6 = self.stage6(hx) |
|
hx6up = _upsample_like(hx6, hx5) |
|
|
|
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) |
|
hx5dup = _upsample_like(hx5d, hx4) |
|
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) |
|
hx4dup = _upsample_like(hx4d, hx3) |
|
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) |
|
hx3dup = _upsample_like(hx3d, hx2) |
|
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) |
|
hx2dup = _upsample_like(hx2d, hx1) |
|
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) |
|
d1 = self.side1(hx1d) |
|
d2 = self.side2(hx2d) |
|
d2 = _upsample_like(d2, d1) |
|
d3 = self.side3(hx3d) |
|
d3 = _upsample_like(d3, d1) |
|
d4 = self.side4(hx4d) |
|
d4 = _upsample_like(d4, d1) |
|
d5 = self.side5(hx5d) |
|
d5 = _upsample_like(d5, d1) |
|
d6 = self.side6(hx6) |
|
d6 = _upsample_like(d6, d1) |
|
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) |
|
return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid( |
|
d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid( |
|
d5), torch.sigmoid(d6) |
|
|
|
|
|
|
|
class SalientAnnotator: |
|
def __init__(self, cfg, device=None): |
|
self.return_image = cfg.get('RETURN_IMAGE', False) |
|
self.use_crop = cfg.get('USE_CROP', False) |
|
pretrained_model = cfg['PRETRAINED_MODEL'] |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device |
|
self.norm_mean = [0.485, 0.456, 0.406] |
|
self.norm_std = [0.229, 0.224, 0.225] |
|
self.norm_size = cfg.get('NORM_SIZE', [320, 320]) |
|
self.model = U2NET(3, 1) |
|
self.model.load_state_dict(torch.load(pretrained_model, map_location='cpu')) |
|
self.model = self.model.to(self.device).eval() |
|
self.transform_input = transforms.Compose([ |
|
transforms.Resize(self.norm_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=self.norm_mean, std=self.norm_std) |
|
]) |
|
|
|
def forward(self, image, return_image=None): |
|
return_image = return_image if return_image is not None else self.return_image |
|
image = convert_to_pil(image) |
|
img_w, img_h = image.size |
|
input_image = self.transform_input(image).float().unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
results = self.model(input_image) |
|
data = results[0][0, 0, :, :] |
|
data_norm = (data - torch.min(data)) / ( |
|
torch.max(data) - torch.min(data)) |
|
data_norm_np = (data_norm.cpu().numpy() * 255).astype('uint8') |
|
data_norm_rst = cv2.resize(data_norm_np, (img_w, img_h)) |
|
if return_image: |
|
image_np = np.array(image) |
|
_, binary_mask = cv2.threshold(data_norm_rst, 1, 255, cv2.THRESH_BINARY) |
|
white_bg = np.ones_like(image) * 255 |
|
ret_image = np.where(binary_mask[:, :, np.newaxis] == 255, image_np, white_bg).astype(np.uint8) |
|
ret_mask = np.where(binary_mask, 255, 0).astype(np.uint8) |
|
if self.use_crop: |
|
x, y, w, h = cv2.boundingRect(binary_mask) |
|
ret_image = ret_image[y:y + h, x:x + w] |
|
ret_mask = ret_mask[y:y + h, x:x + w] |
|
return {"image": ret_image, "mask": ret_mask} |
|
else: |
|
return data_norm_rst |
|
|
|
|
|
class SalientVideoAnnotator(SalientAnnotator): |
|
def forward(self, frames, return_image=None): |
|
ret_frames = [] |
|
for frame in frames: |
|
anno_frame = super().forward(frame) |
|
ret_frames.append(anno_frame) |
|
return ret_frames |