import logging import os import torch import torch.distributed as dist from PIL import Image from datetime import datetime from tqdm import tqdm def generate(args): print("call generate") rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) local_rank = int(os.getenv("LOCAL_RANK", 0)) # Set device: use CPU if specified, else use GPU based on rank if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments device = torch.device("cpu") print("Using CPU for model inference.") else: device = local_rank torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU print(f"Using GPU: {device}") _init_logging(rank) # Distributed setup if world_size > 1: dist.init_process_group( backend="nccl", init_method="env://", rank=rank, world_size=world_size) else: assert not ( args.t5_fsdp or args.dit_fsdp ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." if args.ulysses_size > 1 or args.ring_size > 1: assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." from xfuser.core.distributed import (initialize_model_parallel, init_distributed_environment) init_distributed_environment( rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), ring_degree=args.ring_size, ulysses_degree=args.ulysses_size, ) # Handle prompt extension if needed if args.use_prompt_extend: if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl="i2v" in args.task) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl="i2v" in args.task, device=rank) else: raise NotImplementedError(f"Unsupported prompt_extend_method: {args.prompt_extend_method}") cfg = WAN_CONFIGS[args.task] print(f"Generation job args: {args}") print(f"Generation model config: {cfg}") # Broadcast base seed across distributed workers if dist.is_initialized(): base_seed = [args.base_seed] if rank == 0 else [None] dist.broadcast_object_list(base_seed, src=0) args.base_seed = base_seed[0] # Set prompt and task details if "t2v" in args.task or "t2i" in args.task: print("tect to x ") if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] print(f"Input prompt: {args.prompt}") if args.use_prompt_extend: logging.info("Extending prompt ...") if rank == 0: prompt_output = prompt_expander( args.prompt, tar_lang=args.prompt_extend_target_lang, seed=args.base_seed) if prompt_output.status == False: logging.info(f"Prompt extension failed: {prompt_output.message}") input_prompt = args.prompt else: input_prompt = prompt_output.prompt else: input_prompt = [None] if dist.is_initialized(): dist.broadcast_object_list(input_prompt, src=0) args.prompt = input_prompt[0] logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.") wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=device, rank=rank, t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) print(f"Generating {'image' if 't2i' in args.task else 'video'} ...") try: video = wan_t2v.generate( args.prompt, size=SIZE_CONFIGS[args.size], frame_num=33, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) except Exception as e: logging.error(f"Error during video generation: {e}") raise else: # image-to-video if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if args.image is None: args.image = EXAMPLE_PROMPT[args.task]["image"] logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input image: {args.image}") img = Image.open(args.image).convert("RGB") if args.use_prompt_extend: logging.info("Extending prompt ...") if rank == 0: prompt_output = prompt_expander( args.prompt, tar_lang=args.prompt_extend_target_lang, image=img, seed=args.base_seed) if prompt_output.status == False: logging.info(f"Prompt extension failed: {prompt_output.message}") input_prompt = args.prompt else: input_prompt = prompt_output.prompt else: input_prompt = [None] if dist.is_initialized(): dist.broadcast_object_list(input_prompt, src=0) args.prompt = input_prompt[0] logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanI2V pipeline.") wan_i2v = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=device, rank=rank, t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) print("Generating video ..6666666666666666") try: video = wan_i2v.generate( args.prompt, img, max_area=MAX_AREA_CONFIGS[args.size], frame_num=33, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) except Exception as e: logging.error(f"Error during video generation: {e}") raise # Save the output video or image if rank == 0: if args.save_file is None: args.save_file = f"generated_video.mp4" try: if "t2i" in args.task: logging.info(f"Saving generated image to {args.save_file}") cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True) else: logging.info(f"Saving generated video to {args.save_file}") cache_video(tensor=video, save_file=args.save_file) except Exception as e: logging.error(f"Error saving output: {e}") raise