File size: 631 Bytes
daf9c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from diffusers import (
    DiffusionPipeline,
    AutoencoderKL,
)
from diffusers.schedulers import *


def load_common():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # VAE n Refiner
    sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
    refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
    refiner.enable_model_cpu_offload()
    
    return refiner, sdxl_vae

refiner, sdxl_vae = load_common()