rahul7star commited on
Commit
ddacd23
·
verified ·
1 Parent(s): 711b244

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +24 -254
generate.py CHANGED
@@ -1,224 +1,21 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import argparse
3
- from datetime import datetime
4
- import logging
5
- import os
6
- import sys
7
- import warnings
8
-
9
- warnings.filterwarnings('ignore')
10
-
11
- import torch, random
12
- import torch.distributed as dist
13
- from PIL import Image
14
-
15
- import wan
16
- from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
17
- from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
18
- from wan.utils.utils import cache_video, cache_image, str2bool
19
-
20
- EXAMPLE_PROMPT = {
21
- "t2v-1.3B": {
22
- "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
23
- },
24
- "t2v-14B": {
25
- "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
26
- },
27
- "t2i-14B": {
28
- "prompt": "一个朴素端庄的美人",
29
- },
30
- "i2v-14B": {
31
- "prompt":
32
- "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
33
- "image":
34
- "examples/i2v_input.JPG",
35
- },
36
- }
37
-
38
-
39
- def _validate_args(args):
40
- # Basic check
41
- assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
42
- assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
43
- assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
44
-
45
- # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
46
- if args.sample_steps is None:
47
- args.sample_steps = 40 if "i2v" in args.task else 50
48
-
49
- if args.sample_shift is None:
50
- args.sample_shift = 5.0
51
- if "i2v" in args.task and args.size in ["832*480", "480*832"]:
52
- args.sample_shift = 3.0
53
-
54
- # The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
55
- if args.frame_num is None:
56
- args.frame_num = 1 if "t2i" in args.task else 81
57
-
58
- # T2I frame_num check
59
- if "t2i" in args.task:
60
- assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
61
-
62
- args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
63
- 0, sys.maxsize)
64
- # Size check
65
- assert args.size in SUPPORTED_SIZES[
66
- args.
67
- task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
68
-
69
-
70
- def _parse_args():
71
- parser = argparse.ArgumentParser(
72
- description="Generate a image or video from a text prompt or image using Wan"
73
- )
74
- parser.add_argument(
75
- "--task",
76
- type=str,
77
- default="t2v-14B",
78
- choices=list(WAN_CONFIGS.keys()),
79
- help="The task to run.")
80
- parser.add_argument(
81
- "--size",
82
- type=str,
83
- default="1280*720",
84
- choices=list(SIZE_CONFIGS.keys()),
85
- help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
86
- )
87
- parser.add_argument(
88
- "--frame_num",
89
- type=int,
90
- default=None,
91
- help="How many frames to sample from a image or video. The number should be 4n+1"
92
- )
93
- parser.add_argument(
94
- "--ckpt_dir",
95
- type=str,
96
- default=None,
97
- help="The path to the checkpoint directory.")
98
- parser.add_argument(
99
- "--offload_model",
100
- type=str2bool,
101
- default=None,
102
- help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
103
- )
104
- parser.add_argument(
105
- "--ulysses_size",
106
- type=int,
107
- default=1,
108
- help="The size of the ulysses parallelism in DiT.")
109
- parser.add_argument(
110
- "--ring_size",
111
- type=int,
112
- default=1,
113
- help="The size of the ring attention parallelism in DiT.")
114
- parser.add_argument(
115
- "--t5_fsdp",
116
- action="store_true",
117
- default=False,
118
- help="Whether to use FSDP for T5.")
119
- parser.add_argument(
120
- "--t5_cpu",
121
- action="store_true",
122
- default=False,
123
- help="Whether to place T5 model on CPU.")
124
- parser.add_argument(
125
- "--dit_fsdp",
126
- action="store_true",
127
- default=False,
128
- help="Whether to use FSDP for DiT.")
129
- parser.add_argument(
130
- "--save_file",
131
- type=str,
132
- default=None,
133
- help="The file to save the generated image or video to.")
134
- parser.add_argument(
135
- "--prompt",
136
- type=str,
137
- default=None,
138
- help="The prompt to generate the image or video from.")
139
- parser.add_argument(
140
- "--use_prompt_extend",
141
- action="store_true",
142
- default=False,
143
- help="Whether to use prompt extend.")
144
- parser.add_argument(
145
- "--prompt_extend_method",
146
- type=str,
147
- default="local_qwen",
148
- choices=["dashscope", "local_qwen"],
149
- help="The prompt extend method to use.")
150
- parser.add_argument(
151
- "--prompt_extend_model",
152
- type=str,
153
- default=None,
154
- help="The prompt extend model to use.")
155
- parser.add_argument(
156
- "--prompt_extend_target_lang",
157
- type=str,
158
- default="ch",
159
- choices=["ch", "en"],
160
- help="The target language of prompt extend.")
161
- parser.add_argument(
162
- "--base_seed",
163
- type=int,
164
- default=-1,
165
- help="The seed to use for generating the image or video.")
166
- parser.add_argument(
167
- "--image",
168
- type=str,
169
- default=None,
170
- help="The image to generate the video from.")
171
- parser.add_argument(
172
- "--sample_solver",
173
- type=str,
174
- default='unipc',
175
- choices=['unipc', 'dpm++'],
176
- help="The solver used to sample.")
177
- parser.add_argument(
178
- "--sample_steps", type=int, default=None, help="The sampling steps.")
179
- parser.add_argument(
180
- "--sample_shift",
181
- type=float,
182
- default=None,
183
- help="Sampling shift factor for flow matching schedulers.")
184
- parser.add_argument(
185
- "--sample_guide_scale",
186
- type=float,
187
- default=5.0,
188
- help="Classifier free guidance scale.")
189
-
190
- args = parser.parse_args()
191
-
192
- _validate_args(args)
193
-
194
- return args
195
-
196
-
197
- def _init_logging(rank):
198
- # logging
199
- if rank == 0:
200
- # set format
201
- logging.basicConfig(
202
- level=logging.INFO,
203
- format="[%(asctime)s] %(levelname)s: %(message)s",
204
- handlers=[logging.StreamHandler(stream=sys.stdout)])
205
- else:
206
- logging.basicConfig(level=logging.ERROR)
207
-
208
-
209
  def generate(args):
210
  rank = int(os.getenv("RANK", 0))
211
  world_size = int(os.getenv("WORLD_SIZE", 1))
212
  local_rank = int(os.getenv("LOCAL_RANK", 0))
213
- device = local_rank
 
 
 
 
 
 
 
 
 
214
  _init_logging(rank)
215
 
216
- if args.offload_model is None:
217
- args.offload_model = False if world_size > 1 else True
218
- logging.info(
219
- f"offload_model is not specified, set to {args.offload_model}.")
220
  if world_size > 1:
221
- torch.cuda.set_device(local_rank)
222
  dist.init_process_group(
223
  backend="nccl",
224
  init_method="env://",
@@ -228,9 +25,6 @@ def generate(args):
228
  assert not (
229
  args.t5_fsdp or args.dit_fsdp
230
  ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
231
- assert not (
232
- args.ulysses_size > 1 or args.ring_size > 1
233
- ), f"context parallel are not supported in non-distributed environments."
234
 
235
  if args.ulysses_size > 1 or args.ring_size > 1:
236
  assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
@@ -245,6 +39,7 @@ def generate(args):
245
  ulysses_degree=args.ulysses_size,
246
  )
247
 
 
248
  if args.use_prompt_extend:
249
  if args.prompt_extend_method == "dashscope":
250
  prompt_expander = DashScopePromptExpander(
@@ -255,25 +50,24 @@ def generate(args):
255
  is_vl="i2v" in args.task,
256
  device=rank)
257
  else:
258
- raise NotImplementedError(
259
- f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
260
 
261
  cfg = WAN_CONFIGS[args.task]
262
- if args.ulysses_size > 1:
263
- assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
264
-
265
  logging.info(f"Generation job args: {args}")
266
  logging.info(f"Generation model config: {cfg}")
267
 
 
268
  if dist.is_initialized():
269
  base_seed = [args.base_seed] if rank == 0 else [None]
270
  dist.broadcast_object_list(base_seed, src=0)
271
  args.base_seed = base_seed[0]
272
 
 
273
  if "t2v" in args.task or "t2i" in args.task:
274
  if args.prompt is None:
275
  args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
276
  logging.info(f"Input prompt: {args.prompt}")
 
277
  if args.use_prompt_extend:
278
  logging.info("Extending prompt ...")
279
  if rank == 0:
@@ -282,13 +76,10 @@ def generate(args):
282
  tar_lang=args.prompt_extend_target_lang,
283
  seed=args.base_seed)
284
  if prompt_output.status == False:
285
- logging.info(
286
- f"Extending prompt failed: {prompt_output.message}")
287
- logging.info("Falling back to original prompt.")
288
  input_prompt = args.prompt
289
  else:
290
  input_prompt = prompt_output.prompt
291
- input_prompt = [input_prompt]
292
  else:
293
  input_prompt = [None]
294
  if dist.is_initialized():
@@ -308,8 +99,7 @@ def generate(args):
308
  t5_cpu=args.t5_cpu,
309
  )
310
 
311
- logging.info(
312
- f"Generating {'image' if 't2i' in args.task else 'video'} ...")
313
  video = wan_t2v.generate(
314
  args.prompt,
315
  size=SIZE_CONFIGS[args.size],
@@ -321,7 +111,7 @@ def generate(args):
321
  seed=args.base_seed,
322
  offload_model=args.offload_model)
323
 
324
- else:
325
  if args.prompt is None:
326
  args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
327
  if args.image is None:
@@ -339,13 +129,10 @@ def generate(args):
339
  image=img,
340
  seed=args.base_seed)
341
  if prompt_output.status == False:
342
- logging.info(
343
- f"Extending prompt failed: {prompt_output.message}")
344
- logging.info("Falling back to original prompt.")
345
  input_prompt = args.prompt
346
  else:
347
  input_prompt = prompt_output.prompt
348
- input_prompt = [input_prompt]
349
  else:
350
  input_prompt = [None]
351
  if dist.is_initialized():
@@ -378,34 +165,17 @@ def generate(args):
378
  seed=args.base_seed,
379
  offload_model=args.offload_model)
380
 
 
381
  if rank == 0:
382
  if args.save_file is None:
383
  formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
384
- formatted_prompt = args.prompt.replace(" ", "_").replace("/",
385
- "_")[:50]
386
  suffix = '.png' if "t2i" in args.task else '.mp4'
387
  args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
388
 
389
  if "t2i" in args.task:
390
  logging.info(f"Saving generated image to {args.save_file}")
391
- cache_image(
392
- tensor=video.squeeze(1)[None],
393
- save_file=args.save_file,
394
- nrow=1,
395
- normalize=True,
396
- value_range=(-1, 1))
397
  else:
398
  logging.info(f"Saving generated video to {args.save_file}")
399
- cache_video(
400
- tensor=video[None],
401
- save_file=args.save_file,
402
- fps=cfg.sample_fps,
403
- nrow=1,
404
- normalize=True,
405
- value_range=(-1, 1))
406
- logging.info("Finished.")
407
-
408
-
409
- if __name__ == "__main__":
410
- args = _parse_args()
411
- generate(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def generate(args):
2
  rank = int(os.getenv("RANK", 0))
3
  world_size = int(os.getenv("WORLD_SIZE", 1))
4
  local_rank = int(os.getenv("LOCAL_RANK", 0))
5
+
6
+ # Set device: use CPU if specified, else use GPU based on rank
7
+ if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
8
+ device = torch.device("cpu")
9
+ logging.info("Using CPU for model inference.")
10
+ else:
11
+ device = local_rank
12
+ torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU
13
+ logging.info(f"Using GPU: {device}")
14
+
15
  _init_logging(rank)
16
 
17
+ # Distributed setup
 
 
 
18
  if world_size > 1:
 
19
  dist.init_process_group(
20
  backend="nccl",
21
  init_method="env://",
 
25
  assert not (
26
  args.t5_fsdp or args.dit_fsdp
27
  ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
 
 
 
28
 
29
  if args.ulysses_size > 1 or args.ring_size > 1:
30
  assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
 
39
  ulysses_degree=args.ulysses_size,
40
  )
41
 
42
+ # Handle prompt extension if needed
43
  if args.use_prompt_extend:
44
  if args.prompt_extend_method == "dashscope":
45
  prompt_expander = DashScopePromptExpander(
 
50
  is_vl="i2v" in args.task,
51
  device=rank)
52
  else:
53
+ raise NotImplementedError(f"Unsupported prompt_extend_method: {args.prompt_extend_method}")
 
54
 
55
  cfg = WAN_CONFIGS[args.task]
 
 
 
56
  logging.info(f"Generation job args: {args}")
57
  logging.info(f"Generation model config: {cfg}")
58
 
59
+ # Broadcast base seed across distributed workers
60
  if dist.is_initialized():
61
  base_seed = [args.base_seed] if rank == 0 else [None]
62
  dist.broadcast_object_list(base_seed, src=0)
63
  args.base_seed = base_seed[0]
64
 
65
+ # Set prompt and task details
66
  if "t2v" in args.task or "t2i" in args.task:
67
  if args.prompt is None:
68
  args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
69
  logging.info(f"Input prompt: {args.prompt}")
70
+
71
  if args.use_prompt_extend:
72
  logging.info("Extending prompt ...")
73
  if rank == 0:
 
76
  tar_lang=args.prompt_extend_target_lang,
77
  seed=args.base_seed)
78
  if prompt_output.status == False:
79
+ logging.info(f"Prompt extension failed: {prompt_output.message}")
 
 
80
  input_prompt = args.prompt
81
  else:
82
  input_prompt = prompt_output.prompt
 
83
  else:
84
  input_prompt = [None]
85
  if dist.is_initialized():
 
99
  t5_cpu=args.t5_cpu,
100
  )
101
 
102
+ logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
 
103
  video = wan_t2v.generate(
104
  args.prompt,
105
  size=SIZE_CONFIGS[args.size],
 
111
  seed=args.base_seed,
112
  offload_model=args.offload_model)
113
 
114
+ else: # image-to-video
115
  if args.prompt is None:
116
  args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
117
  if args.image is None:
 
129
  image=img,
130
  seed=args.base_seed)
131
  if prompt_output.status == False:
132
+ logging.info(f"Prompt extension failed: {prompt_output.message}")
 
 
133
  input_prompt = args.prompt
134
  else:
135
  input_prompt = prompt_output.prompt
 
136
  else:
137
  input_prompt = [None]
138
  if dist.is_initialized():
 
165
  seed=args.base_seed,
166
  offload_model=args.offload_model)
167
 
168
+ # Save the output video or image
169
  if rank == 0:
170
  if args.save_file is None:
171
  formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
172
+ formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
 
173
  suffix = '.png' if "t2i" in args.task else '.mp4'
174
  args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
175
 
176
  if "t2i" in args.task:
177
  logging.info(f"Saving generated image to {args.save_file}")
178
+ cache_image(tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True)
 
 
 
 
 
179
  else:
180
  logging.info(f"Saving generated video to {args.save_file}")
181
+ cache_video(tensor=video, save_file=args.save_file)