Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -41,6 +41,9 @@ def load_dit_model(dit_size):
|
|
41 |
# Load checkpoint
|
42 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
43 |
model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
|
|
|
|
44 |
|
45 |
return model
|
46 |
|
@@ -98,7 +101,7 @@ class DiffusionSampler:
|
|
98 |
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
|
99 |
|
100 |
# Start with random latents
|
101 |
-
latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
|
102 |
|
103 |
# Use classifier-free guidance for better quality
|
104 |
cfg_scale = 2.5
|
@@ -135,7 +138,9 @@ class DiffusionSampler:
|
|
135 |
|
136 |
# Decode latents to images
|
137 |
self.load_vae()
|
138 |
-
|
|
|
|
|
139 |
latents = latents.to(self.device)
|
140 |
|
141 |
progress(0.95, desc="Decoding images...")
|
@@ -167,9 +172,9 @@ def generate_random_seed():
|
|
167 |
return random.randint(0, 2**32 - 1)
|
168 |
|
169 |
MODEL_SAMPLE_LIMITS = {
|
170 |
-
"S": {"min":1, "max":
|
171 |
-
"B": {"min":1, "max":
|
172 |
-
"L": {"min":1, "max":
|
173 |
}
|
174 |
|
175 |
def update_sample_slider(dit_size):
|
|
|
41 |
# Load checkpoint
|
42 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
43 |
model.load_state_dict(checkpoint["model_state_dict"])
|
44 |
+
|
45 |
+
# Use half precision to speed up sampling
|
46 |
+
model = model.half()
|
47 |
|
48 |
return model
|
49 |
|
|
|
101 |
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
|
102 |
|
103 |
# Start with random latents
|
104 |
+
latents = torch.randn((num_samples, 4, 32, 32), device=self.device, dtype=torch.float16)
|
105 |
|
106 |
# Use classifier-free guidance for better quality
|
107 |
cfg_scale = 2.5
|
|
|
138 |
|
139 |
# Decode latents to images
|
140 |
self.load_vae()
|
141 |
+
|
142 |
+
# Convert latents back to float32 for vae decoding
|
143 |
+
latents = latents.to(dtype=torch.float16) / self.vae.config.scaling_factor
|
144 |
latents = latents.to(self.device)
|
145 |
|
146 |
progress(0.95, desc="Decoding images...")
|
|
|
172 |
return random.randint(0, 2**32 - 1)
|
173 |
|
174 |
MODEL_SAMPLE_LIMITS = {
|
175 |
+
"S": {"min":1, "max": 16, "default": 4},
|
176 |
+
"B": {"min":1, "max": 12, "default": 4},
|
177 |
+
"L": {"min":1, "max": 4, "default": 1}
|
178 |
}
|
179 |
|
180 |
def update_sample_slider(dit_size):
|