Spaces:
Paused
Paused
Update skyreelsinfer/skyreels_video_infer.py
Browse files
skyreelsinfer/skyreels_video_infer.py
CHANGED
@@ -81,18 +81,10 @@ class SkyReelsVideoSingleGpuInfer:
|
|
81 |
):
|
82 |
self.task_type = task_type
|
83 |
self.gpu_rank = local_rank
|
84 |
-
dist.init_process_group(
|
85 |
-
backend="nccl",
|
86 |
-
init_method="tcp://127.0.0.1:23456",
|
87 |
-
timeout=timedelta(seconds=600),
|
88 |
-
world_size=world_size,
|
89 |
-
rank=local_rank,
|
90 |
-
)
|
91 |
os.environ["LOCAL_RANK"] = str(local_rank)
|
92 |
-
|
93 |
-
torch.cuda.set_device(dist.get_rank())
|
94 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
95 |
-
gpu_device =
|
96 |
|
97 |
self.pipe: SkyreelsVideoPipeline = self._load_model(
|
98 |
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
|
|
|
81 |
):
|
82 |
self.task_type = task_type
|
83 |
self.gpu_rank = local_rank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
os.environ["LOCAL_RANK"] = str(local_rank)
|
85 |
+
torch.cuda.set_device(0)
|
|
|
86 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
87 |
+
gpu_device = "cuda:0"
|
88 |
|
89 |
self.pipe: SkyreelsVideoPipeline = self._load_model(
|
90 |
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
|