Antoni Bigata commited on
Commit
4ef25d2
·
1 Parent(s): 4fd1a69

requirements

Browse files
Files changed (1) hide show
  1. app.py +227 -231
app.py CHANGED
@@ -614,257 +614,253 @@ def process_video(video_input, audio_input, max_num_seconds):
614
  audio_input = DEFAULT_AUDIO_PATH
615
  print(f"Using default audio: {DEFAULT_AUDIO_PATH}")
616
 
617
- try:
618
- # Calculate hashes for cache keys
619
- video_path_hash = video_input
620
- audio_path_hash = audio_input
621
-
622
- # Check if we need to recompute video embeddings
623
- video_cache_hit = cache["video"]["path"] == video_path_hash
624
- audio_cache_hit = cache["audio"]["path"] == audio_path_hash
625
-
626
- if video_cache_hit and audio_cache_hit:
627
- print("Using cached video and audio computations")
628
- # Make copies of cached data to avoid modifying cache
629
- video_embedding = cache["video"]["embedding"].clone()
630
- video_frames = cache["video"]["frames"].clone()
631
- video_landmarks = cache["video"]["landmarks"].copy()
632
- raw_audio = cache["audio"]["raw_audio"].clone()
633
- raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
634
- hubert_embedding = cache["audio"]["hubert_embedding"].clone()
635
- wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
- # Ensure all data is truncated to the same length if needed
638
- min_len = min(
639
- len(video_frames),
640
- len(raw_audio),
641
- len(hubert_embedding),
642
- len(wavlm_embedding),
 
 
 
 
 
 
 
643
  )
644
- video_frames = video_frames[:min_len]
 
 
 
 
 
 
 
 
645
  video_embedding = video_embedding[:min_len]
 
646
  video_landmarks = video_landmarks[:min_len]
647
- raw_audio = raw_audio[:min_len]
648
- hubert_embedding = hubert_embedding[:min_len]
649
- wavlm_embedding = wavlm_embedding[:min_len]
650
- raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
651
 
652
  else:
653
- # Process video if needed
654
- if not video_cache_hit:
655
- print("Computing video embeddings and landmarks")
656
- video_reader = decord.VideoReader(video_input)
657
- decord.bridge.set_bridge("torch")
658
-
659
- if not audio_cache_hit:
660
- # Need to process audio to determine min_len
661
- raw_audio = get_raw_audio(audio_input, 16000)
662
- if len(raw_audio) == 0 or len(video_reader) == 0:
663
- raise ValueError("Empty audio or video input")
664
-
665
- min_len = min(len(raw_audio), len(video_reader))
666
-
667
- # Store full audio in cache
668
- cache["audio"]["path"] = audio_path_hash
669
- cache["audio"]["raw_audio"] = raw_audio.clone()
670
-
671
- # Create truncated copy for processing
672
- raw_audio = raw_audio[:min_len]
673
- raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
674
- else:
675
- # Use cached audio - make a copy
676
- if cache["audio"]["raw_audio"] is None:
677
- raise ValueError("Cached audio is None")
678
-
679
- raw_audio = cache["audio"]["raw_audio"].clone()
680
- if len(raw_audio) == 0 or len(video_reader) == 0:
681
- raise ValueError("Empty cached audio or video input")
682
-
683
- min_len = min(len(raw_audio), len(video_reader))
684
-
685
- # Create truncated copy for processing
686
- raw_audio = raw_audio[:min_len]
687
- raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
688
-
689
- # Compute video embeddings and landmarks - store full version in cache
690
- video_embedding, video_frames = compute_video_embedding(
691
- video_reader, len(video_reader)
692
- )
693
- video_landmarks = extract_video_landmarks(video_frames)
694
 
695
- # Update video cache with full versions
696
- cache["video"]["path"] = video_path_hash
697
- cache["video"]["embedding"] = video_embedding
698
- cache["video"]["frames"] = video_frames
699
- cache["video"]["landmarks"] = video_landmarks
 
700
 
701
- # Create truncated copies for processing
702
- video_embedding = video_embedding[:min_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  video_frames = video_frames[:min_len]
 
704
  video_landmarks = video_landmarks[:min_len]
705
-
706
  else:
707
- # Use cached video data - make copies
708
- print("Using cached video computations")
709
-
710
- if (
711
- cache["video"]["embedding"] is None
712
- or cache["video"]["frames"] is None
713
- or cache["video"]["landmarks"] is None
714
- ):
715
- raise ValueError("One or more video cache entries are None")
716
-
717
- if not audio_cache_hit:
718
- # New audio with cached video
719
- raw_audio = get_raw_audio(audio_input, 16000)
720
- if len(raw_audio) == 0:
721
- raise ValueError("Empty audio input")
722
-
723
- # Store full audio in cache
724
- cache["audio"]["path"] = audio_path_hash
725
- cache["audio"]["raw_audio"] = raw_audio.clone()
726
-
727
- # Make copies of video data
728
- video_embedding = cache["video"]["embedding"].clone()
729
- video_frames = cache["video"]["frames"].clone()
730
- video_landmarks = cache["video"]["landmarks"].copy()
731
-
732
- # Determine truncation length and create truncated copies
733
- min_len = min(len(raw_audio), len(video_frames))
734
- raw_audio = raw_audio[:min_len]
735
- raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
736
- video_frames = video_frames[:min_len]
737
- video_embedding = video_embedding[:min_len]
738
- video_landmarks = video_landmarks[:min_len]
739
- else:
740
- # Both video and audio are cached - should not reach here
741
- # as it's handled in the first if statement
742
- pass
743
-
744
- # Process audio if needed
745
- if not audio_cache_hit:
746
- print("Computing audio embeddings")
747
-
748
- # Compute audio embeddings with the truncated audio
749
- hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
750
- wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
751
 
752
- # Update audio cache with full embeddings
753
- # Note: raw_audio was already cached above
754
- cache["audio"]["hubert_embedding"] = hubert_embedding.clone()
755
- cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone()
756
- else:
757
- # Use cached audio data - make copies
758
- if (
759
- cache["audio"]["hubert_embedding"] is None
760
- or cache["audio"]["wavlm_embedding"] is None
761
- ):
762
- raise ValueError(
763
- "One or more audio embedding cache entries are None"
764
- )
765
 
766
- hubert_embedding = cache["audio"]["hubert_embedding"].clone()
767
- wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
768
-
769
- # Make sure embeddings match the truncated video length if needed
770
- if "min_len" in locals() and (
771
- min_len < len(hubert_embedding) or min_len < len(wavlm_embedding)
772
- ):
773
- hubert_embedding = hubert_embedding[:min_len]
774
- wavlm_embedding = wavlm_embedding[:min_len]
775
-
776
- # Apply max_num_seconds limit if specified
777
- if max_num_seconds > 0:
778
- # Convert seconds to frames (assuming 25 fps)
779
- max_frames = int(max_num_seconds * 25)
780
-
781
- # Truncate all data to max_frames
782
- video_embedding = video_embedding[:max_frames]
783
- video_frames = video_frames[:max_frames]
784
- video_landmarks = video_landmarks[:max_frames]
785
- hubert_embedding = hubert_embedding[:max_frames]
786
- wavlm_embedding = wavlm_embedding[:max_frames]
787
- raw_audio = raw_audio[:max_frames]
788
- raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
789
-
790
- # Validate shapes before proceeding
791
- assert video_embedding.shape[0] == hubert_embedding.shape[0], (
792
- f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})"
793
- )
794
- assert video_embedding.shape[0] == wavlm_embedding.shape[0], (
795
- f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})"
796
- )
797
- assert video_embedding.shape[0] == video_landmarks.shape[0], (
798
- f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})"
799
- )
800
 
801
- print(f"Hubert embedding shape: {hubert_embedding.shape}")
802
- print(f"WavLM embedding shape: {wavlm_embedding.shape}")
803
- print(f"Video embedding shape: {video_embedding.shape}")
804
- print(f"Video landmarks shape: {video_landmarks.shape}")
805
-
806
- # Create pipeline inputs for models
807
- (
808
- interpolation_chunks,
809
- keyframe_chunks,
810
- audio_interpolation_chunks,
811
- audio_keyframe_chunks,
812
- emb_cond,
813
- masks_keyframe_chunks,
814
- masks_interpolation_chunks,
815
- to_remove,
816
- audio_interpolation_idx,
817
- audio_keyframe_idx,
818
- ) = create_pipeline_inputs(
819
- hubert_embedding,
820
- wavlm_embedding,
821
- 14,
822
- video_embedding,
823
- video_landmarks,
824
- overlap=1,
825
- add_zero_flag=True,
826
- mask_arms=None,
827
- nose_index=28,
828
- )
829
 
830
- complete_video = sample(
831
- audio_keyframe_chunks,
832
- keyframe_chunks,
833
- masks_keyframe_chunks,
834
- to_remove,
835
- audio_keyframe_idx,
836
- 14,
837
- "cuda",
838
- emb_cond,
839
- [],
840
- 3,
841
- 3,
842
- audio_interpolation_idx,
843
- audio_interpolation_chunks,
844
- masks_interpolation_chunks,
845
- interpolation_chunks,
846
- keyframe_model,
847
- interpolation_model,
848
- )
849
 
850
- complete_audio = rearrange(
851
- raw_audio[: complete_video.shape[0]], "f s -> () (f s)"
852
- )
853
 
854
- # 4. Convert frames to video and combine with audio
855
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
856
- output_path = temp_video.name
857
 
858
- print("Saving video to", output_path)
859
 
860
- save_audio_video(complete_video, audio=complete_audio, save_path=output_path)
861
- torch.cuda.empty_cache()
862
- return output_path
863
 
864
- except Exception as e:
865
- raise e
866
- print(f"Error processing video: {str(e)}")
867
- return None
868
 
869
 
870
  def get_max_duration(video_input, audio_input):
 
614
  audio_input = DEFAULT_AUDIO_PATH
615
  print(f"Using default audio: {DEFAULT_AUDIO_PATH}")
616
 
617
+ # try:
618
+ # Calculate hashes for cache keys
619
+ video_path_hash = video_input
620
+ audio_path_hash = audio_input
621
+
622
+ # Check if we need to recompute video embeddings
623
+ video_cache_hit = cache["video"]["path"] == video_path_hash
624
+ audio_cache_hit = cache["audio"]["path"] == audio_path_hash
625
+
626
+ if video_cache_hit and audio_cache_hit:
627
+ print("Using cached video and audio computations")
628
+ # Make copies of cached data to avoid modifying cache
629
+ video_embedding = cache["video"]["embedding"].clone()
630
+ video_frames = cache["video"]["frames"].clone()
631
+ video_landmarks = cache["video"]["landmarks"].copy()
632
+ raw_audio = cache["audio"]["raw_audio"].clone()
633
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
634
+ hubert_embedding = cache["audio"]["hubert_embedding"].clone()
635
+ wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
636
+
637
+ # Ensure all data is truncated to the same length if needed
638
+ min_len = min(
639
+ len(video_frames),
640
+ len(raw_audio),
641
+ len(hubert_embedding),
642
+ len(wavlm_embedding),
643
+ )
644
+ video_frames = video_frames[:min_len]
645
+ video_embedding = video_embedding[:min_len]
646
+ video_landmarks = video_landmarks[:min_len]
647
+ raw_audio = raw_audio[:min_len]
648
+ hubert_embedding = hubert_embedding[:min_len]
649
+ wavlm_embedding = wavlm_embedding[:min_len]
650
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
651
+
652
+ else:
653
+ # Process video if needed
654
+ if not video_cache_hit:
655
+ print("Computing video embeddings and landmarks")
656
+ video_reader = decord.VideoReader(video_input)
657
+ decord.bridge.set_bridge("torch")
658
+
659
+ if not audio_cache_hit:
660
+ # Need to process audio to determine min_len
661
+ raw_audio = get_raw_audio(audio_input, 16000)
662
+ if len(raw_audio) == 0 or len(video_reader) == 0:
663
+ raise ValueError("Empty audio or video input")
664
+
665
+ min_len = min(len(raw_audio), len(video_reader))
666
+
667
+ # Store full audio in cache
668
+ cache["audio"]["path"] = audio_path_hash
669
+ cache["audio"]["raw_audio"] = raw_audio.clone()
670
+
671
+ # Create truncated copy for processing
672
+ raw_audio = raw_audio[:min_len]
673
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
674
+ else:
675
+ # Use cached audio - make a copy
676
+ if cache["audio"]["raw_audio"] is None:
677
+ raise ValueError("Cached audio is None")
678
 
679
+ raw_audio = cache["audio"]["raw_audio"].clone()
680
+ if len(raw_audio) == 0 or len(video_reader) == 0:
681
+ raise ValueError("Empty cached audio or video input")
682
+
683
+ min_len = min(len(raw_audio), len(video_reader))
684
+
685
+ # Create truncated copy for processing
686
+ raw_audio = raw_audio[:min_len]
687
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
688
+
689
+ # Compute video embeddings and landmarks - store full version in cache
690
+ video_embedding, video_frames = compute_video_embedding(
691
+ video_reader, len(video_reader)
692
  )
693
+ video_landmarks = extract_video_landmarks(video_frames)
694
+
695
+ # Update video cache with full versions
696
+ cache["video"]["path"] = video_path_hash
697
+ cache["video"]["embedding"] = video_embedding
698
+ cache["video"]["frames"] = video_frames
699
+ cache["video"]["landmarks"] = video_landmarks
700
+
701
+ # Create truncated copies for processing
702
  video_embedding = video_embedding[:min_len]
703
+ video_frames = video_frames[:min_len]
704
  video_landmarks = video_landmarks[:min_len]
 
 
 
 
705
 
706
  else:
707
+ # Use cached video data - make copies
708
+ print("Using cached video computations")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
 
710
+ if (
711
+ cache["video"]["embedding"] is None
712
+ or cache["video"]["frames"] is None
713
+ or cache["video"]["landmarks"] is None
714
+ ):
715
+ raise ValueError("One or more video cache entries are None")
716
 
717
+ if not audio_cache_hit:
718
+ # New audio with cached video
719
+ raw_audio = get_raw_audio(audio_input, 16000)
720
+ if len(raw_audio) == 0:
721
+ raise ValueError("Empty audio input")
722
+
723
+ # Store full audio in cache
724
+ cache["audio"]["path"] = audio_path_hash
725
+ cache["audio"]["raw_audio"] = raw_audio.clone()
726
+
727
+ # Make copies of video data
728
+ video_embedding = cache["video"]["embedding"].clone()
729
+ video_frames = cache["video"]["frames"].clone()
730
+ video_landmarks = cache["video"]["landmarks"].copy()
731
+
732
+ # Determine truncation length and create truncated copies
733
+ min_len = min(len(raw_audio), len(video_frames))
734
+ raw_audio = raw_audio[:min_len]
735
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
736
  video_frames = video_frames[:min_len]
737
+ video_embedding = video_embedding[:min_len]
738
  video_landmarks = video_landmarks[:min_len]
 
739
  else:
740
+ # Both video and audio are cached - should not reach here
741
+ # as it's handled in the first if statement
742
+ pass
743
+
744
+ # Process audio if needed
745
+ if not audio_cache_hit:
746
+ print("Computing audio embeddings")
747
+
748
+ # Compute audio embeddings with the truncated audio
749
+ hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
750
+ wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
751
+
752
+ # Update audio cache with full embeddings
753
+ # Note: raw_audio was already cached above
754
+ cache["audio"]["hubert_embedding"] = hubert_embedding.clone()
755
+ cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone()
756
+ else:
757
+ # Use cached audio data - make copies
758
+ if (
759
+ cache["audio"]["hubert_embedding"] is None
760
+ or cache["audio"]["wavlm_embedding"] is None
761
+ ):
762
+ raise ValueError("One or more audio embedding cache entries are None")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
 
764
+ hubert_embedding = cache["audio"]["hubert_embedding"].clone()
765
+ wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
 
 
 
 
 
 
 
 
 
 
 
766
 
767
+ # Make sure embeddings match the truncated video length if needed
768
+ if "min_len" in locals() and (
769
+ min_len < len(hubert_embedding) or min_len < len(wavlm_embedding)
770
+ ):
771
+ hubert_embedding = hubert_embedding[:min_len]
772
+ wavlm_embedding = wavlm_embedding[:min_len]
773
+
774
+ # Apply max_num_seconds limit if specified
775
+ if max_num_seconds > 0:
776
+ # Convert seconds to frames (assuming 25 fps)
777
+ max_frames = int(max_num_seconds * 25)
778
+
779
+ # Truncate all data to max_frames
780
+ video_embedding = video_embedding[:max_frames]
781
+ video_frames = video_frames[:max_frames]
782
+ video_landmarks = video_landmarks[:max_frames]
783
+ hubert_embedding = hubert_embedding[:max_frames]
784
+ wavlm_embedding = wavlm_embedding[:max_frames]
785
+ raw_audio = raw_audio[:max_frames]
786
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
787
+
788
+ # Validate shapes before proceeding
789
+ assert video_embedding.shape[0] == hubert_embedding.shape[0], (
790
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})"
791
+ )
792
+ assert video_embedding.shape[0] == wavlm_embedding.shape[0], (
793
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})"
794
+ )
795
+ assert video_embedding.shape[0] == video_landmarks.shape[0], (
796
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})"
797
+ )
 
 
 
798
 
799
+ print(f"Hubert embedding shape: {hubert_embedding.shape}")
800
+ print(f"WavLM embedding shape: {wavlm_embedding.shape}")
801
+ print(f"Video embedding shape: {video_embedding.shape}")
802
+ print(f"Video landmarks shape: {video_landmarks.shape}")
803
+
804
+ # Create pipeline inputs for models
805
+ (
806
+ interpolation_chunks,
807
+ keyframe_chunks,
808
+ audio_interpolation_chunks,
809
+ audio_keyframe_chunks,
810
+ emb_cond,
811
+ masks_keyframe_chunks,
812
+ masks_interpolation_chunks,
813
+ to_remove,
814
+ audio_interpolation_idx,
815
+ audio_keyframe_idx,
816
+ ) = create_pipeline_inputs(
817
+ hubert_embedding,
818
+ wavlm_embedding,
819
+ 14,
820
+ video_embedding,
821
+ video_landmarks,
822
+ overlap=1,
823
+ add_zero_flag=True,
824
+ mask_arms=None,
825
+ nose_index=28,
826
+ )
827
 
828
+ complete_video = sample(
829
+ audio_keyframe_chunks,
830
+ keyframe_chunks,
831
+ masks_keyframe_chunks,
832
+ to_remove,
833
+ audio_keyframe_idx,
834
+ 14,
835
+ "cuda",
836
+ emb_cond,
837
+ [],
838
+ 3,
839
+ 3,
840
+ audio_interpolation_idx,
841
+ audio_interpolation_chunks,
842
+ masks_interpolation_chunks,
843
+ interpolation_chunks,
844
+ keyframe_model,
845
+ interpolation_model,
846
+ )
847
 
848
+ complete_audio = rearrange(raw_audio[: complete_video.shape[0]], "f s -> () (f s)")
 
 
849
 
850
+ # 4. Convert frames to video and combine with audio
851
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
852
+ output_path = temp_video.name
853
 
854
+ print("Saving video to", output_path)
855
 
856
+ save_audio_video(complete_video, audio=complete_audio, save_path=output_path)
857
+ torch.cuda.empty_cache()
858
+ return output_path
859
 
860
+ # except Exception as e:
861
+ # raise e
862
+ # print(f"Error processing video: {str(e)}")
863
+ # return None
864
 
865
 
866
  def get_max_duration(video_input, audio_input):