Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for sample_generator.""" | |
import os | |
from absl import flags | |
import numpy as np | |
from PIL import Image | |
import tensorflow as tf | |
from deeplab2 import common | |
from deeplab2.data import data_utils | |
from deeplab2.data import dataset | |
from deeplab2.data import sample_generator | |
image_utils = tf.keras.preprocessing.image | |
flags.DEFINE_string( | |
'panoptic_annotation_data', | |
'deeplab2/data/testdata/', | |
'Path to annotated test image.') | |
flags.DEFINE_bool('update_golden_data', False, | |
'Whether or not to update the golden data for testing.') | |
FLAGS = flags.FLAGS | |
_FILENAME_PREFIX = 'dummy_000000_000000' | |
_IMAGE_FOLDER = 'leftImg8bit/' | |
_TARGET_FOLDER = 'targets/' | |
def _get_groundtruth_image(computed_image_array, groundtruth_image_filename): | |
if FLAGS.update_golden_data: | |
image = Image.fromarray(tf.squeeze(computed_image_array).numpy()) | |
with tf.io.gfile.GFile(groundtruth_image_filename, mode='wb') as fp: | |
image.save(fp) | |
return computed_image_array | |
with tf.io.gfile.GFile(groundtruth_image_filename, mode='rb') as fp: | |
image = data_utils.read_image(fp.read()) | |
# If loaded image has 3 channels, the returned shape is [height, width, 3]. | |
# If loaded image has 1 channel, the returned shape is [height, width]. | |
image = np.squeeze(image_utils.img_to_array(image)) | |
return image | |
def _get_groundtruth_array(computed_image_array, groundtruth_image_filename): | |
if FLAGS.update_golden_data: | |
with tf.io.gfile.GFile(groundtruth_image_filename, mode='wb') as fp: | |
np.save(fp, computed_image_array) | |
return computed_image_array | |
with tf.io.gfile.GFile(groundtruth_image_filename, mode='rb') as fp: | |
# If loaded data has C>1 channels, the returned shape is [height, width, C]. | |
# If loaded data has 1 channel, the returned shape is [height, width]. | |
array = np.squeeze(np.load(fp)) | |
return array | |
class PanopticSampleGeneratorTest(tf.test.TestCase): | |
def setUp(self): | |
super().setUp() | |
self._test_img_data_dir = os.path.join( | |
FLAGS.test_srcdir, | |
FLAGS.panoptic_annotation_data, | |
_IMAGE_FOLDER) | |
self._test_gt_data_dir = os.path.join( | |
FLAGS.test_srcdir, | |
FLAGS.panoptic_annotation_data) | |
self._test_target_data_dir = os.path.join( | |
FLAGS.test_srcdir, | |
FLAGS.panoptic_annotation_data, | |
_TARGET_FOLDER) | |
image_path = self._test_img_data_dir + _FILENAME_PREFIX + '_leftImg8bit.png' | |
with tf.io.gfile.GFile(image_path, 'rb') as image_file: | |
rgb_image = data_utils.read_image(image_file.read()) | |
self._rgb_image = tf.convert_to_tensor(np.array(rgb_image)) | |
label_path = self._test_gt_data_dir + 'dummy_gt_for_vps.png' | |
with tf.io.gfile.GFile(label_path, 'rb') as label_file: | |
label = data_utils.read_image(label_file.read()) | |
self._label = tf.expand_dims(tf.convert_to_tensor( | |
np.dot(np.array(label), [1, 256, 256 * 256])), -1) | |
def test_input_generator(self): | |
tf.random.set_seed(0) | |
np.random.seed(0) | |
small_instances = {'threshold': 4096, 'weight': 3.0} | |
generator = sample_generator.PanopticSampleGenerator( | |
dataset.CITYSCAPES_PANOPTIC_INFORMATION._asdict(), | |
focus_small_instances=small_instances, | |
is_training=True, | |
crop_size=[769, 769], | |
thing_id_mask_annotations=True) | |
input_sample = { | |
'image': self._rgb_image, | |
'image_name': 'test_image', | |
'label': self._label, | |
'height': 800, | |
'width': 800 | |
} | |
sample = generator(input_sample) | |
self.assertIn(common.IMAGE, sample) | |
self.assertIn(common.GT_SEMANTIC_KEY, sample) | |
self.assertIn(common.GT_PANOPTIC_KEY, sample) | |
self.assertIn(common.GT_INSTANCE_CENTER_KEY, sample) | |
self.assertIn(common.GT_INSTANCE_REGRESSION_KEY, sample) | |
self.assertIn(common.GT_IS_CROWD, sample) | |
self.assertIn(common.GT_THING_ID_MASK_KEY, sample) | |
self.assertIn(common.GT_THING_ID_CLASS_KEY, sample) | |
self.assertIn(common.SEMANTIC_LOSS_WEIGHT_KEY, sample) | |
self.assertIn(common.CENTER_LOSS_WEIGHT_KEY, sample) | |
self.assertIn(common.REGRESSION_LOSS_WEIGHT_KEY, sample) | |
self.assertListEqual(sample[common.IMAGE].shape.as_list(), [769, 769, 3]) | |
self.assertListEqual(sample[common.GT_SEMANTIC_KEY].shape.as_list(), | |
[769, 769]) | |
self.assertListEqual(sample[common.GT_PANOPTIC_KEY].shape.as_list(), | |
[769, 769]) | |
self.assertListEqual(sample[common.GT_INSTANCE_CENTER_KEY].shape.as_list(), | |
[769, 769]) | |
self.assertListEqual( | |
sample[common.GT_INSTANCE_REGRESSION_KEY].shape.as_list(), | |
[769, 769, 2]) | |
self.assertListEqual(sample[common.GT_IS_CROWD].shape.as_list(), [769, 769]) | |
self.assertListEqual(sample[common.GT_THING_ID_MASK_KEY].shape.as_list(), | |
[769, 769]) | |
self.assertListEqual(sample[common.GT_THING_ID_CLASS_KEY].shape.as_list(), | |
[128]) | |
self.assertListEqual( | |
sample[common.SEMANTIC_LOSS_WEIGHT_KEY].shape.as_list(), [769, 769]) | |
self.assertListEqual(sample[common.CENTER_LOSS_WEIGHT_KEY].shape.as_list(), | |
[769, 769]) | |
self.assertListEqual( | |
sample[common.REGRESSION_LOSS_WEIGHT_KEY].shape.as_list(), | |
[769, 769]) | |
gt_sem = sample[common.GT_SEMANTIC_KEY] | |
gt_pan = sample[common.GT_PANOPTIC_KEY] | |
gt_center = tf.cast(sample[common.GT_INSTANCE_CENTER_KEY] * 255, tf.uint8) | |
gt_is_crowd = sample[common.GT_IS_CROWD] | |
gt_thing_id_mask = sample[common.GT_THING_ID_MASK_KEY] | |
gt_thing_id_class = sample[common.GT_THING_ID_CLASS_KEY] | |
image = tf.cast(sample[common.IMAGE], tf.uint8) | |
# semantic weights can be in range of [0, 3] in this example. | |
semantic_weights = tf.cast(sample[common.SEMANTIC_LOSS_WEIGHT_KEY] * 85, | |
tf.uint8) | |
center_weights = tf.cast(sample[common.CENTER_LOSS_WEIGHT_KEY] * 255, | |
tf.uint8) | |
offset_weights = tf.cast(sample[common.REGRESSION_LOSS_WEIGHT_KEY] * 255, | |
tf.uint8) | |
np.testing.assert_almost_equal( | |
image.numpy(), | |
_get_groundtruth_image( | |
image, | |
self._test_target_data_dir + 'rgb_target.png')) | |
np.testing.assert_almost_equal( | |
gt_sem.numpy(), | |
_get_groundtruth_image( | |
gt_sem, | |
self._test_target_data_dir + 'semantic_target.png')) | |
# Save gt as png. Pillow is currently unable to correctly save the image as | |
# 32bit, but uses 16bit which overflows. | |
_ = _get_groundtruth_image( | |
gt_pan, self._test_target_data_dir + 'panoptic_target.png') | |
np.testing.assert_almost_equal( | |
gt_pan.numpy(), | |
_get_groundtruth_array( | |
gt_pan, | |
self._test_target_data_dir + 'panoptic_target.npy')) | |
np.testing.assert_almost_equal( | |
gt_thing_id_mask.numpy(), | |
_get_groundtruth_array( | |
gt_thing_id_mask, | |
self._test_target_data_dir + 'thing_id_mask_target.npy')) | |
np.testing.assert_almost_equal( | |
gt_thing_id_class.numpy(), | |
_get_groundtruth_array( | |
gt_thing_id_class, | |
self._test_target_data_dir + 'thing_id_class_target.npy')) | |
np.testing.assert_almost_equal( | |
gt_center.numpy(), | |
_get_groundtruth_image( | |
gt_center, | |
self._test_target_data_dir + 'center_target.png')) | |
np.testing.assert_almost_equal( | |
sample[common.GT_INSTANCE_REGRESSION_KEY].numpy(), | |
_get_groundtruth_array( | |
sample[common.GT_INSTANCE_REGRESSION_KEY].numpy(), | |
self._test_target_data_dir + 'offset_target.npy')) | |
np.testing.assert_array_equal( | |
gt_is_crowd.numpy(), | |
_get_groundtruth_array(gt_is_crowd.numpy(), | |
self._test_target_data_dir + 'is_crowd.npy')) | |
np.testing.assert_almost_equal( | |
semantic_weights.numpy(), | |
_get_groundtruth_image( | |
semantic_weights, | |
self._test_target_data_dir + 'semantic_weights.png')) | |
np.testing.assert_almost_equal( | |
center_weights.numpy(), | |
_get_groundtruth_image( | |
center_weights, | |
self._test_target_data_dir + 'center_weights.png')) | |
np.testing.assert_almost_equal( | |
offset_weights.numpy(), | |
_get_groundtruth_image( | |
offset_weights, | |
self._test_target_data_dir + 'offset_weights.png')) | |
def test_input_generator_eval(self): | |
tf.random.set_seed(0) | |
np.random.seed(0) | |
small_instances = {'threshold': 4096, 'weight': 3.0} | |
generator = sample_generator.PanopticSampleGenerator( | |
dataset.CITYSCAPES_PANOPTIC_INFORMATION._asdict(), | |
focus_small_instances=small_instances, | |
is_training=False, | |
crop_size=[800, 800]) | |
input_sample = { | |
'image': self._rgb_image, | |
'image_name': 'test_image', | |
'label': self._label, | |
'height': 800, | |
'width': 800 | |
} | |
sample = generator(input_sample) | |
self.assertIn(common.GT_SEMANTIC_RAW, sample) | |
self.assertIn(common.GT_PANOPTIC_RAW, sample) | |
self.assertIn(common.GT_IS_CROWD_RAW, sample) | |
gt_sem_raw = sample[common.GT_SEMANTIC_RAW] | |
gt_pan_raw = sample[common.GT_PANOPTIC_RAW] | |
gt_is_crowd_raw = sample[common.GT_IS_CROWD_RAW] | |
self.assertListEqual(gt_sem_raw.shape.as_list(), [800, 800]) | |
self.assertListEqual(gt_pan_raw.shape.as_list(), [800, 800]) | |
self.assertListEqual(gt_is_crowd_raw.shape.as_list(), [800, 800]) | |
np.testing.assert_almost_equal( | |
gt_sem_raw.numpy(), | |
_get_groundtruth_image( | |
gt_sem_raw, | |
self._test_target_data_dir + 'eval_semantic_target.png')) | |
np.testing.assert_almost_equal( | |
gt_pan_raw.numpy(), | |
_get_groundtruth_array( | |
gt_pan_raw, | |
self._test_target_data_dir + 'eval_panoptic_target.npy')) | |
np.testing.assert_almost_equal( | |
gt_is_crowd_raw.numpy(), | |
_get_groundtruth_array(gt_is_crowd_raw, self._test_target_data_dir + | |
'eval_is_crowd.npy')) | |
if __name__ == '__main__': | |
tf.test.main() | |