Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for build_cityscapes_data.""" | |
import os | |
from absl import flags | |
import numpy as np | |
from PIL import Image | |
import tensorflow as tf | |
from deeplab2.data import build_cityscapes_data | |
FLAGS = flags.FLAGS | |
_TEST_DATA_DIR = 'deeplab2/data/testdata' | |
_TEST_FILE_PREFIX = 'dummy_000000_000000' | |
class BuildCityscapesDataTest(tf.test.TestCase): | |
def test_read_segments(self): | |
cityscapes_root = os.path.join(_TEST_DATA_DIR) | |
segments_dict = build_cityscapes_data._read_segments( | |
cityscapes_root, dataset_split='dummy') | |
self.assertIn(_TEST_FILE_PREFIX, segments_dict) | |
_, segments = segments_dict[_TEST_FILE_PREFIX] | |
self.assertLen(segments, 10) | |
def test_generate_panoptic_label(self): | |
FLAGS.treat_crowd_as_ignore = False # Test a more complicated setting | |
cityscapes_root = os.path.join(_TEST_DATA_DIR) | |
segments_dict = build_cityscapes_data._read_segments( | |
cityscapes_root, dataset_split='dummy') | |
annotation_file_name, segments = segments_dict[_TEST_FILE_PREFIX] | |
panoptic_annotation_file = build_cityscapes_data._get_panoptic_annotation( | |
cityscapes_root, dataset_split='dummy', | |
annotation_file_name=annotation_file_name) | |
panoptic_label = build_cityscapes_data._generate_panoptic_label( | |
panoptic_annotation_file, segments) | |
# Check panoptic label matches golden file. | |
golden_file_path = os.path.join(_TEST_DATA_DIR, | |
'dummy_gt_for_vps.png') | |
with tf.io.gfile.GFile(golden_file_path, 'rb') as f: | |
golden_label = Image.open(f) | |
# The PNG file is encoded by: | |
# color = [segmentId % 256, segmentId // 256, segmentId // 256 // 256] | |
golden_label = np.dot(np.asarray(golden_label), [1, 256, 256 * 256]) | |
np.testing.assert_array_equal(panoptic_label, golden_label) | |
if __name__ == '__main__': | |
tf.test.main() | |