import torch import spaces import gradio as gr import numpy as np import os import random from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping from diffusers import AutoencoderKL from huggingface_hub import hf_hub_download from models.DiT import DiT # Global settings num_timesteps = 1000 beta_start = 1e-4 beta_end = 0.02 latent_scale_factor = 0.18215 # Same as in DiTTrainer # For tracking progress in UI global_progress = 0 # Enable half precision inference USE_HALF_PRECISION = True def load_dit_model(dit_size): """Load DiT model of specified size""" # Configure model based on size if dit_size == "S": model = DiT.from_pretrained("kaupane/DiT-Wikiart-Small") elif dit_size == "B": model = DiT.from_pretrained("kaupane/DiT-Wikiart-Base") elif dit_size == "L": model = DiT.from_pretrained("kaupane/DiT-Wikiart-Large") else: raise ValueError(f"Invalid DiT size: {dit_size}") return model class DiffusionSampler: def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half = USE_HALF_PRECISION): self.device = device self.use_half = use_half self.vae = None # Pre-compute diffusion parameters self.betas = torch.linspace(beta_start, beta_end, num_timesteps) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]]) self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) # Move to device self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device) self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device) self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device) self.betas = self.betas.to(self.device) self.posterior_variance = self.posterior_variance.to(self.device) # Convert to half precision if needed if self.use_half: self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half() self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half() self.sqrt_recip_alphas = self.sqrt_recip_alphas.half() self.betas = self.betas.half() self.posterior_variance = self.posterior_variance.half() def load_vae(self): """Load VAE model (done lazily to save memory until needed)""" if self.vae is None: self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device) self.vae.eval() @spaces.GPU(duration=120) def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()): """Generate images with the DiT model""" global global_progress global_progress = 0 # Set random seed for reproducibility if seed is not None: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # Also set CUDA seed if using GPU if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if self.use_half: model.half() model.to(self.device) model.eval() # Convert genre and style to tensors g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long) s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long) g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long) s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long) # Start with random latents latents = torch.randn((num_samples, 4, 32, 32), device=self.device) if self.use_half: latents = latents.half() # Use classifier-free guidance for better quality cfg_scale = 2.5 # Go through the reverse diffusion process timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device) total_steps = len(timesteps) with torch.no_grad(): for i, t_val in enumerate(timesteps): # Update progress global_progress = int(100 * i / total_steps) progress(global_progress / 100, desc="Generating images...") t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long) sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1) beta_t = self.betas[t].view(-1, 1, 1, 1) posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1) # Get noise prediction with classifier-free guidance eps_theta_cond = model(latents, t, g_cond, s_cond) eps_theta_uncond = model(latents, t, g_null, s_null) eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond) # Update latents mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta) noise = torch.randn_like(latents) if t_val == 0: latents = mean else: latents = mean + torch.sqrt(posterior_variance_t) * noise # Decode latents to images self.load_vae() # Convert back to float if self.use_half: latents = latents.float() latents = latents / self.vae.config.scaling_factor latents = latents.to(self.device) progress(0.95, desc="Decoding images...") with torch.no_grad(): images = self.vae.decode(latents).sample images = (images / 2 + 0.5).clamp(0, 1) images = images.permute(0, 2, 3, 1).cpu().numpy() progress(1.0, desc="Done!") global_progress = 100 # Create image gallery with labels gallery_images = [] for i in range(num_samples): # Convert numpy array to PIL Image img = (images[i] * 255).astype(np.uint8) caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}" if seed is not None: caption += f" (Seed: {seed})" gallery_images.append((img, caption)) return gallery_images # Initialize sampler globally sampler = DiffusionSampler() def generate_random_seed(): """Generate a random seed between 0 and 2^32 - 1""" return random.randint(0, 2**32 - 1) MODEL_SAMPLE_LIMITS = { "S": {"min":1, "max": 24, "default": 6}, "B": {"min":1, "max": 16, "default": 4}, "L": {"min":1, "max": 8, "default": 2} } def update_sample_slider(dit_size): limits = MODEL_SAMPLE_LIMITS[dit_size] return gr.update( minimum=limits["min"], maximum=limits["max"], value=limits["default"], info=f"How many images to generate ({limits['min']}-{limits['max']})" ) @spaces.GPU(duration=120) def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()): """Main function for Gradio interface""" limits = MODEL_SAMPLE_LIMITS[dit_size] if num_samples < limits["min"] or num_samples > limits["max"]: return None, gr.update(value=f"Number of samples for {dit_size} model must be between {limits['min']} and {limits['max']}", visible=True) # Get genre and style IDs from mappings genre_id = reduced_genre_mapping.get(genre_name) style_id = reduced_style_mapping.get(style_name) if genre_id is None: return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True) if style_id is None: return None, gr.update(value=f"Unknown style: {style_name}", visible=True) try: # Load model progress(0.05, desc="Loading DiT model...") model = load_dit_model(dit_size) # Generate images gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress) return gallery_images, gr.update(value="", visible=False) except Exception as e: error_msg = f"Error: {str(e)}" return None, gr.update(value=error_msg, visible=True) def clear_gallery(): """Clear the gallery display""" return None, gr.update(value="", visible=False) # Create the Gradio interface with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app: gr.Markdown("# DiT Diffusion Model Generator") gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model") with gr.Row(): with gr.Column(scale=1): dit_size = gr.Radio( choices=["S", "B", "L"], value="B", label="DiT Model Size", info="S: Small (fastest), B: Base (balanced), L: Large (best quality but slowest)" ) num_samples = gr.Slider( minimum=MODEL_SAMPLE_LIMITS["B"]["min"], maximum=MODEL_SAMPLE_LIMITS["B"]["max"], value=MODEL_SAMPLE_LIMITS["B"]["default"], step=1, label="Number of Samples", info=f"How many images to generate ({MODEL_SAMPLE_LIMITS['B']['min']}-{MODEL_SAMPLE_LIMITS['B']['max']})" ) genre_names = list(reduced_genre_mapping.keys()) style_names = list(reduced_style_mapping.keys()) # Sort alphabetically, ensuring 'None' is at top genre_names.sort() style_names.sort() genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre") style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style") with gr.Row(): seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results") reset_seed_btn = gr.Button("🎲 New Seed") with gr.Row(): generate_btn = gr.Button("Generate Images", variant="primary") clear_btn = gr.Button("🗑️ Clear Gallery") progress_bar = gr.Progress(track_tqdm=True) with gr.Column(scale=2): output_gallery = gr.Gallery( label="Generated Images", columns=6, rows=4, height=600, object_fit="contain", allow_preview=True, show_download_button=True ) error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box") dit_size.change(update_sample_slider, inputs=[dit_size],outputs=[num_samples]) # Seed reset button functionality reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed]) # Clear gallery button functionality clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message]) # Connect components generate_btn.click( fn=generate_samples, inputs=[num_samples, dit_size, genre, style, seed], outputs=[output_gallery, error_message], ) if __name__ == "__main__": app.launch()