Linly-Talker / pytorch3d /tests /implicitron /test_dataset_visualize.py
linxianzhong0128's picture
Upload folder using huggingface_hub
7088d16 verified
raw
history blame contribute delete
7.15 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import os
import unittest
import torch
import torchvision
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
from pytorch3d.vis.plotly_vis import plot_scene
if os.environ.get("INSIDE_RE_WORKER") is None:
from visdom import Visdom
from tests.common_testing import interactive_testing_requested
from .common_resources import get_skateboard_data
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
class TestDatasetVisualize(unittest.TestCase):
def setUp(self):
if not interactive_testing_requested():
return
category = "skateboard"
stack = contextlib.ExitStack()
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
self.addCleanup(stack.close)
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
expand_args_fields(JsonIndexDataset)
self.datasets = {
"simple": JsonIndexDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nonsquare": JsonIndexDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nocrop": JsonIndexDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=False,
load_point_clouds=True,
path_manager=path_manager,
),
}
self.datasets.update(
{
k + "_newndc": _change_annotations_to_new_ndc(dataset)
for k, dataset in self.datasets.items()
}
)
self.visdom = Visdom(port=VISDOM_PORT)
if not self.visdom.check_connection():
print("Visdom server not running! Disabling visdom visualizations.")
self.visdom = None
def _render_one_pointcloud(self, point_cloud, cameras, render_size):
(_image_render, _, _) = render_point_cloud_pytorch3d(
cameras,
point_cloud,
render_size=render_size,
point_radius=1e-2,
topk=10,
bg_color=0.0,
)
return _image_render.clamp(0.0, 1.0)
def test_one(self):
"""Test dataset visualization."""
if not interactive_testing_requested():
return
for max_frames in (16, -1):
for load_dataset_point_cloud in (True, False):
for dataset_key in self.datasets:
self._gen_and_render_pointcloud(
max_frames, load_dataset_point_cloud, dataset_key
)
def _gen_and_render_pointcloud(
self, max_frames, load_dataset_point_cloud, dataset_key
):
dataset = self.datasets[dataset_key]
# load the point cloud of the first sequence
sequence_show = list(dataset.seq_annots.keys())[0]
device = torch.device("cuda:0")
point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud(
dataset,
sequence_name=sequence_show,
mask_points=True,
max_frames=max_frames,
num_workers=10,
load_dataset_point_cloud=load_dataset_point_cloud,
)
# render on gpu
point_cloud = point_cloud.to(device)
cameras = sequence_frame_data.camera.to(device)
# render the point_cloud from the viewpoint of loaded cameras
images_render = torch.cat(
[
self._render_one_pointcloud(
point_cloud,
cameras[frame_i],
(
dataset.image_height,
dataset.image_width,
),
)
for frame_i in range(len(cameras))
]
).cpu()
images_gt_and_render = torch.cat(
[sequence_frame_data.image_rgb, images_render], dim=3
)
imfile = os.path.join(
os.path.split(os.path.abspath(__file__))[0],
"test_dataset_visualize"
+ f"_max_frames={max_frames}"
+ f"_load_pcl={load_dataset_point_cloud}.png",
)
print(f"Exporting image {imfile}.")
torchvision.utils.save_image(images_gt_and_render, imfile, nrow=2)
if self.visdom is not None:
test_name = f"{max_frames}_{load_dataset_point_cloud}_{dataset_key}"
self.visdom.images(
images_gt_and_render,
env="test_dataset_visualize",
win=f"pcl_renders_{test_name}",
opts={"title": f"pcl_renders_{test_name}"},
)
plotlyplot = plot_scene(
{
"scene_batch": {
"cameras": cameras,
"point_cloud": point_cloud,
}
},
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1.0,
)
self.visdom.plotlyplot(
plotlyplot,
env="test_dataset_visualize",
win=f"pcl_{test_name}",
)
def _change_annotations_to_new_ndc(dataset):
dataset = copy.deepcopy(dataset)
for frame in dataset.frame_annots:
vp = frame["frame_annotation"].viewpoint
vp.intrinsics_format = "ndc_isotropic"
# this assume the focal length to be equal on x and y (ok for a test)
max_flength = max(vp.focal_length)
vp.principal_point = (
vp.principal_point[0] * max_flength / vp.focal_length[0],
vp.principal_point[1] * max_flength / vp.focal_length[1],
)
vp.focal_length = (
max_flength,
max_flength,
)
return dataset