|
|
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
|
|
from .utils import convert_to_torch |
|
|
|
norm_layer = nn.InstanceNorm2d |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_features): |
|
super(ResidualBlock, self).__init__() |
|
|
|
conv_block = [ |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(in_features, in_features, 3), |
|
norm_layer(in_features), |
|
nn.ReLU(inplace=True), |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(in_features, in_features, 3), |
|
norm_layer(in_features) |
|
] |
|
|
|
self.conv_block = nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
return x + self.conv_block(x) |
|
|
|
|
|
class ContourInference(nn.Module): |
|
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): |
|
super(ContourInference, self).__init__() |
|
|
|
|
|
model0 = [ |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(input_nc, 64, 7), |
|
norm_layer(64), |
|
nn.ReLU(inplace=True) |
|
] |
|
self.model0 = nn.Sequential(*model0) |
|
|
|
|
|
model1 = [] |
|
in_features = 64 |
|
out_features = in_features * 2 |
|
for _ in range(2): |
|
model1 += [ |
|
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), |
|
norm_layer(out_features), |
|
nn.ReLU(inplace=True) |
|
] |
|
in_features = out_features |
|
out_features = in_features * 2 |
|
self.model1 = nn.Sequential(*model1) |
|
|
|
model2 = [] |
|
|
|
for _ in range(n_residual_blocks): |
|
model2 += [ResidualBlock(in_features)] |
|
self.model2 = nn.Sequential(*model2) |
|
|
|
|
|
model3 = [] |
|
out_features = in_features // 2 |
|
for _ in range(2): |
|
model3 += [ |
|
nn.ConvTranspose2d(in_features, |
|
out_features, |
|
3, |
|
stride=2, |
|
padding=1, |
|
output_padding=1), |
|
norm_layer(out_features), |
|
nn.ReLU(inplace=True) |
|
] |
|
in_features = out_features |
|
out_features = in_features // 2 |
|
self.model3 = nn.Sequential(*model3) |
|
|
|
|
|
model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] |
|
if sigmoid: |
|
model4 += [nn.Sigmoid()] |
|
|
|
self.model4 = nn.Sequential(*model4) |
|
|
|
def forward(self, x, cond=None): |
|
out = self.model0(x) |
|
out = self.model1(out) |
|
out = self.model2(out) |
|
out = self.model3(out) |
|
out = self.model4(out) |
|
|
|
return out |
|
|
|
|
|
class ScribbleAnnotator: |
|
def __init__(self, cfg, device=None): |
|
input_nc = cfg.get('INPUT_NC', 3) |
|
output_nc = cfg.get('OUTPUT_NC', 1) |
|
n_residual_blocks = cfg.get('N_RESIDUAL_BLOCKS', 3) |
|
sigmoid = cfg.get('SIGMOID', True) |
|
pretrained_model = cfg['PRETRAINED_MODEL'] |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device |
|
self.model = ContourInference(input_nc, output_nc, n_residual_blocks, |
|
sigmoid) |
|
self.model.load_state_dict(torch.load(pretrained_model, weights_only=True)) |
|
self.model = self.model.eval().requires_grad_(False).to(self.device) |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
@torch.autocast('cuda', enabled=False) |
|
def forward(self, image): |
|
is_batch = False if len(image.shape) == 3 else True |
|
image = convert_to_torch(image) |
|
if len(image.shape) == 3: |
|
image = rearrange(image, 'h w c -> 1 c h w') |
|
image = image.float().div(255).to(self.device) |
|
contour_map = self.model(image) |
|
contour_map = (contour_map.squeeze(dim=1) * 255.0).clip( |
|
0, 255).cpu().numpy().astype(np.uint8) |
|
contour_map = contour_map[..., None].repeat(3, -1) |
|
if not is_batch: |
|
contour_map = contour_map.squeeze() |
|
return contour_map |
|
|
|
|
|
class ScribbleVideoAnnotator(ScribbleAnnotator): |
|
def forward(self, frames): |
|
ret_frames = [] |
|
for frame in frames: |
|
anno_frame = super().forward(np.array(frame)) |
|
ret_frames.append(anno_frame) |
|
return ret_frames |