Luo-Yihang's picture
initial code
4c35d22
raw
history blame
2.28 kB
# From https://github.com/Fanghua-Yu/SUPIR/blob/master/SUPIR/modules/SUPIR_v0.py
import torch
import torch as th
import torch.nn as nn
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
# return super().forward(x.float()).type(x.dtype)
return super().forward(x)
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class ZeroSFT(nn.Module):
def __init__(self, label_nc, norm_nc, nhidden=128, norm=True, mask=False, zero_init=True):
super().__init__()
# param_free_norm_type = str(parsed.group(1))
ks = 3
pw = ks // 2
self.norm = norm
if self.norm:
self.param_free_norm = normalization(norm_nc)
else:
self.param_free_norm = nn.Identity()
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
nn.SiLU()
)
if zero_init:
self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw))
self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw))
else:
self.zero_mul = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.zero_add = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, c, h, control_scale=1):
h_raw = h
actv = self.mlp_shared(c)
gamma = self.zero_mul(actv)
beta = self.zero_add(actv)
h = self.param_free_norm(h) * (gamma + 1) + beta
return h * control_scale + h_raw * (1 - control_scale)