HReynaud commited on
Commit
2d696f5
·
1 Parent(s): f2a9d21

lower ZeroGPU usage

Browse files
Files changed (1) hide show
  1. demo.py +12 -12
demo.py CHANGED
@@ -242,8 +242,8 @@ def preprocess_mask(mask):
242
  return np.array(mask_pil)
243
 
244
 
245
- @spaces.GPU
246
- @torch.no_grad(duration=3)
247
  def generate_latent_image(mask, class_selection, sampling_steps=50):
248
  """Generate a latent image based on mask, class selection, and sampling steps"""
249
 
@@ -306,8 +306,8 @@ def generate_latent_image(mask, class_selection, sampling_steps=50):
306
  return latent_image # B x C x H x W
307
 
308
 
309
- @spaces.GPU
310
- @torch.no_grad(duration=3)
311
  def decode_images(latents):
312
  """Decode latent representations to pixel space using a VAE.
313
 
@@ -385,8 +385,8 @@ def decode_latent_to_pixel(latent_image):
385
  return decoded_image
386
 
387
 
388
- @spaces.GPU
389
- @torch.no_grad(duration=3)
390
  def check_privacy(latent_image_numpy, class_selection):
391
  """Check if the latent image is too similar to database images"""
392
  latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
@@ -412,8 +412,8 @@ def check_privacy(latent_image_numpy, class_selection):
412
  )
413
 
414
 
415
- @spaces.GPU
416
- @torch.no_grad(duration=3)
417
  def generate_animation(
418
  latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
419
  ):
@@ -501,8 +501,8 @@ def generate_animation(
501
  return synthetic_video.detach().cpu() # B x C x T x H x W
502
 
503
 
504
- @spaces.GPU
505
- @torch.no_grad(duration=3)
506
  def decode_animation(latent_animation):
507
  """Decode a latent animation to pixel space"""
508
  if latent_animation is None:
@@ -577,8 +577,8 @@ def convert_latent_to_display(latent_image):
577
  return display_image
578
 
579
 
580
- @spaces.GPU
581
- @torch.no_grad(duration=3)
582
  def latent_animation_to_grayscale(latent_animation):
583
  """Convert multi-channel latent animation to grayscale for display"""
584
  if latent_animation is None:
 
242
  return np.array(mask_pil)
243
 
244
 
245
+ @spaces.GPU(duration=3)
246
+ @torch.no_grad()
247
  def generate_latent_image(mask, class_selection, sampling_steps=50):
248
  """Generate a latent image based on mask, class selection, and sampling steps"""
249
 
 
306
  return latent_image # B x C x H x W
307
 
308
 
309
+ @spaces.GPU(duration=3)
310
+ @torch.no_grad()
311
  def decode_images(latents):
312
  """Decode latent representations to pixel space using a VAE.
313
 
 
385
  return decoded_image
386
 
387
 
388
+ @spaces.GPU(duration=3)
389
+ @torch.no_grad()
390
  def check_privacy(latent_image_numpy, class_selection):
391
  """Check if the latent image is too similar to database images"""
392
  latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
 
412
  )
413
 
414
 
415
+ @spaces.GPU(duration=3)
416
+ @torch.no_grad()
417
  def generate_animation(
418
  latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
419
  ):
 
501
  return synthetic_video.detach().cpu() # B x C x T x H x W
502
 
503
 
504
+ @spaces.GPU(duration=3)
505
+ @torch.no_grad()
506
  def decode_animation(latent_animation):
507
  """Decode a latent animation to pixel space"""
508
  if latent_animation is None:
 
577
  return display_image
578
 
579
 
580
+ @spaces.GPU(duration=3)
581
+ @torch.no_grad()
582
  def latent_animation_to_grayscale(latent_animation):
583
  """Convert multi-channel latent animation to grayscale for display"""
584
  if latent_animation is None: