Spaces:
Paused
Paused
from functools import partial | |
import inspect | |
import weakref | |
import torch | |
from typing import TYPE_CHECKING, Any, Dict, Optional, Union | |
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
from toolkit.lora_special import LoRASpecialNetwork | |
from diffusers import WanTransformer3DModel | |
from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection | |
from diffusers.models.attention_processor import Attention | |
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding | |
from toolkit.util.shuffle import shuffle_tensor_along_axis | |
import torch.nn.functional as F | |
if TYPE_CHECKING: | |
from toolkit.models.base_model import BaseModel | |
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig | |
from toolkit.custom_adapter import CustomAdapter | |
class FrameEmbedder(torch.nn.Module): | |
def __init__( | |
self, | |
adapter: 'I2VAdapter', | |
orig_layer: torch.nn.Conv3d, | |
in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels | |
): | |
super().__init__() | |
# goes through a conv patch embedding first and is then flattened | |
# hidden_states = self.patch_embedding(hidden_states) | |
# hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
inner_dim = orig_layer.out_channels | |
patch_size = adapter.sd_ref().model.config.patch_size | |
self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) | |
self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) | |
def from_model( | |
cls, | |
model: WanTransformer3DModel, | |
adapter: 'I2VAdapter', | |
): | |
if model.__class__.__name__ == 'WanTransformer3DModel': | |
new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels | |
orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding | |
img_embedder = cls( | |
adapter, | |
orig_layer=orig_patch_embedding, | |
in_channels=new_channels, | |
) | |
# hijack the forward method | |
orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward | |
orig_patch_embedding.forward = img_embedder.forward | |
# update the config of the transformer, only needed when merged in | |
# model.config.in_channels = model.config.in_channels + new_channels | |
# model.config["in_channels"] = model.config.in_channels + new_channels | |
return img_embedder | |
else: | |
raise ValueError("Model not supported") | |
def is_active(self): | |
return self.adapter_ref().is_active | |
def forward(self, x): | |
if not self.is_active: | |
# make sure lora is not active | |
if self.adapter_ref().control_lora is not None: | |
self.adapter_ref().control_lora.is_active = False | |
if x.shape[1] > self.orig_layer_ref().in_channels: | |
# we have i2v, so we need to remove the extra channels | |
x = x[:, :self.orig_layer_ref().in_channels, :, :, :] | |
return self.orig_layer_ref()._orig_i2v_adapter_forward(x) | |
# make sure lora is active | |
if self.adapter_ref().control_lora is not None: | |
self.adapter_ref().control_lora.is_active = True | |
# x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16) | |
# (16 + 4 + 16) = 36 channels | |
# (batch_size, 36, num_frames, latent_height, latent_width) | |
orig_device = x.device | |
orig_dtype = x.dtype | |
orig_in = x[:, :16, :, :, :] | |
orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in) | |
# remove original stuff | |
x = x[:, 16:, :, :, :] | |
x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype) | |
x = self.patch_embedding(x) | |
x = x.to(orig_device, dtype=orig_dtype) | |
# add the original out | |
x = x + orig_out | |
return x | |
def deactivatable_forward( | |
self: 'Attention', | |
*args, | |
**kwargs | |
): | |
if self._attn_hog_ref() is not None and self._attn_hog_ref().is_active: | |
self.added_kv_proj_dim = None | |
self.add_k_proj = self._add_k_proj | |
self.add_v_proj = self._add_v_proj | |
self.norm_added_q = self._norm_added_q | |
self.norm_added_k = self._norm_added_k | |
else: | |
self.added_kv_proj_dim = self._attn_hog_ref().added_kv_proj_dim | |
self.add_k_proj = None | |
self.add_v_proj = None | |
self.norm_added_q = None | |
self.norm_added_k = None | |
return self._orig_forward(*args, **kwargs) | |
class AttentionHog(torch.nn.Module): | |
def __init__( | |
self, | |
added_kv_proj_dim: int, | |
adapter: 'I2VAdapter', | |
attn_layer: Attention, | |
model: 'WanTransformer3DModel', | |
): | |
super().__init__() | |
# To prevent circular import. | |
from diffusers.models.normalization import FP32LayerNorm, LpNorm, RMSNorm | |
self.added_kv_proj_dim = added_kv_proj_dim | |
self.attn_layer_ref: weakref.ref = weakref.ref(attn_layer) | |
self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
self.model_ref: weakref.ref = weakref.ref(model) | |
qk_norm = model.config.qk_norm | |
# layers | |
self.add_k_proj = torch.nn.Linear( | |
added_kv_proj_dim, | |
attn_layer.inner_kv_dim, | |
bias=attn_layer.added_proj_bias | |
) | |
self.add_k_proj.weight.data = self.add_k_proj.weight.data * 0.001 | |
self.add_v_proj = torch.nn.Linear( | |
added_kv_proj_dim, | |
attn_layer.inner_kv_dim, | |
bias=attn_layer.added_proj_bias | |
) | |
self.add_v_proj.weight.data = self.add_v_proj.weight.data * 0.001 | |
# do qk norm. It isnt stored in the class, but we can infer it from the attn layer | |
self.norm_added_q = None | |
self.norm_added_k = None | |
if attn_layer.norm_q is not None: | |
eps: float = 1e-5 | |
if qk_norm == "layer_norm": | |
self.norm_added_q = torch.nn.LayerNorm( | |
attn_layer.norm_q.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_q.elementwise_affine) | |
self.norm_added_k = torch.nn.LayerNorm( | |
attn_layer.norm_k.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_k.elementwise_affine) | |
elif qk_norm == "fp32_layer_norm": | |
self.norm_added_q = FP32LayerNorm( | |
attn_layer.norm_q.normalized_shape, elementwise_affine=False, bias=False, eps=eps) | |
self.norm_added_k = FP32LayerNorm( | |
attn_layer.norm_k.normalized_shape, elementwise_affine=False, bias=False, eps=eps) | |
elif qk_norm == "rms_norm": | |
self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) | |
self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) | |
elif qk_norm == "rms_norm_across_heads": | |
# Wanx applies qk norm across all heads | |
self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) | |
self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) | |
else: | |
raise ValueError( | |
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" | |
) | |
# add these to the attn later in a way they can be deactivated | |
attn_layer._add_k_proj = self.add_k_proj | |
attn_layer._add_v_proj = self.add_v_proj | |
attn_layer._norm_added_q = self.norm_added_q | |
attn_layer._norm_added_k = self.norm_added_k | |
# make it deactivateable | |
attn_layer._attn_hog_ref = weakref.ref(self) | |
attn_layer._orig_forward = attn_layer.forward | |
attn_layer.forward = partial(deactivatable_forward, attn_layer) | |
def forward(self, *args, **kwargs): | |
if not self.adapter_ref().is_active: | |
return self.attn_module(*args, **kwargs) | |
# TODO implement this | |
raise NotImplementedError("Attention hog not implemented") | |
def is_active(self): | |
return self.adapter_ref().is_active | |
def new_wan_forward( | |
self: WanTransformer3DModel, | |
hidden_states: torch.Tensor, | |
timestep: torch.LongTensor, | |
encoder_hidden_states: torch.Tensor, | |
encoder_hidden_states_image: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
attention_kwargs: Optional[Dict[str, Any]] = None, | |
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: | |
# prevent circular import | |
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning | |
adapter:'I2VAdapter' = self._i2v_adapter_ref() | |
if adapter.is_active: | |
# activate the condition embedder | |
self.condition_embedder.image_embedder = adapter.image_embedder | |
# for wan they are putting the image emcoder embeds on the unconditional | |
# this needs to be fixed as that wont work. For now, we will will use the embeds we have in order | |
# we cache an conditional and an unconditional embed. On sampling, it samples conditional first, | |
# then unconditional. So we just need to keep track of which one we are using. This is a horrible hack | |
# TODO find a not stupid way to do this. | |
if adapter.adapter_ref().is_sampling: | |
if not hasattr(self, '_do_unconditional'): | |
# set it to true so we alternate to false immediatly | |
self._do_unconditional = True | |
# alternate it | |
self._do_unconditional = not self._do_unconditional | |
if self._do_unconditional: | |
# slightly reduce strength of conditional for the unconditional | |
# encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5 | |
# shuffle the embedding tokens so we still have all the information, but it is scrambled | |
# this will prevent things like color from being cfg overweights, but still sharpen content. | |
encoder_hidden_states_image = shuffle_tensor_along_axis( | |
adapter.adapter_ref().conditional_embeds, | |
axis=1 | |
) | |
# encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds | |
else: | |
# use the conditional | |
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds | |
else: | |
# doing a normal training run, always use conditional embeds | |
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds | |
# add the first frame conditioning | |
if adapter.frame_embedder is not None: | |
with torch.no_grad(): | |
# add the first frame conditioning | |
conditioning_frame = adapter.adapter_ref().cached_control_image_0_1 | |
if conditioning_frame is None: | |
raise ValueError("No conditioning frame found") | |
# make it -1 to 1 | |
conditioning_frame = (conditioning_frame * 2) - 1 | |
conditioning_frame = conditioning_frame.to( | |
hidden_states.device, dtype=hidden_states.dtype | |
) | |
# if doing a full denoise, the latent input may be full channels here, only get first 16 | |
if hidden_states.shape[1] > 16: | |
hidden_states = hidden_states[:, :16, :, :, :] | |
hidden_states = add_first_frame_conditioning( | |
latent_model_input=hidden_states, | |
first_frame=conditioning_frame, | |
vae=adapter.adapter_ref().sd_ref().vae, | |
) | |
else: | |
# not active deactivate the condition embedder | |
self.condition_embedder.image_embedder = None | |
return self._orig_i2v_adapter_forward( | |
hidden_states=hidden_states, | |
timestep=timestep, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_hidden_states_image=encoder_hidden_states_image, | |
return_dict=return_dict, | |
attention_kwargs=attention_kwargs, | |
) | |
class I2VAdapter(torch.nn.Module): | |
def __init__( | |
self, | |
adapter: 'CustomAdapter', | |
sd: 'BaseModel', | |
config: 'AdapterConfig', | |
train_config: 'TrainConfig', | |
image_processor: Union[SiglipImageProcessor, CLIPImageProcessor], | |
vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection], | |
): | |
super().__init__() | |
# avoid circular import | |
from toolkit.models.wan21.wan_attn import WanAttnProcessor2_0 | |
self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
self.sd_ref = weakref.ref(sd) | |
self.model_config: ModelConfig = sd.model_config | |
self.network_config = config.lora_config | |
self.train_config = train_config | |
self.config = config | |
self.device_torch = sd.device_torch | |
self.control_lora = None | |
self.image_processor_ref: weakref.ref = weakref.ref(image_processor) | |
self.vision_encoder_ref: weakref.ref = weakref.ref(vision_encoder) | |
ve_img_size = vision_encoder.config.image_size | |
ve_patch_size = vision_encoder.config.patch_size | |
num_patches = (ve_img_size // ve_patch_size) ** 2 | |
num_vision_tokens = num_patches | |
# siglip does not have a class token | |
if not vision_encoder.__class__.__name__.lower().startswith("siglip"): | |
num_vision_tokens = num_patches + 1 | |
model_class = sd.model.__class__.__name__ | |
if self.network_config is not None: | |
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs | |
if hasattr(sd, 'target_lora_modules'): | |
network_kwargs['target_lin_modules'] = sd.target_lora_modules | |
if 'ignore_if_contains' not in network_kwargs: | |
network_kwargs['ignore_if_contains'] = [] | |
network_kwargs['ignore_if_contains'] += [ | |
'add_k_proj', | |
'add_v_proj', | |
'norm_added_q', | |
'norm_added_k', | |
] | |
if model_class == 'WanTransformer3DModel': | |
# always ignore patch_embedding | |
network_kwargs['ignore_if_contains'].append('patch_embedding') | |
self.control_lora = LoRASpecialNetwork( | |
text_encoder=sd.text_encoder, | |
unet=sd.unet, | |
lora_dim=self.network_config.linear, | |
multiplier=1.0, | |
alpha=self.network_config.linear_alpha, | |
train_unet=self.train_config.train_unet, | |
train_text_encoder=self.train_config.train_text_encoder, | |
conv_lora_dim=self.network_config.conv, | |
conv_alpha=self.network_config.conv_alpha, | |
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, | |
is_v2=self.model_config.is_v2, | |
is_v3=self.model_config.is_v3, | |
is_pixart=self.model_config.is_pixart, | |
is_auraflow=self.model_config.is_auraflow, | |
is_flux=self.model_config.is_flux, | |
is_lumina2=self.model_config.is_lumina2, | |
is_ssd=self.model_config.is_ssd, | |
is_vega=self.model_config.is_vega, | |
dropout=self.network_config.dropout, | |
use_text_encoder_1=self.model_config.use_text_encoder_1, | |
use_text_encoder_2=self.model_config.use_text_encoder_2, | |
use_bias=False, | |
is_lorm=False, | |
network_config=self.network_config, | |
network_type=self.network_config.type, | |
transformer_only=self.network_config.transformer_only, | |
is_transformer=sd.is_transformer, | |
base_model=sd, | |
**network_kwargs | |
) | |
self.control_lora.force_to(self.device_torch, dtype=torch.float32) | |
self.control_lora._update_torch_multiplier() | |
self.control_lora.apply_to( | |
sd.text_encoder, | |
sd.unet, | |
self.train_config.train_text_encoder, | |
self.train_config.train_unet | |
) | |
self.control_lora.can_merge_in = False | |
self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) | |
if self.train_config.gradient_checkpointing: | |
self.control_lora.enable_gradient_checkpointing() | |
self.frame_embedder: FrameEmbedder = None | |
if self.config.i2v_do_start_frame: | |
self.frame_embedder = FrameEmbedder.from_model( | |
sd.unet, | |
self | |
) | |
self.frame_embedder.to(self.device_torch) | |
# hijack the blocks so we can inject our vision encoder | |
attn_hog_list = [] | |
if model_class == 'WanTransformer3DModel': | |
added_kv_proj_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim | |
# update the model so it can accept the new input | |
# wan has i2v with clip-h for i2v, additional k v attn that directly takes | |
# in the penultimate_hidden_states from the vision encoder | |
# the kv is on blocks[0].attn2 | |
sd.model.config.added_kv_proj_dim = added_kv_proj_dim | |
sd.model.config['added_kv_proj_dim'] = added_kv_proj_dim | |
transformer: WanTransformer3DModel = sd.model | |
for block in transformer.blocks: | |
block.attn2.added_kv_proj_dim = added_kv_proj_dim | |
attn_module = AttentionHog( | |
added_kv_proj_dim, | |
self, | |
block.attn2, | |
transformer | |
) | |
# set the attn function to ours that handles custom number of vision tokens | |
block.attn2.set_processor(WanAttnProcessor2_0(num_vision_tokens)) | |
attn_hog_list.append(attn_module) | |
else: | |
raise ValueError(f"Model {model_class} not supported") | |
self.attn_hog_list = torch.nn.ModuleList(attn_hog_list) | |
self.attn_hog_list.to(self.device_torch) | |
inner_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim | |
image_embed_dim = vision_encoder.config.hidden_size | |
self.image_embedder = WanImageEmbedding(image_embed_dim, inner_dim) | |
# override the forward method | |
if model_class == 'WanTransformer3DModel': | |
self.sd_ref().model._orig_i2v_adapter_forward = self.sd_ref().model.forward | |
self.sd_ref().model.forward = partial( | |
new_wan_forward, | |
self.sd_ref().model | |
) | |
# add the wan image embedder | |
self.sd_ref().model.condition_embedder._image_embedder = self.image_embedder | |
self.sd_ref().model.condition_embedder._image_embedder.to(self.device_torch) | |
self.sd_ref().model._i2v_adapter_ref = weakref.ref(self) | |
def get_params(self): | |
if self.control_lora is not None: | |
config = { | |
'text_encoder_lr': self.train_config.lr, | |
'unet_lr': self.train_config.lr, | |
} | |
sig = inspect.signature(self.control_lora.prepare_optimizer_params) | |
if 'default_lr' in sig.parameters: | |
config['default_lr'] = self.train_config.lr | |
if 'learning_rate' in sig.parameters: | |
config['learning_rate'] = self.train_config.lr | |
params_net = self.control_lora.prepare_optimizer_params( | |
**config | |
) | |
# we want only tensors here | |
params = [] | |
for p in params_net: | |
if isinstance(p, dict): | |
params += p["params"] | |
elif isinstance(p, torch.Tensor): | |
params.append(p) | |
elif isinstance(p, list): | |
params += p | |
else: | |
params = [] | |
if self.frame_embedder is not None: | |
# make sure the embedder is float32 | |
self.frame_embedder.to(torch.float32) | |
params += list(self.frame_embedder.parameters()) | |
# add the attn hogs | |
for attn_hog in self.attn_hog_list: | |
params += list(attn_hog.parameters()) | |
# add the image embedder | |
if self.image_embedder is not None: | |
params += list(self.image_embedder.parameters()) | |
return params | |
def load_weights(self, state_dict, strict=True): | |
lora_sd = {} | |
attn_hog_sd = {} | |
frame_embedder_sd = {} | |
image_embedder_sd = {} | |
for key, value in state_dict.items(): | |
if "frame_embedder" in key: | |
new_key = key.replace("frame_embedder.", "") | |
frame_embedder_sd[new_key] = value | |
elif "attn_hog" in key: | |
new_key = key.replace("attn_hog.", "") | |
attn_hog_sd[new_key] = value | |
elif "image_embedder" in key: | |
new_key = key.replace("image_embedder.", "") | |
image_embedder_sd[new_key] = value | |
else: | |
lora_sd[key] = value | |
# todo process state dict before loading | |
if self.control_lora is not None: | |
self.control_lora.load_weights(lora_sd) | |
if self.frame_embedder is not None: | |
self.frame_embedder.load_state_dict( | |
frame_embedder_sd, strict=False) | |
self.attn_hog_list.load_state_dict( | |
attn_hog_sd, strict=False) | |
self.image_embedder.load_state_dict( | |
image_embedder_sd, strict=False) | |
def get_state_dict(self): | |
if self.control_lora is not None: | |
lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) | |
else: | |
lora_sd = {} | |
if self.frame_embedder is not None: | |
frame_embedder_sd = self.frame_embedder.state_dict() | |
for key, value in frame_embedder_sd.items(): | |
lora_sd[f"frame_embedder.{key}"] = value | |
# add the attn hogs | |
attn_hog_sd = self.attn_hog_list.state_dict() | |
for key, value in attn_hog_sd.items(): | |
lora_sd[f"attn_hog.{key}"] = value | |
# add the image embedder | |
image_embedder_sd = self.image_embedder.state_dict() | |
for key, value in image_embedder_sd.items(): | |
lora_sd[f"image_embedder.{key}"] = value | |
return lora_sd | |
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): | |
# todo handle start frame | |
return latents | |
def edit_batch_processed(self, batch: DataLoaderBatchDTO): | |
with torch.no_grad(): | |
# we will alway get a clip image frame, if one is not passed, use image | |
# or if video, pull from the first frame | |
# edit the batch to pull the first frame out of a video if we have it | |
# videos come in (bs, num_frames, channels, height, width) | |
tensor = batch.tensor | |
if batch.clip_image_tensor is None: | |
if len(tensor.shape) == 5: | |
# we have a video | |
first_frames = tensor[:, 0, :, :, :].clone() | |
else: | |
# we have a single image | |
first_frames = tensor.clone() | |
# it is -1 to 1, change it to 0 to 1 | |
first_frames = (first_frames + 1) / 2 | |
# clip image tensors are preprocessed. | |
tensors_0_1 = first_frames.to(dtype=torch.float16) | |
clip_out = self.adapter_ref().clip_image_processor( | |
images=tensors_0_1, | |
return_tensors="pt", | |
do_resize=True, | |
do_rescale=False, | |
).pixel_values | |
batch.clip_image_tensor = clip_out.to(self.device_torch) | |
return batch | |
def is_active(self): | |
return self.adapter_ref().is_active | |