In [5]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import datetime as dt
import numpy as np
import matplotlib.pyplot as plt
import unittest
In [2]:
print(tf.config.list_physical_devices('GPU'))
This code is based on the Tensorflow tutorial.
In [9]:
class ResBlock(tf.keras.Model):
""" A ResBlock module class with expansion, depthwise consolution and
projection.
In this ResBlock, standard 2D convolutions are replaced by 1x1 convolution
that expands the input tensor along the channel dimension, depthwise
convolution and 1x1 convolution that projects the tensor back to the original
number of channels.
Args:
kernel_size: size of the depthwise convolution kernel
expansion_factor: expansion factor of the first 1x1 convolution.
e.g., if the input tensor has N channels, then the first 1x1
convolution layer will expand it to expansion_factor*N channels.
activation: activation function. Supported functions: 'relu',
'relu6', 'lrelu', 'swish'."""
def __init__(self, kernel_size=3, expansion_factor=6, activation='relu'):
super(ResBlock, self).__init__(name='')
if expansion_factor < 1:
raise ValueError('The expansion factor value should be '
'greater than or equal to one.')
self.expansion_factor = expansion_factor
self.activation = self.set_activation_fn(activation)
self.kernel_size = kernel_size
def build(self, input_shape):
input_channel = input_shape[-1]
self.expanded_channel = input_channel*self.expansion_factor
self.conv1 = tf.keras.layers.Conv2D(self.expanded_channel, kernel_size=1,
strides=(1, 1), padding='same')
self.bn1 = tf.keras.layers.BatchNormalization()
self.conv2 = tf.keras.layers.DepthwiseConv2D(kernel_size=self.kernel_size,
strides=(1, 1), padding='same')
self.bn2 = tf.keras.layers.BatchNormalization()
self.conv3 = tf.keras.layers.Conv2D(input_channel, kernel_size=1,
strides=(1, 1), padding='same')
self.bn3 = tf.keras.layers.BatchNormalization()
def call(self, input_tensor, training=True):
x = self.conv1(input_tensor)
x = self.bn1(x, training=training)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x, training=training)
x = self.activation(x)
x = self.conv3(x)
x = self.bn3(x, training=training)
x += input_tensor
return x
def set_activation_fn(self, activation):
switcher = {'relu': tf.nn.relu,
'relu6': tf.nn.relu6,
'lrelu': tf.nn.leaky_relu,
'swish': tf.nn.swish}
res = switcher.get(activation)
if not res:
raise Exception("Given activation function is not supported.")
return res
def _get_input_channel(self, input_shape):
if input_shape.dims[-1].value is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
return int(input_shape[-1])
In [14]:
class BlockTest(tf.test.TestCase):
def setUp(self):
super(BlockTest, self).setUp()
def _run_standard_block(self, input_tensor):
block = ResBlock(kernel_size=3, expansion_factor=6, activation='relu')
block.build(tf.shape(input_tensor))
return block(input_tensor)
def test_basic(self):
"""Checking if the input and output tensors shapes match."""
input_shape = (32, 128, 128, 64)
input_val = tf.random.normal([*input_shape])
out = self._run_standard_block(input_val)
self.assertShapeEqual(input_val.numpy(), out)
def test_standard_input(self):
"""Checking that input / output shapes match on input (8, 16, 16, 40)."""
input_shape = (8, 16, 16, 40)
input_val = tf.random.normal([*input_shape])
out = self._run_standard_block(input_val)
self.assertShapeEqual(input_val.numpy(), out)
def test_expansion_wrong_val(self):
with self.assertRaises(ValueError):
block = ResBlock(kernel_size=3, expansion_factor=0, activation='relu')
def test_zeros_input(self):
input_shape = (8, 16, 16, 40)
input_val = tf.zeros([*input_shape])
out = self._run_standard_block(input_val)
self.assertAllEqual(input_val, out)
def test_wrong_activation(self):
with self.assertRaises(Exception):
block = ResBlock(kernel_size=3, expansion_factor=6, activation='sigmoid')
In [15]:
if __name__ == '__main__':
unittest.main(argv=['first-arg-is-ignored'], exit=False)
In [16]:
h, w, c = (32, 32, 3)
BATCH_SIZE = 64
In [17]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
In [18]:
train_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
.batch(BATCH_SIZE, drop_remainder=True)
.shuffle(buffer_size=50000)).repeat()
test_dataset = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
.batch(BATCH_SIZE, drop_remainder=True)).repeat()
for image, label in test_dataset.take(1):
print(image.shape, label.shape)
In [19]:
def normalize(x, y):
x = tf.image.per_image_standardization(x)
return x, y
def augment(x, y):
x = tf.image.resize_with_crop_or_pad(x, h + 8, w + 8)
x = tf.image.random_crop(x, [BATCH_SIZE, h, w, c])
x = tf.image.random_flip_left_right(x)
x = tf.image.random_brightness(x, max_delta=0.5)
return x, y
test_dataset = test_dataset.map(normalize)
train_dataset = (train_dataset
.map(augment)
.map(normalize))
In [20]:
num_res_blocks = 10
In [21]:
inputs = keras.Input(shape=(h, w, c))
x = layers.Conv2D(64, 7, activation='swish')(inputs)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
for i in range(num_res_blocks):
x = ResBlock(kernel_size=3,
expansion_factor=6,
activation='swish')(x)
x = layers.Conv2D(64, 3, activation='swish')(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation='swish')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs, outputs)
In [22]:
model.summary()
In [23]:
callbacks = [
(keras.callbacks.TensorBoard(log_dir='./log/{}'
.format(dt.datetime.now()
.strftime("%Y-%m-%d-%H-%M-%S")),
write_images=True)),
]
model.compile(optimizer=keras.optimizers.Adam(),
loss='sparse_categorical_crossentropy',
metrics=['acc'])
history = model.fit(train_dataset, epochs=30,
steps_per_epoch=195,
validation_data=test_dataset,
validation_steps=3,
callbacks=callbacks)
In [24]:
#Plotting the accuracy during training
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('ResNet accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='lower right')
plt.show()
In [25]:
results = model.evaluate(test_dataset, steps=5000)
print('test loss, test acc:', results)