Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,481 Bytes
dc269e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import matplotlib
matplotlib.use('Agg')
import sys
import argparse
import torch
import numpy as np
import os
import re
import cv2
import trimesh
from pathlib import Path
from PIL import Image
from typing import NamedTuple, Optional
current_dir = os.getcwd()
sys.path.append(os.path.join(current_dir, 'mast3r'))
from mast3r.model import AsymmetricMASt3R
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
import mast3r.utils.path_to_dust3r
from dust3r.utils.image import load_images
from dust3r.utils.device import to_numpy
from dust3r.image_pairs import make_pairs
from plyfile import PlyData, PlyElement
from utils.dust3r_utils import visualizer, pca, upsampler
class BasicPointCloud(NamedTuple):
points: np.array
colors: np.array
normals: np.array
def invert_matrix(mat):
"""Invert a torch or numpy matrix."""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f'Unsupported matrix type: {type(mat)}')
def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
return 2 * math.atan(pixels / (2 * focal))
def rotmat2qvec(R):
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
K = np.array([
[Rxx - Ryy - Rzz, 0, 0, 0],
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
eigvals, eigvecs = np.linalg.eigh(K)
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
if qvec[0] < 0:
qvec *= -1
return qvec
def storePly(path, xyz, rgb):
# Define the dtype for the structured array
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
normals = np.zeros_like(xyz)
elements = np.empty(xyz.shape[0], dtype=dtype)
attributes = np.concatenate((xyz, normals, rgb), axis=1)
elements[:] = list(map(tuple, attributes))
# Create the PlyData object and write to file
vertex_element = PlyElement.describe(elements, 'vertex')
ply_data = PlyData([vertex_element])
ply_data.write(path)
# Ensure save directories exist
def init_filestructure(save_path):
save_path.mkdir(exist_ok=True, parents=True)
images_path = save_path / 'images'
masks_path = save_path / 'masks'
sparse_path = save_path / 'sparse/0'
images_path.mkdir(exist_ok=True, parents=True)
masks_path.mkdir(exist_ok=True, parents=True)
sparse_path.mkdir(exist_ok=True, parents=True)
return save_path, images_path, masks_path, sparse_path
# Save images and masks
def save_images_and_masks(imgs, masks, images_path, img_files, masks_path):
for i, (image, name, mask) in enumerate(zip(imgs, img_files, masks)):
imgname = Path(name).stem
image_save_path = images_path / f"{imgname}.png"
mask_save_path = masks_path / f"{imgname}.png"
rgb_image = cv2.cvtColor(image * 255, cv2.COLOR_BGR2RGB)
cv2.imwrite(str(image_save_path), rgb_image)
mask = np.repeat(np.expand_dims(mask, -1), 3, axis=2) * 255
Image.fromarray(mask.astype(np.uint8)).save(mask_save_path)
# Save camera information
def save_cameras(focals, principal_points, sparse_path, imgs_shape):
cameras_file = sparse_path / 'cameras.txt'
with open(cameras_file, 'w') as f:
f.write("# Camera list with one line of data per camera:\n")
f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
for i, (focal, pp) in enumerate(zip(focals, principal_points)):
f.write(f"{i} PINHOLE {imgs_shape[2]} {imgs_shape[1]} {focal} {focal} {pp[0]} {pp[1]}\n")
# Save image transformations
def save_images_txt(world2cam, img_files, sparse_path):
images_file = sparse_path / 'images.txt'
with open(images_file, 'w') as f:
f.write("# Image list with two lines of data per image:\n")
f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n")
for i in range(world2cam.shape[0]):
name = Path(img_files[i]).stem
rotation_matrix = world2cam[i, :3, :3]
qw, qx, qy, qz = rotmat2qvec(rotation_matrix)
tx, ty, tz = world2cam[i, :3, 3]
f.write(f"{i} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {i} {name}.png\n\n")
# Save point cloud with normals
def save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path):
pc = get_point_cloud(imgs, pts3d, masks)
default_normal = [0, 1, 0]
vertices = pc.vertices
colors = pc.colors
normals = np.tile(default_normal, (vertices.shape[0], 1))
save_path = sparse_path / 'points3D.ply'
header = """ply
format ascii 1.0
element vertex {}
property float x
property float y
property float z
property uchar red
property uchar green
property uchar blue
property float nx
property float ny
property float nz
end_header
""".format(len(vertices))
with open(save_path, 'w') as f:
f.write(header)
for vertex, color, normal in zip(vertices, colors, normals):
f.write(f"{vertex[0]} {vertex[1]} {vertex[2]} {int(color[0])} {int(color[1])} {int(color[2])} {normal[0]} {normal[1]} {normal[2]}\n")
# Generate point cloud
def get_point_cloud(imgs, pts3d, mask):
imgs = to_numpy(imgs)
pts3d = to_numpy(pts3d)
mask = to_numpy(mask)
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask.reshape(mask.shape[0], -1))])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pts = pts.reshape(-1, 3)[::3]
col = col.reshape(-1, 3)[::3]
normals = np.tile([0, 1, 0], (pts.shape[0], 1))
pct = trimesh.PointCloud(pts, colors=col)
pct.vertices_normal = normals
return pct
def main(image_dir, save_dir, model_path, device, batch_size, image_size, schedule, lr, niter, min_conf_thr, tsdf_thresh):
# Load model and images
model = AsymmetricMASt3R.from_pretrained(model_path).to(device)
image_files = sorted([str(x) for x in Path(image_dir).iterdir() if x.suffix in ['.png', '.jpg']],
key=lambda x: int(re.search(r'\d+', Path(x).stem).group()))
images = load_images(image_files, size=image_size)
# Generate pairs and run inference
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
cache_dir = os.path.join(save_dir, 'cache')
if os.path.exists(cache_dir):
os.system(f'rm -rf {cache_dir}')
scene = sparse_global_alignment(image_files, pairs, cache_dir,
model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=device,
opt_depth='depth' in 'refine', shared_intrinsics=False,
matching_conf_thr=5.)
# Extract scene information
world2cam = invert_matrix(scene.get_im_poses().detach()).cpu().numpy()
principal_points = scene.get_principal_points().detach().cpu().numpy()
focals = scene.get_focals().detach().cpu().numpy()
imgs = np.array(scene.imgs)
tsdf = TSDFPostProcess(scene, TSDF_thresh=tsdf_thresh)
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
masks = np.array(to_numpy([c > min_conf_thr for c in confs]))
_, H, W, _ = imgs.shape
feat_dim = 64
projected_feat = pca(scene.stacked_feat, feat_dim)
upsampled_feat = upsampler(projected_feat, H, W)
visualizer(upsampled_feat, images, save_dir)
# Main execution
save_path, images_path, masks_path, sparse_path = init_filestructure(Path(save_dir))
save_images_and_masks(imgs, masks, images_path, image_files, masks_path)
save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
save_images_txt(world2cam, image_files, sparse_path)
save_pointcloud_with_normals(imgs, pts3d, masks, save_dir)
print(f'[INFO] Mast3R Reconstruction is successfully converted to COLMAP files in: {str(sparse_path)}')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process images and save results.')
parser.add_argument('--image_dir', type=str, required=True, help='Directory containing images')
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the results')
parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint')
parser.add_argument('--device', type=str, default='cuda', help='Device to use for inference')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size for processing images')
parser.add_argument('--image_size', type=int, default=512, help='Size to resize images')
parser.add_argument('--schedule', type=str, default='cosine', help='Learning rate schedule')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
parser.add_argument('--niter', type=int, default=300, help='Number of iterations')
parser.add_argument('--min_conf_thr', type=float, default=1.5, help='Minimum confidence threshold')
parser.add_argument('--tsdf_thresh', type=float, default=0.0, help='TSDF threshold')
args = parser.parse_args()
main(args.image_dir, args.save_dir, args.model_path, args.device, args.batch_size, args.image_size, args.schedule, args.lr, args.niter, args.min_conf_thr, args.tsdf_thresh) |