rahul7star commited on
Commit
45fd710
·
verified ·
1 Parent(s): c4fcfaf

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +5 -4
generate.py CHANGED
@@ -7,6 +7,7 @@ from datetime import datetime
7
  from tqdm import tqdm
8
 
9
  def generate(args):
 
10
  rank = int(os.getenv("RANK", 0))
11
  world_size = int(os.getenv("WORLD_SIZE", 1))
12
  local_rank = int(os.getenv("LOCAL_RANK", 0))
@@ -14,11 +15,11 @@ def generate(args):
14
  # Set device: use CPU if specified, else use GPU based on rank
15
  if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
16
  device = torch.device("cpu")
17
- logging.info("Using CPU for model inference.")
18
  else:
19
  device = local_rank
20
  torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU
21
- logging.info(f"Using GPU: {device}")
22
 
23
  _init_logging(rank)
24
 
@@ -108,7 +109,7 @@ def generate(args):
108
  t5_cpu=args.t5_cpu,
109
  )
110
 
111
- logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
112
  try:
113
  video = wan_t2v.generate(
114
  args.prompt,
@@ -165,7 +166,7 @@ def generate(args):
165
  t5_cpu=args.t5_cpu,
166
  )
167
 
168
- logging.info("Generating video ...")
169
  try:
170
  video = wan_i2v.generate(
171
  args.prompt,
 
7
  from tqdm import tqdm
8
 
9
  def generate(args):
10
+ print("call generate")
11
  rank = int(os.getenv("RANK", 0))
12
  world_size = int(os.getenv("WORLD_SIZE", 1))
13
  local_rank = int(os.getenv("LOCAL_RANK", 0))
 
15
  # Set device: use CPU if specified, else use GPU based on rank
16
  if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
17
  device = torch.device("cpu")
18
+ print("Using CPU for model inference.")
19
  else:
20
  device = local_rank
21
  torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU
22
+ print(f"Using GPU: {device}")
23
 
24
  _init_logging(rank)
25
 
 
109
  t5_cpu=args.t5_cpu,
110
  )
111
 
112
+ print(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
113
  try:
114
  video = wan_t2v.generate(
115
  args.prompt,
 
166
  t5_cpu=args.t5_cpu,
167
  )
168
 
169
+ print("Generating video ..6666666666666666")
170
  try:
171
  video = wan_i2v.generate(
172
  args.prompt,