|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Provide utility functions to write simple tests.""" |
|
import functools |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
NORMALIZATION_LAYERS = ( |
|
tf.keras.layers.experimental.SyncBatchNormalization, |
|
tf.keras.layers.BatchNormalization |
|
) |
|
|
|
|
|
def create_strategy(): |
|
"""Returns a strategy based on available devices. |
|
|
|
Does NOT work with local_multiworker_tpu_test tests! |
|
""" |
|
tpus = tf.config.list_logical_devices(device_type='TPU') |
|
gpus = tf.config.list_logical_devices(device_type='GPU') |
|
if tpus: |
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver('') |
|
tf.config.experimental_connect_to_cluster(resolver) |
|
tf.tpu.experimental.initialize_tpu_system(resolver) |
|
return tf.distribute.TPUStrategy(resolver) |
|
elif gpus: |
|
return tf.distribute.OneDeviceStrategy('/gpu:0') |
|
else: |
|
return tf.distribute.OneDeviceStrategy('/cpu:0') |
|
|
|
|
|
def test_all_strategies(func): |
|
"""Decorator to test CPU, GPU and TPU strategies.""" |
|
@functools.wraps(func) |
|
def decorator(self): |
|
strategy = create_strategy() |
|
return func(self, strategy) |
|
return decorator |
|
|
|
|
|
def create_test_input(batch, height, width, channels): |
|
"""Creates test input tensor.""" |
|
return tf.convert_to_tensor( |
|
np.tile( |
|
np.reshape( |
|
np.reshape(np.arange(height), [height, 1]) + |
|
np.reshape(np.arange(width), [1, width]), |
|
[1, height, width, 1]), |
|
[batch, 1, 1, channels]), dtype=tf.float32) |
|
|