Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
05a9e89
1
Parent(s):
6143d6b
'testing'
Browse files
pyramid_dit/pyramid_dit_for_video_gen_pipeline.py
CHANGED
@@ -533,7 +533,7 @@ class PyramidDiTForVideoGeneration:
|
|
533 |
stages = self.stages
|
534 |
|
535 |
generated_latents_list = [] # The generated results
|
536 |
-
last_generated_latents = None
|
537 |
|
538 |
for unit_index in tqdm(range(num_units)):
|
539 |
if use_linear_guidance:
|
@@ -542,92 +542,114 @@ class PyramidDiTForVideoGeneration:
|
|
542 |
|
543 |
if unit_index == 0:
|
544 |
# Generate first frame
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
|
|
|
|
|
|
564 |
generated_latents_list.append(intermed_latents[-1].clone())
|
|
|
|
|
|
|
|
|
|
|
565 |
else:
|
566 |
|
567 |
-
|
|
|
568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
-
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
len(stages) - 1
|
582 |
)
|
583 |
-
except Exception as e:
|
584 |
-
print(f"Error concatenating latents: {e}")
|
585 |
-
print(f"Shapes: {[x.shape for x in generated_latents_list]}")
|
586 |
-
raise
|
587 |
-
|
588 |
-
for i_s in range(len(stages)):
|
589 |
-
last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
|
590 |
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
|
607 |
-
cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
|
608 |
-
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
|
609 |
-
|
610 |
-
stage_input = list(reversed(stage_input))
|
611 |
-
past_condition_latents.append(stage_input)
|
612 |
-
|
613 |
-
intermed_latents = self.generate_one_unit(
|
614 |
-
latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
|
615 |
-
past_condition_latents,
|
616 |
-
prompt_embeds,
|
617 |
-
prompt_attention_mask,
|
618 |
-
pooled_prompt_embeds,
|
619 |
-
video_num_inference_steps,
|
620 |
-
height,
|
621 |
-
width,
|
622 |
-
self.frame_per_unit,
|
623 |
-
device,
|
624 |
-
dtype,
|
625 |
-
generator,
|
626 |
-
is_first_frame=False,
|
627 |
-
)
|
628 |
|
629 |
-
|
630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
|
632 |
generated_latents = torch.cat(generated_latents_list, dim=2)
|
633 |
|
|
|
533 |
stages = self.stages
|
534 |
|
535 |
generated_latents_list = [] # The generated results
|
536 |
+
# last_generated_latents = None
|
537 |
|
538 |
for unit_index in tqdm(range(num_units)):
|
539 |
if use_linear_guidance:
|
|
|
542 |
|
543 |
if unit_index == 0:
|
544 |
# Generate first frame
|
545 |
+
try:
|
546 |
+
|
547 |
+
past_condition_latents = [[] for _ in range(len(stages))]
|
548 |
+
intermed_latents = self.generate_one_unit(
|
549 |
+
latents[:,:,:1],
|
550 |
+
past_condition_latents,
|
551 |
+
prompt_embeds,
|
552 |
+
prompt_attention_mask,
|
553 |
+
pooled_prompt_embeds,
|
554 |
+
num_inference_steps,
|
555 |
+
height,
|
556 |
+
width,
|
557 |
+
1,
|
558 |
+
device,
|
559 |
+
dtype,
|
560 |
+
generator,
|
561 |
+
is_first_frame=True,
|
562 |
+
)
|
563 |
+
# Ensure
|
564 |
+
if not intermed_latents:
|
565 |
+
raise ValueError("First frame generation failed")
|
566 |
+
|
567 |
generated_latents_list.append(intermed_latents[-1].clone())
|
568 |
+
print(f"successfully generated first frame. shape: {generated_latents_list[-1].shape}")
|
569 |
+
|
570 |
+
except Exception as e:
|
571 |
+
print(f"First frame generation failed: {str(e)}")
|
572 |
+
raise ValueError("Could not generate initial frame") from e
|
573 |
else:
|
574 |
|
575 |
+
if not generated_latents_list:
|
576 |
+
raise ValueError("No previous frames available for conditioning (this should never happen)")
|
577 |
|
578 |
+
try:
|
579 |
+
|
580 |
+
# prepare conditioning from existing frames
|
581 |
+
concatenated_latents = torch.cat(generated_latents_list, dim=2)
|
582 |
+
print(f"Conditioning on {len(generated_latents_list)} frame. Concatenated shape: {concatenated_latents.shape}")
|
583 |
+
|
584 |
+
clean_latents_list = self.get_pyramid_latent(concatenated_latents, len(stages) - 1)
|
585 |
+
|
586 |
+
# prepare past conditions
|
587 |
+
past_condition_latents = []
|
588 |
+
|
589 |
+
|
590 |
+
for i_s in range(len(stages)):
|
591 |
+
last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
|
592 |
+
|
593 |
+
stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
|
594 |
|
595 |
+
# pad the past clean latents
|
596 |
+
cur_unit_num = unit_index
|
597 |
+
cur_stage = i_s
|
598 |
+
cur_unit_ptx = 1
|
599 |
+
|
600 |
+
while cur_unit_ptx < cur_unit_num:
|
601 |
+
cur_stage = max(cur_stage - 1, 0)
|
602 |
+
if cur_stage == 0:
|
603 |
+
break
|
604 |
+
cur_unit_ptx += 1
|
605 |
+
cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
|
606 |
+
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
|
607 |
+
|
608 |
+
if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
|
609 |
+
cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
|
610 |
+
stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
|
611 |
|
612 |
+
stage_input = list(reversed(stage_input))
|
613 |
+
past_condition_latents.append(stage_input)
|
614 |
+
|
615 |
+
# Generate current frame unit
|
616 |
+
frame_slice = slice(
|
617 |
+
1 + (unit_index - 1) * self.frame_per_unit,
|
618 |
+
1 + unit_index * self.frame_per_unit
|
|
|
619 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
|
621 |
+
intermed_latents = self.generate_one_unit(
|
622 |
+
latents[:,:, frame_slice],
|
623 |
+
past_condition_latents,
|
624 |
+
prompt_embeds,
|
625 |
+
prompt_attention_mask,
|
626 |
+
pooled_prompt_embeds,
|
627 |
+
video_num_inference_steps,
|
628 |
+
height,
|
629 |
+
width,
|
630 |
+
self.frame_per_unit,
|
631 |
+
device,
|
632 |
+
dtype,
|
633 |
+
generator,
|
634 |
+
is_first_frame=False,
|
635 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
|
637 |
+
if not intermed_latents:
|
638 |
+
raise ValueError(f"Frame generation failed for unit {unit_index}")
|
639 |
+
|
640 |
+
|
641 |
+
generated_latents_list.append(intermed_latents[-1].clone())
|
642 |
+
print(f"Successfully generated frame unit {unit_index}. Shape: {generated_latents[-1].shape}")
|
643 |
+
# last_generated_latents = intermed_latents
|
644 |
+
|
645 |
+
except Exception as e:
|
646 |
+
print(f"Frame generation faild for unit {unit_index}: {str(e)}")
|
647 |
+
print(f"Current frames : {len(generated_latents_list)}")
|
648 |
+
raise ValueError(f"Could not generate frame unit {unit_index}") from e
|
649 |
+
|
650 |
+
# final processing
|
651 |
+
if not generated_latents_list:
|
652 |
+
raise ValueError("No frames were generated")
|
653 |
|
654 |
generated_latents = torch.cat(generated_latents_list, dim=2)
|
655 |
|