File size: 12,020 Bytes
c1b39e1
61df749
c1b39e1
 
 
 
 
 
1a917cb
c1b39e1
 
 
 
 
 
 
 
 
 
 
ed77fec
cc3f951
ed77fec
c1b39e1
 
 
 
 
2f32e9a
c1b39e1
2f32e9a
c1b39e1
2f32e9a
c1b39e1
 
 
 
 
 
ed77fec
c1b39e1
ed77fec
c1b39e1
 
 
 
 
 
 
 
 
 
 
 
d7b421a
 
 
 
 
 
ed77fec
 
 
 
 
 
 
 
9bb4b8c
c1b39e1
 
 
 
 
61df749
ee39364
c1b39e1
 
 
 
 
 
 
 
 
 
 
 
 
 
cc3f951
 
 
c1b39e1
 
 
 
 
 
 
 
 
3085d39
9bb4b8c
ed77fec
 
c1b39e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bb4b8c
 
 
 
c1b39e1
 
 
 
 
 
 
 
9bb4b8c
c1b39e1
 
 
 
 
3085d39
c1b39e1
ed77fec
 
 
 
9bb4b8c
3085d39
c1b39e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3085d39
c1b39e1
 
 
 
 
bfbe2d4
3b7752b
 
 
bfbe2d4
 
 
 
 
 
 
 
 
 
 
ee39364
c1b39e1
 
bfbe2d4
 
c95f19f
c1b39e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfbe2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1b39e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea1e7c8
 
3b7752b
 
9ad5c03
ea1e7c8
 
 
 
c1b39e1
bfbe2d4
44af800
bfbe2d4
c1b39e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3085d39
c1b39e1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import torch
import spaces
import gradio as gr
import numpy as np
import os
import random
from mapping import reduced_genre_mapping, reduced_style_mapping, reverse_reduced_genre_mapping, reverse_reduced_style_mapping
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from models.DiT import DiT

# Global settings 
num_timesteps = 1000
beta_start = 1e-4
beta_end = 0.02
latent_scale_factor = 0.18215  # Same as in DiTTrainer

# For tracking progress in UI
global_progress = 0

# Enable half precision inference
USE_HALF_PRECISION = True

def load_dit_model(dit_size):
    """Load DiT model of specified size"""
    
    # Configure model based on size
    if dit_size == "S":
        model = DiT.from_pretrained("kaupane/DiT-Wikiart-Small")
    elif dit_size == "B":
        model = DiT.from_pretrained("kaupane/DiT-Wikiart-Base")
    elif dit_size == "L":
        model = DiT.from_pretrained("kaupane/DiT-Wikiart-Large")
    else:
        raise ValueError(f"Invalid DiT size: {dit_size}")
    
    return model

class DiffusionSampler:
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", use_half = USE_HALF_PRECISION):
        self.device = device
        self.use_half = use_half
        self.vae = None
        
        # Pre-compute diffusion parameters
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        
        # Move to device
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device)
        self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.to(self.device)
        self.sqrt_recip_alphas = self.sqrt_recip_alphas.to(self.device)
        self.betas = self.betas.to(self.device)
        self.posterior_variance = self.posterior_variance.to(self.device)

        # Convert to half precision if needed
        if self.use_half:
            self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.half()
            self.sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod.half()
            self.sqrt_recip_alphas = self.sqrt_recip_alphas.half()
            self.betas = self.betas.half()
            self.posterior_variance = self.posterior_variance.half()
        
    def load_vae(self):
        """Load VAE model (done lazily to save memory until needed)"""
        if self.vae is None:
            self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
            self.vae.eval()

    @spaces.GPU(duration=120)
    def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
        """Generate images with the DiT model"""
        global global_progress
        global_progress = 0
        
        # Set random seed for reproducibility
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            # Also set CUDA seed if using GPU
            if torch.cuda.is_available():
                torch.cuda.manual_seed(seed)
                torch.cuda.manual_seed_all(seed)

        if self.use_half:
            model.half()
        model.to(self.device)
        model.eval()
        
        # Convert genre and style to tensors
        g_cond = torch.tensor([genre] * num_samples, device=self.device, dtype=torch.long)
        s_cond = torch.tensor([style] * num_samples, device=self.device, dtype=torch.long)
        g_null = torch.tensor([model.num_genres] * num_samples, device=self.device, dtype=torch.long)
        s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
        
        # Start with random latents
        latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
        if self.use_half:
            latents = latents.half()
        
        # Use classifier-free guidance for better quality
        cfg_scale = 2.5
        
        # Go through the reverse diffusion process
        timesteps = torch.arange(num_timesteps - 1, -1, -1, device=self.device)
        total_steps = len(timesteps)
        
        with torch.no_grad():
            for i, t_val in enumerate(timesteps):
                # Update progress
                global_progress = int(100 * i / total_steps)
                progress(global_progress / 100, desc="Generating images...")
                
                t = torch.full((num_samples,), t_val, device=self.device, dtype=torch.long)

                sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1, 1, 1)
                sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
                beta_t = self.betas[t].view(-1, 1, 1, 1)
                posterior_variance_t = self.posterior_variance[t].view(-1, 1, 1, 1)

                # Get noise prediction with classifier-free guidance
                eps_theta_cond = model(latents, t, g_cond, s_cond)
                eps_theta_uncond = model(latents, t, g_null, s_null)
                eps_theta = eps_theta_uncond + cfg_scale * (eps_theta_cond - eps_theta_uncond)

                # Update latents
                mean = sqrt_recip_alphas_t * (latents - (beta_t / sqrt_one_minus_alphas_cumprod_t) * eps_theta)
                noise = torch.randn_like(latents)
                if t_val == 0:
                    latents = mean
                else:
                    latents = mean + torch.sqrt(posterior_variance_t) * noise
        
        # Decode latents to images
        self.load_vae()

        # Convert back to float
        if self.use_half:
            latents = latents.float()
        latents = latents / self.vae.config.scaling_factor
        latents = latents.to(self.device)
        
        progress(0.95, desc="Decoding images...")
        with torch.no_grad():
            images = self.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)
        images = images.permute(0, 2, 3, 1).cpu().numpy()
        
        progress(1.0, desc="Done!")
        global_progress = 100
        
        # Create image gallery with labels
        gallery_images = []
        for i in range(num_samples):
            # Convert numpy array to PIL Image
            img = (images[i] * 255).astype(np.uint8)
            caption = f"Genre: {reverse_reduced_genre_mapping[genre]}, Style: {reverse_reduced_style_mapping[style]}"
            if seed is not None:
                caption += f" (Seed: {seed})"
            gallery_images.append((img, caption))
        
        return gallery_images

# Initialize sampler globally
sampler = DiffusionSampler()

def generate_random_seed():
    """Generate a random seed between 0 and 2^32 - 1"""
    return random.randint(0, 2**32 - 1)

MODEL_SAMPLE_LIMITS = {
    "S": {"min":1, "max": 24, "default": 6},
    "B": {"min":1, "max": 16, "default": 4},
    "L": {"min":1, "max": 8, "default": 2}
}

def update_sample_slider(dit_size):
    limits = MODEL_SAMPLE_LIMITS[dit_size]
    return gr.update(
        minimum=limits["min"],
        maximum=limits["max"],
        value=limits["default"],
        info=f"How many images to generate ({limits['min']}-{limits['max']})"        
    )

@spaces.GPU(duration=120)
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
    """Main function for Gradio interface"""
    limits = MODEL_SAMPLE_LIMITS[dit_size]
    if num_samples < limits["min"] or num_samples > limits["max"]:
        return None, gr.update(value=f"Number of samples for {dit_size} model must be between {limits['min']} and {limits['max']}", visible=True)
    
    # Get genre and style IDs from mappings
    genre_id = reduced_genre_mapping.get(genre_name)
    style_id = reduced_style_mapping.get(style_name)
    
    if genre_id is None:
        return None, gr.update(value=f"Unknown genre: {genre_name}", visible=True)
    if style_id is None:
        return None, gr.update(value=f"Unknown style: {style_name}", visible=True)
    
    try:
        # Load model
        progress(0.05, desc="Loading DiT model...")
        model = load_dit_model(dit_size)
        
        # Generate images
        gallery_images = sampler.generate_images(model, num_samples, genre_id, style_id, seed, progress)
        
        return gallery_images, gr.update(value="", visible=False)
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        return None, gr.update(value=error_msg, visible=True)

def clear_gallery():
    """Clear the gallery display"""
    return None, gr.update(value="", visible=False)

# Create the Gradio interface
with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as app:
    gr.Markdown("# DiT Diffusion Model Generator")
    gr.Markdown("Generate art images using a Diffusion Transformer (DiT) model")
    
    with gr.Row():
        with gr.Column(scale=1):
            dit_size = gr.Radio(
                choices=["S", "B", "L"], 
                value="B", 
                label="DiT Model Size", 
                info="S: Small (fastest), B: Base (balanced), L: Large (best quality but slowest)"
            )

            num_samples = gr.Slider(
                minimum=MODEL_SAMPLE_LIMITS["B"]["min"],
                maximum=MODEL_SAMPLE_LIMITS["B"]["max"],
                value=MODEL_SAMPLE_LIMITS["B"]["default"],
                step=1,
                label="Number of Samples",
                info=f"How many images to generate ({MODEL_SAMPLE_LIMITS['B']['min']}-{MODEL_SAMPLE_LIMITS['B']['max']})"
            )
            
            genre_names = list(reduced_genre_mapping.keys())
            style_names = list(reduced_style_mapping.keys())
            
            # Sort alphabetically, ensuring 'None' is at top
            genre_names.sort()
            
            style_names.sort()
            
            genre = gr.Dropdown(choices=genre_names, value="landscape", label="Art Genre")
            style = gr.Dropdown(choices=style_names, value="impressionism", label="Art Style")
            
            with gr.Row():
                seed = gr.Number(label="Seed", value=generate_random_seed(), precision=0, info="Set for reproducible results")
                reset_seed_btn = gr.Button("🎲 New Seed")
            
            with gr.Row():
                generate_btn = gr.Button("Generate Images", variant="primary")
                clear_btn = gr.Button("🗑️ Clear Gallery")
            
            progress_bar = gr.Progress(track_tqdm=True)
            
        with gr.Column(scale=2):
            output_gallery = gr.Gallery(
                label="Generated Images",
                columns=6,
                rows=4,
                height=600,
                object_fit="contain",
                allow_preview=True,
                show_download_button=True
            )
            error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")

    
    dit_size.change(update_sample_slider, inputs=[dit_size],outputs=[num_samples])
    
    # Seed reset button functionality
    reset_seed_btn.click(generate_random_seed, inputs=[], outputs=[seed])
    
    # Clear gallery button functionality
    clear_btn.click(clear_gallery, inputs=[], outputs=[output_gallery, error_message])
    
    # Connect components
    generate_btn.click(
        fn=generate_samples,
        inputs=[num_samples, dit_size, genre, style, seed],
        outputs=[output_gallery, error_message],
    )



if __name__ == "__main__":
    app.launch()