from functools import partial import inspect import weakref import torch from typing import TYPE_CHECKING, Any, Dict, Optional, Union from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.lora_special import LoRASpecialNetwork from diffusers import WanTransformer3DModel from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding from toolkit.util.shuffle import shuffle_tensor_along_axis import torch.nn.functional as F if TYPE_CHECKING: from toolkit.models.base_model import BaseModel from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig from toolkit.custom_adapter import CustomAdapter class FrameEmbedder(torch.nn.Module): def __init__( self, adapter: 'I2VAdapter', orig_layer: torch.nn.Conv3d, in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels ): super().__init__() # goes through a conv patch embedding first and is then flattened # hidden_states = self.patch_embedding(hidden_states) # hidden_states = hidden_states.flatten(2).transpose(1, 2) inner_dim = orig_layer.out_channels patch_size = adapter.sd_ref().model.config.patch_size self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) self.adapter_ref: weakref.ref = weakref.ref(adapter) self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) @classmethod def from_model( cls, model: WanTransformer3DModel, adapter: 'I2VAdapter', ): if model.__class__.__name__ == 'WanTransformer3DModel': new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding img_embedder = cls( adapter, orig_layer=orig_patch_embedding, in_channels=new_channels, ) # hijack the forward method orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward orig_patch_embedding.forward = img_embedder.forward # update the config of the transformer, only needed when merged in # model.config.in_channels = model.config.in_channels + new_channels # model.config["in_channels"] = model.config.in_channels + new_channels return img_embedder else: raise ValueError("Model not supported") @property def is_active(self): return self.adapter_ref().is_active def forward(self, x): if not self.is_active: # make sure lora is not active if self.adapter_ref().control_lora is not None: self.adapter_ref().control_lora.is_active = False if x.shape[1] > self.orig_layer_ref().in_channels: # we have i2v, so we need to remove the extra channels x = x[:, :self.orig_layer_ref().in_channels, :, :, :] return self.orig_layer_ref()._orig_i2v_adapter_forward(x) # make sure lora is active if self.adapter_ref().control_lora is not None: self.adapter_ref().control_lora.is_active = True # x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16) # (16 + 4 + 16) = 36 channels # (batch_size, 36, num_frames, latent_height, latent_width) orig_device = x.device orig_dtype = x.dtype orig_in = x[:, :16, :, :, :] orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in) # remove original stuff x = x[:, 16:, :, :, :] x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype) x = self.patch_embedding(x) x = x.to(orig_device, dtype=orig_dtype) # add the original out x = x + orig_out return x def deactivatable_forward( self: 'Attention', *args, **kwargs ): if self._attn_hog_ref() is not None and self._attn_hog_ref().is_active: self.added_kv_proj_dim = None self.add_k_proj = self._add_k_proj self.add_v_proj = self._add_v_proj self.norm_added_q = self._norm_added_q self.norm_added_k = self._norm_added_k else: self.added_kv_proj_dim = self._attn_hog_ref().added_kv_proj_dim self.add_k_proj = None self.add_v_proj = None self.norm_added_q = None self.norm_added_k = None return self._orig_forward(*args, **kwargs) class AttentionHog(torch.nn.Module): def __init__( self, added_kv_proj_dim: int, adapter: 'I2VAdapter', attn_layer: Attention, model: 'WanTransformer3DModel', ): super().__init__() # To prevent circular import. from diffusers.models.normalization import FP32LayerNorm, LpNorm, RMSNorm self.added_kv_proj_dim = added_kv_proj_dim self.attn_layer_ref: weakref.ref = weakref.ref(attn_layer) self.adapter_ref: weakref.ref = weakref.ref(adapter) self.model_ref: weakref.ref = weakref.ref(model) qk_norm = model.config.qk_norm # layers self.add_k_proj = torch.nn.Linear( added_kv_proj_dim, attn_layer.inner_kv_dim, bias=attn_layer.added_proj_bias ) self.add_k_proj.weight.data = self.add_k_proj.weight.data * 0.001 self.add_v_proj = torch.nn.Linear( added_kv_proj_dim, attn_layer.inner_kv_dim, bias=attn_layer.added_proj_bias ) self.add_v_proj.weight.data = self.add_v_proj.weight.data * 0.001 # do qk norm. It isnt stored in the class, but we can infer it from the attn layer self.norm_added_q = None self.norm_added_k = None if attn_layer.norm_q is not None: eps: float = 1e-5 if qk_norm == "layer_norm": self.norm_added_q = torch.nn.LayerNorm( attn_layer.norm_q.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_q.elementwise_affine) self.norm_added_k = torch.nn.LayerNorm( attn_layer.norm_k.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_k.elementwise_affine) elif qk_norm == "fp32_layer_norm": self.norm_added_q = FP32LayerNorm( attn_layer.norm_q.normalized_shape, elementwise_affine=False, bias=False, eps=eps) self.norm_added_k = FP32LayerNorm( attn_layer.norm_k.normalized_shape, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "rms_norm": self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) elif qk_norm == "rms_norm_across_heads": # Wanx applies qk norm across all heads self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) else: raise ValueError( f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" ) # add these to the attn later in a way they can be deactivated attn_layer._add_k_proj = self.add_k_proj attn_layer._add_v_proj = self.add_v_proj attn_layer._norm_added_q = self.norm_added_q attn_layer._norm_added_k = self.norm_added_k # make it deactivateable attn_layer._attn_hog_ref = weakref.ref(self) attn_layer._orig_forward = attn_layer.forward attn_layer.forward = partial(deactivatable_forward, attn_layer) def forward(self, *args, **kwargs): if not self.adapter_ref().is_active: return self.attn_module(*args, **kwargs) # TODO implement this raise NotImplementedError("Attention hog not implemented") def is_active(self): return self.adapter_ref().is_active def new_wan_forward( self: WanTransformer3DModel, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: # prevent circular import from toolkit.models.wan21.wan_utils import add_first_frame_conditioning adapter:'I2VAdapter' = self._i2v_adapter_ref() if adapter.is_active: # activate the condition embedder self.condition_embedder.image_embedder = adapter.image_embedder # for wan they are putting the image emcoder embeds on the unconditional # this needs to be fixed as that wont work. For now, we will will use the embeds we have in order # we cache an conditional and an unconditional embed. On sampling, it samples conditional first, # then unconditional. So we just need to keep track of which one we are using. This is a horrible hack # TODO find a not stupid way to do this. if adapter.adapter_ref().is_sampling: if not hasattr(self, '_do_unconditional'): # set it to true so we alternate to false immediatly self._do_unconditional = True # alternate it self._do_unconditional = not self._do_unconditional if self._do_unconditional: # slightly reduce strength of conditional for the unconditional # encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5 # shuffle the embedding tokens so we still have all the information, but it is scrambled # this will prevent things like color from being cfg overweights, but still sharpen content. encoder_hidden_states_image = shuffle_tensor_along_axis( adapter.adapter_ref().conditional_embeds, axis=1 ) # encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds else: # use the conditional encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds else: # doing a normal training run, always use conditional embeds encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds # add the first frame conditioning if adapter.frame_embedder is not None: with torch.no_grad(): # add the first frame conditioning conditioning_frame = adapter.adapter_ref().cached_control_image_0_1 if conditioning_frame is None: raise ValueError("No conditioning frame found") # make it -1 to 1 conditioning_frame = (conditioning_frame * 2) - 1 conditioning_frame = conditioning_frame.to( hidden_states.device, dtype=hidden_states.dtype ) # if doing a full denoise, the latent input may be full channels here, only get first 16 if hidden_states.shape[1] > 16: hidden_states = hidden_states[:, :16, :, :, :] hidden_states = add_first_frame_conditioning( latent_model_input=hidden_states, first_frame=conditioning_frame, vae=adapter.adapter_ref().sd_ref().vae, ) else: # not active deactivate the condition embedder self.condition_embedder.image_embedder = None return self._orig_i2v_adapter_forward( hidden_states=hidden_states, timestep=timestep, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_image=encoder_hidden_states_image, return_dict=return_dict, attention_kwargs=attention_kwargs, ) class I2VAdapter(torch.nn.Module): def __init__( self, adapter: 'CustomAdapter', sd: 'BaseModel', config: 'AdapterConfig', train_config: 'TrainConfig', image_processor: Union[SiglipImageProcessor, CLIPImageProcessor], vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection], ): super().__init__() # avoid circular import from toolkit.models.wan21.wan_attn import WanAttnProcessor2_0 self.adapter_ref: weakref.ref = weakref.ref(adapter) self.sd_ref = weakref.ref(sd) self.model_config: ModelConfig = sd.model_config self.network_config = config.lora_config self.train_config = train_config self.config = config self.device_torch = sd.device_torch self.control_lora = None self.image_processor_ref: weakref.ref = weakref.ref(image_processor) self.vision_encoder_ref: weakref.ref = weakref.ref(vision_encoder) ve_img_size = vision_encoder.config.image_size ve_patch_size = vision_encoder.config.patch_size num_patches = (ve_img_size // ve_patch_size) ** 2 num_vision_tokens = num_patches # siglip does not have a class token if not vision_encoder.__class__.__name__.lower().startswith("siglip"): num_vision_tokens = num_patches + 1 model_class = sd.model.__class__.__name__ if self.network_config is not None: network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs if hasattr(sd, 'target_lora_modules'): network_kwargs['target_lin_modules'] = sd.target_lora_modules if 'ignore_if_contains' not in network_kwargs: network_kwargs['ignore_if_contains'] = [] network_kwargs['ignore_if_contains'] += [ 'add_k_proj', 'add_v_proj', 'norm_added_q', 'norm_added_k', ] if model_class == 'WanTransformer3DModel': # always ignore patch_embedding network_kwargs['ignore_if_contains'].append('patch_embedding') self.control_lora = LoRASpecialNetwork( text_encoder=sd.text_encoder, unet=sd.unet, lora_dim=self.network_config.linear, multiplier=1.0, alpha=self.network_config.linear_alpha, train_unet=self.train_config.train_unet, train_text_encoder=self.train_config.train_text_encoder, conv_lora_dim=self.network_config.conv, conv_alpha=self.network_config.conv_alpha, is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, is_v2=self.model_config.is_v2, is_v3=self.model_config.is_v3, is_pixart=self.model_config.is_pixart, is_auraflow=self.model_config.is_auraflow, is_flux=self.model_config.is_flux, is_lumina2=self.model_config.is_lumina2, is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, use_text_encoder_1=self.model_config.use_text_encoder_1, use_text_encoder_2=self.model_config.use_text_encoder_2, use_bias=False, is_lorm=False, network_config=self.network_config, network_type=self.network_config.type, transformer_only=self.network_config.transformer_only, is_transformer=sd.is_transformer, base_model=sd, **network_kwargs ) self.control_lora.force_to(self.device_torch, dtype=torch.float32) self.control_lora._update_torch_multiplier() self.control_lora.apply_to( sd.text_encoder, sd.unet, self.train_config.train_text_encoder, self.train_config.train_unet ) self.control_lora.can_merge_in = False self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) if self.train_config.gradient_checkpointing: self.control_lora.enable_gradient_checkpointing() self.frame_embedder: FrameEmbedder = None if self.config.i2v_do_start_frame: self.frame_embedder = FrameEmbedder.from_model( sd.unet, self ) self.frame_embedder.to(self.device_torch) # hijack the blocks so we can inject our vision encoder attn_hog_list = [] if model_class == 'WanTransformer3DModel': added_kv_proj_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim # update the model so it can accept the new input # wan has i2v with clip-h for i2v, additional k v attn that directly takes # in the penultimate_hidden_states from the vision encoder # the kv is on blocks[0].attn2 sd.model.config.added_kv_proj_dim = added_kv_proj_dim sd.model.config['added_kv_proj_dim'] = added_kv_proj_dim transformer: WanTransformer3DModel = sd.model for block in transformer.blocks: block.attn2.added_kv_proj_dim = added_kv_proj_dim attn_module = AttentionHog( added_kv_proj_dim, self, block.attn2, transformer ) # set the attn function to ours that handles custom number of vision tokens block.attn2.set_processor(WanAttnProcessor2_0(num_vision_tokens)) attn_hog_list.append(attn_module) else: raise ValueError(f"Model {model_class} not supported") self.attn_hog_list = torch.nn.ModuleList(attn_hog_list) self.attn_hog_list.to(self.device_torch) inner_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim image_embed_dim = vision_encoder.config.hidden_size self.image_embedder = WanImageEmbedding(image_embed_dim, inner_dim) # override the forward method if model_class == 'WanTransformer3DModel': self.sd_ref().model._orig_i2v_adapter_forward = self.sd_ref().model.forward self.sd_ref().model.forward = partial( new_wan_forward, self.sd_ref().model ) # add the wan image embedder self.sd_ref().model.condition_embedder._image_embedder = self.image_embedder self.sd_ref().model.condition_embedder._image_embedder.to(self.device_torch) self.sd_ref().model._i2v_adapter_ref = weakref.ref(self) def get_params(self): if self.control_lora is not None: config = { 'text_encoder_lr': self.train_config.lr, 'unet_lr': self.train_config.lr, } sig = inspect.signature(self.control_lora.prepare_optimizer_params) if 'default_lr' in sig.parameters: config['default_lr'] = self.train_config.lr if 'learning_rate' in sig.parameters: config['learning_rate'] = self.train_config.lr params_net = self.control_lora.prepare_optimizer_params( **config ) # we want only tensors here params = [] for p in params_net: if isinstance(p, dict): params += p["params"] elif isinstance(p, torch.Tensor): params.append(p) elif isinstance(p, list): params += p else: params = [] if self.frame_embedder is not None: # make sure the embedder is float32 self.frame_embedder.to(torch.float32) params += list(self.frame_embedder.parameters()) # add the attn hogs for attn_hog in self.attn_hog_list: params += list(attn_hog.parameters()) # add the image embedder if self.image_embedder is not None: params += list(self.image_embedder.parameters()) return params def load_weights(self, state_dict, strict=True): lora_sd = {} attn_hog_sd = {} frame_embedder_sd = {} image_embedder_sd = {} for key, value in state_dict.items(): if "frame_embedder" in key: new_key = key.replace("frame_embedder.", "") frame_embedder_sd[new_key] = value elif "attn_hog" in key: new_key = key.replace("attn_hog.", "") attn_hog_sd[new_key] = value elif "image_embedder" in key: new_key = key.replace("image_embedder.", "") image_embedder_sd[new_key] = value else: lora_sd[key] = value # todo process state dict before loading if self.control_lora is not None: self.control_lora.load_weights(lora_sd) if self.frame_embedder is not None: self.frame_embedder.load_state_dict( frame_embedder_sd, strict=False) self.attn_hog_list.load_state_dict( attn_hog_sd, strict=False) self.image_embedder.load_state_dict( image_embedder_sd, strict=False) def get_state_dict(self): if self.control_lora is not None: lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) else: lora_sd = {} if self.frame_embedder is not None: frame_embedder_sd = self.frame_embedder.state_dict() for key, value in frame_embedder_sd.items(): lora_sd[f"frame_embedder.{key}"] = value # add the attn hogs attn_hog_sd = self.attn_hog_list.state_dict() for key, value in attn_hog_sd.items(): lora_sd[f"attn_hog.{key}"] = value # add the image embedder image_embedder_sd = self.image_embedder.state_dict() for key, value in image_embedder_sd.items(): lora_sd[f"image_embedder.{key}"] = value return lora_sd def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): # todo handle start frame return latents def edit_batch_processed(self, batch: DataLoaderBatchDTO): with torch.no_grad(): # we will alway get a clip image frame, if one is not passed, use image # or if video, pull from the first frame # edit the batch to pull the first frame out of a video if we have it # videos come in (bs, num_frames, channels, height, width) tensor = batch.tensor if batch.clip_image_tensor is None: if len(tensor.shape) == 5: # we have a video first_frames = tensor[:, 0, :, :, :].clone() else: # we have a single image first_frames = tensor.clone() # it is -1 to 1, change it to 0 to 1 first_frames = (first_frames + 1) / 2 # clip image tensors are preprocessed. tensors_0_1 = first_frames.to(dtype=torch.float16) clip_out = self.adapter_ref().clip_image_processor( images=tensors_0_1, return_tensors="pt", do_resize=True, do_rescale=False, ).pixel_values batch.clip_image_tensor = clip_out.to(self.device_torch) return batch @property def is_active(self): return self.adapter_ref().is_active