ProgramerSalar commited on
Commit
05a9e89
·
1 Parent(s): 6143d6b
pyramid_dit/pyramid_dit_for_video_gen_pipeline.py CHANGED
@@ -533,7 +533,7 @@ class PyramidDiTForVideoGeneration:
533
  stages = self.stages
534
 
535
  generated_latents_list = [] # The generated results
536
- last_generated_latents = None
537
 
538
  for unit_index in tqdm(range(num_units)):
539
  if use_linear_guidance:
@@ -542,92 +542,114 @@ class PyramidDiTForVideoGeneration:
542
 
543
  if unit_index == 0:
544
  # Generate first frame
545
- past_condition_latents = [[] for _ in range(len(stages))]
546
- intermed_latents = self.generate_one_unit(
547
- latents[:,:,:1],
548
- past_condition_latents,
549
- prompt_embeds,
550
- prompt_attention_mask,
551
- pooled_prompt_embeds,
552
- num_inference_steps,
553
- height,
554
- width,
555
- 1,
556
- device,
557
- dtype,
558
- generator,
559
- is_first_frame=True,
560
- )
561
- # Ensure
562
- if not intermed_latents or len(intermed_latents) == 0:
563
- raise ValueError("First frame generation failed")
 
 
 
564
  generated_latents_list.append(intermed_latents[-1].clone())
 
 
 
 
 
565
  else:
566
 
567
-
 
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
- # Subsequent frames
571
- if len(generated_latents_list) == 0:
572
- raise ValueError("No previous frames available for conditioning")
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
- # Debug print
575
- print(f"Conditioning on {len(generated_latents_list)} existing frames")
576
-
577
- # Get pyramid latents from existing frames
578
- try:
579
- clean_latents_list = self.get_pyramid_latent(
580
- torch.cat(generated_latents_list, dim=2),
581
- len(stages) - 1
582
  )
583
- except Exception as e:
584
- print(f"Error concatenating latents: {e}")
585
- print(f"Shapes: {[x.shape for x in generated_latents_list]}")
586
- raise
587
-
588
- for i_s in range(len(stages)):
589
- last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
590
 
591
- stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
592
-
593
- # pad the past clean latents
594
- cur_unit_num = unit_index
595
- cur_stage = i_s
596
- cur_unit_ptx = 1
597
-
598
- while cur_unit_ptx < cur_unit_num:
599
- cur_stage = max(cur_stage - 1, 0)
600
- if cur_stage == 0:
601
- break
602
- cur_unit_ptx += 1
603
- cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
604
- stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
605
-
606
- if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
607
- cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
608
- stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
609
-
610
- stage_input = list(reversed(stage_input))
611
- past_condition_latents.append(stage_input)
612
-
613
- intermed_latents = self.generate_one_unit(
614
- latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
615
- past_condition_latents,
616
- prompt_embeds,
617
- prompt_attention_mask,
618
- pooled_prompt_embeds,
619
- video_num_inference_steps,
620
- height,
621
- width,
622
- self.frame_per_unit,
623
- device,
624
- dtype,
625
- generator,
626
- is_first_frame=False,
627
- )
628
 
629
- generated_latents_list.append(intermed_latents[-1])
630
- last_generated_latents = intermed_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
  generated_latents = torch.cat(generated_latents_list, dim=2)
633
 
 
533
  stages = self.stages
534
 
535
  generated_latents_list = [] # The generated results
536
+ # last_generated_latents = None
537
 
538
  for unit_index in tqdm(range(num_units)):
539
  if use_linear_guidance:
 
542
 
543
  if unit_index == 0:
544
  # Generate first frame
545
+ try:
546
+
547
+ past_condition_latents = [[] for _ in range(len(stages))]
548
+ intermed_latents = self.generate_one_unit(
549
+ latents[:,:,:1],
550
+ past_condition_latents,
551
+ prompt_embeds,
552
+ prompt_attention_mask,
553
+ pooled_prompt_embeds,
554
+ num_inference_steps,
555
+ height,
556
+ width,
557
+ 1,
558
+ device,
559
+ dtype,
560
+ generator,
561
+ is_first_frame=True,
562
+ )
563
+ # Ensure
564
+ if not intermed_latents:
565
+ raise ValueError("First frame generation failed")
566
+
567
  generated_latents_list.append(intermed_latents[-1].clone())
568
+ print(f"successfully generated first frame. shape: {generated_latents_list[-1].shape}")
569
+
570
+ except Exception as e:
571
+ print(f"First frame generation failed: {str(e)}")
572
+ raise ValueError("Could not generate initial frame") from e
573
  else:
574
 
575
+ if not generated_latents_list:
576
+ raise ValueError("No previous frames available for conditioning (this should never happen)")
577
 
578
+ try:
579
+
580
+ # prepare conditioning from existing frames
581
+ concatenated_latents = torch.cat(generated_latents_list, dim=2)
582
+ print(f"Conditioning on {len(generated_latents_list)} frame. Concatenated shape: {concatenated_latents.shape}")
583
+
584
+ clean_latents_list = self.get_pyramid_latent(concatenated_latents, len(stages) - 1)
585
+
586
+ # prepare past conditions
587
+ past_condition_latents = []
588
+
589
+
590
+ for i_s in range(len(stages)):
591
+ last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
592
+
593
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
594
 
595
+ # pad the past clean latents
596
+ cur_unit_num = unit_index
597
+ cur_stage = i_s
598
+ cur_unit_ptx = 1
599
+
600
+ while cur_unit_ptx < cur_unit_num:
601
+ cur_stage = max(cur_stage - 1, 0)
602
+ if cur_stage == 0:
603
+ break
604
+ cur_unit_ptx += 1
605
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
606
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
607
+
608
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
609
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
610
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
611
 
612
+ stage_input = list(reversed(stage_input))
613
+ past_condition_latents.append(stage_input)
614
+
615
+ # Generate current frame unit
616
+ frame_slice = slice(
617
+ 1 + (unit_index - 1) * self.frame_per_unit,
618
+ 1 + unit_index * self.frame_per_unit
 
619
  )
 
 
 
 
 
 
 
620
 
621
+ intermed_latents = self.generate_one_unit(
622
+ latents[:,:, frame_slice],
623
+ past_condition_latents,
624
+ prompt_embeds,
625
+ prompt_attention_mask,
626
+ pooled_prompt_embeds,
627
+ video_num_inference_steps,
628
+ height,
629
+ width,
630
+ self.frame_per_unit,
631
+ device,
632
+ dtype,
633
+ generator,
634
+ is_first_frame=False,
635
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
+ if not intermed_latents:
638
+ raise ValueError(f"Frame generation failed for unit {unit_index}")
639
+
640
+
641
+ generated_latents_list.append(intermed_latents[-1].clone())
642
+ print(f"Successfully generated frame unit {unit_index}. Shape: {generated_latents[-1].shape}")
643
+ # last_generated_latents = intermed_latents
644
+
645
+ except Exception as e:
646
+ print(f"Frame generation faild for unit {unit_index}: {str(e)}")
647
+ print(f"Current frames : {len(generated_latents_list)}")
648
+ raise ValueError(f"Could not generate frame unit {unit_index}") from e
649
+
650
+ # final processing
651
+ if not generated_latents_list:
652
+ raise ValueError("No frames were generated")
653
 
654
  generated_latents = torch.cat(generated_latents_list, dim=2)
655