Fashion MNIST with Keras and TPUs

Let's try out using tf.keras and Cloud TPUs to train a model on the fashion MNIST dataset.

First, let's grab our dataset using tf.keras.datasets.

Talken from


In [1]:
import tensorflow as tf
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

# add empty color dimension
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

Defining our model

We will use a standard conv-net for this example. We have 3 layers with drop-out and batch normalization between each layer.


In [2]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))

model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(128, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Dropout(0.25))

model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(256, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256))
model.add(tf.keras.layers.Activation('elu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10))
model.add(tf.keras.layers.Activation('softmax'))
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
batch_normalization (BatchNo (None, 28, 28, 1)         4         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        1664      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 128)       204928    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 256)         819456    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 256)         0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 3, 3, 256)         0         
_________________________________________________________________
flatten (Flatten)            (None, 2304)              0         
_________________________________________________________________
dense (Dense)                (None, 256)               590080    
_________________________________________________________________
activation (Activation)      (None, 256)               0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2570      
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
=================================================================
Total params: 1,619,470
Trainable params: 1,619,084
Non-trainable params: 386
_________________________________________________________________

Training on the TPU

We're ready to train! We first construct our model on the TPU, and compile it.

Here we demonstrate that we can use a generator function and fit_generator to train the model. You can also pass in x_train and y_train to tpu_model.fit() instead.


In [3]:
import os
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
    )
)
tpu_model.compile(
    optimizer=tf.train.AdamOptimizer(learning_rate=1e-3, ),
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=['sparse_categorical_accuracy']
)

def train_gen(batch_size):
  while True:
    offset = np.random.randint(0, x_train.shape[0] - batch_size)
    yield x_train[offset:offset+batch_size], y_train[offset:offset + batch_size]
    

tpu_model.fit_generator(
    train_gen(1024),
    epochs=10,
    steps_per_epoch=100,
    validation_data=(x_test, y_test),
)


INFO:tensorflow:Querying Tensorflow master (b'grpc://10.7.138.218:8470') for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 14408847620022770600)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 2746972741496128654)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 13565741426377899995)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 12818168040642754990)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 3010502991640372920)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 5598428933246360125)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 14292462930392135804)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 837902191328346042)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 4952914342205355326)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 15791634188471722418)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 1382312301466126899)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 5076344107432619919)
WARNING:tensorflow:tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning.
INFO:tensorflow:Connecting to: b'grpc://10.7.138.218:8470'
Epoch 1/10
INFO:tensorflow:New input shapes; (re-)compiling: mode=train, [TensorSpec(shape=(128, 28, 28, 1), dtype=tf.float32, name='batch_normalization_input0'), TensorSpec(shape=(128, 1), dtype=tf.float32, name='activation_1_target0')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for batch_normalization_input
INFO:tensorflow:Started compiling
INFO:tensorflow:Finished compiling. Time elapsed: 2.821094036102295 secs
INFO:tensorflow:Setting weights on TPU model.
 99/100 [============================>.] - ETA: 0s - loss: 0.9226 - sparse_categorical_accuracy: 0.7431INFO:tensorflow:New input shapes; (re-)compiling: mode=eval, [TensorSpec(shape=(128, 28, 28, 1), dtype=tf.float32, name='batch_normalization_input0'), TensorSpec(shape=(128, 1), dtype=tf.float32, name='activation_1_target0')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for batch_normalization_input
INFO:tensorflow:Started compiling
INFO:tensorflow:Finished compiling. Time elapsed: 1.5474913120269775 secs
INFO:tensorflow:New input shapes; (re-)compiling: mode=eval, [TensorSpec(shape=(98, 28, 28, 1), dtype=tf.float32, name='batch_normalization_input0'), TensorSpec(shape=(98, 1), dtype=tf.float32, name='activation_1_target0')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for batch_normalization_input
INFO:tensorflow:Started compiling
INFO:tensorflow:Finished compiling. Time elapsed: 1.5241296291351318 secs
100/100 [==============================] - 18s 177ms/step - loss: 0.9185 - sparse_categorical_accuracy: 0.7441 - val_loss: 1.1408 - val_sparse_categorical_accuracy: 0.6304
Epoch 2/10
100/100 [==============================] - 8s 84ms/step - loss: 0.4337 - sparse_categorical_accuracy: 0.8507 - val_loss: 1.0953 - val_sparse_categorical_accuracy: 0.6376
Epoch 3/10
100/100 [==============================] - 8s 82ms/step - loss: 0.3591 - sparse_categorical_accuracy: 0.8743 - val_loss: 0.8236 - val_sparse_categorical_accuracy: 0.7040
Epoch 4/10
100/100 [==============================] - 8s 81ms/step - loss: 0.2888 - sparse_categorical_accuracy: 0.8970 - val_loss: 0.4052 - val_sparse_categorical_accuracy: 0.8472
Epoch 5/10
100/100 [==============================] - 8s 84ms/step - loss: 0.2583 - sparse_categorical_accuracy: 0.9050 - val_loss: 0.2967 - val_sparse_categorical_accuracy: 0.8952
Epoch 6/10
100/100 [==============================] - 8s 84ms/step - loss: 0.2396 - sparse_categorical_accuracy: 0.9121 - val_loss: 0.2633 - val_sparse_categorical_accuracy: 0.8984
Epoch 7/10
100/100 [==============================] - 8s 81ms/step - loss: 0.2075 - sparse_categorical_accuracy: 0.9233 - val_loss: 0.2607 - val_sparse_categorical_accuracy: 0.9016
Epoch 8/10
100/100 [==============================] - 8s 81ms/step - loss: 0.1946 - sparse_categorical_accuracy: 0.9278 - val_loss: 0.2273 - val_sparse_categorical_accuracy: 0.9216
Epoch 9/10
100/100 [==============================] - 8s 83ms/step - loss: 0.1688 - sparse_categorical_accuracy: 0.9380 - val_loss: 0.2341 - val_sparse_categorical_accuracy: 0.9216
Epoch 10/10
100/100 [==============================] - 8s 81ms/step - loss: 0.1596 - sparse_categorical_accuracy: 0.9410 - val_loss: 0.2264 - val_sparse_categorical_accuracy: 0.9200
Out[3]:
<tensorflow.python.keras.callbacks.History at 0x7f590be95d68>

Checking our results (inference)

Now that we're done training, let's see how well we can predict fashion categories! Keras/TPU prediction isn't working due to a small bug (fixed in TF 1.12!), but we can predict on the CPU to see how our results look.


In [5]:
LABEL_NAMES = ['t_shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boots']


cpu_model = tpu_model.sync_to_cpu()

from matplotlib import pyplot
%matplotlib inline

def plot_predictions(images, predictions):
  n = images.shape[0]
  nc = int(np.ceil(n / 4))
  f, axes = pyplot.subplots(nc, 4)
  for i in range(nc * 4):
    y = i // 4
    x = i % 4
    axes[x, y].axis('off')
    
    label = LABEL_NAMES[np.argmax(predictions[i])]
    confidence = np.max(predictions[i])
    if i > n:
      continue
    axes[x, y].imshow(images[i])
    axes[x, y].text(0.5, 0.5, label + '\n%.3f' % confidence, fontsize=14)

  pyplot.gcf().set_size_inches(8, 8)  

plot_predictions(np.squeeze(x_test[:16]), 
                 cpu_model.predict(x_test[:16]))


INFO:tensorflow:Copying TPU weights to the CPU

Not bad!