Upload pipeline.py
Browse files- pipeline.py +3 -3
pipeline.py
CHANGED
@@ -102,9 +102,9 @@ class SuperDiffSDXLPipeline(DiffusionPipeline, ConfigMixin):
|
|
102 |
dtype = torch.float16
|
103 |
|
104 |
vae.to(device)
|
105 |
-
unet.to(device)
|
106 |
-
text_encoder.to(device)
|
107 |
-
text_encoder_2.to(device)
|
108 |
|
109 |
self.register_modules(
|
110 |
unet=unet,
|
|
|
102 |
dtype = torch.float16
|
103 |
|
104 |
vae.to(device)
|
105 |
+
unet.to(device, dtype=dtype)
|
106 |
+
text_encoder.to(device, dtype=dtype)
|
107 |
+
text_encoder_2.to(device, dtype=dtype)
|
108 |
|
109 |
self.register_modules(
|
110 |
unet=unet,
|