cocktailpeanut commited on
Commit
8a7960d
·
1 Parent(s): 9aab28d
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -5,23 +5,25 @@ from SDXL.diff_pipe import StableDiffusionXLDiffImg2ImgPipeline
5
  from diffusers import DPMSolverMultistepScheduler
6
 
7
  NUM_INFERENCE_STEPS = 50
 
8
  if torch.cuda.is_available():
9
  device = "cuda"
10
  elif torch.backends.mps.is_available():
11
  device = "mps"
 
12
  else:
13
  device = "cpu"
14
  #device = "cuda"
15
 
16
  base = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
17
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
18
  )
19
 
20
  refiner = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
21
  "stabilityai/stable-diffusion-xl-refiner-1.0",
22
  text_encoder_2=base.text_encoder_2,
23
  vae=base.vae,
24
- torch_dtype=torch.float16,
25
  use_safetensors=True,
26
  variant="fp16",
27
  )
 
5
  from diffusers import DPMSolverMultistepScheduler
6
 
7
  NUM_INFERENCE_STEPS = 50
8
+ dtype = torch.float16
9
  if torch.cuda.is_available():
10
  device = "cuda"
11
  elif torch.backends.mps.is_available():
12
  device = "mps"
13
+ dtype = torch.float32
14
  else:
15
  device = "cpu"
16
  #device = "cuda"
17
 
18
  base = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
19
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True
20
  )
21
 
22
  refiner = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
23
  "stabilityai/stable-diffusion-xl-refiner-1.0",
24
  text_encoder_2=base.text_encoder_2,
25
  vae=base.vae,
26
+ torch_dtype=dtype,
27
  use_safetensors=True,
28
  variant="fp16",
29
  )