|
|
|
|
|
import re |
|
from collections import OrderedDict |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torch.utils.checkpoint import checkpoint_sequential |
|
|
|
from scepter.modules.model.base_model import BaseModel |
|
from scepter.modules.model.registry import BACKBONES |
|
from scepter.modules.utils.config import dict_to_yaml |
|
from scepter.modules.utils.file_system import FS |
|
|
|
from .layers import ( |
|
Mlp, |
|
TimestepEmbedder, |
|
PatchEmbed, |
|
DiTACEBlock, |
|
T2IFinalLayer |
|
) |
|
from .pos_embed import rope_params |
|
|
|
|
|
@BACKBONES.register_class() |
|
class DiTACE(BaseModel): |
|
|
|
para_dict = { |
|
'PATCH_SIZE': { |
|
'value': 2, |
|
'description': '' |
|
}, |
|
'IN_CHANNELS': { |
|
'value': 4, |
|
'description': '' |
|
}, |
|
'HIDDEN_SIZE': { |
|
'value': 1152, |
|
'description': '' |
|
}, |
|
'DEPTH': { |
|
'value': 28, |
|
'description': '' |
|
}, |
|
'NUM_HEADS': { |
|
'value': 16, |
|
'description': '' |
|
}, |
|
'MLP_RATIO': { |
|
'value': 4.0, |
|
'description': '' |
|
}, |
|
'PRED_SIGMA': { |
|
'value': True, |
|
'description': '' |
|
}, |
|
'DROP_PATH': { |
|
'value': 0., |
|
'description': '' |
|
}, |
|
'WINDOW_SIZE': { |
|
'value': 0, |
|
'description': '' |
|
}, |
|
'WINDOW_BLOCK_INDEXES': { |
|
'value': None, |
|
'description': '' |
|
}, |
|
'Y_CHANNELS': { |
|
'value': 4096, |
|
'description': '' |
|
}, |
|
'ATTENTION_BACKEND': { |
|
'value': None, |
|
'description': '' |
|
}, |
|
'QK_NORM': { |
|
'value': True, |
|
'description': 'Whether to use RMSNorm for query and key.', |
|
}, |
|
} |
|
para_dict.update(BaseModel.para_dict) |
|
|
|
def __init__(self, cfg, logger): |
|
super().__init__(cfg, logger=logger) |
|
self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None) |
|
if self.window_block_indexes is None: |
|
self.window_block_indexes = [] |
|
self.pred_sigma = cfg.get('PRED_SIGMA', True) |
|
self.in_channels = cfg.get('IN_CHANNELS', 4) |
|
self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels |
|
self.patch_size = cfg.get('PATCH_SIZE', 2) |
|
self.num_heads = cfg.get('NUM_HEADS', 16) |
|
self.hidden_size = cfg.get('HIDDEN_SIZE', 1152) |
|
self.y_channels = cfg.get('Y_CHANNELS', 4096) |
|
self.drop_path = cfg.get('DROP_PATH', 0.) |
|
self.depth = cfg.get('DEPTH', 28) |
|
self.mlp_ratio = cfg.get('MLP_RATIO', 4.0) |
|
self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False) |
|
self.attention_backend = cfg.get('ATTENTION_BACKEND', None) |
|
self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024) |
|
self.qk_norm = cfg.get('QK_NORM', False) |
|
self.ignore_keys = cfg.get('IGNORE_KEYS', []) |
|
assert (self.hidden_size % self.num_heads |
|
) == 0 and (self.hidden_size // self.num_heads) % 2 == 0 |
|
d = self.hidden_size // self.num_heads |
|
self.freqs = torch.cat( |
|
[ |
|
rope_params(self.max_seq_len, d - 4 * (d // 6)), |
|
rope_params(self.max_seq_len, 2 * (d // 6)), |
|
rope_params(self.max_seq_len, 2 * (d // 6)) |
|
], |
|
dim=1) |
|
|
|
|
|
self.x_embedder = PatchEmbed(self.patch_size, |
|
self.in_channels + 1, |
|
self.hidden_size, |
|
bias=True, |
|
flatten=False) |
|
self.t_embedder = TimestepEmbedder(self.hidden_size) |
|
self.y_embedder = Mlp(in_features=self.y_channels, |
|
hidden_features=self.hidden_size, |
|
out_features=self.hidden_size, |
|
act_layer=lambda: nn.GELU(approximate='tanh'), |
|
drop=0) |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)) |
|
|
|
drop_path = [ |
|
x.item() for x in torch.linspace(0, self.drop_path, self.depth) |
|
] |
|
self.blocks = nn.ModuleList([ |
|
DiTACEBlock(self.hidden_size, |
|
self.num_heads, |
|
mlp_ratio=self.mlp_ratio, |
|
drop_path=drop_path[i], |
|
window_size=self.window_size |
|
if i in self.window_block_indexes else 0, |
|
backend=self.attention_backend, |
|
use_condition=True, |
|
qk_norm=self.qk_norm) for i in range(self.depth) |
|
]) |
|
self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, |
|
self.out_channels) |
|
self.initialize_weights() |
|
|
|
def load_pretrained_model(self, pretrained_model): |
|
if pretrained_model: |
|
with FS.get_from(pretrained_model, wait_finish=True) as local_path: |
|
model = torch.load(local_path, map_location='cpu') |
|
if 'state_dict' in model: |
|
model = model['state_dict'] |
|
new_ckpt = OrderedDict() |
|
for k, v in model.items(): |
|
if self.ignore_keys is not None: |
|
if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \ |
|
(isinstance(self.ignore_keys, list) and k in self.ignore_keys): |
|
continue |
|
k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.') |
|
k = k.replace('.cross_attn.proj.', |
|
'.cross_attn.o.').replace( |
|
'.attn.proj.', '.attn.o.') |
|
if '.cross_attn.kv_linear.' in k: |
|
k_p, v_p = torch.split(v, v.shape[0] // 2) |
|
new_ckpt[k.replace('.cross_attn.kv_linear.', |
|
'.cross_attn.k.')] = k_p |
|
new_ckpt[k.replace('.cross_attn.kv_linear.', |
|
'.cross_attn.v.')] = v_p |
|
elif '.attn.qkv.' in k: |
|
q_p, k_p, v_p = torch.split(v, v.shape[0] // 3) |
|
new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p |
|
new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p |
|
new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p |
|
elif 'y_embedder.y_proj.' in k: |
|
new_ckpt[k.replace('y_embedder.y_proj.', |
|
'y_embedder.')] = v |
|
elif k in ('x_embedder.proj.weight'): |
|
model_p = self.state_dict()[k] |
|
if v.shape != model_p.shape: |
|
model_p.zero_() |
|
model_p[:, :4, :, :].copy_(v) |
|
new_ckpt[k] = torch.nn.parameter.Parameter(model_p) |
|
else: |
|
new_ckpt[k] = v |
|
elif k in ('x_embedder.proj.bias'): |
|
new_ckpt[k] = v |
|
else: |
|
new_ckpt[k] = v |
|
missing, unexpected = self.load_state_dict(new_ckpt, |
|
strict=False) |
|
print( |
|
f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys' |
|
) |
|
if len(missing) > 0: |
|
print(f'Missing Keys:\n {missing}') |
|
if len(unexpected) > 0: |
|
print(f'\nUnexpected Keys:\n {unexpected}') |
|
|
|
def forward(self, |
|
x, |
|
t=None, |
|
cond=dict(), |
|
mask=None, |
|
text_position_embeddings=None, |
|
gc_seg=-1, |
|
**kwargs): |
|
if self.freqs.device != x.device: |
|
self.freqs = self.freqs.to(x.device) |
|
if isinstance(cond, dict): |
|
context = cond.get('crossattn', None) |
|
else: |
|
context = cond |
|
if text_position_embeddings is not None: |
|
|
|
|
|
proj_position_embeddings = self.y_embedder( |
|
text_position_embeddings) |
|
else: |
|
proj_position_embeddings = None |
|
|
|
ctx_batch, txt_lens = [], [] |
|
if mask is not None and isinstance(mask, list): |
|
for ctx, ctx_mask in zip(context, mask): |
|
for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)): |
|
u, m = one_ctx |
|
t_len = m.flatten().sum() |
|
u = u[:t_len] |
|
u = self.y_embedder(u) |
|
if frame_id == 0: |
|
u = u + proj_position_embeddings[ |
|
len(ctx) - |
|
1] if proj_position_embeddings is not None else u |
|
else: |
|
u = u + proj_position_embeddings[ |
|
frame_id - |
|
1] if proj_position_embeddings is not None else u |
|
ctx_batch.append(u) |
|
txt_lens.append(t_len) |
|
else: |
|
raise TypeError |
|
y = torch.cat(ctx_batch, dim=0) |
|
txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True) |
|
|
|
batch_frames = [] |
|
for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']): |
|
u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1]) |
|
m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0) |
|
batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)]) |
|
if 'edit' in cond: |
|
for i, (edit, edit_mask) in enumerate( |
|
zip(cond['edit'], cond['edit_mask'])): |
|
if edit is None: |
|
continue |
|
for u, m in zip(edit, edit_mask): |
|
u = u.squeeze(0) |
|
m = torch.ones_like( |
|
u[[0], :, :]) if m is None else m.squeeze(0) |
|
batch_frames[i].append( |
|
torch.cat([u, m], dim=0).unsqueeze(0)) |
|
|
|
patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], [] |
|
for frames in batch_frames: |
|
patches, patch_shapes = [], [] |
|
self_x_len.append(0) |
|
for frame_id, u in enumerate(frames): |
|
u = self.x_embedder(u) |
|
h, w = u.size(2), u.size(3) |
|
u = rearrange(u, '1 c h w -> (h w) c') |
|
if frame_id == 0: |
|
u = u + proj_position_embeddings[ |
|
len(frames) - |
|
1] if proj_position_embeddings is not None else u |
|
else: |
|
u = u + proj_position_embeddings[ |
|
frame_id - |
|
1] if proj_position_embeddings is not None else u |
|
patches.append(u) |
|
patch_shapes.append([h, w]) |
|
cross_x_len.append(h * w) |
|
self_x_len[-1] += h * w |
|
|
|
patch_batch.extend(patches) |
|
shape_batch.append( |
|
torch.LongTensor(patch_shapes).to(x.device, non_blocking=True)) |
|
|
|
t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)]) |
|
self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to( |
|
x.device, non_blocking=True), torch.LongTensor(cross_x_len).to( |
|
x.device, non_blocking=True)) |
|
|
|
x = torch.cat(patch_batch, dim=0) |
|
x_shapes = pad_sequence(tuple(shape_batch), |
|
batch_first=True) |
|
t = self.t_embedder(t) |
|
t0 = self.t_block(t) |
|
|
|
|
|
kwargs = dict(y=y, |
|
t=t0, |
|
x_shapes=x_shapes, |
|
self_x_len=self_x_len, |
|
cross_x_len=cross_x_len, |
|
freqs=self.freqs, |
|
txt_lens=txt_lens) |
|
if self.use_grad_checkpoint and gc_seg >= 0: |
|
x = checkpoint_sequential( |
|
functions=[partial(block, **kwargs) for block in self.blocks], |
|
segments=gc_seg if gc_seg > 0 else len(self.blocks), |
|
input=x, |
|
use_reentrant=False) |
|
else: |
|
for block in self.blocks: |
|
x = block(x, **kwargs) |
|
x = self.final_layer(x, t) |
|
outs, cur_length = [], 0 |
|
p = self.patch_size |
|
for seq_length, shape in zip(self_x_len, shape_batch): |
|
x_i = x[cur_length:cur_length + seq_length] |
|
h, w = shape[0].tolist() |
|
u = x_i[:h * w].view(h, w, p, p, -1) |
|
u = rearrange(u, 'h w p q c -> (h p w q) c' |
|
) |
|
cur_length = cur_length + seq_length |
|
outs.append(u) |
|
x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1) |
|
if self.pred_sigma: |
|
return x.chunk(2, dim=1)[0] |
|
else: |
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
if hasattr(self, 'y_embedder'): |
|
nn.init.normal_(self.y_embedder.fc1.weight, std=0.02) |
|
nn.init.normal_(self.y_embedder.fc2.weight, std=0.02) |
|
|
|
for block in self.blocks: |
|
nn.init.constant_(block.cross_attn.o.weight, 0) |
|
nn.init.constant_(block.cross_attn.o.bias, 0) |
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
@staticmethod |
|
def get_config_template(): |
|
return dict_to_yaml('BACKBONE', |
|
__class__.__name__, |
|
DiTACE.para_dict, |
|
set_name=True) |
|
|