|
import os
|
|
import copy
|
|
import sys
|
|
import json
|
|
import importlib
|
|
import argparse
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import pandas as pd
|
|
import utils3d
|
|
from tqdm import tqdm
|
|
from easydict import EasyDict as edict
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from queue import Queue
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
def get_data(frames, sha256):
|
|
with ThreadPoolExecutor(max_workers=16) as executor:
|
|
def worker(view):
|
|
image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path'])
|
|
try:
|
|
image = Image.open(image_path)
|
|
except:
|
|
print(f"Error loading image {image_path}")
|
|
return None
|
|
image = image.resize((518, 518), Image.Resampling.LANCZOS)
|
|
image = np.array(image).astype(np.float32) / 255
|
|
image = image[:, :, :3] * image[:, :, 3:]
|
|
image = torch.from_numpy(image).permute(2, 0, 1).float()
|
|
|
|
c2w = torch.tensor(view['transform_matrix'])
|
|
c2w[:3, 1:3] *= -1
|
|
extrinsics = torch.inverse(c2w)
|
|
fov = view['camera_angle_x']
|
|
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
|
|
|
return {
|
|
'image': image,
|
|
'extrinsics': extrinsics,
|
|
'intrinsics': intrinsics
|
|
}
|
|
|
|
datas = executor.map(worker, frames)
|
|
for data in datas:
|
|
if data is not None:
|
|
yield data
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--output_dir', type=str, required=True,
|
|
help='Directory to save the metadata')
|
|
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
|
help='Filter objects with aesthetic score lower than this value')
|
|
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg',
|
|
help='Feature extraction model')
|
|
parser.add_argument('--instances', type=str, default=None,
|
|
help='Instances to process')
|
|
parser.add_argument('--batch_size', type=int, default=16)
|
|
parser.add_argument('--rank', type=int, default=0)
|
|
parser.add_argument('--world_size', type=int, default=1)
|
|
opt = parser.parse_args()
|
|
opt = edict(vars(opt))
|
|
|
|
feature_name = opt.model
|
|
os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True)
|
|
|
|
|
|
dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model)
|
|
dinov2_model.eval().cuda()
|
|
transform = transforms.Compose([
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
n_patch = 518 // 14
|
|
|
|
|
|
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
|
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
|
else:
|
|
raise ValueError('metadata.csv not found')
|
|
if opt.instances is not None:
|
|
with open(opt.instances, 'r') as f:
|
|
instances = f.read().splitlines()
|
|
metadata = metadata[metadata['sha256'].isin(instances)]
|
|
else:
|
|
if opt.filter_low_aesthetic_score is not None:
|
|
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
|
if f'feature_{feature_name}' in metadata.columns:
|
|
metadata = metadata[metadata[f'feature_{feature_name}'] == False]
|
|
metadata = metadata[metadata['voxelized'] == True]
|
|
metadata = metadata[metadata['rendered'] == True]
|
|
|
|
start = len(metadata) * opt.rank // opt.world_size
|
|
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
|
metadata = metadata[start:end]
|
|
records = []
|
|
|
|
|
|
sha256s = list(metadata['sha256'].values)
|
|
for sha256 in copy.copy(sha256s):
|
|
if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')):
|
|
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
|
sha256s.remove(sha256)
|
|
|
|
|
|
load_queue = Queue(maxsize=4)
|
|
try:
|
|
with ThreadPoolExecutor(max_workers=8) as loader_executor, \
|
|
ThreadPoolExecutor(max_workers=8) as saver_executor:
|
|
def loader(sha256):
|
|
try:
|
|
with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f:
|
|
metadata = json.load(f)
|
|
frames = metadata['frames']
|
|
data = []
|
|
for datum in get_data(frames, sha256):
|
|
datum['image'] = transform(datum['image'])
|
|
data.append(datum)
|
|
positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
|
load_queue.put((sha256, data, positions))
|
|
except Exception as e:
|
|
print(f"Error loading data for {sha256}: {e}")
|
|
|
|
loader_executor.map(loader, sha256s)
|
|
|
|
def saver(sha256, pack, patchtokens, uv):
|
|
pack['patchtokens'] = F.grid_sample(
|
|
patchtokens,
|
|
uv.unsqueeze(1),
|
|
mode='bilinear',
|
|
align_corners=False,
|
|
).squeeze(2).permute(0, 2, 1).cpu().numpy()
|
|
pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16)
|
|
save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')
|
|
np.savez_compressed(save_path, **pack)
|
|
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
|
|
|
for _ in tqdm(range(len(sha256s)), desc="Extracting features"):
|
|
sha256, data, positions = load_queue.get()
|
|
positions = torch.from_numpy(positions).float().cuda()
|
|
indices = ((positions + 0.5) * 64).long()
|
|
assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds"
|
|
n_views = len(data)
|
|
N = positions.shape[0]
|
|
pack = {
|
|
'indices': indices.cpu().numpy().astype(np.uint8),
|
|
}
|
|
patchtokens_lst = []
|
|
uv_lst = []
|
|
for i in range(0, n_views, opt.batch_size):
|
|
batch_data = data[i:i+opt.batch_size]
|
|
bs = len(batch_data)
|
|
batch_images = torch.stack([d['image'] for d in batch_data]).cuda()
|
|
batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda()
|
|
batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda()
|
|
features = dinov2_model(batch_images, is_training=True)
|
|
uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1
|
|
patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch)
|
|
patchtokens_lst.append(patchtokens)
|
|
uv_lst.append(uv)
|
|
patchtokens = torch.cat(patchtokens_lst, dim=0)
|
|
uv = torch.cat(uv_lst, dim=0)
|
|
|
|
|
|
saver_executor.submit(saver, sha256, pack, patchtokens, uv)
|
|
|
|
saver_executor.shutdown(wait=True)
|
|
except:
|
|
print("Error happened during processing.")
|
|
|
|
records = pd.DataFrame.from_records(records)
|
|
records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False)
|
|
|