Spaces:
Sleeping
Sleeping
import torch | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import os | |
import random | |
from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping | |
from diffusers import AutoencoderKL | |
from huggingface_hub import hf_hub_download | |
from models.DiT import DiT | |
# Global settings | |
num_timesteps = 1000 | |
beta_start = 1e-4 | |
beta_end = 0.02 | |
latent_scale_factor = 0.18215 # Same as in DiTTrainer | |
# For tracking progress in UI | |
global_progress = 0 | |
# Enable half precision inference | |
USE_HALF_PRECISION = True | |
def load_dit_model(dit_size): | |
"""Load DiT model of specified size""" | |
# Configure model based on size | |
if dit_size == "S": | |
model = DiT.from_pretrained("kaupane/DiT-Wikiart-Small") | |
elif dit_size == "B": | |
model = DiT.from_pretrained("kaupane/DiT-Wikiart-Base") | |
elif dit_size == "L": | |
model = DiT.from_pretrained("kaupane/DiT-Wikiart-Large") | |
else: | |
raise ValueError(f"Invalid DiT size: {dit_size}") | |
return model | |
class DiffusionSampler: | |
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half = USE_HALF_PRECISION): | |
self.device = device | |
self.use_half = use_half | |
self.vae = None | |
# Pre-compute diffusion parameters | |
self.betas = torch.linspace(beta_start, beta_end, num_timesteps) | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |
self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) | |
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) | |
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]]) | |
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
# Move to device | |
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device) | |
self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device) | |
self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device) | |
self.betas = self.betas.to(self.device) | |
self.posterior_variance = self.posterior_variance.to(self.device) | |
# Convert to half precision if needed | |
if self.use_half: | |
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half() | |
self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half() | |
self.sqrt_recip_alphas = self.sqrt_recip_alphas.half() | |
self.betas = self.betas.half() | |
self.posterior_variance = self.posterior_variance.half() | |
def load_vae(self): | |
"""Load VAE model (done lazily to save memory until needed)""" | |
if self.vae is None: | |
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device) | |
self.vae.eval() | |
def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()): | |
"""Generate images with the DiT model""" | |
global global_progress | |
global_progress = 0 | |
# Set random seed for reproducibility | |
if seed is not None: | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
# Also set CUDA seed if using GPU | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
if self.use_half: | |
model.half() | |
model.to(self.device) | |
model.eval() | |
# Convert genre and style to tensors | |
g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long) | |
s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long) | |
g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long) | |
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long) | |
# Start with random latents | |
latents = torch.randn((num_samples, 4, 32, 32), device=self.device) | |
if self.use_half: | |
latents = latents.half() | |
# Use classifier-free guidance for better quality | |
cfg_scale = 2.5 | |
# Go through the reverse diffusion process | |
timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device) | |
total_steps = len(timesteps) | |
with torch.no_grad(): | |
for i, t_val in enumerate(timesteps): | |
# Update progress | |
global_progress = int(100 * i / total_steps) | |
progress(global_progress / 100, desc="Generating images...") | |
t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long) | |
sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1) | |
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1) | |
beta_t = self.betas[t].view(-1, 1, 1, 1) | |
posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1) | |
# Get noise prediction with classifier-free guidance | |
eps_theta_cond = model(latents, t, g_cond, s_cond) | |
eps_theta_uncond = model(latents, t, g_null, s_null) | |
eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond) | |
# Update latents | |
mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta) | |
noise = torch.randn_like(latents) | |
if t_val == 0: | |
latents = mean | |
else: | |
latents = mean + torch.sqrt(posterior_variance_t) * noise | |
# Decode latents to images | |
self.load_vae() | |
# Convert back to float | |
if self.use_half: | |
latents = latents.float() | |
latents = latents / self.vae.config.scaling_factor | |
latents = latents.to(self.device) | |
progress(0.95, desc="Decoding images...") | |
with torch.no_grad(): | |
images = self.vae.decode(latents).sample | |
images = (images / 2 + 0.5).clamp(0, 1) | |
images = images.permute(0, 2, 3, 1).cpu().numpy() | |
progress(1.0, desc="Done!") | |
global_progress = 100 | |
# Create image gallery with labels | |
gallery_images = [] | |
for i in range(num_samples): | |
# Convert numpy array to PIL Image | |
img = (images[i] * 255).astype(np.uint8) | |
caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}" | |
if seed is not None: | |
caption += f" (Seed: {seed})" | |
gallery_images.append((img, caption)) | |
return gallery_images | |
# Initialize sampler globally | |
sampler = DiffusionSampler() | |
def generate_random_seed(): | |
"""Generate a random seed between 0 and 2^32 - 1""" | |
return random.randint(0, 2**32 - 1) | |
MODEL_SAMPLE_LIMITS = { | |
"S": {"min":1, "max": 24, "default": 6}, | |
"B": {"min":1, "max": 16, "default": 4}, | |
"L": {"min":1, "max": 8, "default": 2} | |
} | |
def update_sample_slider(dit_size): | |
limits = MODEL_SAMPLE_LIMITS[dit_size] | |
return gr.update( | |
minimum=limits["min"], | |
maximum=limits["max"], | |
value=limits["default"], | |
info=f"How many images to generate ({limits['min']}-{limits['max']})" | |
) | |
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()): | |
"""Main function for Gradio interface""" | |
limits = MODEL_SAMPLE_LIMITS[dit_size] | |
if num_samples < limits["min"] or num_samples > limits["max"]: | |
return None, gr.update(value=f"Number of samples for {dit_size} model must be between {limits['min']} and {limits['max']}", visible=True) | |
# Get genre and style IDs from mappings | |
genre_id = reduced_genre_mapping.get(genre_name) | |
style_id = reduced_style_mapping.get(style_name) | |
if genre_id is None: | |
return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True) | |
if style_id is None: | |
return None, gr.update(value=f"Unknown style: {style_name}", visible=True) | |
try: | |
# Load model | |
progress(0.05, desc="Loading DiT model...") | |
model = load_dit_model(dit_size) | |
# Generate images | |
gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress) | |
return gallery_images, gr.update(value="", visible=False) | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
return None, gr.update(value=error_msg, visible=True) | |
def clear_gallery(): | |
"""Clear the gallery display""" | |
return None, gr.update(value="", visible=False) | |
# Create the Gradio interface | |
with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# DiT Diffusion Model Generator") | |
gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
dit_size = gr.Radio( | |
choices=["S", "B", "L"], | |
value="B", | |
label="DiT Model Size", | |
info="S: Small (fastest), B: Base (balanced), L: Large (best quality but slowest)" | |
) | |
num_samples = gr.Slider( | |
minimum=MODEL_SAMPLE_LIMITS["B"]["min"], | |
maximum=MODEL_SAMPLE_LIMITS["B"]["max"], | |
value=MODEL_SAMPLE_LIMITS["B"]["default"], | |
step=1, | |
label="Number of Samples", | |
info=f"How many images to generate ({MODEL_SAMPLE_LIMITS['B']['min']}-{MODEL_SAMPLE_LIMITS['B']['max']})" | |
) | |
genre_names = list(reduced_genre_mapping.keys()) | |
style_names = list(reduced_style_mapping.keys()) | |
# Sort alphabetically, ensuring 'None' is at top | |
genre_names.sort() | |
style_names.sort() | |
genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre") | |
style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style") | |
with gr.Row(): | |
seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results") | |
reset_seed_btn = gr.Button("🎲 New Seed") | |
with gr.Row(): | |
generate_btn = gr.Button("Generate Images", variant="primary") | |
clear_btn = gr.Button("🗑️ Clear Gallery") | |
progress_bar = gr.Progress(track_tqdm=True) | |
with gr.Column(scale=2): | |
output_gallery = gr.Gallery( | |
label="Generated Images", | |
columns=6, | |
rows=4, | |
height=600, | |
object_fit="contain", | |
allow_preview=True, | |
show_download_button=True | |
) | |
error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box") | |
dit_size.change(update_sample_slider, inputs=[dit_size],outputs=[num_samples]) | |
# Seed reset button functionality | |
reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed]) | |
# Clear gallery button functionality | |
clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message]) | |
# Connect components | |
generate_btn.click( | |
fn=generate_samples, | |
inputs=[num_samples, dit_size, genre, style, seed], | |
outputs=[output_gallery, error_message], | |
) | |
if __name__ == "__main__": | |
app.launch() | |