3DTopia2 / 3DTopia /utility /mcubes_from_latent.py
HongFangzhou
add source codes
bc2085d
import os
import torch
import argparse
import mcubes
import trimesh
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
from utility.initialize import instantiate_from_config, get_obj_from_str
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
# load model
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, required=True)
parser.add_argument("--ckpt", type=str, default=None, required=True)
args = parser.parse_args()
configs = OmegaConf.load(args.config)
device = 'cuda'
vae = get_obj_from_str(configs.model.params.first_stage_config['target'])(**configs.model.params.first_stage_config['params'])
vae = vae.to(device)
vae.eval()
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(args.ckpt, map_location='cpu', strict=False, **configs.model.params)
model = model.to(device)
def extract_mesh(triplane_fname, save_name=None):
latent = torch.from_numpy(np.load(triplane_fname)).to(device)
with torch.no_grad():
with model.ema_scope():
triplane = model.decode_first_stage(latent)
# prepare volumn for marching cube
res = 128
c_list = torch.linspace(-1.2, 1.2, steps=res)
grid_x, grid_y, grid_z = torch.meshgrid(
c_list, c_list, c_list, indexing='ij'
)
coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) # 256x256x256x3
plane_axes = generate_planes()
feats = sample_from_planes(
plane_axes, triplane.reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4
)
fake_dirs = torch.zeros_like(coords)
fake_dirs[..., 0] = 1
with torch.no_grad():
out = vae.triplane_decoder.decoder(feats, fake_dirs)
u = out['sigma'].reshape(res, res, res).detach().cpu().numpy()
del out
# marching cube
vertices, triangles = mcubes.marching_cubes(u, 8)
min_bound = np.array([-1.2, -1.2, -1.2])
max_bound = np.array([1.2, 1.2, 1.2])
vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :]
pt_vertices = torch.from_numpy(vertices).to(device)
# extract vertices color
res_triplane = 256
# rays_d = torch.from_numpy(-vertices / np.sqrt((vertices ** 2).sum(-1)).reshape(-1, 1)).to(device).unsqueeze(0)
# rays_o = -rays_d * 2.0
render_kwargs = {
'depth_resolution': 128,
'disparity_space_sampling': False,
'box_warp': 2.4,
'depth_resolution_importance': 128,
'clamp_mode': 'softplus',
'white_back': True,
'det': True
}
# render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane), rays_o, rays_d, render_kwargs, whole_img=False, tvloss=False)
# rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
# rgb = (rgb * 255).astype(np.uint8)
rays_o_list = [
np.array([0, 0, 2]),
np.array([0, 0, -2]),
np.array([0, 2, 0]),
np.array([0, -2, 0]),
np.array([2, 0, 0]),
np.array([-2, 0, 0]),
]
rgb_final = None
diff_final = None
for rays_o in tqdm(rays_o_list):
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
rays_d = pt_vertices.reshape(-1, 3) - rays_o
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1)
# batch_size = 2**14
# batch_num = (rays_o.shape[0] // batch_size) + 1
# rgb_list = []
# depth_diff_list = []
# for b in range(batch_num):
# cur_rays_o = rays_o[b * batch_size: (b + 1) * batch_size]
# cur_rays_d = rays_d[b * batch_size: (b + 1) * batch_size]
with torch.no_grad():
render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane),
rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs,
whole_img=False, tvloss=False)
rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy()
depth_diff = np.abs(dist - depth)
# rgb_list.append(rgb)
# depth_diff_list.append(depth_diff)
# del render_out
# torch.cuda.empty_cache()
# rgb = np.concatenate(rgb_list, 0)
# depth_diff = np.concatenate(depth_diff_list, 0)
if rgb_final is None:
rgb_final = rgb.copy()
diff_final = depth_diff.copy()
else:
ind = diff_final > depth_diff
rgb_final[ind] = rgb[ind]
diff_final[ind] = depth_diff[ind]
# bgr to rgb
rgb_final = np.stack([
rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0]
], -1)
# export to ply
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8))
if save_name:
trimesh.exchange.export.export_mesh(mesh, save_name, file_type='ply')
else:
trimesh.exchange.export.export_mesh(mesh, triplane_fname[:-4] + '.ply', file_type='ply')
# load triplane
# fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/sample_16_0.npy'
# u = np.load(fname)
# triplane_fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/triplane_16_0.npy'
# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt'
# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt_simple'
folder = '/mnt/lustre/hongfangzhou.p/AE3D/log/diff_res32ch8_preprocess_ca_text_new_triplane_96_full_openaimodel_only_cap3d_high_quality_7w/sample_demo_424_prompts_for_demo_30_60_10'
save_folder = folder + '_extract_mesh'
os.makedirs(save_folder, exist_ok=True)
fnames = [f.replace('_sample', 'triplane').replace('mp4', 'npy') for f in os.listdir(folder) if f.startswith('_')]
prompts = [l.strip() for l in open('test/prompts_for_demo_2.txt', 'r').readlines()][30:60]
# fnames = [os.path.join(folder, f) for f in os.listdir(folder) if (f.startswith('triplane') and f.endswith('.npy'))]
fnames = sorted(fnames)
def extract_number(s):
return int(s.split('_')[-2])
def extract_id(s):
return s.split('_')[-1].split('.')[0]
for fname in fnames:
try:
print(fname)
extract_mesh(os.path.join(folder, fname), os.path.join(save_folder, prompts[extract_number(fname)].replace(' ', '_') + '_' + extract_id(fname) + '.ply'))
except Exception as e:
print(e)