ProgramerSalar commited on
Commit
6143d6b
·
1 Parent(s): c7532a7
pyramid_dit/pyramid_dit_for_video_gen_pipeline.py CHANGED
@@ -558,18 +558,32 @@ class PyramidDiTForVideoGeneration:
558
  generator,
559
  is_first_frame=True,
560
  )
561
- generated_latents_list.append(intermed_latents[-1])
 
 
 
562
  else:
 
563
 
564
- # check if we have previous frames to condition on
565
- if not generated_latents_list or len(generated_latents_list) == 0:
566
- raise ValueError("No latent tensors generated - check previous steps ")
567
 
568
 
 
 
 
 
 
 
569
 
570
- # prepare the condition latents
571
- past_condition_latents = []
572
- clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
 
 
 
 
 
 
 
573
 
574
  for i_s in range(len(stages)):
575
  last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
 
558
  generator,
559
  is_first_frame=True,
560
  )
561
+ # Ensure
562
+ if not intermed_latents or len(intermed_latents) == 0:
563
+ raise ValueError("First frame generation failed")
564
+ generated_latents_list.append(intermed_latents[-1].clone())
565
  else:
566
+
567
 
 
 
 
568
 
569
 
570
+ # Subsequent frames
571
+ if len(generated_latents_list) == 0:
572
+ raise ValueError("No previous frames available for conditioning")
573
+
574
+ # Debug print
575
+ print(f"Conditioning on {len(generated_latents_list)} existing frames")
576
 
577
+ # Get pyramid latents from existing frames
578
+ try:
579
+ clean_latents_list = self.get_pyramid_latent(
580
+ torch.cat(generated_latents_list, dim=2),
581
+ len(stages) - 1
582
+ )
583
+ except Exception as e:
584
+ print(f"Error concatenating latents: {e}")
585
+ print(f"Shapes: {[x.shape for x in generated_latents_list]}")
586
+ raise
587
 
588
  for i_s in range(len(stages)):
589
  last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]