Spaces:
Running
on
Zero
Running
on
Zero
File size: 15,000 Bytes
f22f03c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
"""Improved diffusion model architecture proposed in the paper
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
import numpy as np
import torch
from torch_utils import persistence
from torch_utils import misc
#----------------------------------------------------------------------------
# Normalize given tensor to unit magnitude with respect to the given
# dimensions. Default = all dimensions except the first.
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
#----------------------------------------------------------------------------
# Upsample or downsample the given tensor with the given filter,
# or keep it as is.
def resample(x, f=[1,1], mode='keep'):
if mode == 'keep':
return x
f = np.float32(f)
assert f.ndim == 1 and len(f) % 2 == 0
pad = (len(f) - 1) // 2
f = f / f.sum()
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
f = misc.const_like(x, f)
c = x.shape[1]
if mode == 'down':
return torch.nn.functional.conv2d(x, f.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
assert mode == 'up'
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
#----------------------------------------------------------------------------
# Magnitude-preserving SiLU (Equation 81).
def mp_silu(x):
return torch.nn.functional.silu(x) / 0.596
#----------------------------------------------------------------------------
# Magnitude-preserving sum (Equation 88).
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)
#----------------------------------------------------------------------------
# Magnitude-preserving concatenation (Equation 103).
def mp_cat(a, b, dim=1, t=0.5):
Na = a.shape[dim]
Nb = b.shape[dim]
C = np.sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2))
wa = C / np.sqrt(Na) * (1 - t)
wb = C / np.sqrt(Nb) * t
return torch.cat([wa * a , wb * b], dim=dim)
#----------------------------------------------------------------------------
# Magnitude-preserving Fourier features (Equation 75).
@persistence.persistent_class
class MPFourier(torch.nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
def forward(self, x):
y = x.to(torch.float32)
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32)
y = y.cos() * np.sqrt(2)
return y.to(x.dtype)
#----------------------------------------------------------------------------
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
# with force weight normalization (Equation 66).
@persistence.persistent_class
class MPConv(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
def forward(self, x, gain=1):
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
#----------------------------------------------------------------------------
# U-Net encoder/decoder block with optional self-attention (Figure 21).
@persistence.persistent_class
class Block(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
emb_channels, # Number of embedding channels.
flavor = 'enc', # Flavor: 'enc' or 'dec'.
resample_mode = 'keep', # Resampling: 'keep', 'up', or 'down'.
resample_filter = [1,1], # Resampling filter.
attention = False, # Include self-attention?
channels_per_head = 64, # Number of channels per attention head.
dropout = 0, # Dropout probability.
res_balance = 0.3, # Balance between main branch (0) and residual branch (1).
attn_balance = 0.3, # Balance between main branch (0) and self-attention (1).
clip_act = 256, # Clip output activations. None = do not clip.
):
super().__init__()
self.out_channels = out_channels
self.flavor = flavor
self.resample_filter = resample_filter
self.resample_mode = resample_mode
self.num_heads = out_channels // channels_per_head if attention else 0
self.dropout = dropout
self.res_balance = res_balance
self.attn_balance = attn_balance
self.clip_act = clip_act
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
self.conv_res0 = MPConv(out_channels if flavor == 'enc' else in_channels, out_channels, kernel=[3,3])
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
self.conv_res1 = MPConv(out_channels, out_channels, kernel=[3,3])
self.conv_skip = MPConv(in_channels, out_channels, kernel=[1,1]) if in_channels != out_channels else None
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=[1,1]) if self.num_heads != 0 else None
self.attn_proj = MPConv(out_channels, out_channels, kernel=[1,1]) if self.num_heads != 0 else None
def forward(self, x, emb):
# Main branch.
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
if self.flavor == 'enc':
if self.conv_skip is not None:
x = self.conv_skip(x)
x = normalize(x, dim=1) # pixel norm
# Residual branch.
y = self.conv_res0(mp_silu(x))
c = self.emb_linear(emb, gain=self.emb_gain) + 1
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
if self.training and self.dropout != 0:
y = torch.nn.functional.dropout(y, p=self.dropout)
y = self.conv_res1(y)
# Connect the branches.
if self.flavor == 'dec' and self.conv_skip is not None:
x = self.conv_skip(x)
x = mp_sum(x, y, t=self.res_balance)
# Self-attention.
# Note: torch.nn.functional.scaled_dot_product_attention() could be used here,
# but we haven't done sufficient testing to verify that it produces identical results.
if self.num_heads != 0:
y = self.attn_qkv(x)
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
q, k, v = normalize(y, dim=2).unbind(3) # pixel norm & split
w = torch.einsum('nhcq,nhck->nhqk', q, k / np.sqrt(q.shape[2])).softmax(dim=3)
y = torch.einsum('nhqk,nhck->nhcq', w, v)
y = self.attn_proj(y.reshape(*x.shape))
x = mp_sum(x, y, t=self.attn_balance)
# Clip activations.
if self.clip_act is not None:
x = x.clip_(-self.clip_act, self.clip_act)
return x
#----------------------------------------------------------------------------
# EDM2 U-Net model (Figure 21).
@persistence.persistent_class
class UNet(torch.nn.Module):
def __init__(self,
img_resolution, # Image resolution.
img_channels, # Image channels.
label_dim, # Class label dimensionality. 0 = unconditional.
model_channels = 192, # Base multiplier for the number of channels.
channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
channel_mult_noise = None, # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
channel_mult_emb = None, # Multiplier for final embedding dimensionality. None = select based on channel_mult.
num_blocks = 3, # Number of residual blocks per resolution.
attn_resolutions = [16,8], # List of resolutions with self-attention.
label_balance = 0.5, # Balance between noise embedding (0) and class embedding (1).
concat_balance = 0.5, # Balance between skip connections (0) and main path (1).
**block_kwargs, # Arguments for Block.
):
super().__init__()
cblock = [model_channels * x for x in channel_mult]
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
self.label_balance = label_balance
self.concat_balance = concat_balance
self.out_gain = torch.nn.Parameter(torch.zeros([]))
# Embedding.
self.emb_fourier = MPFourier(cnoise)
self.emb_noise = MPConv(cnoise, cemb, kernel=[])
self.emb_label = MPConv(label_dim, cemb, kernel=[]) if label_dim != 0 else None
# Encoder.
self.enc = torch.nn.ModuleDict()
cout = img_channels + 1
for level, channels in enumerate(cblock):
res = img_resolution >> level
if level == 0:
cin = cout
cout = channels
self.enc[f'{res}x{res}_conv'] = MPConv(cin, cout, kernel=[3,3])
else:
self.enc[f'{res}x{res}_down'] = Block(cout, cout, cemb, flavor='enc', resample_mode='down', **block_kwargs)
for idx in range(num_blocks):
cin = cout
cout = channels
self.enc[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='enc', attention=(res in attn_resolutions), **block_kwargs)
# Decoder.
self.dec = torch.nn.ModuleDict()
skips = [block.out_channels for block in self.enc.values()]
for level, channels in reversed(list(enumerate(cblock))):
res = img_resolution >> level
if level == len(cblock) - 1:
self.dec[f'{res}x{res}_in0'] = Block(cout, cout, cemb, flavor='dec', attention=True, **block_kwargs)
self.dec[f'{res}x{res}_in1'] = Block(cout, cout, cemb, flavor='dec', **block_kwargs)
else:
self.dec[f'{res}x{res}_up'] = Block(cout, cout, cemb, flavor='dec', resample_mode='up', **block_kwargs)
for idx in range(num_blocks + 1):
cin = cout + skips.pop()
cout = channels
self.dec[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='dec', attention=(res in attn_resolutions), **block_kwargs)
self.out_conv = MPConv(cout, img_channels, kernel=[3,3])
def forward(self, x, noise_labels, class_labels):
# Embedding.
emb = self.emb_noise(self.emb_fourier(noise_labels))
if self.emb_label is not None:
emb = mp_sum(emb, self.emb_label(class_labels * np.sqrt(class_labels.shape[1])), t=self.label_balance)
emb = mp_silu(emb)
# Encoder.
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
skips = []
for name, block in self.enc.items():
x = block(x) if 'conv' in name else block(x, emb)
skips.append(x)
# Decoder.
for name, block in self.dec.items():
if 'block' in name:
x = mp_cat(x, skips.pop(), t=self.concat_balance)
x = block(x, emb)
x = self.out_conv(x, gain=self.out_gain)
return x
#----------------------------------------------------------------------------
# Preconditioning and uncertainty estimation.
@persistence.persistent_class
class Precond(torch.nn.Module):
def __init__(self,
img_resolution, # Image resolution.
img_channels, # Image channels.
label_dim, # Class label dimensionality. 0 = unconditional.
use_fp16 = True, # Run the model at FP16 precision?
sigma_data = 0.5, # Expected standard deviation of the training data.
logvar_channels = 128, # Intermediate dimensionality for uncertainty estimation.
**unet_kwargs, # Keyword arguments for UNet.
):
super().__init__()
self.img_resolution = img_resolution
self.img_channels = img_channels
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.sigma_data = sigma_data
self.unet = UNet(img_resolution=img_resolution, img_channels=img_channels, label_dim=label_dim, **unet_kwargs)
self.logvar_fourier = MPFourier(logvar_channels)
self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])
def forward(self, x, sigma, class_labels=None, force_fp32=False, return_logvar=False, **unet_kwargs):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
# Preconditioning weights.
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
c_noise = sigma.flatten().log() / 4
# Run the model.
x_in = (c_in * x).to(dtype)
F_x = self.unet(x_in, c_noise, class_labels, **unet_kwargs)
D_x = c_skip * x + c_out * F_x.to(torch.float32)
# Estimate uncertainty if requested.
if return_logvar:
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
return D_x, logvar # u(sigma) in Equation 21
return D_x
#----------------------------------------------------------------------------
|