Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import torch | |
import torch.nn as nn | |
from copy import deepcopy | |
from torch import Tensor | |
from torch.nn import Module, Linear, init | |
from typing import Any, Mapping | |
from diffusion.model.nets import PixArtMSBlock, PixArtMS, PixArt, MVEncoder | |
from diffusion.model.nets.PixArt import get_2d_sincos_pos_embed | |
from diffusion.model.utils import auto_grad_checkpoint | |
# The implementation of ControlNet-Half architrecture | |
# https://github.com/lllyasviel/ControlNet/discussions/188 | |
class ControlT2IDitBlockHalf(Module): | |
def __init__(self, base_block: PixArtMSBlock, block_index: 0, zero_init=True, base_size=None) -> None: | |
super().__init__() | |
self.copied_block = deepcopy(base_block) | |
self.block_index = block_index | |
for p in self.copied_block.parameters(): | |
p.requires_grad_(True) | |
self.copied_block.load_state_dict(base_block.state_dict()) | |
self.copied_block.train() | |
self.hidden_size = hidden_size = base_block.hidden_size | |
if self.block_index == 0: | |
self.before_proj = Linear(hidden_size, hidden_size) | |
# we still keep the before_proj as zero initialed | |
init.zeros_(self.before_proj.weight) | |
init.zeros_(self.before_proj.bias) | |
self.after_proj = Linear(hidden_size, hidden_size) | |
if zero_init: | |
init.zeros_(self.after_proj.weight) | |
init.zeros_(self.after_proj.bias) | |
def forward(self, x, y, t, mask=None, c=None, epipolar_constrains=None, cam_distances=None, n_views=None): | |
if self.block_index == 0: | |
# the first block | |
c = self.before_proj(c) | |
c = self.copied_block(x + c, y, t, mask, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views) | |
c_skip = self.after_proj(c) | |
else: | |
# load from previous c and produce the c for skip connection | |
c = self.copied_block(c, y, t, mask, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views) | |
c_skip = self.after_proj(c) | |
return c, c_skip | |
# The implementation of ControlPixArtHalf net | |
class ControlPixArtHalf(Module): | |
# only support single res model | |
def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None: | |
super().__init__() | |
self.base_model = base_model.eval() | |
self.controlnet = [] | |
self.copy_blocks_num = copy_blocks_num | |
self.total_blocks_num = len(base_model.blocks) | |
for p in self.base_model.parameters(): | |
p.requires_grad_(False) | |
# Copy first copy_blocks_num block | |
for i in range(copy_blocks_num): | |
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i)) | |
self.controlnet = nn.ModuleList(self.controlnet) | |
def __getattr__(self, name: str) -> Tensor or Module: | |
if name in [ | |
'base_model', | |
'controlnet', | |
'encoder', | |
'controlnet_t_block', | |
'noise_embedding', | |
]: | |
return super().__getattr__(name) | |
else: | |
return getattr(self.base_model, name) | |
def forward_c(self, c): | |
self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size | |
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype) | |
return self.x_embedder(c) + pos_embed if c is not None else c | |
# def forward(self, x, t, c, **kwargs): | |
# return self.base_model(x, t, c=self.forward_c(c), **kwargs) | |
def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs): | |
# modify the original PixArtMS forward function | |
if c is not None: | |
c = c.to(self.dtype) | |
c = self.forward_c(c) | |
""" | |
Forward pass of PixArt. | |
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
t: (N,) tensor of diffusion timesteps | |
y: (N, 1, 120, C) tensor of class labels | |
""" | |
x = x.to(self.dtype) | |
timestep = timestep.to(self.dtype) | |
y = y.to(self.dtype) | |
pos_embed = self.pos_embed.to(self.dtype) | |
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size | |
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 | |
t = self.t_embedder(timestep.to(x.dtype)) # (N, D) | |
t0 = self.t_block(t) | |
y = self.y_embedder(y, self.training) # (N, 1, L, D) | |
if mask is not None: | |
if mask.shape[0] != y.shape[0]: | |
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) | |
mask = mask.squeeze(1).squeeze(1) | |
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) | |
y_lens = mask.sum(dim=1).tolist() | |
else: | |
y_lens = [y.shape[2]] * y.shape[0] | |
y = y.squeeze(1).view(1, -1, x.shape[-1]) | |
# define the first layer | |
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint | |
if c is not None: | |
# update c | |
for index in range(1, self.copy_blocks_num + 1): | |
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) | |
# update x | |
for index in range(self.copy_blocks_num + 1, self.total_blocks_num): | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) | |
else: | |
for index in range(1, self.total_blocks_num): | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) | |
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
x = self.unpatchify(x) # (N, out_channels, H, W) | |
return x | |
def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs): | |
model_out = self.forward(x, t, y, data_info=data_info, c=c, **kwargs) | |
return model_out.chunk(2, dim=1)[0] | |
def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, c, **kwargs): | |
""" | |
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. | |
""" | |
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb | |
half = x[: len(x) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = self.forward(combined, timestep, y, data_info=data_info, c=c) | |
eps, rest = model_out[:, :3], model_out[:, 3:] | |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) | |
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) | |
eps = torch.cat([half_eps, half_eps], dim=0) | |
return torch.cat([eps, rest], dim=1) | |
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): | |
if all((k.startswith(('base_model', 'controlnet', 'encoder', 'controlnet_t_block', 'noise_embedding'))) for k in state_dict.keys()): | |
return super().load_state_dict(state_dict, strict) | |
else: | |
new_key = {} | |
for k in state_dict.keys(): | |
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k) | |
for k, v in new_key.items(): | |
if k != v: | |
print(f"replace {k} to {v}") | |
state_dict[v] = state_dict.pop(k) | |
return self.base_model.load_state_dict(state_dict, strict) | |
def unpatchify(self, x): | |
""" | |
x: (N, T, patch_size**2 * C) | |
imgs: (N, H, W, C) | |
""" | |
c = self.out_channels | |
p = self.x_embedder.patch_size[0] | |
assert self.h * self.w == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) | |
x = torch.einsum('nhwpqc->nchpwq', x) | |
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) | |
return imgs | |
def dtype(self): | |
return next(self.parameters()).dtype | |
# The implementation for PixArtMS_Half + 1024 resolution | |
class ControlPixArtMSHalf(ControlPixArtHalf): | |
# support multi-scale res model (multi-scale model can also be applied to single reso training & inference) | |
def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None: | |
super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num) | |
def forward(self, x, timestep, y, mask=None, data_info=None, c=None, need_forward_c=True, **kwargs): | |
# modify the original PixArtMS forward function | |
""" | |
Forward pass of PixArt. | |
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
t: (N,) tensor of diffusion timesteps | |
y: (N, 1, 120, C) tensor of class labels | |
""" | |
if c is not None and need_forward_c: | |
c = c.to(self.dtype) | |
c = self.forward_c(c) | |
bs = x.shape[0] | |
x = x.to(self.dtype) | |
timestep = timestep.to(self.dtype) | |
y = y.to(self.dtype) | |
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size | |
pos_embed = torch.from_numpy( | |
get_2d_sincos_pos_embed( | |
self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, | |
base_size=self.base_size | |
) | |
).unsqueeze(0).to(x.device).to(self.dtype) | |
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 | |
t = self.t_embedder(timestep) # (N, D) | |
if self.micro_conditioning: | |
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) | |
csize = self.csize_embedder(c_size, bs) # (N, D) | |
ar = self.ar_embedder(ar, bs) # (N, D) | |
t = t + torch.cat([csize, ar], dim=1) | |
t0 = self.t_block(t) | |
y = self.y_embedder(y, self.training) # (N, D) | |
if mask is not None: | |
if mask.shape[0] != y.shape[0]: | |
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) | |
mask = mask.squeeze(1).squeeze(1) | |
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) | |
y_lens = mask.sum(dim=1).tolist() | |
y_lens = [int(item) for item in y_lens] | |
else: | |
y_lens = [y.shape[2]] * y.shape[0] | |
y = y.squeeze(1).view(1, -1, x.shape[-1]) | |
# define the first layer | |
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint | |
if c is not None: | |
# update c | |
for index in range(1, self.copy_blocks_num + 1): | |
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) | |
# update x | |
for index in range(self.copy_blocks_num + 1, self.total_blocks_num): | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) | |
else: | |
for index in range(1, self.total_blocks_num): | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) | |
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
x = self.unpatchify(x) # (N, out_channels, H, W) | |
return x | |
# 3DEnhancer Backbone | |
class ControlPixArtMSMVHalfWithEncoder(ControlPixArtMSHalf): | |
def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None: | |
super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num) | |
self.encoder = MVEncoder( | |
double_z=False, | |
resolution=512, | |
in_channels=9, | |
ch=64, | |
ch_mult=[1, 2, 4, 4], | |
num_res_blocks=1, | |
dropout=0.0, | |
attn_resolutions=[], | |
out_ch=3, # unused | |
z_channels=self.hidden_size, | |
attn_kwargs = { | |
'n_heads': 8, | |
'd_head': 64, | |
}, | |
z_downsample_size=2, | |
) | |
self.noise_embedding = nn.Embedding(500, self.hidden_size) | |
self.noise_embedding.weight.data.fill_(0) | |
self.controlnet_t_block = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True) | |
) | |
self.attetion_token_num = self.base_size**2 | |
def encode(self, input_img, camera_pose, n_views): | |
# fuse this two on 2nd dim | |
# input_img: b3hw, camera_pose: b6hw (b%4==0) | |
z_lq = torch.cat((input_img, camera_pose), dim=1) | |
z_lq = self.encoder(z_lq, n_views) | |
z_lq = z_lq.permute(0, 2, 3, 1).reshape(-1, self.attetion_token_num, self.hidden_size) | |
return z_lq | |
def forward(self, x, timestep, y, mask=None, data_info=None, input_img=None, camera_pose=None, c=None, noise_level=None, epipolar_constrains=None, cam_distances=None, n_views=None, **kwargs): | |
""" | |
Forward pass of PixArt. | |
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
t: (N,) tensor of diffusion timesteps | |
y: (N, 1, 120, C) tensor of class labels | |
""" | |
c = self.encode(input_img, camera_pose, n_views).to(x.dtype) if c is None else c | |
bs = x.shape[0] | |
x = x.to(self.dtype) | |
timestep = timestep.to(self.dtype) | |
y = y.to(self.dtype) | |
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size | |
pos_embed = torch.from_numpy( | |
get_2d_sincos_pos_embed( | |
self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=self.pe_interpolation, | |
base_size=self.base_size | |
) | |
).unsqueeze(0).to(x.device).to(self.dtype) | |
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 | |
t = self.t_embedder(timestep) # (N, D) | |
noise_level = self.noise_embedding(noise_level) | |
controlnet_t = t + noise_level | |
if self.micro_conditioning: | |
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype) | |
csize = self.csize_embedder(c_size, bs) # (N, D) | |
ar = self.ar_embedder(ar, bs) # (N, D) | |
t = t + torch.cat([csize, ar], dim=1) | |
t0 = self.t_block(t) | |
controlnet_t0 = self.controlnet_t_block(controlnet_t) | |
y = self.y_embedder(y, self.training) # (N, D) | |
if mask is not None: | |
if mask.shape[0] != y.shape[0]: | |
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) | |
mask = mask.squeeze(1).squeeze(1) | |
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) | |
y_lens = mask.sum(dim=1).tolist() | |
y_lens = [int(item) for item in y_lens] | |
else: | |
y_lens = [y.shape[2]] * y.shape[0] | |
y = y.squeeze(1).view(1, -1, x.shape[-1]) | |
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) # (N, T, D) #support grad checkpoint | |
if c is not None: | |
# update c | |
for index in range(1, self.copy_blocks_num + 1): | |
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, controlnet_t0, y_lens, c, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views, **kwargs) | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) | |
# update x | |
for index in range(self.copy_blocks_num + 1, self.total_blocks_num): | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) | |
else: | |
for index in range(1, self.total_blocks_num): | |
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, None, None, epipolar_constrains, cam_distances, n_views, **kwargs) | |
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
x = self.unpatchify(x) # (N, out_channels, H, W) | |
return x | |
def forward_with_dpmsolver(self, x, t, y, data_info, c, noise_level, epipolar_constrains, cam_distances, n_views, **kwargs): | |
model_out = self.forward(x, t, y, data_info=data_info, c=c, noise_level=noise_level, epipolar_constrains=epipolar_constrains, cam_distances=cam_distances, n_views=n_views, **kwargs) | |
return model_out.chunk(2, dim=1)[0] |