kaupane commited on
Commit
d7b421a
·
verified ·
1 Parent(s): ff687a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -62,13 +62,13 @@ class DiffusionSampler:
62
  self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
63
  self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
64
 
65
- # Move to device in half precision
66
- self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device).half()
67
- self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device).half()
68
- self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device).half()
69
- self.betas = self.betas.to(self.device).half()
70
- self.posterior_variance = self.posterior_variance.to(self.device).half()
71
-
72
  def load_vae(self):
73
  """Load VAE model (done lazily to save memory until needed)"""
74
  if self.vae is None:
@@ -102,6 +102,13 @@ class DiffusionSampler:
102
 
103
  # Start with random latents
104
  latents = torch.randn((num_samples, 4, 32, 32), device=self.device, dtype=torch.float16)
 
 
 
 
 
 
 
105
 
106
  # Use classifier-free guidance for better quality
107
  cfg_scale = 2.5
@@ -118,10 +125,10 @@ class DiffusionSampler:
118
 
119
  t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
120
 
121
- sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
122
- sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
123
- beta_t = self.betas[t].view(-1, 1, 1, 1)
124
- posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1)
125
 
126
  # Get noise prediction with classifier-free guidance
127
  eps_theta_cond = model(latents, t, g_cond, s_cond)
@@ -130,7 +137,7 @@ class DiffusionSampler:
130
 
131
  # Update latents
132
  mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
133
- noise = torch.randn_like(latents)
134
  if t_val == 0:
135
  latents = mean
136
  else:
 
62
  self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
63
  self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
64
 
65
+ # Move to device
66
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device)
67
+ self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device)
68
+ self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
69
+ self.betas = self.betas.to(self.device)
70
+ self.posterior_variance = self.posterior_variance.to(self.device)
71
+
72
  def load_vae(self):
73
  """Load VAE model (done lazily to save memory until needed)"""
74
  if self.vae is None:
 
102
 
103
  # Start with random latents
104
  latents = torch.randn((num_samples, 4, 32, 32), device=self.device, dtype=torch.float16)
105
+
106
+ # Convert diffusion parameters to half precision for compatibility
107
+ sqrt_alphas_cumprod_half = self.sqrt_alphas_cumprod.half()
108
+ sqrt_one_minus_alpha_cumprod_half = self.sqrt_one_minus_alpha_cumprod.half()
109
+ sqrt_recip_alphas_half = self.sqrt_recip_alphas.half()
110
+ betas_half = self.betas.half()
111
+ posterior_variance_half = self.posterior_variance.half()
112
 
113
  # Use classifier-free guidance for better quality
114
  cfg_scale = 2.5
 
125
 
126
  t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
127
 
128
+ sqrt_recip_alphas_t = sqrt_recip_alphas_half[t].view(-1, 1, 1, 1)
129
+ sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alpha_cumprod_half[t].view(-1, 1, 1, 1)
130
+ beta_t = betas_half[t].view(-1, 1, 1, 1)
131
+ posterior_variance_t = posterior_variance_half[t].view(-1, 1, 1, 1)
132
 
133
  # Get noise prediction with classifier-free guidance
134
  eps_theta_cond = model(latents, t, g_cond, s_cond)
 
137
 
138
  # Update latents
139
  mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
140
+ noise = torch.randn_like(latents,dtype=torch.float16)
141
  if t_val == 0:
142
  latents = mean
143
  else: