File size: 4,790 Bytes
846c31c
cce4ca7
846c31c
 
 
 
cce4ca7
11907e7
 
 
 
 
cce4ca7
a05a8a4
cce4ca7
 
 
846c31c
cce4ca7
846c31c
 
 
 
 
 
 
 
 
cce4ca7
846c31c
 
 
 
 
cce4ca7
 
 
846c31c
 
 
b7b8102
cce4ca7
846c31c
 
cce4ca7
b7b8102
cce4ca7
3b62245
846c31c
cce4ca7
 
7a75277
cce4ca7
 
7a75277
846c31c
 
 
 
b7b8102
cce4ca7
846c31c
cce4ca7
846c31c
 
cce4ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846c31c
cce4ca7
 
846c31c
 
cce4ca7
846c31c
cce4ca7
 
846c31c
cce4ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846c31c
cce4ca7
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import os
import time
from datetime import timedelta
from typing import Any
from typing import Dict

import torch
from diffusers import HunyuanVideoTransformer3DModel
from PIL import Image
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from transformers import LlamaModel

from . import TaskType  # Assuming these are still needed
from .offload import Offload, OffloadConfig
from .pipelines import SkyreelsVideoPipeline

logger = logging.getLogger("SkyreelsVideoInfer")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
    f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

class SkyReelsVideoSingleGpuInfer:
    def _load_model(
        self,
        model_id: str,
        base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
        quant_model: bool = True,
        gpu_device: str = "cuda:0",
    ) -> SkyreelsVideoPipeline:
        logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
        text_encoder = LlamaModel.from_pretrained(
            base_model_id,
            subfolder="text_encoder",
            torch_dtype=torch.bfloat16,
        ).to("cpu")
        transformer = HunyuanVideoTransformer3DModel.from_pretrained(
            model_id,
            # subfolder="transformer",
            torch_dtype=torch.bfloat16,
            device="cpu",
        ).to("cpu").eval()
        if quant_model:
            quantize_(text_encoder, float8_weight_only(), device="cpu")
            text_encoder.to("cpu")
            #torch.cuda.empty_cache()
            quantize_(transformer, float8_weight_only(), device="cpu")
            transformer.to("cpu")
            #torch.cuda.empty_cache()
        pipe = SkyreelsVideoPipeline.from_pretrained(
            base_model_id,
            transformer=transformer,
            text_encoder=text_encoder,
            torch_dtype=torch.bfloat16,
        ).to("cpu")
        pipe.vae.enable_tiling()
        torch.cuda.empty_cache()
        return pipe

    def __init__(
        self,
        task_type: TaskType,
        model_id: str,
        quant_model: bool = True,
        is_offload: bool = True,
        offload_config: OffloadConfig = OffloadConfig(),
    ):
        self.task_type = task_type
        # os.environ["LOCAL_RANK"] = "0"   # No longer needed in single-GPU
        #torch.cuda.set_device(0)  # Still a good idea to be explicit.
        torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
        gpu_device = "cuda:0"

        self.pipe: SkyreelsVideoPipeline = self._load_model(
            model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
        )

        if is_offload:
            Offload.offload(
                pipeline=self.pipe,
                config=offload_config,
            )
        else:
            self.pipe.to(gpu_device)

        if offload_config.compiler_transformer:
            torch._dynamo.config.suppress_errors = True
            os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
            os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
            self.pipe.transformer = torch.compile(
                self.pipe.transformer,
                mode="max-autotune-no-cudagraphs",
                dynamic=True,
            )
            self.warm_up()

    def warm_up(self):
        init_kwargs = {
            "prompt": "A woman is dancing in a room",
            "height": 512,
            "width": 512,
            "guidance_scale": 6,
            "num_inference_steps": 1,
            "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
            "num_frames": 97,
            "generator": torch.Generator("cuda").manual_seed(42),
            "embedded_guidance_scale": 1.0,
        }
        if self.task_type == TaskType.I2V:
            init_kwargs["image"] = Image.new("RGB", (512, 512), color="black")
        self.pipe(**init_kwargs)

    def inference(self, kwargs: Dict[str, Any]):
        logger.info(f"kwargs: {kwargs}")
        if "seed" in kwargs:
            kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
            del kwargs["seed"]
        start_time = time.time()
        assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
        out = self.pipe(**kwargs).frames[0]
        logger.info(f"inference time: {time.time() - start_time}")
        return out