kaupane commited on
Commit
ed77fec
·
verified ·
1 Parent(s): 3085d39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -18,6 +18,9 @@ latent_scale_factor = 0.18215 # Same as in DiTTrainer
18
  # For tracking progress in UI
19
  global_progress = 0
20
 
 
 
 
21
  def load_dit_model(dit_size):
22
  """Load DiT model of specified size"""
23
  #ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
@@ -45,8 +48,9 @@ def load_dit_model(dit_size):
45
  return model
46
 
47
  class DiffusionSampler:
48
- def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
49
  self.device = device
 
50
  self.vae = None
51
 
52
  # Pre-compute diffusion parameters
@@ -65,6 +69,14 @@ class DiffusionSampler:
65
  self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
66
  self.betas = self.betas.to(self.device)
67
  self.posterior_variance = self.posterior_variance.to(self.device)
 
 
 
 
 
 
 
 
68
 
69
  def load_vae(self):
70
  """Load VAE model (done lazily to save memory until needed)"""
@@ -99,6 +111,8 @@ class DiffusionSampler:
99
 
100
  # Start with random latents
101
  latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
 
 
102
 
103
  # Use classifier-free guidance for better quality
104
  cfg_scale = 2.5
@@ -135,6 +149,10 @@ class DiffusionSampler:
135
 
136
  # Decode latents to images
137
  self.load_vae()
 
 
 
 
138
  latents = latents / self.vae.config.scaling_factor
139
  latents = latents.to(self.device)
140
 
 
18
  # For tracking progress in UI
19
  global_progress = 0
20
 
21
+ # Enable half precision inference
22
+ USE_HALF_PRECISION = True
23
+
24
  def load_dit_model(dit_size):
25
  """Load DiT model of specified size"""
26
  #ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
 
48
  return model
49
 
50
  class DiffusionSampler:
51
+ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half = USE_HALF_PRECISION):
52
  self.device = device
53
+ self.use_half = use_half
54
  self.vae = None
55
 
56
  # Pre-compute diffusion parameters
 
69
  self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
70
  self.betas = self.betas.to(self.device)
71
  self.posterior_variance = self.posterior_variance.to(self.device)
72
+
73
+ # Convert to half precision if needed
74
+ if self.use_half:
75
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half()
76
+ self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half()
77
+ self.sqrt_recip_alphas = self.sqrt_recip_alphas.half()
78
+ self.betas = self.betas.half()
79
+ self.posterior_variance = self.posterior_variance.half()
80
 
81
  def load_vae(self):
82
  """Load VAE model (done lazily to save memory until needed)"""
 
111
 
112
  # Start with random latents
113
  latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
114
+ if self.use_half:
115
+ latents = latents.half()
116
 
117
  # Use classifier-free guidance for better quality
118
  cfg_scale = 2.5
 
149
 
150
  # Decode latents to images
151
  self.load_vae()
152
+
153
+ # Convert back to float
154
+ if self.use_half:
155
+ latents = latents.float()
156
  latents = latents / self.vae.config.scaling_factor
157
  latents = latents.to(self.device)
158