File size: 4,781 Bytes
846c31c
3262ed7
846c31c
 
 
 
3262ed7
11907e7
 
 
 
 
3262ed7
 
3587696
 
3262ed7
846c31c
3262ed7
846c31c
 
 
 
 
 
 
 
 
3262ed7
846c31c
 
 
 
 
3262ed7
 
 
846c31c
 
 
 
275ebd1
846c31c
 
3262ed7
846c31c
275ebd1
 
846c31c
206d706
275ebd1
 
206d706
275ebd1
 
846c31c
 
 
 
 
275ebd1
846c31c
275ebd1
846c31c
 
3262ed7
 
 
 
 
 
 
 
 
3587696
3af83e0
3587696
3ee15e0
3262ed7
 
 
 
 
 
846c31c
 
3262ed7
 
 
 
 
 
 
 
3587696
3262ed7
 
 
 
846c31c
3262ed7
846c31c
3262ed7
 
 
134291d
 
3262ed7
 
 
 
 
 
 
846c31c
134291d
3262ed7
 
 
3587696
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import os
import time
from datetime import timedelta
from typing import Any
from typing import Dict

import torch
from diffusers import HunyuanVideoTransformer3DModel
from PIL import Image
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from transformers import LlamaModel

from . import TaskType  # Assuming these are still needed
from .offload import Offload, OffloadConfig
from .pipelines import SkyreelsVideoPipeline

logger = logging.getLogger("SkyreelsVideoInfer")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
    f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

class SkyReelsVideoSingleGpuInfer:
    def _load_model(
        self,
        model_id: str,
        base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
        quant_model: bool = True,
        gpu_device: str = "cuda:0",
    ) -> SkyreelsVideoPipeline:
        logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
        text_encoder = LlamaModel.from_pretrained(
            base_model_id,
            subfolder="text_encoder",
            torch_dtype=torch.bfloat16,
        ).to("cpu")
        transformer = HunyuanVideoTransformer3DModel.from_pretrained(
            model_id,
            # subfolder="transformer",
            torch_dtype=torch.bfloat16,
            device="cpu",
        ).to("cpu")
        if quant_model:
            quantize_(text_encoder, float8_weight_only(), device="cpu")
            text_encoder.to("cpu")
            torch.cuda.empty_cache()
            quantize_(transformer, float8_weight_only(), device="cpu")
            transformer.to("cpu")
            torch.cuda.empty_cache()
        pipe = SkyreelsVideoPipeline.from_pretrained(
            base_model_id,
            transformer=transformer,
            text_encoder=text_encoder,
            torch_dtype=torch.bfloat16,
        ).to("cpu")
        pipe.vae.enable_tiling()
        torch.cuda.empty_cache()
        return pipe

    def __init__(
        self,
        task_type: TaskType,
        model_id: str,
        quant_model: bool = True,
        is_offload: bool = True,
        offload_config: OffloadConfig = OffloadConfig(),
    ):
        self.task_type = task_type
        # os.environ["LOCAL_RANK"] = "0"   # No longer needed in single-GPU
        #torch.cuda.set_device(0)  # Still a good idea to be explicit.
        torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
        gpu_device = "cuda:0"

        self.pipe: SkyreelsVideoPipeline = self._load_model(
            model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
        )

        if is_offload:
            Offload.offload(
                pipeline=self.pipe,
                config=offload_config,
            )
        else:
            self.pipe.to(gpu_device)

        if offload_config.compiler_transformer:
            torch._dynamo.config.suppress_errors = True
            os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
            os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
            self.pipe.transformer = torch.compile(
                self.pipe.transformer,
                mode="max-autotune-no-cudagraphs",
                dynamic=True,
            )
            self.warm_up()

    def warm_up(self):
        init_kwargs = {
            "prompt": "A woman is dancing in a room",
            "height": 512,
            "width": 512,
            "guidance_scale": 6,
            "num_inference_steps": 1,
            "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
            "num_frames": 97,
            "generator": torch.Generator("cuda").manual_seed(42),
            "embedded_guidance_scale": 1.0,
        }
        if self.task_type == TaskType.I2V:
            init_kwargs["image"] = Image.new("RGB", (512, 512), color="black")
        self.pipe(**init_kwargs)

    def inference(self, kwargs: Dict[str, Any]):
        logger.info(f"kwargs: {kwargs}")
        if "seed" in kwargs:
            kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
            del kwargs["seed"]
        start_time = time.time()
        assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
        out = self.pipe(**kwargs).frames[0]
        logger.info(f"inference time: {time.time() - start_time}")
        return out