import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import unittest

Input adapter

This model transforms the input image into tenfor of desired shape (in our setup 16x16x40).

def create_input_adapter(input_shape, size=16, depth=40, activation=None):
  """Creates an input adapter module for the input image.
  The input adapter transforms input image of given shape 
  into a tensor of target shape.

    input_shape: shape of input image (HxWxC). Image width and height 
      must be devidible by size. H*W*C must be less than or equal
      to size*size*depth.
    size: height and width of the output tensor after space2depth operation. 
    depth: number of channels in the output tensor.
    activation: conv layer activation function."""
  h, w, c = input_shape
  if h < size or w < size:
    raise ValueError('Input height and width should be greater than `size`.')
  # `block_size` of space2depth
  block_size = min(h / size, w / size)
  if depth % (block_size * block_size) != 0:
    raise ValueError('depth value is not devisible by the computed block size') 
  # creating an adapter model
  inputs = keras.Input(shape=input_shape)
  s2d = tf.nn.space_to_depth(inputs, block_size)
  outputs = layers.Conv2D(filters=depth,
                          kernel_size=1, activation=activation)(s2d)
  model = keras.Model(inputs, outputs, name='in_adapter')
  return model

Output adapter

def create_output_adapter(input_shape, block_size=None, pool_stride=None,
                        activation='swish', depthwise=True):
  """ Creates an output adapter module that processes tensors before 
  passing them to fully connected layers.
    input_shape: shape of the input tensor (HxWxC).
    block_size: tensor height and width after average pooling. Default 
    value is 4.
    pool_stride: stride of average pooling.
    activation: activation function.
    depthwise: whether to use depthwise convolution."""
  if not block_size: 
    block_size = 4
  if not isinstance(block_size, int) or block_size < 1:
    raise ValueError("block_size must be a positive integer.")

  if pool_stride != None and (not isinstance(pool_stride, int) or
                              pool_stride < 1):
    raise ValueError("pool_stride be a positive integer or None.")
  if len(input_shape) != 3:
    raise ValueError("input_shape must be a tuple of size 3.")

  h, w, _ = input_shape
  inputs = keras.Input(shape=input_shape)
  kernel_size = (tf.round(h / block_size), tf.round(w / block_size))

  x = tf.keras.layers.AveragePooling2D(pool_size=kernel_size, 
  if depthwise:
    x = tf.keras.layers.DepthwiseConv2D(kernel_size=1,
    x = tf.keras.layers.Activation(activation)(x)

  x = tf.keras.layers.Flatten(data_format='channels_last')(x)
  outputs = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
  model = keras.Model(inputs, outputs, name='out_adapter')
  return model

input_shape = (32, 32, 40)
input_tensor = tf.Variable(np.random.rand(32, *input_shape))
out_adapter = create_output_adapter(input_shape, block_size=4, 
out_tensor = out_adapter(input_tensor)

tf.Tensor([ 32   1   1 640], shape=(4,), dtype=int32)

Model: "out_adapter"
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         [(None, 32, 32, 40)]      0         
average_pooling2d (AveragePo (None, 4, 4, 40)          0         
depthwise_conv2d (DepthwiseC (None, 4, 4, 40)          80        
flatten (Flatten)            (None, 640)               0         
tf_op_layer_ExpandDims_10 (T [(None, 1, 640)]          0         
tf_op_layer_ExpandDims_11 (T [(None, 1, 1, 640)]       0         
Total params: 80
Trainable params: 80
Non-trainable params: 0

I/O adapters tests

class InputAdapterTest(tf.test.TestCase):
  def setUp(self):
    super(InputAdapterTest, self).setUp()
    self.default_size = 32
    self.default_depth = 64

  # tests if the output of the adapter is of correct shape
  def test_output_shape(self):
    input_shape = (64, 64, 3)
    batch_size = 16
    expected_out_shape = (batch_size, self.default_size,
                          self.default_size, self.default_depth)
    adapter = self._create_default_adapter(input_shape)
    out = adapter(np.zeros((batch_size, *input_shape)))
    self.assertShapeEqual(np.zeros(expected_out_shape), out)

  def test_small_in_shape(self):
    input_shape = (28, 28, 3)
    with self.assertRaises(Exception):

  def test_non_divisible(self):
    input_shape = (50, 50, 3)
    with self.assertRaises(Exception):
     self. _create_default_adapter(input_shape)

  def _create_default_adapter(self, input_shape):
    adapter = create_input_adapter(input_shape,
    return adapter

class OutputAdapterTest(tf.test.TestCase):

  def setUp(self):
    super(OutputAdapterTest, self).setUp()
  def test_out_shape(self):
    input_shape = (32, 32, 40)
    batch = 32
    input_tensor = tf.random.normal([batch, *input_shape])
    block_size = 4
    out_adapter = create_output_adapter(
                        input_shape, block_size=block_size)
    out = out_adapter(input_tensor)
    expected_num_c = input_shape[2] * block_size * block_size
    expected_out_shape = (batch, 1, 1, expected_num_c)
    self.assertAllEqual(expected_out_shape, out.shape)

  def test_bad_block_size(self):
    input_shape = (32, 32, 40)
    with self.assertRaises(ValueError):
      out_adapter = create_output_adapter(
                        input_shape, block_size= 3.5)
  def test_bad_pool_stride(self):
    input_shape = (32, 32, 40)
    with self.assertRaises(ValueError):
      out_adapter = create_output_adapter(
                        input_shape, pool_stride = '3')
  def test_bad_input_shape(self):
    input_shape = (32, 32)
    with self.assertRaises(ValueError):
      out_adapter = create_output_adapter(
                        input_shape, block_size= 4)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

Ran 9 tests in 0.156s

OK (skipped=2)

