Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -72,7 +72,7 @@ class DiffusionSampler:
|
|
72 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
|
73 |
self.vae.eval()
|
74 |
|
75 |
-
@spaces.GPU
|
76 |
def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
|
77 |
"""Generate images with the DiT model"""
|
78 |
global global_progress
|
@@ -166,11 +166,11 @@ def generate_random_seed():
|
|
166 |
"""Generate a random seed between 0 and 2^32 - 1"""
|
167 |
return random.randint(0, 2**32 - 1)
|
168 |
|
169 |
-
@spaces.GPU
|
170 |
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
|
171 |
"""Main function for Gradio interface"""
|
172 |
-
if num_samples < 1 or num_samples >
|
173 |
-
return None, gr.update(value="Number of samples must be between 1 and
|
174 |
|
175 |
# Get genre and style IDs from mappings
|
176 |
genre_id = reduced_genre_mapping.get(genre_name)
|
@@ -205,8 +205,8 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
|
|
205 |
|
206 |
with gr.Row():
|
207 |
with gr.Column(scale=1):
|
208 |
-
num_samples = gr.Slider(minimum=1, maximum=
|
209 |
-
dit_size = gr.Radio(choices=["S", "B", "L"], value="
|
210 |
|
211 |
genre_names = list(reduced_genre_mapping.keys())
|
212 |
style_names = list(reduced_style_mapping.keys())
|
@@ -230,7 +230,7 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
|
|
230 |
progress_bar = gr.Progress(track_tqdm=True)
|
231 |
|
232 |
with gr.Column(scale=2):
|
233 |
-
output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=
|
234 |
error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
|
235 |
|
236 |
# Seed reset button functionality
|
|
|
72 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
|
73 |
self.vae.eval()
|
74 |
|
75 |
+
@spaces.GPU(duration=120)
|
76 |
def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
|
77 |
"""Generate images with the DiT model"""
|
78 |
global global_progress
|
|
|
166 |
"""Generate a random seed between 0 and 2^32 - 1"""
|
167 |
return random.randint(0, 2**32 - 1)
|
168 |
|
169 |
+
@spaces.GPU(duration=120)
|
170 |
def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
|
171 |
"""Main function for Gradio interface"""
|
172 |
+
if num_samples < 1 or num_samples > 12:
|
173 |
+
return None, gr.update(value="Number of samples must be between 1 and 12", visible=True)
|
174 |
|
175 |
# Get genre and style IDs from mappings
|
176 |
genre_id = reduced_genre_mapping.get(genre_name)
|
|
|
205 |
|
206 |
with gr.Row():
|
207 |
with gr.Column(scale=1):
|
208 |
+
num_samples = gr.Slider(minimum=1, maximum=12, value=2, step=1, label="Number of Samples", info="How many images to generate (1-16)")
|
209 |
+
dit_size = gr.Radio(choices=["S", "B", "L"], value="B", label="DiT Model Size", info="Larger models produce better quality but take longer")
|
210 |
|
211 |
genre_names = list(reduced_genre_mapping.keys())
|
212 |
style_names = list(reduced_style_mapping.keys())
|
|
|
230 |
progress_bar = gr.Progress(track_tqdm=True)
|
231 |
|
232 |
with gr.Column(scale=2):
|
233 |
+
output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=3, object_fit="contain", height=600)
|
234 |
error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
|
235 |
|
236 |
# Seed reset button functionality
|