vace-demo / vace /models /ltx /ltx_vace.py
maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from pathlib import Path
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.conditioning_method import ConditioningMethod
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
from .models.transformers.transformer3d import VaceTransformer3DModel
from .pipelines.pipeline_ltx_video import VaceLTXVideoPipeline
from ..utils.preprocessor import VaceImageProcessor, VaceVideoProcessor
class LTXVace():
def __init__(self, ckpt_path, text_encoder_path, precision='bfloat16', stg_skip_layers="19", stg_mode="stg_a", offload_to_cpu=False):
self.precision = precision
self.offload_to_cpu = offload_to_cpu
ckpt_path = Path(ckpt_path)
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
transformer = VaceTransformer3DModel.from_pretrained(ckpt_path)
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder")
patchifier = SymmetricPatchifier(patch_size=1)
tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
if torch.cuda.is_available():
transformer = transformer.cuda()
vae = vae.cuda()
text_encoder = text_encoder.cuda()
vae = vae.to(torch.bfloat16)
if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
transformer = transformer.to(torch.bfloat16)
text_encoder = text_encoder.to(torch.bfloat16)
# Set spatiotemporal guidance
self.skip_block_list = [int(x.strip()) for x in stg_skip_layers.split(",")]
self.skip_layer_strategy = (
SkipLayerStrategy.Attention
if stg_mode.lower() == "stg_a"
else SkipLayerStrategy.Residual
)
# Use submodels for the pipeline
submodel_dict = {
"transformer": transformer,
"patchifier": patchifier,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"vae": vae,
}
self.pipeline = VaceLTXVideoPipeline(**submodel_dict)
if torch.cuda.is_available():
self.pipeline = self.pipeline.to("cuda")
self.img_proc = VaceImageProcessor(downsample=[8,32,32], seq_len=384)
self.vid_proc = VaceVideoProcessor(downsample=[8,32,32],
min_area=512*768,
max_area=512*768,
min_fps=25,
max_fps=25,
seq_len=4992,
zero_start=True,
keep_last=True)
def generate(self, src_video=None, src_mask=None, src_ref_images=[], prompt="", negative_prompt="", seed=42,
num_inference_steps=40, num_images_per_prompt=1, context_scale=1.0, guidance_scale=3, stg_scale=1, stg_rescale=0.7,
frame_rate=25, image_cond_noise_scale=0.15, decode_timestep=0.05, decode_noise_scale=0.025,
output_height=512, output_width=768, num_frames=97):
# src_video: [c, t, h, w] / norm [-1, 1]
# src_mask : [c, t, h, w] / norm [0, 1]
# src_ref_images : [[c, h, w], [c, h, w], ...] / norm [-1, 1]
# image_size: (H, W)
if (src_video is not None and src_video != "") and (src_mask is not None and src_mask != ""):
src_video, src_mask, frame_ids, image_size, frame_rate = self.vid_proc.load_video_batch(src_video, src_mask)
if torch.all(src_mask > 0):
src_mask = torch.ones_like(src_video[:1, :, :, :])
else:
# bool_mask = src_mask > 0
# bool_mask = bool_mask.expand_as(src_video)
# src_video[bool_mask] = 0
src_mask = src_mask[:1, :, :, :]
src_mask = torch.clamp((src_mask + 1) / 2, min=0, max=1)
elif (src_video is not None and src_video != "") and (src_mask is None or src_mask == ""):
src_video, frame_ids, image_size, frame_rate = self.vid_proc.load_video_batch(src_video)
src_mask = torch.ones_like(src_video[:1, :, :, :])
else:
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
frame_ids = list(range(num_frames))
image_size = (output_height, output_width)
src_video = torch.zeros((3, num_frames, output_height, output_width))
src_mask = torch.ones((1, num_frames, output_height, output_width))
src_ref_images_prelist = src_ref_images
src_ref_images = []
for ref_image in src_ref_images_prelist:
if ref_image != "" and ref_image is not None:
src_ref_images.append(self.img_proc.load_image(ref_image)[0])
# Prepare input for the pipeline
num_frames = len(frame_ids)
sample = {
"src_video": [src_video],
"src_mask": [src_mask],
"src_ref_images": [src_ref_images],
"prompt": [prompt],
"prompt_attention_mask": None,
"negative_prompt": [negative_prompt],
"negative_prompt_attention_mask": None,
}
generator = torch.Generator(
device="cuda" if torch.cuda.is_available() else "cpu"
).manual_seed(seed)
output = self.pipeline(
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
context_scale=context_scale,
guidance_scale=guidance_scale,
skip_layer_strategy=self.skip_layer_strategy,
skip_block_list=self.skip_block_list,
stg_scale=stg_scale,
do_rescaling=stg_rescale != 1,
rescaling_scale=stg_rescale,
generator=generator,
output_type="pt",
callback_on_step_end=None,
height=image_size[0],
width=image_size[1],
num_frames=num_frames,
frame_rate=frame_rate,
**sample,
is_video=True,
vae_per_channel_normalize=True,
conditioning_method=ConditioningMethod.UNCONDITIONAL,
image_cond_noise_scale=image_cond_noise_scale,
decode_timestep=decode_timestep,
decode_noise_scale=decode_noise_scale,
mixed_precision=(self.precision in "mixed_precision"),
offload_to_cpu=self.offload_to_cpu,
)
gen_video = output.images[0]
gen_video = gen_video.to(torch.float32) if gen_video.dtype == torch.bfloat16 else gen_video
info = output.info
ret_data = {
"out_video": gen_video,
"src_video": src_video,
"src_mask": src_mask,
"src_ref_images": src_ref_images,
"info": info
}
return ret_data