Yuanshi's picture
Upload 61 files
fb6a167 verified
raw
history blame contribute delete
8.93 kB
import torch
from diffusers.pipelines import FluxPipeline
from typing import List, Union, Optional, Dict, Any, Callable
from .block import block_forward, single_block_forward
from .lora_controller import enable_lora
from accelerate.utils import is_torch_version
from diffusers.models.transformers.transformer_flux import (
FluxTransformer2DModel,
Transformer2DModelOutput,
USE_PEFT_BACKEND,
scale_lora_layers,
unscale_lora_layers,
logger,
)
import numpy as np
def prepare_params(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
**kwargs: dict,
):
return (
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
img_ids,
txt_ids,
guidance,
joint_attention_kwargs,
controlnet_block_samples,
controlnet_single_block_samples,
return_dict,
)
def tranformer_forward(
transformer: FluxTransformer2DModel,
condition_latents: torch.Tensor,
condition_ids: torch.Tensor,
condition_type_ids: torch.Tensor,
model_config: Optional[Dict[str, Any]] = {},
c_t=0,
**params: dict,
):
self = transformer
use_condition = condition_latents is not None
(
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
img_ids,
txt_ids,
guidance,
joint_attention_kwargs,
controlnet_block_samples,
controlnet_single_block_samples,
return_dict,
) = prepare_params(**params)
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
hidden_states = self.x_embedder(hidden_states)
condition_latents = self.x_embedder(condition_latents) if use_condition else None
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
cond_temb = (
self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
if guidance is None
else self.time_text_embed(
torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if use_condition:
# condition_ids[:, :1] = condition_type_ids
cond_rotary_emb = self.pos_embed(condition_ids)
# hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
encoder_hidden_states, hidden_states, condition_latents = (
torch.utils.checkpoint.checkpoint(
block_forward,
self=block,
model_config=model_config,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
condition_latents=condition_latents if use_condition else None,
temb=temb,
cond_temb=cond_temb if use_condition else None,
cond_rotary_emb=cond_rotary_emb if use_condition else None,
image_rotary_emb=image_rotary_emb,
**ckpt_kwargs,
)
)
else:
encoder_hidden_states, hidden_states, condition_latents = block_forward(
block,
model_config=model_config,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
condition_latents=condition_latents if use_condition else None,
temb=temb,
cond_temb=cond_temb if use_condition else None,
cond_rotary_emb=cond_rotary_emb if use_condition else None,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(
controlnet_block_samples
)
interval_control = int(np.ceil(interval_control))
hidden_states = (
hidden_states
+ controlnet_block_samples[index_block // interval_control]
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
result = torch.utils.checkpoint.checkpoint(
single_block_forward,
self=block,
model_config=model_config,
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
**(
{
"condition_latents": condition_latents,
"cond_temb": cond_temb,
"cond_rotary_emb": cond_rotary_emb,
}
if use_condition
else {}
),
**ckpt_kwargs,
)
else:
result = single_block_forward(
block,
model_config=model_config,
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
**(
{
"condition_latents": condition_latents,
"cond_temb": cond_temb,
"cond_rotary_emb": cond_rotary_emb,
}
if use_condition
else {}
),
)
if use_condition:
hidden_states, condition_latents = result
else:
hidden_states = result
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(
controlnet_single_block_samples
)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)