In [19]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import unittest
In [20]:
def create_input_adapter(input_shape, size=16, depth=40, activation=None):
"""Creates an input adapter module for the input image.
The input adapter transforms input image of given shape
into a tensor of target shape.
Arguments:
input_shape: shape of input image (HxWxC). Image width and height
must be devidible by size. H*W*C must be less than or equal
to size*size*depth.
size: height and width of the output tensor after space2depth operation.
depth: number of channels in the output tensor.
activation: conv layer activation function."""
h, w, c = input_shape
if h < size or w < size:
raise ValueError('Input height and width should be greater than `size`.')
# `block_size` of space2depth
block_size = min(h / size, w / size)
if depth % (block_size * block_size) != 0:
raise ValueError('depth value is not devisible by the computed block size')
# creating an adapter model
inputs = keras.Input(shape=input_shape)
s2d = tf.nn.space_to_depth(inputs, block_size)
outputs = layers.Conv2D(filters=depth,
kernel_size=1, activation=activation)(s2d)
model = keras.Model(inputs, outputs, name='in_adapter')
return model
In [21]:
def create_output_adapter(input_shape, block_size=None, pool_stride=None,
activation='swish', depthwise=True):
""" Creates an output adapter module that processes tensors before
passing them to fully connected layers.
Arguments:
input_shape: shape of the input tensor (HxWxC).
block_size: tensor height and width after average pooling. Default
value is 4.
pool_stride: stride of average pooling.
activation: activation function.
depthwise: whether to use depthwise convolution."""
if not block_size:
block_size = 4
if not isinstance(block_size, int) or block_size < 1:
raise ValueError("block_size must be a positive integer.")
if pool_stride != None and (not isinstance(pool_stride, int) or
pool_stride < 1):
raise ValueError("pool_stride be a positive integer or None.")
if len(input_shape) != 3:
raise ValueError("input_shape must be a tuple of size 3.")
h, w, _ = input_shape
inputs = keras.Input(shape=input_shape)
kernel_size = (tf.round(h / block_size), tf.round(w / block_size))
x = tf.keras.layers.AveragePooling2D(pool_size=kernel_size,
strides=pool_stride,
padding='valid')(inputs)
if depthwise:
x = tf.keras.layers.DepthwiseConv2D(kernel_size=1,
activation=activation)(x)
else:
x = tf.keras.layers.Activation(activation)(x)
x = tf.keras.layers.Flatten(data_format='channels_last')(x)
outputs = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
model = keras.Model(inputs, outputs, name='out_adapter')
return model
In [22]:
input_shape = (32, 32, 40)
input_tensor = tf.Variable(np.random.rand(32, *input_shape))
out_adapter = create_output_adapter(input_shape, block_size=4,
depthwise=True)
out_tensor = out_adapter(input_tensor)
print(tf.shape(out_tensor))
In [23]:
out_adapter.summary()
In [24]:
class InputAdapterTest(tf.test.TestCase):
def setUp(self):
super(InputAdapterTest, self).setUp()
self.default_size = 32
self.default_depth = 64
# tests if the output of the adapter is of correct shape
def test_output_shape(self):
input_shape = (64, 64, 3)
batch_size = 16
expected_out_shape = (batch_size, self.default_size,
self.default_size, self.default_depth)
adapter = self._create_default_adapter(input_shape)
out = adapter(np.zeros((batch_size, *input_shape)))
self.assertShapeEqual(np.zeros(expected_out_shape), out)
def test_small_in_shape(self):
input_shape = (28, 28, 3)
with self.assertRaises(Exception):
self._create_default_adapter(input_shape)
def test_non_divisible(self):
input_shape = (50, 50, 3)
with self.assertRaises(Exception):
self. _create_default_adapter(input_shape)
def _create_default_adapter(self, input_shape):
adapter = create_input_adapter(input_shape,
size=self.default_size,
depth=self.default_depth)
return adapter
In [25]:
class OutputAdapterTest(tf.test.TestCase):
def setUp(self):
super(OutputAdapterTest, self).setUp()
def test_out_shape(self):
input_shape = (32, 32, 40)
batch = 32
input_tensor = tf.random.normal([batch, *input_shape])
block_size = 4
out_adapter = create_output_adapter(
input_shape, block_size=block_size)
out = out_adapter(input_tensor)
expected_num_c = input_shape[2] * block_size * block_size
expected_out_shape = (batch, 1, 1, expected_num_c)
self.assertAllEqual(expected_out_shape, out.shape)
def test_bad_block_size(self):
input_shape = (32, 32, 40)
with self.assertRaises(ValueError):
out_adapter = create_output_adapter(
input_shape, block_size= 3.5)
def test_bad_pool_stride(self):
input_shape = (32, 32, 40)
with self.assertRaises(ValueError):
out_adapter = create_output_adapter(
input_shape, pool_stride = '3')
def test_bad_input_shape(self):
input_shape = (32, 32)
with self.assertRaises(ValueError):
out_adapter = create_output_adapter(
input_shape, block_size= 4)
In [26]:
if __name__ == '__main__':
unittest.main(argv=['first-arg-is-ignored'], exit=False)
In [ ]: