Spaces:
Runtime error
Runtime error
Update opensora/models/ae/videobase/causal_vae/modeling_causalvae.py
Browse files
opensora/models/ae/videobase/causal_vae/modeling_causalvae.py
CHANGED
@@ -316,7 +316,8 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
316 |
self.tile_sample_min_size = 256
|
317 |
self.tile_sample_min_size_t = 65
|
318 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
|
319 |
-
|
|
|
320 |
self.tile_overlap_factor = 0.25
|
321 |
self.use_tiling = False
|
322 |
|
@@ -374,8 +375,9 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
374 |
if self.use_tiling and (
|
375 |
x.shape[-1] > self.tile_sample_min_size
|
376 |
or x.shape[-2] > self.tile_sample_min_size
|
|
|
377 |
):
|
378 |
-
return self.
|
379 |
h = self.encoder(x)
|
380 |
moments = self.quant_conv(h)
|
381 |
posterior = DiagonalGaussianDistribution(moments)
|
@@ -385,8 +387,9 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
385 |
if self.use_tiling and (
|
386 |
z.shape[-1] > self.tile_latent_min_size
|
387 |
or z.shape[-2] > self.tile_latent_min_size
|
|
|
388 |
):
|
389 |
-
return self.
|
390 |
z = self.post_quant_conv(z)
|
391 |
dec = self.decoder(z)
|
392 |
return dec
|
@@ -554,7 +557,54 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
554 |
) + b[:, :, :, :, x] * (x / blend_extent)
|
555 |
return b
|
556 |
|
557 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
559 |
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
560 |
row_limit = self.tile_latent_min_size - blend_extent
|
@@ -590,7 +640,8 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
590 |
|
591 |
moments = torch.cat(result_rows, dim=3)
|
592 |
posterior = DiagonalGaussianDistribution(moments)
|
593 |
-
|
|
|
594 |
return posterior
|
595 |
|
596 |
def tiled_decode2d(self, z):
|
|
|
316 |
self.tile_sample_min_size = 256
|
317 |
self.tile_sample_min_size_t = 65
|
318 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
|
319 |
+
t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
|
320 |
+
self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1
|
321 |
self.tile_overlap_factor = 0.25
|
322 |
self.use_tiling = False
|
323 |
|
|
|
375 |
if self.use_tiling and (
|
376 |
x.shape[-1] > self.tile_sample_min_size
|
377 |
or x.shape[-2] > self.tile_sample_min_size
|
378 |
+
or x.shape[-3] > self.tile_sample_min_size_t
|
379 |
):
|
380 |
+
return self.tiled_encode(x)
|
381 |
h = self.encoder(x)
|
382 |
moments = self.quant_conv(h)
|
383 |
posterior = DiagonalGaussianDistribution(moments)
|
|
|
387 |
if self.use_tiling and (
|
388 |
z.shape[-1] > self.tile_latent_min_size
|
389 |
or z.shape[-2] > self.tile_latent_min_size
|
390 |
+
or z.shape[-3] > self.tile_latent_min_size_t
|
391 |
):
|
392 |
+
return self.tiled_decode(z)
|
393 |
z = self.post_quant_conv(z)
|
394 |
dec = self.decoder(z)
|
395 |
return dec
|
|
|
557 |
) + b[:, :, :, :, x] * (x / blend_extent)
|
558 |
return b
|
559 |
|
560 |
+
def tiled_encode(self, x):
|
561 |
+
t = x.shape[2]
|
562 |
+
t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t-1)]
|
563 |
+
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
|
564 |
+
t_chunk_start_end = [[0, t]]
|
565 |
+
else:
|
566 |
+
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)]
|
567 |
+
if t_chunk_start_end[-1][-1] > t:
|
568 |
+
t_chunk_start_end[-1][-1] = t
|
569 |
+
elif t_chunk_start_end[-1][-1] < t:
|
570 |
+
last_start_end = [t_chunk_idx[-1], t]
|
571 |
+
t_chunk_start_end.append(last_start_end)
|
572 |
+
moments = []
|
573 |
+
for idx, (start, end) in enumerate(t_chunk_start_end):
|
574 |
+
chunk_x = x[:, :, start: end]
|
575 |
+
if idx != 0:
|
576 |
+
moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
|
577 |
+
else:
|
578 |
+
moment = self.tiled_encode2d(chunk_x, return_moments=True)
|
579 |
+
moments.append(moment)
|
580 |
+
moments = torch.cat(moments, dim=2)
|
581 |
+
posterior = DiagonalGaussianDistribution(moments)
|
582 |
+
return posterior
|
583 |
+
|
584 |
+
def tiled_decode(self, x):
|
585 |
+
t = x.shape[2]
|
586 |
+
t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t-1)]
|
587 |
+
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
|
588 |
+
t_chunk_start_end = [[0, t]]
|
589 |
+
else:
|
590 |
+
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)]
|
591 |
+
if t_chunk_start_end[-1][-1] > t:
|
592 |
+
t_chunk_start_end[-1][-1] = t
|
593 |
+
elif t_chunk_start_end[-1][-1] < t:
|
594 |
+
last_start_end = [t_chunk_idx[-1], t]
|
595 |
+
t_chunk_start_end.append(last_start_end)
|
596 |
+
dec_ = []
|
597 |
+
for idx, (start, end) in enumerate(t_chunk_start_end):
|
598 |
+
chunk_x = x[:, :, start: end]
|
599 |
+
if idx != 0:
|
600 |
+
dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
|
601 |
+
else:
|
602 |
+
dec = self.tiled_decode2d(chunk_x)
|
603 |
+
dec_.append(dec)
|
604 |
+
dec_ = torch.cat(dec_, dim=2)
|
605 |
+
return dec_
|
606 |
+
|
607 |
+
def tiled_encode2d(self, x, return_moments=False):
|
608 |
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
609 |
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
610 |
row_limit = self.tile_latent_min_size - blend_extent
|
|
|
640 |
|
641 |
moments = torch.cat(result_rows, dim=3)
|
642 |
posterior = DiagonalGaussianDistribution(moments)
|
643 |
+
if return_moments:
|
644 |
+
return moments
|
645 |
return posterior
|
646 |
|
647 |
def tiled_decode2d(self, z):
|