Spaces:
Paused
Paused
File size: 4,781 Bytes
846c31c 3262ed7 846c31c 3262ed7 11907e7 3262ed7 3587696 3262ed7 846c31c 3262ed7 846c31c 3262ed7 846c31c 3262ed7 846c31c 275ebd1 846c31c 3262ed7 846c31c 275ebd1 846c31c 206d706 275ebd1 206d706 275ebd1 846c31c 275ebd1 846c31c 275ebd1 846c31c 3262ed7 3587696 3af83e0 3587696 3ee15e0 3262ed7 846c31c 3262ed7 3587696 3262ed7 846c31c 3262ed7 846c31c 3262ed7 134291d 3262ed7 846c31c 134291d 3262ed7 3587696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import logging
import os
import time
from datetime import timedelta
from typing import Any
from typing import Dict
import torch
from diffusers import HunyuanVideoTransformer3DModel
from PIL import Image
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from transformers import LlamaModel
from . import TaskType # Assuming these are still needed
from .offload import Offload, OffloadConfig
from .pipelines import SkyreelsVideoPipeline
logger = logging.getLogger("SkyreelsVideoInfer")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SkyReelsVideoSingleGpuInfer:
def _load_model(
self,
model_id: str,
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
quant_model: bool = True,
gpu_device: str = "cuda:0",
) -> SkyreelsVideoPipeline:
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
text_encoder = LlamaModel.from_pretrained(
base_model_id,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
).to("cpu")
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id,
# subfolder="transformer",
torch_dtype=torch.bfloat16,
device="cpu",
).to("cpu")
if quant_model:
quantize_(text_encoder, float8_weight_only(), device="cpu")
text_encoder.to("cpu")
torch.cuda.empty_cache()
quantize_(transformer, float8_weight_only(), device="cpu")
transformer.to("cpu")
torch.cuda.empty_cache()
pipe = SkyreelsVideoPipeline.from_pretrained(
base_model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16,
).to("cpu")
pipe.vae.enable_tiling()
torch.cuda.empty_cache()
return pipe
def __init__(
self,
task_type: TaskType,
model_id: str,
quant_model: bool = True,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
):
self.task_type = task_type
# os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU
#torch.cuda.set_device(0) # Still a good idea to be explicit.
torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
gpu_device = "cuda:0"
self.pipe: SkyreelsVideoPipeline = self._load_model(
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
)
if is_offload:
Offload.offload(
pipeline=self.pipe,
config=offload_config,
)
else:
self.pipe.to(gpu_device)
if offload_config.compiler_transformer:
torch._dynamo.config.suppress_errors = True
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
self.pipe.transformer = torch.compile(
self.pipe.transformer,
mode="max-autotune-no-cudagraphs",
dynamic=True,
)
self.warm_up()
def warm_up(self):
init_kwargs = {
"prompt": "A woman is dancing in a room",
"height": 512,
"width": 512,
"guidance_scale": 6,
"num_inference_steps": 1,
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
"num_frames": 97,
"generator": torch.Generator("cuda").manual_seed(42),
"embedded_guidance_scale": 1.0,
}
if self.task_type == TaskType.I2V:
init_kwargs["image"] = Image.new("RGB", (512, 512), color="black")
self.pipe(**init_kwargs)
def inference(self, kwargs: Dict[str, Any]):
logger.info(f"kwargs: {kwargs}")
if "seed" in kwargs:
kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
del kwargs["seed"]
start_time = time.time()
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
out = self.pipe(**kwargs).frames[0]
logger.info(f"inference time: {time.time() - start_time}")
return out |