In [0]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 1s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

In [3]:
x_train.shape


Out[3]:
(60000, 28, 28)

In [0]:
# add empty color dimension
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

In [5]:
x_train.shape


Out[5]:
(60000, 28, 28, 1)

In [0]:
# recude memory and compute time
# NUMBER_OF_SAMPLES = 50000
NUMBER_OF_SAMPLES = 50000

In [0]:
x_train_samples = x_train[:NUMBER_OF_SAMPLES]

In [0]:
y_train_samples = y_train[:NUMBER_OF_SAMPLES]

In [0]:
import skimage.data
import skimage.transform

x_train_224 = np.array([skimage.transform.resize(image, (32, 32)) for image in x_train_samples])

In [32]:
x_train_224.shape


Out[32]:
(50000, 32, 32, 1)

Alternative: ResNet

http://arxiv.org/abs/1512.03385


In [33]:
from tensorflow.keras.applications.resnet50 import ResNet50

# https://keras.io/applications/#mobilenet
# https://arxiv.org/pdf/1704.04861.pdf
from tensorflow.keras.applications.mobilenet import MobileNet

# model = ResNet50(classes=10, weights=None, input_shape=(32, 32, 1))
model = MobileNet(classes=10, weights=None, input_shape=(32, 32, 1))

model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         (None, 32, 32, 1)         0         
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 33, 33, 1)         0         
_________________________________________________________________
conv1 (Conv2D)               (None, 16, 16, 32)        288       
_________________________________________________________________
conv1_bn (BatchNormalization (None, 16, 16, 32)        128       
_________________________________________________________________
conv1_relu (ReLU)            (None, 16, 16, 32)        0         
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 16, 16, 32)        288       
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 16, 16, 32)        128       
_________________________________________________________________
conv_dw_1_relu (ReLU)        (None, 16, 16, 32)        0         
_________________________________________________________________
conv_pw_1 (Conv2D)           (None, 16, 16, 64)        2048      
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 16, 16, 64)        256       
_________________________________________________________________
conv_pw_1_relu (ReLU)        (None, 16, 16, 64)        0         
_________________________________________________________________
conv_pad_2 (ZeroPadding2D)   (None, 17, 17, 64)        0         
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D)  (None, 8, 8, 64)          576       
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 8, 8, 64)          256       
_________________________________________________________________
conv_dw_2_relu (ReLU)        (None, 8, 8, 64)          0         
_________________________________________________________________
conv_pw_2 (Conv2D)           (None, 8, 8, 128)         8192      
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 8, 8, 128)         512       
_________________________________________________________________
conv_pw_2_relu (ReLU)        (None, 8, 8, 128)         0         
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D)  (None, 8, 8, 128)         1152      
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 8, 8, 128)         512       
_________________________________________________________________
conv_dw_3_relu (ReLU)        (None, 8, 8, 128)         0         
_________________________________________________________________
conv_pw_3 (Conv2D)           (None, 8, 8, 128)         16384     
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 8, 8, 128)         512       
_________________________________________________________________
conv_pw_3_relu (ReLU)        (None, 8, 8, 128)         0         
_________________________________________________________________
conv_pad_4 (ZeroPadding2D)   (None, 9, 9, 128)         0         
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D)  (None, 4, 4, 128)         1152      
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 4, 4, 128)         512       
_________________________________________________________________
conv_dw_4_relu (ReLU)        (None, 4, 4, 128)         0         
_________________________________________________________________
conv_pw_4 (Conv2D)           (None, 4, 4, 256)         32768     
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 4, 4, 256)         1024      
_________________________________________________________________
conv_pw_4_relu (ReLU)        (None, 4, 4, 256)         0         
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D)  (None, 4, 4, 256)         2304      
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 4, 4, 256)         1024      
_________________________________________________________________
conv_dw_5_relu (ReLU)        (None, 4, 4, 256)         0         
_________________________________________________________________
conv_pw_5 (Conv2D)           (None, 4, 4, 256)         65536     
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 4, 4, 256)         1024      
_________________________________________________________________
conv_pw_5_relu (ReLU)        (None, 4, 4, 256)         0         
_________________________________________________________________
conv_pad_6 (ZeroPadding2D)   (None, 5, 5, 256)         0         
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D)  (None, 2, 2, 256)         2304      
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 2, 2, 256)         1024      
_________________________________________________________________
conv_dw_6_relu (ReLU)        (None, 2, 2, 256)         0         
_________________________________________________________________
conv_pw_6 (Conv2D)           (None, 2, 2, 512)         131072    
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_pw_6_relu (ReLU)        (None, 2, 2, 512)         0         
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D)  (None, 2, 2, 512)         4608      
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_dw_7_relu (ReLU)        (None, 2, 2, 512)         0         
_________________________________________________________________
conv_pw_7 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_pw_7_relu (ReLU)        (None, 2, 2, 512)         0         
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D)  (None, 2, 2, 512)         4608      
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_dw_8_relu (ReLU)        (None, 2, 2, 512)         0         
_________________________________________________________________
conv_pw_8 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_pw_8_relu (ReLU)        (None, 2, 2, 512)         0         
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D)  (None, 2, 2, 512)         4608      
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_dw_9_relu (ReLU)        (None, 2, 2, 512)         0         
_________________________________________________________________
conv_pw_9 (Conv2D)           (None, 2, 2, 512)         262144    
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_pw_9_relu (ReLU)        (None, 2, 2, 512)         0         
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 2, 2, 512)         4608      
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_dw_10_relu (ReLU)       (None, 2, 2, 512)         0         
_________________________________________________________________
conv_pw_10 (Conv2D)          (None, 2, 2, 512)         262144    
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_pw_10_relu (ReLU)       (None, 2, 2, 512)         0         
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 2, 2, 512)         4608      
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_dw_11_relu (ReLU)       (None, 2, 2, 512)         0         
_________________________________________________________________
conv_pw_11 (Conv2D)          (None, 2, 2, 512)         262144    
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 2, 2, 512)         2048      
_________________________________________________________________
conv_pw_11_relu (ReLU)       (None, 2, 2, 512)         0         
_________________________________________________________________
conv_pad_12 (ZeroPadding2D)  (None, 3, 3, 512)         0         
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 1, 1, 512)         4608      
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 1, 1, 512)         2048      
_________________________________________________________________
conv_dw_12_relu (ReLU)       (None, 1, 1, 512)         0         
_________________________________________________________________
conv_pw_12 (Conv2D)          (None, 1, 1, 1024)        524288    
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 1, 1, 1024)        4096      
_________________________________________________________________
conv_pw_12_relu (ReLU)       (None, 1, 1, 1024)        0         
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 1, 1, 1024)        9216      
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 1, 1, 1024)        4096      
_________________________________________________________________
conv_dw_13_relu (ReLU)       (None, 1, 1, 1024)        0         
_________________________________________________________________
conv_pw_13 (Conv2D)          (None, 1, 1, 1024)        1048576   
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 1, 1, 1024)        4096      
_________________________________________________________________
conv_pw_13_relu (ReLU)       (None, 1, 1, 1024)        0         
_________________________________________________________________
global_average_pooling2d_2 ( (None, 1024)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 1, 1, 1024)        0         
_________________________________________________________________
dropout (Dropout)            (None, 1, 1, 1024)        0         
_________________________________________________________________
conv_preds (Conv2D)          (None, 1, 1, 10)          10250     
_________________________________________________________________
act_softmax (Activation)     (None, 1, 1, 10)          0         
_________________________________________________________________
reshape_2 (Reshape)          (None, 10)                0         
=================================================================
Total params: 3,238,538
Trainable params: 3,216,650
Non-trainable params: 21,888
_________________________________________________________________

In [34]:
BATCH_SIZE=10
EPOCHS = 20

model.compile(loss='sparse_categorical_crossentropy',
             optimizer='adam',
             metrics=['accuracy'])

%time history = model.fit(x_train_224, y_train_samples, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_split=0.2, verbose=1)


Train on 40000 samples, validate on 10000 samples
Epoch 1/20
40000/40000 [==============================] - 203s 5ms/step - loss: 0.9197 - acc: 0.6855 - val_loss: 0.6144 - val_acc: 0.7491
Epoch 2/20
40000/40000 [==============================] - 195s 5ms/step - loss: 0.6013 - acc: 0.7987 - val_loss: 0.4832 - val_acc: 0.8303
Epoch 3/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.4854 - acc: 0.8367 - val_loss: 0.3939 - val_acc: 0.8611
Epoch 4/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.4145 - acc: 0.8579 - val_loss: 0.3753 - val_acc: 0.8725
Epoch 5/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.3604 - acc: 0.8746 - val_loss: 0.3443 - val_acc: 0.8802
Epoch 6/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.3277 - acc: 0.8851 - val_loss: 0.3258 - val_acc: 0.8873
Epoch 7/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.3013 - acc: 0.8957 - val_loss: 0.3238 - val_acc: 0.8847
Epoch 8/20
40000/40000 [==============================] - 198s 5ms/step - loss: 0.2797 - acc: 0.9025 - val_loss: 0.2958 - val_acc: 0.8979
Epoch 9/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.2629 - acc: 0.9084 - val_loss: 0.2651 - val_acc: 0.9106
Epoch 10/20
40000/40000 [==============================] - 195s 5ms/step - loss: 0.2503 - acc: 0.9123 - val_loss: 0.2733 - val_acc: 0.9094
Epoch 11/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.2321 - acc: 0.9181 - val_loss: 0.2605 - val_acc: 0.9113
Epoch 12/20
40000/40000 [==============================] - 195s 5ms/step - loss: 0.2169 - acc: 0.9241 - val_loss: 0.2670 - val_acc: 0.9116
Epoch 13/20
40000/40000 [==============================] - 195s 5ms/step - loss: 0.2086 - acc: 0.9260 - val_loss: 0.2703 - val_acc: 0.9094
Epoch 14/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.1962 - acc: 0.9303 - val_loss: 0.2372 - val_acc: 0.9179
Epoch 15/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.1862 - acc: 0.9329 - val_loss: 0.2487 - val_acc: 0.9153
Epoch 16/20
40000/40000 [==============================] - 197s 5ms/step - loss: 0.1789 - acc: 0.9374 - val_loss: 0.2880 - val_acc: 0.9027
Epoch 17/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.1730 - acc: 0.9386 - val_loss: 0.2329 - val_acc: 0.9197
Epoch 18/20
40000/40000 [==============================] - 196s 5ms/step - loss: 0.1619 - acc: 0.9422 - val_loss: 0.2403 - val_acc: 0.9135
Epoch 19/20
40000/40000 [==============================] - 195s 5ms/step - loss: 0.1565 - acc: 0.9454 - val_loss: 0.2699 - val_acc: 0.9018
Epoch 20/20
40000/40000 [==============================] - 195s 5ms/step - loss: 0.1504 - acc: 0.9469 - val_loss: 0.2421 - val_acc: 0.9163
CPU times: user 1h 5min 10s, sys: 9min 47s, total: 1h 14min 58s
Wall time: 1h 5min 31s

In [0]:
import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline

In [36]:
def plot_history(history, samples=10, init_phase_samples=None):
    epochs = history.params['epochs']
    
    acc = history.history['acc']
    val_acc = history.history['val_acc']

    every_sample =  int(epochs / samples)
    acc = pd.DataFrame(acc).iloc[::every_sample, :]
    val_acc = pd.DataFrame(val_acc).iloc[::every_sample, :]

    fig, ax = plt.subplots(figsize=(20,5))

    ax.plot(acc, 'bo', label='Training acc')
    ax.plot(val_acc, 'b', label='Validation acc')
    ax.set_title('Training and validation accuracy')
    ax.legend()

plot_history(history)


Checking our results (inference)


In [0]:
x_test_224 = np.array([skimage.transform.resize(image, (32, 32)) for image in x_test])

In [39]:
LABEL_NAMES = ['t_shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boots']


def plot_predictions(images, predictions):
  n = images.shape[0]
  nc = int(np.ceil(n / 4))
  f, axes = plt.subplots(nc, 4)
  for i in range(nc * 4):
    y = i // 4
    x = i % 4
    axes[x, y].axis('off')
    
    label = LABEL_NAMES[np.argmax(predictions[i])]
    confidence = np.max(predictions[i])
    if i > n:
      continue
    axes[x, y].imshow(images[i])
    axes[x, y].text(0.5, 0.5, label + '\n%.3f' % confidence, fontsize=14)

  pyplot.gcf().set_size_inches(8, 8)  

plot_predictions(np.squeeze(x_test_224[:16]), 
                 model.predict(x_test_224[:16]))



In [40]:
train_loss, train_accuracy = model.evaluate(x_train_224, y_train_samples, batch_size=BATCH_SIZE)
train_accuracy


50000/50000 [==============================] - 42s 849us/step
Out[40]:
0.948899992954731

In [41]:
test_loss, test_accuracy = model.evaluate(x_test_224, y_test, batch_size=BATCH_SIZE)
test_accuracy


10000/10000 [==============================] - 8s 850us/step
Out[41]:
0.9085999925136566

In [0]: