1inkusFace commited on
Commit
40b1047
·
verified ·
1 Parent(s): 42d98b4

revert to repo

Browse files
Files changed (1) hide show
  1. skyreelsinfer/offload.py +310 -496
skyreelsinfer/offload.py CHANGED
@@ -1,519 +1,333 @@
1
- import functools
2
- import gc
 
 
3
  import os
4
- import time
5
- from dataclasses import dataclass
 
 
 
 
 
 
6
 
7
  import torch
8
- from diffusers.pipelines import DiffusionPipeline
9
- from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
- @dataclass
13
  class OffloadConfig:
14
- # high_cpu_memory: Whether to use pinned memory for offload optimization. This can effectively prevent increased model offload latency caused by memory swapping.
15
- high_cpu_memory: bool = True
16
- # parameters_level: Whether to enable parameter-level offload. This further reduces VRAM requirements but may result in increased latency.
17
- parameters_level: bool = False
18
- # compiler_transformer: Whether to enable compilation optimization for the transformer.
19
- compiler_transformer: bool = False
20
- compiler_cache: str = "/tmp/compile_cache"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
- class HfHook:
24
- def __init__(self):
25
- device_id = os.environ.get("LOCAL_RANK", 0)
26
- self.execution_device = f"cuda:{device_id}"
27
-
28
- def detach_hook(self, module):
29
- pass
30
-
31
-
32
- class Offload:
33
- def __init__(self) -> None:
34
- self.active_models = []
35
- self.active_models_ids = []
36
- self.active_subcaches = {}
37
- self.models = {}
38
- self.verboseLevel = 0
39
- self.models_to_quantize = []
40
- self.pinned_modules_data = {}
41
- self.blocks_of_modules = {}
42
- self.blocks_of_modules_sizes = {}
43
- self.compile = False
44
- self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
45
- self.last_reserved_mem_check = 0
46
- self.loaded_blocks = {}
47
- self.prev_blocks_names = {}
48
- self.next_blocks_names = {}
49
- device_id = os.environ.get("LOCAL_RANK", 0)
50
- self.device_id = f"cuda:{device_id}"
51
- self.default_stream = torch.cuda.default_stream(self.device_id) # torch.cuda.current_stream()
52
- self.transfer_stream = torch.cuda.Stream()
53
- self.async_transfers = False
54
- self.last_run_model = None
55
-
56
- def check_empty_cuda_cache(self): # Now a method of Offload
57
- if torch.cuda.is_available():
58
- torch.cuda.empty_cache()
59
 
60
- @classmethod
61
- def offload(cls, pipeline: DiffusionPipeline, config: OffloadConfig = OffloadConfig()):
62
- """
63
- Enable offloading for multiple models in the pipeline, supporting video generation inference on user-level GPUs.
64
- pipe: the pipeline object
65
- config: offload strategy configuration
66
- """
67
- self = cls()
68
- self.pinned_modules_data = {}
69
- if config.parameters_level:
70
- model_budgets = {
71
- "transformer": 600 * 1024 * 1024,
72
- "text_encoder": 3 * 1024 * 1024 * 1024,
73
- "text_encoder_2": 3 * 1024 * 1024 * 1024,
74
- }
75
- self.async_transfers = True
76
- else:
77
- model_budgets = {}
78
-
79
- device_id = os.getenv("LOCAL_RANK", 0)
80
- torch.set_default_device(f"cuda:{device_id}")
81
- pipeline.hf_device_map = torch.device(f"cuda:{device_id}")
82
- pipe_or_dict_of_modules = pipeline.components
83
- if config.compiler_transformer:
84
- pipeline.transformer.to("cuda")
85
- models = {
86
- k: v
87
- for k, v in pipe_or_dict_of_modules.items()
88
- if isinstance(v, torch.nn.Module) and not (config.compiler_transformer and k == "transformer")
89
- }
90
- print_info = {k: type(v) for k, v in models.items()}
91
- print(f"offload models: {print_info}")
92
- if config.compiler_transformer:
93
- pipeline.text_encoder.to("cpu")
94
- pipeline.text_encoder_2.to("cpu")
95
- torch.cuda.empty_cache()
96
- pipeline.transformer.to("cuda")
97
- pipeline.vae.to("cuda")
98
-
99
- def move_text_encoder_to_gpu(pipe):
100
- torch.cuda.empty_cache()
101
- pipe.text_encoder.to("cuda")
102
- pipe.text_encoder_2.to("cuda")
103
-
104
- def move_text_encoder_to_cpu(pipe):
105
- pipe.text_encoder.to("cpu")
106
- pipe.text_encoder_2.to("cpu")
107
- torch.cuda.empty_cache()
108
-
109
- setattr(pipeline, "text_encoder_to_cpu", functools.partial(move_text_encoder_to_cpu, pipeline))
110
- setattr(pipeline, "text_encoder_to_gpu", functools.partial(move_text_encoder_to_gpu, pipeline))
111
-
112
- for k, module in pipe_or_dict_of_modules.items():
113
- if isinstance(module, torch.nn.Module):
114
- for submodule_name, submodule in module.named_modules():
115
- if not hasattr(submodule, "_hf_hook"):
116
- setattr(submodule, "_hf_hook", HfHook())
117
- return self
118
-
119
- sizeofbfloat16 = torch.bfloat16.itemsize
120
- modelPinned = config.high_cpu_memory
121
- # Pin in RAM models
122
- # Calculate the VRAM requirements of the computational modules to determine whether parameters-level offload is necessary.
123
- for model_name, curr_model in models.items():
124
- curr_model.to("cpu").eval()
125
- pinned_parameters_data = {}
126
- current_model_size = 0
127
- print(f"{model_name} move to pinned memory:{modelPinned}")
128
- for p in curr_model.parameters():
129
- if isinstance(p, AffineQuantizedTensor):
130
- if not modelPinned and p.tensor_impl.scale.dtype == torch.float32:
131
- p.tensor_impl.scale = p.tensor_impl.scale.to(torch.bfloat16)
132
- current_model_size += torch.numel(p.tensor_impl.scale) * sizeofbfloat16
133
- current_model_size += torch.numel(p.tensor_impl.float8_data) * sizeofbfloat16 / 2
134
- if modelPinned:
135
- p.tensor_impl.float8_data = p.tensor_impl.float8_data.pin_memory()
136
- p.tensor_impl.scale = p.tensor_impl.scale.pin_memory()
137
- pinned_parameters_data[p] = [p.tensor_impl.float8_data, p.tensor_impl.scale]
138
- else:
139
- p.data = p.data.to(torch.bfloat16) if p.data.dtype == torch.float32 else p.data.to(p.data.dtype)
140
- current_model_size += torch.numel(p.data) * p.data.element_size()
141
- if modelPinned:
142
- p.data = p.data.pin_memory()
143
- pinned_parameters_data[p] = p.data
144
-
145
- for buffer in curr_model.buffers():
146
- buffer.data = (
147
- buffer.data.to(torch.bfloat16)
148
- if buffer.data.dtype == torch.float32
149
- else buffer.data.to(buffer.data.dtype)
150
- )
151
- current_model_size += torch.numel(buffer.data) * buffer.data.element_size()
152
- if modelPinned:
153
- buffer.data = buffer.data.pin_memory()
154
-
155
- if model_name not in self.models:
156
- self.models[model_name] = curr_model
157
-
158
- curr_model_budget = model_budgets.get(model_name, 0)
159
- if curr_model_budget > 0 and curr_model_budget > current_model_size:
160
- model_budgets[model_name] = 0
161
-
162
- if modelPinned:
163
- pinned_buffers_data = {b: b.data for b in curr_model.buffers()}
164
- pinned_parameters_data.update(pinned_buffers_data)
165
- self.pinned_modules_data[model_name] = pinned_parameters_data
166
- gc.collect()
167
- torch.cuda.empty_cache()
168
 
169
- # if config.compiler_transformer:
170
- # module = pipeline.transformer
171
- # print("wrap transformer forward")
172
- # # gpu model wrap
173
- # for submodule_name, submodule in module.named_modules():
174
- # if not hasattr(submodule, "_hf_hook"):
175
- # setattr(submodule, "_hf_hook", HfHook())
176
- #
177
- # forward_method = getattr(module, "forward")
178
- #
179
- # def wrap_unload_all(*args, **kwargs):
180
- # self.unload_all("transformer")
181
- # return forward_method(*args, **kwargs)
182
- #
183
- # setattr(module, "forward", functools.update_wrapper(wrap_unload_all, forward_method))
184
-
185
- # wrap forward methods
186
- for model_name, curr_model in models.items():
187
- current_budget = model_budgets.get(model_name, 0)
188
- current_size = 0
189
- self.loaded_blocks[model_name] = None
190
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name, cur_blocks_seq = None, None, None, -1
191
-
192
- for submodule_name, submodule in curr_model.named_modules():
193
- # create a fake accelerate parameter so that the _execution_device property returns always "cuda"
194
- if not hasattr(submodule, "_hf_hook"):
195
- setattr(submodule, "_hf_hook", HfHook())
196
-
197
- if not submodule_name:
198
- continue
199
-
200
- # usr parameters-level offload
201
- if current_budget > 0:
202
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
203
- if cur_blocks_prefix == None:
204
- cur_blocks_prefix = submodule_name + "."
205
- else:
206
- if not submodule_name.startswith(cur_blocks_prefix):
207
- cur_blocks_prefix = submodule_name + "."
208
- cur_blocks_name, cur_blocks_seq = None, -1
209
- else:
210
- if cur_blocks_prefix is not None:
211
- if submodule_name.startswith(cur_blocks_prefix):
212
- num = int(submodule_name[len(cur_blocks_prefix) :].split(".")[0])
213
- if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
214
- prev_blocks_name = cur_blocks_name
215
- cur_blocks_name = cur_blocks_prefix + str(num)
216
- cur_blocks_seq = num
217
- else:
218
- cur_blocks_prefix = None
219
- prev_blocks_name = None
220
- cur_blocks_name = None
221
- cur_blocks_seq = -1
222
-
223
- if hasattr(submodule, "forward"):
224
- submodule_forward = getattr(submodule, "forward")
225
- if not callable(submodule_forward):
226
- print("***")
227
- continue
228
- if len(submodule_name.split(".")) == 1:
229
- self.hook_me(submodule, curr_model, model_name, submodule_name, submodule_forward)
230
- else:
231
- self.hook_me_light(
232
- submodule, model_name, cur_blocks_name, submodule_forward, context=submodule_name
233
- )
234
- current_size = self.add_module_to_blocks(model_name, cur_blocks_name, submodule, prev_blocks_name)
235
-
236
- gc.collect()
237
- torch.cuda.empty_cache()
238
  return self
239
 
240
- def add_module_to_blocks(self, model_name, blocks_name, submodule, prev_block_name):
 
 
 
241
 
242
- entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
243
- if entry_name in self.blocks_of_modules:
244
- blocks_params = self.blocks_of_modules[entry_name]
245
- blocks_params_size = self.blocks_of_modules_sizes[entry_name]
246
- else:
247
- blocks_params = []
248
- self.blocks_of_modules[entry_name] = blocks_params
249
- blocks_params_size = 0
250
- if blocks_name != None:
251
- prev_entry_name = None if prev_block_name == None else model_name + "/" + prev_block_name
252
- self.prev_blocks_names[entry_name] = prev_entry_name
253
- if not prev_block_name == None:
254
- self.next_blocks_names[prev_entry_name] = entry_name
255
-
256
- for p in submodule.parameters(recurse=False):
257
- blocks_params.append(p)
258
- if isinstance(p, AffineQuantizedTensor):
259
- blocks_params_size += p.tensor_impl.float8_data.nbytes
260
- blocks_params_size += p.tensor_impl.scale.nbytes
261
- else:
262
- blocks_params_size += p.data.nbytes
263
-
264
- for p in submodule.buffers(recurse=False):
265
- blocks_params.append(p)
266
- blocks_params_size += p.data.nbytes
267
-
268
- self.blocks_of_modules_sizes[entry_name] = blocks_params_size
269
-
270
- return blocks_params_size
271
-
272
- def can_model_be_cotenant(self, model_name):
273
- cotenants_map = {
274
- "text_encoder": ["vae", "text_encoder_2"],
275
- "text_encoder_2": ["vae", "text_encoder"],
276
- }
277
- potential_cotenants = cotenants_map.get(model_name, None)
278
- if potential_cotenants is None:
279
- return False
280
- for existing_cotenant in self.active_models_ids:
281
- if existing_cotenant not in potential_cotenants:
282
- return False
283
- return True
284
-
285
- @torch.compiler.disable()
286
- def gpu_load_blocks(self, model_name, blocks_name, async_load=False):
287
- if blocks_name != None:
288
- self.loaded_blocks[model_name] = blocks_name
289
-
290
- def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream=None):
291
- with torch.cuda.stream(stream_to_use):
292
- for p in blocks_params:
293
- if isinstance(p, AffineQuantizedTensor):
294
- p.tensor_impl.float8_data = p.tensor_impl.float8_data.cuda(
295
- non_blocking=True, device=self.device_id
296
- )
297
- p.tensor_impl.scale = p.tensor_impl.scale.cuda(non_blocking=True, device=self.device_id)
298
- else:
299
- p.data = p.data.cuda(non_blocking=True, device=self.device_id)
300
-
301
- if record_for_stream != None:
302
- if isinstance(p, AffineQuantizedTensor):
303
- p.tensor_impl.float8_data.record_stream(record_for_stream)
304
- p.tensor_impl.scale.record_stream(record_for_stream)
305
- else:
306
- p.data.record_stream(record_for_stream)
307
-
308
- entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
309
- if self.verboseLevel >= 2:
310
- model = self.models[model_name]
311
- model_name = model._get_name()
312
- print(f"Loading model {entry_name} ({model_name}) in GPU")
313
-
314
- if self.async_transfers and blocks_name != None:
315
- first = self.prev_blocks_names[entry_name] == None
316
- next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
317
- if first:
318
- cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
319
- torch.cuda.synchronize()
320
-
321
- if next_blocks_entry != None:
322
- cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry])
323
 
324
- else:
325
- cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
326
- torch.cuda.synchronize()
327
-
328
- @torch.compiler.disable()
329
- def gpu_unload_blocks(self, model_name, blocks_name):
330
- if blocks_name != None:
331
- self.loaded_blocks[model_name] = None
332
-
333
- blocks_name = model_name if blocks_name is None else model_name + "/" + blocks_name
334
-
335
- if self.verboseLevel >= 2:
336
- model = self.models[model_name]
337
- model_name = model._get_name()
338
- print(f"Unloading model {blocks_name} ({model_name}) from GPU")
339
-
340
- blocks_params = self.blocks_of_modules[blocks_name]
341
-
342
- if model_name in self.pinned_modules_data:
343
- pinned_parameters_data = self.pinned_modules_data[model_name]
344
- for p in blocks_params:
345
- if isinstance(p, AffineQuantizedTensor):
346
- data = pinned_parameters_data[p]
347
- p.tensor_impl.float8_data = data[0]
348
- p.tensor_impl.scale = data[1]
349
- else:
350
- p.data = pinned_parameters_data[p]
351
- else:
352
- for p in blocks_params:
353
- if isinstance(p, AffineQuantizedTensor):
354
- p.tensor_impl.float8_data = p.tensor_impl.float8_data.cpu()
355
- p.tensor_impl.scale = p.tensor_impl.scale.cpu()
356
- else:
357
- p.data = p.data.cpu()
358
-
359
- @torch.compiler.disable()
360
- def gpu_load(self, model_name):
361
- model = self.models[model_name]
362
- self.active_models.append(model)
363
- self.active_models_ids.append(model_name)
364
-
365
- self.gpu_load_blocks(model_name, None)
366
-
367
- # torch.cuda.current_stream().synchronize()
368
-
369
- @torch.compiler.disable()
370
- def unload_all(self, model_name: str):
371
- if len(self.active_models_ids) == 0 and self.last_run_model == model_name:
372
- self.last_run_model = model_name
373
- return
374
- for model_name in self.active_models_ids:
375
- self.gpu_unload_blocks(model_name, None)
376
- loaded_block = self.loaded_blocks[model_name]
377
- if loaded_block != None:
378
- self.gpu_unload_blocks(model_name, loaded_block)
379
- self.loaded_blocks[model_name] = None
380
-
381
- self.active_models = []
382
- self.active_models_ids = []
383
- self.active_subcaches = []
384
  torch.cuda.empty_cache()
385
- gc.collect()
386
- self.last_reserved_mem_check = time.time()
387
- self.last_run_model = model_name
388
-
389
- def move_args_to_gpu(self, *args, **kwargs):
390
- new_args = []
391
- new_kwargs = {}
392
- for arg in args:
393
- if torch.is_tensor(arg):
394
- if arg.dtype == torch.float32:
395
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
396
- else:
397
- arg = arg.cuda(non_blocking=True, device=self.device_id)
398
- new_args.append(arg)
399
-
400
- for k in kwargs:
401
- arg = kwargs[k]
402
- if torch.is_tensor(arg):
403
- if arg.dtype == torch.float32:
404
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
405
- else:
406
- arg = arg.cuda(non_blocking=True, device=self.device_id)
407
- new_kwargs[k] = arg
408
-
409
- return new_args, new_kwargs
410
-
411
- def ready_to_check_mem(self):
412
- if self.compile:
413
- return
414
- cur_clock = time.time()
415
- # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
416
- if (cur_clock - self.last_reserved_mem_check) < 0.200:
417
- return False
418
- self.last_reserved_mem_check = cur_clock
419
- return True
420
-
421
- def empty_cache_if_needed(self):
422
- mem_reserved = torch.cuda.memory_reserved()
423
- mem_threshold = 0.9 * self.device_mem_capacity
424
- if mem_reserved >= mem_threshold:
425
- mem_allocated = torch.cuda.memory_allocated()
426
- if mem_allocated <= 0.70 * mem_reserved:
427
- torch.cuda.empty_cache()
428
- tm = time.time()
429
- if self.verboseLevel >= 2:
430
- print(f"Empty Cuda cache at {tm}")
431
-
432
- def any_param_or_buffer(self, target_module: torch.nn.Module):
433
-
434
- for _ in target_module.parameters(recurse=False):
435
- return True
436
-
437
- for _ in target_module.buffers(recurse=False):
438
- return True
439
-
440
- return False
441
-
442
- def hook_me_light(self, target_module, model_name, blocks_name, previous_method, context):
443
-
444
- anyParam = self.any_param_or_buffer(target_module)
445
-
446
- def check_empty_cuda_cache(module, *args, **kwargs):
447
- if self.ready_to_check_mem():
448
- self.empty_cache_if_needed()
449
- return previous_method(*args, **kwargs)
450
-
451
- def load_module_blocks(module, *args, **kwargs):
452
- if blocks_name == None:
453
- if self.ready_to_check_mem():
454
- self.empty_cache_if_needed()
455
- else:
456
- loaded_block = self.loaded_blocks[model_name]
457
- if loaded_block == None or loaded_block != blocks_name:
458
- if loaded_block != None:
459
- self.gpu_unload_blocks(model_name, loaded_block)
460
- if self.ready_to_check_mem():
461
- self.empty_cache_if_needed()
462
- self.loaded_blocks[model_name] = blocks_name
463
- self.gpu_load_blocks(model_name, blocks_name)
464
- return previous_method(*args, **kwargs)
465
-
466
- if hasattr(target_module, "_mm_id"):
467
- orig_model_name = getattr(target_module, "_mm_id")
468
- if self.verboseLevel >= 2:
469
- print(
470
- f"Model '{model_name}' shares module '{target_module._get_name()}' with module '{orig_model_name}' "
471
- )
472
- assert not anyParam
473
  return
474
- setattr(target_module, "_mm_id", model_name)
475
 
476
- if blocks_name != None and anyParam:
477
- setattr(
478
- target_module,
479
- "forward",
480
- functools.update_wrapper(functools.partial(load_module_blocks, target_module), previous_method),
481
- )
482
- # print(f"new cache:{blocks_name}")
 
483
  else:
484
- setattr(
485
- target_module,
486
- "forward",
487
- functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method),
 
 
 
 
488
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
- def hook_me(self, target_module, model, model_name, module_id, previous_method):
491
- def check_change_module(module, *args, **kwargs):
492
- performEmptyCacheTest = False
493
- if not model_name in self.active_models_ids:
494
- new_model_name = getattr(module, "_mm_id")
495
- if not self.can_model_be_cotenant(new_model_name):
496
- self.unload_all(model_name)
497
- performEmptyCacheTest = False
498
- self.gpu_load(new_model_name)
499
- args, kwargs = self.move_args_to_gpu(*args, **kwargs)
500
- if performEmptyCacheTest:
501
- self.empty_cache_if_needed()
502
- return previous_method(*args, **kwargs)
503
-
504
- if hasattr(target_module, "_mm_id"):
505
- return
506
- setattr(target_module, "_mm_id", model_name)
507
-
508
- setattr(
509
- target_module,
510
- "forward",
511
- functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  )
 
513
 
514
- if not self.verboseLevel >= 1:
515
- return
516
 
517
- if module_id == None or module_id == "":
518
- model_name = model._get_name()
519
- print(f"Hooked in model '{model_name}' ({model_name})")
 
1
+ import spaces
2
+ import gradio as gr
3
+ import argparse
4
+ import sys
5
  import os
6
+ import random
7
+ import subprocess
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ # Removed environment-specific lines
12
+ from diffusers.utils import export_to_video
13
+ from diffusers.utils import load_image
14
 
15
  import torch
16
+ import logging
17
+ from collections import OrderedDict
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = False
20
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
21
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
22
+ torch.backends.cudnn.allow_tf32 = False
23
+ torch.backends.cudnn.deterministic = False
24
+ torch.backends.cudnn.benchmark = False
25
+ torch.set_float32_matmul_precision("highest")
26
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
+
28
+ logger = logging.getLogger(__name__)
29
 
30
 
31
+ # --- Dummy Classes (Keep for standalone execution) ---
32
  class OffloadConfig:
33
+ def __init__(
34
+ self,
35
+ high_cpu_memory: bool = False,
36
+ parameters_level: bool = False,
37
+ compiler_transformer: bool = False,
38
+ compiler_cache: str = "",
39
+ ):
40
+ self.high_cpu_memory = high_cpu_memory
41
+ self.parameters_level = parameters_level
42
+ self.compiler_transformer = compiler_transformer
43
+ self.compiler_cache = compiler_cache
44
+
45
+
46
+ class TaskType: # Keep here for infer
47
+ T2V = 0
48
+ I2V = 1
49
+
50
+
51
+ class LlamaModel:
52
+ @staticmethod
53
+ def from_pretrained(*args, **kwargs):
54
+ return LlamaModel()
55
+
56
+ def to(self, device):
57
+ return self
58
 
59
 
60
+ class HunyuanVideoTransformer3DModel:
61
+ @staticmethod
62
+ def from_pretrained(*args, **kwargs):
63
+ return HunyuanVideoTransformer3DModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def to(self, device):
66
+ return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+
69
+ class SkyreelsVideoPipeline:
70
+ @staticmethod
71
+ def from_pretrained(*args, **kwargs):
72
+ return SkyreelsVideoPipeline()
73
+
74
+ def to(self, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return self
76
 
77
+ def __call__(self, *args, **kwargs):
78
+ num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
79
+ height = kwargs.get("height", 512)
80
+ width = kwargs.get("width", 512)
81
 
82
+ if "image" in kwargs: # I2V
83
+ image = kwargs["image"]
84
+ # Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
85
+ image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
86
+ image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # Create video by repeating the image
89
+ frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
90
+ frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
91
+ # Correct shape: (1, C, T, H, W) - NO PERMUTE HERE
92
+
93
+ else: # T2V
94
+ frames = torch.randn(1, 3, num_frames, height, width) # (1, C, T, H, W) - Correct!
95
+
96
+ return type("obj", (object,), {"frames": frames})() # No longer a list!
97
+
98
+ def __init__(self):
99
+ super().__init__()
100
+ self._modules = OrderedDict()
101
+ self.vae = self.VAE()
102
+ self._modules["vae"] = self.vae
103
+
104
+ def named_children(self):
105
+ return self._modules.items()
106
+
107
+ class VAE:
108
+ def enable_tiling(self):
109
+ pass
110
+
111
+
112
+ def quantize_(*args, **kwargs):
113
+ return
114
+
115
+
116
+ def float8_weight_only():
117
+ return
118
+
119
+
120
+ # --- End Dummy Classes ---
121
+
122
+
123
+ class SkyReelsVideoSingleGpuInfer:
124
+ def _load_model(
125
+ self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
126
+ ):
127
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
128
+ text_encoder = LlamaModel.from_pretrained(
129
+ base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
130
+ ).to("cpu")
131
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
132
+ model_id, torch_dtype=torch.bfloat16, device="cpu"
133
+ ).to("cpu")
134
+
135
+ if quant_model:
136
+ quantize_(text_encoder, float8_weight_only())
137
+ text_encoder.to("cpu")
138
+ torch.cuda.empty_cache()
139
+ quantize_(transformer, float8_weight_only())
140
+ transformer.to("cpu")
141
+ torch.cuda.empty_cache()
142
+
143
+ pipe = SkyreelsVideoPipeline.from_pretrained(
144
+ base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
145
+ ).to("cpu")
146
+ pipe.vae.enable_tiling()
 
147
  torch.cuda.empty_cache()
148
+ return pipe
149
+
150
+ def __init__(
151
+ self,
152
+ task_type: TaskType,
153
+ model_id: str,
154
+ quant_model: bool = True,
155
+ is_offload: bool = True,
156
+ offload_config: OffloadConfig = OffloadConfig(),
157
+ enable_cfg_parallel: bool = True,
158
+ ):
159
+ self.task_type = task_type
160
+ self.model_id = model_id
161
+ self.quant_model = quant_model
162
+ self.is_offload = is_offload
163
+ self.offload_config = offload_config
164
+ self.enable_cfg_parallel = enable_cfg_parallel
165
+ self.pipe = None
166
+ self.is_initialized = False
167
+ self.gpu_device = None
168
+
169
+ def initialize(self):
170
+ """Initializes the model and moves it to the GPU."""
171
+ if self.is_initialized:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  return
 
173
 
174
+ if not torch.cuda.is_available():
175
+ raise RuntimeError("CUDA is not available. Cannot initialize model.")
176
+
177
+ self.gpu_device = "cuda:0"
178
+ self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
179
+
180
+ if self.is_offload:
181
+ pass
182
  else:
183
+ self.pipe.to(self.gpu_device)
184
+
185
+ if self.offload_config.compiler_transformer:
186
+ torch._dynamo.config.suppress_errors = True
187
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
188
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}"
189
+ self.pipe.transformer = torch.compile(
190
+ self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
191
  )
192
+ if self.offload_config.compiler_transformer:
193
+ self.warm_up()
194
+ self.is_initialized = True
195
+
196
+ def warm_up(self):
197
+ if not self.is_initialized:
198
+ raise RuntimeError("Model must be initialized before warm-up.")
199
+
200
+ init_kwargs = {
201
+ "prompt": "A woman is dancing in a room",
202
+ "height": 544,
203
+ "width": 960,
204
+ "guidance_scale": 6,
205
+ "num_inference_steps": 1,
206
+ "negative_prompt": "bad quality",
207
+ "num_frames": 16,
208
+ "generator": torch.Generator(self.gpu_device).manual_seed(42),
209
+ "embedded_guidance_scale": 1.0,
210
+ }
211
+ if self.task_type == TaskType.I2V:
212
+ init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
213
+ self.pipe(**init_kwargs)
214
+ logger.info("Warm-up complete.")
215
+
216
+ def infer(self, **kwargs):
217
+ """Handles inference requests."""
218
+ if not self.is_initialized:
219
+ self.initialize()
220
+ if "seed" in kwargs:
221
+ kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
222
+ del kwargs["seed"]
223
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
224
+ result = self.pipe(**kwargs).frames # Return the tensor directly
225
+ return result
226
+
227
+
228
+ _predictor = None
229
+
230
+
231
+ @spaces.GPU(duration=90)
232
+ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
233
+ """Generates a video based on the given prompt and seed.
234
+
235
+ Args:
236
+ prompt: The text prompt to guide video generation.
237
+ seed: The random seed for reproducibility.
238
+ image: Optional path to an image for Image-to-Video.
239
+
240
+ Returns:
241
+ A tuple containing the path to the generated video and the parameters used.
242
+ """
243
+ global _predictor
244
+
245
+ if seed == -1:
246
+ random.seed()
247
+ seed = int(random.randrange(4294967294))
248
+
249
+ if image is None:
250
+ task_type = TaskType.T2V
251
+ model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
252
+ kwargs = {
253
+ "prompt": prompt,
254
+ "height": 512,
255
+ "width": 512,
256
+ "num_frames": 16,
257
+ "num_inference_steps": 30,
258
+ "seed": seed,
259
+ "guidance_scale": 7.5,
260
+ "negative_prompt": "bad quality, worst quality",
261
+ }
262
+ else:
263
+ task_type = TaskType.I2V
264
+ model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
265
+ kwargs = {
266
+ "prompt": prompt,
267
+ "image": load_image(image),
268
+ "height": 512,
269
+ "width": 512,
270
+ "num_frames": 97,
271
+ "num_inference_steps": 30,
272
+ "seed": seed,
273
+ "guidance_scale": 6.0,
274
+ "embedded_guidance_scale": 1.0,
275
+ "negative_prompt": "Aerial view, low quality, bad hands",
276
+ "cfg_for": False,
277
+ }
278
 
279
+ if _predictor is None:
280
+ _predictor = SkyReelsVideoSingleGpuInfer(
281
+ task_type=task_type,
282
+ model_id=model_id,
283
+ quant_model=True,
284
+ is_offload=True,
285
+ offload_config=OffloadConfig(
286
+ high_cpu_memory=True,
287
+ parameters_level=True,
288
+ compiler_transformer=False,
289
+ ),
290
+ )
291
+ _predictor.initialize()
292
+ logger.info("Predictor initialized")
293
+
294
+ with torch.no_grad():
295
+ output = _predictor.infer(**kwargs)
296
+ '''
297
+ output = (output.numpy() * 255).astype(np.uint8)
298
+ # Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C)
299
+ output = output.transpose(0, 2, 3, 4, 1)
300
+ output = output[0] # Remove batch dimension: (T, H, W, C)
301
+ '''
302
+
303
+ save_dir = f"./result"
304
+ os.makedirs(save_dir, exist_ok=True)
305
+ video_out_file = f"{save_dir}/{seed}.mp4"
306
+ print(f"generate video, local path: {video_out_file}")
307
+ export_to_video(output, video_out_file, fps=24)
308
+ return video_out_file, kwargs
309
+
310
+
311
+ def create_gradio_interface():
312
+ with gr.Blocks() as demo:
313
+ with gr.Row():
314
+ with gr.Column():
315
+ image = gr.Image(label="Upload Image", type="filepath")
316
+ prompt = gr.Textbox(label="Input Prompt")
317
+ seed = gr.Number(label="Random Seed", value=-1)
318
+ with gr.Column():
319
+ submit_button = gr.Button("Generate Video")
320
+ output_video = gr.Video(label="Generated Video")
321
+ output_params = gr.Textbox(label="Output Parameters")
322
+
323
+ submit_button.click(
324
+ fn=generate_video,
325
+ inputs=[prompt, seed, image],
326
+ outputs=[output_video, output_params],
327
  )
328
+ return demo
329
 
 
 
330
 
331
+ if __name__ == "__main__":
332
+ demo = create_gradio_interface()
333
+ demo.queue().launch()