import torch import torch.nn as nn from basicsr.archs.ddcolor_arch_utils.unet import Hook, CustomPixelShuffle_ICNR, UnetBlockWide, NormType, custom_conv_layer from basicsr.archs.ddcolor_arch_utils.convnext import ConvNeXt from basicsr.archs.ddcolor_arch_utils.transformer_utils import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP from basicsr.archs.ddcolor_arch_utils.position_encoding import PositionEmbeddingSine class DDColor(nn.Module): def __init__( self, encoder_name='convnext-l', decoder_name='MultiScaleColorDecoder', num_input_channels=3, input_size=(256, 256), nf=512, num_output_channels=3, last_norm='Weight', do_normalize=False, num_queries=256, num_scales=3, dec_layers=9, ): super().__init__() self.encoder = ImageEncoder(encoder_name, ['norm0', 'norm1', 'norm2', 'norm3']) self.encoder.eval() test_input = torch.randn(1, num_input_channels, *input_size) self.encoder(test_input) self.decoder = DuelDecoder( self.encoder.hooks, nf=nf, last_norm=last_norm, num_queries=num_queries, num_scales=num_scales, dec_layers=dec_layers, decoder_name=decoder_name ) self.refine_net = nn.Sequential( custom_conv_layer(num_queries + 3, num_output_channels, ks=1, use_activ=False, norm_type=NormType.Spectral) ) self.do_normalize = do_normalize self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def normalize(self, img): return (img - self.mean) / self.std def denormalize(self, img): return img * self.std + self.mean def forward(self, x): if x.shape[1] == 3: x = self.normalize(x) self.encoder(x) out_feat = self.decoder() coarse_input = torch.cat([out_feat, x], dim=1) out = self.refine_net(coarse_input) if self.do_normalize: out = self.denormalize(out) return out class ImageEncoder(nn.Module): def __init__(self, encoder_name, hook_names): super().__init__() assert encoder_name == 'convnext-t' or encoder_name == 'convnext-l' if encoder_name == 'convnext-t': self.arch = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) elif encoder_name == 'convnext-l': self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]) else: raise NotImplementedError self.encoder_name = encoder_name self.hook_names = hook_names self.hooks = self.setup_hooks() def setup_hooks(self): hooks = [Hook(self.arch._modules[name]) for name in self.hook_names] return hooks def forward(self, x): return self.arch(x) class DuelDecoder(nn.Module): def __init__( self, hooks, nf=512, blur=True, last_norm='Weight', num_queries=256, num_scales=3, dec_layers=9, decoder_name='MultiScaleColorDecoder', ): super().__init__() self.hooks = hooks self.nf = nf self.blur = blur self.last_norm = getattr(NormType, last_norm) self.decoder_name = decoder_name self.layers = self.make_layers() embed_dim = nf // 2 self.last_shuf = CustomPixelShuffle_ICNR(embed_dim, embed_dim, blur=self.blur, norm_type=self.last_norm, scale=4) assert decoder_name == 'MultiScaleColorDecoder' self.color_decoder = MultiScaleColorDecoder( in_channels=[512, 512, 256], num_queries=num_queries, num_scales=num_scales, dec_layers=dec_layers, ) def make_layers(self): decoder_layers = [] in_c = self.hooks[-1].feature.shape[1] out_c = self.nf setup_hooks = self.hooks[-2::-1] for layer_index, hook in enumerate(setup_hooks): feature_c = hook.feature.shape[1] if layer_index == len(setup_hooks) - 1: out_c = out_c // 2 decoder_layers.append( UnetBlockWide( in_c, feature_c, out_c, hook, blur=self.blur, self_attention=False, norm_type=NormType.Spectral)) in_c = out_c return nn.Sequential(*decoder_layers) def forward(self): encode_feat = self.hooks[-1].feature out0 = self.layers[0](encode_feat) out1 = self.layers[1](out0) out2 = self.layers[2](out1) out3 = self.last_shuf(out2) return self.color_decoder([out0, out1, out2], out3) class MultiScaleColorDecoder(nn.Module): def __init__( self, in_channels, hidden_dim=256, num_queries=100, nheads=8, dim_feedforward=2048, dec_layers=9, pre_norm=False, color_embed_dim=256, enforce_input_project=True, num_scales=3, ): super().__init__() self.hidden_dim = hidden_dim self.num_queries = num_queries self.num_layers = dec_layers self.num_feature_levels = num_scales # Positional encoding layer self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) # Learnable query features and embeddings self.query_feat = nn.Embedding(num_queries, hidden_dim) self.query_embed = nn.Embedding(num_queries, hidden_dim) # Learnable level embeddings self.level_embed = nn.Embedding(num_scales, hidden_dim) # Input projection layers self.input_proj = nn.ModuleList( [self._make_input_proj(in_ch, hidden_dim, enforce_input_project) for in_ch in in_channels] ) # Transformer layers self.transformer_self_attention_layers = nn.ModuleList() self.transformer_cross_attention_layers = nn.ModuleList() self.transformer_ffn_layers = nn.ModuleList() for _ in range(dec_layers): self.transformer_self_attention_layers.append( SelfAttentionLayer( d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, ) ) self.transformer_cross_attention_layers.append( CrossAttentionLayer( d_model=hidden_dim, nhead=nheads, dropout=0.0, normalize_before=pre_norm, ) ) self.transformer_ffn_layers.append( FFNLayer( d_model=hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm, ) ) # Layer normalization for the decoder output self.decoder_norm = nn.LayerNorm(hidden_dim) # Output embedding layer self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3) def forward(self, x, img_features): assert len(x) == self.num_feature_levels src, pos = self._get_src_and_pos(x) bs = src[0].shape[1] # Prepare query embeddings (QxNxC) query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) for i in range(self.num_layers): level_index = i % self.num_feature_levels # attention: cross-attention first output = self.transformer_cross_attention_layers[i]( output, src[level_index], memory_mask=None, memory_key_padding_mask=None, pos=pos[level_index], query_pos=query_embed ) output = self.transformer_self_attention_layers[i]( output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed ) # FFN output = self.transformer_ffn_layers[i]( output ) decoder_output = self.decoder_norm(output).transpose(0, 1) color_embed = self.color_embed(decoder_output) out = torch.einsum("bqc,bchw->bqhw", color_embed, img_features) return out def _make_input_proj(self, in_ch, hidden_dim, enforce): if in_ch != hidden_dim or enforce: proj = nn.Conv2d(in_ch, hidden_dim, kernel_size=1) nn.init.kaiming_uniform_(proj.weight, a=1) if proj.bias is not None: nn.init.constant_(proj.bias, 0) return proj return nn.Sequential() def _get_src_and_pos(self, x): src, pos = [], [] for i, feature in enumerate(x): pos.append(self.pe_layer(feature).flatten(2).permute(2, 0, 1)) # flatten NxCxHxW to HWxNxC src.append((self.input_proj[i](feature).flatten(2) + self.level_embed.weight[i][None, :, None]).permute(2, 0, 1)) return src, pos