vace-demo / vace /vace_ltx_inference.py
maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import random
import time
import torch
import numpy as np
from models.ltx.ltx_vace import LTXVace
from annotators.utils import save_one_video, save_one_image, get_annotator
MAX_HEIGHT = 720
MAX_WIDTH = 1280
MAX_NUM_FRAMES = 257
def get_total_gpu_memory():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return total_memory
return None
def seed_everething(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_parser():
parser = argparse.ArgumentParser(
description="Load models from separate directories and run the pipeline."
)
# Directories
parser.add_argument(
"--ckpt_path",
type=str,
default='models/VACE-LTX-Video-0.9/ltx-video-2b-v0.9.safetensors',
help="Path to a safetensors file that contains all model parts.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default='models/VACE-LTX-Video-0.9',
help="Path to a safetensors file that contains all model parts.",
)
parser.add_argument(
"--src_video",
type=str,
default=None,
help="The file of the source video. Default None.")
parser.add_argument(
"--src_mask",
type=str,
default=None,
help="The file of the source mask. Default None.")
parser.add_argument(
"--src_ref_images",
type=str,
default=None,
help="The file list of the source reference images. Separated by ','. Default None.")
parser.add_argument(
"--save_dir",
type=str,
default=None,
help="Path to the folder to save output video, if None will save in results/ directory.",
)
parser.add_argument("--seed", type=int, default="42")
# Pipeline parameters
parser.add_argument(
"--num_inference_steps", type=int, default=40, help="Number of inference steps"
)
parser.add_argument(
"--num_images_per_prompt",
type=int,
default=1,
help="Number of images per prompt",
)
parser.add_argument(
"--context_scale",
type=float,
default=1.0,
help="Context scale for the pipeline",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3,
help="Guidance scale for the pipeline",
)
parser.add_argument(
"--stg_scale",
type=float,
default=1,
help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.",
)
parser.add_argument(
"--stg_rescale",
type=float,
default=0.7,
help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.",
)
parser.add_argument(
"--stg_mode",
type=str,
default="stg_a",
help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.",
)
parser.add_argument(
"--stg_skip_layers",
type=str,
default="19",
help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.",
)
parser.add_argument(
"--image_cond_noise_scale",
type=float,
default=0.15,
help="Amount of noise to add to the conditioned image",
)
parser.add_argument(
"--height",
type=int,
default=512,
help="The height of the output video only if src_video is empty.",
)
parser.add_argument(
"--width",
type=int,
default=768,
help="The width of the output video only if src_video is empty.",
)
parser.add_argument(
"--num_frames",
type=int,
default=97,
help="The frames of the output video only if src_video is empty.",
)
parser.add_argument(
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
)
parser.add_argument(
"--precision",
choices=["bfloat16", "mixed_precision"],
default="bfloat16",
help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.",
)
# VAE noise augmentation
parser.add_argument(
"--decode_timestep",
type=float,
default=0.05,
help="Timestep for decoding noise",
)
parser.add_argument(
"--decode_noise_scale",
type=float,
default=0.025,
help="Noise level for decoding noise",
)
# Prompts
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Text prompt to guide generation",
)
parser.add_argument(
"--negative_prompt",
type=str,
default="worst quality, inconsistent motion, blurry, jittery, distorted",
help="Negative prompt for undesired features",
)
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)
parser.add_argument(
"--use_prompt_extend",
default='plain',
choices=['plain', 'ltx_en', 'ltx_en_ds'],
help="Whether to use prompt extend."
)
return parser
def main(args):
args = argparse.Namespace(**args) if isinstance(args, dict) else args
print(f"Running generation with arguments: {args}")
seed_everething(args.seed)
offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30
assert os.path.exists(args.ckpt_path) and os.path.exists(args.text_encoder_path)
ltx_vace = LTXVace(ckpt_path=args.ckpt_path,
text_encoder_path=args.text_encoder_path,
precision=args.precision,
stg_skip_layers=args.stg_skip_layers,
stg_mode=args.stg_mode,
offload_to_cpu=offload_to_cpu)
src_ref_images = args.src_ref_images.split(',') if args.src_ref_images is not None else []
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
prompt = get_annotator(config_type='prompt', config_task=args.use_prompt_extend, return_dict=False).forward(args.prompt)
print(f"Prompt extended from '{args.prompt}' to '{prompt}'")
else:
prompt = args.prompt
output = ltx_vace.generate(src_video=args.src_video,
src_mask=args.src_mask,
src_ref_images=src_ref_images,
prompt=prompt,
negative_prompt=args.negative_prompt,
seed=args.seed,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.num_images_per_prompt,
context_scale=args.context_scale,
guidance_scale=args.guidance_scale,
stg_scale=args.stg_scale,
stg_rescale=args.stg_rescale,
frame_rate=args.frame_rate,
image_cond_noise_scale=args.image_cond_noise_scale,
decode_timestep=args.decode_timestep,
decode_noise_scale=args.decode_noise_scale,
output_height=args.height,
output_width=args.width,
num_frames=args.num_frames)
if args.save_dir is None:
save_dir = os.path.join('results', 'vace_ltxv', time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())))
else:
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
frame_rate = output['info']['frame_rate']
ret_data = {}
if output['out_video'] is not None:
save_path = os.path.join(save_dir, 'out_video.mp4')
out_video = (torch.clamp(output['out_video'] / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
save_one_video(save_path, out_video, fps=frame_rate)
print(f"Save out_video to {save_path}")
ret_data['out_video'] = save_path
if output['src_video'] is not None:
save_path = os.path.join(save_dir, 'src_video.mp4')
src_video = (torch.clamp(output['src_video'] / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
save_one_video(save_path, src_video, fps=frame_rate)
print(f"Save src_video to {save_path}")
ret_data['src_video'] = save_path
if output['src_mask'] is not None:
save_path = os.path.join(save_dir, 'src_mask.mp4')
src_mask = (torch.clamp(output['src_mask'], min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
save_one_video(save_path, src_mask, fps=frame_rate)
print(f"Save src_mask to {save_path}")
ret_data['src_mask'] = save_path
if output['src_ref_images'] is not None:
for i, ref_img in enumerate(output['src_ref_images']): # [C, F=1, H, W]
save_path = os.path.join(save_dir, f'src_ref_image_{i}.png')
ref_img = (torch.clamp(ref_img.squeeze(1), min=0.0, max=1.0).permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8)
save_one_image(save_path, ref_img, use_type='pil')
print(f"Save src_ref_image_{i} to {save_path}")
ret_data[f'src_ref_image_{i}'] = save_path
return ret_data
if __name__ == "__main__":
args = get_parser().parse_args()
main(args)