Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,790 Bytes
846c31c cce4ca7 846c31c cce4ca7 11907e7 cce4ca7 a05a8a4 cce4ca7 846c31c cce4ca7 846c31c cce4ca7 846c31c cce4ca7 846c31c b7b8102 cce4ca7 846c31c cce4ca7 b7b8102 cce4ca7 3b62245 846c31c cce4ca7 7a75277 cce4ca7 7a75277 846c31c b7b8102 cce4ca7 846c31c cce4ca7 846c31c cce4ca7 846c31c cce4ca7 846c31c cce4ca7 846c31c cce4ca7 846c31c cce4ca7 846c31c cce4ca7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import logging
import os
import time
from datetime import timedelta
from typing import Any
from typing import Dict
import torch
from diffusers import HunyuanVideoTransformer3DModel
from PIL import Image
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from transformers import LlamaModel
from . import TaskType # Assuming these are still needed
from .offload import Offload, OffloadConfig
from .pipelines import SkyreelsVideoPipeline
logger = logging.getLogger("SkyreelsVideoInfer")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SkyReelsVideoSingleGpuInfer:
def _load_model(
self,
model_id: str,
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
quant_model: bool = True,
gpu_device: str = "cuda:0",
) -> SkyreelsVideoPipeline:
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
text_encoder = LlamaModel.from_pretrained(
base_model_id,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
).to("cpu")
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id,
# subfolder="transformer",
torch_dtype=torch.bfloat16,
device="cpu",
).to("cpu").eval()
if quant_model:
quantize_(text_encoder, float8_weight_only(), device="cpu")
text_encoder.to("cpu")
#torch.cuda.empty_cache()
quantize_(transformer, float8_weight_only(), device="cpu")
transformer.to("cpu")
#torch.cuda.empty_cache()
pipe = SkyreelsVideoPipeline.from_pretrained(
base_model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16,
).to("cpu")
pipe.vae.enable_tiling()
torch.cuda.empty_cache()
return pipe
def __init__(
self,
task_type: TaskType,
model_id: str,
quant_model: bool = True,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
):
self.task_type = task_type
# os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU
#torch.cuda.set_device(0) # Still a good idea to be explicit.
torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
gpu_device = "cuda:0"
self.pipe: SkyreelsVideoPipeline = self._load_model(
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
)
if is_offload:
Offload.offload(
pipeline=self.pipe,
config=offload_config,
)
else:
self.pipe.to(gpu_device)
if offload_config.compiler_transformer:
torch._dynamo.config.suppress_errors = True
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
self.pipe.transformer = torch.compile(
self.pipe.transformer,
mode="max-autotune-no-cudagraphs",
dynamic=True,
)
self.warm_up()
def warm_up(self):
init_kwargs = {
"prompt": "A woman is dancing in a room",
"height": 512,
"width": 512,
"guidance_scale": 6,
"num_inference_steps": 1,
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
"num_frames": 97,
"generator": torch.Generator("cuda").manual_seed(42),
"embedded_guidance_scale": 1.0,
}
if self.task_type == TaskType.I2V:
init_kwargs["image"] = Image.new("RGB", (512, 512), color="black")
self.pipe(**init_kwargs)
def inference(self, kwargs: Dict[str, Any]):
logger.info(f"kwargs: {kwargs}")
if "seed" in kwargs:
kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
del kwargs["seed"]
start_time = time.time()
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
out = self.pipe(**kwargs).frames[0]
logger.info(f"inference time: {time.time() - start_time}")
return out |