Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contains common utility functions and classes for building dataset.""" | |
import collections | |
import io | |
import numpy as np | |
from PIL import Image | |
from PIL import ImageOps | |
import tensorflow as tf | |
from deeplab2 import common | |
_PANOPTIC_LABEL_FORMAT = 'raw' | |
def read_image(image_data): | |
"""Decodes image from in-memory data. | |
Args: | |
image_data: Bytes data representing encoded image. | |
Returns: | |
Decoded PIL.Image object. | |
""" | |
image = Image.open(io.BytesIO(image_data)) | |
try: | |
image = ImageOps.exif_transpose(image) | |
except TypeError: | |
# capture and ignore this bug: | |
# https://github.com/python-pillow/Pillow/issues/3973 | |
pass | |
return image | |
def get_image_dims(image_data, check_is_rgb=False): | |
"""Decodes image and return its height and width. | |
Args: | |
image_data: Bytes data representing encoded image. | |
check_is_rgb: Whether to check encoded image is RGB. | |
Returns: | |
Decoded image size as a tuple of (height, width) | |
Raises: | |
ValueError: If check_is_rgb is set and input image has other format. | |
""" | |
image = read_image(image_data) | |
if check_is_rgb and image.mode != 'RGB': | |
raise ValueError('Expects RGB image data, gets mode: %s' % image.mode) | |
width, height = image.size | |
return height, width | |
def _int64_list_feature(values): | |
"""Returns a TF-Feature of int64_list. | |
Args: | |
values: A scalar or an iterable of integer values. | |
Returns: | |
A TF-Feature. | |
""" | |
if not isinstance(values, collections.Iterable): | |
values = [values] | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) | |
def _bytes_list_feature(values): | |
"""Returns a TF-Feature of bytes. | |
Args: | |
values: A string. | |
Returns: | |
A TF-Feature. | |
""" | |
if isinstance(values, str): | |
values = values.encode() | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) | |
def create_features(image_data, | |
image_format, | |
filename, | |
label_data=None, | |
label_format=None): | |
"""Creates image/segmentation features. | |
Args: | |
image_data: String or byte stream of encoded image data. | |
image_format: String, image data format, should be either 'jpeg' or 'png'. | |
filename: String, image filename. | |
label_data: String or byte stream of (potentially) encoded label data. If | |
None, we skip to write it to tf.train.Example. | |
label_format: String, label data format, should be either 'png' or 'raw'. If | |
None, we skip to write it to tf.train.Example. | |
Returns: | |
A dictionary of feature name to tf.train.Feature maaping. | |
""" | |
if image_format not in ('jpeg', 'png'): | |
raise ValueError('Unsupported image format: %s' % image_format) | |
# Check color mode, and convert grey image to rgb image. | |
image = read_image(image_data) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image_data = io.BytesIO() | |
image.save(image_data, format=image_format) | |
image_data = image_data.getvalue() | |
height, width = get_image_dims(image_data, check_is_rgb=True) | |
feature_dict = { | |
common.KEY_ENCODED_IMAGE: _bytes_list_feature(image_data), | |
common.KEY_IMAGE_FILENAME: _bytes_list_feature(filename), | |
common.KEY_IMAGE_FORMAT: _bytes_list_feature(image_format), | |
common.KEY_IMAGE_HEIGHT: _int64_list_feature(height), | |
common.KEY_IMAGE_WIDTH: _int64_list_feature(width), | |
common.KEY_IMAGE_CHANNELS: _int64_list_feature(3), | |
} | |
if label_data is None: | |
return feature_dict | |
if label_format == 'png': | |
label_height, label_width = get_image_dims(label_data) | |
if (label_height, label_width) != (height, width): | |
raise ValueError('Image (%s) and label (%s) shape mismatch' % | |
((height, width), (label_height, label_width))) | |
elif label_format == 'raw': | |
# Raw label encodes int32 array. | |
expected_label_size = height * width * np.dtype(np.int32).itemsize | |
if len(label_data) != expected_label_size: | |
raise ValueError('Expects raw label data length %d, gets %d' % | |
(expected_label_size, len(label_data))) | |
else: | |
raise ValueError('Unsupported label format: %s' % label_format) | |
feature_dict.update({ | |
common.KEY_ENCODED_LABEL: _bytes_list_feature(label_data), | |
common.KEY_LABEL_FORMAT: _bytes_list_feature(label_format) | |
}) | |
return feature_dict | |
def create_tfexample(image_data, | |
image_format, | |
filename, | |
label_data=None, | |
label_format=None): | |
"""Converts one image/segmentation pair to TF example. | |
Args: | |
image_data: String or byte stream of encoded image data. | |
image_format: String, image data format, should be either 'jpeg' or 'png'. | |
filename: String, image filename. | |
label_data: String or byte stream of (potentially) encoded label data. If | |
None, we skip to write it to tf.train.Example. | |
label_format: String, label data format, should be either 'png' or 'raw'. If | |
None, we skip to write it to tf.train.Example. | |
Returns: | |
TF example proto. | |
""" | |
feature_dict = create_features(image_data, image_format, filename, label_data, | |
label_format) | |
return tf.train.Example(features=tf.train.Features(feature=feature_dict)) | |
def create_video_tfexample(image_data, | |
image_format, | |
filename, | |
sequence_id, | |
image_id, | |
label_data=None, | |
label_format=None, | |
prev_image_data=None, | |
prev_label_data=None): | |
"""Converts one video frame/panoptic segmentation pair to TF example. | |
Args: | |
image_data: String or byte stream of encoded image data. | |
image_format: String, image data format, should be either 'jpeg' or 'png'. | |
filename: String, image filename. | |
sequence_id: ID of the video sequence as a string. | |
image_id: ID of the image as a string. | |
label_data: String or byte stream of (potentially) encoded label data. If | |
None, we skip to write it to tf.train.Example. | |
label_format: String, label data format, should be either 'png' or 'raw'. If | |
None, we skip to write it to tf.train.Example. | |
prev_image_data: An optional string or byte stream of encoded previous image | |
data. | |
prev_label_data: An optional string or byte stream of (potentially) encoded | |
previous label data. | |
Returns: | |
TF example proto. | |
""" | |
feature_dict = create_features(image_data, image_format, filename, label_data, | |
label_format) | |
feature_dict.update({ | |
common.KEY_SEQUENCE_ID: _bytes_list_feature(sequence_id), | |
common.KEY_FRAME_ID: _bytes_list_feature(image_id) | |
}) | |
if prev_image_data is not None: | |
feature_dict[common.KEY_ENCODED_PREV_IMAGE] = _bytes_list_feature( | |
prev_image_data) | |
if prev_label_data is not None: | |
feature_dict[common.KEY_ENCODED_PREV_LABEL] = _bytes_list_feature( | |
prev_label_data) | |
return tf.train.Example(features=tf.train.Features(feature=feature_dict)) | |
def create_video_and_depth_tfexample(image_data, | |
image_format, | |
filename, | |
sequence_id, | |
image_id, | |
label_data=None, | |
label_format=None, | |
next_image_data=None, | |
next_label_data=None, | |
depth_data=None, | |
depth_format=None): | |
"""Converts an image/segmentation pair and depth of first frame to TF example. | |
The image pair contains the current frame and the next frame with the | |
current frame including depth label. | |
Args: | |
image_data: String or byte stream of encoded image data. | |
image_format: String, image data format, should be either 'jpeg' or 'png'. | |
filename: String, image filename. | |
sequence_id: ID of the video sequence as a string. | |
image_id: ID of the image as a string. | |
label_data: String or byte stream of (potentially) encoded label data. If | |
None, we skip to write it to tf.train.Example. | |
label_format: String, label data format, should be either 'png' or 'raw'. If | |
None, we skip to write it to tf.train.Example. | |
next_image_data: An optional string or byte stream of encoded next image | |
data. | |
next_label_data: An optional string or byte stream of (potentially) encoded | |
next label data. | |
depth_data: An optional string or byte sream of encoded depth data. | |
depth_format: String, depth data format, should be either 'png' or 'raw'. | |
Returns: | |
TF example proto. | |
""" | |
feature_dict = create_features(image_data, image_format, filename, label_data, | |
label_format) | |
feature_dict.update({ | |
common.KEY_SEQUENCE_ID: _bytes_list_feature(sequence_id), | |
common.KEY_FRAME_ID: _bytes_list_feature(image_id) | |
}) | |
if next_image_data is not None: | |
feature_dict[common.KEY_ENCODED_NEXT_IMAGE] = _bytes_list_feature( | |
next_image_data) | |
if next_label_data is not None: | |
feature_dict[common.KEY_ENCODED_NEXT_LABEL] = _bytes_list_feature( | |
next_label_data) | |
if depth_data is not None: | |
feature_dict[common.KEY_ENCODED_DEPTH] = _bytes_list_feature( | |
depth_data) | |
feature_dict[common.KEY_DEPTH_FORMAT] = _bytes_list_feature( | |
depth_format) | |
return tf.train.Example(features=tf.train.Features(feature=feature_dict)) | |
class SegmentationDecoder(object): | |
"""Basic parser to decode serialized tf.Example.""" | |
def __init__(self, | |
is_panoptic_dataset=True, | |
is_video_dataset=False, | |
use_two_frames=False, | |
use_next_frame=False, | |
decode_groundtruth_label=True): | |
self._is_panoptic_dataset = is_panoptic_dataset | |
self._is_video_dataset = is_video_dataset | |
self._use_two_frames = use_two_frames | |
self._use_next_frame = use_next_frame | |
self._decode_groundtruth_label = decode_groundtruth_label | |
string_feature = tf.io.FixedLenFeature((), tf.string) | |
int_feature = tf.io.FixedLenFeature((), tf.int64) | |
self._keys_to_features = { | |
common.KEY_ENCODED_IMAGE: string_feature, | |
common.KEY_IMAGE_FILENAME: string_feature, | |
common.KEY_IMAGE_FORMAT: string_feature, | |
common.KEY_IMAGE_HEIGHT: int_feature, | |
common.KEY_IMAGE_WIDTH: int_feature, | |
common.KEY_IMAGE_CHANNELS: int_feature, | |
} | |
if decode_groundtruth_label: | |
self._keys_to_features[common.KEY_ENCODED_LABEL] = string_feature | |
if self._is_video_dataset: | |
self._keys_to_features[common.KEY_SEQUENCE_ID] = string_feature | |
self._keys_to_features[common.KEY_FRAME_ID] = string_feature | |
# Two-frame specific processing. | |
if self._use_two_frames: | |
self._keys_to_features[common.KEY_ENCODED_PREV_IMAGE] = string_feature | |
if decode_groundtruth_label: | |
self._keys_to_features[common.KEY_ENCODED_PREV_LABEL] = string_feature | |
# Next-frame specific processing. | |
if self._use_next_frame: | |
self._keys_to_features[common.KEY_ENCODED_NEXT_IMAGE] = string_feature | |
if decode_groundtruth_label: | |
self._keys_to_features[common.KEY_ENCODED_NEXT_LABEL] = string_feature | |
def _decode_image(self, parsed_tensors, key): | |
"""Decodes image udner key from parsed tensors.""" | |
image = tf.io.decode_image( | |
parsed_tensors[key], | |
channels=3, | |
dtype=tf.dtypes.uint8, | |
expand_animations=False) | |
image.set_shape([None, None, 3]) | |
return image | |
def _decode_label(self, parsed_tensors, label_key): | |
"""Decodes segmentation label under label_key from parsed tensors.""" | |
if self._is_panoptic_dataset: | |
flattened_label = tf.io.decode_raw( | |
parsed_tensors[label_key], out_type=tf.int32) | |
label_shape = tf.stack([ | |
parsed_tensors[common.KEY_IMAGE_HEIGHT], | |
parsed_tensors[common.KEY_IMAGE_WIDTH], 1 | |
]) | |
label = tf.reshape(flattened_label, label_shape) | |
return label | |
label = tf.io.decode_image(parsed_tensors[label_key], channels=1) | |
label.set_shape([None, None, 1]) | |
return label | |
def __call__(self, serialized_example): | |
parsed_tensors = tf.io.parse_single_example( | |
serialized_example, features=self._keys_to_features) | |
return_dict = { | |
'image': | |
self._decode_image(parsed_tensors, common.KEY_ENCODED_IMAGE), | |
'image_name': | |
parsed_tensors[common.KEY_IMAGE_FILENAME], | |
'height': | |
tf.cast(parsed_tensors[common.KEY_IMAGE_HEIGHT], dtype=tf.int32), | |
'width': | |
tf.cast(parsed_tensors[common.KEY_IMAGE_WIDTH], dtype=tf.int32), | |
} | |
return_dict['label'] = None | |
if self._decode_groundtruth_label: | |
return_dict['label'] = self._decode_label(parsed_tensors, | |
common.KEY_ENCODED_LABEL) | |
if self._is_video_dataset: | |
return_dict['sequence'] = parsed_tensors[common.KEY_SEQUENCE_ID] | |
if self._use_two_frames: | |
return_dict['prev_image'] = self._decode_image( | |
parsed_tensors, common.KEY_ENCODED_PREV_IMAGE) | |
if self._decode_groundtruth_label: | |
return_dict['prev_label'] = self._decode_label( | |
parsed_tensors, common.KEY_ENCODED_PREV_LABEL) | |
if self._use_next_frame: | |
return_dict['next_image'] = self._decode_image( | |
parsed_tensors, common.KEY_ENCODED_NEXT_IMAGE) | |
if self._decode_groundtruth_label: | |
return_dict['next_label'] = self._decode_label( | |
parsed_tensors, common.KEY_ENCODED_NEXT_LABEL) | |
return return_dict | |