kaupane commited on
Commit
9bb4b8c
·
verified ·
1 Parent(s): d7b421a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -29
app.py CHANGED
@@ -18,6 +18,9 @@ latent_scale_factor = 0.18215 # Same as in DiTTrainer
18
  # For tracking progress in UI
19
  global_progress = 0
20
 
 
 
 
21
  def load_dit_model(dit_size):
22
  """Load DiT model of specified size"""
23
  #ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
@@ -41,15 +44,13 @@ def load_dit_model(dit_size):
41
  # Load checkpoint
42
  checkpoint = torch.load(ckpt_path, map_location="cpu")
43
  model.load_state_dict(checkpoint["model_state_dict"])
44
-
45
- # Use half precision to speed up sampling
46
- model = model.half()
47
 
48
  return model
49
 
50
  class DiffusionSampler:
51
- def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
52
  self.device = device
 
53
  self.vae = None
54
 
55
  # Pre-compute diffusion parameters
@@ -68,11 +69,20 @@ class DiffusionSampler:
68
  self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
69
  self.betas = self.betas.to(self.device)
70
  self.posterior_variance = self.posterior_variance.to(self.device)
71
-
 
 
 
 
 
 
 
 
72
  def load_vae(self):
73
  """Load VAE model (done lazily to save memory until needed)"""
74
  if self.vae is None:
75
  self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
 
76
  self.vae.eval()
77
 
78
  @spaces.GPU(duration=120)
@@ -91,7 +101,10 @@ class DiffusionSampler:
91
  torch.cuda.manual_seed(seed)
92
  torch.cuda.manual_seed_all(seed)
93
 
 
94
  model.to(self.device)
 
 
95
  model.eval()
96
 
97
  # Convert genre and style to tensors
@@ -100,15 +113,10 @@ class DiffusionSampler:
100
  g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
101
  s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
102
 
103
- # Start with random latents
104
- latents = torch.randn((num_samples, 4, 32, 32), device=self.device, dtype=torch.float16)
105
-
106
- # Convert diffusion parameters to half precision for compatibility
107
- sqrt_alphas_cumprod_half = self.sqrt_alphas_cumprod.half()
108
- sqrt_one_minus_alpha_cumprod_half = self.sqrt_one_minus_alpha_cumprod.half()
109
- sqrt_recip_alphas_half = self.sqrt_recip_alphas.half()
110
- betas_half = self.betas.half()
111
- posterior_variance_half = self.posterior_variance.half()
112
 
113
  # Use classifier-free guidance for better quality
114
  cfg_scale = 2.5
@@ -125,10 +133,11 @@ class DiffusionSampler:
125
 
126
  t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
127
 
128
- sqrt_recip_alphas_t = sqrt_recip_alphas_half[t].view(-1, 1, 1, 1)
129
- sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alpha_cumprod_half[t].view(-1, 1, 1, 1)
130
- beta_t = betas_half[t].view(-1, 1, 1, 1)
131
- posterior_variance_t = posterior_variance_half[t].view(-1, 1, 1, 1)
 
132
 
133
  # Get noise prediction with classifier-free guidance
134
  eps_theta_cond = model(latents, t, g_cond, s_cond)
@@ -137,18 +146,23 @@ class DiffusionSampler:
137
 
138
  # Update latents
139
  mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
140
- noise = torch.randn_like(latents,dtype=torch.float16)
 
 
 
141
  if t_val == 0:
142
  latents = mean
143
  else:
144
  latents = mean + torch.sqrt(posterior_variance_t) * noise
145
 
146
- # Decode latents to images
147
  self.load_vae()
148
-
149
- # Convert latents back to float32 for vae decoding
150
- latents = latents.to(dtype=torch.float16) / self.vae.config.scaling_factor
151
- latents = latents.to(self.device)
 
 
152
 
153
  progress(0.95, desc="Decoding images...")
154
  with torch.no_grad():
@@ -172,16 +186,16 @@ class DiffusionSampler:
172
  return gallery_images
173
 
174
  # Initialize sampler globally
175
- sampler = DiffusionSampler()
176
 
177
  def generate_random_seed():
178
  """Generate a random seed between 0 and 2^32 - 1"""
179
  return random.randint(0, 2**32 - 1)
180
 
181
  MODEL_SAMPLE_LIMITS = {
182
- "S": {"min":1, "max": 16, "default": 4},
183
- "B": {"min":1, "max": 12, "default": 4},
184
- "L": {"min":1, "max": 4, "default": 1}
185
  }
186
 
187
  def update_sample_slider(dit_size):
@@ -264,6 +278,10 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
264
  seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
265
  reset_seed_btn = gr.Button("🎲 New Seed")
266
 
 
 
 
 
267
  with gr.Row():
268
  generate_btn = gr.Button("Generate Images", variant="primary")
269
  clear_btn = gr.Button("🗑️ Clear Gallery")
@@ -282,6 +300,17 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
282
  # Clear gallery button functionality
283
  clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
284
 
 
 
 
 
 
 
 
 
 
 
 
285
  # Connect components
286
  generate_btn.click(
287
  fn=generate_samples,
@@ -290,6 +319,5 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
290
  )
291
 
292
 
293
-
294
  if __name__ == "__main__":
295
  app.launch()
 
18
  # For tracking progress in UI
19
  global_progress = 0
20
 
21
+ # Set to True to enable half-precision inference
22
+ USE_HALF_PRECISION = True
23
+
24
  def load_dit_model(dit_size):
25
  """Load DiT model of specified size"""
26
  #ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
 
44
  # Load checkpoint
45
  checkpoint = torch.load(ckpt_path, map_location="cpu")
46
  model.load_state_dict(checkpoint["model_state_dict"])
 
 
 
47
 
48
  return model
49
 
50
  class DiffusionSampler:
51
+ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half=USE_HALF_PRECISION):
52
  self.device = device
53
+ self.use_half = use_half
54
  self.vae = None
55
 
56
  # Pre-compute diffusion parameters
 
69
  self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
70
  self.betas = self.betas.to(self.device)
71
  self.posterior_variance = self.posterior_variance.to(self.device)
72
+
73
+ # Convert diffusion parameters to half precision if needed
74
+ if self.use_half:
75
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half()
76
+ self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half()
77
+ self.sqrt_recip_alphas = self.sqrt_recip_alphas.half()
78
+ self.betas = self.betas.half()
79
+ self.posterior_variance = self.posterior_variance.half()
80
+
81
  def load_vae(self):
82
  """Load VAE model (done lazily to save memory until needed)"""
83
  if self.vae is None:
84
  self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
85
+ # VAE should always remain in full precision
86
  self.vae.eval()
87
 
88
  @spaces.GPU(duration=120)
 
101
  torch.cuda.manual_seed(seed)
102
  torch.cuda.manual_seed_all(seed)
103
 
104
+ # Move model to device and convert to half precision if enabled
105
  model.to(self.device)
106
+ if self.use_half:
107
+ model.half()
108
  model.eval()
109
 
110
  # Convert genre and style to tensors
 
113
  g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
114
  s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
115
 
116
+ # Start with random latents (in appropriate precision)
117
+ latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
118
+ if self.use_half:
119
+ latents = latents.half()
 
 
 
 
 
120
 
121
  # Use classifier-free guidance for better quality
122
  cfg_scale = 2.5
 
133
 
134
  t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
135
 
136
+ # Get diffusion parameters for current timestep in proper precision
137
+ sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
138
+ sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
139
+ beta_t = self.betas[t].view(-1, 1, 1, 1)
140
+ posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1)
141
 
142
  # Get noise prediction with classifier-free guidance
143
  eps_theta_cond = model(latents, t, g_cond, s_cond)
 
146
 
147
  # Update latents
148
  mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
149
+
150
+ # Generate noise with same precision as latents
151
+ noise = torch.randn_like(latents)
152
+
153
  if t_val == 0:
154
  latents = mean
155
  else:
156
  latents = mean + torch.sqrt(posterior_variance_t) * noise
157
 
158
+ # Decode latents to images - VAE needs full precision
159
  self.load_vae()
160
+
161
+ # Convert latents to full precision for VAE if needed
162
+ if self.use_half:
163
+ latents = latents.float()
164
+
165
+ latents = latents / self.vae.config.scaling_factor
166
 
167
  progress(0.95, desc="Decoding images...")
168
  with torch.no_grad():
 
186
  return gallery_images
187
 
188
  # Initialize sampler globally
189
+ sampler = DiffusionSampler(use_half=USE_HALF_PRECISION)
190
 
191
  def generate_random_seed():
192
  """Generate a random seed between 0 and 2^32 - 1"""
193
  return random.randint(0, 2**32 - 1)
194
 
195
  MODEL_SAMPLE_LIMITS = {
196
+ "S": {"min":1, "max": 18, "default": 4},
197
+ "B": {"min":1, "max": 9, "default": 4},
198
+ "L": {"min":1, "max": 3, "default": 1}
199
  }
200
 
201
  def update_sample_slider(dit_size):
 
278
  seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
279
  reset_seed_btn = gr.Button("🎲 New Seed")
280
 
281
+ # Add option to toggle half-precision
282
+ use_half_precision = gr.Checkbox(label="Use half-precision (faster)", value=USE_HALF_PRECISION,
283
+ info="Use FP16 for model (faster, less memory, slightly lower quality)")
284
+
285
  with gr.Row():
286
  generate_btn = gr.Button("Generate Images", variant="primary")
287
  clear_btn = gr.Button("🗑️ Clear Gallery")
 
300
  # Clear gallery button functionality
301
  clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
302
 
303
+ # Update half-precision setting when checkbox is changed
304
+ def update_half_precision(value):
305
+ global USE_HALF_PRECISION
306
+ USE_HALF_PRECISION = value
307
+ # Recreate sampler with new setting
308
+ global sampler
309
+ sampler = DiffusionSampler(use_half=value)
310
+ return None
311
+
312
+ use_half_precision.change(update_half_precision, inputs=[use_half_precision], outputs=[None])
313
+
314
  # Connect components
315
  generate_btn.click(
316
  fn=generate_samples,
 
319
  )
320
 
321
 
 
322
  if __name__ == "__main__":
323
  app.launch()