ZeroGPU
Browse files
demo.py
CHANGED
@@ -498,7 +498,7 @@ def generate_animation(
|
|
498 |
|
499 |
print("Animation generated")
|
500 |
|
501 |
-
return synthetic_video.detach() # B x C x T x H x W
|
502 |
|
503 |
|
504 |
@spaces.GPU
|
@@ -510,7 +510,8 @@ def decode_animation(latent_animation):
|
|
510 |
|
511 |
# Convert to torch tensor if needed
|
512 |
if not isinstance(latent_animation, torch.Tensor):
|
513 |
-
latent_animation = torch.from_numpy(latent_animation)
|
|
|
514 |
|
515 |
# Ensure shape is B x C x T x H x W
|
516 |
if len(latent_animation.shape) == 4: # [T, C, H, W]
|
|
|
498 |
|
499 |
print("Animation generated")
|
500 |
|
501 |
+
return synthetic_video.detach().cpu() # B x C x T x H x W
|
502 |
|
503 |
|
504 |
@spaces.GPU
|
|
|
510 |
|
511 |
# Convert to torch tensor if needed
|
512 |
if not isinstance(latent_animation, torch.Tensor):
|
513 |
+
latent_animation = torch.from_numpy(latent_animation)
|
514 |
+
latent_animation = latent_animation.to(device, dtype=dtype)
|
515 |
|
516 |
# Ensure shape is B x C x T x H x W
|
517 |
if len(latent_animation.shape) == 4: # [T, C, H, W]
|