HReynaud commited on
Commit
65341d8
Β·
1 Parent(s): 1472edd
assets/{seg.png β†’ a4c_seg.png} RENAMED
File without changes
assets/examples/a4c_decoded.png ADDED
assets/examples/a4c_filtered.png ADDED
assets/examples/a4c_latent.png ADDED
assets/examples/plax_decoded.png ADDED
assets/examples/plax_filtered.png ADDED
assets/examples/plax_latent.png ADDED
assets/examples/psax_decoded.png ADDED
assets/examples/psax_filtered.png ADDED
assets/examples/psax_latent.png ADDED
assets/plax_seg.png ADDED
assets/psax_seg.png ADDED
demo.py CHANGED
@@ -645,6 +645,28 @@ def latent_animation_to_grayscale(latent_animation):
645
  return temp_file
646
 
647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  def create_demo():
649
  # Define the theme and layout
650
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -686,7 +708,7 @@ def create_demo():
686
 
687
  # Load the default mask image if it exists
688
  try:
689
- mask_image = Image.open("assets/seg.png").convert("L")
690
  mask_image = mask_image.resize(
691
  (400, 400), Image.Resampling.LANCZOS
692
  )
@@ -758,7 +780,8 @@ def create_demo():
758
  type="numpy",
759
  height=400,
760
  width=400,
761
- # show_label=False,
 
762
  )
763
 
764
  # Decode button (initially disabled)
@@ -774,7 +797,8 @@ def create_demo():
774
  type="numpy",
775
  height=400,
776
  width=400,
777
- # show_label=False,
 
778
  )
779
 
780
  # Column 2: Privacy Filter
@@ -869,6 +893,13 @@ def create_demo():
869
  latent_animation_state = gr.State(None)
870
 
871
  # Event handlers
 
 
 
 
 
 
 
872
  generate_btn.click(
873
  fn=generate_latent_image,
874
  inputs=[mask_input, class_selection, sampling_steps],
@@ -954,6 +985,121 @@ def create_demo():
954
  queue=True,
955
  )
956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
  return demo
958
 
959
 
 
645
  return temp_file
646
 
647
 
648
+ # Add function to load view-specific mask
649
+ def load_view_mask(view):
650
+ mask_path = f"assets/{view.lower()}_seg.png"
651
+ try:
652
+ mask_image = Image.open(mask_path).convert("L")
653
+ mask_image = mask_image.resize((400, 400), Image.Resampling.LANCZOS)
654
+ # Make it binary (0 or 255)
655
+ mask_image = ImageOps.autocontrast(mask_image, cutoff=0)
656
+ mask_array = np.array(mask_image)
657
+
658
+ # Create the editor value structure
659
+ editor_value = {
660
+ "background": np.zeros((400, 400), dtype=np.uint8), # Black background
661
+ "layers": [mask_array], # The mask as an editable layer
662
+ "composite": mask_array, # The composite image
663
+ }
664
+ return editor_value
665
+ except Exception as e:
666
+ print(f"Error loading mask for view {view}: {e}")
667
+ return None
668
+
669
+
670
  def create_demo():
671
  # Define the theme and layout
672
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
708
 
709
  # Load the default mask image if it exists
710
  try:
711
+ mask_image = Image.open("assets/a4c_seg.png").convert("L")
712
  mask_image = mask_image.resize(
713
  (400, 400), Image.Resampling.LANCZOS
714
  )
 
780
  type="numpy",
781
  height=400,
782
  width=400,
783
+ show_download_button=True,
784
+ interactive=False, # Cannot be uploaded/edited
785
  )
786
 
787
  # Decode button (initially disabled)
 
797
  type="numpy",
798
  height=400,
799
  width=400,
800
+ show_download_button=True,
801
+ interactive=False, # Cannot be uploaded/edited
802
  )
803
 
804
  # Column 2: Privacy Filter
 
893
  latent_animation_state = gr.State(None)
894
 
895
  # Event handlers
896
+ class_selection.change(
897
+ fn=load_view_mask,
898
+ inputs=[class_selection],
899
+ outputs=[mask_input],
900
+ queue=False,
901
+ )
902
+
903
  generate_btn.click(
904
  fn=generate_latent_image,
905
  inputs=[mask_input, class_selection, sampling_steps],
 
985
  queue=True,
986
  )
987
 
988
+ # Add examples
989
+ gr.Examples(
990
+ examples=[
991
+ # Example 1: A4C view
992
+ [
993
+ # Inputs
994
+ {
995
+ "background": np.zeros((400, 400), dtype=np.uint8),
996
+ "layers": [
997
+ np.array(
998
+ Image.open("assets/a4c_seg.png")
999
+ .convert("L")
1000
+ .resize((400, 400))
1001
+ )
1002
+ ],
1003
+ "composite": np.array(
1004
+ Image.open("assets/a4c_seg.png")
1005
+ .convert("L")
1006
+ .resize((400, 400))
1007
+ ),
1008
+ },
1009
+ "A4C", # view
1010
+ 100, # sampling steps
1011
+ 65, # EF slider
1012
+ 100, # animation steps
1013
+ 1.0, # cfg scale
1014
+ # Pre-computed outputs
1015
+ Image.open("assets/examples/a4c_latent.png"), # latent image
1016
+ Image.open("assets/examples/a4c_decoded.png"), # decoded image
1017
+ "βœ… **Success:** Generated image passed privacy check.", # privacy status
1018
+ Image.open("assets/examples/a4c_filtered.png"), # filtered latent
1019
+ "assets/examples/a4c_latent.mp4", # latent animation
1020
+ "assets/examples/a4c_decoded.mp4", # decoded animation
1021
+ ],
1022
+ # Example 2: PSAX view
1023
+ [
1024
+ # Inputs
1025
+ {
1026
+ "background": np.zeros((400, 400), dtype=np.uint8),
1027
+ "layers": [
1028
+ np.array(
1029
+ Image.open("assets/psax_seg.png")
1030
+ .convert("L")
1031
+ .resize((400, 400))
1032
+ )
1033
+ ],
1034
+ "composite": np.array(
1035
+ Image.open("assets/psax_seg.png")
1036
+ .convert("L")
1037
+ .resize((400, 400))
1038
+ ),
1039
+ },
1040
+ "PSAX", # view
1041
+ 150, # sampling steps
1042
+ 75, # EF slider
1043
+ 150, # animation steps
1044
+ 2.0, # cfg scale
1045
+ # Pre-computed outputs
1046
+ Image.open("assets/examples/psax_latent.png"), # latent image
1047
+ Image.open("assets/examples/psax_decoded.png"), # decoded image
1048
+ "βœ… **Success:** Generated image passed privacy check.", # privacy status
1049
+ Image.open("assets/examples/psax_filtered.png"), # filtered latent
1050
+ "assets/examples/psax_latent.mp4", # latent animation
1051
+ "assets/examples/psax_decoded.mp4", # decoded animation
1052
+ ],
1053
+ # Example 3: PLAX view
1054
+ [
1055
+ # Inputs
1056
+ {
1057
+ "background": np.zeros((400, 400), dtype=np.uint8),
1058
+ "layers": [
1059
+ np.array(
1060
+ Image.open("assets/plax_seg.png")
1061
+ .convert("L")
1062
+ .resize((400, 400))
1063
+ )
1064
+ ],
1065
+ "composite": np.array(
1066
+ Image.open("assets/plax_seg.png")
1067
+ .convert("L")
1068
+ .resize((400, 400))
1069
+ ),
1070
+ },
1071
+ "PLAX", # view
1072
+ 200, # sampling steps
1073
+ 55, # EF slider
1074
+ 200, # animation steps
1075
+ 3.0, # cfg scale
1076
+ # Pre-computed outputs
1077
+ Image.open("assets/examples/plax_latent.png"), # latent image
1078
+ Image.open("assets/examples/plax_decoded.png"), # decoded image
1079
+ "βœ… **Success:** Generated image passed privacy check.", # privacy status
1080
+ Image.open("assets/examples/plax_filtered.png"), # filtered latent
1081
+ "assets/examples/plax_latent.mp4", # latent animation
1082
+ "assets/examples/plax_decoded.mp4", # decoded animation
1083
+ ],
1084
+ ],
1085
+ inputs=[
1086
+ mask_input,
1087
+ class_selection,
1088
+ sampling_steps,
1089
+ ef_slider,
1090
+ animation_steps,
1091
+ cfg_slider,
1092
+ latent_image_display,
1093
+ decoded_image_display,
1094
+ privacy_status,
1095
+ filtered_latent_display,
1096
+ latent_animation_display,
1097
+ decoded_animation_display,
1098
+ ],
1099
+ label="Example Configurations",
1100
+ examples_per_page=3,
1101
+ )
1102
+
1103
  return demo
1104
 
1105