HReynaud commited on
Commit
2a15265
·
1 Parent(s): f03832e
Files changed (1) hide show
  1. demo.py +3 -2
demo.py CHANGED
@@ -498,7 +498,7 @@ def generate_animation(
498
 
499
  print("Animation generated")
500
 
501
- return synthetic_video.detach() # B x C x T x H x W
502
 
503
 
504
  @spaces.GPU
@@ -510,7 +510,8 @@ def decode_animation(latent_animation):
510
 
511
  # Convert to torch tensor if needed
512
  if not isinstance(latent_animation, torch.Tensor):
513
- latent_animation = torch.from_numpy(latent_animation).to(device, dtype=dtype)
 
514
 
515
  # Ensure shape is B x C x T x H x W
516
  if len(latent_animation.shape) == 4: # [T, C, H, W]
 
498
 
499
  print("Animation generated")
500
 
501
+ return synthetic_video.detach().cpu() # B x C x T x H x W
502
 
503
 
504
  @spaces.GPU
 
510
 
511
  # Convert to torch tensor if needed
512
  if not isinstance(latent_animation, torch.Tensor):
513
+ latent_animation = torch.from_numpy(latent_animation)
514
+ latent_animation = latent_animation.to(device, dtype=dtype)
515
 
516
  # Ensure shape is B x C x T x H x W
517
  if len(latent_animation.shape) == 4: # [T, C, H, W]