|
import os |
|
import torch |
|
import argparse |
|
import mcubes |
|
import trimesh |
|
import numpy as np |
|
from tqdm import tqdm |
|
from omegaconf import OmegaConf |
|
from utility.initialize import instantiate_from_config, get_obj_from_str |
|
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=str, default=None, required=True) |
|
parser.add_argument("--ckpt", type=str, default=None, required=True) |
|
args = parser.parse_args() |
|
configs = OmegaConf.load(args.config) |
|
device = 'cuda' |
|
vae = get_obj_from_str(configs.model.params.first_stage_config['target'])(**configs.model.params.first_stage_config['params']) |
|
vae = vae.to(device) |
|
vae.eval() |
|
|
|
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(args.ckpt, map_location='cpu', strict=False, **configs.model.params) |
|
model = model.to(device) |
|
|
|
def extract_mesh(triplane_fname, save_name=None): |
|
latent = torch.from_numpy(np.load(triplane_fname)).to(device) |
|
with torch.no_grad(): |
|
with model.ema_scope(): |
|
triplane = model.decode_first_stage(latent) |
|
|
|
|
|
res = 128 |
|
c_list = torch.linspace(-1.2, 1.2, steps=res) |
|
grid_x, grid_y, grid_z = torch.meshgrid( |
|
c_list, c_list, c_list, indexing='ij' |
|
) |
|
coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) |
|
plane_axes = generate_planes() |
|
feats = sample_from_planes( |
|
plane_axes, triplane.reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4 |
|
) |
|
fake_dirs = torch.zeros_like(coords) |
|
fake_dirs[..., 0] = 1 |
|
with torch.no_grad(): |
|
out = vae.triplane_decoder.decoder(feats, fake_dirs) |
|
u = out['sigma'].reshape(res, res, res).detach().cpu().numpy() |
|
del out |
|
|
|
|
|
vertices, triangles = mcubes.marching_cubes(u, 8) |
|
min_bound = np.array([-1.2, -1.2, -1.2]) |
|
max_bound = np.array([1.2, 1.2, 1.2]) |
|
vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :] |
|
pt_vertices = torch.from_numpy(vertices).to(device) |
|
|
|
|
|
res_triplane = 256 |
|
|
|
|
|
render_kwargs = { |
|
'depth_resolution': 128, |
|
'disparity_space_sampling': False, |
|
'box_warp': 2.4, |
|
'depth_resolution_importance': 128, |
|
'clamp_mode': 'softplus', |
|
'white_back': True, |
|
'det': True |
|
} |
|
|
|
|
|
|
|
rays_o_list = [ |
|
np.array([0, 0, 2]), |
|
np.array([0, 0, -2]), |
|
np.array([0, 2, 0]), |
|
np.array([0, -2, 0]), |
|
np.array([2, 0, 0]), |
|
np.array([-2, 0, 0]), |
|
] |
|
rgb_final = None |
|
diff_final = None |
|
for rays_o in tqdm(rays_o_list): |
|
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device) |
|
rays_d = pt_vertices.reshape(-1, 3) - rays_o |
|
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1) |
|
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane), |
|
rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs, |
|
whole_img=False, tvloss=False) |
|
rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() |
|
depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy() |
|
depth_diff = np.abs(dist - depth) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rgb_final is None: |
|
rgb_final = rgb.copy() |
|
diff_final = depth_diff.copy() |
|
|
|
else: |
|
ind = diff_final > depth_diff |
|
rgb_final[ind] = rgb[ind] |
|
diff_final[ind] = depth_diff[ind] |
|
|
|
|
|
|
|
rgb_final = np.stack([ |
|
rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0] |
|
], -1) |
|
|
|
|
|
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8)) |
|
if save_name: |
|
trimesh.exchange.export.export_mesh(mesh, save_name, file_type='ply') |
|
else: |
|
trimesh.exchange.export.export_mesh(mesh, triplane_fname[:-4] + '.ply', file_type='ply') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
folder = '/mnt/lustre/hongfangzhou.p/AE3D/log/diff_res32ch8_preprocess_ca_text_new_triplane_96_full_openaimodel_only_cap3d_high_quality_7w/sample_demo_424_prompts_for_demo_30_60_10' |
|
save_folder = folder + '_extract_mesh' |
|
os.makedirs(save_folder, exist_ok=True) |
|
fnames = [f.replace('_sample', 'triplane').replace('mp4', 'npy') for f in os.listdir(folder) if f.startswith('_')] |
|
prompts = [l.strip() for l in open('test/prompts_for_demo_2.txt', 'r').readlines()][30:60] |
|
|
|
fnames = sorted(fnames) |
|
|
|
def extract_number(s): |
|
return int(s.split('_')[-2]) |
|
|
|
def extract_id(s): |
|
return s.split('_')[-1].split('.')[0] |
|
|
|
for fname in fnames: |
|
try: |
|
print(fname) |
|
extract_mesh(os.path.join(folder, fname), os.path.join(save_folder, prompts[extract_number(fname)].replace(' ', '_') + '_' + extract_id(fname) + '.ply')) |
|
except Exception as e: |
|
print(e) |
|
|