Wenzheng Chang commited on
Commit
d5d6d85
·
1 Parent(s): 159559c

init gradio

Browse files
Files changed (1) hide show
  1. scripts/demo_gradio.py +760 -287
scripts/demo_gradio.py CHANGED
@@ -17,6 +17,7 @@ from diffusers import (
17
  CogVideoXTransformer3DModel,
18
  )
19
  from transformers import AutoTokenizer, T5EncoderModel
 
20
 
21
 
22
  rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
@@ -39,9 +40,6 @@ from aether.utils.postprocess_utils import ( # noqa: E402
39
  from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
40
 
41
 
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
-
44
-
45
  def seed_all(seed: int = 0) -> None:
46
  """
47
  Set random seeds of all components.
@@ -52,7 +50,7 @@ def seed_all(seed: int = 0) -> None:
52
  torch.cuda.manual_seed_all(seed)
53
 
54
 
55
- # Global pipeline
56
  cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
57
  aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
58
  pipeline = AetherV1PipelineCogVideoX(
@@ -64,22 +62,45 @@ pipeline = AetherV1PipelineCogVideoX(
64
  cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
65
  ),
66
  vae=AutoencoderKLCogVideoX.from_pretrained(
67
- cogvideox_pretrained_model_name_or_path, subfolder="vae"
68
  ),
69
  scheduler=CogVideoXDPMScheduler.from_pretrained(
70
  cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
71
  ),
72
  transformer=CogVideoXTransformer3DModel.from_pretrained(
73
- aether_pretrained_model_name_or_path, subfolder="transformer"
74
  ),
75
  )
76
  pipeline.vae.enable_slicing()
77
  pipeline.vae.enable_tiling()
78
- pipeline.to(device)
79
 
80
 
81
- def build_pipeline() -> AetherV1PipelineCogVideoX:
82
  """Initialize the model pipeline."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return pipeline
84
 
85
 
@@ -395,12 +416,29 @@ def save_output_files(
395
  for frame_idx in frames_to_save:
396
  if frame_idx >= pointmap.shape[0]:
397
  continue
398
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  predictions = {
400
- "world_points": pointmap[frame_idx : frame_idx + 1],
401
  "images": rgb[frame_idx : frame_idx + 1],
402
  "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
403
- "camera_poses": poses[frame_idx : frame_idx + 1],
404
  }
405
 
406
  glb_path = os.path.join(
@@ -423,6 +461,7 @@ def save_output_files(
423
  return paths
424
 
425
 
 
426
  def process_reconstruction(
427
  video_file,
428
  height,
@@ -447,11 +486,13 @@ def process_reconstruction(
447
  gc.collect()
448
  torch.cuda.empty_cache()
449
 
450
- # Set random seed
451
  seed_all(seed)
452
-
453
- # Build the pipeline
454
- pipeline = build_pipeline()
 
 
 
455
 
456
  progress(0.1, "Loading video")
457
  # Check if video_file is a string or a file object
@@ -545,6 +586,7 @@ def process_reconstruction(
545
  return None, None, []
546
 
547
 
 
548
  def process_prediction(
549
  image_file,
550
  height,
@@ -573,9 +615,14 @@ def process_prediction(
573
 
574
  # Set random seed
575
  seed_all(seed)
 
 
 
 
 
576
 
577
  # Build the pipeline
578
- pipeline = build_pipeline()
579
 
580
  progress(0.1, "Loading image")
581
  # Check if image_file is a string or a file object
@@ -671,6 +718,7 @@ def process_prediction(
671
  return None, None, []
672
 
673
 
 
674
  def process_planning(
675
  image_file,
676
  goal_file,
@@ -700,8 +748,13 @@ def process_planning(
700
  # Set random seed
701
  seed_all(seed)
702
 
 
 
 
 
 
703
  # Build the pipeline
704
- pipeline = build_pipeline()
705
 
706
  progress(0.1, "Loading images")
707
  # Check if image_file and goal_file are strings or file objects
@@ -807,11 +860,10 @@ def update_task_ui(task):
807
  """Update UI elements based on selected task."""
808
  if task == "reconstruction":
809
  return (
810
- gr.update(visible=True), # video_input
811
- gr.update(visible=False), # image_input
812
- gr.update(visible=False), # goal_input
813
- gr.update(visible=False), # image_preview
814
- gr.update(visible=False), # goal_preview
815
  gr.update(value=4), # num_inference_steps
816
  gr.update(visible=True), # sliding_window_stride
817
  gr.update(visible=False), # use_dynamic_cfg
@@ -821,11 +873,10 @@ def update_task_ui(task):
821
  )
822
  elif task == "prediction":
823
  return (
824
- gr.update(visible=False), # video_input
825
- gr.update(visible=True), # image_input
826
- gr.update(visible=False), # goal_input
827
- gr.update(visible=True), # image_preview
828
- gr.update(visible=False), # goal_preview
829
  gr.update(value=50), # num_inference_steps
830
  gr.update(visible=False), # sliding_window_stride
831
  gr.update(visible=True), # use_dynamic_cfg
@@ -835,11 +886,10 @@ def update_task_ui(task):
835
  )
836
  elif task == "planning":
837
  return (
838
- gr.update(visible=False), # video_input
839
- gr.update(visible=True), # image_input
840
- gr.update(visible=True), # goal_input
841
- gr.update(visible=True), # image_preview
842
- gr.update(visible=True), # goal_preview
843
  gr.update(value=50), # num_inference_steps
844
  gr.update(visible=False), # sliding_window_stride
845
  gr.update(visible=True), # use_dynamic_cfg
@@ -851,16 +901,20 @@ def update_task_ui(task):
851
 
852
  def update_image_preview(image_file):
853
  """Update the image preview."""
854
- if image_file:
855
- return image_file.name
856
- return None
 
 
857
 
858
 
859
  def update_goal_preview(goal_file):
860
  """Update the goal preview."""
861
- if goal_file:
862
- return goal_file.name
863
- return None
 
 
864
 
865
 
866
  def get_download_link(selected_frame, all_paths):
@@ -892,8 +946,17 @@ with gr.Blocks(
892
  min-height: 400px;
893
  }
894
  .warning {
895
- color: #ff9800;
896
- font-weight: bold;
 
 
 
 
 
 
 
 
 
897
  }
898
  .highlight {
899
  background-color: rgba(0, 123, 255, 0.1);
@@ -903,9 +966,9 @@ with gr.Blocks(
903
  margin: 10px 0;
904
  }
905
  .task-header {
906
- margin-top: 10px;
907
- margin-bottom: 15px;
908
- font-size: 1.2em;
909
  font-weight: bold;
910
  color: #007bff;
911
  }
@@ -922,9 +985,9 @@ with gr.Blocks(
922
  }
923
  .input-section, .params-section, .advanced-section {
924
  border: 1px solid #ddd;
925
- padding: 15px;
926
  border-radius: 8px;
927
- margin-bottom: 15px;
928
  }
929
  .logo-container {
930
  display: flex;
@@ -935,288 +998,703 @@ with gr.Blocks(
935
  max-width: 300px;
936
  height: auto;
937
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
  """,
939
  ) as demo:
940
- with gr.Row(elem_classes=["logo-container"]):
941
- gr.Image("assets/logo.png", show_label=False, elem_classes=["logo-image"])
942
-
943
- gr.Markdown(
944
- """
945
- # Aether: Geometric-Aware Unified World Modeling
946
-
947
- Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with
948
- generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities:
949
-
950
- 1. **4D dynamic reconstruction** - Reconstruct dynamic point clouds from videos by estimating depths and camera poses.
951
- 2. **Action-Conditioned Video Prediction** - Predict future frames based on initial observation images, with optional conditions of camera trajectory actions.
952
- 3. **Goal-Conditioned Visual Planning** - Generate planning paths from pairs of observation and goal images.
953
-
954
- Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.
955
- """
956
- )
957
-
958
- with gr.Row():
959
- with gr.Column(scale=1):
960
- task = gr.Radio(
961
- ["reconstruction", "prediction", "planning"],
962
- label="Select Task",
963
- value="reconstruction",
964
- info="Choose the task you want to perform",
965
- )
966
-
967
- with gr.Group(elem_classes=["input-section"]):
968
- # Input section - changes based on task
969
- gr.Markdown("## 📥 Input", elem_classes=["task-header"])
970
-
971
- # Task-specific inputs
972
- video_input = gr.Video(
973
- label="Upload Input Video",
974
- sources=["upload"],
975
- visible=True,
976
- interactive=True,
977
- elem_id="video_input",
978
  )
979
-
980
- image_input = gr.File(
981
- label="Upload Start Image",
982
- file_count="single",
983
- file_types=["image"],
984
- visible=False,
985
- interactive=True,
986
- elem_id="image_input",
 
 
987
  )
988
 
989
- goal_input = gr.File(
990
- label="Upload Goal Image",
991
- file_count="single",
992
- file_types=["image"],
993
- visible=False,
994
- interactive=True,
995
- elem_id="goal_input",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
996
  )
997
 
998
- with gr.Row(visible=False) as preview_row:
999
- image_preview = gr.Image(
1000
- label="Start Image Preview",
1001
- elem_id="image_preview",
1002
- visible=False,
1003
- )
1004
- goal_preview = gr.Image(
1005
- label="Goal Image Preview",
1006
- elem_id="goal_preview",
1007
- visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
1008
  )
1009
 
1010
- with gr.Group(elem_classes=["params-section"]):
1011
- gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"])
1012
 
1013
- with gr.Row():
1014
- with gr.Column(scale=1):
1015
- height = gr.Dropdown(
1016
- choices=[480],
1017
- value=480,
1018
- label="Height",
1019
- info="Height of the output video",
1020
  )
1021
-
1022
- with gr.Column(scale=1):
1023
- width = gr.Dropdown(
1024
- choices=[720],
1025
- value=720,
1026
- label="Width",
1027
- info="Width of the output video",
 
 
 
1028
  )
1029
 
1030
- with gr.Row():
1031
- with gr.Column(scale=1):
1032
- num_frames = gr.Dropdown(
1033
- choices=[17, 25, 33, 41],
1034
- value=41,
1035
- label="Number of Frames",
1036
- info="Number of frames to predict",
1037
  )
1038
-
1039
- with gr.Column(scale=1):
1040
- fps = gr.Dropdown(
1041
- choices=[8, 10, 12, 15, 24],
1042
- value=12,
1043
- label="FPS",
1044
- info="Frames per second",
 
 
 
1045
  )
1046
 
1047
- with gr.Row():
1048
- with gr.Column(scale=1):
1049
- num_inference_steps = gr.Slider(
1050
- minimum=1,
1051
- maximum=60,
1052
- value=4,
1053
- step=1,
1054
- label="Inference Steps",
1055
- info="Number of inference step",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1056
  )
1057
 
1058
- sliding_window_stride = gr.Slider(
1059
- minimum=1,
1060
- maximum=40,
1061
- value=24,
1062
- step=1,
1063
- label="Sliding Window Stride",
1064
- info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task",
1065
- visible=True,
1066
- )
1067
-
1068
- use_dynamic_cfg = gr.Checkbox(
1069
- label="Use Dynamic CFG",
1070
- value=True,
1071
- info="Use dynamic CFG",
1072
- visible=False,
1073
- )
1074
-
1075
- raymap_option = gr.Radio(
1076
- choices=["backward", "forward_right", "left_forward", "right"],
1077
- label="Camera Movement Direction",
1078
- value="forward_right",
1079
- info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.",
1080
- visible=False,
1081
- )
1082
 
1083
- post_reconstruction = gr.Checkbox(
1084
- label="Post-Reconstruction",
1085
- value=True,
1086
- info="Run reconstruction after prediction for better quality",
1087
- visible=False,
1088
- )
1089
 
1090
- with gr.Accordion(
1091
- "Advanced Options", open=False, visible=True
1092
- ) as advanced_options:
1093
- with gr.Group(elem_classes=["advanced-section"]):
1094
  with gr.Row():
1095
  with gr.Column(scale=1):
1096
- guidance_scale = gr.Slider(
1097
- minimum=1.0,
1098
- maximum=10.0,
1099
- value=1.0,
1100
- step=0.1,
1101
- label="Guidance Scale",
1102
- info="Guidance scale (only for prediction / planning)",
1103
  )
1104
 
1105
- with gr.Row():
1106
  with gr.Column(scale=1):
1107
- seed = gr.Number(
1108
- value=42,
1109
- label="Random Seed",
1110
- info="Set a seed for reproducible results",
1111
- precision=0,
1112
- minimum=0,
1113
- maximum=2147483647,
1114
  )
1115
 
1116
  with gr.Row():
1117
  with gr.Column(scale=1):
1118
- smooth_camera = gr.Checkbox(
1119
- label="Smooth Camera",
1120
- value=True,
1121
- info="Apply smoothing to camera trajectory",
 
1122
  )
1123
 
1124
  with gr.Column(scale=1):
1125
- align_pointmaps = gr.Checkbox(
1126
- label="Align Point Maps",
1127
- value=False,
1128
- info="Align point maps across frames",
 
1129
  )
1130
 
1131
  with gr.Row():
1132
- with gr.Column(scale=1):
1133
- max_depth = gr.Slider(
1134
- minimum=10,
1135
- maximum=200,
1136
- value=60,
1137
- step=10,
1138
- label="Max Depth",
1139
- info="Maximum depth for point cloud (higher = more distant points)",
1140
- )
1141
-
1142
- with gr.Column(scale=1):
1143
- rtol = gr.Slider(
1144
- minimum=0.01,
1145
- maximum=2.0,
1146
- value=0.03,
1147
- step=0.01,
1148
- label="Relative Tolerance",
1149
- info="Used for depth edge detection. Lower = remove more edges",
1150
- )
1151
 
1152
- pointcloud_save_frame_interval = gr.Slider(
1153
  minimum=1,
1154
- maximum=20,
1155
- value=10,
1156
  step=1,
1157
- label="Point Cloud Frame Interval",
1158
- info="Save point cloud every N frames (higher = fewer files but less complete representation)",
 
1159
  )
1160
 
1161
- run_button = gr.Button("Run Aether", variant="primary")
1162
-
1163
- with gr.Column(scale=1, elem_classes=["output-column"]):
1164
- with gr.Group():
1165
- gr.Markdown("## 📤 Output", elem_classes=["task-header"])
1166
-
1167
- gr.Markdown("### RGB Video", elem_classes=["output-subtitle"])
1168
- rgb_output = gr.Video(
1169
- label="RGB Output", interactive=False, elem_id="rgb_output"
1170
- )
1171
-
1172
- gr.Markdown("### Depth Video", elem_classes=["output-subtitle"])
1173
- depth_output = gr.Video(
1174
- label="Depth Output", interactive=False, elem_id="depth_output"
1175
- )
1176
-
1177
- gr.Markdown("### Point Clouds", elem_classes=["output-subtitle"])
1178
- with gr.Row(elem_classes=["flex-display"]):
1179
- pointcloud_frames = gr.Dropdown(
1180
- label="Select Frame",
1181
- choices=[],
1182
- value=None,
1183
- interactive=True,
1184
- elem_id="pointcloud_frames",
1185
- )
1186
- pointcloud_download = gr.DownloadButton(
1187
- label="Download Point Cloud",
1188
  visible=False,
1189
- elem_id="pointcloud_download",
1190
  )
1191
 
1192
- model_output = gr.Model3D(
1193
- label="Point Cloud Viewer", interactive=True, elem_id="model_output"
1194
- )
 
 
 
 
1195
 
1196
- with gr.Tab("About Results"):
1197
- gr.Markdown(
1198
- """
1199
- ### Understanding the Outputs
 
 
1200
 
1201
- - **RGB Video**: Shows the predicted or reconstructed RGB frames
1202
- - **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue)
1203
- - **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1204
 
1205
- <p class="warning">Note: 3D point clouds take a long time to visualize, and we show the keyframes only.
1206
- You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.</p>
1207
- """
1208
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1209
 
1210
  # Event handlers
1211
  task.change(
1212
  fn=update_task_ui,
1213
  inputs=[task],
1214
  outputs=[
1215
- video_input,
1216
- image_input,
1217
- goal_input,
1218
- image_preview,
1219
- goal_preview,
1220
  num_inference_steps,
1221
  sliding_window_stride,
1222
  use_dynamic_cfg,
@@ -1227,11 +1705,15 @@ with gr.Blocks(
1227
  )
1228
 
1229
  image_input.change(
1230
- fn=update_image_preview, inputs=[image_input], outputs=[image_preview]
 
 
1231
  ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1232
 
1233
  goal_input.change(
1234
- fn=update_goal_preview, inputs=[goal_input], outputs=[goal_preview]
 
 
1235
  ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1236
 
1237
  def update_pointcloud_frames(pointcloud_paths):
@@ -1453,17 +1935,8 @@ with gr.Blocks(
1453
  outputs=[pointcloud_download],
1454
  )
1455
 
1456
- # Example Accordion
1457
- with gr.Accordion("Examples"):
1458
- gr.Markdown(
1459
- """
1460
- ### Examples will be added soon
1461
- Check back for example inputs for each task type.
1462
- """
1463
- )
1464
-
1465
  # Load the model at startup
1466
- demo.load(lambda: build_pipeline(), inputs=None, outputs=None)
1467
 
1468
  if __name__ == "__main__":
1469
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
17
  CogVideoXTransformer3DModel,
18
  )
19
  from transformers import AutoTokenizer, T5EncoderModel
20
+ import spaces
21
 
22
 
23
  rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
40
  from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
41
 
42
 
 
 
 
43
  def seed_all(seed: int = 0) -> None:
44
  """
45
  Set random seeds of all components.
 
50
  torch.cuda.manual_seed_all(seed)
51
 
52
 
53
+ # # Global pipeline
54
  cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
55
  aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
56
  pipeline = AetherV1PipelineCogVideoX(
 
62
  cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
63
  ),
64
  vae=AutoencoderKLCogVideoX.from_pretrained(
65
+ cogvideox_pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16
66
  ),
67
  scheduler=CogVideoXDPMScheduler.from_pretrained(
68
  cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
69
  ),
70
  transformer=CogVideoXTransformer3DModel.from_pretrained(
71
+ aether_pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16
72
  ),
73
  )
74
  pipeline.vae.enable_slicing()
75
  pipeline.vae.enable_tiling()
76
+ # pipeline.to(device)
77
 
78
 
79
+ def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX:
80
  """Initialize the model pipeline."""
81
+ # cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
82
+ # aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
83
+ # pipeline = AetherV1PipelineCogVideoX(
84
+ # tokenizer=AutoTokenizer.from_pretrained(
85
+ # cogvideox_pretrained_model_name_or_path,
86
+ # subfolder="tokenizer",
87
+ # ),
88
+ # text_encoder=T5EncoderModel.from_pretrained(
89
+ # cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
90
+ # ),
91
+ # vae=AutoencoderKLCogVideoX.from_pretrained(
92
+ # cogvideox_pretrained_model_name_or_path, subfolder="vae"
93
+ # ),
94
+ # scheduler=CogVideoXDPMScheduler.from_pretrained(
95
+ # cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
96
+ # ),
97
+ # transformer=CogVideoXTransformer3DModel.from_pretrained(
98
+ # aether_pretrained_model_name_or_path, subfolder="transformer"
99
+ # ),
100
+ # )
101
+ # pipeline.vae.enable_slicing()
102
+ # pipeline.vae.enable_tiling()
103
+ pipeline.to(device)
104
  return pipeline
105
 
106
 
 
416
  for frame_idx in frames_to_save:
417
  if frame_idx >= pointmap.shape[0]:
418
  continue
419
+
420
+ # fix the problem of point cloud being upside down and left-right reversed: flip Y axis and X axis
421
+ flipped_pointmap = pointmap[frame_idx:frame_idx+1].copy()
422
+ flipped_pointmap[..., 1] = -flipped_pointmap[..., 1] # flip Y axis (up and down)
423
+ flipped_pointmap[..., 0] = -flipped_pointmap[..., 0] # flip X axis (left and right)
424
+
425
+ # flip camera poses
426
+ flipped_poses = poses[frame_idx:frame_idx+1].copy()
427
+ # flip Y axis and X axis of camera orientation
428
+ flipped_poses[..., 1, :3] = -flipped_poses[..., 1, :3] # flip Y axis of camera orientation
429
+ flipped_poses[..., 0, :3] = -flipped_poses[..., 0, :3] # flip X axis of camera orientation
430
+ flipped_poses[..., :3, 1] = -flipped_poses[..., :3, 1] # flip Y axis of camera orientation
431
+ flipped_poses[..., :3, 0] = -flipped_poses[..., :3, 0] # flip X axis of camera orientation
432
+ # flip Y axis and X axis of camera position
433
+ flipped_poses[..., 1, 3] = -flipped_poses[..., 1, 3] # flip Y axis position
434
+ flipped_poses[..., 0, 3] = -flipped_poses[..., 0, 3] # flip X axis position
435
+
436
+ # use flipped point cloud and camera poses
437
  predictions = {
438
+ "world_points": flipped_pointmap,
439
  "images": rgb[frame_idx : frame_idx + 1],
440
  "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
441
+ "camera_poses": flipped_poses,
442
  }
443
 
444
  glb_path = os.path.join(
 
461
  return paths
462
 
463
 
464
+ @spaces.GPU(duration=300)
465
  def process_reconstruction(
466
  video_file,
467
  height,
 
486
  gc.collect()
487
  torch.cuda.empty_cache()
488
 
 
489
  seed_all(seed)
490
+
491
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
492
+ if not torch.cuda.is_available():
493
+ raise ValueError("CUDA is not available. Check your environment.")
494
+
495
+ pipeline = build_pipeline(device)
496
 
497
  progress(0.1, "Loading video")
498
  # Check if video_file is a string or a file object
 
586
  return None, None, []
587
 
588
 
589
+ @spaces.GPU(duration=300)
590
  def process_prediction(
591
  image_file,
592
  height,
 
615
 
616
  # Set random seed
617
  seed_all(seed)
618
+
619
+ # Check if CUDA is available
620
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
621
+ if not torch.cuda.is_available():
622
+ raise ValueError("CUDA is not available. Check your environment.")
623
 
624
  # Build the pipeline
625
+ pipeline = build_pipeline(device)
626
 
627
  progress(0.1, "Loading image")
628
  # Check if image_file is a string or a file object
 
718
  return None, None, []
719
 
720
 
721
+ @spaces.GPU(duration=300)
722
  def process_planning(
723
  image_file,
724
  goal_file,
 
748
  # Set random seed
749
  seed_all(seed)
750
 
751
+ # Check if CUDA is available
752
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
753
+ if not torch.cuda.is_available():
754
+ raise ValueError("CUDA is not available. Check your environment.")
755
+
756
  # Build the pipeline
757
+ pipeline = build_pipeline(device)
758
 
759
  progress(0.1, "Loading images")
760
  # Check if image_file and goal_file are strings or file objects
 
860
  """Update UI elements based on selected task."""
861
  if task == "reconstruction":
862
  return (
863
+ gr.update(visible=True), # reconstruction_group
864
+ gr.update(visible=False), # prediction_group
865
+ gr.update(visible=False), # planning_group
866
+ gr.update(visible=False), # preview_row
 
867
  gr.update(value=4), # num_inference_steps
868
  gr.update(visible=True), # sliding_window_stride
869
  gr.update(visible=False), # use_dynamic_cfg
 
873
  )
874
  elif task == "prediction":
875
  return (
876
+ gr.update(visible=False), # reconstruction_group
877
+ gr.update(visible=True), # prediction_group
878
+ gr.update(visible=False), # planning_group
879
+ gr.update(visible=True), # preview_row
 
880
  gr.update(value=50), # num_inference_steps
881
  gr.update(visible=False), # sliding_window_stride
882
  gr.update(visible=True), # use_dynamic_cfg
 
886
  )
887
  elif task == "planning":
888
  return (
889
+ gr.update(visible=False), # reconstruction_group
890
+ gr.update(visible=False), # prediction_group
891
+ gr.update(visible=True), # planning_group
892
+ gr.update(visible=True), # preview_row
 
893
  gr.update(value=50), # num_inference_steps
894
  gr.update(visible=False), # sliding_window_stride
895
  gr.update(visible=True), # use_dynamic_cfg
 
901
 
902
  def update_image_preview(image_file):
903
  """Update the image preview."""
904
+ if image_file is None:
905
+ return None
906
+ if isinstance(image_file, str):
907
+ return image_file
908
+ return image_file.name if hasattr(image_file, 'name') else None
909
 
910
 
911
  def update_goal_preview(goal_file):
912
  """Update the goal preview."""
913
+ if goal_file is None:
914
+ return None
915
+ if isinstance(goal_file, str):
916
+ return goal_file
917
+ return goal_file.name if hasattr(goal_file, 'name') else None
918
 
919
 
920
  def get_download_link(selected_frame, all_paths):
 
946
  min-height: 400px;
947
  }
948
  .warning {
949
+ color: #856404 !important;
950
+ font-weight: bold !important;
951
+ padding: 10px !important;
952
+ background-color: #fff3cd !important;
953
+ border-left: 4px solid #ffc107 !important;
954
+ border-radius: 4px !important;
955
+ margin: 10px 0 !important;
956
+ }
957
+ .dark .warning {
958
+ background-color: rgba(255, 193, 7, 0.1) !important;
959
+ color: #fbd38d !important;
960
  }
961
  .highlight {
962
  background-color: rgba(0, 123, 255, 0.1);
 
966
  margin: 10px 0;
967
  }
968
  .task-header {
969
+ margin-top: 15px;
970
+ margin-bottom: 20px;
971
+ font-size: 1.4em;
972
  font-weight: bold;
973
  color: #007bff;
974
  }
 
985
  }
986
  .input-section, .params-section, .advanced-section {
987
  border: 1px solid #ddd;
988
+ padding: 20px;
989
  border-radius: 8px;
990
+ margin-bottom: 20px;
991
  }
992
  .logo-container {
993
  display: flex;
 
998
  max-width: 300px;
999
  height: auto;
1000
  }
1001
+
1002
+ /* Optimize layout and spacing */
1003
+ .container {
1004
+ margin: 0 auto;
1005
+ padding: 0 15px;
1006
+ max-width: 1800px;
1007
+ }
1008
+
1009
+ .header {
1010
+ text-align: center;
1011
+ margin-bottom: 20px;
1012
+ padding: 15px;
1013
+ background: linear-gradient(to right, #f8f9fa, #e9ecef);
1014
+ border-radius: 10px;
1015
+ }
1016
+
1017
+ .dark .header {
1018
+ background: linear-gradient(to right, #2d3748, #1a202c);
1019
+ }
1020
+
1021
+ .main-title {
1022
+ font-size: 2.2em;
1023
+ font-weight: bold;
1024
+ margin: 0 auto;
1025
+ color: #2c3e50;
1026
+ max-width: 800px;
1027
+ }
1028
+
1029
+ .dark .main-title {
1030
+ color: #e2e8f0;
1031
+ }
1032
+
1033
+ .links-bar {
1034
+ display: flex;
1035
+ justify-content: center;
1036
+ gap: 15px;
1037
+ margin: 12px 0;
1038
+ }
1039
+
1040
+ .link-button {
1041
+ display: inline-flex;
1042
+ align-items: center;
1043
+ padding: 6px 12px;
1044
+ background-color: #007bff;
1045
+ color: white !important;
1046
+ text-decoration: none;
1047
+ border-radius: 5px;
1048
+ transition: background-color 0.3s;
1049
+ font-size: 0.95em;
1050
+ }
1051
+
1052
+ .link-button:hover {
1053
+ background-color: #0056b3;
1054
+ text-decoration: none;
1055
+ }
1056
+
1057
+ .features-limitations-container {
1058
+ display: flex;
1059
+ gap: 15px;
1060
+ margin: 20px 0;
1061
+ }
1062
+
1063
+ .capabilities-box, .limitations-box {
1064
+ flex: 1;
1065
+ padding: 18px;
1066
+ border-radius: 8px;
1067
+ margin-bottom: 15px;
1068
+ }
1069
+
1070
+ .capabilities-box {
1071
+ background: #f0f9ff;
1072
+ border-left: 5px solid #3498db;
1073
+ }
1074
+
1075
+ .dark .capabilities-box {
1076
+ background: #172a3a;
1077
+ border-left: 5px solid #3498db;
1078
+ }
1079
+
1080
+ .limitations-box {
1081
+ background: #f8f9fa;
1082
+ border-left: 5px solid #ffc107;
1083
+ }
1084
+
1085
+ .dark .limitations-box {
1086
+ background: #2d2a20;
1087
+ border-left: 5px solid #ffc107;
1088
+ }
1089
+
1090
+ .capabilities-text, .limitations-text {
1091
+ color: #495057;
1092
+ line-height: 1.6;
1093
+ }
1094
+
1095
+ .dark .capabilities-text, .dark .limitations-text {
1096
+ color: #cbd5e0;
1097
+ }
1098
+
1099
+ .capabilities-text h3 {
1100
+ color: #2980b9;
1101
+ margin-top: 0;
1102
+ margin-bottom: 15px;
1103
+ }
1104
+
1105
+ .dark .capabilities-text h3 {
1106
+ color: #63b3ed;
1107
+ }
1108
+
1109
+ .limitations-text h3 {
1110
+ color: #d39e00;
1111
+ margin-top: 0;
1112
+ margin-bottom: 15px;
1113
+ }
1114
+
1115
+ .dark .limitations-text h3 {
1116
+ color: #fbd38d;
1117
+ }
1118
+
1119
+ .capabilities-text blockquote, .limitations-text blockquote {
1120
+ margin: 20px 0 0 0;
1121
+ padding: 10px 20px;
1122
+ font-style: italic;
1123
+ }
1124
+
1125
+ .capabilities-text blockquote {
1126
+ border-left: 3px solid #3498db;
1127
+ background: rgba(52, 152, 219, 0.1);
1128
+ }
1129
+
1130
+ .dark .capabilities-text blockquote {
1131
+ background: rgba(52, 152, 219, 0.2);
1132
+ }
1133
+
1134
+ .limitations-text blockquote {
1135
+ border-left: 3px solid #ffc107;
1136
+ background: rgba(255, 193, 7, 0.1);
1137
+ }
1138
+
1139
+ .dark .limitations-text blockquote {
1140
+ background: rgba(255, 193, 7, 0.2);
1141
+ }
1142
+
1143
+ /* Optimize layout and spacing */
1144
+ .main-interface {
1145
+ display: flex;
1146
+ gap: 30px;
1147
+ margin-top: 20px;
1148
+ }
1149
+
1150
+ .input-column, .output-column {
1151
+ flex: 1;
1152
+ min-width: 0;
1153
+ display: flex;
1154
+ flex-direction: column;
1155
+ }
1156
+
1157
+ .output-panel {
1158
+ border: 1px solid #ddd;
1159
+ border-radius: 8px;
1160
+ padding: 20px;
1161
+ height: 100%;
1162
+ display: flex;
1163
+ flex-direction: column;
1164
+ overflow-y: auto;
1165
+ }
1166
+
1167
+ .dark .output-panel {
1168
+ border-color: #4a5568;
1169
+ }
1170
+
1171
+ .run-button-container {
1172
+ display: flex;
1173
+ justify-content: center;
1174
+ margin: 15px 0;
1175
+ }
1176
+
1177
+ .run-button {
1178
+ padding: 10px 30px;
1179
+ font-size: 1.1em;
1180
+ font-weight: bold;
1181
+ background: linear-gradient(to right, #3498db, #2980b9);
1182
+ border: none;
1183
+ border-radius: 5px;
1184
+ color: white;
1185
+ cursor: pointer;
1186
+ transition: all 0.3s;
1187
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
1188
+ }
1189
+
1190
+ .run-button:hover {
1191
+ background: linear-gradient(to right, #2980b9, #1a5276);
1192
+ box-shadow: 0 6px 8px rgba(0, 0, 0, 0.15);
1193
+ transform: translateY(-2px);
1194
+ }
1195
+
1196
+ .task-selector {
1197
+ background-color: #f8f9fa;
1198
+ padding: 12px;
1199
+ border-radius: 8px;
1200
+ margin-bottom: 15px;
1201
+ border: 1px solid #e9ecef;
1202
+ }
1203
+
1204
+ .dark .task-selector {
1205
+ background-color: #2d3748;
1206
+ border-color: #4a5568;
1207
+ }
1208
+
1209
+ /* Compact parameter settings */
1210
+ .compact-params .row {
1211
+ margin-bottom: 8px;
1212
+ }
1213
+
1214
+ .compact-params label {
1215
+ margin-bottom: 4px;
1216
+ }
1217
+
1218
+ /* More obvious advanced options */
1219
+ .advanced-options-header {
1220
+ background-color: #e9ecef;
1221
+ padding: 10px 15px;
1222
+ border-radius: 6px;
1223
+ margin-top: 10px;
1224
+ font-weight: bold;
1225
+ color: #495057;
1226
+ border-left: 4px solid #6c757d;
1227
+ cursor: pointer;
1228
+ transition: all 0.2s;
1229
+ }
1230
+
1231
+ .advanced-options-header:hover {
1232
+ background-color: #dee2e6;
1233
+ }
1234
+
1235
+ .dark .advanced-options-header {
1236
+ background-color: #2d3748;
1237
+ color: #e2e8f0;
1238
+ border-left: 4px solid #a0aec0;
1239
+ }
1240
+
1241
+ .dark .advanced-options-header:hover {
1242
+ background-color: #4a5568;
1243
+ }
1244
+
1245
+ /* Vertical arrangement of output section */
1246
+ .output-section {
1247
+ margin-bottom: 30px;
1248
+ border: 1px solid #e9ecef;
1249
+ border-radius: 8px;
1250
+ padding: 20px;
1251
+ }
1252
+
1253
+ .output-section-title {
1254
+ font-weight: bold;
1255
+ color: #495057;
1256
+ margin-bottom: 15px;
1257
+ font-size: 1.2em;
1258
+ }
1259
+
1260
+ .dark .output-section-title {
1261
+ color: #e2e8f0;
1262
+ }
1263
+
1264
+ .pointcloud-controls {
1265
+ display: flex;
1266
+ gap: 10px;
1267
+ margin-bottom: 10px;
1268
+ align-items: center;
1269
+ }
1270
+
1271
+ .note-box {
1272
+ background-color: #fff8e1 !important;
1273
+ border-left: 4px solid #ffc107 !important;
1274
+ padding: 12px !important;
1275
+ margin: 15px 0 !important;
1276
+ border-radius: 4px !important;
1277
+ color: #333 !important;
1278
+ }
1279
+
1280
+ .dark .note-box {
1281
+ background-color: rgba(255, 193, 7, 0.1) !important;
1282
+ color: #e0e0e0 !important;
1283
+ }
1284
+
1285
+ .note-box p, .note-box strong {
1286
+ color: inherit !important;
1287
+ }
1288
+
1289
+ /* Ensure warning class styles are correctly applied */
1290
+ .warning {
1291
+ color: #856404 !important;
1292
+ font-weight: bold !important;
1293
+ padding: 10px !important;
1294
+ background-color: #fff3cd !important;
1295
+ border-left: 4px solid #ffc107 !important;
1296
+ border-radius: 4px !important;
1297
+ margin: 10px 0 !important;
1298
+ }
1299
+
1300
+ .dark .warning {
1301
+ background-color: rgba(255, 193, 7, 0.1) !important;
1302
+ color: #fbd38d !important;
1303
+ }
1304
+
1305
+ .warning-box {
1306
+ background-color: #fff3cd;
1307
+ border-left: 4px solid #ffc107;
1308
+ padding: 12px;
1309
+ margin: 15px 0;
1310
+ border-radius: 4px;
1311
+ color: #856404;
1312
+ }
1313
+
1314
+ .dark .warning-box {
1315
+ background-color: rgba(255, 193, 7, 0.1);
1316
+ color: #fbd38d;
1317
+ }
1318
  """,
1319
  ) as demo:
1320
+ with gr.Column(elem_classes=["container"]):
1321
+ with gr.Row(elem_classes=["header"]):
1322
+ with gr.Column():
1323
+ gr.Markdown(
1324
+ """
1325
+ # Aether: Geometric-Aware Unified World Modeling
1326
+ """,
1327
+ elem_classes=["main-title"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1328
  )
1329
+
1330
+ gr.Markdown(
1331
+ """
1332
+ <div class="links-bar">
1333
+ 🌐<a href="https://aether-world.github.io/" class="link-button" target="_blank"> Project Page</a>
1334
+ 📄<a href="https://arxiv.org/abs/2503.18945" class="link-button" target="_blank"> Paper</a>
1335
+ 💻<a href="https://github.com/OpenRobotLab/Aether" class="link-button" target="_blank"> Code</a>
1336
+ 🤗<a href="https://huggingface.co/AetherWorldModel/AetherV1" class="link-button" target="_blank"> Model</a>
1337
+ </div>
1338
+ """,
1339
  )
1340
 
1341
+ with gr.Row(elem_classes=["features-limitations-container"]):
1342
+ with gr.Column(elem_classes=["capabilities-box"]):
1343
+ gr.Markdown(
1344
+ """
1345
+ ### 🚀 Key Capabilities
1346
+
1347
+ Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities:
1348
+
1349
+ - 🌏 **4D Dynamic Reconstruction**: Reconstruct dynamic point clouds from videos by estimating depths and camera poses.
1350
+
1351
+ - 🎬 **Action-Conditioned Prediction**: Predict future frames based on initial observations, with optional camera trajectory actions.
1352
+
1353
+ - 🎯 **Goal-Conditioned Planning**: Generate planning paths from pairs of observation and goal images.
1354
+
1355
+ > *Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.*
1356
+ """,
1357
+ elem_classes=["capabilities-text"]
1358
+ )
1359
+
1360
+ with gr.Column(elem_classes=["limitations-box"]):
1361
+ gr.Markdown(
1362
+ """
1363
+ ### 📝 Current Limitations
1364
+
1365
+ Aether represents an initial step in our journey, trained entirely on synthetic data. While it demonstrates promising capabilities, it is important to be aware of its current limitations:
1366
+
1367
+ - 🔄 **Dynamic Scenarios**: Struggles with highly dynamic scenarios involving significant motion or dense crowds.
1368
+
1369
+ - 📸 **Camera Stability**: Camera pose estimation can be less stable in certain conditions.
1370
+
1371
+ - 📐 **Planning Range**: For visual planning tasks, we recommend keeping the observations and goals relatively close to ensure optimal performance.
1372
+
1373
+ > *We are actively working on the next generation of Aether and are committed to addressing these limitations in future releases.*
1374
+ """,
1375
+ elem_classes=["limitations-text"]
1376
  )
1377
 
1378
+ with gr.Row(elem_classes=["main-interface"]):
1379
+ with gr.Column(elem_classes=["input-column"]):
1380
+ gpu_time_warning = gr.Markdown(
1381
+ """
1382
+ <div class="warning-box">
1383
+ <strong>⚠️ Warning:</strong><br>
1384
+ Due to HuggingFace Spaces ZERO GPU quota limitations, only short video reconstruction tasks (less than 100 frames) can be completed online.
1385
+
1386
+ <strong>💻 Recommendation:</strong><br>
1387
+ We strongly encourage you to deploy Aether locally for:
1388
+ - Processing longer video reconstruction tasks
1389
+ - Better performance and full access to prediction and planning tasks
1390
+
1391
+ Visit our <a href="https://github.com/OpenRobotLab/Aether" target="_blank">GitHub repository</a> for local deployment instructions.
1392
+ </div>
1393
+ """,
1394
+ )
1395
+ with gr.Group(elem_classes=["task-selector"]):
1396
+ task = gr.Radio(
1397
+ ["reconstruction", "prediction", "planning"],
1398
+ label="Select Task",
1399
+ value="reconstruction",
1400
+ info="Choose the task you want to perform",
1401
  )
1402
 
1403
+ with gr.Group(elem_classes=["input-section"]):
1404
+ gr.Markdown("## 📥 Input", elem_classes=["task-header"])
1405
 
1406
+ # Task-specific inputs
1407
+ with gr.Group(visible=True) as reconstruction_group:
1408
+ video_input = gr.Video(
1409
+ label="Upload Input Video",
1410
+ sources=["upload"],
1411
+ interactive=True,
1412
+ elem_id="video_input",
1413
  )
1414
+ reconstruction_examples = gr.Examples(
1415
+ examples=[
1416
+ ["assets/example_videos/bridge.mp4"],
1417
+ ["assets/example_videos/moviegen.mp4"],
1418
+ ["assets/example_videos/nuscenes.mp4"],
1419
+ ["assets/example_videos/veo2.mp4"],
1420
+ ],
1421
+ inputs=[video_input],
1422
+ label="Reconstruction Examples",
1423
+ examples_per_page=4,
1424
  )
1425
 
1426
+ with gr.Group(visible=False) as prediction_group:
1427
+ image_input = gr.Image(
1428
+ label="Upload Start Image",
1429
+ type="filepath",
1430
+ interactive=True,
1431
+ elem_id="image_input",
 
1432
  )
1433
+ prediction_examples = gr.Examples(
1434
+ examples=[
1435
+ ["assets/example_obs/car.png"],
1436
+ ["assets/example_obs/cartoon.png"],
1437
+ ["assets/example_obs/garden.jpg"],
1438
+ ["assets/example_obs/room.jpg"],
1439
+ ],
1440
+ inputs=[image_input],
1441
+ label="Prediction Examples",
1442
+ examples_per_page=4,
1443
  )
1444
 
1445
+ with gr.Group(visible=False) as planning_group:
1446
+ with gr.Row():
1447
+ image_input_planning = gr.Image(
1448
+ label="Upload Start Image",
1449
+ type="filepath",
1450
+ interactive=True,
1451
+ elem_id="image_input_planning",
1452
+ )
1453
+ goal_input = gr.Image(
1454
+ label="Upload Goal Image",
1455
+ type="filepath",
1456
+ interactive=True,
1457
+ elem_id="goal_input",
1458
+ )
1459
+ planning_examples = gr.Examples(
1460
+ examples=[
1461
+ ["assets/example_obs_goal/01_obs.png", "assets/example_obs_goal/01_goal.png"],
1462
+ ["assets/example_obs_goal/02_obs.png", "assets/example_obs_goal/02_goal.png"],
1463
+ ["assets/example_obs_goal/03_obs.png", "assets/example_obs_goal/03_goal.png"],
1464
+ ["assets/example_obs_goal/04_obs.png", "assets/example_obs_goal/04_goal.png"],
1465
+ ],
1466
+ inputs=[image_input_planning, goal_input],
1467
+ label="Planning Examples",
1468
+ examples_per_page=4,
1469
  )
1470
 
1471
+ with gr.Row(visible=False) as preview_row:
1472
+ image_preview = gr.Image(
1473
+ label="Start Image Preview",
1474
+ elem_id="image_preview",
1475
+ visible=False,
1476
+ )
1477
+ goal_preview = gr.Image(
1478
+ label="Goal Image Preview",
1479
+ elem_id="goal_preview",
1480
+ visible=False,
1481
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
1482
 
1483
+ with gr.Group(elem_classes=["params-section", "compact-params"]):
1484
+ gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"])
 
 
 
 
1485
 
 
 
 
 
1486
  with gr.Row():
1487
  with gr.Column(scale=1):
1488
+ height = gr.Dropdown(
1489
+ choices=[480],
1490
+ value=480,
1491
+ label="Height",
1492
+ info="Height of the output video",
 
 
1493
  )
1494
 
 
1495
  with gr.Column(scale=1):
1496
+ width = gr.Dropdown(
1497
+ choices=[720],
1498
+ value=720,
1499
+ label="Width",
1500
+ info="Width of the output video",
 
 
1501
  )
1502
 
1503
  with gr.Row():
1504
  with gr.Column(scale=1):
1505
+ num_frames = gr.Dropdown(
1506
+ choices=[17, 25, 33, 41],
1507
+ value=41,
1508
+ label="Number of Frames",
1509
+ info="Number of frames to predict",
1510
  )
1511
 
1512
  with gr.Column(scale=1):
1513
+ fps = gr.Dropdown(
1514
+ choices=[8, 10, 12, 15, 24],
1515
+ value=12,
1516
+ label="FPS",
1517
+ info="Frames per second",
1518
  )
1519
 
1520
  with gr.Row():
1521
+ num_inference_steps = gr.Slider(
1522
+ minimum=1,
1523
+ maximum=60,
1524
+ value=4,
1525
+ step=1,
1526
+ label="Inference Steps",
1527
+ info="Number of inference step",
1528
+ )
 
 
 
 
 
 
 
 
 
 
 
1529
 
1530
+ sliding_window_stride = gr.Slider(
1531
  minimum=1,
1532
+ maximum=40,
1533
+ value=24,
1534
  step=1,
1535
+ label="Sliding Window Stride",
1536
+ info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task",
1537
+ visible=True,
1538
  )
1539
 
1540
+ use_dynamic_cfg = gr.Checkbox(
1541
+ label="Use Dynamic CFG",
1542
+ value=True,
1543
+ info="Use dynamic CFG",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1544
  visible=False,
 
1545
  )
1546
 
1547
+ raymap_option = gr.Radio(
1548
+ choices=["backward", "forward_right", "left_forward", "right"],
1549
+ label="Camera Movement Direction",
1550
+ value="forward_right",
1551
+ info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.",
1552
+ visible=False,
1553
+ )
1554
 
1555
+ post_reconstruction = gr.Checkbox(
1556
+ label="Post-Reconstruction",
1557
+ value=True,
1558
+ info="Run reconstruction after prediction for better quality",
1559
+ visible=False,
1560
+ )
1561
 
1562
+ with gr.Accordion(
1563
+ "Advanced Options", open=False, visible=True, elem_classes=["advanced-options-header"]
1564
+ ) as advanced_options:
1565
+ with gr.Group(elem_classes=["advanced-section"]):
1566
+ with gr.Row():
1567
+ guidance_scale = gr.Slider(
1568
+ minimum=1.0,
1569
+ maximum=10.0,
1570
+ value=1.0,
1571
+ step=0.1,
1572
+ label="Guidance Scale",
1573
+ info="Guidance scale (only for prediction / planning)",
1574
+ )
1575
+
1576
+ with gr.Row():
1577
+ seed = gr.Number(
1578
+ value=42,
1579
+ label="Random Seed",
1580
+ info="Set a seed for reproducible results",
1581
+ precision=0,
1582
+ minimum=0,
1583
+ maximum=2147483647,
1584
+ )
1585
+
1586
+ with gr.Row():
1587
+ with gr.Column(scale=1):
1588
+ smooth_camera = gr.Checkbox(
1589
+ label="Smooth Camera",
1590
+ value=True,
1591
+ info="Apply smoothing to camera trajectory",
1592
+ )
1593
+
1594
+ with gr.Column(scale=1):
1595
+ align_pointmaps = gr.Checkbox(
1596
+ label="Align Point Maps",
1597
+ value=False,
1598
+ info="Align point maps across frames",
1599
+ )
1600
+
1601
+ with gr.Row():
1602
+ with gr.Column(scale=1):
1603
+ max_depth = gr.Slider(
1604
+ minimum=10,
1605
+ maximum=200,
1606
+ value=60,
1607
+ step=10,
1608
+ label="Max Depth",
1609
+ info="Maximum depth for point cloud (higher = more distant points)",
1610
+ )
1611
+
1612
+ with gr.Column(scale=1):
1613
+ rtol = gr.Slider(
1614
+ minimum=0.01,
1615
+ maximum=2.0,
1616
+ value=0.2,
1617
+ step=0.01,
1618
+ label="Relative Tolerance",
1619
+ info="Used for depth edge detection. Lower = remove more edges",
1620
+ )
1621
+
1622
+ pointcloud_save_frame_interval = gr.Slider(
1623
+ minimum=1,
1624
+ maximum=20,
1625
+ value=10,
1626
+ step=1,
1627
+ label="Point Cloud Frame Interval",
1628
+ info="Save point cloud every N frames (higher = fewer files but less complete representation)",
1629
+ )
1630
 
1631
+ with gr.Group(elem_classes=["run-button-container"]):
1632
+ run_button = gr.Button("Run Aether", variant="primary", elem_classes=["run-button"])
1633
+
1634
+ with gr.Column(elem_classes=["output-column"]):
1635
+ with gr.Group(elem_classes=["output-panel"]):
1636
+ gr.Markdown("## 📤 Output", elem_classes=["task-header"])
1637
+
1638
+ with gr.Group(elem_classes=["output-section"]):
1639
+ gr.Markdown("### RGB Video", elem_classes=["output-section-title"])
1640
+ rgb_output = gr.Video(
1641
+ label="RGB Output", interactive=False, elem_id="rgb_output"
1642
+ )
1643
+
1644
+ with gr.Group(elem_classes=["output-section"]):
1645
+ gr.Markdown("### Depth Video", elem_classes=["output-section-title"])
1646
+ depth_output = gr.Video(
1647
+ label="Depth Output", interactive=False, elem_id="depth_output"
1648
+ )
1649
+
1650
+ with gr.Group(elem_classes=["output-section"]):
1651
+ gr.Markdown("### Point Clouds", elem_classes=["output-section-title"])
1652
+ with gr.Row(elem_classes=["pointcloud-controls"]):
1653
+ pointcloud_frames = gr.Dropdown(
1654
+ label="Select Frame",
1655
+ choices=[],
1656
+ value=None,
1657
+ interactive=True,
1658
+ elem_id="pointcloud_frames",
1659
+ )
1660
+ pointcloud_download = gr.DownloadButton(
1661
+ label="Download Point Cloud",
1662
+ visible=False,
1663
+ elem_id="pointcloud_download",
1664
+ )
1665
+
1666
+ model_output = gr.Model3D(
1667
+ label="Point Cloud Viewer", interactive=True, elem_id="model_output"
1668
+ )
1669
+
1670
+ gr.Markdown(
1671
+ """
1672
+ > **Note:** 3D point clouds take a long time to visualize, and we show the keyframes only.
1673
+ > You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.
1674
+ """
1675
+ )
1676
+
1677
+ with gr.Group(elem_classes=["output-section"]):
1678
+ gr.Markdown("### About Results", elem_classes=["output-section-title"])
1679
+ gr.Markdown(
1680
+ """
1681
+ #### Understanding the Outputs
1682
+
1683
+ - **RGB Video**: Shows the predicted or reconstructed RGB frames
1684
+ - **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue)
1685
+ - **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids
1686
+ """
1687
+ )
1688
 
1689
  # Event handlers
1690
  task.change(
1691
  fn=update_task_ui,
1692
  inputs=[task],
1693
  outputs=[
1694
+ reconstruction_group,
1695
+ prediction_group,
1696
+ planning_group,
1697
+ preview_row,
 
1698
  num_inference_steps,
1699
  sliding_window_stride,
1700
  use_dynamic_cfg,
 
1705
  )
1706
 
1707
  image_input.change(
1708
+ fn=update_image_preview,
1709
+ inputs=[image_input],
1710
+ outputs=[image_preview]
1711
  ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1712
 
1713
  goal_input.change(
1714
+ fn=update_goal_preview,
1715
+ inputs=[goal_input],
1716
+ outputs=[goal_preview]
1717
  ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1718
 
1719
  def update_pointcloud_frames(pointcloud_paths):
 
1935
  outputs=[pointcloud_download],
1936
  )
1937
 
 
 
 
 
 
 
 
 
 
1938
  # Load the model at startup
1939
+ demo.load(lambda: build_pipeline(torch.device("cpu")), inputs=None, outputs=None)
1940
 
1941
  if __name__ == "__main__":
1942
  os.environ["TOKENIZERS_PARALLELISM"] = "false"