import torch import spaces import gradio as gr import numpy as np import os import random from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping from diffusers import AutoencoderKL from huggingface_hub import hf_hub_download from models.DiT import DiT # Global settings num_timesteps = 1000 beta_start = 1e-4 beta_end = 0.02 latent_scale_factor = 0.18215 # Same as in DiTTrainer # For tracking progress in UI global_progress = 0 def load_dit_model(dit_size): """Load DiT model of specified size""" #ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth" ckpt_path = hf_hub_download( repo_id = "kaupane/DiT-Wikiart", filename = f"DiT_{dit_size}_final.pth" ) if not os.path.exists(ckpt_path): raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}") # Configure model based on size if dit_size == "S": model = DiT(num_blocks=8, hidden_size=384, num_heads=6) elif dit_size == "B": model = DiT(num_blocks=12, hidden_size=640, num_heads=10) elif dit_size == "L": model = DiT(num_blocks=16, hidden_size=896, num_heads=14) else: raise ValueError(f"Invalid DiT size: {dit_size}") # Load checkpoint checkpoint = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) return model class DiffusionSampler: def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): self.device = device self.vae = None # Pre-compute diffusion parameters self.betas = torch.linspace(beta_start, beta_end, num_timesteps) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]]) self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) # Move to device self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device) self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device) self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device) self.betas = self.betas.to(self.device) self.posterior_variance = self.posterior_variance.to(self.device) def load_vae(self): """Load VAE model (done lazily to save memory until needed)""" if self.vae is None: self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device) self.vae.eval() @spaces.GPU def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()): """Generate images with the DiT model""" global global_progress global_progress = 0 # Set random seed for reproducibility if seed is not None: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # Also set CUDA seed if using GPU if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) model.to(self.device) model.eval() # Convert genre and style to tensors g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long) s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long) g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long) s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long) # Start with random latents latents = torch.randn((num_samples, 4, 32, 32), device=self.device) # Use classifier-free guidance for better quality cfg_scale = 2.5 # Go through the reverse diffusion process timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device) total_steps = len(timesteps) with torch.no_grad(): for i, t_val in enumerate(timesteps): # Update progress global_progress = int(100 * i / total_steps) progress(global_progress / 100, desc="Generating images...") t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long) sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1) beta_t = self.betas[t].view(-1, 1, 1, 1) posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1) # Get noise prediction with classifier-free guidance eps_theta_cond = model(latents, t, g_cond, s_cond) eps_theta_uncond = model(latents, t, g_null, s_null) eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond) # Update latents mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta) noise = torch.randn_like(latents) if t_val == 0: latents = mean else: latents = mean + torch.sqrt(posterior_variance_t) * noise # Decode latents to images self.load_vae() latents = latents / self.vae.config.scaling_factor latents = latents.to(self.device) progress(0.95, desc="Decoding images...") with torch.no_grad(): images = self.vae.decode(latents).sample images = (images / 2 + 0.5).clamp(0, 1) images = images.permute(0, 2, 3, 1).cpu().numpy() progress(1.0, desc="Done!") global_progress = 100 # Create image gallery with labels gallery_images = [] for i in range(num_samples): # Convert numpy array to PIL Image img = (images[i] * 255).astype(np.uint8) caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}" if seed is not None: caption += f" (Seed: {seed})" gallery_images.append((img, caption)) return gallery_images # Initialize sampler globally sampler = DiffusionSampler() def generate_random_seed(): """Generate a random seed between 0 and 2^32 - 1""" return random.randint(0, 2**32 - 1) @spaces.GPU def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()): """Main function for Gradio interface""" if num_samples < 1 or num_samples > 16: return None, gr.update(value="Number of samples must be between 1 and 16", visible=True) # Get genre and style IDs from mappings genre_id = reduced_genre_mapping.get(genre_name) style_id = reduced_style_mapping.get(style_name) if genre_id is None: return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True) if style_id is None: return None, gr.update(value=f"Unknown style: {style_name}", visible=True) try: # Load model progress(0.05, desc="Loading DiT model...") model = load_dit_model(dit_size) # Generate images gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress) return gallery_images, gr.update(value="", visible=False) except Exception as e: error_msg = f"Error: {str(e)}" return None, gr.update(value=error_msg, visible=True) def clear_gallery(): """Clear the gallery display""" return None, gr.update(value="", visible=False) # Create the Gradio interface with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app: gr.Markdown("# DiT Diffusion Model Generator") gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model") with gr.Row(): with gr.Column(scale=1): num_samples = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Number of Samples", info="How many images to generate (1-16)") dit_size = gr.Radio(choices=["S", "B", "L"], value="S", label="DiT Model Size", info="Larger models produce better quality but take longer") genre_names = list(reduced_genre_mapping.keys()) style_names = list(reduced_style_mapping.keys()) # Sort alphabetically, ensuring 'None' is at top genre_names.sort() style_names.sort() genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre") style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style") with gr.Row(): seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results") reset_seed_btn = gr.Button("🎲 New Seed") with gr.Row(): generate_btn = gr.Button("Generate Images", variant="primary") clear_btn = gr.Button("🗑️ Clear Gallery") progress_bar = gr.Progress(track_tqdm=True) with gr.Column(scale=2): output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=4, object_fit="contain", height=600) error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box") # Seed reset button functionality reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed]) # Clear gallery button functionality clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message]) # Connect components generate_btn.click( fn=generate_samples, inputs=[num_samples, dit_size, genre, style, seed], outputs=[output_gallery, error_message], ) if __name__ == "__main__": app.launch()