HReynaud commited on
Commit
40159fb
·
1 Parent(s): 65341d8

move examples

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -177,4 +177,4 @@ tmp/
177
  .vscode/
178
  .gradio/
179
  .cursor/
180
- *.mp4
 
177
  .vscode/
178
  .gradio/
179
  .cursor/
180
+ temp_*
assets/examples/a4c_decoded.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac03378eb37b461050c952e3766044e7dabeb4e1d55068f9ad07f990d6b1e05a
3
+ size 313650
assets/examples/a4c_latent.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f854cbd44e2e151098de2699875c8eb582254e61e782394f5c002522cd0ac644
3
+ size 249646
assets/examples/plax_decoded.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d91827182f647975cc68dd09ee0e927285b90ea6c408e8b190a75077fda0281e
3
+ size 216872
assets/examples/plax_latent.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18c6b8c86ca1c2acd16d37097bed855d84ef58d0cfd7ad8392d2d967a5bdec57
3
+ size 171456
assets/examples/psax_decoded.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04eaa3f6827cd111915fe32fbc379db1033b74d72bfbd89361deb3dde4e06011
3
+ size 199467
assets/examples/psax_latent.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf9f1c520f74f5142a49bcc476b98f81fb854946a2d24f0a27ff43182a83a097
3
+ size 232138
assets/plax_seg.png CHANGED
demo.py CHANGED
@@ -668,6 +668,108 @@ def load_view_mask(view):
668
 
669
 
670
  def create_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
  # Define the theme and layout
672
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
673
  gr.Markdown("# EchoFlow Demo")
@@ -689,6 +791,153 @@ def create_demo():
689
  """
690
  )
691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
  # Main container with 4 columns
693
  with gr.Row():
694
  # Column 1: Latent Image Generation
@@ -701,7 +950,6 @@ def create_demo():
701
  with gr.Row():
702
  # Input mask (binary image)
703
  with gr.Column(scale=1):
704
- # gr.Markdown("#### Mask Condition")
705
  gr.Markdown("Draw the LV mask (white = region of interest)")
706
  # Create a black background for the canvas
707
  black_background = np.zeros((400, 400), dtype=np.uint8)
@@ -730,59 +978,16 @@ def create_demo():
730
  # Fall back to empty canvas
731
  editor_value = black_background
732
 
733
- mask_input = gr.ImageEditor(
734
- label="Binary Mask",
735
- height=400,
736
- width=400,
737
- image_mode="L",
738
- value=editor_value,
739
- type="numpy",
740
- brush=gr.Brush(
741
- colors=["#ffffff"],
742
- color_mode="fixed",
743
- default_size=20,
744
- default_color="#ffffff",
745
- ),
746
- eraser=gr.Eraser(default_size=20),
747
- # show_label=False,
748
- show_download_button=True,
749
- sources=[],
750
- canvas_size=(400, 400),
751
- fixed_canvas=True,
752
- layers=False, # Enable layers to make the mask editable
753
- )
754
-
755
- # # Class selection
756
- # with gr.Column(scale=1):
757
- # gr.Markdown("#### View Condition")
758
- class_selection = gr.Radio(
759
- choices=["A4C", "PSAX", "PLAX"],
760
- label="View Class",
761
- value="A4C",
762
- )
763
-
764
- # gr.Markdown("#### Sampling Steps")
765
- sampling_steps = gr.Slider(
766
- minimum=1,
767
- maximum=200,
768
- value=100,
769
- step=1,
770
- label="Number of Sampling Steps",
771
- info="Higher values = better quality but slower generation",
772
- )
773
 
774
  # Generate button
775
  generate_btn = gr.Button("Generate Latent Image", variant="primary")
776
 
777
  # Display area for latent image (grayscale visualization)
778
- latent_image_display = gr.Image(
779
- label="Latent Image",
780
- type="numpy",
781
- height=400,
782
- width=400,
783
- show_download_button=True,
784
- interactive=False, # Cannot be uploaded/edited
785
- )
786
 
787
  # Decode button (initially disabled)
788
  decode_btn = gr.Button(
@@ -792,14 +997,7 @@ def create_demo():
792
  )
793
 
794
  # Display area for decoded image
795
- decoded_image_display = gr.Image(
796
- label="Decoded Image",
797
- type="numpy",
798
- height=400,
799
- width=400,
800
- show_download_button=True,
801
- interactive=False, # Cannot be uploaded/edited
802
- )
803
 
804
  # Column 2: Privacy Filter
805
  with gr.Column():
@@ -817,12 +1015,10 @@ def create_demo():
817
  )
818
 
819
  # Display area for privacy result status
820
- privacy_status = gr.Markdown("No image processed yet")
821
 
822
  # Display area for privacy-filtered latent image
823
- filtered_latent_display = gr.Image(
824
- label="Filtered Latent Image", type="numpy", height=400, width=400
825
- )
826
 
827
  # Column 3: Animation
828
  with gr.Column():
@@ -832,33 +1028,9 @@ def create_demo():
832
  gr.Markdown("### Latent Video Generation")
833
 
834
  # Ejection Fraction slider
835
- ef_slider = gr.Slider(
836
- minimum=0,
837
- maximum=100,
838
- value=65,
839
- label="Ejection Fraction (%)",
840
- info="Higher values = stronger contraction",
841
- )
842
-
843
- # Add sampling steps slider for animation
844
- animation_steps = gr.Slider(
845
- minimum=1,
846
- maximum=200,
847
- value=100,
848
- step=1,
849
- label="Number of Sampling Steps",
850
- info="Higher values = better quality but slower generation",
851
- )
852
-
853
- # Add CFG slider
854
- cfg_slider = gr.Slider(
855
- minimum=0,
856
- maximum=10,
857
- value=1,
858
- step=1,
859
- label="Classifier-Free Guidance Scale",
860
- # info="Higher values = better quality but slower generation",
861
- )
862
 
863
  # Animate button
864
  animate_btn = gr.Button(
@@ -866,9 +1038,7 @@ def create_demo():
866
  )
867
 
868
  # Display area for latent animation (grayscale)
869
- latent_animation_display = gr.Video(
870
- label="Latent Video", format="mp4", autoplay=True, loop=True
871
- )
872
 
873
  # Column 4: Video Decoding
874
  with gr.Column():
@@ -883,9 +1053,7 @@ def create_demo():
883
  )
884
 
885
  # Display area for decoded animation
886
- decoded_animation_display = gr.Video(
887
- label="Decoded Video", format="mp4", autoplay=True, loop=True
888
- )
889
 
890
  # Hidden state variables to store the full latent representations
891
  latent_image_state = gr.State(None)
@@ -985,121 +1153,6 @@ def create_demo():
985
  queue=True,
986
  )
987
 
988
- # Add examples
989
- gr.Examples(
990
- examples=[
991
- # Example 1: A4C view
992
- [
993
- # Inputs
994
- {
995
- "background": np.zeros((400, 400), dtype=np.uint8),
996
- "layers": [
997
- np.array(
998
- Image.open("assets/a4c_seg.png")
999
- .convert("L")
1000
- .resize((400, 400))
1001
- )
1002
- ],
1003
- "composite": np.array(
1004
- Image.open("assets/a4c_seg.png")
1005
- .convert("L")
1006
- .resize((400, 400))
1007
- ),
1008
- },
1009
- "A4C", # view
1010
- 100, # sampling steps
1011
- 65, # EF slider
1012
- 100, # animation steps
1013
- 1.0, # cfg scale
1014
- # Pre-computed outputs
1015
- Image.open("assets/examples/a4c_latent.png"), # latent image
1016
- Image.open("assets/examples/a4c_decoded.png"), # decoded image
1017
- "✅ **Success:** Generated image passed privacy check.", # privacy status
1018
- Image.open("assets/examples/a4c_filtered.png"), # filtered latent
1019
- "assets/examples/a4c_latent.mp4", # latent animation
1020
- "assets/examples/a4c_decoded.mp4", # decoded animation
1021
- ],
1022
- # Example 2: PSAX view
1023
- [
1024
- # Inputs
1025
- {
1026
- "background": np.zeros((400, 400), dtype=np.uint8),
1027
- "layers": [
1028
- np.array(
1029
- Image.open("assets/psax_seg.png")
1030
- .convert("L")
1031
- .resize((400, 400))
1032
- )
1033
- ],
1034
- "composite": np.array(
1035
- Image.open("assets/psax_seg.png")
1036
- .convert("L")
1037
- .resize((400, 400))
1038
- ),
1039
- },
1040
- "PSAX", # view
1041
- 150, # sampling steps
1042
- 75, # EF slider
1043
- 150, # animation steps
1044
- 2.0, # cfg scale
1045
- # Pre-computed outputs
1046
- Image.open("assets/examples/psax_latent.png"), # latent image
1047
- Image.open("assets/examples/psax_decoded.png"), # decoded image
1048
- "✅ **Success:** Generated image passed privacy check.", # privacy status
1049
- Image.open("assets/examples/psax_filtered.png"), # filtered latent
1050
- "assets/examples/psax_latent.mp4", # latent animation
1051
- "assets/examples/psax_decoded.mp4", # decoded animation
1052
- ],
1053
- # Example 3: PLAX view
1054
- [
1055
- # Inputs
1056
- {
1057
- "background": np.zeros((400, 400), dtype=np.uint8),
1058
- "layers": [
1059
- np.array(
1060
- Image.open("assets/plax_seg.png")
1061
- .convert("L")
1062
- .resize((400, 400))
1063
- )
1064
- ],
1065
- "composite": np.array(
1066
- Image.open("assets/plax_seg.png")
1067
- .convert("L")
1068
- .resize((400, 400))
1069
- ),
1070
- },
1071
- "PLAX", # view
1072
- 200, # sampling steps
1073
- 55, # EF slider
1074
- 200, # animation steps
1075
- 3.0, # cfg scale
1076
- # Pre-computed outputs
1077
- Image.open("assets/examples/plax_latent.png"), # latent image
1078
- Image.open("assets/examples/plax_decoded.png"), # decoded image
1079
- "✅ **Success:** Generated image passed privacy check.", # privacy status
1080
- Image.open("assets/examples/plax_filtered.png"), # filtered latent
1081
- "assets/examples/plax_latent.mp4", # latent animation
1082
- "assets/examples/plax_decoded.mp4", # decoded animation
1083
- ],
1084
- ],
1085
- inputs=[
1086
- mask_input,
1087
- class_selection,
1088
- sampling_steps,
1089
- ef_slider,
1090
- animation_steps,
1091
- cfg_slider,
1092
- latent_image_display,
1093
- decoded_image_display,
1094
- privacy_status,
1095
- filtered_latent_display,
1096
- latent_animation_display,
1097
- decoded_animation_display,
1098
- ],
1099
- label="Example Configurations",
1100
- examples_per_page=3,
1101
- )
1102
-
1103
  return demo
1104
 
1105
 
 
668
 
669
 
670
  def create_demo():
671
+ # Define all components first
672
+ mask_input = gr.ImageEditor(
673
+ label="Binary Mask",
674
+ height=400,
675
+ width=400,
676
+ image_mode="L",
677
+ type="numpy",
678
+ brush=gr.Brush(
679
+ colors=["#ffffff"],
680
+ color_mode="fixed",
681
+ default_size=20,
682
+ default_color="#ffffff",
683
+ ),
684
+ eraser=gr.Eraser(default_size=20),
685
+ show_download_button=True,
686
+ sources=[],
687
+ canvas_size=(400, 400),
688
+ fixed_canvas=True,
689
+ layers=False,
690
+ render=False,
691
+ )
692
+
693
+ class_selection = gr.Radio(
694
+ choices=["A4C", "PSAX", "PLAX"],
695
+ label="View Class",
696
+ value="A4C",
697
+ render=False,
698
+ )
699
+
700
+ sampling_steps = gr.Slider(
701
+ minimum=1,
702
+ maximum=200,
703
+ value=100,
704
+ step=1,
705
+ label="Number of Sampling Steps",
706
+ render=False,
707
+ )
708
+
709
+ ef_slider = gr.Slider(
710
+ minimum=0,
711
+ maximum=100,
712
+ value=65,
713
+ label="Ejection Fraction (%)",
714
+ render=False,
715
+ )
716
+
717
+ animation_steps = gr.Slider(
718
+ minimum=1,
719
+ maximum=200,
720
+ value=100,
721
+ step=1,
722
+ label="Number of Sampling Steps",
723
+ render=False,
724
+ )
725
+
726
+ cfg_slider = gr.Slider(
727
+ minimum=0,
728
+ maximum=10,
729
+ value=1,
730
+ step=1,
731
+ label="Classifier-Free Guidance Scale",
732
+ render=False,
733
+ )
734
+
735
+ latent_image_display = gr.Image(
736
+ label="Latent Image",
737
+ type="numpy",
738
+ height=400,
739
+ width=400,
740
+ render=False,
741
+ )
742
+
743
+ decoded_image_display = gr.Image(
744
+ label="Decoded Image",
745
+ type="numpy",
746
+ height=400,
747
+ width=400,
748
+ render=False,
749
+ )
750
+
751
+ privacy_status = gr.Markdown(render=False)
752
+
753
+ filtered_latent_display = gr.Image(
754
+ label="Filtered Latent Image",
755
+ type="numpy",
756
+ height=400,
757
+ width=400,
758
+ render=False,
759
+ )
760
+
761
+ latent_animation_display = gr.Video(
762
+ label="Latent Video",
763
+ format="mp4",
764
+ render=False,
765
+ )
766
+
767
+ decoded_animation_display = gr.Video(
768
+ label="Decoded Video",
769
+ format="mp4",
770
+ render=False,
771
+ )
772
+
773
  # Define the theme and layout
774
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
775
  gr.Markdown("# EchoFlow Demo")
 
791
  """
792
  )
793
 
794
+ def load_example(
795
+ mask,
796
+ view,
797
+ steps,
798
+ ef,
799
+ anim_steps,
800
+ cfg,
801
+ latent,
802
+ decoded,
803
+ status,
804
+ filtered,
805
+ latent_vid,
806
+ decoded_vid,
807
+ ):
808
+ # This function will be called when an example is clicked
809
+ # It returns all values in order they should be loaded into components
810
+ return [
811
+ mask,
812
+ view,
813
+ steps,
814
+ ef,
815
+ anim_steps,
816
+ cfg,
817
+ latent,
818
+ decoded,
819
+ status,
820
+ filtered,
821
+ latent_vid,
822
+ decoded_vid,
823
+ ]
824
+
825
+ # Add examples using the components
826
+ examples = gr.Examples(
827
+ examples=[
828
+ # Example 1: A4C view
829
+ [
830
+ # Inputs
831
+ {
832
+ "background": np.zeros((400, 400), dtype=np.uint8),
833
+ "layers": [
834
+ np.array(
835
+ Image.open("assets/a4c_seg.png")
836
+ .convert("L")
837
+ .resize((400, 400))
838
+ )
839
+ ],
840
+ "composite": np.array(
841
+ Image.open("assets/a4c_seg.png")
842
+ .convert("L")
843
+ .resize((400, 400))
844
+ ),
845
+ },
846
+ "A4C", # view
847
+ 100, # sampling steps
848
+ 65, # EF slider
849
+ 100, # animation steps
850
+ 1.0, # cfg scale
851
+ # Pre-computed outputs
852
+ Image.open("assets/examples/a4c_latent.png"), # latent image
853
+ Image.open("assets/examples/a4c_decoded.png"), # decoded image
854
+ "✅ **Success:** Generated image passed privacy check.", # privacy status
855
+ Image.open("assets/examples/a4c_filtered.png"), # filtered latent
856
+ "assets/examples/a4c_latent.mp4", # latent animation
857
+ "assets/examples/a4c_decoded.mp4", # decoded animation
858
+ ],
859
+ # Example 2: PSAX view
860
+ [
861
+ # Inputs
862
+ {
863
+ "background": np.zeros((400, 400), dtype=np.uint8),
864
+ "layers": [
865
+ np.array(
866
+ Image.open("assets/psax_seg.png")
867
+ .convert("L")
868
+ .resize((400, 400))
869
+ )
870
+ ],
871
+ "composite": np.array(
872
+ Image.open("assets/psax_seg.png")
873
+ .convert("L")
874
+ .resize((400, 400))
875
+ ),
876
+ },
877
+ "PSAX", # view
878
+ 100, # sampling steps
879
+ 65, # EF slider
880
+ 100, # animation steps
881
+ 1.0, # cfg scale
882
+ # Pre-computed outputs
883
+ Image.open("assets/examples/psax_latent.png"), # latent image
884
+ Image.open("assets/examples/psax_decoded.png"), # decoded image
885
+ "✅ **Success:** Generated image passed privacy check.", # privacy status
886
+ Image.open("assets/examples/psax_filtered.png"), # filtered latent
887
+ "assets/examples/psax_latent.mp4", # latent animation
888
+ "assets/examples/psax_decoded.mp4", # decoded animation
889
+ ],
890
+ # Example 3: PLAX view
891
+ [
892
+ # Inputs
893
+ {
894
+ "background": np.zeros((400, 400), dtype=np.uint8),
895
+ "layers": [
896
+ np.array(
897
+ Image.open("assets/plax_seg.png")
898
+ .convert("L")
899
+ .resize((400, 400))
900
+ )
901
+ ],
902
+ "composite": np.array(
903
+ Image.open("assets/plax_seg.png")
904
+ .convert("L")
905
+ .resize((400, 400))
906
+ ),
907
+ },
908
+ "PLAX", # view
909
+ 100, # sampling steps
910
+ 65, # EF slider
911
+ 100, # animation steps
912
+ 1.0, # cfg scale
913
+ # Pre-computed outputs
914
+ Image.open("assets/examples/plax_latent.png"), # latent image
915
+ Image.open("assets/examples/plax_decoded.png"), # decoded image
916
+ "✅ **Success:** Generated image passed privacy check.", # privacy status
917
+ Image.open("assets/examples/plax_filtered.png"), # filtered latent
918
+ "assets/examples/plax_latent.mp4", # latent animation
919
+ "assets/examples/plax_decoded.mp4", # decoded animation
920
+ ],
921
+ ],
922
+ inputs=[
923
+ mask_input,
924
+ class_selection,
925
+ sampling_steps,
926
+ ef_slider,
927
+ animation_steps,
928
+ cfg_slider,
929
+ latent_image_display,
930
+ decoded_image_display,
931
+ privacy_status,
932
+ filtered_latent_display,
933
+ latent_animation_display,
934
+ decoded_animation_display,
935
+ ],
936
+ fn=load_example,
937
+ label="Click on an example to see the results immediately.",
938
+ examples_per_page=3,
939
+ )
940
+
941
  # Main container with 4 columns
942
  with gr.Row():
943
  # Column 1: Latent Image Generation
 
950
  with gr.Row():
951
  # Input mask (binary image)
952
  with gr.Column(scale=1):
 
953
  gr.Markdown("Draw the LV mask (white = region of interest)")
954
  # Create a black background for the canvas
955
  black_background = np.zeros((400, 400), dtype=np.uint8)
 
978
  # Fall back to empty canvas
979
  editor_value = black_background
980
 
981
+ mask_input.value = editor_value
982
+ mask_input.render()
983
+ class_selection.render()
984
+ sampling_steps.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
 
986
  # Generate button
987
  generate_btn = gr.Button("Generate Latent Image", variant="primary")
988
 
989
  # Display area for latent image (grayscale visualization)
990
+ latent_image_display.render()
 
 
 
 
 
 
 
991
 
992
  # Decode button (initially disabled)
993
  decode_btn = gr.Button(
 
997
  )
998
 
999
  # Display area for decoded image
1000
+ decoded_image_display.render()
 
 
 
 
 
 
 
1001
 
1002
  # Column 2: Privacy Filter
1003
  with gr.Column():
 
1015
  )
1016
 
1017
  # Display area for privacy result status
1018
+ privacy_status.render()
1019
 
1020
  # Display area for privacy-filtered latent image
1021
+ filtered_latent_display.render()
 
 
1022
 
1023
  # Column 3: Animation
1024
  with gr.Column():
 
1028
  gr.Markdown("### Latent Video Generation")
1029
 
1030
  # Ejection Fraction slider
1031
+ ef_slider.render()
1032
+ animation_steps.render()
1033
+ cfg_slider.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034
 
1035
  # Animate button
1036
  animate_btn = gr.Button(
 
1038
  )
1039
 
1040
  # Display area for latent animation (grayscale)
1041
+ latent_animation_display.render()
 
 
1042
 
1043
  # Column 4: Video Decoding
1044
  with gr.Column():
 
1053
  )
1054
 
1055
  # Display area for decoded animation
1056
+ decoded_animation_display.render()
 
 
1057
 
1058
  # Hidden state variables to store the full latent representations
1059
  latent_image_state = gr.State(None)
 
1153
  queue=True,
1154
  )
1155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1156
  return demo
1157
 
1158