|
|
|
|
|
|
|
import argparse |
|
import time |
|
from datetime import datetime |
|
import logging |
|
import os |
|
import sys |
|
import warnings |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
import torch, random |
|
import torch.distributed as dist |
|
from PIL import Image |
|
|
|
import wan |
|
from wan.utils.utils import cache_video, cache_image, str2bool |
|
|
|
from models.wan import WanVace |
|
from models.wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES |
|
from annotators.utils import get_annotator |
|
|
|
EXAMPLE_PROMPT = { |
|
"vace-1.3B": { |
|
"src_ref_images": './bag.jpg,./heben.png', |
|
"prompt": "优雅的女士在精品店仔细挑选包包,她身穿一袭黑色修身连衣裙,搭配珍珠项链,展现出成熟女性的魅力。手中拿着一款复古风格的棕色皮质半月形手提包,正细致地观察其工艺与质地。店内灯光柔和,木质装潢营造出温馨而高级的氛围。中景,侧拍捕捉女士挑选瞬间,展现其品味与气质。" |
|
} |
|
} |
|
|
|
|
|
def validate_args(args): |
|
|
|
assert args.ckpt_dir is not None, "Please specify the checkpoint directory." |
|
assert args.model_name in WAN_CONFIGS, f"Unsupport model name: {args.model_name}" |
|
assert args.model_name in EXAMPLE_PROMPT, f"Unsupport model name: {args.model_name}" |
|
|
|
|
|
if args.sample_steps is None: |
|
args.sample_steps = 25 |
|
|
|
if args.sample_shift is None: |
|
args.sample_shift = 8.0 |
|
|
|
|
|
if args.frame_num is None: |
|
args.frame_num = 81 |
|
|
|
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( |
|
0, sys.maxsize) |
|
|
|
assert args.size in SUPPORTED_SIZES[ |
|
args.model_name], f"Unsupport size {args.size} for model name {args.model_name}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.model_name])}" |
|
return args |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
description="Generate a image or video from a text prompt or image using Wan" |
|
) |
|
parser.add_argument( |
|
"--model_name", |
|
type=str, |
|
default="vace-1.3B", |
|
choices=list(WAN_CONFIGS.keys()), |
|
help="The model name to run.") |
|
parser.add_argument( |
|
"--size", |
|
type=str, |
|
default="480*832", |
|
choices=list(SIZE_CONFIGS.keys()), |
|
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." |
|
) |
|
parser.add_argument( |
|
"--frame_num", |
|
type=int, |
|
default=81, |
|
help="How many frames to sample from a image or video. The number should be 4n+1" |
|
) |
|
parser.add_argument( |
|
"--ckpt_dir", |
|
type=str, |
|
default='models/VACE-Wan2.1-1.3B-Preview', |
|
help="The path to the checkpoint directory.") |
|
parser.add_argument( |
|
"--offload_model", |
|
type=str2bool, |
|
default=None, |
|
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." |
|
) |
|
parser.add_argument( |
|
"--ulysses_size", |
|
type=int, |
|
default=1, |
|
help="The size of the ulysses parallelism in DiT.") |
|
parser.add_argument( |
|
"--ring_size", |
|
type=int, |
|
default=1, |
|
help="The size of the ring attention parallelism in DiT.") |
|
parser.add_argument( |
|
"--t5_fsdp", |
|
action="store_true", |
|
default=False, |
|
help="Whether to use FSDP for T5.") |
|
parser.add_argument( |
|
"--t5_cpu", |
|
action="store_true", |
|
default=False, |
|
help="Whether to place T5 model on CPU.") |
|
parser.add_argument( |
|
"--dit_fsdp", |
|
action="store_true", |
|
default=False, |
|
help="Whether to use FSDP for DiT.") |
|
parser.add_argument( |
|
"--save_dir", |
|
type=str, |
|
default=None, |
|
help="The file to save the generated image or video to.") |
|
parser.add_argument( |
|
"--src_video", |
|
type=str, |
|
default=None, |
|
help="The file of the source video. Default None.") |
|
parser.add_argument( |
|
"--src_mask", |
|
type=str, |
|
default=None, |
|
help="The file of the source mask. Default None.") |
|
parser.add_argument( |
|
"--src_ref_images", |
|
type=str, |
|
default=None, |
|
help="The file list of the source reference images. Separated by ','. Default None.") |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
default=None, |
|
help="The prompt to generate the image or video from.") |
|
parser.add_argument( |
|
"--use_prompt_extend", |
|
default='plain', |
|
choices=['plain', 'wan_zh', 'wan_en', 'wan_zh_ds', 'wan_en_ds'], |
|
help="Whether to use prompt extend.") |
|
parser.add_argument( |
|
"--base_seed", |
|
type=int, |
|
default=2025, |
|
help="The seed to use for generating the image or video.") |
|
parser.add_argument( |
|
"--sample_solver", |
|
type=str, |
|
default='unipc', |
|
choices=['unipc', 'dpm++'], |
|
help="The solver used to sample.") |
|
parser.add_argument( |
|
"--sample_steps", type=int, default=None, help="The sampling steps.") |
|
parser.add_argument( |
|
"--sample_shift", |
|
type=float, |
|
default=None, |
|
help="Sampling shift factor for flow matching schedulers.") |
|
parser.add_argument( |
|
"--sample_guide_scale", |
|
type=float, |
|
default=6.0, |
|
help="Classifier free guidance scale.") |
|
return parser |
|
|
|
|
|
def _init_logging(rank): |
|
|
|
if rank == 0: |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="[%(asctime)s] %(levelname)s: %(message)s", |
|
handlers=[logging.StreamHandler(stream=sys.stdout)]) |
|
else: |
|
logging.basicConfig(level=logging.ERROR) |
|
|
|
|
|
def main(args): |
|
args = argparse.Namespace(**args) if isinstance(args, dict) else args |
|
args = validate_args(args) |
|
|
|
rank = int(os.getenv("RANK", 0)) |
|
world_size = int(os.getenv("WORLD_SIZE", 1)) |
|
local_rank = int(os.getenv("LOCAL_RANK", 0)) |
|
device = local_rank |
|
_init_logging(rank) |
|
|
|
if args.offload_model is None: |
|
args.offload_model = False if world_size > 1 else True |
|
logging.info( |
|
f"offload_model is not specified, set to {args.offload_model}.") |
|
if world_size > 1: |
|
torch.cuda.set_device(local_rank) |
|
dist.init_process_group( |
|
backend="nccl", |
|
init_method="env://", |
|
rank=rank, |
|
world_size=world_size) |
|
else: |
|
assert not ( |
|
args.t5_fsdp or args.dit_fsdp |
|
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." |
|
assert not ( |
|
args.ulysses_size > 1 or args.ring_size > 1 |
|
), f"context parallel are not supported in non-distributed environments." |
|
|
|
if args.ulysses_size > 1 or args.ring_size > 1: |
|
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." |
|
from xfuser.core.distributed import (initialize_model_parallel, |
|
init_distributed_environment) |
|
init_distributed_environment( |
|
rank=dist.get_rank(), world_size=dist.get_world_size()) |
|
|
|
initialize_model_parallel( |
|
sequence_parallel_degree=dist.get_world_size(), |
|
ring_degree=args.ring_size, |
|
ulysses_degree=args.ulysses_size, |
|
) |
|
|
|
if args.use_prompt_extend and args.use_prompt_extend != 'plain': |
|
prompt_expander = get_annotator(config_type='prompt', config_task=args.use_prompt_extend, return_dict=False) |
|
|
|
cfg = WAN_CONFIGS[args.model_name] |
|
if args.ulysses_size > 1: |
|
assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`." |
|
|
|
logging.info(f"Generation job args: {args}") |
|
logging.info(f"Generation model config: {cfg}") |
|
|
|
if dist.is_initialized(): |
|
base_seed = [args.base_seed] if rank == 0 else [None] |
|
dist.broadcast_object_list(base_seed, src=0) |
|
args.base_seed = base_seed[0] |
|
|
|
if args.prompt is None: |
|
args.prompt = EXAMPLE_PROMPT[args.model_name]["prompt"] |
|
args.src_video = EXAMPLE_PROMPT[args.model_name].get("src_video", None) |
|
args.src_mask = EXAMPLE_PROMPT[args.model_name].get("src_mask", None) |
|
args.src_ref_images = EXAMPLE_PROMPT[args.model_name].get("src_ref_images", None) |
|
|
|
logging.info(f"Input prompt: {args.prompt}") |
|
if args.use_prompt_extend and args.use_prompt_extend != 'plain': |
|
logging.info("Extending prompt ...") |
|
if rank == 0: |
|
prompt = prompt_expander.forward(args.prompt) |
|
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'") |
|
input_prompt = [prompt] |
|
else: |
|
input_prompt = [None] |
|
if dist.is_initialized(): |
|
dist.broadcast_object_list(input_prompt, src=0) |
|
args.prompt = input_prompt[0] |
|
logging.info(f"Extended prompt: {args.prompt}") |
|
|
|
logging.info("Creating WanT2V pipeline.") |
|
wan_vace = WanVace( |
|
config=cfg, |
|
checkpoint_dir=args.ckpt_dir, |
|
device_id=device, |
|
rank=rank, |
|
t5_fsdp=args.t5_fsdp, |
|
dit_fsdp=args.dit_fsdp, |
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), |
|
t5_cpu=args.t5_cpu, |
|
) |
|
|
|
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video], |
|
[args.src_mask], |
|
[None if args.src_ref_images is None else args.src_ref_images.split(',')], |
|
args.frame_num, SIZE_CONFIGS[args.size], device) |
|
|
|
logging.info(f"Generating video...") |
|
video = wan_vace.generate( |
|
args.prompt, |
|
src_video, |
|
src_mask, |
|
src_ref_images, |
|
size=SIZE_CONFIGS[args.size], |
|
frame_num=args.frame_num, |
|
shift=args.sample_shift, |
|
sample_solver=args.sample_solver, |
|
sampling_steps=args.sample_steps, |
|
guide_scale=args.sample_guide_scale, |
|
seed=args.base_seed, |
|
offload_model=args.offload_model) |
|
|
|
ret_data = {} |
|
if rank == 0: |
|
if args.save_dir is None: |
|
save_dir = os.path.join('results', 'vace_wan_1.3b', time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))) |
|
else: |
|
save_dir = args.save_dir |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
|
|
save_file = os.path.join(save_dir, 'out_video.mp4') |
|
cache_video( |
|
tensor=video[None], |
|
save_file=save_file, |
|
fps=cfg.sample_fps, |
|
nrow=1, |
|
normalize=True, |
|
value_range=(-1, 1)) |
|
logging.info(f"Saving generated video to {save_file}") |
|
ret_data['out_video'] = save_file |
|
|
|
save_file = os.path.join(save_dir, 'src_video.mp4') |
|
cache_video( |
|
tensor=src_video[0][None], |
|
save_file=save_file, |
|
fps=cfg.sample_fps, |
|
nrow=1, |
|
normalize=True, |
|
value_range=(-1, 1)) |
|
logging.info(f"Saving src_video to {save_file}") |
|
ret_data['src_video'] = save_file |
|
|
|
save_file = os.path.join(save_dir, 'src_mask.mp4') |
|
cache_video( |
|
tensor=src_mask[0][None], |
|
save_file=save_file, |
|
fps=cfg.sample_fps, |
|
nrow=1, |
|
normalize=True, |
|
value_range=(0, 1)) |
|
logging.info(f"Saving src_mask to {save_file}") |
|
ret_data['src_mask'] = save_file |
|
|
|
if src_ref_images[0] is not None: |
|
for i, ref_img in enumerate(src_ref_images[0]): |
|
save_file = os.path.join(save_dir, f'src_ref_image_{i}.png') |
|
cache_image( |
|
tensor=ref_img[:, 0, ...], |
|
save_file=save_file, |
|
nrow=1, |
|
normalize=True, |
|
value_range=(-1, 1)) |
|
logging.info(f"Saving src_ref_image_{i} to {save_file}") |
|
ret_data[f'src_ref_image_{i}'] = save_file |
|
logging.info("Finished.") |
|
return ret_data |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_parser().parse_args() |
|
main(args) |
|
|