DDColor / ddcolor_model.py
Sulio's picture
Upload folder using huggingface_hub
00e6746 verified
import torch
import torch.nn as nn
from basicsr.archs.ddcolor_arch_utils.unet import Hook, CustomPixelShuffle_ICNR, UnetBlockWide, NormType, custom_conv_layer
from basicsr.archs.ddcolor_arch_utils.convnext import ConvNeXt
from basicsr.archs.ddcolor_arch_utils.transformer_utils import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
from basicsr.archs.ddcolor_arch_utils.position_encoding import PositionEmbeddingSine
class DDColor(nn.Module):
def __init__(
self,
encoder_name='convnext-l',
decoder_name='MultiScaleColorDecoder',
num_input_channels=3,
input_size=(256, 256),
nf=512,
num_output_channels=3,
last_norm='Weight',
do_normalize=False,
num_queries=256,
num_scales=3,
dec_layers=9,
):
super().__init__()
self.encoder = ImageEncoder(encoder_name, ['norm0', 'norm1', 'norm2', 'norm3'])
self.encoder.eval()
test_input = torch.randn(1, num_input_channels, *input_size)
self.encoder(test_input)
self.decoder = DuelDecoder(
self.encoder.hooks,
nf=nf,
last_norm=last_norm,
num_queries=num_queries,
num_scales=num_scales,
dec_layers=dec_layers,
decoder_name=decoder_name
)
self.refine_net = nn.Sequential(
custom_conv_layer(num_queries + 3, num_output_channels, ks=1, use_activ=False, norm_type=NormType.Spectral)
)
self.do_normalize = do_normalize
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def normalize(self, img):
return (img - self.mean) / self.std
def denormalize(self, img):
return img * self.std + self.mean
def forward(self, x):
if x.shape[1] == 3:
x = self.normalize(x)
self.encoder(x)
out_feat = self.decoder()
coarse_input = torch.cat([out_feat, x], dim=1)
out = self.refine_net(coarse_input)
if self.do_normalize:
out = self.denormalize(out)
return out
class ImageEncoder(nn.Module):
def __init__(self, encoder_name, hook_names):
super().__init__()
assert encoder_name == 'convnext-t' or encoder_name == 'convnext-l'
if encoder_name == 'convnext-t':
self.arch = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
elif encoder_name == 'convnext-l':
self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
else:
raise NotImplementedError
self.encoder_name = encoder_name
self.hook_names = hook_names
self.hooks = self.setup_hooks()
def setup_hooks(self):
hooks = [Hook(self.arch._modules[name]) for name in self.hook_names]
return hooks
def forward(self, x):
return self.arch(x)
class DuelDecoder(nn.Module):
def __init__(
self,
hooks,
nf=512,
blur=True,
last_norm='Weight',
num_queries=256,
num_scales=3,
dec_layers=9,
decoder_name='MultiScaleColorDecoder',
):
super().__init__()
self.hooks = hooks
self.nf = nf
self.blur = blur
self.last_norm = getattr(NormType, last_norm)
self.decoder_name = decoder_name
self.layers = self.make_layers()
embed_dim = nf // 2
self.last_shuf = CustomPixelShuffle_ICNR(embed_dim, embed_dim, blur=self.blur, norm_type=self.last_norm, scale=4)
assert decoder_name == 'MultiScaleColorDecoder'
self.color_decoder = MultiScaleColorDecoder(
in_channels=[512, 512, 256],
num_queries=num_queries,
num_scales=num_scales,
dec_layers=dec_layers,
)
def make_layers(self):
decoder_layers = []
in_c = self.hooks[-1].feature.shape[1]
out_c = self.nf
setup_hooks = self.hooks[-2::-1]
for layer_index, hook in enumerate(setup_hooks):
feature_c = hook.feature.shape[1]
if layer_index == len(setup_hooks) - 1:
out_c = out_c // 2
decoder_layers.append(
UnetBlockWide(
in_c, feature_c, out_c, hook, blur=self.blur, self_attention=False, norm_type=NormType.Spectral))
in_c = out_c
return nn.Sequential(*decoder_layers)
def forward(self):
encode_feat = self.hooks[-1].feature
out0 = self.layers[0](encode_feat)
out1 = self.layers[1](out0)
out2 = self.layers[2](out1)
out3 = self.last_shuf(out2)
return self.color_decoder([out0, out1, out2], out3)
class MultiScaleColorDecoder(nn.Module):
def __init__(
self,
in_channels,
hidden_dim=256,
num_queries=100,
nheads=8,
dim_feedforward=2048,
dec_layers=9,
pre_norm=False,
color_embed_dim=256,
enforce_input_project=True,
num_scales=3,
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_queries = num_queries
self.num_layers = dec_layers
self.num_feature_levels = num_scales
# Positional encoding layer
self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
# Learnable query features and embeddings
self.query_feat = nn.Embedding(num_queries, hidden_dim)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
# Learnable level embeddings
self.level_embed = nn.Embedding(num_scales, hidden_dim)
# Input projection layers
self.input_proj = nn.ModuleList(
[self._make_input_proj(in_ch, hidden_dim, enforce_input_project) for in_ch in in_channels]
)
# Transformer layers
self.transformer_self_attention_layers = nn.ModuleList()
self.transformer_cross_attention_layers = nn.ModuleList()
self.transformer_ffn_layers = nn.ModuleList()
for _ in range(dec_layers):
self.transformer_self_attention_layers.append(
SelfAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.transformer_cross_attention_layers.append(
CrossAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.transformer_ffn_layers.append(
FFNLayer(
d_model=hidden_dim,
dim_feedforward=dim_feedforward,
dropout=0.0,
normalize_before=pre_norm,
)
)
# Layer normalization for the decoder output
self.decoder_norm = nn.LayerNorm(hidden_dim)
# Output embedding layer
self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3)
def forward(self, x, img_features):
assert len(x) == self.num_feature_levels
src, pos = self._get_src_and_pos(x)
bs = src[0].shape[1]
# Prepare query embeddings (QxNxC)
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
for i in range(self.num_layers):
level_index = i % self.num_feature_levels
# attention: cross-attention first
output = self.transformer_cross_attention_layers[i](
output, src[level_index],
memory_mask=None,
memory_key_padding_mask=None,
pos=pos[level_index], query_pos=query_embed
)
output = self.transformer_self_attention_layers[i](
output, tgt_mask=None,
tgt_key_padding_mask=None,
query_pos=query_embed
)
# FFN
output = self.transformer_ffn_layers[i](
output
)
decoder_output = self.decoder_norm(output).transpose(0, 1)
color_embed = self.color_embed(decoder_output)
out = torch.einsum("bqc,bchw->bqhw", color_embed, img_features)
return out
def _make_input_proj(self, in_ch, hidden_dim, enforce):
if in_ch != hidden_dim or enforce:
proj = nn.Conv2d(in_ch, hidden_dim, kernel_size=1)
nn.init.kaiming_uniform_(proj.weight, a=1)
if proj.bias is not None:
nn.init.constant_(proj.bias, 0)
return proj
return nn.Sequential()
def _get_src_and_pos(self, x):
src, pos = [], []
for i, feature in enumerate(x):
pos.append(self.pe_layer(feature).flatten(2).permute(2, 0, 1)) # flatten NxCxHxW to HWxNxC
src.append((self.input_proj[i](feature).flatten(2) + self.level_embed.weight[i][None, :, None]).permute(2, 0, 1))
return src, pos