1inkusFace commited on
Commit
3ee15e0
·
verified ·
1 Parent(s): 9f45842

Update skyreelsinfer/skyreels_video_infer.py

Browse files
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -81,18 +81,10 @@ class SkyReelsVideoSingleGpuInfer:
81
  ):
82
  self.task_type = task_type
83
  self.gpu_rank = local_rank
84
- dist.init_process_group(
85
- backend="nccl",
86
- init_method="tcp://127.0.0.1:23456",
87
- timeout=timedelta(seconds=600),
88
- world_size=world_size,
89
- rank=local_rank,
90
- )
91
  os.environ["LOCAL_RANK"] = str(local_rank)
92
- logger.info(f"rank:{local_rank} Distributed backend: {dist.get_backend()}")
93
- torch.cuda.set_device(dist.get_rank())
94
  torch.backends.cuda.enable_cudnn_sdp(False)
95
- gpu_device = f"cuda:{dist.get_rank()}"
96
 
97
  self.pipe: SkyreelsVideoPipeline = self._load_model(
98
  model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
 
81
  ):
82
  self.task_type = task_type
83
  self.gpu_rank = local_rank
 
 
 
 
 
 
 
84
  os.environ["LOCAL_RANK"] = str(local_rank)
85
+ torch.cuda.set_device(0)
 
86
  torch.backends.cuda.enable_cudnn_sdp(False)
87
+ gpu_device = "cuda:0"
88
 
89
  self.pipe: SkyreelsVideoPipeline = self._load_model(
90
  model_id=model_id, quant_model=quant_model, gpu_device=gpu_device