Feat2GS / submodules /mast3r /colmap_from_mast3r.py
faneggg's picture
init
dc269e0
import matplotlib
matplotlib.use('Agg')
import sys
import argparse
import torch
import numpy as np
import os
import re
import cv2
import trimesh
from pathlib import Path
from PIL import Image
from typing import NamedTuple, Optional
current_dir = os.getcwd()
sys.path.append(os.path.join(current_dir, 'mast3r'))
from mast3r.model import AsymmetricMASt3R
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
import mast3r.utils.path_to_dust3r
from dust3r.utils.image import load_images
from dust3r.utils.device import to_numpy
from dust3r.image_pairs import make_pairs
from plyfile import PlyData, PlyElement
from utils.dust3r_utils import visualizer, pca, upsampler
class BasicPointCloud(NamedTuple):
points: np.array
colors: np.array
normals: np.array
def invert_matrix(mat):
"""Invert a torch or numpy matrix."""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f'Unsupported matrix type: {type(mat)}')
def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
return 2 * math.atan(pixels / (2 * focal))
def rotmat2qvec(R):
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
K = np.array([
[Rxx - Ryy - Rzz, 0, 0, 0],
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
eigvals, eigvecs = np.linalg.eigh(K)
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
if qvec[0] < 0:
qvec *= -1
return qvec
def storePly(path, xyz, rgb):
# Define the dtype for the structured array
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
normals = np.zeros_like(xyz)
elements = np.empty(xyz.shape[0], dtype=dtype)
attributes = np.concatenate((xyz, normals, rgb), axis=1)
elements[:] = list(map(tuple, attributes))
# Create the PlyData object and write to file
vertex_element = PlyElement.describe(elements, 'vertex')
ply_data = PlyData([vertex_element])
ply_data.write(path)
# Ensure save directories exist
def init_filestructure(save_path):
save_path.mkdir(exist_ok=True, parents=True)
images_path = save_path / 'images'
masks_path = save_path / 'masks'
sparse_path = save_path / 'sparse/0'
images_path.mkdir(exist_ok=True, parents=True)
masks_path.mkdir(exist_ok=True, parents=True)
sparse_path.mkdir(exist_ok=True, parents=True)
return save_path, images_path, masks_path, sparse_path
# Save images and masks
def save_images_and_masks(imgs, masks, images_path, img_files, masks_path):
for i, (image, name, mask) in enumerate(zip(imgs, img_files, masks)):
imgname = Path(name).stem
image_save_path = images_path / f"{imgname}.png"
mask_save_path = masks_path / f"{imgname}.png"
rgb_image = cv2.cvtColor(image * 255, cv2.COLOR_BGR2RGB)
cv2.imwrite(str(image_save_path), rgb_image)
mask = np.repeat(np.expand_dims(mask, -1), 3, axis=2) * 255
Image.fromarray(mask.astype(np.uint8)).save(mask_save_path)
# Save camera information
def save_cameras(focals, principal_points, sparse_path, imgs_shape):
cameras_file = sparse_path / 'cameras.txt'
with open(cameras_file, 'w') as f:
f.write("# Camera list with one line of data per camera:\n")
f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
for i, (focal, pp) in enumerate(zip(focals, principal_points)):
f.write(f"{i} PINHOLE {imgs_shape[2]} {imgs_shape[1]} {focal} {focal} {pp[0]} {pp[1]}\n")
# Save image transformations
def save_images_txt(world2cam, img_files, sparse_path):
images_file = sparse_path / 'images.txt'
with open(images_file, 'w') as f:
f.write("# Image list with two lines of data per image:\n")
f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n")
for i in range(world2cam.shape[0]):
name = Path(img_files[i]).stem
rotation_matrix = world2cam[i, :3, :3]
qw, qx, qy, qz = rotmat2qvec(rotation_matrix)
tx, ty, tz = world2cam[i, :3, 3]
f.write(f"{i} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {i} {name}.png\n\n")
# Save point cloud with normals
def save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path):
pc = get_point_cloud(imgs, pts3d, masks)
default_normal = [0, 1, 0]
vertices = pc.vertices
colors = pc.colors
normals = np.tile(default_normal, (vertices.shape[0], 1))
save_path = sparse_path / 'points3D.ply'
header = """ply
format ascii 1.0
element vertex {}
property float x
property float y
property float z
property uchar red
property uchar green
property uchar blue
property float nx
property float ny
property float nz
end_header
""".format(len(vertices))
with open(save_path, 'w') as f:
f.write(header)
for vertex, color, normal in zip(vertices, colors, normals):
f.write(f"{vertex[0]} {vertex[1]} {vertex[2]} {int(color[0])} {int(color[1])} {int(color[2])} {normal[0]} {normal[1]} {normal[2]}\n")
# Generate point cloud
def get_point_cloud(imgs, pts3d, mask):
imgs = to_numpy(imgs)
pts3d = to_numpy(pts3d)
mask = to_numpy(mask)
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask.reshape(mask.shape[0], -1))])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pts = pts.reshape(-1, 3)[::3]
col = col.reshape(-1, 3)[::3]
normals = np.tile([0, 1, 0], (pts.shape[0], 1))
pct = trimesh.PointCloud(pts, colors=col)
pct.vertices_normal = normals
return pct
def main(image_dir, save_dir, model_path, device, batch_size, image_size, schedule, lr, niter, min_conf_thr, tsdf_thresh):
# Load model and images
model = AsymmetricMASt3R.from_pretrained(model_path).to(device)
image_files = sorted([str(x) for x in Path(image_dir).iterdir() if x.suffix in ['.png', '.jpg']],
key=lambda x: int(re.search(r'\d+', Path(x).stem).group()))
images = load_images(image_files, size=image_size)
# Generate pairs and run inference
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
cache_dir = os.path.join(save_dir, 'cache')
if os.path.exists(cache_dir):
os.system(f'rm -rf {cache_dir}')
scene = sparse_global_alignment(image_files, pairs, cache_dir,
model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=device,
opt_depth='depth' in 'refine', shared_intrinsics=False,
matching_conf_thr=5.)
# Extract scene information
world2cam = invert_matrix(scene.get_im_poses().detach()).cpu().numpy()
principal_points = scene.get_principal_points().detach().cpu().numpy()
focals = scene.get_focals().detach().cpu().numpy()
imgs = np.array(scene.imgs)
tsdf = TSDFPostProcess(scene, TSDF_thresh=tsdf_thresh)
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
masks = np.array(to_numpy([c > min_conf_thr for c in confs]))
_, H, W, _ = imgs.shape
feat_dim = 64
projected_feat = pca(scene.stacked_feat, feat_dim)
upsampled_feat = upsampler(projected_feat, H, W)
visualizer(upsampled_feat, images, save_dir)
# Main execution
save_path, images_path, masks_path, sparse_path = init_filestructure(Path(save_dir))
save_images_and_masks(imgs, masks, images_path, image_files, masks_path)
save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
save_images_txt(world2cam, image_files, sparse_path)
save_pointcloud_with_normals(imgs, pts3d, masks, save_dir)
print(f'[INFO] Mast3R Reconstruction is successfully converted to COLMAP files in: {str(sparse_path)}')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process images and save results.')
parser.add_argument('--image_dir', type=str, required=True, help='Directory containing images')
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the results')
parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint')
parser.add_argument('--device', type=str, default='cuda', help='Device to use for inference')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size for processing images')
parser.add_argument('--image_size', type=int, default=512, help='Size to resize images')
parser.add_argument('--schedule', type=str, default='cosine', help='Learning rate schedule')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
parser.add_argument('--niter', type=int, default=300, help='Number of iterations')
parser.add_argument('--min_conf_thr', type=float, default=1.5, help='Minimum confidence threshold')
parser.add_argument('--tsdf_thresh', type=float, default=0.0, help='TSDF threshold')
args = parser.parse_args()
main(args.image_dir, args.save_dir, args.model_path, args.device, args.batch_size, args.image_size, args.schedule, args.lr, args.niter, args.min_conf_thr, args.tsdf_thresh)