In [3]:
import matplotlib.pyplot as pl
import numpy as np
import torch
from scipy.spatial.transform import Rotation
import plotly.graph_objects as go
import plotly.subplots as sp
import trimesh
import hydra
from omegaconf import OmegaConf, DictConfig
import os
import time
import copy
import shutil
import open3d as o3d

import rootutils
rootutils.setup_root("../fast3r", indicator=".project-root", pythonpath=True)

from fast3r.models.multiview_dust3r_module import MultiViewDUSt3RLitModule
from fast3r.dust3r.inference_multiview import inference
from fast3r.dust3r.model import FlashDUSt3R
from fast3r.dust3r.utils.image import load_images, rgb
from fast3r.dust3r.viz import CAM_COLORS, OPENGL, add_scene_cam, cat_meshes, pts3d_to_trimesh


pl.ion()


def get_reconstructed_scene(
 outdir,
 model,
 device,
 silent,
 image_size,
 filelist,
 profiling=False,
 dtype=torch.float32,
 rotate_clockwise_90=False,
 crop_to_landscape=False,
):
 """
 from a list of images, run dust3r inference, global aligner.
 then run get_3D_model_from_scene
 """
 multiple_views_in_one_sample = load_images(filelist, size=image_size, verbose=not silent, rotate_clockwise_90=rotate_clockwise_90, crop_to_landscape=crop_to_landscape)

 # time the inference
 start = time.time()
 output = inference(multiple_views_in_one_sample, model, device, dtype=dtype, verbose=not silent, profiling=profiling)
 end = time.time()
 print(f"Time elapsed: {end - start}")

 return output



def plot_rgb_images(views, title="RGB Images", save_image_to_folder=None):
 fig = sp.make_subplots(rows=1, cols=len(views), subplot_titles=[f"View {i} Image" for i in range(len(views))])

 # Plot the RGB images
 for i, view in enumerate(views):
 img_rgb = view['img'].cpu().numpy().squeeze().transpose(1, 2, 0) # Shape: (224, 224, 3)
 # Rescale RGB values from [-1, 1] to [0, 255]
 img_rgb = ((img_rgb + 1) * 127.5).astype(int).clip(0, 255)
 
 fig.add_trace(go.Image(z=img_rgb), row=1, col=i+1)

 if save_image_to_folder:
 img_path = os.path.join(save_image_to_folder, f"view_{i}.png")
 pl.imsave(img_path, img_rgb.astype(np.uint8))

 fig.update_layout(
 title=title,
 margin=dict(l=0, r=0, b=0, t=40)
 )

 # fig.show()

def plot_confidence_maps(preds, title="Confidence Maps", save_image_to_folder=None):
 fig = sp.make_subplots(rows=1, cols=len(preds), subplot_titles=[f"View {i} Confidence" for i in range(len(preds))])

 # Plot the confidence maps
 for i, pred in enumerate(preds):
 conf = pred['conf'].cpu().numpy().squeeze()
 fig.add_trace(go.Heatmap(z=conf, colorscale='turbo', showscale=False), row=1, col=i+1)

 if save_image_to_folder:
 conf_path = os.path.join(save_image_to_folder, f"view_{i}_conf.png")
 pl.imsave(conf_path, conf, cmap='turbo')

 fig.update_layout(
 title=title,
 margin=dict(l=0, r=0, b=0, t=40)
 )

 for i in range(len(preds)):
 fig['layout'][f'yaxis{i+1}'].update(autorange='reversed')

 # fig.show()

def maybe_plot_local_depth_and_conf(preds, title="Local Depth and Confidence Maps", save_image_to_folder=None):
 # Define the number of columns based on available keys
 num_plots = len(preds)
 rows = 2 # one for confidence maps, one for depth maps
 cols = num_plots

 # Create subplots for both confidence and depth maps
 fig = sp.make_subplots(
 rows=rows, 
 cols=cols, 
 subplot_titles=[f"View {i+1} Conf" if 'conf_local' in pred else f"View {i+1} No Conf" for i, pred in enumerate(preds)]
 )

 # Iterate over preds to add confidence and depth maps if the fields exist
 for i, pred in enumerate(preds):
 # Add confidence map if "conf_local" exists
 if 'conf_local' in pred:
 conf_local = pred['conf_local'].cpu().numpy().squeeze()
 fig.add_trace(go.Heatmap(z=conf_local, colorscale='Turbo', showscale=False), row=1, col=i+1)

 if save_image_to_folder:
 conf_local_path = os.path.join(save_image_to_folder, f"view_{i}_conf_local.png")
 pl.imsave(conf_local_path, conf_local, cmap='turbo')
 
 # Add depth map if "pts3d_local" exists
 if 'pts3d_local' in pred:
 # Extract Z values as depth from pts3d_local (XY plane)
 depth_local = pred['pts3d_local'][..., 2].cpu().numpy().squeeze() # Use the Z-coordinate
 fig.add_trace(go.Heatmap(z=depth_local, colorscale='Greys', showscale=False), row=2, col=i+1)

 if save_image_to_folder:
 depth_local_path = os.path.join(save_image_to_folder, f"view_{i}_depth_local.png")
 pl.imsave(depth_local_path, depth_local, cmap='Greys')
 

 # Update layout for the figure
 fig.update_layout(
 title=title,
 margin=dict(l=0, r=0, b=0, t=40)
 )

 # Reverse the y-axis for each subplot for consistency
 for i in range(num_plots):
 if 'conf_local' in preds[i]:
 fig['layout'][f'yaxis{i*2+1}'].update(autorange='reversed')
 if 'pts3d_local' in preds[i]:
 fig['layout'][f'yaxis{i*2+2}'].update(autorange='reversed')

 # fig.show()

def plot_3d_points_with_colors(preds, views, title="3D Points Visualization", flip_axes=False, as_mesh=False, min_conf_thr_percentile=80, export_ply_path=None):
 fig = go.Figure()

 all_points = []
 all_colors = []
 
 if as_mesh:
 meshes = []
 for i, pred in enumerate(preds):
 pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze() # Ensure tensor is on CPU and convert to numpy
 img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0) # Shape: (224, 224, 3)
 conf = pred['conf'].cpu().numpy().squeeze()

 # Determine the confidence threshold based on the percentile
 conf_thr = np.percentile(conf, min_conf_thr_percentile)

 # Filter points based on the confidence threshold
 mask = conf > conf_thr

 # Rescale RGB values from [-1, 1] to [0, 255]
 img_rgb = ((img_rgb + 1) * 127.5).astype(np.uint8).clip(0, 255)

 # Generate the mesh for the current view
 mesh_dict = pts3d_to_trimesh(img_rgb, pts3d, valid=mask)
 meshes.append(mesh_dict)

 # Concatenate all meshes
 combined_mesh = trimesh.Trimesh(**cat_meshes(meshes))

 # Flip axes if needed
 if flip_axes:
 combined_mesh.vertices[:, [1, 2]] = combined_mesh.vertices[:, [2, 1]]
 combined_mesh.vertices[:, 2] = -combined_mesh.vertices[:, 2]

 # Export as .ply if the path is provided
 if export_ply_path:
 combined_mesh.export(export_ply_path)

 # Add the combined mesh to the plotly figure
 vertex_colors = combined_mesh.visual.vertex_colors[:, :3] # Ensure the colors are in RGB format
 # Map vertex colors to face colors
 face_colors = []
 for face in combined_mesh.faces:
 face_colors.append(np.mean(vertex_colors[face], axis=0))
 face_colors = np.array(face_colors).astype(int)
 face_colors = ['rgb({}, {}, {})'.format(r, g, b) for r, g, b in face_colors]

 fig.add_trace(go.Mesh3d(
 x=combined_mesh.vertices[:, 0], 
 y=combined_mesh.vertices[:, 1], 
 z=combined_mesh.vertices[:, 2],
 i=combined_mesh.faces[:, 0], 
 j=combined_mesh.faces[:, 1], 
 k=combined_mesh.faces[:, 2],
 facecolor=face_colors,
 opacity=0.5,
 name="Combined Mesh"
 ))
 else:
 # Loop through each set of points in preds
 for i, pred in enumerate(preds):
 pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze() # Ensure tensor is on CPU and convert to numpy
 img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0) # Shape: (224, 224, 3)
 conf = pred['conf'].cpu().numpy().squeeze()

 # Determine the confidence threshold based on the percentile
 conf_thr = np.percentile(conf, min_conf_thr_percentile)

 # Flatten the points and colors
 x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()
 r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()
 conf_flat = conf.flatten()

 # Apply confidence mask
 mask = conf_flat > conf_thr
 x, y, z = x[mask], y[mask], z[mask]
 r, g, b = r[mask], g[mask], b[mask]

 # Collect points and colors for exporting
 all_points.append(np.vstack([x, y, z]).T)
 all_colors.append(np.vstack([r, g, b]).T)

 # Rescale RGB values from [-1, 1] to [0, 255]
 r = ((r + 1) * 127.5).astype(int).clip(0, 255)
 g = ((g + 1) * 127.5).astype(int).clip(0, 255)
 b = ((b + 1) * 127.5).astype(int).clip(0, 255)

 colors = ['rgb({}, {}, {})'.format(r[j], g[j], b[j]) for j in range(len(r))]
 
 # Check the flag and flip axes if needed
 if flip_axes:
 x, y, z = x, z, y
 z = -z

 # Add points to the plot
 fig.add_trace(go.Scatter3d(
 x=x, y=y, z=z,
 mode='markers',
 marker=dict(size=2, opacity=0.8, color=colors),
 name=f"View {i}"
 ))

 # Export as .ply if the path is provided
 if export_ply_path:
 all_points = np.vstack(all_points)
 all_colors = np.vstack(all_colors)
 point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)
 point_cloud.export(export_ply_path)

 fig.update_layout(
 title=title,
 scene=dict(
 xaxis_title='X',
 yaxis_title='Y',
 zaxis_title='Z'
 ),
 margin=dict(l=0, r=0, b=0, t=40),
 height=1000
 )

 fig.show()


import numpy as np
import torch
import plotly.graph_objects as go
from fast3r.dust3r.cloud_opt.init_im_poses import fast_pnp
from fast3r.dust3r.viz import auto_cam_size
from fast3r.dust3r.viz_plotly import SceneViz
from fast3r.dust3r.utils.image import rgb # Assuming you have this utility for image processing


# Function to visualize 3D points and camera poses with SceneViz
def plot_3d_points_with_estimated_camera_poses(preds, views, title="3D Points and Camera Poses", flip_axes=False, min_conf_thr_percentile=80, export_ply_path=None, export_html_path=None):
 # Initialize SceneViz for visualization
 viz = SceneViz()

 # Flip axes if requested
 if flip_axes:
 preds = copy.deepcopy(preds)
 for i, pred in enumerate(preds):
 pts3d = pred['pts3d_in_other_view']
 pts3d = pts3d[..., [0, 2, 1]] # Swap Y and Z axes
 pts3d[..., 2] *= -1 # Flip the sign of the Z axis
 pred['pts3d_in_other_view'] = pts3d # Reassign the modified points back to pred

 # Estimate camera poses and focal lengths
 poses_c2w, estimated_focals = MultiViewDUSt3RLitModule.estimate_camera_poses(preds, niter_PnP=10)
 poses_c2w = poses_c2w[0] # batch size is 1
 estimated_focals = estimated_focals[0] # batch size is 1
 cam_size = max(auto_cam_size(poses_c2w), 0.05) # Auto-scale based on the point cloud

 # Set up point clouds and visualization
 for i, (pred, pose_c2w) in enumerate(zip(preds, poses_c2w)):
 pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze() # (224, 224, 3)
 img_rgb = rgb(views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)) # Shape: (224, 224, 3)
 conf = pred['conf'].cpu().numpy().squeeze()

 # Determine the confidence threshold based on the percentile
 conf_thr = np.percentile(conf, min_conf_thr_percentile)
 mask = conf > conf_thr

 # Add the point cloud directly to the SceneViz object
 viz.add_pointcloud(pts3d, img_rgb, mask=mask, point_size=1.0, view_idx=i)

 # Add camera to the visualization
 viz.add_camera(
 pose_c2w=pose_c2w, # Estimated camera-to-world pose
 focal=estimated_focals[i], # Estimated focal length for each view
 color=np.random.randint(0, 256, size=3), # Generate a random RGB color for each camera
 image=img_rgb, # Image of the view
 cam_size=cam_size, # Auto-scaled camera size
 view_idx=i
 )

 # Export point clouds and meshes if the path is provided
 if export_ply_path:
 all_points = []
 all_colors = []
 for i, pred in enumerate(preds):
 pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()
 img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)
 conf = pred['conf'].cpu().numpy().squeeze()
 conf_thr = np.percentile(conf, min_conf_thr_percentile)
 mask = conf > conf_thr
 all_points.append(pts3d[mask])
 all_colors.append(img_rgb[mask])
 
 all_points = np.vstack(all_points)
 all_colors = np.vstack(all_colors)
 point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)
 point_cloud.export(export_ply_path)
 
 if export_html_path:
 viz.export_html(export_html_path)

 # Show the visualization
 viz.show()

def save_pointmaps_and_camera_parameters_to_folder(preds, save_folder, niter_PnP=100, focal_length_estimation_method='individual'):
 """
 Saves pointmaps and estimated camera parameters to a folder.

 Args:
 preds (list): List of prediction dictionaries containing point maps and confidence scores.
 save_folder (str): Path to the folder where the numpy data structure will be saved.
 """
 # Estimate camera poses and focal lengths
 poses_c2w, estimated_focals = MultiViewDUSt3RLitModule.estimate_camera_poses(preds, niter_PnP=niter_PnP, focal_length_estimation_method=focal_length_estimation_method)
 poses_c2w = poses_c2w[0] # Assuming batch size is 1
 estimated_focals = estimated_focals[0] # Assuming batch size is 1

 # Initialize lists to hold the data
 global_pointmap = []
 global_confidence = []
 local_pointmap = []
 local_aligned_to_global_pointmap = []
 local_confidence = []
 estimated_focals_list = []
 estimated_poses_c2w_list = []

 # Loop over predictions and extract required data
 for i, pred in enumerate(preds):
 # Extract global point map
 pts3d_in_other_view = pred['pts3d_in_other_view'].cpu().numpy().squeeze() # Shape: H x W x 3
 global_pointmap.append(pts3d_in_other_view)
 
 # Extract global confidence map
 conf = pred['conf'].cpu().numpy().squeeze() # Shape: H x W
 global_confidence.append(conf)
 
 # Extract local point map
 pts3d_local = pred['pts3d_local'].cpu().numpy().squeeze() # Shape: H x W x 3
 local_pointmap.append(pts3d_local)

 # Extract local aligned to global point map
 pts3d_local_aligned = pred['pts3d_local_aligned_to_global'].cpu().numpy().squeeze() # Shape: H x W x 3
 local_aligned_to_global_pointmap.append(pts3d_local_aligned)
 
 # Extract local confidence map
 conf_local = pred['conf_local'].cpu().numpy().squeeze()
 local_confidence.append(conf_local)

 # Append estimated focal length and camera pose
 focal = estimated_focals[i].item() if isinstance(estimated_focals[i], torch.Tensor) else estimated_focals[i]
 estimated_focals_list.append(focal)
 pose = poses_c2w[i].cpu().numpy() if isinstance(poses_c2w[i], torch.Tensor) else poses_c2w[i]
 estimated_poses_c2w_list.append(pose)

 # Ensure the save_folder exists
 os.makedirs(save_folder, exist_ok=True)

 # Create the file name inside the function
 save_file = os.path.join(save_folder, 'pointmaps_and_camera_params.npz')

 # Save the data to a numpy file
 np.savez(
 save_file,
 global_pointmaps=global_pointmap,
 global_confidence_maps=global_confidence,
 local_pointmaps=local_pointmap,
 local_aligned_to_global_pointmaps=local_aligned_to_global_pointmap,
 local_confidence_maps=local_confidence,
 estimated_focals=estimated_focals_list,
 estimated_poses_c2w=estimated_poses_c2w_list
 )


def export_combined_ply(preds, views, export_ply_path=None, 
 pts3d_key_to_visualize="pts3d_local_aligned_to_global",
 conf_key_to_visualize="conf_local",
 min_conf_thr_percentile=0, flip_axes=False, max_num_points=None, sampling_strategy='uniform'):
 all_points = []
 all_colors = []

 # Loop through each set of points in preds
 for i, pred in enumerate(preds):
 pts3d = pred[pts3d_key_to_visualize].cpu().numpy().squeeze() # Ensure tensor is on CPU and convert to numpy
 img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0) # Shape: (H, W, 3)
 conf = pred[conf_key_to_visualize].cpu().numpy().squeeze()

 # Determine the confidence threshold based on the percentile
 conf_thr = np.percentile(conf, min_conf_thr_percentile)

 # Flatten the points and colors
 x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()
 r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()
 conf_flat = conf.flatten()

 # Apply confidence mask
 mask = conf_flat > conf_thr
 x, y, z = x[mask], y[mask], z[mask]
 r, g, b = r[mask], g[mask], b[mask]

 # Rescale RGB values from [-1, 1] to [0, 255]
 r = ((r + 1) * 127.5).astype(np.uint8).clip(0, 255)
 g = ((g + 1) * 127.5).astype(np.uint8).clip(0, 255)
 b = ((b + 1) * 127.5).astype(np.uint8).clip(0, 255)

 # Collect points and colors for exporting
 points = np.vstack([x, y, z]).T
 colors = np.vstack([r, g, b]).T

 # Check the flag and flip axes if needed
 if flip_axes:
 points = points[:, [0, 2, 1]] # Swap y and z
 points[:, 2] = -points[:, 2] # Invert z-axis

 all_points.append(points)
 all_colors.append(colors)

 all_points = np.vstack(all_points)
 all_colors = np.vstack(all_colors)

 # If max_num_points is specified, downsample the point cloud using the selected sampling strategy
 if max_num_points is not None and len(all_points) > max_num_points:
 if sampling_strategy == 'uniform':
 # Uniform random sampling
 indices = np.random.choice(len(all_points), size=max_num_points, replace=False)
 all_points = all_points[indices]
 all_colors = all_colors[indices]
 elif sampling_strategy == 'voxel':
 # Voxel grid downsampling
 pcd = o3d.geometry.PointCloud()
 pcd.points = o3d.utility.Vector3dVector(all_points)
 pcd.colors = o3d.utility.Vector3dVector(all_colors.astype(np.float64) / 255.0)
 
 # Estimate a voxel size to achieve the desired number of points
 # This is a heuristic and may need adjustment
 bounding_box = pcd.get_axis_aligned_bounding_box()
 extent = bounding_box.get_extent()
 volume = extent[0] * extent[1] * extent[2]
 voxel_size = (volume / max_num_points) ** (1/3)

 down_pcd = pcd.voxel_down_sample(voxel_size)

 # Extract downsampled points and colors
 all_points = np.asarray(down_pcd.points)
 all_colors = (np.asarray(down_pcd.colors) * 255.0).astype(np.uint8)
 elif sampling_strategy == 'farthest_point':
 # Farthest point downsampling using Open3D
 # Note: May be slow for large point clouds
 pcd = o3d.geometry.PointCloud()
 pcd.points = o3d.utility.Vector3dVector(all_points)
 pcd.colors = o3d.utility.Vector3dVector(all_colors.astype(np.float64) / 255.0)

 down_pcd = pcd.farthest_point_down_sample(max_num_points)

 # Extract downsampled points and colors
 all_points = np.asarray(down_pcd.points)
 all_colors = (np.asarray(down_pcd.colors) * 255.0).astype(np.uint8)
 else:
 raise ValueError(f"Unsupported sampling strategy: {sampling_strategy}")

 # Export as .ply if the path is provided
 if export_ply_path:
 point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)
 point_cloud.export(export_ply_path)

 return all_points, all_colors


In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import random

data_root = "../data"

# filelist_train = [
# f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000001.jpg",
# f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000002.jpg"
# ]

# apple
# filelist_test = [
# f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000200.jpg",
# f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000085.jpg",
# f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000090.jpg",
# f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000170.jpg",
# f"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000199.jpg",
# ]


# bench test
# filelist_test = [
# f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000006.jpg",
# f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000016.jpg",
# f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000026.jpg",
# f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000036.jpg",
# f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000096.jpg",
# f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000126.jpg",
# # f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000156.jpg",
# f"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000186.jpg",
# ]

# # teddy bear train
# filelist_test = [
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000001.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000002.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000003.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000004.jpg",
# # f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000012.jpg",
# # f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000022.jpg",
# # f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000032.jpg",
# ]
# teddy bear test
# filelist_test = [
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000016.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000026.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000126.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000156.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000186.jpg",
# ]

# teddy bear random order
# filelist_test = [
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000126.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000026.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000186.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000016.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000156.jpg",
# ]


# filelist_test = [
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000066.jpg",
# f"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg",
# ]

# suitcase test
# filelist_test = [
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000006.jpg",
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000016.jpg",
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000026.jpg",
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000036.jpg",
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000096.jpg",
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000126.jpg",
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000156.jpg",
# f"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000186.jpg",
# ]

# cake test
# filelist_test = [
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000006.jpg",
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000016.jpg",
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000026.jpg",
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000036.jpg",
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000096.jpg",
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000126.jpg",
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000156.jpg",
# f"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000186.jpg",
# ]

# in-the-wild obj: book
# filelist_test = [
# f"{data_root}/unseen_book/IMG_9837.jpg",
# f"{data_root}/unseen_book/IMG_9838.jpg",
# f"{data_root}/unseen_book/IMG_9839.jpg",
# f"{data_root}/unseen_book/IMG_9840.jpg",
# f"{data_root}/unseen_book/IMG_9841.jpg",
# f"{data_root}/unseen_book/IMG_9842.jpg",
# f"{data_root}/unseen_book/IMG_9843.jpg",
# f"{data_root}/unseen_book/IMG_9844.jpg",
# ]

# in-the-wild obj: beef jerky
# filelist_test = [
# f"{data_root}/beef_jerky/IMG_0050.jpg",
# f"{data_root}/beef_jerky/IMG_0051.jpg",
# f"{data_root}/beef_jerky/IMG_0052.jpg",
# f"{data_root}/beef_jerky/IMG_0053.jpg",
# f"{data_root}/beef_jerky/IMG_0054.jpg",
# f"{data_root}/beef_jerky/IMG_0055.jpg",
# f"{data_root}/beef_jerky/IMG_0056.jpg",
# f"{data_root}/beef_jerky/IMG_0057.jpg",
# f"{data_root}/beef_jerky/IMG_0058.jpg",
# ]


# ArkitScenes
# filelist_test = [
# f"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_312.125.png",
# f"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_313.124.png",
# f"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_314.124.png",
# f"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_315.123.png",
# f"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_316.123.png",
# f"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_317.123.png",
# f"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_318.122.png",
# ]

# HSSD
# filelist_test = [
# f"{data_root}/0_102344022_0/rgb/0000{i:02d}.png" for i in range(8)
# ]

# filelist_test = [
# f"{data_root}/17_102344250_4/rgb/0000{i:02d}.png" for i in range(0,15)
# ]

# unseen obj: teddy bear from co3d
# filelist_test = [
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000006.jpg",
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000036.jpg",
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000056.jpg",
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000086.jpg",
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000096.jpg",
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000126.jpg",
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000156.jpg",
# "/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000186.jpg",
# ]


# unseen obj: keyboard from co3d
# filelist_test = [
# "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000096.jpg",
# "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000126.jpg",
# "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000156.jpg",
# # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000186.jpg",
# # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000006.jpg",
# # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000016.jpg",
# # "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000026.jpg",
# "/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000036.jpg",
# ]

# filelist_test = [
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000006.jpg",
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000016.jpg",
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000026.jpg",
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000036.jpg",
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000096.jpg",
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000126.jpg",
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000156.jpg",
# "/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000186.jpg",
# ]

# DTU
# filelist_test = [
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_001_max.png",
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_002_max.png",
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_003_max.png",
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_004_max.png",
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_005_max.png",
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_006_max.png",
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_007_max.png",
# "/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_008_max.png",
# ]


# DTU test
# filelist_test = [f"/path/to/dust3r_data/dtu_test_mvsnet_release/scan4/images/000000{i:02d}.jpg" for i in range(0, 49, 5)]
# filelist_test = [f"/path/to/dust3r_data/dtu_test_mvsnet_release/scan1/images/000000{i:02d}.jpg" for i in range(0, 49, 5)]

# NRGBD test
# filelist_test = [f"../data/neural_rgbd/kitchen/images/img{i}.png" for i in range(1, 1517, 50)]
# filelist_test = [f"../data/neural_rgbd/morning_apartment/images/img{i}.png" for i in range(1, 919, 30)]
# filelist_test = [f"../data/neural_rgbd/whiteroom/images/img{i}.png" for i in range(1, 1675, 50)]
# filelist_test = [f"../data/neural_rgbd/grey_white_room/images/img{i}.png" for i in range(1, 1492, 50)]
# filelist_test = [f"../data/neural_rgbd/green_room/images/img{i}.png" for i in range(1, 1441, 100)]
# filelist_test = [f"../data/neural_rgbd/staircase/images/img{i}.png" for i in range(0, 1148, 40)]

# 7-Scenes test
# filelist_test = [f"/path/to/dust3r_data/7_scenes_processed/redkitchen/seq-06/frame-00{i:04d}.color.png" for i in range(0, 1000, 50)]
# filelist_test = filelist_test[0:10] * 50
# filelist_test.pop(2)
# filelist_test = [f"/path/to/dust3r_data/7_scenes_processed/redkitchen/seq-06/frame-00{i:04d}.color.png" for i in range(0, 1000, 2)][:88]
# filelist_test = [f"/path/to/dust3r_data/7_scenes_processed/redkitchen/seq-03/frame-00{i:04d}.color.png" for i in range(0, 1000, 2)][:320]
# filelist_test = [f"/path/to/dust3r_data/7_scenes_processed/pumpkin/seq-02/frame-00{i:04d}.color.png" for i in range(0, 1000, 20)]
# filelist_test = [f"/path/to/dust3r_data/7_scenes_processed/office/seq-09/frame-00{i:04d}.color.png" for i in range(0, 1000, 20)]
# filelist_test = [f"/path/to/dust3r_data/7_scenes_processed/fire/seq-04/frame-00{i:04d}.color.png" for i in range(0, 1000, 30)]


# Tanks and Temples
# use all images from /home/ssax/InstantSplat/data/collated_instantsplat_data/eval/Tanks/Family/24_views/dust3r_9_views/images by walking through the folder
# filelist_test = []
# for root, dirs, files in os.walk("/home/ssax/InstantSplat/data/collated_instantsplat_data/eval/Tanks/Family/24_views/dust3r_9_views/images"):
# for file in files:
# filelist_test.append(os.path.join(root, file))
# filelist_test = sorted(filelist_test) 

# filelist_test = [f"/data/jianingy/tanks_and_temples/Barn/{i:06d}.jpg" for i in range(1, 410, 1)]
filelist_test = [f"/data/jianingy/tanks_and_temples_subset/Barn/{i:06d}.jpg" for i in range(1, 410, 2)]
filelist_test = [f"/data/jianingy/tanks_and_temples_subset/Lighthouse/{i:05d}.jpg" for i in range(1, 309, 1)]
filelist_test = [f"/data/jianingy/tanks_and_temples_subset/Playground/{i:05d}.jpg" for i in range(1, 307, 1)]
filelist_test = [f"/data/jianingy/tanks_and_temples_subset/Family/{i:05d}.jpg" for i in range(1, 152, 50)]
# filelist_test = [f"/data/jianingy/tanks_and_temples/Courthouse/images/{i:08d}.jpg" for i in range(1, 500, 1)]
# filelist_test = [f"/data/jianingy/tanks_and_temples/Ignatius/images/{i:08d}.jpg" for i in range(1, 262, 1)]

# RealEstate10K
# randomly sample 10 files from /data/jianingy/RealEstate10K/videos/test/93e6c08c33206a0c
# Randomly sample 10 files from the specified directory
# def sample_random_files(directory, n=10):
# all_files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
# return random.sample(all_files, n)

# # Sample 10 random files from the RealEstate10K directory
# filelist_test = sample_random_files("/data/jianingy/RealEstate10K/videos/test/924ccc02891cc7df", 10)


# # reverse the order
# filelist_test = filelist_test[::-1]

# filelist_test = [f"/home/ssax/InstantSplat/data/collated_instantsplat_data/eval/Tanks/Barn/images/000{i:03d}.jpg" for i in range(521, 670, 5)]


# multi-cam dynamic scenes
# Function to get sorted file list from a directory
def get_sorted_file_list(directory):
 return sorted([os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))])

# filelist_test = [get_sorted_file_list(f"/path/to/juggling_multicam/{cam_idx}/")[149] for cam_idx in range(0, 8)]



# display the images
def display_images(filelist, title, rotate_clockwise_90=False, crop_to_landscape=False):
 fig, axes = plt.subplots(1, len(filelist), figsize=(30, 4))
 fig.suptitle(title)
 for ax, filepath in zip(axes if hasattr(axes, '__iter__') else [axes], filelist):
 img = Image.open(filepath)
 if rotate_clockwise_90:
 img = img.rotate(-90, expand=True)
 if crop_to_landscape:
 # Crop to a landscape aspect ratio (e.g., 16:9)
 desired_aspect_ratio = 4 / 3
 width, height = img.size
 current_aspect_ratio = width / height

 if current_aspect_ratio > desired_aspect_ratio:
 # Wider than landscape: crop width
 new_width = int(height * desired_aspect_ratio)
 left = (width - new_width) // 2
 right = left + new_width
 top = 0
 bottom = height
 else:
 # Taller than landscape: crop height
 new_height = int(width / desired_aspect_ratio)
 top = (height - new_height) // 2
 bottom = top + new_height
 left = 0
 right = width
 
 img = img.crop((left, top, right, bottom))
 
 ax.imshow(img)
 ax.axis('off')
 plt.show()

# # Display train images
# display_images(filelist_train, 'Train Images')

# Display test images
display_images(filelist_test, 'Test Images')
# display_images(filelist_test, 'Test Images', rotate_clockwise_90=True)
# display_images(filelist_test, 'Test Images', rotate_clockwise_90=True, crop_to_landscape=True)

In [None]:
# skip this
device = torch.device("cuda")

checkpoint_root = "/path/to/checkpoint_root"

# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_longer_epochs/checkpoint-best.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview/checkpoint-best.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview_co3d_full/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview_co3d_full_100_epochs_100_samples_per_window/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset/checkpoint-best.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_dec_and_head/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_large/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth/checkpoint-last.pth').to(device)
model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/bf16_flash_attn_unfreeze_everything_co3d_scannetpp_megadepth_large_bs4/checkpoint-10.pth').to(device)
# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/bf16_flash_attn_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)

In [None]:
# Lightning model
%load_ext autoreload
%autoreload 2

from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

# device = torch.device("cuda:2")
device = torch.device("cuda")

checkpoint_root = "path/to/checkpoint_root"

# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/train/runs/2024-08-13_04-40-37" #fp32-fancy-sun-181
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/train/runs/2024-08-13_08-06-08" #fp32_workers11_giddy-gorge-182
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_3782640"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs2_views8/runs/fp32_bs2_views8_3782638"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4007485" # with random image idx embeddings
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4030983" # fix Regr3D loss (wrong rotation)
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4037511" # fix Regr3D loss (fixed rotation)
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_scannetpp_only/runs/fp32_bs6_views4_scannetpp_only_4060428" # ScanNet++ only no random emb
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_scannetpp_only/runs/fp32_bs6_views4_scannetpp_only_4051504" # ScanNet++ only
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_arkitscenes_only/runs/arkitscenes_only_4123064" # ARKitScenes only
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/arkitscenes_only_no_pairs/runs/arkitscenes_only_no_pairs_4129400" # ARKitScenes only no pairs
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp/runs/co3d_scannetpp_4123062" # co3d_scannetpp
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp_arkitscenes/runs/co3d_scannetpp_arkitscenes_4123063" # co3d_scannetpp_arkitscenes
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp_arkitscenes_bs2_views8/runs/co3d_scannetpp_arkitscenes_bs2_views8_4155008" # co3d_scannetpp_arkitscenes 8 views
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp_arkitscenes_better_random_pose_emb/runs/co3d_scannetpp_arkitscenes_better_random_pose_emb_4323524" # co3d_scannetpp_arkitscenes 8 views

# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_co3d_scannetpp_arkitscenes_better_random_pose_emb/runs/fast3r_co3d_scannetpp_arkitscenes_better_random_pose_emb_4365927"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_habitat_larger_decoder_views4/runs/fast3r_habitat_larger_decoder_views4_4383740"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_habitat_larger_decoder_views8/runs/fast3r_habitat_larger_decoder_views8_4383741"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_larger_decoder_bs1_views4/runs/fast3r_larger_decoder_4371625"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_habitat_co3d_scannetpp_arkitscenes/runs/fast3r_habitat_co3d_scannetpp_arkitscenes_4383742"

# local head
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_no_local_head/runs/fast3r_no_local_head_4611636"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_no_local_head_habitat/runs/fast3r_no_local_head_habitat_4615832"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head/runs/fast3r_local_head_4638120"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_habitat/runs/fast3r_local_head_habitat_4626119"

# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_habitat_better_scannetpp/runs/fast3r_local_head_habitat_better_scannetpp_4701731"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_habitat_better_scannetpp_8views/runs/fast3r_local_head_habitat_better_scannetpp_8views_4726417"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_16views/runs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_16views_4793676"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_20views_small_lr/runs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_20views_small_lr_4804512"


# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/more_co3d_finetune_16views/runs/more_co3d_finetune_16views_4865625"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/more_co3d_finetune_20views/runs/more_co3d_finetune_20views_4867088"

# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/super_long_training/runs/super_long_training_5031318"
# checkpoint_dir = f"{checkpoint_root}/dust3r/fast3r/logs/super_long_training/runs/super_long_training_5078043"
checkpoint_dir = f"/data/jianingy/dust3r_data/fast3r_checkpoints/super_long_training_5175604"


print("Creating an empty lightning module to hold the weights...")
cfg = OmegaConf.load(os.path.join(checkpoint_dir, '.hydra/config.yaml'))

# replace all occurances of "dust3r." in cfg.model.net with "fast3r.dust3r." (this is due to relocation of our code)
def replace_dust3r_in_config(cfg):
 for key, value in cfg.items():
 if isinstance(value, DictConfig):
 replace_dust3r_in_config(value)
 elif isinstance(value, str):
 if "dust3r." in value and "fast3r.dust3r." not in value:
 cfg[key] = value.replace("dust3r.", "fast3r.dust3r.")
 return cfg

def replace_src_in_config(cfg_dict):
 for key, value in cfg_dict.items():
 if isinstance(value, DictConfig):
 replace_src_in_config(value)
 elif isinstance(value, str) and "src." in value:
 cfg_dict[key] = value.replace("src.", "fast3r.")
 return cfg_dict

cfg.model.net = replace_dust3r_in_config(cfg.model.net)
cfg.model = replace_src_in_config(cfg.model)

if "encoder_args" in cfg.model.net:
 cfg.model.net.encoder_args.patch_embed_cls = "PatchEmbedDust3R"
 cfg.model.net.head_args.landscape_only = False
else:
 cfg.model.net.patch_embed_cls = "PatchEmbedDust3R" # TODO: investigate what exactly this does, this seems to support inferencing images of protrait orientation
 cfg.model.net.landscape_only = False # TODO: investigate what exactly this does


cfg.model.net.decoder_args.random_image_idx_embedding = True # try to load the model without random image idx embeddings

# enable attention biasing for inference more views than training
cfg.model.net.decoder_args.attn_bias_for_inference_enabled = False

lit_module = hydra.utils.instantiate(cfg.model, train_criterion=None, validation_criterion=None)


print("Loading weights from checkpoint...")


# check if checkpoint_dir + "/checkpoints/last.ckpt" is a directory, if so, load the last checkpoint from that directory
if os.path.isdir(checkpoint_dir + "/checkpoints/last.ckpt"):
 # it is a DeepSpeed checkpoint, convert it to a regular checkpoint
 CKPT_PATH = os.path.join(checkpoint_dir, 'checkpoints/last_aggregated.ckpt')
 if not os.path.exists(CKPT_PATH):
 convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir=checkpoint_dir + "/checkpoints/last.ckpt", output_file=CKPT_PATH, tag=None)
else:
 CKPT_PATH = os.path.join(checkpoint_dir, 'checkpoints/last.ckpt')

lit_module = MultiViewDUSt3RLitModule.load_from_checkpoint(checkpoint_path=CKPT_PATH,
 net=lit_module.net,
 train_criterion=lit_module.train_criterion,
 validation_criterion=lit_module.validation_criterion,)
lit_module.eval()
model = lit_module.net.to(device)

# model = torch.compile(model)

In [None]:
# model.set_max_parallel_views_for_head(150) # set the maximum number of parallel views for the head
model.set_max_parallel_views_for_head(20)

output = get_reconstructed_scene(
 outdir = "./output",
 model = model,
 device = device,
 silent = False,
 # image_size = 224,
 image_size = 512,
 filelist = filelist_test,
 profiling=True,
 dtype = torch.float32,
 # dtype = torch.bfloat16,
)


# local to global alignment
# before fix: lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=0)
lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=85)

In [None]:
# Camera pose evaluation on RealEstate10K
%load_ext autoreload
%autoreload 2

import os
import glob
import random
import numpy as np
import torch
from tqdm import tqdm
from PIL import Image

from fast3r.dust3r.datasets.utils.transforms import ImgNorm
from fast3r.dust3r.utils.geometry import inv
from fast3r.dust3r.utils.image import imread_cv2
import fast3r.dust3r.datasets.utils.cropping as cropping

# Suppose these references exist in your environment:
# - inference(...) function
# - lit_module that has evaluate_camera_poses(...)
# - model variable
# - crop_resize_if_necessary(...) from your snippet

# set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

def crop_resize_if_necessary(
 image, 
 intrinsics_3x3, 
 target_resolution=(512, 288),
 rng=None, 
 info=None
):
 """
 1. Crops around the principal point so that the principal point stays near center.
 2. Rescales to target_resolution (landscape 512×288 or swapped for portrait).
 3. Updates the intrinsics accordingly.
 
 In this example, 'depthmap' is not used (we pass None), but
 the logic is the same as your snippet for applying transformations.
 
 Args:
 image: A PIL.Image or numpy array (H×W×3).
 intrinsics_3x3: np.array (3,3) camera intrinsics (pixel-based).
 target_resolution: (W_target, H_target) in landscape mode. 
 If the image is portrait, we may swap them.
 rng: numpy RandomState or None; used if you do random data augmentation or random orientation.
 info: optional debug string

 Returns:
 image_out: PIL.Image resized to final shape
 intrinsics_out: The updated 3×3 matrix
 """
 # Convert to PIL if needed
 if isinstance(image, np.ndarray):
 image = Image.fromarray(image)

 # Pull out info from intrinsics
 # intrinsics_3x3[0,0] = fx, intrinsics_3x3[1,1] = fy,
 # intrinsics_3x3[0,2] = cx, intrinsics_3x3[1,2] = cy
 W_org, H_org = image.size
 cx, cy = int(round(intrinsics_3x3[0,2])), int(round(intrinsics_3x3[1,2]))

 # Basic check if principal point is not obviously invalid:
 min_margin_x = min(cx, W_org - cx)
 min_margin_y = min(cy, H_org - cy)
 if min_margin_x < W_org / 5 or min_margin_y < H_org / 5:
 # You might raise an error or do fallback
 # for example just center-crop in the middle
 pass

 # Crop around the principal point, symmetrical in x & y
 left = cx - min_margin_x
 top = cy - min_margin_y
 right = cx + min_margin_x
 bottom = cy + min_margin_y

 crop_bbox = (left, top, right, bottom)
 # For depthmap = None, we can pass None to the cropping utility
 image_c, _, intrinsics_c = cropping.crop_image_depthmap(
 image, 
 None, 
 intrinsics_3x3, 
 crop_bbox
 )

 # image_c is now a PIL.Image with size = (2*min_margin_x, 2*min_margin_y)
 W_c, H_c = image_c.size

 # Adjust target_resolution if the image is "portrait"
 # e.g. if H > W. 
 # If your logic is to always produce 512×288 for "landscape" and 288×512 for "portrait":
 # You can check aspect ratio:
 if H_c > W_c:
 # Swap if we need a "portrait" orientation
 # (288×512 instead of 512×288)
 target_resolution = (target_resolution[1], target_resolution[0])

 # Now do a high-quality downscale (Lanczos)
 # You can keep the same approach as your snippet or randomize if you do data augmentation
 image_rs, _, intrinsics_rs = cropping.rescale_image_depthmap(
 image_c, None, intrinsics_c, np.array(target_resolution)
 )

 # If there's still a small difference or if you do a final crop:
 intrinsics2 = cropping.camera_matrix_of_crop(
 intrinsics_rs, image_rs.size, target_resolution, offset_factor=0.5
 )
 final_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics_rs, intrinsics2, target_resolution)

 image_out, _, intrinsics_out = cropping.crop_image_depthmap(
 image_rs, None, intrinsics_rs, final_bbox
 )

 return image_out, intrinsics_out

re10k_video_root = "/data/jianingy/RealEstate10K/videos/test"
re10k_txt_root = "/data/jianingy/RealEstate10K/test"

# video_folders = sorted(os.listdir(re10k_video_root))
# video_folders = ['9414231317ded453']
# video_folders = ['0090cc64d7b7bb24'] # worst scene
video_folders = ['0be9a0dcbfe032f1'] # worst scene

for vid_folder in tqdm(video_folders, desc="Evaluating RealEstate10K Test Videos"):
 folder_path = os.path.join(re10k_video_root, vid_folder)
 if not os.path.isdir(folder_path):
 continue
 
 txt_path = os.path.join(re10k_txt_root, vid_folder + ".txt")
 if not os.path.exists(txt_path):
 # no .txt => skip
 continue
 
 # 1) Build a dict mapping "frame ID" => line columns
 # The first line of the .txt is the video URL, so skip it.
 with open(txt_path, "r") as f:
 txt_lines = f.read().strip().split("\n")
 if len(txt_lines) <= 1:
 continue
 
 txt_lines = txt_lines[1:] # skip the URL line
 # Create a dictionary like: lines_map["308641667"] = [col1, col2, fx, fy, cx, cy, ...]
 lines_map = {}
 for line in txt_lines:
 parts = line.strip().split()
 if len(parts) < 19: # In principle, should have 1 + 4 + 1 + 12 = 18 or 19 fields
 continue
 frame_id = parts[0] # e.g. "308641667"
 lines_map[frame_id] = parts # entire line columns for that ID

 # 2) Gather all JPG frames in this folder
 frame_files = sorted(glob.glob(os.path.join(folder_path, "*.jpg")))
 if len(frame_files) < 2:
 continue

 # 3) Sample a subset of frames
 # We can just sample from the actual files, 
 # then look up the line by matching the base filename
 n_to_sample = min(10, len(frame_files))
 sampled_frames = random.sample(frame_files, n_to_sample)
 # sampled_frames = frame_files[:n_to_sample]

 # 4) Build "views" for each sampled frame
 selected_views = []
 for frame_path in sorted(sampled_frames):
 # Extract "308641667" from "308641667.jpg"
 basename = os.path.splitext(os.path.basename(frame_path))[0]
 
 # Check if we have a matching line in lines_map
 if basename not in lines_map:
 # No match => skip
 # (This can happen if the .txt doesn't list every single frame or naming mismatch.)
 continue
 
 columns = lines_map[basename] # e.g. columns[1] => fx, columns[2] => fy, etc.
 
 # parse fx, fy, cx, cy
 fx = float(columns[1])
 fy = float(columns[2])
 cx = float(columns[3])
 cy = float(columns[4])
 
 # parse extrinsic (3×4), RE10K assumes row-major, where the translation is the last row
 # 1) Parse the 3x4 (row-major) extrinsic values
 # columns[7:19] is exactly 12 floats
 extrinsic_val = [float(v) for v in columns[7:19]]
 extrinsic = np.array(extrinsic_val, dtype=np.float64).reshape(3, 4)

 # 2) Build a 4x4 (row-major by default in NumPy)
 pose_4x4 = np.eye(4, dtype=np.float32)
 pose_4x4[:3, :3] = extrinsic[:3, :3]
 pose_4x4[:3, -1] = extrinsic[:3, -1]
 
 poses_c2w_gt = inv(pose_4x4)

 # read image
 img_rgb = imread_cv2(frame_path) # shape (H,W,3) in BGR or RGB depending on your function
 if img_rgb is None:
 continue
 
 H_org, W_org = img_rgb.shape[:2]
 
 # RealEstate10K formula: K = [[fx*W, 0, cx*W], [0, fy*H, cy*H], [0,0,1]]
 K_3x3 = np.array([
 [fx * W_org, 0.0, cx * W_org],
 [0.0, fy * H_org, cy * H_org],
 [0.0, 0.0, 1.0 ],
 ], dtype=np.float32)



 # Convert to PIL (if imread_cv2 is BGR, also convert to RGB)
 # e.g. if imread_cv2 returns BGR, do:
 # img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB)
 pil_img = Image.fromarray(img_rgb)

 # Crop + resize around principal point => 512×288
 final_img_pil, final_intrinsics_3x3 = crop_resize_if_necessary(
 image=pil_img,
 intrinsics_3x3=K_3x3,
 target_resolution=(512, 288),
 rng=np.random,
 info=f"{vid_folder}_{basename}"
 )

 # Now normalize to [-1,1], channel-first
 tensor_chw = ImgNorm(final_img_pil) # shape (3,H,W) in [-1,1]
 
 # Store in a view dict
 view_dict = {
 "img": tensor_chw.unsqueeze(0), # => (B=1,3,H,W)
 "camera_pose": torch.from_numpy(poses_c2w_gt).unsqueeze(0), # shape (1,4,4)
 "camera_intrinsics": torch.from_numpy(final_intrinsics_3x3).unsqueeze(0), # (1,3,3)
 "dataset": ["RealEstate10K"],
 "true_shape": torch.tensor([[final_img_pil.size[1], final_img_pil.size[0]]]) 
 # shape => (1, 2) = (height, width)
 }
 selected_views.append(view_dict)

 # If we ended up with fewer than 2 views, skip
 if len(selected_views) < 2:
 continue

 # 5) Run inference
 output = inference(
 selected_views,
 model=model,
 device=torch.device("cuda"),
 dtype=torch.float32,
 verbose=False,
 profiling=False
 )

 # 6) Evaluate camera poses
 cam_pose_result = lit_module.evaluate_camera_poses(
 views=output["views"],
 preds=output["preds"],
 niter_PnP=100,
 focal_length_estimation_method='first_view_from_global_head'
 # focal_length_estimation_method='first_view_from_local_head'
 )[0] # return a batch of results, we take the first one assuming batch size = 1
 
 # write cam pose result to a txt file, add a key of "video_name" to the result dict
 cam_pose_result["video_name"] = vid_folder
 # save the result to a txt file
 with open(f"/home/jianingy/research/fast3r/notebooks/RealEstate10K_eval/{vid_folder}.txt", "w") as f:
 f.write(str(cam_pose_result))

 # lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=85)


print("All done!")

In [None]:
from tqdm import tqdm

for timestep in tqdm(range(149)):

 filelist_test = [get_sorted_file_list(f"/path/to/juggling_multicam/{cam_idx}/")[149] for cam_idx in range(0, 8)]
 model.set_max_parallel_views_for_head(150) # set the maximum number of parallel views for the head

 output = get_reconstructed_scene(
 outdir = "./output",
 model = model,
 device = device,
 silent = False,
 # image_size = 224,
 image_size = 512,
 filelist = filelist_test,
 profiling=True,
 dtype = torch.float32,
 )


 # local to global alignment
 # before fix: lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=0)
 lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=85)
 
 img_output_dir = f"./output/juggling_multicam/{timestep}"
 dirs_to_create = [img_output_dir, os.path.join(img_output_dir, "rgb_images"), os.path.join(img_output_dir, "global_confidence_maps"), os.path.join(img_output_dir, "local_depth_and_confidence_maps")]

 if os.path.exists(img_output_dir):
 shutil.rmtree(img_output_dir)

 for d in dirs_to_create:
 if not os.path.exists(d):
 os.makedirs(d)

 # Usage example in your context
 # Plot the RGB images
 plot_rgb_images(output['views'], save_image_to_folder=os.path.join(img_output_dir, "rgb_images"))

 # Plot the confidence maps
 plot_confidence_maps(output['preds'], save_image_to_folder=os.path.join(img_output_dir, "global_confidence_maps"))

 # Plot the local depth and confidence maps
 maybe_plot_local_depth_and_conf(output['preds'], save_image_to_folder=os.path.join(img_output_dir, "local_depth_and_confidence_maps"))

 export_combined_ply(
 preds=output['preds'],
 views=output['views'],
 pts3d_key_to_visualize="pts3d_local_aligned_to_global",
 conf_key_to_visualize="conf_local",
 export_ply_path=os.path.join(img_output_dir, "combined_pointcloud.ply"),
 min_conf_thr_percentile=45,
 flip_axes=True,
 max_num_points=1_000_000, # Set your desired maximum number of points here
 sampling_strategy='uniform' # Choose 'uniform', 'voxel', or 'farthest_point'
 )

 save_pointmaps_and_camera_parameters_to_folder(preds=output['preds'], save_folder=img_output_dir, niter_PnP=100, focal_length_estimation_method='first_view_from_global_head')



In [None]:
# low conf views: [13:24]
conf_list = [pred['conf'] for pred in output['preds']]
conf = torch.stack(conf_list, dim=0).squeeze(1)
conf[1].max()

# plot a bar chart of the confidence scores (max value per view)
# conf has shape [num_views, H, W]

# Step 1: Extract the maximum confidence score for each view
max_conf_per_view = conf.view(conf.shape[0], -1).max(dim=1).values

# Step 2: Plot the bar chart
plt.figure(figsize=(25, 4))
plt.bar(range(len(max_conf_per_view)), max_conf_per_view.numpy())

# draw a horizontal red dotted line at the 1.5 threshold
plt.axhline(y=1.5, color='r', linestyle='--')

plt.xlabel('View Index')
plt.ylabel('Max Confidence Score')
plt.title('Max Confidence Score per View')
plt.show()

# print number of views vs. total number of views with confidence score > 1.5
print(f"Number of views with confidence score > 1.5: {torch.sum(max_conf_per_view > 1.5)} out of {len(max_conf_per_view)}")

# print the average and median confidence score
print(f"Average confidence score: {torch.mean(max_conf_per_view)}")
print(f"Median confidence score: {torch.median(max_conf_per_view)}")

In [None]:
# %load_ext autoreload
# %autoreload 2


img_output_dir = "./output/nrgbd_kitchen"
dirs_to_create = [img_output_dir, os.path.join(img_output_dir, "rgb_images"), os.path.join(img_output_dir, "global_confidence_maps"), os.path.join(img_output_dir, "local_depth_and_confidence_maps")]

if os.path.exists(img_output_dir):
 shutil.rmtree(img_output_dir)

for d in dirs_to_create:
 if not os.path.exists(d):
 os.makedirs(d)

# Usage example in your context
# Plot the RGB images
plot_rgb_images(output['views'], save_image_to_folder=os.path.join(img_output_dir, "rgb_images"))

# Plot the confidence maps
plot_confidence_maps(output['preds'], save_image_to_folder=os.path.join(img_output_dir, "global_confidence_maps"))

# Plot the local depth and confidence maps
maybe_plot_local_depth_and_conf(output['preds'], save_image_to_folder=os.path.join(img_output_dir, "local_depth_and_confidence_maps"))

export_combined_ply(
 preds=output['preds'],
 views=output['views'],
 pts3d_key_to_visualize="pts3d_local_aligned_to_global",
 conf_key_to_visualize="conf_local",
 export_ply_path=os.path.join(img_output_dir, "combined_pointcloud.ply"),
 min_conf_thr_percentile=15,
 flip_axes=True,
 max_num_points=1_000_000, # Set your desired maximum number of points here
 sampling_strategy='uniform' # Choose 'uniform', 'voxel', or 'farthest_point'
)

save_pointmaps_and_camera_parameters_to_folder(preds=output['preds'], save_folder=img_output_dir, niter_PnP=100, focal_length_estimation_method='first_view_from_global_head')

# Plot the 3D points along with estimated camera poses
# plot_3d_points_with_estimated_camera_poses(
# output['preds'], # Predictions containing 3D points
# output['views'], # Views containing RGB images
# flip_axes=True, # Enable flipping of axes (swap Y and Z and flip Z)
# min_conf_thr_percentile=0, # Confidence threshold percentile for filtering points
# # export_ply_path='./output/combined_mesh.ply' # Export path for the .ply file
# export_html_path='./output/combined_mesh.html' # Export path for the .html file
# )


In [17]:
server.stop()

In [None]:
# viser visualization

import time
import threading
import numpy as np
from tqdm.auto import tqdm
import imageio.v3 as iio
from matplotlib import cm

import viser
import viser.transforms as tf
from fast3r.dust3r.utils.device import to_numpy

def start_visualization(output, min_conf_thr_percentile=10, global_conf_thr_value_to_drop_view=1.5, port=8020):
 # Create the viser server on the specified port
 server = viser.ViserServer(host='127.0.0.1', port=port)

 # Estimate camera poses
 poses_c2w_batch, estimated_focals = MultiViewDUSt3RLitModule.estimate_camera_poses(
 output['preds'], niter_PnP=100, focal_length_estimation_method='first_view_from_global_head'
 )
 poses_c2w = poses_c2w_batch[0] # Assuming batch size of 1

 # Set the upward direction to negative Y-axis
 server.scene.set_up_direction((0.0, -1.0, 0.0))
 server.scene.world_axes.visible = False # Optional: Hide world axes

 num_frames = len(output['preds'])

 # Prepare lists to store per-frame data
 frame_data_list = []

 # Generate colors for frustums and points in rainbow order
 def rainbow_color(n, total):
 import colorsys
 hue = n / total
 rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
 return rgb

 # Add playback UI
 with server.gui.add_folder("Playback"):
 gui_point_size = server.gui.add_slider("Point size", min=0.000001, max=0.002, step=1e-5, initial_value=0.0005)
 gui_frustum_size_percent = server.gui.add_slider("Camera Size (%)", min=0.1, max=10.0, step=0.1, initial_value=2.0)
 gui_timestep = server.gui.add_slider("Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True)
 gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
 gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
 gui_playing = server.gui.add_checkbox("Playing", True)
 gui_framerate = server.gui.add_slider("FPS", min=0.25, max=60, step=0.25, initial_value=10)
 gui_framerate_options = server.gui.add_button_group("FPS options", ("0.5", "1", "10", "20", "30", "60"))

 # Add point cloud options UI
 with server.gui.add_folder("Point Cloud Options"):
 gui_show_global = server.gui.add_checkbox("Global", False)
 gui_show_local = server.gui.add_checkbox("Local", True)

 # Add view options UI
 with server.gui.add_folder("View Options"):
 gui_show_high_conf = server.gui.add_checkbox("Show High-Conf Views", True)
 gui_show_low_conf = server.gui.add_checkbox("Show Low-Conf Views", False)
 gui_global_conf_threshold = server.gui.add_slider("High/Low Conf threshold value", min=1.0, max=12.0, step=0.1, initial_value=global_conf_thr_value_to_drop_view)
 gui_min_conf_percentile = server.gui.add_slider("Per-View conf percentile", min=0, max=100, step=1, initial_value=min_conf_thr_percentile)

 # Add color options UI
 with server.gui.add_folder("Color Options"):
 gui_show_confidence = server.gui.add_checkbox("Show Confidence", False)
 gui_rainbow_color = server.gui.add_checkbox("Rainbow Colors", False)

 button_render_gif = server.gui.add_button("Render a GIF")

 # Frame step buttons
 @gui_next_frame.on_click
 def _(_) -> None:
 gui_timestep.value = (gui_timestep.value + 1) % num_frames

 @gui_prev_frame.on_click
 def _(_) -> None:
 gui_timestep.value = (gui_timestep.value - 1) % num_frames

 # Disable frame controls when we're playing
 @gui_playing.on_update
 def _(_) -> None:
 gui_timestep.disabled = gui_playing.value
 gui_next_frame.disabled = gui_playing.value
 gui_prev_frame.disabled = gui_playing.value

 # Set the framerate when we click one of the options
 @gui_framerate_options.on_click
 def _(_) -> None:
 gui_framerate.value = float(gui_framerate_options.value)

 server.scene.add_frame("/cams", show_axes=False)

 # First pass: Collect data and compute scene extent
 cumulative_pts = []

 for i in tqdm(range(num_frames)):
 pred = output['preds'][i]
 view = output['views'][i]

 # Extract global and local points and confidences
 pts3d_global = to_numpy(pred['pts3d_in_other_view'].cpu().squeeze())
 conf_global = to_numpy(pred['conf'].cpu().squeeze())
 pts3d_local = to_numpy(pred['pts3d_local_aligned_to_global'].cpu().squeeze())
 conf_local = to_numpy(pred['conf_local'].cpu().squeeze())
 img_rgb = to_numpy(view['img'].cpu().squeeze().permute(1, 2, 0))

 # Reshape and flatten data
 pts3d_global = pts3d_global.reshape(-1, 3)
 pts3d_local = pts3d_local.reshape(-1, 3)
 img_rgb = img_rgb.reshape(-1, 3)
 conf_global = conf_global.flatten()
 conf_local = conf_local.flatten()

 cumulative_pts.append(pts3d_global)

 # Store per-frame data
 frame_data = {}

 # Sort points by confidence in descending order
 # For global point cloud
 sort_indices_global = np.argsort(-conf_global)
 sorted_conf_global = conf_global[sort_indices_global]
 sorted_pts3d_global = pts3d_global[sort_indices_global]
 sorted_img_rgb_global = img_rgb[sort_indices_global]

 # For local point cloud
 sort_indices_local = np.argsort(-conf_local)
 sorted_conf_local = conf_local[sort_indices_local]
 sorted_pts3d_local = pts3d_local[sort_indices_local]
 sorted_img_rgb_local = img_rgb[sort_indices_local]

 # Normalize colors
 colors_rgb_global = ((sorted_img_rgb_global + 1) * 127.5).astype(np.uint8) / 255.0 # Values in [0,1]
 colors_rgb_local = ((sorted_img_rgb_local + 1) * 127.5).astype(np.uint8) / 255.0 # Values in [0,1]

 # Precompute confidence-based colors
 conf_norm_global = (sorted_conf_global - sorted_conf_global.min()) / (sorted_conf_global.max() - sorted_conf_global.min() + 1e-8)
 conf_norm_local = (sorted_conf_local - sorted_conf_local.min()) / (sorted_conf_local.max() - sorted_conf_local.min() + 1e-8)
 colormap = cm.turbo
 colors_confidence_global = colormap(conf_norm_global)[:, :3] # Values in [0,1]
 colors_confidence_local = colormap(conf_norm_local)[:, :3] # Values in [0,1]

 # Rainbow color for the frame's points
 rainbow_color_for_frame = rainbow_color(i, num_frames)
 colors_rainbow_global = np.tile(rainbow_color_for_frame, (sorted_pts3d_global.shape[0], 1))
 colors_rainbow_local = np.tile(rainbow_color_for_frame, (sorted_pts3d_local.shape[0], 1))

 # Compute initial high-confidence flag based on global confidence
 max_conf_global = conf_global.max()
 is_high_confidence = max_conf_global >= gui_global_conf_threshold.value

 # Camera parameters
 c2w = poses_c2w[i]
 height, width = view['img'].shape[2], view['img'].shape[3]
 focal_length = estimated_focals[0][i]
 img_rgb_reshaped = img_rgb.reshape(height, width, 3)
 img_rgb_normalized = ((img_rgb_reshaped + 1) * 127.5).astype(np.uint8) # Values in [0,255]
 img_downsampled = img_rgb_normalized[::4, ::4] # Keep as uint8

 # Store all precomputed data
 frame_data['sorted_pts3d_global'] = sorted_pts3d_global
 frame_data['colors_rgb_global'] = colors_rgb_global
 frame_data['colors_confidence_global'] = colors_confidence_global
 frame_data['colors_rainbow_global'] = colors_rainbow_global

 frame_data['sorted_pts3d_local'] = sorted_pts3d_local
 frame_data['colors_rgb_local'] = colors_rgb_local
 frame_data['colors_confidence_local'] = colors_confidence_local
 frame_data['colors_rainbow_local'] = colors_rainbow_local

 frame_data['max_conf_global'] = max_conf_global
 frame_data['is_high_confidence'] = is_high_confidence

 frame_data['c2w'] = c2w
 frame_data['height'] = height
 frame_data['width'] = width
 frame_data['focal_length'] = focal_length
 frame_data['img_downsampled'] = img_downsampled
 frame_data['rainbow_color'] = rainbow_color_for_frame

 frame_data_list.append(frame_data)

 # Compute scene extent and max_extent
 cumulative_pts_combined = np.concatenate(cumulative_pts, axis=0)
 min_coords = np.min(cumulative_pts_combined, axis=0)
 max_coords = np.max(cumulative_pts_combined, axis=0)
 scene_extent = max_coords - min_coords
 max_extent = np.max(scene_extent)

 # Now create the visualization nodes
 for i in tqdm(range(num_frames)):
 frame_data = frame_data_list[i]

 # Initialize frame node
 frame_node = server.scene.add_frame(f"/cams/t{i}", show_axes=False)

 # Initialize point cloud nodes
 # Global point cloud
 point_node_global = server.scene.add_point_cloud(
 name=f"/pts3d_global/t{i}",
 points=frame_data['sorted_pts3d_global'],
 colors=frame_data['colors_rgb_global'],
 point_size=gui_point_size.value,
 point_shape="rounded",
 visible=False, # Initially hidden
 )

 # Local point cloud
 point_node_local = server.scene.add_point_cloud(
 name=f"/pts3d_local/t{i}",
 points=frame_data['sorted_pts3d_local'],
 colors=frame_data['colors_rgb_local'],
 point_size=gui_point_size.value,
 point_shape="rounded",
 visible=True if frame_data_list[i]['is_high_confidence'] else False,
 )

 # Compute frustum parameters
 c2w = frame_data['c2w']
 rotation_matrix = c2w[:3, :3]
 position = c2w[:3, 3]
 rotation_quaternion = tf.SO3.from_matrix(rotation_matrix).wxyz

 fov = 2 * np.arctan2(frame_data['height'] / 2, frame_data['focal_length'])
 aspect_ratio = frame_data['width'] / frame_data['height']
 frustum_scale = max_extent * (gui_frustum_size_percent.value / 100.0)

 frustum_node = server.scene.add_camera_frustum(
 name=f"/cams/t{i}/frustum",
 fov=fov,
 aspect=aspect_ratio,
 scale=frustum_scale,
 color=frame_data['rainbow_color'],
 image=frame_data['img_downsampled'],
 wxyz=rotation_quaternion,
 position=position,
 visible=True if frame_data_list[i]['is_high_confidence'] else False,
 )

 # Store nodes
 frame_data['frame_node'] = frame_node
 frame_data['point_node_global'] = point_node_global
 frame_data['point_node_local'] = point_node_local
 frame_data['frustum_node'] = frustum_node

 # Set initial visibility
 for frame_data in frame_data_list:
 frame_data['frame_node'].visible = False
 frame_data['point_node_global'].visible = False
 frame_data['point_node_local'].visible = False
 frame_data['frustum_node'].visible = False

 def update_visibility():
 current_timestep = int(gui_timestep.value)
 with server.atomic():
 for i in range(num_frames):
 frame_data = frame_data_list[i]
 if i <= current_timestep:
 is_high_confidence = frame_data['is_high_confidence']
 show_frame = False
 if is_high_confidence and gui_show_high_conf.value:
 show_frame = True
 if not is_high_confidence and gui_show_low_conf.value:
 show_frame = True

 # Update visibility based on global point cloud confidence
 frame_data['frame_node'].visible = show_frame
 frame_data['frustum_node'].visible = show_frame

 # Show/hide global point cloud
 frame_data['point_node_global'].visible = show_frame and gui_show_global.value

 # Show/hide local point cloud
 frame_data['point_node_local'].visible = show_frame and gui_show_local.value
 else:
 frame_data['frame_node'].visible = False
 frame_data['frustum_node'].visible = False
 frame_data['point_node_global'].visible = False
 frame_data['point_node_local'].visible = False
 server.flush()

 @gui_timestep.on_update
 def _(_) -> None:
 update_visibility()

 @gui_point_size.on_update
 def _(_) -> None:
 with server.atomic():
 for frame_data in frame_data_list:
 frame_data['point_node_global'].point_size = gui_point_size.value
 frame_data['point_node_local'].point_size = gui_point_size.value
 server.flush()

 @gui_frustum_size_percent.on_update
 def _(_) -> None:
 frustum_scale = max_extent * (gui_frustum_size_percent.value / 100.0)
 with server.atomic():
 for frame_data in frame_data_list:
 frame_data['frustum_node'].scale = frustum_scale
 server.flush()

 @gui_show_confidence.on_update
 def _(_) -> None:
 update_point_cloud_colors()

 @gui_rainbow_color.on_update
 def _(_) -> None:
 update_point_cloud_colors()

 @gui_show_global.on_update
 def _(_) -> None:
 update_visibility()

 @gui_show_local.on_update
 def _(_) -> None:
 update_visibility()

 def update_point_cloud_colors():
 with server.atomic():
 for frame_data in frame_data_list:
 num_points_to_show_global = frame_data.get('num_points_to_show_global', len(frame_data['sorted_pts3d_global']))
 num_points_to_show_local = frame_data.get('num_points_to_show_local', len(frame_data['sorted_pts3d_local']))

 # Update global point cloud colors
 if gui_show_confidence.value:
 colors_global = frame_data['colors_confidence_global'][:num_points_to_show_global]
 elif gui_rainbow_color.value:
 colors_global = frame_data['colors_rainbow_global'][:num_points_to_show_global]
 else:
 colors_global = frame_data['colors_rgb_global'][:num_points_to_show_global]
 frame_data['point_node_global'].colors = colors_global

 # Update local point cloud colors
 if gui_show_confidence.value:
 colors_local = frame_data['colors_confidence_local'][:num_points_to_show_local]
 elif gui_rainbow_color.value:
 colors_local = frame_data['colors_rainbow_local'][:num_points_to_show_local]
 else:
 colors_local = frame_data['colors_rgb_local'][:num_points_to_show_local]
 frame_data['point_node_local'].colors = colors_local
 server.flush()

 @gui_show_high_conf.on_update
 def _(_) -> None:
 update_visibility()

 @gui_show_low_conf.on_update
 def _(_) -> None:
 update_visibility()

 @gui_global_conf_threshold.on_update
 def _(_) -> None:
 # Update high-confidence flags based on new threshold
 for frame_data in frame_data_list:
 is_high_confidence = frame_data['max_conf_global'] >= gui_global_conf_threshold.value
 frame_data['is_high_confidence'] = is_high_confidence
 update_visibility()

 @gui_min_conf_percentile.on_update
 def _(_) -> None:
 # Update number of points to display based on percentile
 percentile = gui_min_conf_percentile.value
 with server.atomic():
 for frame_data in frame_data_list:
 # For global point cloud
 total_points_global = len(frame_data['sorted_pts3d_global'])
 num_points_to_show_global = int(total_points_global * (100 - percentile) / 100)
 num_points_to_show_global = max(1, num_points_to_show_global) # Ensure at least one point
 frame_data['num_points_to_show_global'] = num_points_to_show_global
 frame_data['point_node_global'].points = frame_data['sorted_pts3d_global'][:num_points_to_show_global]

 # For local point cloud
 total_points_local = len(frame_data['sorted_pts3d_local'])
 num_points_to_show_local = int(total_points_local * (100 - percentile) / 100)
 num_points_to_show_local = max(1, num_points_to_show_local) # Ensure at least one point
 frame_data['num_points_to_show_local'] = num_points_to_show_local
 frame_data['point_node_local'].points = frame_data['sorted_pts3d_local'][:num_points_to_show_local]

 # Update colors
 update_point_cloud_colors()
 server.flush()

 def playback_loop():
 while True:
 if gui_playing.value:
 gui_timestep.value = (int(gui_timestep.value) + 1) % num_frames
 time.sleep(1.0 / gui_framerate.value)

 playback_thread = threading.Thread(target=playback_loop)
 playback_thread.start()

 @button_render_gif.on_click
 def _(event: viser.GuiEvent) -> None:
 client = event.client
 if client is None:
 print("Error: No client connected.")
 return
 try:
 images = []
 original_timestep = gui_timestep.value
 original_playing = gui_playing.value
 gui_playing.value = False
 fps = gui_framerate.value
 for i in range(num_frames):
 gui_timestep.value = i
 time.sleep(0.1)
 image = client.get_render(height=720, width=1280)
 images.append(image)
 gif_bytes = iio.imwrite("", images, extension=".gif", fps=fps, loop=0)
 client.send_file_download("visualization.gif", gif_bytes)
 gui_timestep.value = original_timestep
 gui_playing.value = original_playing
 except Exception as e:
 print(f"Error while rendering GIF: {e}")

 print(f"Visualization setup complete. Access the viser server at http://localhost:{port}")
 public_url = server.request_share_url()
 print(f"Public URL: {public_url}")
 return server

# Start the visualization server
server = start_visualization(
 output=output,
 min_conf_thr_percentile=10,
 global_conf_thr_value_to_drop_view=1.5,
 port=8020
)


In [None]:
server.stop()

In [None]:
# save the images to jpgs
for i, img in enumerate(output['views']):
 img = img['img'][0]
 img = img.permute(1, 2, 0).cpu().numpy()
 img = ((img + 1) * 127.5).astype(np.uint8)
 img = Image.fromarray(img)
 img.save(f"./output/img_{i}.jpg")

In [None]:
import os
import numpy as np
from PIL import Image
from fast3r.dust3r.utils.device import to_numpy
import pyrender

# Set the EGL platform for offscreen rendering
os.environ["PYOPENGL_PLATFORM"] = "egl"

def create_camera_pose(camera_position, target_point, up_vector):
 """
 Create a camera pose matrix (camera-to-world) that positions the camera at camera_position
 and orients it to look at target_point.
 """
 # Compute forward vector (from camera to target)
 forward_vector = target_point - camera_position
 forward_vector /= np.linalg.norm(forward_vector)

 # Compute right and up vectors
 right_vector = np.cross(up_vector, forward_vector)
 if np.linalg.norm(right_vector) < 1e-6:
 # Adjust up_vector if it's parallel to forward_vector
 up_vector = np.array([0, 0, 1]) if up_vector[1] != 1 else np.array([1, 0, 0])
 right_vector = np.cross(up_vector, forward_vector)
 right_vector /= np.linalg.norm(right_vector)
 up_vector = np.cross(forward_vector, right_vector)

 # Construct the camera-to-world matrix
 camera_pose = np.eye(4)
 camera_pose[:3, 0] = right_vector
 camera_pose[:3, 1] = up_vector
 camera_pose[:3, 2] = forward_vector
 camera_pose[:3, 3] = camera_position
 return camera_pose

def convert_c2w_to_opengl_view(c2w):
 """
 Convert a camera-to-world (c2w) extrinsic matrix to an OpenGL-compatible view matrix.
 """
 # Invert the camera-to-world matrix to get world-to-camera (view) matrix
 world_to_camera = np.linalg.inv(c2w)

 # OpenGL requires flipping the Y and Z axes
 opengl_to_camera = np.array([
 [1, 0, 0, 0],
 [0, -1, 0, 0],
 [0, 0, -1, 0],
 [0, 0, 0, 1]
 ])

 # Compute the OpenGL view matrix
 opengl_view_matrix = world_to_camera @ opengl_to_camera
 return opengl_view_matrix

def render_cumulative_pts3d_viz(preds, views, output_dir='./output', point_size=5.0, min_conf_thr_percentile=0):
 os.makedirs(output_dir, exist_ok=True)

 cumulative_pts = []
 cumulative_colors = []

 # First, accumulate all points across all frames to compute the scene extents
 for i, pred in enumerate(preds):
 # Flatten `pts3d` and `img_rgb`
 pts3d = to_numpy(pred['pts3d_in_other_view'].cpu().squeeze()).reshape(-1, 3)
 img_rgb = to_numpy(views[i]['img'].cpu().squeeze().permute(1, 2, 0)).reshape(-1, 3)

 # Apply confidence threshold
 conf = to_numpy(pred['conf'].cpu().squeeze()).flatten()
 conf_thr = np.percentile(conf, min_conf_thr_percentile)
 mask = conf > conf_thr

 # Apply the mask to points and colors
 pts3d_masked = pts3d[mask]
 colors_masked = ((img_rgb[mask] + 1) * 127.5).astype(np.uint8)

 # Accumulate masked points and colors
 cumulative_pts.append(pts3d_masked)
 cumulative_colors.append(colors_masked)

 # Combine cumulative points and colors
 cumulative_pts_combined = np.concatenate(cumulative_pts, axis=0)
 cumulative_colors_combined = np.concatenate(cumulative_colors, axis=0)

 # Verify that we have valid points
 if cumulative_pts_combined.shape[0] == 0:
 print("No points to render. Exiting.")
 return

 # Compute the center and extents of the cumulative point cloud
 point_cloud_center = np.mean(cumulative_pts_combined, axis=0)
 min_coords = np.min(cumulative_pts_combined, axis=0)
 max_coords = np.max(cumulative_pts_combined, axis=0)
 scene_extent = max_coords - min_coords
 max_extent = np.max(scene_extent)

 # Debug: Print point cloud stats
 print(f"Point cloud center: {point_cloud_center}")
 print(f"Scene extents: {scene_extent}")
 print(f"Max extent: {max_extent}")

 # Adjust camera position based on coordinate system
 # Assuming Z-up coordinate system (adjust if necessary)
 camera_distance = max_extent * 2 # Adjust multiplier as needed
 camera_position = point_cloud_center + np.array([0, 0, camera_distance]) # Camera above the scene
 up_vector = np.array([0, 1, 0]) # Y-axis is up in this case

 # Create the camera pose looking at the center of the point cloud
 camera_pose = create_camera_pose(camera_position, point_cloud_center, up_vector=up_vector)

 # Convert to OpenGL view matrix
 opengl_camera_pose = convert_c2w_to_opengl_view(camera_pose)

 print("Using stationary bird's eye view camera pose.")

 for i in range(len(preds)):
 print(f"Rendering frame {i}...")

 # For each frame, use the cumulative points up to that frame
 cumulative_pts_upto_frame = np.concatenate(cumulative_pts[:i+1], axis=0)
 cumulative_colors_upto_frame = np.concatenate(cumulative_colors[:i+1], axis=0)

 # Create the Pyrender scene and render
 pyrender_scene = pyrender.Scene()
 points_mesh = pyrender.Mesh.from_points(cumulative_pts_upto_frame, colors=cumulative_colors_upto_frame)
 pyrender_scene.add(points_mesh)

 # Set up the camera
 camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=16/9)
 pyrender_scene.add(camera, pose=opengl_camera_pose)

 # Add light
 light = pyrender.DirectionalLight(color=np.ones(3), intensity=3.0)
 pyrender_scene.add(light, pose=opengl_camera_pose)

 # Render the scene
 r = pyrender.OffscreenRenderer(viewport_width=1920, viewport_height=1080, point_size=point_size)
 color, _ = r.render(pyrender_scene)
 r.delete()

 # Save the rendered image
 frame_filename = os.path.join(output_dir, f'cumulative_{i:03d}.png')
 print(f"Frame {i} saved as {frame_filename}")
 Image.fromarray(color).save(frame_filename)

 print("Rendering complete. Frames saved as PNG files.")

# Run the rendering function
render_cumulative_pts3d_viz(
 output['preds'],
 output['views'],
 output_dir='./output',
 point_size=5.0,
 min_conf_thr_percentile=0
)


In [None]:
import os
import numpy as np
import trimesh
import pyrender
import matplotlib.pyplot as plt

# Set PyOpenGL platform to EGL for headless rendering
os.environ["PYOPENGL_PLATFORM"] = "egl"

# Load the FUZE bottle trimesh and put it in a scene
fuze_trimesh = trimesh.load('output/fuze.obj')
mesh = pyrender.Mesh.from_trimesh(fuze_trimesh)
scene = pyrender.Scene()
scene.add(mesh)

# Set up the camera -- z-axis away from the scene, x-axis right, y-axis up
camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
s = np.sqrt(2) / 2
camera_pose = np.array([
 [0.0, -s, s, 0.3],
 [1.0, 0.0, 0.0, 0.0],
 [0.0, s, s, 0.35],
 [0.0, 0.0, 0.0, 1.0],
])
scene.add(camera, pose=camera_pose)

# Set up the light -- a single spot light in the same spot as the camera
light = pyrender.SpotLight(color=np.ones(3), intensity=3.0,
 innerConeAngle=np.pi / 16.0, outerConeAngle=np.pi / 6.0)
scene.add(light, pose=camera_pose)

# Initialize the offscreen renderer
r = pyrender.OffscreenRenderer(640, 480)

# Render the scene
color, depth = r.render(scene)

# Display the images
plt.figure()
plt.subplot(1, 2, 1)
plt.axis('off')
plt.imshow(color)
plt.title("Color")

plt.subplot(1, 2, 2)
plt.axis('off')
plt.imshow(depth, cmap=plt.cm.gray_r)
plt.title("Depth")

plt.show()

# Clean up the renderer
r.delete()


In [None]:

# Plot the RGB images
plot_rgb_images(output['views'])

# Plot the confidence maps
plot_confidence_maps(output['preds'])

# Plot the 3D points
plot_3d_points_with_colors(output['preds'], output['views'], flip_axes=True, as_mesh=False, min_conf_thr_percentile=30, export_ply_path='./output/combined_mesh.ply')


In [None]:
import numpy as np
import plotly.graph_objects as go
from scipy.linalg import rq

from tqdm import tqdm

def estimate_camera_matrix(world_points, image_points):
 """
 Estimate the camera matrix from 3D world points and 2D image points using DLT.
 
 Parameters:
 world_points (np.ndarray): Array of 3D points in the world coordinates, shape (N, 3).
 image_points (np.ndarray): Array of 2D points in the image coordinates, shape (N, 2).
 
 Returns:
 np.ndarray: The 3x4 camera matrix.
 """
 assert world_points.shape[0] == image_points.shape[0], "Number of points must match"
 num_points = world_points.shape[0]
 
 # Add homogeneous coordinates to the world points
 homogeneous_world_points = np.hstack((world_points, np.ones((num_points, 1))))
 
 A = []
 
 for i in range(num_points):
 X, Y, Z, _ = homogeneous_world_points[i]
 u, v = image_points[i]
 
 # Two rows of the equation for each point
 A.append([X, Y, Z, 1, 0, 0, 0, 0, -u*X, -u*Y, -u*Z, -u])
 A.append([0, 0, 0, 0, X, Y, Z, 1, -v*X, -v*Y, -v*Z, -v])
 
 # Convert A to a numpy array
 A = np.array(A)
 
 # Solve using SVD (Singular Value Decomposition)
 U, S, Vt = np.linalg.svd(A)
 
 # The last row of Vt (or last column of V) is the solution
 P = Vt[-1].reshape(3, 4)
 
 return P

def decompose_camera_matrix(P):
 """
 Decompose the camera matrix into intrinsic and extrinsic matrices.
 
 Parameters:
 P (np.ndarray): The 3x4 camera matrix.
 
 Returns:
 K (np.ndarray): The 3x3 intrinsic matrix.
 R (np.ndarray): The 3x3 rotation matrix.
 t (np.ndarray): The 3x1 translation vector.
 """
 # Extract the camera matrix K and rotation matrix R using RQ decomposition
 M = P[:, :3] # The first 3x3 part of P
 
 # RQ Decomposition of M
 K, R = rq(M)
 
 # Normalize K so that K[2,2] = 1
 K /= K[2, 2]
 
 # Compute translation vector
 t = np.dot(np.linalg.inv(K), P[:, 3])
 
 return K, R, t

def plot_camera_cones(fig, R, t, K, color='blue', scale=0.1):
 """
 Plot the camera as a cone in 3D space based on the intrinsic matrix K for focal length.
 
 Parameters:
 fig (plotly.graph_objects.Figure): The existing Plotly figure.
 R (np.ndarray): The 3x3 rotation matrix.
 t (np.ndarray): The 3x1 translation vector.
 K (np.ndarray): The 3x3 intrinsic matrix.
 color (str): Color of the camera cone.
 scale (float): Scale factor for the size of the cone base.
 """
 # The focal length is the element K[0, 0] (assuming fx and fy are equal)
 focal_length = K[0, 0] / K[2, 2]

 # The camera center (apex of the cone)
 camera_center = -R.T @ t

 # Define the orientation of the cone based on the rotation matrix
 direction = R.T @ np.array([0, 0, 1]) # Camera looks along the +Z axis in camera space

 # Scale the direction by the focal length
 direction = direction * focal_length

 # Plot the camera cone
 fig.add_trace(go.Cone(
 x=[camera_center[0]],
 y=[camera_center[1]],
 z=[camera_center[2]],
 u=[direction[0]],
 v=[direction[1]],
 w=[direction[2]],
 colorscale=[[0, color], [1, color]], # Single color for the cone
 showscale=False,
 sizemode="absolute",
 sizeref=scale, # The size of the cone base
 anchor="tip", # The tip of the cone is the camera center
 name="Camera Cone"
 ))

def plot_3d_points_with_estimated_camera(output, fig, camera_poses, min_conf_thr_percentile=80):
 """
 Plot 3D points together with estimated camera cones in the same plot.
 
 Parameters:
 output (dict): The output containing 'preds' with 3D points and corresponding 2D image points.
 fig (plotly.graph_objects.Figure): The existing 3D plot.
 camera_poses (list): List of estimated camera poses.
 min_conf_thr_percentile (int): Percentile threshold for confidence values to filter points.
 """
 # Plot the 3D points first
 all_points = []
 all_colors = []

 for i, pred in enumerate(output['preds']):
 pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze() # 3D points
 img_rgb = output['views'][i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0) # RGB image (224x224)
 conf = pred['conf'].cpu().numpy().squeeze() # Confidence map

 # Apply confidence threshold
 conf_thr = np.percentile(conf, min_conf_thr_percentile)
 mask = conf > conf_thr

 # Rescale RGB values from [-1, 1] to [0, 255]
 img_rgb = ((img_rgb + 1) * 127.5).astype(np.uint8).clip(0, 255)

 # Flatten the points and colors, and apply mask
 x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()
 r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()
 x, y, z = x[mask.flatten()], y[mask.flatten()], z[mask.flatten()]
 r, g, b = r[mask.flatten()], g[mask.flatten()], b[mask.flatten()]

 colors = ['rgb({}, {}, {})'.format(r[j], g[j], b[j]) for j in range(len(r))]

 # Add points to the plot
 fig.add_trace(go.Scatter3d(
 x=x, y=y, z=z,
 mode='markers',
 marker=dict(size=2, opacity=0.8, color=colors),
 name=f"View {i} Points"
 ))

 # Now, plot the estimated cameras as cones
 for i, (R, t, K) in enumerate(camera_poses):
 plot_camera_cones(fig, R, t, K, color='blue')

 fig.update_layout(
 scene=dict(
 xaxis_title='X',
 yaxis_title='Y',
 zaxis_title='Z'
 ),
 margin=dict(l=0, r=0, b=0, t=40)
 )

def estimate_camera_poses(output, min_conf_thr_percentile=80):
 """
 Estimate camera poses from 3D points and 2D image points.
 
 Parameters:
 output (dict): The output containing 'preds' with 3D points and corresponding 2D image points.
 min_conf_thr_percentile (int): Percentile threshold for confidence values to filter points.
 
 Returns:
 list: A list of camera poses (R, t, K) where R is rotation, t is translation, and K is intrinsic matrix.
 """
 camera_poses = []
 
 # Loop through all views in output['preds']
 for i, pred in enumerate(output['preds']):
 # Get the 3D points and confidence map for the current view
 world_points = pred['pts3d_in_other_view'].cpu().numpy().squeeze() # Shape: (272, 512, 3)
 conf = pred['conf'].cpu().numpy().squeeze() # Confidence map

 # Determine the confidence threshold based on the percentile
 conf_thr = np.percentile(conf, min_conf_thr_percentile)

 # Apply confidence mask to filter points
 mask = conf > conf_thr
 world_points_filtered = world_points[mask]

 # Generate 2D pixel coordinates corresponding to the filtered points
 h, w, _ = world_points.shape
 image_points = np.indices((h, w)).reshape(2, -1).T # Shape: (N, 2)
 image_points_filtered = image_points[mask.flatten()] # Apply mask to 2D points

 if world_points_filtered.shape[0] == 0:
 print(f"View {i}: No points above confidence threshold. Skipping camera estimation.")
 continue

 # Estimate the camera matrix
 P = estimate_camera_matrix(world_points_filtered, image_points_filtered)
 print(f"Camera matrix for view {i}:\n", P)

 # Decompose into intrinsic and extrinsic matrices
 K, R, t = decompose_camera_matrix(P)
 print(f"Intrinsic matrix (K) for view {i}:\n", K)
 print(f"Rotation matrix (R) for view {i}:\n", R)
 print(f"Translation vector (t) for view {i}:\n", t)

 # Store the camera pose (rotation, translation, and intrinsic matrix)
 camera_poses.append((R, t, K))
 
 return camera_poses

# Estimate the camera poses first
camera_poses = estimate_camera_poses(output, min_conf_thr_percentile=80)

# Create a 3D plot and plot the 3D points together with the estimated cameras
fig = go.Figure()
plot_3d_points_with_estimated_camera(output, fig, camera_poses, min_conf_thr_percentile=50)

# Display the final plot with 3D points and camera cones
fig.show()


In [None]:
output['views'][0]['img'].shape

# Align with DTU point cloud

In [None]:
# The Rt matrix of the first image lives at /path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18/pos_001.txt
# it looks like this:
# 2607.429996 -3.844898 1498.178098 -533936.661373
# -192.076910 2862.552532 681.798177 23434.686572
# -0.241605 -0.030951 0.969881 22.540121
# I'd like to use this this to rotate an input 3D points to the correct orientation
# my 3d points assumes the camera is at (0, 0, 0) and looking at (0, 0, 1)



In [None]:
import numpy as np
import trimesh
import plotly.graph_objs as go
from scipy.linalg import rq

def load_camera_matrix(filepath):
 """Loads the camera calibration matrix from the given file."""
 with open(filepath, 'r') as f:
 lines = f.readlines()
 camera_matrix = np.array([list(map(float, line.split())) for line in lines])
 return camera_matrix

def decompose_camera_matrix(camera_matrix):
 """Decomposes the camera calibration matrix into intrinsic matrix (K), rotation matrix (R), and translation vector (t)."""
 # The camera matrix is 3x4
 M = camera_matrix[:, :3]
 
 # RQ decomposition to separate K and R
 K, R = rq(M)
 
 # Normalize K to ensure the sign of the diagonal is positive
 T = np.diag(np.sign(np.diag(K)))
 K = K @ T
 R = T @ R
 
 # Compute translation vector t
 t = np.linalg.inv(K) @ camera_matrix[:, 3]
 
 # Camera position C = -R^T * t
 camera_position = -R.T @ t
 
 return K, R, t, camera_position

def apply_transformation_to_point_cloud(ply_filepath, camera_matrix_filepath):
 """Applies the rotation and translation from the decomposed camera matrix to a point cloud loaded from a .ply file."""
 
 # Load the point cloud
 point_cloud = trimesh.load(ply_filepath)
 
 # Load and decompose the camera matrix
 camera_matrix = load_camera_matrix(camera_matrix_filepath)
 K, R, t, camera_position = decompose_camera_matrix(camera_matrix)

 
 # print point cloud range before transformation
 print(f"X range: {np.min(point_cloud.vertices[:, 0])} - {np.max(point_cloud.vertices[:, 0])} = {np.max(point_cloud.vertices[:, 0]) - np.min(point_cloud.vertices[:, 0])}")
 print(f"Y range: {np.min(point_cloud.vertices[:, 1])} - {np.max(point_cloud.vertices[:, 1])} = {np.max(point_cloud.vertices[:, 1]) - np.min(point_cloud.vertices[:, 1])}")
 print(f"Z range: {np.min(point_cloud.vertices[:, 2])} - {np.max(point_cloud.vertices[:, 2])} = {np.max(point_cloud.vertices[:, 2]) - np.min(point_cloud.vertices[:, 2])}")

 # prting the camera position
 print(f"Camera position: {camera_position}")
 
 # Apply the rotation matrix to the point cloud vertices
 rotated_points = (R @ point_cloud.vertices.T).T
 
 # Apply translation
 transformed_points = rotated_points + t
 
 # Print the range of the transformed points per axis
 print(f"X range: {np.min(transformed_points[:, 0])} - {np.max(transformed_points[:, 0])} = {np.max(transformed_points[:, 0]) - np.min(transformed_points[:, 0])}")
 print(f"Y range: {np.min(transformed_points[:, 1])} - {np.max(transformed_points[:, 1])} = {np.max(transformed_points[:, 1]) - np.min(transformed_points[:, 1])}")
 print(f"Z range: {np.min(transformed_points[:, 2])} - {np.max(transformed_points[:, 2])} = {np.max(transformed_points[:, 2]) - np.min(transformed_points[:, 2])}")
 
 # Create a new point cloud with rotated and translated points
 transformed_point_cloud = trimesh.PointCloud(vertices=transformed_points, colors=point_cloud.colors)
 
 return transformed_point_cloud

def plot_point_cloud(point_cloud, title="Transformed Point Cloud"):
 """Visualizes a point cloud using Plotly."""
 x = point_cloud.vertices[:, 0]
 y = point_cloud.vertices[:, 1]
 z = point_cloud.vertices[:, 2]
 colors = point_cloud.colors / 255.0 # Normalize colors to [0, 1] for Plotly
 
 fig = go.Figure(data=[go.Scatter3d(
 x=x, y=y, z=z,
 mode='markers',
 marker=dict(
 size=2,
 color=colors,
 opacity=0.8
 )
 )])
 
 fig.update_layout(
 title=title,
 scene=dict(
 xaxis_title='X',
 yaxis_title='Y',
 zaxis_title='Z'
 ),
 margin=dict(l=0, r=0, b=0, t=40),
 height=800
 )
 
 fig.show()

# Example usage:
ply_filepath = '/path/to/combined_mesh.ply'
camera_matrix_filepath = '/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18/pos_001.txt'

transformed_point_cloud = apply_transformation_to_point_cloud(ply_filepath, camera_matrix_filepath)

# Save the transformed point cloud to a new .ply file
transformed_point_cloud.export('/path/to/transformed_output.ply')

# Visualize the transformed point cloud
plot_point_cloud(transformed_point_cloud)


In [None]:
import numpy as np
import trimesh

def load_and_print_xyz_ranges(ply_filepath):
 """Loads a point cloud from a .ply file and prints the XYZ ranges."""
 
 # Load the point cloud
 point_cloud = trimesh.load(ply_filepath)
 
 # Extract the vertices (XYZ coordinates)
 vertices = point_cloud.vertices
 
 # Calculate the ranges for X, Y, and Z
 x_min, x_max = np.min(vertices[:, 0]), np.max(vertices[:, 0])
 y_min, y_max = np.min(vertices[:, 1]), np.max(vertices[:, 1])
 z_min, z_max = np.min(vertices[:, 2]), np.max(vertices[:, 2])
 
 # Print the ranges
 print(f"X range: {x_min} - {x_max} = {x_max - x_min}")
 print(f"Y range: {y_min} - {y_max} = {y_max - y_min}")
 print(f"Z range: {z_min} - {z_max} = {z_max - z_min}")

# Example usage:
reference_ply_filepath = '/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Points/stl/stl006_total.ply'

load_and_print_xyz_ranges(reference_ply_filepath)


In [None]:
import numpy as np
import plotly.graph_objs as go
import os
from scipy.linalg import rq

def load_camera_matrix(filepath):
 """Loads the camera calibration matrix from the given file."""
 with open(filepath, 'r') as f:
 lines = f.readlines()
 camera_matrix = np.array([list(map(float, line.split())) for line in lines])
 return camera_matrix

def decompose_camera_matrix(camera_matrix):
 """Decomposes the camera calibration matrix into intrinsic matrix (K), rotation matrix (R), and translation vector (t)."""
 # The camera matrix is 3x4
 M = camera_matrix[:, :3]
 
 # RQ decomposition to separate K and R
 K, R = rq(M)
 
 # Normalize K to ensure the sign of the diagonal is positive
 T = np.diag(np.sign(np.diag(K)))
 K = K @ T
 R = T @ R
 
 # Compute translation vector t
 t = np.linalg.inv(K) @ camera_matrix[:, 3]
 
 # Camera position C = -R^T * t
 camera_position = -R.T @ t
 
 return K, R, t, camera_position

def plot_camera_poses(base_path, pose_count):
 """Plots all camera poses and visualizes them in Plotly."""
 camera_positions = []
 camera_orientations = []
 
 for i in range(1, pose_count + 1):
 filepath = os.path.join(base_path, f'pos_{i:03d}.txt')
 camera_matrix = load_camera_matrix(filepath)
 
 # Print the full camera matrix
 print(f"Camera Matrix {i}:\n{camera_matrix}\n")
 
 K, R, t, camera_position = decompose_camera_matrix(camera_matrix)
 
 # Print the decomposed matrices
 print(f"Intrinsic Matrix (K) {i}:\n{K}\n")
 print(f"Rotation Matrix (R) {i}:\n{R}\n")
 print(f"Translation Vector (t) {i}:\n{t}\n")
 print(f"Camera Position {i}: {camera_position}\n")
 
 # Camera direction (assuming camera is looking along -Z in its own coordinate system)
 camera_direction = R.T @ np.array([0, 0, -1])
 
 camera_positions.append(camera_position)
 camera_orientations.append(camera_direction)
 
 # Convert lists to numpy arrays
 camera_positions = np.array(camera_positions)
 camera_orientations = np.array(camera_orientations)
 
 # Create the 3D scatter plot for camera positions
 scatter = go.Scatter3d(
 x=camera_positions[:, 0],
 y=camera_positions[:, 1],
 z=camera_positions[:, 2],
 mode='markers',
 marker=dict(size=5, color='blue'),
 name='Camera Positions'
 )
 
 # Create the 3D quiver plot for camera orientations
 quiver = go.Cone(
 x=camera_positions[:, 0],
 y=camera_positions[:, 1],
 z=camera_positions[:, 2],
 u=camera_orientations[:, 0],
 v=camera_orientations[:, 1],
 w=camera_orientations[:, 2],
 sizemode='scaled',
 sizeref=2,
 colorscale='Blues',
 name='Camera Orientations'
 )
 
 # Set up the layout
 layout = go.Layout(
 title='Camera Poses Visualization',
 scene=dict(
 xaxis=dict(title='X'),
 yaxis=dict(title='Y'),
 zaxis=dict(title='Z'),
 ),
 margin=dict(l=0, r=0, b=0, t=40),
 height=800
 )
 
 # Create the figure and show it
 fig = go.Figure(data=[scatter, quiver], layout=layout)
 fig.show()

# Example usage:
base_path = '/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18'
pose_count = 49 # Adjust this according to the number of poses available

plot_camera_poses(base_path, pose_count)
