Diffusion-As-Shader / models /cogvideox_tracking.py
Beijia11
init
3aba902
from typing import Any, Dict, Optional, Tuple, Union, List, Callable
import torch, os, math
from torch import nn
from PIL import Image
from tqdm import tqdm
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, CogVideoXPipelineOutput
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
from diffusers.pipelines.cogvideo.pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.pipelines.cogvideo.pipeline_cogvideox import retrieve_timesteps
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.pipelines import DiffusionPipeline
from diffusers.models.modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXTransformer3DModelTracking(CogVideoXTransformer3DModel, ModelMixin):
"""
Add tracking maps to the CogVideoX transformer model.
Parameters:
num_tracking_blocks (`int`, defaults to `18`):
The number of tracking blocks to use. Must be less than or equal to num_layers.
"""
def __init__(
self,
num_tracking_blocks: Optional[int] = 18,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
**kwargs
):
super().__init__(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
out_channels=out_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
time_embed_dim=time_embed_dim,
text_embed_dim=text_embed_dim,
num_layers=num_layers,
dropout=dropout,
attention_bias=attention_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
patch_size=patch_size,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
activation_fn=activation_fn,
timestep_activation_fn=timestep_activation_fn,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
**kwargs
)
inner_dim = num_attention_heads * attention_head_dim
self.num_tracking_blocks = num_tracking_blocks
# Ensure num_tracking_blocks is not greater than num_layers
if num_tracking_blocks > num_layers:
raise ValueError("num_tracking_blocks must be less than or equal to num_layers")
# Create linear layers for combining hidden states and tracking maps
self.combine_linears = nn.ModuleList(
[nn.Linear(inner_dim, inner_dim) for _ in range(num_tracking_blocks)]
)
# Initialize weights of combine_linears to zero
for linear in self.combine_linears:
linear.weight.data.zero_()
linear.bias.data.zero_()
# Create transformer blocks for processing tracking maps
self.transformer_blocks_copy = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
time_embed_dim=self.config.time_embed_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
)
for _ in range(num_tracking_blocks)
]
)
# For initial combination of hidden states and tracking maps
self.initial_combine_linear = nn.Linear(inner_dim, inner_dim)
self.initial_combine_linear.weight.data.zero_()
self.initial_combine_linear.bias.data.zero_()
# Freeze all parameters
for param in self.parameters():
param.requires_grad = False
# Unfreeze parameters that need to be trained
for linear in self.combine_linears:
for param in linear.parameters():
param.requires_grad = True
for block in self.transformer_blocks_copy:
for param in block.parameters():
param.requires_grad = True
for param in self.initial_combine_linear.parameters():
param.requires_grad = True
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
tracking_maps: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
# Process tracking maps
prompt_embed = encoder_hidden_states.clone()
tracking_maps_hidden_states = self.patch_embed(prompt_embed, tracking_maps)
tracking_maps_hidden_states = self.embedding_dropout(tracking_maps_hidden_states)
del prompt_embed
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
tracking_maps = tracking_maps_hidden_states[:, text_seq_length:]
# Combine hidden states and tracking maps initially
combined = hidden_states + tracking_maps
tracking_maps = self.initial_combine_linear(combined)
# Process transformer blocks
for i in range(len(self.transformer_blocks)):
if self.training and self.gradient_checkpointing:
# Gradient checkpointing logic for hidden states
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.transformer_blocks[i]),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = self.transformer_blocks[i](
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
if i < len(self.transformer_blocks_copy):
if self.training and self.gradient_checkpointing:
# Gradient checkpointing logic for tracking maps
tracking_maps, _ = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.transformer_blocks_copy[i]),
tracking_maps,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
tracking_maps, _ = self.transformer_blocks_copy[i](
hidden_states=tracking_maps,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
# Combine hidden states and tracking maps
tracking_maps = self.combine_linears[i](tracking_maps)
hidden_states = hidden_states + tracking_maps
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
try:
model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
print("Loaded DiffusionAsShader checkpoint directly.")
for param in model.parameters():
param.requires_grad = False
for linear in model.combine_linears:
for param in linear.parameters():
param.requires_grad = True
for block in model.transformer_blocks_copy:
for param in block.parameters():
param.requires_grad = True
for param in model.initial_combine_linear.parameters():
param.requires_grad = True
return model
except Exception as e:
print(f"Failed to load as DiffusionAsShader: {e}")
print("Attempting to load as CogVideoXTransformer3DModel and convert...")
base_model = CogVideoXTransformer3DModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
config = dict(base_model.config)
config["num_tracking_blocks"] = kwargs.pop("num_tracking_blocks", 18)
model = cls(**config)
model.load_state_dict(base_model.state_dict(), strict=False)
model.initial_combine_linear.weight.data.zero_()
model.initial_combine_linear.bias.data.zero_()
for linear in model.combine_linears:
linear.weight.data.zero_()
linear.bias.data.zero_()
for i in range(model.num_tracking_blocks):
model.transformer_blocks_copy[i].load_state_dict(model.transformer_blocks[i].state_dict())
for param in model.parameters():
param.requires_grad = False
for linear in model.combine_linears:
for param in linear.parameters():
param.requires_grad = True
for block in model.transformer_blocks_copy:
for param in block.parameters():
param.requires_grad = True
for param in model.initial_combine_linear.parameters():
param.requires_grad = True
return model
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Optional[Callable] = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "5GB",
push_to_hub: bool = False,
**kwargs,
):
super().save_pretrained(
save_directory,
is_main_process=is_main_process,
save_function=save_function,
safe_serialization=safe_serialization,
variant=variant,
max_shard_size=max_shard_size,
push_to_hub=push_to_hub,
**kwargs,
)
if is_main_process:
config_dict = dict(self.config)
config_dict.pop("_name_or_path", None)
config_dict.pop("_use_default_values", None)
config_dict["_class_name"] = "CogVideoXTransformer3DModelTracking"
config_dict["num_tracking_blocks"] = self.num_tracking_blocks
os.makedirs(save_directory, exist_ok=True)
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
import json
json.dump(config_dict, f, indent=2)
class CogVideoXPipelineTracking(CogVideoXPipeline, DiffusionPipeline):
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModelTracking,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
tracking_maps: Optional[torch.Tensor] = None,
) -> Union[CogVideoXPipelineOutput, Tuple]:
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
num_videos_per_prompt = 1
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
tracking_maps_latent = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
tracking_maps=tracking_maps_latent,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
class CogVideoXImageToVideoPipelineTracking(CogVideoXImageToVideoPipeline, DiffusionPipeline):
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModelTracking,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
# 打印transformer blocks的数量
print(f"Number of transformer blocks: {len(self.transformer.transformer_blocks)}")
print(f"Number of tracking transformer blocks: {len(self.transformer.transformer_blocks_copy)}")
self.transformer = torch.compile(self.transformer)
@torch.no_grad()
def __call__(
self,
image: Union[torch.Tensor, Image.Image],
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
tracking_maps: Optional[torch.Tensor] = None,
tracking_image: Optional[torch.Tensor] = None,
) -> Union[CogVideoXPipelineOutput, Tuple]:
# Most of the implementation remains the same as the parent class
# We will modify the parts that need to handle tracking_maps
# 1. Check inputs and set default values
self.check_inputs(
image,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
del negative_prompt_embeds
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents
image = self.video_processor.preprocess(image, height=height, width=width).to(
device, dtype=prompt_embeds.dtype
)
tracking_image = self.video_processor.preprocess(tracking_image, height=height, width=width).to(
device, dtype=prompt_embeds.dtype
)
if self.transformer.config.in_channels != 16:
latent_channels = self.transformer.config.in_channels // 2
else:
latent_channels = self.transformer.config.in_channels
latents, image_latents = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
del image
_, tracking_image_latents = self.prepare_latents(
tracking_image,
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents=None,
)
del tracking_image
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
del latent_image_input
# Handle tracking maps
if tracking_maps is not None:
latents_tracking_image = torch.cat([tracking_image_latents] * 2) if do_classifier_free_guidance else tracking_image_latents
tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
tracking_maps_input = torch.cat([tracking_maps_input, latents_tracking_image], dim=2)
del latents_tracking_image
else:
tracking_maps_input = None
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# Predict noise
self.transformer.to(dtype=latent_model_input.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
tracking_maps=tracking_maps_input,
return_dict=False,
)[0]
del latent_model_input
if tracking_maps_input is not None:
del tracking_maps_input
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
del noise_pred_uncond, noise_pred_text
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
del noise_pred
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 9. Post-processing
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
class CogVideoXVideoToVideoPipelineTracking(CogVideoXVideoToVideoPipeline, DiffusionPipeline):
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModelTracking,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
@torch.no_grad()
def __call__(
self,
video: List[Image.Image] = None,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
strength: float = 0.8,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
tracking_maps: Optional[torch.Tensor] = None,
) -> Union[CogVideoXPipelineOutput, Tuple]:
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
strength=strength,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
video=video,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)
# 5. Prepare latents
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=prompt_embeds.dtype)
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
video,
batch_size * num_videos_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
latent_timestep,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
tracking_maps=tracking_maps_input,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)