Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,9 @@ latent_scale_factor = 0.18215 # Same as in DiTTrainer
|
|
18 |
# For tracking progress in UI
|
19 |
global_progress = 0
|
20 |
|
|
|
|
|
|
|
21 |
def load_dit_model(dit_size):
|
22 |
"""Load DiT model of specified size"""
|
23 |
#ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
|
@@ -41,15 +44,13 @@ def load_dit_model(dit_size):
|
|
41 |
# Load checkpoint
|
42 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
43 |
model.load_state_dict(checkpoint["model_state_dict"])
|
44 |
-
|
45 |
-
# Use half precision to speed up sampling
|
46 |
-
model = model.half()
|
47 |
|
48 |
return model
|
49 |
|
50 |
class DiffusionSampler:
|
51 |
-
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
|
52 |
self.device = device
|
|
|
53 |
self.vae = None
|
54 |
|
55 |
# Pre-compute diffusion parameters
|
@@ -68,11 +69,20 @@ class DiffusionSampler:
|
|
68 |
self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
|
69 |
self.betas = self.betas.to(self.device)
|
70 |
self.posterior_variance = self.posterior_variance.to(self.device)
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def load_vae(self):
|
73 |
"""Load VAE model (done lazily to save memory until needed)"""
|
74 |
if self.vae is None:
|
75 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
|
|
|
76 |
self.vae.eval()
|
77 |
|
78 |
@spaces.GPU(duration=120)
|
@@ -91,7 +101,10 @@ class DiffusionSampler:
|
|
91 |
torch.cuda.manual_seed(seed)
|
92 |
torch.cuda.manual_seed_all(seed)
|
93 |
|
|
|
94 |
model.to(self.device)
|
|
|
|
|
95 |
model.eval()
|
96 |
|
97 |
# Convert genre and style to tensors
|
@@ -100,15 +113,10 @@ class DiffusionSampler:
|
|
100 |
g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
|
101 |
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
|
102 |
|
103 |
-
# Start with random latents
|
104 |
-
latents = torch.randn((num_samples, 4, 32, 32), device=self.device
|
105 |
-
|
106 |
-
|
107 |
-
sqrt_alphas_cumprod_half = self.sqrt_alphas_cumprod.half()
|
108 |
-
sqrt_one_minus_alpha_cumprod_half = self.sqrt_one_minus_alpha_cumprod.half()
|
109 |
-
sqrt_recip_alphas_half = self.sqrt_recip_alphas.half()
|
110 |
-
betas_half = self.betas.half()
|
111 |
-
posterior_variance_half = self.posterior_variance.half()
|
112 |
|
113 |
# Use classifier-free guidance for better quality
|
114 |
cfg_scale = 2.5
|
@@ -125,10 +133,11 @@ class DiffusionSampler:
|
|
125 |
|
126 |
t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
132 |
|
133 |
# Get noise prediction with classifier-free guidance
|
134 |
eps_theta_cond = model(latents, t, g_cond, s_cond)
|
@@ -137,18 +146,23 @@ class DiffusionSampler:
|
|
137 |
|
138 |
# Update latents
|
139 |
mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
|
140 |
-
|
|
|
|
|
|
|
141 |
if t_val == 0:
|
142 |
latents = mean
|
143 |
else:
|
144 |
latents = mean + torch.sqrt(posterior_variance_t) * noise
|
145 |
|
146 |
-
# Decode latents to images
|
147 |
self.load_vae()
|
148 |
-
|
149 |
-
# Convert latents
|
150 |
-
|
151 |
-
|
|
|
|
|
152 |
|
153 |
progress(0.95, desc="Decoding images...")
|
154 |
with torch.no_grad():
|
@@ -172,16 +186,16 @@ class DiffusionSampler:
|
|
172 |
return gallery_images
|
173 |
|
174 |
# Initialize sampler globally
|
175 |
-
sampler = DiffusionSampler()
|
176 |
|
177 |
def generate_random_seed():
|
178 |
"""Generate a random seed between 0 and 2^32 - 1"""
|
179 |
return random.randint(0, 2**32 - 1)
|
180 |
|
181 |
MODEL_SAMPLE_LIMITS = {
|
182 |
-
"S": {"min":1, "max":
|
183 |
-
"B": {"min":1, "max":
|
184 |
-
"L": {"min":1, "max":
|
185 |
}
|
186 |
|
187 |
def update_sample_slider(dit_size):
|
@@ -264,6 +278,10 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
|
|
264 |
seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
|
265 |
reset_seed_btn = gr.Button("🎲 New Seed")
|
266 |
|
|
|
|
|
|
|
|
|
267 |
with gr.Row():
|
268 |
generate_btn = gr.Button("Generate Images", variant="primary")
|
269 |
clear_btn = gr.Button("🗑️ Clear Gallery")
|
@@ -282,6 +300,17 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
|
|
282 |
# Clear gallery button functionality
|
283 |
clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
# Connect components
|
286 |
generate_btn.click(
|
287 |
fn=generate_samples,
|
@@ -290,6 +319,5 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
|
|
290 |
)
|
291 |
|
292 |
|
293 |
-
|
294 |
if __name__ == "__main__":
|
295 |
app.launch()
|
|
|
18 |
# For tracking progress in UI
|
19 |
global_progress = 0
|
20 |
|
21 |
+
# Set to True to enable half-precision inference
|
22 |
+
USE_HALF_PRECISION = True
|
23 |
+
|
24 |
def load_dit_model(dit_size):
|
25 |
"""Load DiT model of specified size"""
|
26 |
#ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
|
|
|
44 |
# Load checkpoint
|
45 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
46 |
model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
|
|
|
|
47 |
|
48 |
return model
|
49 |
|
50 |
class DiffusionSampler:
|
51 |
+
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half=USE_HALF_PRECISION):
|
52 |
self.device = device
|
53 |
+
self.use_half = use_half
|
54 |
self.vae = None
|
55 |
|
56 |
# Pre-compute diffusion parameters
|
|
|
69 |
self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
|
70 |
self.betas = self.betas.to(self.device)
|
71 |
self.posterior_variance = self.posterior_variance.to(self.device)
|
72 |
+
|
73 |
+
# Convert diffusion parameters to half precision if needed
|
74 |
+
if self.use_half:
|
75 |
+
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half()
|
76 |
+
self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half()
|
77 |
+
self.sqrt_recip_alphas = self.sqrt_recip_alphas.half()
|
78 |
+
self.betas = self.betas.half()
|
79 |
+
self.posterior_variance = self.posterior_variance.half()
|
80 |
+
|
81 |
def load_vae(self):
|
82 |
"""Load VAE model (done lazily to save memory until needed)"""
|
83 |
if self.vae is None:
|
84 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
|
85 |
+
# VAE should always remain in full precision
|
86 |
self.vae.eval()
|
87 |
|
88 |
@spaces.GPU(duration=120)
|
|
|
101 |
torch.cuda.manual_seed(seed)
|
102 |
torch.cuda.manual_seed_all(seed)
|
103 |
|
104 |
+
# Move model to device and convert to half precision if enabled
|
105 |
model.to(self.device)
|
106 |
+
if self.use_half:
|
107 |
+
model.half()
|
108 |
model.eval()
|
109 |
|
110 |
# Convert genre and style to tensors
|
|
|
113 |
g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
|
114 |
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
|
115 |
|
116 |
+
# Start with random latents (in appropriate precision)
|
117 |
+
latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
|
118 |
+
if self.use_half:
|
119 |
+
latents = latents.half()
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# Use classifier-free guidance for better quality
|
122 |
cfg_scale = 2.5
|
|
|
133 |
|
134 |
t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
|
135 |
|
136 |
+
# Get diffusion parameters for current timestep in proper precision
|
137 |
+
sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
|
138 |
+
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
|
139 |
+
beta_t = self.betas[t].view(-1, 1, 1, 1)
|
140 |
+
posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1)
|
141 |
|
142 |
# Get noise prediction with classifier-free guidance
|
143 |
eps_theta_cond = model(latents, t, g_cond, s_cond)
|
|
|
146 |
|
147 |
# Update latents
|
148 |
mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
|
149 |
+
|
150 |
+
# Generate noise with same precision as latents
|
151 |
+
noise = torch.randn_like(latents)
|
152 |
+
|
153 |
if t_val == 0:
|
154 |
latents = mean
|
155 |
else:
|
156 |
latents = mean + torch.sqrt(posterior_variance_t) * noise
|
157 |
|
158 |
+
# Decode latents to images - VAE needs full precision
|
159 |
self.load_vae()
|
160 |
+
|
161 |
+
# Convert latents to full precision for VAE if needed
|
162 |
+
if self.use_half:
|
163 |
+
latents = latents.float()
|
164 |
+
|
165 |
+
latents = latents / self.vae.config.scaling_factor
|
166 |
|
167 |
progress(0.95, desc="Decoding images...")
|
168 |
with torch.no_grad():
|
|
|
186 |
return gallery_images
|
187 |
|
188 |
# Initialize sampler globally
|
189 |
+
sampler = DiffusionSampler(use_half=USE_HALF_PRECISION)
|
190 |
|
191 |
def generate_random_seed():
|
192 |
"""Generate a random seed between 0 and 2^32 - 1"""
|
193 |
return random.randint(0, 2**32 - 1)
|
194 |
|
195 |
MODEL_SAMPLE_LIMITS = {
|
196 |
+
"S": {"min":1, "max": 18, "default": 4},
|
197 |
+
"B": {"min":1, "max": 9, "default": 4},
|
198 |
+
"L": {"min":1, "max": 3, "default": 1}
|
199 |
}
|
200 |
|
201 |
def update_sample_slider(dit_size):
|
|
|
278 |
seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
|
279 |
reset_seed_btn = gr.Button("🎲 New Seed")
|
280 |
|
281 |
+
# Add option to toggle half-precision
|
282 |
+
use_half_precision = gr.Checkbox(label="Use half-precision (faster)", value=USE_HALF_PRECISION,
|
283 |
+
info="Use FP16 for model (faster, less memory, slightly lower quality)")
|
284 |
+
|
285 |
with gr.Row():
|
286 |
generate_btn = gr.Button("Generate Images", variant="primary")
|
287 |
clear_btn = gr.Button("🗑️ Clear Gallery")
|
|
|
300 |
# Clear gallery button functionality
|
301 |
clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
|
302 |
|
303 |
+
# Update half-precision setting when checkbox is changed
|
304 |
+
def update_half_precision(value):
|
305 |
+
global USE_HALF_PRECISION
|
306 |
+
USE_HALF_PRECISION = value
|
307 |
+
# Recreate sampler with new setting
|
308 |
+
global sampler
|
309 |
+
sampler = DiffusionSampler(use_half=value)
|
310 |
+
return None
|
311 |
+
|
312 |
+
use_half_precision.change(update_half_precision, inputs=[use_half_precision], outputs=[None])
|
313 |
+
|
314 |
# Connect components
|
315 |
generate_btn.click(
|
316 |
fn=generate_samples,
|
|
|
319 |
)
|
320 |
|
321 |
|
|
|
322 |
if __name__ == "__main__":
|
323 |
app.launch()
|