|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Input reader to load segmentation dataset.""" |
|
|
|
import tensorflow as tf |
|
|
|
_NUM_INPUTS_PROCESSED_CONCURRENTLY = 32 |
|
_SHUFFLE_BUFFER_SIZE = 1000 |
|
|
|
|
|
class InputReader(object): |
|
"""Input function that creates a dataset from files.""" |
|
|
|
def __init__(self, |
|
file_pattern, |
|
decoder_fn, |
|
generator_fn=None, |
|
is_training=False): |
|
"""Initializes the input reader. |
|
|
|
Args: |
|
file_pattern: The file pattern for the data example, in TFRecord format |
|
decoder_fn: A callable that takes a serialized tf.Example and produces |
|
parsed (and potentially processed / augmented) tensors. |
|
generator_fn: An optional `callable` that takes the decoded raw tensors |
|
dict and generates a ground-truth dictionary that can be consumed by |
|
the model. It will be executed after decoder_fn (default: None). |
|
is_training: If this dataset is used for training or not (default: False). |
|
""" |
|
self._file_pattern = file_pattern |
|
self._is_training = is_training |
|
self._decoder_fn = decoder_fn |
|
self._generator_fn = generator_fn |
|
|
|
def __call__(self, batch_size=1, max_num_examples=-1): |
|
"""Provides tf.data.Dataset object. |
|
|
|
Args: |
|
batch_size: Expected batch size input data. |
|
max_num_examples: Positive integer or -1. If positive, the returned |
|
dataset will only take (at most) this number of examples and raise |
|
tf.errors.OutOfRangeError after that (default: -1). |
|
|
|
Returns: |
|
tf.data.Dataset object. |
|
""" |
|
dataset = tf.data.Dataset.list_files(self._file_pattern) |
|
|
|
if self._is_training: |
|
|
|
dataset = dataset.shuffle(dataset.cardinality(), |
|
reshuffle_each_iteration=True) |
|
dataset = dataset.repeat() |
|
|
|
|
|
|
|
|
|
dataset = dataset.interleave( |
|
map_func=tf.data.TFRecordDataset, |
|
cycle_length=(_NUM_INPUTS_PROCESSED_CONCURRENTLY |
|
if self._is_training else 1), |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE, |
|
deterministic=not self._is_training) |
|
|
|
if self._is_training: |
|
dataset = dataset.shuffle(_SHUFFLE_BUFFER_SIZE) |
|
if max_num_examples > 0: |
|
dataset = dataset.take(max_num_examples) |
|
|
|
|
|
dataset = dataset.map( |
|
self._decoder_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
if self._generator_fn is not None: |
|
dataset = dataset.map( |
|
self._generator_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
dataset = dataset.batch(batch_size, drop_remainder=True) |
|
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) |
|
return dataset |
|
|