ProgramerSalar commited on
Commit
c7532a7
·
1 Parent(s): 7ae5123
LICENSE CHANGED
@@ -1,6 +1,6 @@
1
  MIT License
2
 
3
- Copyright (c) 2024 Yang Jin
4
 
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
  of this software and associated documentation files (the "Software"), to deal
 
1
  MIT License
2
 
3
+ Copyright (c)
4
 
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
  of this software and associated documentation files (the "Software"), to deal
pyramid_dit/pyramid_dit_for_video_gen_pipeline.py CHANGED
@@ -541,6 +541,7 @@ class PyramidDiTForVideoGeneration:
541
  self._video_guidance_scale = guidance_scale_list[unit_index]
542
 
543
  if unit_index == 0:
 
544
  past_condition_latents = [[] for _ in range(len(stages))]
545
  intermed_latents = self.generate_one_unit(
546
  latents[:,:,:1],
@@ -557,15 +558,17 @@ class PyramidDiTForVideoGeneration:
557
  generator,
558
  is_first_frame=True,
559
  )
 
560
  else:
561
- # prepare the condition latents
562
- past_condition_latents = []
563
- # if not generated_latents_list or len(generated_latents_list) == 0:
564
- # raise ValueError("No latent tensors generated - check previous steps ")
565
 
566
- print(f"know the tensor of generated_latent_list: {generated_latents_list}")
567
- print(f"know the len(stages): {len(stages) - 1}")
568
 
 
 
 
569
  clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
570
 
571
  for i_s in range(len(stages)):
 
541
  self._video_guidance_scale = guidance_scale_list[unit_index]
542
 
543
  if unit_index == 0:
544
+ # Generate first frame
545
  past_condition_latents = [[] for _ in range(len(stages))]
546
  intermed_latents = self.generate_one_unit(
547
  latents[:,:,:1],
 
558
  generator,
559
  is_first_frame=True,
560
  )
561
+ generated_latents_list.append(intermed_latents[-1])
562
  else:
563
+
564
+ # check if we have previous frames to condition on
565
+ if not generated_latents_list or len(generated_latents_list) == 0:
566
+ raise ValueError("No latent tensors generated - check previous steps ")
567
 
 
 
568
 
569
+
570
+ # prepare the condition latents
571
+ past_condition_latents = []
572
  clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
573
 
574
  for i_s in range(len(stages)):