File size: 9,481 Bytes
dc269e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import matplotlib
matplotlib.use('Agg')

import sys
import argparse
import torch
import numpy as np
import os
import re
import cv2
import trimesh
from pathlib import Path
from PIL import Image
from typing import NamedTuple, Optional

current_dir = os.getcwd()
sys.path.append(os.path.join(current_dir, 'mast3r'))
from mast3r.model import AsymmetricMASt3R
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
import mast3r.utils.path_to_dust3r

from dust3r.utils.image import load_images
from dust3r.utils.device import to_numpy
from dust3r.image_pairs import make_pairs

from plyfile import PlyData, PlyElement

from utils.dust3r_utils import visualizer, pca, upsampler

class BasicPointCloud(NamedTuple):
    points: np.array
    colors: np.array
    normals: np.array

def invert_matrix(mat):
    """Invert a torch or numpy matrix."""
    if isinstance(mat, torch.Tensor):
        return torch.linalg.inv(mat)
    if isinstance(mat, np.ndarray):
        return np.linalg.inv(mat)
    raise ValueError(f'Unsupported matrix type: {type(mat)}')

def fov2focal(fov, pixels):
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    return 2 * math.atan(pixels / (2 * focal))

def rotmat2qvec(R):
    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
    K = np.array([
        [Rxx - Ryy - Rzz, 0, 0, 0],
        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
    eigvals, eigvecs = np.linalg.eigh(K)
    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
    if qvec[0] < 0:
        qvec *= -1
    return qvec

def storePly(path, xyz, rgb):
    # Define the dtype for the structured array
    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
             ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
             ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]

    normals = np.zeros_like(xyz)

    elements = np.empty(xyz.shape[0], dtype=dtype)
    attributes = np.concatenate((xyz, normals, rgb), axis=1)
    elements[:] = list(map(tuple, attributes))

    # Create the PlyData object and write to file
    vertex_element = PlyElement.describe(elements, 'vertex')
    ply_data = PlyData([vertex_element])
    ply_data.write(path)

# Ensure save directories exist
def init_filestructure(save_path):
    save_path.mkdir(exist_ok=True, parents=True)
    images_path = save_path / 'images'
    masks_path = save_path / 'masks'
    sparse_path = save_path / 'sparse/0'
    images_path.mkdir(exist_ok=True, parents=True)
    masks_path.mkdir(exist_ok=True, parents=True)
    sparse_path.mkdir(exist_ok=True, parents=True)
    return save_path, images_path, masks_path, sparse_path

# Save images and masks
def save_images_and_masks(imgs, masks, images_path, img_files, masks_path):
    for i, (image, name, mask) in enumerate(zip(imgs, img_files, masks)):
        imgname = Path(name).stem
        image_save_path = images_path / f"{imgname}.png"
        mask_save_path = masks_path / f"{imgname}.png"
        rgb_image = cv2.cvtColor(image * 255, cv2.COLOR_BGR2RGB)
        cv2.imwrite(str(image_save_path), rgb_image)
        mask = np.repeat(np.expand_dims(mask, -1), 3, axis=2) * 255
        Image.fromarray(mask.astype(np.uint8)).save(mask_save_path)

# Save camera information
def save_cameras(focals, principal_points, sparse_path, imgs_shape):
    cameras_file = sparse_path / 'cameras.txt'
    with open(cameras_file, 'w') as f:
        f.write("# Camera list with one line of data per camera:\n")
        f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
        for i, (focal, pp) in enumerate(zip(focals, principal_points)):
            f.write(f"{i} PINHOLE {imgs_shape[2]} {imgs_shape[1]} {focal} {focal} {pp[0]} {pp[1]}\n")

# Save image transformations
def save_images_txt(world2cam, img_files, sparse_path):
    images_file = sparse_path / 'images.txt'
    with open(images_file, 'w') as f:
        f.write("# Image list with two lines of data per image:\n")
        f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
        f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n")
        for i in range(world2cam.shape[0]):
            name = Path(img_files[i]).stem
            rotation_matrix = world2cam[i, :3, :3]
            qw, qx, qy, qz = rotmat2qvec(rotation_matrix)
            tx, ty, tz = world2cam[i, :3, 3]
            f.write(f"{i} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {i} {name}.png\n\n")

# Save point cloud with normals
def save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path):
    pc = get_point_cloud(imgs, pts3d, masks)
    default_normal = [0, 1, 0]
    vertices = pc.vertices
    colors = pc.colors
    normals = np.tile(default_normal, (vertices.shape[0], 1))
    save_path = sparse_path / 'points3D.ply'
    header = """ply
format ascii 1.0
element vertex {}
property float x
property float y
property float z
property uchar red
property uchar green
property uchar blue
property float nx
property float ny
property float nz
end_header
""".format(len(vertices))
    with open(save_path, 'w') as f:
        f.write(header)
        for vertex, color, normal in zip(vertices, colors, normals):
            f.write(f"{vertex[0]} {vertex[1]} {vertex[2]} {int(color[0])} {int(color[1])} {int(color[2])} {normal[0]} {normal[1]} {normal[2]}\n")

# Generate point cloud
def get_point_cloud(imgs, pts3d, mask):
    imgs = to_numpy(imgs)
    pts3d = to_numpy(pts3d)
    mask = to_numpy(mask)
    pts = np.concatenate([p[m] for p, m in zip(pts3d, mask.reshape(mask.shape[0], -1))])
    col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
    pts = pts.reshape(-1, 3)[::3]
    col = col.reshape(-1, 3)[::3]
    normals = np.tile([0, 1, 0], (pts.shape[0], 1))
    pct = trimesh.PointCloud(pts, colors=col)
    pct.vertices_normal = normals
    return pct

def main(image_dir, save_dir, model_path, device, batch_size, image_size, schedule, lr, niter, min_conf_thr, tsdf_thresh):
    # Load model and images
    model = AsymmetricMASt3R.from_pretrained(model_path).to(device)
    image_files = sorted([str(x) for x in Path(image_dir).iterdir() if x.suffix in ['.png', '.jpg']],
                         key=lambda x: int(re.search(r'\d+', Path(x).stem).group()))
    images = load_images(image_files, size=image_size)

    # Generate pairs and run inference
    pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)

    cache_dir = os.path.join(save_dir, 'cache')
    if os.path.exists(cache_dir):
        os.system(f'rm -rf {cache_dir}')
    scene = sparse_global_alignment(image_files, pairs, cache_dir,
                                    model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=device,
                                    opt_depth='depth' in 'refine', shared_intrinsics=False,
                                    matching_conf_thr=5.)

    # Extract scene information
    world2cam = invert_matrix(scene.get_im_poses().detach()).cpu().numpy()
    principal_points = scene.get_principal_points().detach().cpu().numpy()
    focals = scene.get_focals().detach().cpu().numpy()
    imgs = np.array(scene.imgs)

    tsdf = TSDFPostProcess(scene, TSDF_thresh=tsdf_thresh)
    pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
    masks = np.array(to_numpy([c > min_conf_thr for c in confs]))

    _, H, W, _ = imgs.shape
    feat_dim = 64
    projected_feat = pca(scene.stacked_feat, feat_dim)
    upsampled_feat = upsampler(projected_feat, H, W)
    visualizer(upsampled_feat, images, save_dir)

    # Main execution
    save_path, images_path, masks_path, sparse_path = init_filestructure(Path(save_dir))
    save_images_and_masks(imgs, masks, images_path, image_files, masks_path)
    save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
    save_images_txt(world2cam, image_files, sparse_path)
    save_pointcloud_with_normals(imgs, pts3d, masks, save_dir)

    print(f'[INFO] Mast3R Reconstruction is successfully converted to COLMAP files in: {str(sparse_path)}')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process images and save results.')
    parser.add_argument('--image_dir', type=str, required=True, help='Directory containing images')
    parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the results')
    parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use for inference')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size for processing images')
    parser.add_argument('--image_size', type=int, default=512, help='Size to resize images')
    parser.add_argument('--schedule', type=str, default='cosine', help='Learning rate schedule')
    parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--niter', type=int, default=300, help='Number of iterations')
    parser.add_argument('--min_conf_thr', type=float, default=1.5, help='Minimum confidence threshold')
    parser.add_argument('--tsdf_thresh', type=float, default=0.0, help='TSDF threshold')

    args = parser.parse_args()
    main(args.image_dir, args.save_dir, args.model_path, args.device, args.batch_size, args.image_size, args.schedule, args.lr, args.niter, args.min_conf_thr, args.tsdf_thresh)