Spaces:
Running
Running
Update generate.py
Browse files- generate.py +5 -4
generate.py
CHANGED
@@ -7,6 +7,7 @@ from datetime import datetime
|
|
7 |
from tqdm import tqdm
|
8 |
|
9 |
def generate(args):
|
|
|
10 |
rank = int(os.getenv("RANK", 0))
|
11 |
world_size = int(os.getenv("WORLD_SIZE", 1))
|
12 |
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
@@ -14,11 +15,11 @@ def generate(args):
|
|
14 |
# Set device: use CPU if specified, else use GPU based on rank
|
15 |
if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
|
16 |
device = torch.device("cpu")
|
17 |
-
|
18 |
else:
|
19 |
device = local_rank
|
20 |
torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU
|
21 |
-
|
22 |
|
23 |
_init_logging(rank)
|
24 |
|
@@ -108,7 +109,7 @@ def generate(args):
|
|
108 |
t5_cpu=args.t5_cpu,
|
109 |
)
|
110 |
|
111 |
-
|
112 |
try:
|
113 |
video = wan_t2v.generate(
|
114 |
args.prompt,
|
@@ -165,7 +166,7 @@ def generate(args):
|
|
165 |
t5_cpu=args.t5_cpu,
|
166 |
)
|
167 |
|
168 |
-
|
169 |
try:
|
170 |
video = wan_i2v.generate(
|
171 |
args.prompt,
|
|
|
7 |
from tqdm import tqdm
|
8 |
|
9 |
def generate(args):
|
10 |
+
print("call generate")
|
11 |
rank = int(os.getenv("RANK", 0))
|
12 |
world_size = int(os.getenv("WORLD_SIZE", 1))
|
13 |
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
|
|
15 |
# Set device: use CPU if specified, else use GPU based on rank
|
16 |
if args.t5_cpu or args.dit_fsdp: # Use CPU if specified in arguments
|
17 |
device = torch.device("cpu")
|
18 |
+
print("Using CPU for model inference.")
|
19 |
else:
|
20 |
device = local_rank
|
21 |
torch.cuda.set_device(local_rank) # Ensure proper device assignment if using GPU
|
22 |
+
print(f"Using GPU: {device}")
|
23 |
|
24 |
_init_logging(rank)
|
25 |
|
|
|
109 |
t5_cpu=args.t5_cpu,
|
110 |
)
|
111 |
|
112 |
+
print(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
113 |
try:
|
114 |
video = wan_t2v.generate(
|
115 |
args.prompt,
|
|
|
166 |
t5_cpu=args.t5_cpu,
|
167 |
)
|
168 |
|
169 |
+
print("Generating video ..6666666666666666")
|
170 |
try:
|
171 |
video = wan_i2v.generate(
|
172 |
args.prompt,
|