kaupane commited on
Commit
1a8413a
·
verified ·
1 Parent(s): c95f19f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -41,6 +41,9 @@ def load_dit_model(dit_size):
41
  # Load checkpoint
42
  checkpoint = torch.load(ckpt_path, map_location="cpu")
43
  model.load_state_dict(checkpoint["model_state_dict"])
 
 
 
44
 
45
  return model
46
 
@@ -98,7 +101,7 @@ class DiffusionSampler:
98
  s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
99
 
100
  # Start with random latents
101
- latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
102
 
103
  # Use classifier-free guidance for better quality
104
  cfg_scale = 2.5
@@ -135,7 +138,9 @@ class DiffusionSampler:
135
 
136
  # Decode latents to images
137
  self.load_vae()
138
- latents = latents / self.vae.config.scaling_factor
 
 
139
  latents = latents.to(self.device)
140
 
141
  progress(0.95, desc="Decoding images...")
@@ -167,9 +172,9 @@ def generate_random_seed():
167
  return random.randint(0, 2**32 - 1)
168
 
169
  MODEL_SAMPLE_LIMITS = {
170
- "S": {"min":1, "max": 18, "default": 4},
171
- "B": {"min":1, "max": 9, "default": 4},
172
- "L": {"min":1, "max": 3, "default": 1}
173
  }
174
 
175
  def update_sample_slider(dit_size):
 
41
  # Load checkpoint
42
  checkpoint = torch.load(ckpt_path, map_location="cpu")
43
  model.load_state_dict(checkpoint["model_state_dict"])
44
+
45
+ # Use half precision to speed up sampling
46
+ model = model.half()
47
 
48
  return model
49
 
 
101
  s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
102
 
103
  # Start with random latents
104
+ latents = torch.randn((num_samples, 4, 32, 32), device=self.device, dtype=torch.float16)
105
 
106
  # Use classifier-free guidance for better quality
107
  cfg_scale = 2.5
 
138
 
139
  # Decode latents to images
140
  self.load_vae()
141
+
142
+ # Convert latents back to float32 for vae decoding
143
+ latents = latents.to(dtype=torch.float16) / self.vae.config.scaling_factor
144
  latents = latents.to(self.device)
145
 
146
  progress(0.95, desc="Decoding images...")
 
172
  return random.randint(0, 2**32 - 1)
173
 
174
  MODEL_SAMPLE_LIMITS = {
175
+ "S": {"min":1, "max": 16, "default": 4},
176
+ "B": {"min":1, "max": 12, "default": 4},
177
+ "L": {"min":1, "max": 4, "default": 1}
178
  }
179
 
180
  def update_sample_slider(dit_size):