Wenzheng Chang commited on
Commit
da3b980
·
1 Parent(s): 9562db5

add app.py

Browse files
Files changed (1) hide show
  1. app.py +1470 -0
app.py ADDED
@@ -0,0 +1,1470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import random
4
+ import re
5
+ from datetime import datetime
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import gradio as gr
9
+ import imageio.v3 as iio
10
+ import numpy as np
11
+ import PIL
12
+ import rootutils
13
+ import torch
14
+ from diffusers import (
15
+ AutoencoderKLCogVideoX,
16
+ CogVideoXDPMScheduler,
17
+ CogVideoXTransformer3DModel,
18
+ )
19
+ from transformers import AutoTokenizer, T5EncoderModel
20
+
21
+
22
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
23
+
24
+ from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402
25
+ AetherV1PipelineCogVideoX,
26
+ AetherV1PipelineOutput,
27
+ )
28
+ from aether.utils.postprocess_utils import ( # noqa: E402
29
+ align_camera_extrinsics,
30
+ apply_transformation,
31
+ colorize_depth,
32
+ compute_scale,
33
+ get_intrinsics,
34
+ interpolate_poses,
35
+ postprocess_pointmap,
36
+ project,
37
+ raymap_to_poses,
38
+ )
39
+ from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
40
+
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+
45
+ def seed_all(seed: int = 0) -> None:
46
+ """
47
+ Set random seeds of all components.
48
+ """
49
+ random.seed(seed)
50
+ np.random.seed(seed)
51
+ torch.manual_seed(seed)
52
+ torch.cuda.manual_seed_all(seed)
53
+
54
+
55
+ # Global pipeline
56
+ cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
57
+ aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
58
+ pipeline = AetherV1PipelineCogVideoX(
59
+ tokenizer=AutoTokenizer.from_pretrained(
60
+ cogvideox_pretrained_model_name_or_path,
61
+ subfolder="tokenizer",
62
+ ),
63
+ text_encoder=T5EncoderModel.from_pretrained(
64
+ cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
65
+ ),
66
+ vae=AutoencoderKLCogVideoX.from_pretrained(
67
+ cogvideox_pretrained_model_name_or_path, subfolder="vae"
68
+ ),
69
+ scheduler=CogVideoXDPMScheduler.from_pretrained(
70
+ cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
71
+ ),
72
+ transformer=CogVideoXTransformer3DModel.from_pretrained(
73
+ aether_pretrained_model_name_or_path, subfolder="transformer"
74
+ ),
75
+ )
76
+ pipeline.vae.enable_slicing()
77
+ pipeline.vae.enable_tiling()
78
+ pipeline.to(device)
79
+
80
+
81
+ def build_pipeline() -> AetherV1PipelineCogVideoX:
82
+ """Initialize the model pipeline."""
83
+ return pipeline
84
+
85
+
86
+ def get_window_starts(
87
+ total_frames: int, sliding_window_size: int, temporal_stride: int
88
+ ) -> List[int]:
89
+ """Calculate window start indices."""
90
+ starts = list(
91
+ range(
92
+ 0,
93
+ total_frames - sliding_window_size + 1,
94
+ temporal_stride,
95
+ )
96
+ )
97
+ if (
98
+ total_frames > sliding_window_size
99
+ and (total_frames - sliding_window_size) % temporal_stride != 0
100
+ ):
101
+ starts.append(total_frames - sliding_window_size)
102
+ return starts
103
+
104
+
105
+ def blend_and_merge_window_results(
106
+ window_results: List[AetherV1PipelineOutput], window_indices: List[int], args: Dict
107
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
108
+ """Blend and merge window results."""
109
+ merged_rgb = None
110
+ merged_disparity = None
111
+ merged_poses = None
112
+ merged_focals = None
113
+ align_pointmaps = args.get("align_pointmaps", True)
114
+ smooth_camera = args.get("smooth_camera", True)
115
+ smooth_method = args.get("smooth_method", "kalman") if smooth_camera else "none"
116
+
117
+ if align_pointmaps:
118
+ merged_pointmaps = None
119
+
120
+ w1 = window_results[0].disparity
121
+
122
+ for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)):
123
+ t_end = t_start + window_result.rgb.shape[0]
124
+ if idx == 0:
125
+ merged_rgb = window_result.rgb
126
+ merged_disparity = window_result.disparity
127
+ pointmap_dict = postprocess_pointmap(
128
+ window_result.disparity,
129
+ window_result.raymap,
130
+ vae_downsample_scale=8,
131
+ ray_o_scale_inv=0.1,
132
+ smooth_camera=smooth_camera,
133
+ smooth_method=smooth_method if smooth_camera else "none",
134
+ )
135
+ merged_poses = pointmap_dict["camera_pose"]
136
+ merged_focals = (
137
+ pointmap_dict["intrinsics"][:, 0, 0]
138
+ + pointmap_dict["intrinsics"][:, 1, 1]
139
+ ) / 2
140
+ if align_pointmaps:
141
+ merged_pointmaps = pointmap_dict["pointmap"]
142
+ else:
143
+ overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start
144
+
145
+ window_disparity = window_result.disparity
146
+
147
+ # Align disparity
148
+ disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1
149
+ scale = compute_scale(
150
+ window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]),
151
+ merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]),
152
+ disp_mask.reshape(1, -1, w1.shape[-1]),
153
+ )
154
+ window_disparity = scale * window_disparity
155
+
156
+ # Blend disparity
157
+ result_disparity = np.ones((t_end, *w1.shape[1:]))
158
+ result_disparity[:t_start] = merged_disparity[:t_start]
159
+ result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:]
160
+ weight = np.linspace(1, 0, overlap_t)[:, None, None]
161
+ result_disparity[t_start : t_start + overlap_t] = merged_disparity[
162
+ t_start : t_start + overlap_t
163
+ ] * weight + window_disparity[:overlap_t] * (1 - weight)
164
+ merged_disparity = result_disparity
165
+
166
+ # Blend RGB
167
+ result_rgb = np.ones((t_end, *w1.shape[1:], 3))
168
+ result_rgb[:t_start] = merged_rgb[:t_start]
169
+ result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:]
170
+ weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None]
171
+ result_rgb[t_start : t_start + overlap_t] = merged_rgb[
172
+ t_start : t_start + overlap_t
173
+ ] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb)
174
+ merged_rgb = result_rgb
175
+
176
+ # Align poses
177
+ window_raymap = window_result.raymap
178
+ window_poses, window_Fov_x, window_Fov_y = raymap_to_poses(
179
+ window_raymap, ray_o_scale_inv=0.1
180
+ )
181
+ rel_r, rel_t, rel_s = align_camera_extrinsics(
182
+ torch.from_numpy(window_poses[:overlap_t]),
183
+ torch.from_numpy(merged_poses[-overlap_t:]),
184
+ )
185
+ aligned_window_poses = (
186
+ apply_transformation(
187
+ torch.from_numpy(window_poses),
188
+ rel_r,
189
+ rel_t,
190
+ rel_s,
191
+ return_extri=True,
192
+ )
193
+ .cpu()
194
+ .numpy()
195
+ )
196
+
197
+ result_poses = np.ones((t_end, 4, 4))
198
+ result_poses[:t_start] = merged_poses[:t_start]
199
+ result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:]
200
+
201
+ # Interpolate poses in overlap region
202
+ weights = np.linspace(1, 0, overlap_t)
203
+ for t in range(overlap_t):
204
+ weight = weights[t]
205
+ pose1 = merged_poses[t_start + t]
206
+ pose2 = aligned_window_poses[t]
207
+ result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight)
208
+
209
+ merged_poses = result_poses
210
+
211
+ # Align intrinsics
212
+ window_intrinsics, _ = get_intrinsics(
213
+ batch_size=window_poses.shape[0],
214
+ h=window_result.disparity.shape[1],
215
+ w=window_result.disparity.shape[2],
216
+ fovx=window_Fov_x,
217
+ fovy=window_Fov_y,
218
+ )
219
+ window_focals = (
220
+ window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1]
221
+ ) / 2
222
+ scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean()
223
+ window_focals = scale * window_focals
224
+ result_focals = np.ones((t_end,))
225
+ result_focals[:t_start] = merged_focals[:t_start]
226
+ result_focals[t_start + overlap_t :] = window_focals[overlap_t:]
227
+ weight = np.linspace(1, 0, overlap_t)
228
+ result_focals[t_start : t_start + overlap_t] = merged_focals[
229
+ t_start : t_start + overlap_t
230
+ ] * weight + window_focals[:overlap_t] * (1 - weight)
231
+ merged_focals = result_focals
232
+
233
+ if align_pointmaps:
234
+ # Align pointmaps
235
+ window_pointmaps = postprocess_pointmap(
236
+ result_disparity[t_start:],
237
+ window_raymap,
238
+ vae_downsample_scale=8,
239
+ camera_pose=aligned_window_poses,
240
+ focal=window_focals,
241
+ ray_o_scale_inv=0.1,
242
+ smooth_camera=smooth_camera,
243
+ smooth_method=smooth_method if smooth_camera else "none",
244
+ )
245
+ result_pointmaps = np.ones((t_end, *w1.shape[1:], 3))
246
+ result_pointmaps[:t_start] = merged_pointmaps[:t_start]
247
+ result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][
248
+ overlap_t:
249
+ ]
250
+ weight = np.linspace(1, 0, overlap_t)[:, None, None, None]
251
+ result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[
252
+ t_start : t_start + overlap_t
253
+ ] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight)
254
+ merged_pointmaps = result_pointmaps
255
+
256
+ # project to pointmaps
257
+ height = args.get("height", 480)
258
+ width = args.get("width", 720)
259
+
260
+ intrinsics = [
261
+ np.array([[f, 0, 0.5 * width], [0, f, 0.5 * height], [0, 0, 1]])
262
+ for f in merged_focals
263
+ ]
264
+ if align_pointmaps:
265
+ pointmaps = merged_pointmaps
266
+ else:
267
+ pointmaps = np.stack(
268
+ [
269
+ project(
270
+ 1 / np.clip(merged_disparity[i], 1e-8, 1e8),
271
+ intrinsics[i],
272
+ merged_poses[i],
273
+ )
274
+ for i in range(merged_poses.shape[0])
275
+ ]
276
+ )
277
+
278
+ return merged_rgb, merged_disparity, merged_poses, pointmaps
279
+
280
+
281
+ def process_video_to_frames(video_path: str, fps_sample: int = 12) -> List[str]:
282
+ """Process video into frames and save them locally."""
283
+ # Create a unique output directory
284
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
285
+ output_dir = f"temp_frames_{timestamp}"
286
+ os.makedirs(output_dir, exist_ok=True)
287
+
288
+ # Read video
289
+ video = iio.imread(video_path)
290
+
291
+ # Calculate frame interval based on original video fps
292
+ if isinstance(video, np.ndarray):
293
+ # For captured videos
294
+ total_frames = len(video)
295
+ frame_interval = max(
296
+ 1, round(total_frames / (fps_sample * (total_frames / 30)))
297
+ )
298
+ else:
299
+ # Default if can't determine
300
+ frame_interval = 2
301
+
302
+ frame_paths = []
303
+ for i, frame in enumerate(video[::frame_interval]):
304
+ frame_path = os.path.join(output_dir, f"frame_{i:04d}.jpg")
305
+ if isinstance(frame, np.ndarray):
306
+ iio.imwrite(frame_path, frame)
307
+ frame_paths.append(frame_path)
308
+
309
+ return frame_paths, output_dir
310
+
311
+
312
+ def save_output_files(
313
+ rgb: np.ndarray,
314
+ disparity: np.ndarray,
315
+ poses: Optional[np.ndarray] = None,
316
+ raymap: Optional[np.ndarray] = None,
317
+ pointmap: Optional[np.ndarray] = None,
318
+ task: str = "reconstruction",
319
+ output_dir: str = "outputs",
320
+ **kwargs,
321
+ ) -> Dict[str, str]:
322
+ """
323
+ Save outputs and return paths to saved files.
324
+ """
325
+ os.makedirs(output_dir, exist_ok=True)
326
+
327
+ if pointmap is None and raymap is not None:
328
+ # Generate pointmap from raymap and disparity
329
+ smooth_camera = kwargs.get("smooth_camera", True)
330
+ smooth_method = (
331
+ kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
332
+ )
333
+
334
+ pointmap_dict = postprocess_pointmap(
335
+ disparity,
336
+ raymap,
337
+ vae_downsample_scale=8,
338
+ ray_o_scale_inv=0.1,
339
+ smooth_camera=smooth_camera,
340
+ smooth_method=smooth_method,
341
+ )
342
+ pointmap = pointmap_dict["pointmap"]
343
+
344
+ if poses is None and raymap is not None:
345
+ poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
346
+
347
+ # Create a unique filename
348
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
349
+ base_filename = f"{task}_{timestamp}"
350
+
351
+ # Paths for saved files
352
+ paths = {}
353
+
354
+ # Save RGB video
355
+ rgb_path = os.path.join(output_dir, f"{base_filename}_rgb.mp4")
356
+ iio.imwrite(
357
+ rgb_path,
358
+ (np.clip(rgb, 0, 1) * 255).astype(np.uint8),
359
+ fps=kwargs.get("fps", 12),
360
+ )
361
+ paths["rgb"] = rgb_path
362
+
363
+ # Save depth/disparity video
364
+ depth_path = os.path.join(output_dir, f"{base_filename}_disparity.mp4")
365
+ iio.imwrite(
366
+ depth_path,
367
+ (colorize_depth(disparity) * 255).astype(np.uint8),
368
+ fps=kwargs.get("fps", 12),
369
+ )
370
+ paths["disparity"] = depth_path
371
+
372
+ # Save point cloud GLB files
373
+ if pointmap is not None and poses is not None:
374
+ pointcloud_save_frame_interval = kwargs.get(
375
+ "pointcloud_save_frame_interval", 10
376
+ )
377
+ max_depth = kwargs.get("max_depth", 100.0)
378
+ rtol = kwargs.get("rtol", 0.03)
379
+
380
+ glb_paths = []
381
+ # Determine which frames to save based on the interval
382
+ frames_to_save = list(
383
+ range(0, pointmap.shape[0], pointcloud_save_frame_interval)
384
+ )
385
+
386
+ # Always include the first and last frame
387
+ if 0 not in frames_to_save:
388
+ frames_to_save.insert(0, 0)
389
+ if pointmap.shape[0] - 1 not in frames_to_save:
390
+ frames_to_save.append(pointmap.shape[0] - 1)
391
+
392
+ # Sort the frames to ensure they're in order
393
+ frames_to_save = sorted(set(frames_to_save))
394
+
395
+ for frame_idx in frames_to_save:
396
+ if frame_idx >= pointmap.shape[0]:
397
+ continue
398
+
399
+ predictions = {
400
+ "world_points": pointmap[frame_idx : frame_idx + 1],
401
+ "images": rgb[frame_idx : frame_idx + 1],
402
+ "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
403
+ "camera_poses": poses[frame_idx : frame_idx + 1],
404
+ }
405
+
406
+ glb_path = os.path.join(
407
+ output_dir, f"{base_filename}_pointcloud_frame_{frame_idx}.glb"
408
+ )
409
+
410
+ scene_3d = predictions_to_glb(
411
+ predictions,
412
+ filter_by_frames="all",
413
+ show_cam=True,
414
+ max_depth=max_depth,
415
+ rtol=rtol,
416
+ frame_rel_idx=float(frame_idx) / pointmap.shape[0],
417
+ )
418
+ scene_3d.export(glb_path)
419
+ glb_paths.append(glb_path)
420
+
421
+ paths["pointcloud_glbs"] = glb_paths
422
+
423
+ return paths
424
+
425
+
426
+ def process_reconstruction(
427
+ video_file,
428
+ height,
429
+ width,
430
+ num_frames,
431
+ num_inference_steps,
432
+ guidance_scale,
433
+ sliding_window_stride,
434
+ fps,
435
+ smooth_camera,
436
+ align_pointmaps,
437
+ max_depth,
438
+ rtol,
439
+ pointcloud_save_frame_interval,
440
+ seed,
441
+ progress=gr.Progress(),
442
+ ):
443
+ """
444
+ Process reconstruction task.
445
+ """
446
+ try:
447
+ gc.collect()
448
+ torch.cuda.empty_cache()
449
+
450
+ # Set random seed
451
+ seed_all(seed)
452
+
453
+ # Build the pipeline
454
+ pipeline = build_pipeline()
455
+
456
+ progress(0.1, "Loading video")
457
+ # Check if video_file is a string or a file object
458
+ if isinstance(video_file, str):
459
+ video_path = video_file
460
+ else:
461
+ video_path = video_file.name
462
+
463
+ video = iio.imread(video_path).astype(np.float32) / 255.0
464
+
465
+ # Setup arguments
466
+ args = {
467
+ "height": height,
468
+ "width": width,
469
+ "num_frames": num_frames,
470
+ "sliding_window_stride": sliding_window_stride,
471
+ "smooth_camera": smooth_camera,
472
+ "smooth_method": "kalman" if smooth_camera else "none",
473
+ "align_pointmaps": align_pointmaps,
474
+ "max_depth": max_depth,
475
+ "rtol": rtol,
476
+ "pointcloud_save_frame_interval": pointcloud_save_frame_interval,
477
+ }
478
+
479
+ # Process in sliding windows
480
+ window_results = []
481
+ window_indices = get_window_starts(
482
+ len(video), num_frames, sliding_window_stride
483
+ )
484
+
485
+ progress(0.2, f"Processing video in {len(window_indices)} windows")
486
+
487
+ for i, start_idx in enumerate(window_indices):
488
+ progress_val = 0.2 + (0.6 * (i / len(window_indices)))
489
+ progress(progress_val, f"Processing window {i+1}/{len(window_indices)}")
490
+
491
+ output = pipeline(
492
+ task="reconstruction",
493
+ image=None,
494
+ goal=None,
495
+ video=video[start_idx : start_idx + num_frames],
496
+ raymap=None,
497
+ height=height,
498
+ width=width,
499
+ num_frames=num_frames,
500
+ fps=fps,
501
+ num_inference_steps=num_inference_steps,
502
+ guidance_scale=guidance_scale,
503
+ use_dynamic_cfg=False,
504
+ generator=torch.Generator(device=device).manual_seed(seed),
505
+ )
506
+ window_results.append(output)
507
+
508
+ progress(0.8, "Merging results from all windows")
509
+ # Merge window results
510
+ (
511
+ merged_rgb,
512
+ merged_disparity,
513
+ merged_poses,
514
+ pointmaps,
515
+ ) = blend_and_merge_window_results(window_results, window_indices, args)
516
+
517
+ progress(0.9, "Saving output files")
518
+ # Save output files
519
+ output_dir = "outputs"
520
+ os.makedirs(output_dir, exist_ok=True)
521
+ output_paths = save_output_files(
522
+ rgb=merged_rgb,
523
+ disparity=merged_disparity,
524
+ poses=merged_poses,
525
+ pointmap=pointmaps,
526
+ task="reconstruction",
527
+ output_dir=output_dir,
528
+ fps=12,
529
+ **args,
530
+ )
531
+
532
+ progress(1.0, "Done!")
533
+
534
+ # Return paths for displaying
535
+ return (
536
+ output_paths["rgb"],
537
+ output_paths["disparity"],
538
+ output_paths.get("pointcloud_glbs", []),
539
+ )
540
+
541
+ except Exception:
542
+ import traceback
543
+
544
+ traceback.print_exc()
545
+ return None, None, []
546
+
547
+
548
+ def process_prediction(
549
+ image_file,
550
+ height,
551
+ width,
552
+ num_frames,
553
+ num_inference_steps,
554
+ guidance_scale,
555
+ use_dynamic_cfg,
556
+ raymap_option,
557
+ post_reconstruction,
558
+ fps,
559
+ smooth_camera,
560
+ align_pointmaps,
561
+ max_depth,
562
+ rtol,
563
+ pointcloud_save_frame_interval,
564
+ seed,
565
+ progress=gr.Progress(),
566
+ ):
567
+ """
568
+ Process prediction task.
569
+ """
570
+ try:
571
+ gc.collect()
572
+ torch.cuda.empty_cache()
573
+
574
+ # Set random seed
575
+ seed_all(seed)
576
+
577
+ # Build the pipeline
578
+ pipeline = build_pipeline()
579
+
580
+ progress(0.1, "Loading image")
581
+ # Check if image_file is a string or a file object
582
+ if isinstance(image_file, str):
583
+ image_path = image_file
584
+ else:
585
+ image_path = image_file.name
586
+
587
+ image = PIL.Image.open(image_path)
588
+
589
+ progress(0.2, "Running prediction")
590
+ # Run prediction
591
+ output = pipeline(
592
+ task="prediction",
593
+ image=image,
594
+ video=None,
595
+ goal=None,
596
+ raymap=np.load(f"assets/example_raymaps/raymap_{raymap_option}.npy"),
597
+ height=height,
598
+ width=width,
599
+ num_frames=num_frames,
600
+ fps=fps,
601
+ num_inference_steps=num_inference_steps,
602
+ guidance_scale=guidance_scale,
603
+ use_dynamic_cfg=use_dynamic_cfg,
604
+ generator=torch.Generator(device=device).manual_seed(seed),
605
+ return_dict=True,
606
+ )
607
+
608
+ # Show RGB output immediately
609
+ rgb_output = output.rgb
610
+
611
+ # Setup arguments for saving
612
+ args = {
613
+ "height": height,
614
+ "width": width,
615
+ "smooth_camera": smooth_camera,
616
+ "smooth_method": "kalman" if smooth_camera else "none",
617
+ "align_pointmaps": align_pointmaps,
618
+ "max_depth": max_depth,
619
+ "rtol": rtol,
620
+ "pointcloud_save_frame_interval": pointcloud_save_frame_interval,
621
+ }
622
+
623
+ if post_reconstruction:
624
+ progress(0.5, "Running post-reconstruction for better quality")
625
+ recon_output = pipeline(
626
+ task="reconstruction",
627
+ video=output.rgb,
628
+ height=height,
629
+ width=width,
630
+ num_frames=num_frames,
631
+ fps=fps,
632
+ num_inference_steps=4,
633
+ guidance_scale=1.0,
634
+ use_dynamic_cfg=False,
635
+ generator=torch.Generator(device=device).manual_seed(seed),
636
+ )
637
+
638
+ disparity = recon_output.disparity
639
+ raymap = recon_output.raymap
640
+ else:
641
+ disparity = output.disparity
642
+ raymap = output.raymap
643
+
644
+ progress(0.8, "Saving output files")
645
+ # Save output files
646
+ output_dir = "outputs"
647
+ os.makedirs(output_dir, exist_ok=True)
648
+ output_paths = save_output_files(
649
+ rgb=rgb_output,
650
+ disparity=disparity,
651
+ raymap=raymap,
652
+ task="prediction",
653
+ output_dir=output_dir,
654
+ fps=12,
655
+ **args,
656
+ )
657
+
658
+ progress(1.0, "Done!")
659
+
660
+ # Return paths for displaying
661
+ return (
662
+ output_paths["rgb"],
663
+ output_paths["disparity"],
664
+ output_paths.get("pointcloud_glbs", []),
665
+ )
666
+
667
+ except Exception:
668
+ import traceback
669
+
670
+ traceback.print_exc()
671
+ return None, None, []
672
+
673
+
674
+ def process_planning(
675
+ image_file,
676
+ goal_file,
677
+ height,
678
+ width,
679
+ num_frames,
680
+ num_inference_steps,
681
+ guidance_scale,
682
+ use_dynamic_cfg,
683
+ post_reconstruction,
684
+ fps,
685
+ smooth_camera,
686
+ align_pointmaps,
687
+ max_depth,
688
+ rtol,
689
+ pointcloud_save_frame_interval,
690
+ seed,
691
+ progress=gr.Progress(),
692
+ ):
693
+ """
694
+ Process planning task.
695
+ """
696
+ try:
697
+ gc.collect()
698
+ torch.cuda.empty_cache()
699
+
700
+ # Set random seed
701
+ seed_all(seed)
702
+
703
+ # Build the pipeline
704
+ pipeline = build_pipeline()
705
+
706
+ progress(0.1, "Loading images")
707
+ # Check if image_file and goal_file are strings or file objects
708
+ if isinstance(image_file, str):
709
+ image_path = image_file
710
+ else:
711
+ image_path = image_file.name
712
+
713
+ if isinstance(goal_file, str):
714
+ goal_path = goal_file
715
+ else:
716
+ goal_path = goal_file.name
717
+
718
+ image = PIL.Image.open(image_path)
719
+ goal = PIL.Image.open(goal_path)
720
+
721
+ progress(0.2, "Running planning")
722
+ # Run planning
723
+ output = pipeline(
724
+ task="planning",
725
+ image=image,
726
+ video=None,
727
+ goal=goal,
728
+ raymap=None,
729
+ height=height,
730
+ width=width,
731
+ num_frames=num_frames,
732
+ fps=fps,
733
+ num_inference_steps=num_inference_steps,
734
+ guidance_scale=guidance_scale,
735
+ use_dynamic_cfg=use_dynamic_cfg,
736
+ generator=torch.Generator(device=device).manual_seed(seed),
737
+ return_dict=True,
738
+ )
739
+
740
+ # Show RGB output immediately
741
+ rgb_output = output.rgb
742
+
743
+ # Setup arguments for saving
744
+ args = {
745
+ "height": height,
746
+ "width": width,
747
+ "smooth_camera": smooth_camera,
748
+ "smooth_method": "kalman" if smooth_camera else "none",
749
+ "align_pointmaps": align_pointmaps,
750
+ "max_depth": max_depth,
751
+ "rtol": rtol,
752
+ "pointcloud_save_frame_interval": pointcloud_save_frame_interval,
753
+ }
754
+
755
+ if post_reconstruction:
756
+ progress(0.5, "Running post-reconstruction for better quality")
757
+ recon_output = pipeline(
758
+ task="reconstruction",
759
+ video=output.rgb,
760
+ height=height,
761
+ width=width,
762
+ num_frames=num_frames,
763
+ fps=12,
764
+ num_inference_steps=4,
765
+ guidance_scale=1.0,
766
+ use_dynamic_cfg=False,
767
+ generator=torch.Generator(device=device).manual_seed(seed),
768
+ )
769
+
770
+ disparity = recon_output.disparity
771
+ raymap = recon_output.raymap
772
+ else:
773
+ disparity = output.disparity
774
+ raymap = output.raymap
775
+
776
+ progress(0.8, "Saving output files")
777
+ # Save output files
778
+ output_dir = "outputs"
779
+ os.makedirs(output_dir, exist_ok=True)
780
+ output_paths = save_output_files(
781
+ rgb=rgb_output,
782
+ disparity=disparity,
783
+ raymap=raymap,
784
+ task="planning",
785
+ output_dir=output_dir,
786
+ fps=fps,
787
+ **args,
788
+ )
789
+
790
+ progress(1.0, "Done!")
791
+
792
+ # Return paths for displaying
793
+ return (
794
+ output_paths["rgb"],
795
+ output_paths["disparity"],
796
+ output_paths.get("pointcloud_glbs", []),
797
+ )
798
+
799
+ except Exception:
800
+ import traceback
801
+
802
+ traceback.print_exc()
803
+ return None, None, []
804
+
805
+
806
+ def update_task_ui(task):
807
+ """Update UI elements based on selected task."""
808
+ if task == "reconstruction":
809
+ return (
810
+ gr.update(visible=True), # video_input
811
+ gr.update(visible=False), # image_input
812
+ gr.update(visible=False), # goal_input
813
+ gr.update(visible=False), # image_preview
814
+ gr.update(visible=False), # goal_preview
815
+ gr.update(value=4), # num_inference_steps
816
+ gr.update(visible=True), # sliding_window_stride
817
+ gr.update(visible=False), # use_dynamic_cfg
818
+ gr.update(visible=False), # raymap_option
819
+ gr.update(visible=False), # post_reconstruction
820
+ gr.update(value=1.0), # guidance_scale
821
+ )
822
+ elif task == "prediction":
823
+ return (
824
+ gr.update(visible=False), # video_input
825
+ gr.update(visible=True), # image_input
826
+ gr.update(visible=False), # goal_input
827
+ gr.update(visible=True), # image_preview
828
+ gr.update(visible=False), # goal_preview
829
+ gr.update(value=50), # num_inference_steps
830
+ gr.update(visible=False), # sliding_window_stride
831
+ gr.update(visible=True), # use_dynamic_cfg
832
+ gr.update(visible=True), # raymap_option
833
+ gr.update(visible=True), # post_reconstruction
834
+ gr.update(value=3.0), # guidance_scale
835
+ )
836
+ elif task == "planning":
837
+ return (
838
+ gr.update(visible=False), # video_input
839
+ gr.update(visible=True), # image_input
840
+ gr.update(visible=True), # goal_input
841
+ gr.update(visible=True), # image_preview
842
+ gr.update(visible=True), # goal_preview
843
+ gr.update(value=50), # num_inference_steps
844
+ gr.update(visible=False), # sliding_window_stride
845
+ gr.update(visible=True), # use_dynamic_cfg
846
+ gr.update(visible=False), # raymap_option
847
+ gr.update(visible=True), # post_reconstruction
848
+ gr.update(value=3.0), # guidance_scale
849
+ )
850
+
851
+
852
+ def update_image_preview(image_file):
853
+ """Update the image preview."""
854
+ if image_file:
855
+ return image_file.name
856
+ return None
857
+
858
+
859
+ def update_goal_preview(goal_file):
860
+ """Update the goal preview."""
861
+ if goal_file:
862
+ return goal_file.name
863
+ return None
864
+
865
+
866
+ def get_download_link(selected_frame, all_paths):
867
+ """Update the download button with the selected file path."""
868
+ if not selected_frame or not all_paths:
869
+ return gr.update(visible=False, value=None)
870
+
871
+ frame_num = int(re.search(r"Frame (\d+)", selected_frame).group(1))
872
+
873
+ for path in all_paths:
874
+ if f"frame_{frame_num}" in path:
875
+ # Make sure the file exists before setting it
876
+ if os.path.exists(path):
877
+ return gr.update(visible=True, value=path, interactive=True)
878
+
879
+ return gr.update(visible=False, value=None)
880
+
881
+
882
+ # Theme setup
883
+ theme = gr.themes.Default(
884
+ primary_hue="blue",
885
+ secondary_hue="cyan",
886
+ )
887
+
888
+ with gr.Blocks(
889
+ theme=theme,
890
+ css="""
891
+ .output-column {
892
+ min-height: 400px;
893
+ }
894
+ .warning {
895
+ color: #ff9800;
896
+ font-weight: bold;
897
+ }
898
+ .highlight {
899
+ background-color: rgba(0, 123, 255, 0.1);
900
+ padding: 10px;
901
+ border-radius: 8px;
902
+ border-left: 5px solid #007bff;
903
+ margin: 10px 0;
904
+ }
905
+ .task-header {
906
+ margin-top: 10px;
907
+ margin-bottom: 15px;
908
+ font-size: 1.2em;
909
+ font-weight: bold;
910
+ color: #007bff;
911
+ }
912
+ .flex-display {
913
+ display: flex;
914
+ flex-wrap: wrap;
915
+ gap: 10px;
916
+ }
917
+ .output-subtitle {
918
+ font-size: 1.1em;
919
+ margin-top: 5px;
920
+ margin-bottom: 5px;
921
+ color: #505050;
922
+ }
923
+ .input-section, .params-section, .advanced-section {
924
+ border: 1px solid #ddd;
925
+ padding: 15px;
926
+ border-radius: 8px;
927
+ margin-bottom: 15px;
928
+ }
929
+ .logo-container {
930
+ display: flex;
931
+ justify-content: center;
932
+ margin-bottom: 20px;
933
+ }
934
+ .logo-image {
935
+ max-width: 300px;
936
+ height: auto;
937
+ }
938
+ """,
939
+ ) as demo:
940
+ with gr.Row(elem_classes=["logo-container"]):
941
+ gr.Image("assets/logo.png", show_label=False, elem_classes=["logo-image"])
942
+
943
+ gr.Markdown(
944
+ """
945
+ # Aether: Geometric-Aware Unified World Modeling
946
+
947
+ Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with
948
+ generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities:
949
+
950
+ 1. **4D dynamic reconstruction** - Reconstruct dynamic point clouds from videos by estimating depths and camera poses.
951
+ 2. **Action-Conditioned Video Prediction** - Predict future frames based on initial observation images, with optional conditions of camera trajectory actions.
952
+ 3. **Goal-Conditioned Visual Planning** - Generate planning paths from pairs of observation and goal images.
953
+
954
+ Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.
955
+ """
956
+ )
957
+
958
+ with gr.Row():
959
+ with gr.Column(scale=1):
960
+ task = gr.Radio(
961
+ ["reconstruction", "prediction", "planning"],
962
+ label="Select Task",
963
+ value="reconstruction",
964
+ info="Choose the task you want to perform",
965
+ )
966
+
967
+ with gr.Group(elem_classes=["input-section"]):
968
+ # Input section - changes based on task
969
+ gr.Markdown("## 📥 Input", elem_classes=["task-header"])
970
+
971
+ # Task-specific inputs
972
+ video_input = gr.Video(
973
+ label="Upload Input Video",
974
+ sources=["upload"],
975
+ visible=True,
976
+ interactive=True,
977
+ elem_id="video_input",
978
+ )
979
+
980
+ image_input = gr.File(
981
+ label="Upload Start Image",
982
+ file_count="single",
983
+ file_types=["image"],
984
+ visible=False,
985
+ interactive=True,
986
+ elem_id="image_input",
987
+ )
988
+
989
+ goal_input = gr.File(
990
+ label="Upload Goal Image",
991
+ file_count="single",
992
+ file_types=["image"],
993
+ visible=False,
994
+ interactive=True,
995
+ elem_id="goal_input",
996
+ )
997
+
998
+ with gr.Row(visible=False) as preview_row:
999
+ image_preview = gr.Image(
1000
+ label="Start Image Preview",
1001
+ elem_id="image_preview",
1002
+ visible=False,
1003
+ )
1004
+ goal_preview = gr.Image(
1005
+ label="Goal Image Preview",
1006
+ elem_id="goal_preview",
1007
+ visible=False,
1008
+ )
1009
+
1010
+ with gr.Group(elem_classes=["params-section"]):
1011
+ gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"])
1012
+
1013
+ with gr.Row():
1014
+ with gr.Column(scale=1):
1015
+ height = gr.Dropdown(
1016
+ choices=[480],
1017
+ value=480,
1018
+ label="Height",
1019
+ info="Height of the output video",
1020
+ )
1021
+
1022
+ with gr.Column(scale=1):
1023
+ width = gr.Dropdown(
1024
+ choices=[720],
1025
+ value=720,
1026
+ label="Width",
1027
+ info="Width of the output video",
1028
+ )
1029
+
1030
+ with gr.Row():
1031
+ with gr.Column(scale=1):
1032
+ num_frames = gr.Dropdown(
1033
+ choices=[17, 25, 33, 41],
1034
+ value=41,
1035
+ label="Number of Frames",
1036
+ info="Number of frames to predict",
1037
+ )
1038
+
1039
+ with gr.Column(scale=1):
1040
+ fps = gr.Dropdown(
1041
+ choices=[8, 10, 12, 15, 24],
1042
+ value=12,
1043
+ label="FPS",
1044
+ info="Frames per second",
1045
+ )
1046
+
1047
+ with gr.Row():
1048
+ with gr.Column(scale=1):
1049
+ num_inference_steps = gr.Slider(
1050
+ minimum=1,
1051
+ maximum=60,
1052
+ value=4,
1053
+ step=1,
1054
+ label="Inference Steps",
1055
+ info="Number of inference step",
1056
+ )
1057
+
1058
+ sliding_window_stride = gr.Slider(
1059
+ minimum=1,
1060
+ maximum=40,
1061
+ value=24,
1062
+ step=1,
1063
+ label="Sliding Window Stride",
1064
+ info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task",
1065
+ visible=True,
1066
+ )
1067
+
1068
+ use_dynamic_cfg = gr.Checkbox(
1069
+ label="Use Dynamic CFG",
1070
+ value=True,
1071
+ info="Use dynamic CFG",
1072
+ visible=False,
1073
+ )
1074
+
1075
+ raymap_option = gr.Radio(
1076
+ choices=["backward", "forward_right", "left_forward", "right"],
1077
+ label="Camera Movement Direction",
1078
+ value="forward_right",
1079
+ info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.",
1080
+ visible=False,
1081
+ )
1082
+
1083
+ post_reconstruction = gr.Checkbox(
1084
+ label="Post-Reconstruction",
1085
+ value=True,
1086
+ info="Run reconstruction after prediction for better quality",
1087
+ visible=False,
1088
+ )
1089
+
1090
+ with gr.Accordion(
1091
+ "Advanced Options", open=False, visible=True
1092
+ ) as advanced_options:
1093
+ with gr.Group(elem_classes=["advanced-section"]):
1094
+ with gr.Row():
1095
+ with gr.Column(scale=1):
1096
+ guidance_scale = gr.Slider(
1097
+ minimum=1.0,
1098
+ maximum=10.0,
1099
+ value=1.0,
1100
+ step=0.1,
1101
+ label="Guidance Scale",
1102
+ info="Guidance scale (only for prediction / planning)",
1103
+ )
1104
+
1105
+ with gr.Row():
1106
+ with gr.Column(scale=1):
1107
+ seed = gr.Number(
1108
+ value=42,
1109
+ label="Random Seed",
1110
+ info="Set a seed for reproducible results",
1111
+ precision=0,
1112
+ minimum=0,
1113
+ maximum=2147483647,
1114
+ )
1115
+
1116
+ with gr.Row():
1117
+ with gr.Column(scale=1):
1118
+ smooth_camera = gr.Checkbox(
1119
+ label="Smooth Camera",
1120
+ value=True,
1121
+ info="Apply smoothing to camera trajectory",
1122
+ )
1123
+
1124
+ with gr.Column(scale=1):
1125
+ align_pointmaps = gr.Checkbox(
1126
+ label="Align Point Maps",
1127
+ value=False,
1128
+ info="Align point maps across frames",
1129
+ )
1130
+
1131
+ with gr.Row():
1132
+ with gr.Column(scale=1):
1133
+ max_depth = gr.Slider(
1134
+ minimum=10,
1135
+ maximum=200,
1136
+ value=60,
1137
+ step=10,
1138
+ label="Max Depth",
1139
+ info="Maximum depth for point cloud (higher = more distant points)",
1140
+ )
1141
+
1142
+ with gr.Column(scale=1):
1143
+ rtol = gr.Slider(
1144
+ minimum=0.01,
1145
+ maximum=2.0,
1146
+ value=0.03,
1147
+ step=0.01,
1148
+ label="Relative Tolerance",
1149
+ info="Used for depth edge detection. Lower = remove more edges",
1150
+ )
1151
+
1152
+ pointcloud_save_frame_interval = gr.Slider(
1153
+ minimum=1,
1154
+ maximum=20,
1155
+ value=10,
1156
+ step=1,
1157
+ label="Point Cloud Frame Interval",
1158
+ info="Save point cloud every N frames (higher = fewer files but less complete representation)",
1159
+ )
1160
+
1161
+ run_button = gr.Button("Run Aether", variant="primary")
1162
+
1163
+ with gr.Column(scale=1, elem_classes=["output-column"]):
1164
+ with gr.Group():
1165
+ gr.Markdown("## 📤 Output", elem_classes=["task-header"])
1166
+
1167
+ gr.Markdown("### RGB Video", elem_classes=["output-subtitle"])
1168
+ rgb_output = gr.Video(
1169
+ label="RGB Output", interactive=False, elem_id="rgb_output"
1170
+ )
1171
+
1172
+ gr.Markdown("### Depth Video", elem_classes=["output-subtitle"])
1173
+ depth_output = gr.Video(
1174
+ label="Depth Output", interactive=False, elem_id="depth_output"
1175
+ )
1176
+
1177
+ gr.Markdown("### Point Clouds", elem_classes=["output-subtitle"])
1178
+ with gr.Row(elem_classes=["flex-display"]):
1179
+ pointcloud_frames = gr.Dropdown(
1180
+ label="Select Frame",
1181
+ choices=[],
1182
+ value=None,
1183
+ interactive=True,
1184
+ elem_id="pointcloud_frames",
1185
+ )
1186
+ pointcloud_download = gr.DownloadButton(
1187
+ label="Download Point Cloud",
1188
+ visible=False,
1189
+ elem_id="pointcloud_download",
1190
+ )
1191
+
1192
+ model_output = gr.Model3D(
1193
+ label="Point Cloud Viewer", interactive=True, elem_id="model_output"
1194
+ )
1195
+
1196
+ with gr.Tab("About Results"):
1197
+ gr.Markdown(
1198
+ """
1199
+ ### Understanding the Outputs
1200
+
1201
+ - **RGB Video**: Shows the predicted or reconstructed RGB frames
1202
+ - **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue)
1203
+ - **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids
1204
+
1205
+ <p class="warning">Note: 3D point clouds take a long time to visualize, and we show the keyframes only.
1206
+ You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.</p>
1207
+ """
1208
+ )
1209
+
1210
+ # Event handlers
1211
+ task.change(
1212
+ fn=update_task_ui,
1213
+ inputs=[task],
1214
+ outputs=[
1215
+ video_input,
1216
+ image_input,
1217
+ goal_input,
1218
+ image_preview,
1219
+ goal_preview,
1220
+ num_inference_steps,
1221
+ sliding_window_stride,
1222
+ use_dynamic_cfg,
1223
+ raymap_option,
1224
+ post_reconstruction,
1225
+ guidance_scale,
1226
+ ],
1227
+ )
1228
+
1229
+ image_input.change(
1230
+ fn=update_image_preview, inputs=[image_input], outputs=[image_preview]
1231
+ ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1232
+
1233
+ goal_input.change(
1234
+ fn=update_goal_preview, inputs=[goal_input], outputs=[goal_preview]
1235
+ ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1236
+
1237
+ def update_pointcloud_frames(pointcloud_paths):
1238
+ """Update the pointcloud frames dropdown with available frames."""
1239
+ if not pointcloud_paths:
1240
+ return gr.update(choices=[], value=None), None, gr.update(visible=False)
1241
+
1242
+ # Extract frame numbers from filenames
1243
+ frame_info = []
1244
+ for path in pointcloud_paths:
1245
+ filename = os.path.basename(path)
1246
+ match = re.search(r"frame_(\d+)", filename)
1247
+ if match:
1248
+ frame_num = int(match.group(1))
1249
+ frame_info.append((f"Frame {frame_num}", path))
1250
+
1251
+ # Sort by frame number
1252
+ frame_info.sort(key=lambda x: int(re.search(r"Frame (\d+)", x[0]).group(1)))
1253
+
1254
+ choices = [label for label, _ in frame_info]
1255
+ paths = [path for _, path in frame_info]
1256
+
1257
+ if not choices:
1258
+ return gr.update(choices=[], value=None), None, gr.update(visible=False)
1259
+
1260
+ # Make download button visible when we have point cloud files
1261
+ return (
1262
+ gr.update(choices=choices, value=choices[0]),
1263
+ paths[0],
1264
+ gr.update(visible=True),
1265
+ )
1266
+
1267
+ def select_pointcloud_frame(frame_label, all_paths):
1268
+ """Select a specific pointcloud frame."""
1269
+ if not frame_label or not all_paths:
1270
+ return None
1271
+
1272
+ frame_num = int(re.search(r"Frame (\d+)", frame_label).group(1))
1273
+
1274
+ for path in all_paths:
1275
+ if f"frame_{frame_num}" in path:
1276
+ return path
1277
+
1278
+ return None
1279
+
1280
+ # Then in the run button click handler:
1281
+ def process_task(task_type, *args):
1282
+ """Process selected task with appropriate function."""
1283
+ if task_type == "reconstruction":
1284
+ rgb_path, depth_path, pointcloud_paths = process_reconstruction(*args)
1285
+ # Update the pointcloud frames dropdown
1286
+ frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
1287
+ pointcloud_paths
1288
+ )
1289
+ return (
1290
+ rgb_path,
1291
+ depth_path,
1292
+ initial_path,
1293
+ frame_dropdown,
1294
+ pointcloud_paths,
1295
+ download_visible,
1296
+ )
1297
+ elif task_type == "prediction":
1298
+ rgb_path, depth_path, pointcloud_paths = process_prediction(*args)
1299
+ frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
1300
+ pointcloud_paths
1301
+ )
1302
+ return (
1303
+ rgb_path,
1304
+ depth_path,
1305
+ initial_path,
1306
+ frame_dropdown,
1307
+ pointcloud_paths,
1308
+ download_visible,
1309
+ )
1310
+ elif task_type == "planning":
1311
+ rgb_path, depth_path, pointcloud_paths = process_planning(*args)
1312
+ frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
1313
+ pointcloud_paths
1314
+ )
1315
+ return (
1316
+ rgb_path,
1317
+ depth_path,
1318
+ initial_path,
1319
+ frame_dropdown,
1320
+ pointcloud_paths,
1321
+ download_visible,
1322
+ )
1323
+ return (
1324
+ None,
1325
+ None,
1326
+ None,
1327
+ gr.update(choices=[], value=None),
1328
+ [],
1329
+ gr.update(visible=False),
1330
+ )
1331
+
1332
+ # Store all pointcloud paths for later use
1333
+ all_pointcloud_paths = gr.State([])
1334
+
1335
+ run_button.click(
1336
+ fn=lambda task_type,
1337
+ video_file,
1338
+ image_file,
1339
+ goal_file,
1340
+ height,
1341
+ width,
1342
+ num_frames,
1343
+ num_inference_steps,
1344
+ guidance_scale,
1345
+ sliding_window_stride,
1346
+ use_dynamic_cfg,
1347
+ raymap_option,
1348
+ post_reconstruction,
1349
+ fps,
1350
+ smooth_camera,
1351
+ align_pointmaps,
1352
+ max_depth,
1353
+ rtol,
1354
+ pointcloud_save_frame_interval,
1355
+ seed: process_task(
1356
+ task_type,
1357
+ *(
1358
+ [
1359
+ video_file,
1360
+ height,
1361
+ width,
1362
+ num_frames,
1363
+ num_inference_steps,
1364
+ guidance_scale,
1365
+ sliding_window_stride,
1366
+ fps,
1367
+ smooth_camera,
1368
+ align_pointmaps,
1369
+ max_depth,
1370
+ rtol,
1371
+ pointcloud_save_frame_interval,
1372
+ seed,
1373
+ ]
1374
+ if task_type == "reconstruction"
1375
+ else [
1376
+ image_file,
1377
+ height,
1378
+ width,
1379
+ num_frames,
1380
+ num_inference_steps,
1381
+ guidance_scale,
1382
+ use_dynamic_cfg,
1383
+ raymap_option,
1384
+ post_reconstruction,
1385
+ fps,
1386
+ smooth_camera,
1387
+ align_pointmaps,
1388
+ max_depth,
1389
+ rtol,
1390
+ pointcloud_save_frame_interval,
1391
+ seed,
1392
+ ]
1393
+ if task_type == "prediction"
1394
+ else [
1395
+ image_file,
1396
+ goal_file,
1397
+ height,
1398
+ width,
1399
+ num_frames,
1400
+ num_inference_steps,
1401
+ guidance_scale,
1402
+ use_dynamic_cfg,
1403
+ post_reconstruction,
1404
+ fps,
1405
+ smooth_camera,
1406
+ align_pointmaps,
1407
+ max_depth,
1408
+ rtol,
1409
+ pointcloud_save_frame_interval,
1410
+ seed,
1411
+ ]
1412
+ ),
1413
+ ),
1414
+ inputs=[
1415
+ task,
1416
+ video_input,
1417
+ image_input,
1418
+ goal_input,
1419
+ height,
1420
+ width,
1421
+ num_frames,
1422
+ num_inference_steps,
1423
+ guidance_scale,
1424
+ sliding_window_stride,
1425
+ use_dynamic_cfg,
1426
+ raymap_option,
1427
+ post_reconstruction,
1428
+ fps,
1429
+ smooth_camera,
1430
+ align_pointmaps,
1431
+ max_depth,
1432
+ rtol,
1433
+ pointcloud_save_frame_interval,
1434
+ seed,
1435
+ ],
1436
+ outputs=[
1437
+ rgb_output,
1438
+ depth_output,
1439
+ model_output,
1440
+ pointcloud_frames,
1441
+ all_pointcloud_paths,
1442
+ pointcloud_download,
1443
+ ],
1444
+ )
1445
+
1446
+ pointcloud_frames.change(
1447
+ fn=select_pointcloud_frame,
1448
+ inputs=[pointcloud_frames, all_pointcloud_paths],
1449
+ outputs=[model_output],
1450
+ ).then(
1451
+ fn=get_download_link,
1452
+ inputs=[pointcloud_frames, all_pointcloud_paths],
1453
+ outputs=[pointcloud_download],
1454
+ )
1455
+
1456
+ # Example Accordion
1457
+ with gr.Accordion("Examples"):
1458
+ gr.Markdown(
1459
+ """
1460
+ ### Examples will be added soon
1461
+ Check back for example inputs for each task type.
1462
+ """
1463
+ )
1464
+
1465
+ # Load the model at startup
1466
+ demo.load(lambda: build_pipeline(), inputs=None, outputs=None)
1467
+
1468
+ if __name__ == "__main__":
1469
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
1470
+ demo.queue(max_size=20).launch(show_error=True, share=True)