1inkusFace commited on
Commit
b7b8102
·
verified ·
1 Parent(s): cce4ca7

Update skyreelsinfer/skyreels_video_infer.py

Browse files
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -38,12 +38,12 @@ class SkyReelsVideoSingleGpuInfer:
38
  text_encoder = LlamaModel.from_pretrained(
39
  base_model_id,
40
  subfolder="text_encoder",
41
- torch_dtype=torch.float16,
42
  ).to("cpu")
43
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
44
  model_id,
45
  # subfolder="transformer",
46
- torch_dtype=torch.float16,
47
  device="cpu",
48
  ).to("cpu")
49
  if quant_model:
@@ -57,7 +57,7 @@ class SkyReelsVideoSingleGpuInfer:
57
  base_model_id,
58
  transformer=transformer,
59
  text_encoder=text_encoder,
60
- torch_dtype=torch.float16,
61
  ).to("cpu")
62
  pipe.vae.enable_tiling()
63
  torch.cuda.empty_cache()
 
38
  text_encoder = LlamaModel.from_pretrained(
39
  base_model_id,
40
  subfolder="text_encoder",
41
+ torch_dtype=torch.bfloat16,
42
  ).to("cpu")
43
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
44
  model_id,
45
  # subfolder="transformer",
46
+ torch_dtype=torch.bfloat16,
47
  device="cpu",
48
  ).to("cpu")
49
  if quant_model:
 
57
  base_model_id,
58
  transformer=transformer,
59
  text_encoder=text_encoder,
60
+ torch_dtype=torch.bfloat16,
61
  ).to("cpu")
62
  pipe.vae.enable_tiling()
63
  torch.cuda.empty_cache()