Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
from pathlib import Path | |
import sys | |
sys.path.append(str(Path(__file__).absolute().parents[1])) | |
from typing import * | |
import itertools | |
import json | |
import warnings | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
import trimesh | |
import trimesh.visual | |
import click | |
from moge.model import MoGeModel | |
from moge.utils.io import save_glb, save_ply | |
from moge.utils.vis import colorize_depth, colorize_normal | |
import utils3d | |
def main( | |
input_path: str, | |
fov_x_: float, | |
output_path: str, | |
pretrained_model_name_or_path: str, | |
device_name: str, | |
resize_to: int, | |
resolution_level: int, | |
threshold: float, | |
save_maps_: bool, | |
save_glb_: bool, | |
save_ply_: bool, | |
show: bool, | |
): | |
device = torch.device(device_name) | |
include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] | |
if Path(input_path).is_dir(): | |
image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) | |
else: | |
image_paths = [Path(input_path)] | |
if len(image_paths) == 0: | |
raise FileNotFoundError(f'No image files found in {input_path}') | |
model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() | |
for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): | |
image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) | |
height, width = image.shape[:2] | |
if resize_to is not None: | |
height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) | |
image = cv2.resize(image, (width, height), cv2.INTER_AREA) | |
image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) | |
# Inference | |
output = model.infer(image_tensor, fov_x=fov_x_) | |
points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy() | |
normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) | |
# Write outputs | |
if not any([save_maps_, save_glb_, save_ply_]): | |
warnings.warn('No output format specified. Please use "--maps", "--glb", or "--ply" to specify the output.') | |
save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) | |
save_path.mkdir(exist_ok=True, parents=True) | |
if save_maps_: | |
cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) | |
cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) | |
cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8)) | |
cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) | |
fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) | |
with open(save_path / 'fov.json', 'w') as f: | |
json.dump({ | |
'fov_x': round(float(np.rad2deg(fov_x)), 2), | |
'fov_y': round(float(np.rad2deg(fov_y)), 2), | |
}, f) | |
# Export mesh & visulization | |
if save_glb_ or save_ply_ or show: | |
faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( | |
points, | |
image.astype(np.float32) / 255, | |
utils3d.numpy.image_uv(width=width, height=height), | |
mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), | |
tri=True | |
) | |
# When exporting the model, follow the OpenGL coordinate conventions: | |
# - world coordinate system: x right, y up, z backward. | |
# - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. | |
vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] | |
if save_glb_: | |
save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) | |
if save_ply_: | |
save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) | |
if show: | |
trimesh.Trimesh( | |
vertices=vertices, | |
vertex_colors=vertex_colors, | |
faces=faces, | |
process=False | |
).show() | |
if __name__ == '__main__': | |
main() | |