Wan2.1 / generate.py
rahul7star's picture
Update generate.py
bc1f660 verified
import logging
import os
import torch
import torch.distributed as dist
from PIL import Image
from datetime import datetime
from tqdm import tqdm
def generate(args):
print("call generate")
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
# Set device: use CPU if specified, else use GPU based on rank
if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
device = torch.device("cpu")
print("Using CPU for model inference.")
else:
device = local_rank
torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU
print(f"Using GPU: {device}")
_init_logging(rank)
# Distributed setup
if world_size > 1:
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_size,
ulysses_degree=args.ulysses_size,
)
# Handle prompt extension if needed
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task,
device=rank)
else:
raise NotImplementedError(f"Unsupported prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
print(f"Generation job args: {args}")
print(f"Generation model config: {cfg}")
# Broadcast base seed across distributed workers
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
# Set prompt and task details
if "t2v" in args.task or "t2i" in args.task:
print("tect to x ")
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
print(f"Input prompt: {args.prompt}")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(f"Prompt extension failed: {prompt_output.message}")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
print(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
try:
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=33,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
except Exception as e:
logging.error(f"Error during video generation: {e}")
raise
else: # image-to-video
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
args.image = EXAMPLE_PROMPT[args.task]["image"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}")
img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=img,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(f"Prompt extension failed: {prompt_output.message}")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
print("Generating video ..6666666666666666")
try:
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=33,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
except Exception as e:
logging.error(f"Error during video generation: {e}")
raise
# Save the output video or image
if rank == 0:
if args.save_file is None:
args.save_file = f"generated_video.mp4"
try:
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True)
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(tensor=video, save_file=args.save_file)
except Exception as e:
logging.error(f"Error saving output: {e}")
raise