LPX55 commited on
Commit
77fb855
·
verified ·
1 Parent(s): b8b83a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Union
9
  from PIL import Image
10
  import io
11
  from io import BytesIO
 
12
  from diffusers import HunyuanVideoPipeline, FlowMatchEulerDiscreteScheduler
13
  from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel
14
  from diffusers.utils import export_to_video
@@ -39,7 +40,20 @@ video_transforms = transforms.Compose(
39
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
40
  ]
41
  )
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
43
  model_id = "hunyuanvideo-community/HunyuanVideo"
44
  lora_path = hf_hub_download("dashtoon/hunyuan-video-keyframe-control-lora", "i2v.sft")
45
  # lora_path = "./cache/i2v.sft"
 
9
  from PIL import Image
10
  import io
11
  from io import BytesIO
12
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
13
  from diffusers import HunyuanVideoPipeline, FlowMatchEulerDiscreteScheduler
14
  from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel
15
  from diffusers.utils import export_to_video
 
40
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
41
  ]
42
  )
43
+ quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
44
+ transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
45
+ "hunyuanvideo-community/HunyuanVideo",
46
+ subfolder="transformer",
47
+ quantization_config=quant_config,
48
+ torch_dtype=torch.bfloat16,
49
+ )
50
 
51
+ pipeline = HunyuanVideoPipeline.from_pretrained(
52
+ "hunyuanvideo-community/HunyuanVideo",
53
+ transformer=transformer_8bit,
54
+ torch_dtype=torch.float16,
55
+ device_map="balanced",
56
+ )
57
  model_id = "hunyuanvideo-community/HunyuanVideo"
58
  lora_path = hf_hub_download("dashtoon/hunyuan-video-keyframe-control-lora", "i2v.sft")
59
  # lora_path = "./cache/i2v.sft"