kaupane's picture
Update app.py
2f32e9a verified
import torch
import spaces
import gradio as gr
import numpy as np
import os
import random
from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from models.DiT import DiT
# Global settings
num_timesteps = 1000
beta_start = 1e-4
beta_end = 0.02
latent_scale_factor = 0.18215 # Same as in DiTTrainer
# For tracking progress in UI
global_progress = 0
# Enable half precision inference
USE_HALF_PRECISION = True
def load_dit_model(dit_size):
"""Load DiT model of specified size"""
# Configure model based on size
if dit_size == "S":
model = DiT.from_pretrained("kaupane/DiT-Wikiart-Small")
elif dit_size == "B":
model = DiT.from_pretrained("kaupane/DiT-Wikiart-Base")
elif dit_size == "L":
model = DiT.from_pretrained("kaupane/DiT-Wikiart-Large")
else:
raise ValueError(f"Invalid DiT size: {dit_size}")
return model
class DiffusionSampler:
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half = USE_HALF_PRECISION):
self.device = device
self.use_half = use_half
self.vae = None
# Pre-compute diffusion parameters
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# Move to device
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device)
self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device)
self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
self.betas = self.betas.to(self.device)
self.posterior_variance = self.posterior_variance.to(self.device)
# Convert to half precision if needed
if self.use_half:
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half()
self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half()
self.sqrt_recip_alphas = self.sqrt_recip_alphas.half()
self.betas = self.betas.half()
self.posterior_variance = self.posterior_variance.half()
def load_vae(self):
"""Load VAE model (done lazily to save memory until needed)"""
if self.vae is None:
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
self.vae.eval()
@spaces.GPU(duration=120)
def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
"""Generate images with the DiT model"""
global global_progress
global_progress = 0
# Set random seed for reproducibility
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# Also set CUDA seed if using GPU
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if self.use_half:
model.half()
model.to(self.device)
model.eval()
# Convert genre and style to tensors
g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long)
s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long)
g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
# Start with random latents
latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
if self.use_half:
latents = latents.half()
# Use classifier-free guidance for better quality
cfg_scale = 2.5
# Go through the reverse diffusion process
timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device)
total_steps = len(timesteps)
with torch.no_grad():
for i, t_val in enumerate(timesteps):
# Update progress
global_progress = int(100 * i / total_steps)
progress(global_progress / 100, desc="Generating images...")
t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)
sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
beta_t = self.betas[t].view(-1, 1, 1, 1)
posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1)
# Get noise prediction with classifier-free guidance
eps_theta_cond = model(latents, t, g_cond, s_cond)
eps_theta_uncond = model(latents, t, g_null, s_null)
eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond)
# Update latents
mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
noise = torch.randn_like(latents)
if t_val == 0:
latents = mean
else:
latents = mean + torch.sqrt(posterior_variance_t) * noise
# Decode latents to images
self.load_vae()
# Convert back to float
if self.use_half:
latents = latents.float()
latents = latents / self.vae.config.scaling_factor
latents = latents.to(self.device)
progress(0.95, desc="Decoding images...")
with torch.no_grad():
images = self.vae.decode(latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
images = images.permute(0, 2, 3, 1).cpu().numpy()
progress(1.0, desc="Done!")
global_progress = 100
# Create image gallery with labels
gallery_images = []
for i in range(num_samples):
# Convert numpy array to PIL Image
img = (images[i] * 255).astype(np.uint8)
caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}"
if seed is not None:
caption += f" (Seed: {seed})"
gallery_images.append((img, caption))
return gallery_images
# Initialize sampler globally
sampler = DiffusionSampler()
def generate_random_seed():
"""Generate a random seed between 0 and 2^32 - 1"""
return random.randint(0, 2**32 - 1)
MODEL_SAMPLE_LIMITS = {
"S": {"min":1, "max": 24, "default": 6},
"B": {"min":1, "max": 16, "default": 4},
"L": {"min":1, "max": 8, "default": 2}
}
def update_sample_slider(dit_size):
limits = MODEL_SAMPLE_LIMITS[dit_size]
return gr.update(
minimum=limits["min"],
maximum=limits["max"],
value=limits["default"],
info=f"How many images to generate ({limits['min']}-{limits['max']})"
)
@spaces.GPU(duration=120)
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
"""Main function for Gradio interface"""
limits = MODEL_SAMPLE_LIMITS[dit_size]
if num_samples < limits["min"] or num_samples > limits["max"]:
return None, gr.update(value=f"Number of samples for {dit_size} model must be between {limits['min']} and {limits['max']}", visible=True)
# Get genre and style IDs from mappings
genre_id = reduced_genre_mapping.get(genre_name)
style_id = reduced_style_mapping.get(style_name)
if genre_id is None:
return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True)
if style_id is None:
return None, gr.update(value=f"Unknown style: {style_name}", visible=True)
try:
# Load model
progress(0.05, desc="Loading DiT model...")
model = load_dit_model(dit_size)
# Generate images
gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress)
return gallery_images, gr.update(value="", visible=False)
except Exception as e:
error_msg = f"Error: {str(e)}"
return None, gr.update(value=error_msg, visible=True)
def clear_gallery():
"""Clear the gallery display"""
return None, gr.update(value="", visible=False)
# Create the Gradio interface
with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app:
gr.Markdown("# DiT Diffusion Model Generator")
gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model")
with gr.Row():
with gr.Column(scale=1):
dit_size = gr.Radio(
choices=["S", "B", "L"],
value="B",
label="DiT Model Size",
info="S: Small (fastest), B: Base (balanced), L: Large (best quality but slowest)"
)
num_samples = gr.Slider(
minimum=MODEL_SAMPLE_LIMITS["B"]["min"],
maximum=MODEL_SAMPLE_LIMITS["B"]["max"],
value=MODEL_SAMPLE_LIMITS["B"]["default"],
step=1,
label="Number of Samples",
info=f"How many images to generate ({MODEL_SAMPLE_LIMITS['B']['min']}-{MODEL_SAMPLE_LIMITS['B']['max']})"
)
genre_names = list(reduced_genre_mapping.keys())
style_names = list(reduced_style_mapping.keys())
# Sort alphabetically, ensuring 'None' is at top
genre_names.sort()
style_names.sort()
genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre")
style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style")
with gr.Row():
seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
reset_seed_btn = gr.Button("🎲 New Seed")
with gr.Row():
generate_btn = gr.Button("Generate Images", variant="primary")
clear_btn = gr.Button("🗑️ Clear Gallery")
progress_bar = gr.Progress(track_tqdm=True)
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="Generated Images",
columns=6,
rows=4,
height=600,
object_fit="contain",
allow_preview=True,
show_download_button=True
)
error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
dit_size.change(update_sample_slider, inputs=[dit_size],outputs=[num_samples])
# Seed reset button functionality
reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])
# Clear gallery button functionality
clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
# Connect components
generate_btn.click(
fn=generate_samples,
inputs=[num_samples, dit_size, genre, style, seed],
outputs=[output_gallery, error_message],
)
if __name__ == "__main__":
app.launch()