1inkusFace commited on
Commit
846c31c
·
verified ·
1 Parent(s): 0bedac9

gemini update

Browse files
Files changed (1) hide show
  1. skyreelsinfer/skyreels_video_infer.py +322 -258
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -1,258 +1,322 @@
1
- import logging
2
- import os
3
- import threading
4
- import time
5
- from datetime import timedelta
6
- from typing import Any
7
- from typing import Dict
8
-
9
- import torch
10
- import torch.distributed as dist
11
- import torch.multiprocessing as mp
12
- from diffusers import HunyuanVideoTransformer3DModel
13
- from PIL import Image
14
- from torchao.quantization import float8_weight_only
15
- from torchao.quantization import quantize_
16
- from transformers import LlamaModel
17
-
18
- from . import TaskType
19
- from .offload import Offload
20
- from .offload import OffloadConfig
21
- from .pipelines import SkyreelsVideoPipeline
22
-
23
- logger = logging.getLogger("SkyreelsVideoInfer")
24
- logger.setLevel(logging.DEBUG)
25
- console_handler = logging.StreamHandler()
26
- console_handler.setLevel(logging.DEBUG)
27
- formatter = logging.Formatter(
28
- f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
29
- )
30
- console_handler.setFormatter(formatter)
31
- logger.addHandler(console_handler)
32
-
33
-
34
- class SkyReelsVideoSingleGpuInfer:
35
- def _load_model(
36
- self,
37
- model_id: str,
38
- base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
39
- quant_model: bool = True,
40
- gpu_device: str = "cuda:0",
41
- ) -> SkyreelsVideoPipeline:
42
- logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
43
- text_encoder = LlamaModel.from_pretrained(
44
- base_model_id,
45
- subfolder="text_encoder",
46
- torch_dtype=torch.bfloat16,
47
- ).to("cpu")
48
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
49
- model_id,
50
- # subfolder="transformer",
51
- torch_dtype=torch.bfloat16,
52
- device="cpu",
53
- ).to("cpu")
54
- if quant_model:
55
- quantize_(text_encoder, float8_weight_only(), device=gpu_device)
56
- text_encoder.to("cpu")
57
- torch.cuda.empty_cache()
58
- quantize_(transformer, float8_weight_only(), device=gpu_device)
59
- transformer.to("cpu")
60
- torch.cuda.empty_cache()
61
- pipe = SkyreelsVideoPipeline.from_pretrained(
62
- base_model_id,
63
- transformer=transformer,
64
- text_encoder=text_encoder,
65
- torch_dtype=torch.bfloat16,
66
- ).to("cpu")
67
- pipe.vae.enable_tiling()
68
- torch.cuda.empty_cache()
69
- return pipe
70
-
71
- def __init__(
72
- self,
73
- task_type: TaskType,
74
- model_id: str,
75
- quant_model: bool = True,
76
- local_rank: int = 0,
77
- world_size: int = 1,
78
- is_offload: bool = True,
79
- offload_config: OffloadConfig = OffloadConfig(),
80
- enable_cfg_parallel: bool = True,
81
- ):
82
- self.task_type = task_type
83
- self.gpu_rank = local_rank
84
- dist.init_process_group(
85
- backend="nccl",
86
- init_method="tcp://127.0.0.1:23456",
87
- timeout=timedelta(seconds=600),
88
- world_size=world_size,
89
- rank=local_rank,
90
- )
91
- os.environ["LOCAL_RANK"] = str(local_rank)
92
- logger.info(f"rank:{local_rank} Distributed backend: {dist.get_backend()}")
93
- torch.cuda.set_device(dist.get_rank())
94
- torch.backends.cuda.enable_cudnn_sdp(False)
95
- gpu_device = f"cuda:{dist.get_rank()}"
96
-
97
- self.pipe: SkyreelsVideoPipeline = self._load_model(
98
- model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
99
- )
100
-
101
- from para_attn.context_parallel import init_context_parallel_mesh
102
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
103
- from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
104
-
105
- max_batch_dim_size = 2 if enable_cfg_parallel and world_size > 1 else 1
106
- max_ulysses_dim_size = int(world_size / max_batch_dim_size)
107
- logger.info(f"max_batch_dim_size: {max_batch_dim_size}, max_ulysses_dim_size:{max_ulysses_dim_size}")
108
-
109
- mesh = init_context_parallel_mesh(
110
- self.pipe.device.type,
111
- max_ring_dim_size=1,
112
- max_batch_dim_size=max_batch_dim_size,
113
- )
114
- parallelize_pipe(self.pipe, mesh=mesh)
115
- parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
116
-
117
- if is_offload:
118
- Offload.offload(
119
- pipeline=self.pipe,
120
- config=offload_config,
121
- )
122
- else:
123
- self.pipe.to(gpu_device)
124
-
125
- if offload_config.compiler_transformer:
126
- torch._dynamo.config.suppress_errors = True
127
- os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
128
- os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_{world_size}"
129
- self.pipe.transformer = torch.compile(
130
- self.pipe.transformer,
131
- mode="max-autotune-no-cudagraphs",
132
- dynamic=True,
133
- )
134
- self.warm_up()
135
-
136
- def warm_up(self):
137
- init_kwargs = {
138
- "prompt": "A woman is dancing in a room",
139
- "height": 544,
140
- "width": 960,
141
- "guidance_scale": 6,
142
- "num_inference_steps": 1,
143
- "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
144
- "num_frames": 97,
145
- "generator": torch.Generator("cuda").manual_seed(42),
146
- "embedded_guidance_scale": 1.0,
147
- }
148
- if self.task_type == TaskType.I2V:
149
- init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
150
- self.pipe(**init_kwargs)
151
-
152
- def damon_inference(self, request_queue: mp.Queue, response_queue: mp.Queue):
153
- response_queue.put(f"rank:{self.gpu_rank} ready")
154
- logger.info(f"rank:{self.gpu_rank} finish init pipe")
155
- while True:
156
- logger.info(f"rank:{self.gpu_rank} waiting for request")
157
- kwargs = request_queue.get()
158
- logger.info(f"rank:{self.gpu_rank} kwargs: {kwargs}")
159
- if "seed" in kwargs:
160
- kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
161
- del kwargs["seed"]
162
- start_time = time.time()
163
- assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
164
- out = self.pipe(**kwargs).frames[0]
165
- logger.info(f"rank:{dist.get_rank()} inference time: {time.time() - start_time}")
166
- if dist.get_rank() == 0:
167
- response_queue.put(out)
168
-
169
-
170
- def single_gpu_run(
171
- rank,
172
- task_type: TaskType,
173
- model_id: str,
174
- request_queue: mp.Queue,
175
- response_queue: mp.Queue,
176
- quant_model: bool = True,
177
- world_size: int = 1,
178
- is_offload: bool = True,
179
- offload_config: OffloadConfig = OffloadConfig(),
180
- enable_cfg_parallel: bool = True,
181
- ):
182
- pipe = SkyReelsVideoSingleGpuInfer(
183
- task_type=task_type,
184
- model_id=model_id,
185
- quant_model=quant_model,
186
- local_rank=rank,
187
- world_size=world_size,
188
- is_offload=is_offload,
189
- offload_config=offload_config,
190
- enable_cfg_parallel=enable_cfg_parallel,
191
- )
192
- pipe.damon_inference(request_queue, response_queue)
193
-
194
-
195
- class SkyReelsVideoInfer:
196
- def __init__(
197
- self,
198
- task_type: TaskType,
199
- model_id: str,
200
- quant_model: bool = True,
201
- world_size: int = 1,
202
- is_offload: bool = True,
203
- offload_config: OffloadConfig = OffloadConfig(),
204
- enable_cfg_parallel: bool = True,
205
- ):
206
- self.world_size = world_size
207
- smp = mp.get_context("spawn")
208
- self.REQ_QUEUES: mp.Queue = smp.Queue()
209
- self.RESP_QUEUE: mp.Queue = smp.Queue()
210
- assert self.world_size > 0, "gpu_num must be greater than 0"
211
- spawn_thread = threading.Thread(
212
- target=self.lauch_single_gpu_infer,
213
- args=(task_type, model_id, quant_model, world_size, is_offload, offload_config, enable_cfg_parallel),
214
- daemon=True,
215
- )
216
- spawn_thread.start()
217
- logger.info(f"Started multi-GPU thread with GPU_NUM: {world_size}")
218
- print(f"Started multi-GPU thread with GPU_NUM: {world_size}")
219
- # Block and wait for the prediction process to start
220
- for _ in range(world_size):
221
- msg = self.RESP_QUEUE.get()
222
- logger.info(f"launch_multi_gpu get init msg: {msg}")
223
- print(f"launch_multi_gpu get init msg: {msg}")
224
-
225
- def lauch_single_gpu_infer(
226
- self,
227
- task_type: TaskType,
228
- model_id: str,
229
- quant_model: bool = True,
230
- world_size: int = 1,
231
- is_offload: bool = True,
232
- offload_config: OffloadConfig = OffloadConfig(),
233
- enable_cfg_parallel: bool = True,
234
- ):
235
- mp.spawn(
236
- single_gpu_run,
237
- nprocs=world_size,
238
- join=True,
239
- daemon=True,
240
- args=(
241
- task_type,
242
- model_id,
243
- self.REQ_QUEUES,
244
- self.RESP_QUEUE,
245
- quant_model,
246
- world_size,
247
- is_offload,
248
- offload_config,
249
- enable_cfg_parallel,
250
- ),
251
- )
252
- logger.info(f"finish lanch multi gpu infer, world_size:{world_size}")
253
-
254
- def inference(self, kwargs: Dict[str, Any]):
255
- # put request to singlegpuinfer
256
- for _ in range(self.world_size):
257
- self.REQ_QUEUES.put(kwargs)
258
- return self.RESP_QUEUE.get()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import threading
4
+ import time
5
+ from datetime import timedelta
6
+ from typing import Any
7
+ from typing import Dict
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.multiprocessing as mp
12
+ from diffusers import HunyuanVideoTransformer3DModel
13
+ from PIL import Image
14
+ from torchao.quantization import float8_weight_only
15
+ from torchao.quantization import quantize_
16
+ from transformers import LlamaModel
17
+
18
+ from . import TaskType
19
+ from .offload import Offload
20
+ from .offload import OffloadConfig
21
+ from .pipelines import SkyreelsVideoPipeline
22
+
23
+ logger = logging.getLogger("SkyreelsVideoInfer")
24
+ logger.setLevel(logging.DEBUG)
25
+ console_handler = logging.StreamHandler()
26
+ console_handler.setLevel(logging.DEBUG)
27
+ formatter = logging.Formatter(
28
+ f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
29
+ )
30
+ console_handler.setFormatter(formatter)
31
+ logger.addHandler(console_handler)
32
+
33
+
34
+ class SkyReelsVideoSingleGpuInfer:
35
+ def _load_model(
36
+ self,
37
+ model_id: str,
38
+ base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
39
+ quant_model: bool = True,
40
+ gpu_device: str = "cuda:0",
41
+ ) -> SkyreelsVideoPipeline:
42
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
43
+ text_encoder = LlamaModel.from_pretrained(
44
+ base_model_id,
45
+ subfolder="text_encoder",
46
+ torch_dtype=torch.bfloat16,
47
+ ).to("cpu")
48
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
49
+ model_id,
50
+ # subfolder="transformer",
51
+ torch_dtype=torch.bfloat16,
52
+ device="cpu",
53
+ ).to("cpu")
54
+ if quant_model:
55
+ quantize_(text_encoder, float8_weight_only(), device=gpu_device)
56
+ text_encoder.to("cpu")
57
+ torch.cuda.empty_cache()
58
+ quantize_(transformer, float8_weight_only(), device=gpu_device)
59
+ transformer.to("cpu")
60
+ torch.cuda.empty_cache()
61
+ pipe = SkyreelsVideoPipeline.from_pretrained(
62
+ base_model_id,
63
+ transformer=transformer,
64
+ text_encoder=text_encoder,
65
+ torch_dtype=torch.bfloat16,
66
+ ).to("cpu")
67
+ pipe.vae.enable_tiling()
68
+ torch.cuda.empty_cache()
69
+ return pipe
70
+
71
+ def __init__(
72
+ self,
73
+ task_type: TaskType,
74
+ model_id: str,
75
+ quant_model: bool = True,
76
+ local_rank: int = 0,
77
+ world_size: int = 1,
78
+ is_offload: bool = True,
79
+ offload_config: OffloadConfig = OffloadConfig(),
80
+ enable_cfg_parallel: bool = True,
81
+ ):
82
+ self.task_type = task_type
83
+ self.gpu_rank = local_rank
84
+ dist.init_process_group(
85
+ backend="nccl",
86
+ init_method="tcp://127.0.0.1:23456",
87
+ timeout=timedelta(seconds=600),
88
+ world_size=world_size,
89
+ rank=local_rank,
90
+ )
91
+ os.environ["LOCAL_RANK"] = str(local_rank)
92
+ logger.info(f"rank:{local_rank} Distributed backend: {dist.get_backend()}")
93
+ torch.cuda.set_device(dist.get_rank())
94
+ torch.backends.cuda.enable_cudnn_sdp(False)
95
+ gpu_device = f"cuda:{dist.get_rank()}"
96
+
97
+ self.pipe: SkyreelsVideoPipeline = self._load_model(
98
+ model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
99
+ )
100
+
101
+ from para_attn.context_parallel import init_context_parallel_mesh
102
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
103
+ from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
104
+
105
+ max_batch_dim_size = 2 if enable_cfg_parallel and world_size > 1 else 1
106
+ max_ulysses_dim_size = int(world_size / max_batch_dim_size)
107
+ logger.info(f"max_batch_dim_size: {max_batch_dim_size}, max_ulysses_dim_size:{max_ulysses_dim_size}")
108
+
109
+ mesh = init_context_parallel_mesh(
110
+ self.pipe.device.type,
111
+ max_ring_dim_size=1,
112
+ max_batch_dim_size=max_batch_dim_size,
113
+ )
114
+ parallelize_pipe(self.pipe, mesh=mesh)
115
+ parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
116
+
117
+ if is_offload:
118
+ Offload.offload(
119
+ pipeline=self.pipe,
120
+ config=offload_config,
121
+ )
122
+ else:
123
+ self.pipe.to(gpu_device)
124
+
125
+ if offload_config.compiler_transformer:
126
+ torch._dynamo.config.suppress_errors = True
127
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
128
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_{world_size}"
129
+ self.pipe.transformer = torch.compile(
130
+ self.pipe.transformer,
131
+ mode="max-autotune-no-cudagraphs",
132
+ dynamic=True,
133
+ )
134
+ self.warm_up()
135
+
136
+ def warm_up(self):
137
+ init_kwargs = {
138
+ "prompt": "A woman is dancing in a room",
139
+ "height": 544,
140
+ "width": 960,
141
+ "guidance_scale": 6,
142
+ "num_inference_steps": 1,
143
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
144
+ "num_frames": 97,
145
+ "generator": torch.Generator("cuda").manual_seed(42),
146
+ "embedded_guidance_scale": 1.0,
147
+ }
148
+ if self.task_type == TaskType.I2V:
149
+ init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
150
+ self.pipe(**init_kwargs)
151
+
152
+ def damon_inference(self, request_queue: mp.Queue, response_queue: mp.Queue):
153
+ response_queue.put(f"rank:{self.gpu_rank} ready")
154
+ logger.info(f"rank:{self.gpu_rank} finish init pipe")
155
+ while True:
156
+ logger.info(f"rank:{self.gpu_rank} waiting for request")
157
+ kwargs = request_queue.get()
158
+ logger.info(f"rank:{self.gpu_rank} kwargs: {kwargs}")
159
+ if "seed" in kwargs:
160
+ kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
161
+ del kwargs["seed"]
162
+ start_time = time.time()
163
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
164
+ out = self.pipe(**kwargs).frames[0]
165
+ logger.info(f"rank:{dist.get_rank()} inference time: {time.time() - start_time}")
166
+ if dist.get_rank() == 0:
167
+ response_queue.put(out)
168
+
169
+
170
+ def single_gpu_run(
171
+ rank,
172
+ task_type: TaskType,
173
+ model_id: str,
174
+ request_queue: mp.Queue,
175
+ response_queue: mp.Queue,
176
+ quant_model: bool = True,
177
+ world_size: int = 1,
178
+ is_offload: bool = True,
179
+ offload_config: OffloadConfig = OffloadConfig(),
180
+ enable_cfg_parallel: bool = True,
181
+ ):
182
+ pipe = SkyReelsVideoSingleGpuInfer(
183
+ task_type=task_type,
184
+ model_id=model_id,
185
+ quant_model=quant_model,
186
+ local_rank=rank,
187
+ world_size=world_size,
188
+ is_offload=is_offload,
189
+ offload_config=offload_config,
190
+ enable_cfg_parallel=enable_cfg_parallel,
191
+ )
192
+ pipe.damon_inference(request_queue, response_queue)
193
+
194
+
195
+ class SkyReelsVideoInfer:
196
+ def __init__(
197
+ self,
198
+ task_type: TaskType,
199
+ model_id: str,
200
+ quant_model: bool = True,
201
+ world_size: int = 1,
202
+ is_offload: bool = False,
203
+ offload_config: OffloadConfig = None,
204
+ use_multiprocessing: bool = True # <--- Add this parameter
205
+ ):
206
+ self.task_type = task_type
207
+ self.model_id = model_id
208
+ self.quant_model = quant_model
209
+ self.world_size = world_size
210
+ self.is_offload = is_offload
211
+ self.offload_config = offload_config
212
+ self.use_multiprocessing = use_multiprocessing # <--- Store it
213
+
214
+ if self.use_multiprocessing: # Only run if flag set
215
+ self.infer_lock = mp.Lock()
216
+ #self.infer_event = mp.Event()
217
+ mp.set_start_method("spawn", force=True)
218
+ print(f"Started multi-GPU thread with GPU_NUM: {world_size}")
219
+ self._lauch_infer_thread()
220
+ else: #If multi-processing disabled, initialize pipe here.
221
+ self._initialize_pipeline() #Call to initialize
222
+
223
+
224
+
225
+ def _initialize_pipeline(self):
226
+ """Initializes the DiffusionPipeline."""
227
+ if self.is_offload and self.offload_config:
228
+ # ... (your existing offload setup code) ...
229
+ pipe = DiffusionPipeline.from_pretrained(
230
+ self.model_id,
231
+ torch_dtype=torch.float16,
232
+ variant="fp16",
233
+ )
234
+ #Offload
235
+ if self.offload_config.parameters_level:
236
+ pipe = pipe.to("cpu")
237
+ if self.offload_config.high_cpu_memory:
238
+ pipe.enable_model_offload()
239
+ else:
240
+ pipe.enable_sequential_cpu_offload()
241
+
242
+ elif self.quant_model:
243
+ # ... (your existing quantization setup code) ...
244
+ pipe = DiffusionPipeline.from_pretrained(
245
+ self.model_id,
246
+ torch_dtype=torch.bfloat16,
247
+ variant="bf16",
248
+ )
249
+ else:
250
+ pipe = DiffusionPipeline.from_pretrained(self.model_id)
251
+ self.pipe = pipe
252
+
253
+
254
+ def _lauch_infer_thread(self):
255
+ # ... (your existing thread launching code, BUT gated by use_multiprocessing) ...
256
+ #Wrap with use_multiprocessing check
257
+ for gpu_id in range(self.world_size):
258
+ thread = mp.Process(
259
+ target=self.lauch_single_gpu_infer,
260
+ args=(
261
+ gpu_id,
262
+ self.is_offload,
263
+ self.offload_config,
264
+ self.model_id,
265
+ self.quant_model,
266
+ self.infer_lock
267
+ ),
268
+ )
269
+ thread.daemon = True
270
+ thread.start()
271
+ #Remove else statement here, it is taken care of at init
272
+
273
+ def lauch_single_gpu_infer(self, gpu_id, is_offload, offload_config, model_id, quant_model, infer_lock):
274
+ # ... (rest of your lauch_single_gpu_infer function) ...
275
+ #Make sure it runs on CPU:
276
+ device = torch.device("cpu") #Force CPU
277
+ # ... inside lauch_single_gpu_infer, initialize the pipe:
278
+ if is_offload and offload_config:
279
+ # ... (your existing offload setup code) ...
280
+ pipe = DiffusionPipeline.from_pretrained(
281
+ model_id,
282
+ torch_dtype=torch.float16,
283
+ variant="fp16",
284
+ )
285
+
286
+ #Offload
287
+ if offload_config.parameters_level:
288
+ pipe = pipe.to("cpu") #Force to CPU
289
+ if offload_config.high_cpu_memory:
290
+ pipe.enable_model_offload()
291
+ else:
292
+ pipe.enable_sequential_cpu_offload()
293
+ elif quant_model:
294
+ pipe = DiffusionPipeline.from_pretrained(
295
+ model_id,
296
+ torch_dtype=torch.bfloat16,
297
+ variant="bf16",
298
+ )
299
+ else:
300
+ pipe = DiffusionPipeline.from_pretrained(model_id)
301
+ pipe = pipe.to(device) #Move to the CPU device.
302
+ #Rest of the Function
303
+
304
+ def inference(self, kwargs):
305
+ if self.use_multiprocessing: # Only run if flag set
306
+ # ... (your existing multi-processing inference code) ...
307
+ with self.infer_lock:
308
+ #self.infer_event.wait()
309
+ if self.task_type == TaskType.I2V:
310
+ image = kwargs.pop("image")
311
+ output = self.pipe(image=image, **kwargs).frames
312
+ else:
313
+ output = self.pipe(**kwargs).frames
314
+ return output
315
+ else: # <--- Add this else block for single-process inference
316
+ # Run inference directly in the current process
317
+ if self.task_type == TaskType.I2V:
318
+ image = kwargs.pop("image")
319
+ output = self.pipe(image=image, **kwargs).frames
320
+ else:
321
+ output = self.pipe(**kwargs).frames
322
+ return output