Spaces:
Running
Running
import logging | |
import os | |
import torch | |
import torch.distributed as dist | |
from PIL import Image | |
from datetime import datetime | |
from tqdm import tqdm | |
def generate(args): | |
print("call generate") | |
rank = int(os.getenv("RANK", 0)) | |
world_size = int(os.getenv("WORLD_SIZE", 1)) | |
local_rank = int(os.getenv("LOCAL_RANK", 0)) | |
# Set device: use CPU if specified, else use GPU based on rank | |
if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments | |
device = torch.device("cpu") | |
print("Using CPU for model inference.") | |
else: | |
device = local_rank | |
torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU | |
print(f"Using GPU: {device}") | |
_init_logging(rank) | |
# Distributed setup | |
if world_size > 1: | |
dist.init_process_group( | |
backend="nccl", | |
init_method="env://", | |
rank=rank, | |
world_size=world_size) | |
else: | |
assert not ( | |
args.t5_fsdp or args.dit_fsdp | |
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." | |
if args.ulysses_size > 1 or args.ring_size > 1: | |
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." | |
from xfuser.core.distributed import (initialize_model_parallel, | |
init_distributed_environment) | |
init_distributed_environment( | |
rank=dist.get_rank(), world_size=dist.get_world_size()) | |
initialize_model_parallel( | |
sequence_parallel_degree=dist.get_world_size(), | |
ring_degree=args.ring_size, | |
ulysses_degree=args.ulysses_size, | |
) | |
# Handle prompt extension if needed | |
if args.use_prompt_extend: | |
if args.prompt_extend_method == "dashscope": | |
prompt_expander = DashScopePromptExpander( | |
model_name=args.prompt_extend_model, is_vl="i2v" in args.task) | |
elif args.prompt_extend_method == "local_qwen": | |
prompt_expander = QwenPromptExpander( | |
model_name=args.prompt_extend_model, | |
is_vl="i2v" in args.task, | |
device=rank) | |
else: | |
raise NotImplementedError(f"Unsupported prompt_extend_method: {args.prompt_extend_method}") | |
cfg = WAN_CONFIGS[args.task] | |
print(f"Generation job args: {args}") | |
print(f"Generation model config: {cfg}") | |
# Broadcast base seed across distributed workers | |
if dist.is_initialized(): | |
base_seed = [args.base_seed] if rank == 0 else [None] | |
dist.broadcast_object_list(base_seed, src=0) | |
args.base_seed = base_seed[0] | |
# Set prompt and task details | |
if "t2v" in args.task or "t2i" in args.task: | |
print("tect to x ") | |
if args.prompt is None: | |
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] | |
print(f"Input prompt: {args.prompt}") | |
if args.use_prompt_extend: | |
logging.info("Extending prompt ...") | |
if rank == 0: | |
prompt_output = prompt_expander( | |
args.prompt, | |
tar_lang=args.prompt_extend_target_lang, | |
seed=args.base_seed) | |
if prompt_output.status == False: | |
logging.info(f"Prompt extension failed: {prompt_output.message}") | |
input_prompt = args.prompt | |
else: | |
input_prompt = prompt_output.prompt | |
else: | |
input_prompt = [None] | |
if dist.is_initialized(): | |
dist.broadcast_object_list(input_prompt, src=0) | |
args.prompt = input_prompt[0] | |
logging.info(f"Extended prompt: {args.prompt}") | |
logging.info("Creating WanT2V pipeline.") | |
wan_t2v = wan.WanT2V( | |
config=cfg, | |
checkpoint_dir=args.ckpt_dir, | |
device_id=device, | |
rank=rank, | |
t5_fsdp=args.t5_fsdp, | |
dit_fsdp=args.dit_fsdp, | |
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), | |
t5_cpu=args.t5_cpu, | |
) | |
print(f"Generating {'image' if 't2i' in args.task else 'video'} ...") | |
try: | |
video = wan_t2v.generate( | |
args.prompt, | |
size=SIZE_CONFIGS[args.size], | |
frame_num=33, | |
shift=args.sample_shift, | |
sample_solver=args.sample_solver, | |
sampling_steps=args.sample_steps, | |
guide_scale=args.sample_guide_scale, | |
seed=args.base_seed, | |
offload_model=args.offload_model) | |
except Exception as e: | |
logging.error(f"Error during video generation: {e}") | |
raise | |
else: # image-to-video | |
if args.prompt is None: | |
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] | |
if args.image is None: | |
args.image = EXAMPLE_PROMPT[args.task]["image"] | |
logging.info(f"Input prompt: {args.prompt}") | |
logging.info(f"Input image: {args.image}") | |
img = Image.open(args.image).convert("RGB") | |
if args.use_prompt_extend: | |
logging.info("Extending prompt ...") | |
if rank == 0: | |
prompt_output = prompt_expander( | |
args.prompt, | |
tar_lang=args.prompt_extend_target_lang, | |
image=img, | |
seed=args.base_seed) | |
if prompt_output.status == False: | |
logging.info(f"Prompt extension failed: {prompt_output.message}") | |
input_prompt = args.prompt | |
else: | |
input_prompt = prompt_output.prompt | |
else: | |
input_prompt = [None] | |
if dist.is_initialized(): | |
dist.broadcast_object_list(input_prompt, src=0) | |
args.prompt = input_prompt[0] | |
logging.info(f"Extended prompt: {args.prompt}") | |
logging.info("Creating WanI2V pipeline.") | |
wan_i2v = wan.WanI2V( | |
config=cfg, | |
checkpoint_dir=args.ckpt_dir, | |
device_id=device, | |
rank=rank, | |
t5_fsdp=args.t5_fsdp, | |
dit_fsdp=args.dit_fsdp, | |
use_usp=(args.ulysses_size > 1 or args.ring_size > 1), | |
t5_cpu=args.t5_cpu, | |
) | |
print("Generating video ..6666666666666666") | |
try: | |
video = wan_i2v.generate( | |
args.prompt, | |
img, | |
max_area=MAX_AREA_CONFIGS[args.size], | |
frame_num=33, | |
shift=args.sample_shift, | |
sample_solver=args.sample_solver, | |
sampling_steps=args.sample_steps, | |
guide_scale=args.sample_guide_scale, | |
seed=args.base_seed, | |
offload_model=args.offload_model) | |
except Exception as e: | |
logging.error(f"Error during video generation: {e}") | |
raise | |
# Save the output video or image | |
if rank == 0: | |
if args.save_file is None: | |
args.save_file = f"generated_video.mp4" | |
try: | |
if "t2i" in args.task: | |
logging.info(f"Saving generated image to {args.save_file}") | |
cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True) | |
else: | |
logging.info(f"Saving generated video to {args.save_file}") | |
cache_video(tensor=video, save_file=args.save_file) | |
except Exception as e: | |
logging.error(f"Error saving output: {e}") | |
raise | |