kaupane commited on
Commit
ee39364
·
verified ·
1 Parent(s): b8d8f05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -72,7 +72,7 @@ class DiffusionSampler:
72
  self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
73
  self.vae.eval()
74
 
75
- @spaces.GPU
76
  def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
77
  """Generate images with the DiT model"""
78
  global global_progress
@@ -166,11 +166,11 @@ def generate_random_seed():
166
  """Generate a random seed between 0 and 2^32 - 1"""
167
  return random.randint(0, 2**32 - 1)
168
 
169
- @spaces.GPU
170
  def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
171
  """Main function for Gradio interface"""
172
- if num_samples < 1 or num_samples > 16:
173
- return None, gr.update(value="Number of samples must be between 1 and 16", visible=True)
174
 
175
  # Get genre and style IDs from mappings
176
  genre_id = reduced_genre_mapping.get(genre_name)
@@ -205,8 +205,8 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
205
 
206
  with gr.Row():
207
  with gr.Column(scale=1):
208
- num_samples = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Number of Samples", info="How many images to generate (1-16)")
209
- dit_size = gr.Radio(choices=["S", "B", "L"], value="S", label="DiT Model Size", info="Larger models produce better quality but take longer")
210
 
211
  genre_names = list(reduced_genre_mapping.keys())
212
  style_names = list(reduced_style_mapping.keys())
@@ -230,7 +230,7 @@ with gr.Blocks(title="DiT Diffusion Model Generator", theme=gr.themes.Soft()) as
230
  progress_bar = gr.Progress(track_tqdm=True)
231
 
232
  with gr.Column(scale=2):
233
- output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=4, object_fit="contain", height=600)
234
  error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
235
 
236
  # Seed reset button functionality
 
72
  self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(self.device)
73
  self.vae.eval()
74
 
75
+ @spaces.GPU(duration=120)
76
  def generate_images(self, model, num_samples, genre, style, seed, progress=gr.Progress()):
77
  """Generate images with the DiT model"""
78
  global global_progress
 
166
  """Generate a random seed between 0 and 2^32 - 1"""
167
  return random.randint(0, 2**32 - 1)
168
 
169
+ @spaces.GPU(duration=120)
170
  def generate_samples(num_samples, dit_size, genre_name, style_name, seed, progress=gr.Progress()):
171
  """Main function for Gradio interface"""
172
+ if num_samples < 1 or num_samples > 12:
173
+ return None, gr.update(value="Number of samples must be between 1 and 12", visible=True)
174
 
175
  # Get genre and style IDs from mappings
176
  genre_id = reduced_genre_mapping.get(genre_name)
 
205
 
206
  with gr.Row():
207
  with gr.Column(scale=1):
208
+ num_samples = gr.Slider(minimum=1, maximum=12, value=2, step=1, label="Number of Samples", info="How many images to generate (1-16)")
209
+ dit_size = gr.Radio(choices=["S", "B", "L"], value="B", label="DiT Model Size", info="Larger models produce better quality but take longer")
210
 
211
  genre_names = list(reduced_genre_mapping.keys())
212
  style_names = list(reduced_style_mapping.keys())
 
230
  progress_bar = gr.Progress(track_tqdm=True)
231
 
232
  with gr.Column(scale=2):
233
+ output_gallery = gr.Gallery(label="Generated Images", columns=4, rows=3, object_fit="contain", height=600)
234
  error_message = gr.Textbox(label="Error", visible=False, max_lines=3, container=True, elem_id="error_box")
235
 
236
  # Seed reset button functionality