ramimu's picture
Upload 586 files
1c72248 verified
from functools import partial
import inspect
import weakref
import torch
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.lora_special import LoRASpecialNetwork
from diffusers import WanTransformer3DModel
from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
from toolkit.util.shuffle import shuffle_tensor_along_axis
import torch.nn.functional as F
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
from toolkit.custom_adapter import CustomAdapter
class FrameEmbedder(torch.nn.Module):
def __init__(
self,
adapter: 'I2VAdapter',
orig_layer: torch.nn.Conv3d,
in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels
):
super().__init__()
# goes through a conv patch embedding first and is then flattened
# hidden_states = self.patch_embedding(hidden_states)
# hidden_states = hidden_states.flatten(2).transpose(1, 2)
inner_dim = orig_layer.out_channels
patch_size = adapter.sd_ref().model.config.patch_size
self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
@classmethod
def from_model(
cls,
model: WanTransformer3DModel,
adapter: 'I2VAdapter',
):
if model.__class__.__name__ == 'WanTransformer3DModel':
new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels
orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding
img_embedder = cls(
adapter,
orig_layer=orig_patch_embedding,
in_channels=new_channels,
)
# hijack the forward method
orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward
orig_patch_embedding.forward = img_embedder.forward
# update the config of the transformer, only needed when merged in
# model.config.in_channels = model.config.in_channels + new_channels
# model.config["in_channels"] = model.config.in_channels + new_channels
return img_embedder
else:
raise ValueError("Model not supported")
@property
def is_active(self):
return self.adapter_ref().is_active
def forward(self, x):
if not self.is_active:
# make sure lora is not active
if self.adapter_ref().control_lora is not None:
self.adapter_ref().control_lora.is_active = False
if x.shape[1] > self.orig_layer_ref().in_channels:
# we have i2v, so we need to remove the extra channels
x = x[:, :self.orig_layer_ref().in_channels, :, :, :]
return self.orig_layer_ref()._orig_i2v_adapter_forward(x)
# make sure lora is active
if self.adapter_ref().control_lora is not None:
self.adapter_ref().control_lora.is_active = True
# x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16)
# (16 + 4 + 16) = 36 channels
# (batch_size, 36, num_frames, latent_height, latent_width)
orig_device = x.device
orig_dtype = x.dtype
orig_in = x[:, :16, :, :, :]
orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in)
# remove original stuff
x = x[:, 16:, :, :, :]
x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype)
x = self.patch_embedding(x)
x = x.to(orig_device, dtype=orig_dtype)
# add the original out
x = x + orig_out
return x
def deactivatable_forward(
self: 'Attention',
*args,
**kwargs
):
if self._attn_hog_ref() is not None and self._attn_hog_ref().is_active:
self.added_kv_proj_dim = None
self.add_k_proj = self._add_k_proj
self.add_v_proj = self._add_v_proj
self.norm_added_q = self._norm_added_q
self.norm_added_k = self._norm_added_k
else:
self.added_kv_proj_dim = self._attn_hog_ref().added_kv_proj_dim
self.add_k_proj = None
self.add_v_proj = None
self.norm_added_q = None
self.norm_added_k = None
return self._orig_forward(*args, **kwargs)
class AttentionHog(torch.nn.Module):
def __init__(
self,
added_kv_proj_dim: int,
adapter: 'I2VAdapter',
attn_layer: Attention,
model: 'WanTransformer3DModel',
):
super().__init__()
# To prevent circular import.
from diffusers.models.normalization import FP32LayerNorm, LpNorm, RMSNorm
self.added_kv_proj_dim = added_kv_proj_dim
self.attn_layer_ref: weakref.ref = weakref.ref(attn_layer)
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.model_ref: weakref.ref = weakref.ref(model)
qk_norm = model.config.qk_norm
# layers
self.add_k_proj = torch.nn.Linear(
added_kv_proj_dim,
attn_layer.inner_kv_dim,
bias=attn_layer.added_proj_bias
)
self.add_k_proj.weight.data = self.add_k_proj.weight.data * 0.001
self.add_v_proj = torch.nn.Linear(
added_kv_proj_dim,
attn_layer.inner_kv_dim,
bias=attn_layer.added_proj_bias
)
self.add_v_proj.weight.data = self.add_v_proj.weight.data * 0.001
# do qk norm. It isnt stored in the class, but we can infer it from the attn layer
self.norm_added_q = None
self.norm_added_k = None
if attn_layer.norm_q is not None:
eps: float = 1e-5
if qk_norm == "layer_norm":
self.norm_added_q = torch.nn.LayerNorm(
attn_layer.norm_q.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_q.elementwise_affine)
self.norm_added_k = torch.nn.LayerNorm(
attn_layer.norm_k.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_k.elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(
attn_layer.norm_q.normalized_shape, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(
attn_layer.norm_k.normalized_shape, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps)
self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# Wanx applies qk norm across all heads
self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps)
self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
)
# add these to the attn later in a way they can be deactivated
attn_layer._add_k_proj = self.add_k_proj
attn_layer._add_v_proj = self.add_v_proj
attn_layer._norm_added_q = self.norm_added_q
attn_layer._norm_added_k = self.norm_added_k
# make it deactivateable
attn_layer._attn_hog_ref = weakref.ref(self)
attn_layer._orig_forward = attn_layer.forward
attn_layer.forward = partial(deactivatable_forward, attn_layer)
def forward(self, *args, **kwargs):
if not self.adapter_ref().is_active:
return self.attn_module(*args, **kwargs)
# TODO implement this
raise NotImplementedError("Attention hog not implemented")
def is_active(self):
return self.adapter_ref().is_active
def new_wan_forward(
self: WanTransformer3DModel,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# prevent circular import
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
adapter:'I2VAdapter' = self._i2v_adapter_ref()
if adapter.is_active:
# activate the condition embedder
self.condition_embedder.image_embedder = adapter.image_embedder
# for wan they are putting the image emcoder embeds on the unconditional
# this needs to be fixed as that wont work. For now, we will will use the embeds we have in order
# we cache an conditional and an unconditional embed. On sampling, it samples conditional first,
# then unconditional. So we just need to keep track of which one we are using. This is a horrible hack
# TODO find a not stupid way to do this.
if adapter.adapter_ref().is_sampling:
if not hasattr(self, '_do_unconditional'):
# set it to true so we alternate to false immediatly
self._do_unconditional = True
# alternate it
self._do_unconditional = not self._do_unconditional
if self._do_unconditional:
# slightly reduce strength of conditional for the unconditional
# encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5
# shuffle the embedding tokens so we still have all the information, but it is scrambled
# this will prevent things like color from being cfg overweights, but still sharpen content.
encoder_hidden_states_image = shuffle_tensor_along_axis(
adapter.adapter_ref().conditional_embeds,
axis=1
)
# encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds
else:
# use the conditional
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds
else:
# doing a normal training run, always use conditional embeds
encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds
# add the first frame conditioning
if adapter.frame_embedder is not None:
with torch.no_grad():
# add the first frame conditioning
conditioning_frame = adapter.adapter_ref().cached_control_image_0_1
if conditioning_frame is None:
raise ValueError("No conditioning frame found")
# make it -1 to 1
conditioning_frame = (conditioning_frame * 2) - 1
conditioning_frame = conditioning_frame.to(
hidden_states.device, dtype=hidden_states.dtype
)
# if doing a full denoise, the latent input may be full channels here, only get first 16
if hidden_states.shape[1] > 16:
hidden_states = hidden_states[:, :16, :, :, :]
hidden_states = add_first_frame_conditioning(
latent_model_input=hidden_states,
first_frame=conditioning_frame,
vae=adapter.adapter_ref().sd_ref().vae,
)
else:
# not active deactivate the condition embedder
self.condition_embedder.image_embedder = None
return self._orig_i2v_adapter_forward(
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_image=encoder_hidden_states_image,
return_dict=return_dict,
attention_kwargs=attention_kwargs,
)
class I2VAdapter(torch.nn.Module):
def __init__(
self,
adapter: 'CustomAdapter',
sd: 'BaseModel',
config: 'AdapterConfig',
train_config: 'TrainConfig',
image_processor: Union[SiglipImageProcessor, CLIPImageProcessor],
vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection],
):
super().__init__()
# avoid circular import
from toolkit.models.wan21.wan_attn import WanAttnProcessor2_0
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref = weakref.ref(sd)
self.model_config: ModelConfig = sd.model_config
self.network_config = config.lora_config
self.train_config = train_config
self.config = config
self.device_torch = sd.device_torch
self.control_lora = None
self.image_processor_ref: weakref.ref = weakref.ref(image_processor)
self.vision_encoder_ref: weakref.ref = weakref.ref(vision_encoder)
ve_img_size = vision_encoder.config.image_size
ve_patch_size = vision_encoder.config.patch_size
num_patches = (ve_img_size // ve_patch_size) ** 2
num_vision_tokens = num_patches
# siglip does not have a class token
if not vision_encoder.__class__.__name__.lower().startswith("siglip"):
num_vision_tokens = num_patches + 1
model_class = sd.model.__class__.__name__
if self.network_config is not None:
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
if hasattr(sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = sd.target_lora_modules
if 'ignore_if_contains' not in network_kwargs:
network_kwargs['ignore_if_contains'] = []
network_kwargs['ignore_if_contains'] += [
'add_k_proj',
'add_v_proj',
'norm_added_q',
'norm_added_k',
]
if model_class == 'WanTransformer3DModel':
# always ignore patch_embedding
network_kwargs['ignore_if_contains'].append('patch_embedding')
self.control_lora = LoRASpecialNetwork(
text_encoder=sd.text_encoder,
unet=sd.unet,
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_v3=self.model_config.is_v3,
is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow,
is_flux=self.model_config.is_flux,
is_lumina2=self.model_config.is_lumina2,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,
use_bias=False,
is_lorm=False,
network_config=self.network_config,
network_type=self.network_config.type,
transformer_only=self.network_config.transformer_only,
is_transformer=sd.is_transformer,
base_model=sd,
**network_kwargs
)
self.control_lora.force_to(self.device_torch, dtype=torch.float32)
self.control_lora._update_torch_multiplier()
self.control_lora.apply_to(
sd.text_encoder,
sd.unet,
self.train_config.train_text_encoder,
self.train_config.train_unet
)
self.control_lora.can_merge_in = False
self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet)
if self.train_config.gradient_checkpointing:
self.control_lora.enable_gradient_checkpointing()
self.frame_embedder: FrameEmbedder = None
if self.config.i2v_do_start_frame:
self.frame_embedder = FrameEmbedder.from_model(
sd.unet,
self
)
self.frame_embedder.to(self.device_torch)
# hijack the blocks so we can inject our vision encoder
attn_hog_list = []
if model_class == 'WanTransformer3DModel':
added_kv_proj_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim
# update the model so it can accept the new input
# wan has i2v with clip-h for i2v, additional k v attn that directly takes
# in the penultimate_hidden_states from the vision encoder
# the kv is on blocks[0].attn2
sd.model.config.added_kv_proj_dim = added_kv_proj_dim
sd.model.config['added_kv_proj_dim'] = added_kv_proj_dim
transformer: WanTransformer3DModel = sd.model
for block in transformer.blocks:
block.attn2.added_kv_proj_dim = added_kv_proj_dim
attn_module = AttentionHog(
added_kv_proj_dim,
self,
block.attn2,
transformer
)
# set the attn function to ours that handles custom number of vision tokens
block.attn2.set_processor(WanAttnProcessor2_0(num_vision_tokens))
attn_hog_list.append(attn_module)
else:
raise ValueError(f"Model {model_class} not supported")
self.attn_hog_list = torch.nn.ModuleList(attn_hog_list)
self.attn_hog_list.to(self.device_torch)
inner_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim
image_embed_dim = vision_encoder.config.hidden_size
self.image_embedder = WanImageEmbedding(image_embed_dim, inner_dim)
# override the forward method
if model_class == 'WanTransformer3DModel':
self.sd_ref().model._orig_i2v_adapter_forward = self.sd_ref().model.forward
self.sd_ref().model.forward = partial(
new_wan_forward,
self.sd_ref().model
)
# add the wan image embedder
self.sd_ref().model.condition_embedder._image_embedder = self.image_embedder
self.sd_ref().model.condition_embedder._image_embedder.to(self.device_torch)
self.sd_ref().model._i2v_adapter_ref = weakref.ref(self)
def get_params(self):
if self.control_lora is not None:
config = {
'text_encoder_lr': self.train_config.lr,
'unet_lr': self.train_config.lr,
}
sig = inspect.signature(self.control_lora.prepare_optimizer_params)
if 'default_lr' in sig.parameters:
config['default_lr'] = self.train_config.lr
if 'learning_rate' in sig.parameters:
config['learning_rate'] = self.train_config.lr
params_net = self.control_lora.prepare_optimizer_params(
**config
)
# we want only tensors here
params = []
for p in params_net:
if isinstance(p, dict):
params += p["params"]
elif isinstance(p, torch.Tensor):
params.append(p)
elif isinstance(p, list):
params += p
else:
params = []
if self.frame_embedder is not None:
# make sure the embedder is float32
self.frame_embedder.to(torch.float32)
params += list(self.frame_embedder.parameters())
# add the attn hogs
for attn_hog in self.attn_hog_list:
params += list(attn_hog.parameters())
# add the image embedder
if self.image_embedder is not None:
params += list(self.image_embedder.parameters())
return params
def load_weights(self, state_dict, strict=True):
lora_sd = {}
attn_hog_sd = {}
frame_embedder_sd = {}
image_embedder_sd = {}
for key, value in state_dict.items():
if "frame_embedder" in key:
new_key = key.replace("frame_embedder.", "")
frame_embedder_sd[new_key] = value
elif "attn_hog" in key:
new_key = key.replace("attn_hog.", "")
attn_hog_sd[new_key] = value
elif "image_embedder" in key:
new_key = key.replace("image_embedder.", "")
image_embedder_sd[new_key] = value
else:
lora_sd[key] = value
# todo process state dict before loading
if self.control_lora is not None:
self.control_lora.load_weights(lora_sd)
if self.frame_embedder is not None:
self.frame_embedder.load_state_dict(
frame_embedder_sd, strict=False)
self.attn_hog_list.load_state_dict(
attn_hog_sd, strict=False)
self.image_embedder.load_state_dict(
image_embedder_sd, strict=False)
def get_state_dict(self):
if self.control_lora is not None:
lora_sd = self.control_lora.get_state_dict(dtype=torch.float32)
else:
lora_sd = {}
if self.frame_embedder is not None:
frame_embedder_sd = self.frame_embedder.state_dict()
for key, value in frame_embedder_sd.items():
lora_sd[f"frame_embedder.{key}"] = value
# add the attn hogs
attn_hog_sd = self.attn_hog_list.state_dict()
for key, value in attn_hog_sd.items():
lora_sd[f"attn_hog.{key}"] = value
# add the image embedder
image_embedder_sd = self.image_embedder.state_dict()
for key, value in image_embedder_sd.items():
lora_sd[f"image_embedder.{key}"] = value
return lora_sd
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
# todo handle start frame
return latents
def edit_batch_processed(self, batch: DataLoaderBatchDTO):
with torch.no_grad():
# we will alway get a clip image frame, if one is not passed, use image
# or if video, pull from the first frame
# edit the batch to pull the first frame out of a video if we have it
# videos come in (bs, num_frames, channels, height, width)
tensor = batch.tensor
if batch.clip_image_tensor is None:
if len(tensor.shape) == 5:
# we have a video
first_frames = tensor[:, 0, :, :, :].clone()
else:
# we have a single image
first_frames = tensor.clone()
# it is -1 to 1, change it to 0 to 1
first_frames = (first_frames + 1) / 2
# clip image tensors are preprocessed.
tensors_0_1 = first_frames.to(dtype=torch.float16)
clip_out = self.adapter_ref().clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
batch.clip_image_tensor = clip_out.to(self.device_torch)
return batch
@property
def is_active(self):
return self.adapter_ref().is_active