kaupane commited on
Commit
ff687a4
·
verified ·
1 Parent(s): 1a8413a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -62,12 +62,12 @@ class DiffusionSampler:
62
  self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
63
  self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
64
 
65
- # Move to device
66
- self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device)
67
- self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device)
68
- self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
69
- self.betas = self.betas.to(self.device)
70
- self.posterior_variance = self.posterior_variance.to(self.device)
71
 
72
  def load_vae(self):
73
  """Load VAE model (done lazily to save memory until needed)"""
 
62
  self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
63
  self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
64
 
65
+ # Move to device in half precision
66
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device).half()
67
+ self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device).half()
68
+ self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device).half()
69
+ self.betas = self.betas.to(self.device).half()
70
+ self.posterior_variance = self.posterior_variance.to(self.device).half()
71
 
72
  def load_vae(self):
73
  """Load VAE model (done lazily to save memory until needed)"""