In [1]:
from __future__ import print_function

# to be able to see plots
%matplotlib inline  
import matplotlib.pyplot as plt

import numpy as np

import sys
sys.path.append("../tools")

from tools import collage

# just to use a fraction of GPU memory 
# This is not needed on dedicated machines.
# Allows you to share the GPU.
# This is specific to tensorflow.
gpu_memory_usage=0.5 
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_usage
set_session(tf.Session(config=config))


Using TensorFlow backend.

Read CIFAR10 dataset


In [2]:
from tools import readCIFAR, mapLabelsOneHot

# First run ../data/downloadCIFAR.sh 
# This reads the dataset
trnData, tstData, trnLabels, tstLabels = readCIFAR('../data/cifar-10-batches-py')

plt.subplot(1, 2, 1)
img = collage(trnData[:16])
print(img.shape)
plt.imshow(img)
plt.subplot(1, 2, 2)
img = collage(tstData[:16])
plt.imshow(img)
plt.show()

# Convert categorical labels to one-hot encoding which 
# is needed by categorical_crossentropy in Keras.
# This is not universal. The loss can be easily implemented
# with category IDs as labels.
trnLabels = mapLabelsOneHot(trnLabels)
tstLabels = mapLabelsOneHot(tstLabels)
print('One-hot trn. labels shape:', trnLabels.shape)


('Trn data shape:', (100000, 32, 32, 3))
('Tst data shape:', (20000, 32, 32, 3))
('Trn labels shape: ', (100000,))
('Tst labels shape: ', (20000,))
(128, 128, 3)
One-hot trn. labels shape: (100000, 10)

Normalize data

This maps all values in trn. and tst. data to range <-0.5,0.5>. Some kind of value normalization is preferable to provide consistent behavior accross different problems and datasets.


In [3]:
trnData = trnData.astype(np.float32) / 255.0 - 0.5
tstData = tstData.astype(np.float32) / 255.0 - 0.5

Define net


In [11]:
from keras.layers import Input, Reshape, Dense, Dropout, Flatten
from keras.layers import Activation, Conv2D, MaxPooling2D
from keras.models import Model
from keras import regularizers

w_decay = 0.0001
w_reg = regularizers.l2(w_decay)

def get_simple_FC_network(input_data, layer_cout, layer_dim):

    net = Conv2D(32,3, activation='relu')(input_data)
    net = MaxPooling2D(2, 2)(net)
    net = Conv2D(32,3, activation='relu')(net)
    net = Flatten()(net)
    for i in range(layer_cout):
        FC = Dense(layer_dim, activation='relu')
        net = Dropout(rate=0.33)(FC(net))
    
    net = Dense(10, name='out', activation='softmax')(net)

    return net

Build and compile model

Create the computation graph of the network and compile a 'model' for optimization inluding loss function and optimizer.


In [12]:
from keras import optimizers
from keras.models import Model
from keras import losses
from keras import metrics

input_data = Input(shape=(trnData.shape[1:]), name='data')
net = get_simple_FC_network(input_data, 2, 1024)
model = Model(inputs=[input_data], outputs=[net])

print('Model')
model.summary()

model.compile(
    loss=losses.categorical_crossentropy, 
    optimizer=optimizers.Adam(lr=0.001), 
    metrics=[metrics.categorical_accuracy])


Model
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
data (InputLayer)            (None, 32, 32, 3)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 13, 13, 32)        9248      
_________________________________________________________________
flatten_2 (Flatten)          (None, 5408)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 1024)              5538816   
_________________________________________________________________
dropout_3 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
dropout_4 (Dropout)          (None, 1024)              0         
_________________________________________________________________
out (Dense)                  (None, 10)                10250     
=================================================================
Total params: 6,608,810
Trainable params: 6,608,810
Non-trainable params: 0
_________________________________________________________________

Define TensorBoard callback

TensorBoard is able to store network statistics (loss, accuracy, weight histograms, activation histograms, ...) and view them through web interface. To view the statistics, run 'tensorboard --logdir=path/to/log-directory' and go to localhost:6006.


In [13]:
import keras
tbCallBack = keras.callbacks.TensorBoard(
    log_dir='./Graph', 
    histogram_freq=1, 
    write_graph=True, write_images=True)

In [15]:
model.fit(
    x=trnData, y=trnLabels,
    batch_size=64, epochs=100, verbose=1, 
    validation_data=[tstData, tstLabels], shuffle=True)#, callbacks=[tbCallBack])


Train on 100000 samples, validate on 20000 samples
Epoch 1/100
100000/100000 [==============================] - 13s - loss: 0.8279 - categorical_accuracy: 0.7102 - val_loss: 0.8307 - val_categorical_accuracy: 0.7142
Epoch 2/100
100000/100000 [==============================] - 13s - loss: 0.6470 - categorical_accuracy: 0.7718 - val_loss: 0.7816 - val_categorical_accuracy: 0.7345
Epoch 3/100
100000/100000 [==============================] - 13s - loss: 0.4969 - categorical_accuracy: 0.8251 - val_loss: 0.7896 - val_categorical_accuracy: 0.7346
Epoch 4/100
100000/100000 [==============================] - 13s - loss: 0.3758 - categorical_accuracy: 0.8677 - val_loss: 0.8352 - val_categorical_accuracy: 0.7370
Epoch 5/100
100000/100000 [==============================] - 13s - loss: 0.2901 - categorical_accuracy: 0.8987 - val_loss: 0.9497 - val_categorical_accuracy: 0.7345
Epoch 6/100
100000/100000 [==============================] - 13s - loss: 0.2389 - categorical_accuracy: 0.9167 - val_loss: 0.9512 - val_categorical_accuracy: 0.7388
Epoch 7/100
100000/100000 [==============================] - 13s - loss: 0.1980 - categorical_accuracy: 0.9318 - val_loss: 1.0086 - val_categorical_accuracy: 0.7351
Epoch 8/100
100000/100000 [==============================] - 13s - loss: 0.1746 - categorical_accuracy: 0.9411 - val_loss: 1.0625 - val_categorical_accuracy: 0.7337
Epoch 9/100
100000/100000 [==============================] - 13s - loss: 0.1604 - categorical_accuracy: 0.9454 - val_loss: 1.0809 - val_categorical_accuracy: 0.7329
Epoch 10/100
100000/100000 [==============================] - 13s - loss: 0.1444 - categorical_accuracy: 0.9512 - val_loss: 1.1065 - val_categorical_accuracy: 0.7381
Epoch 11/100
100000/100000 [==============================] - 13s - loss: 0.1366 - categorical_accuracy: 0.9552 - val_loss: 1.1673 - val_categorical_accuracy: 0.7355
Epoch 12/100
100000/100000 [==============================] - 13s - loss: 0.1270 - categorical_accuracy: 0.9583 - val_loss: 1.2424 - val_categorical_accuracy: 0.7239
Epoch 13/100
100000/100000 [==============================] - 13s - loss: 0.1210 - categorical_accuracy: 0.9600 - val_loss: 1.1847 - val_categorical_accuracy: 0.7355
Epoch 14/100
100000/100000 [==============================] - 13s - loss: 0.1128 - categorical_accuracy: 0.9630 - val_loss: 1.2643 - val_categorical_accuracy: 0.7304
Epoch 15/100
100000/100000 [==============================] - 13s - loss: 0.1085 - categorical_accuracy: 0.9645 - val_loss: 1.3267 - val_categorical_accuracy: 0.7273
Epoch 16/100
100000/100000 [==============================] - 13s - loss: 0.1075 - categorical_accuracy: 0.9656 - val_loss: 1.2130 - val_categorical_accuracy: 0.7363
Epoch 17/100
100000/100000 [==============================] - 13s - loss: 0.1047 - categorical_accuracy: 0.9660 - val_loss: 1.2835 - val_categorical_accuracy: 0.7340
Epoch 18/100
100000/100000 [==============================] - 13s - loss: 0.1002 - categorical_accuracy: 0.9680 - val_loss: 1.2763 - val_categorical_accuracy: 0.7369
Epoch 19/100
100000/100000 [==============================] - 13s - loss: 0.0958 - categorical_accuracy: 0.9695 - val_loss: 1.2606 - val_categorical_accuracy: 0.7317
Epoch 20/100
100000/100000 [==============================] - 13s - loss: 0.0940 - categorical_accuracy: 0.9709 - val_loss: 1.2681 - val_categorical_accuracy: 0.7342
Epoch 21/100
100000/100000 [==============================] - 13s - loss: 0.0922 - categorical_accuracy: 0.9711 - val_loss: 1.3760 - val_categorical_accuracy: 0.7288
Epoch 22/100
100000/100000 [==============================] - 13s - loss: 0.0914 - categorical_accuracy: 0.9717 - val_loss: 1.2662 - val_categorical_accuracy: 0.7298
Epoch 23/100
100000/100000 [==============================] - 13s - loss: 0.0859 - categorical_accuracy: 0.9737 - val_loss: 1.3498 - val_categorical_accuracy: 0.7307
Epoch 24/100
100000/100000 [==============================] - 13s - loss: 0.0891 - categorical_accuracy: 0.9728 - val_loss: 1.3411 - val_categorical_accuracy: 0.7348
Epoch 25/100
100000/100000 [==============================] - 13s - loss: 0.0820 - categorical_accuracy: 0.9745 - val_loss: 1.3986 - val_categorical_accuracy: 0.7335
Epoch 26/100
100000/100000 [==============================] - 13s - loss: 0.0788 - categorical_accuracy: 0.9757 - val_loss: 1.3644 - val_categorical_accuracy: 0.7322
Epoch 27/100
100000/100000 [==============================] - 13s - loss: 0.0809 - categorical_accuracy: 0.9753 - val_loss: 1.4610 - val_categorical_accuracy: 0.7314
Epoch 28/100
100000/100000 [==============================] - 13s - loss: 0.0830 - categorical_accuracy: 0.9752 - val_loss: 1.3935 - val_categorical_accuracy: 0.7329
Epoch 29/100
100000/100000 [==============================] - 13s - loss: 0.0788 - categorical_accuracy: 0.9760 - val_loss: 1.3824 - val_categorical_accuracy: 0.7288
Epoch 30/100
100000/100000 [==============================] - 13s - loss: 0.0771 - categorical_accuracy: 0.9770 - val_loss: 1.4315 - val_categorical_accuracy: 0.7239
Epoch 31/100
100000/100000 [==============================] - 14s - loss: 0.0766 - categorical_accuracy: 0.9768 - val_loss: 1.4658 - val_categorical_accuracy: 0.7343
Epoch 32/100
100000/100000 [==============================] - 13s - loss: 0.0789 - categorical_accuracy: 0.9769 - val_loss: 1.4689 - val_categorical_accuracy: 0.7281
Epoch 33/100
100000/100000 [==============================] - 13s - loss: 0.0753 - categorical_accuracy: 0.9781 - val_loss: 1.4493 - val_categorical_accuracy: 0.7231
Epoch 34/100
100000/100000 [==============================] - 13s - loss: 0.0718 - categorical_accuracy: 0.9794 - val_loss: 1.4753 - val_categorical_accuracy: 0.7264
Epoch 35/100
100000/100000 [==============================] - 13s - loss: 0.0768 - categorical_accuracy: 0.9778 - val_loss: 1.4453 - val_categorical_accuracy: 0.7310
Epoch 36/100
100000/100000 [==============================] - 13s - loss: 0.0706 - categorical_accuracy: 0.9794 - val_loss: 1.4212 - val_categorical_accuracy: 0.7304
Epoch 37/100
100000/100000 [==============================] - 13s - loss: 0.0714 - categorical_accuracy: 0.9795 - val_loss: 1.4976 - val_categorical_accuracy: 0.7356
Epoch 38/100
100000/100000 [==============================] - 13s - loss: 0.0705 - categorical_accuracy: 0.9798 - val_loss: 1.5222 - val_categorical_accuracy: 0.7280
Epoch 39/100
100000/100000 [==============================] - 14s - loss: 0.0723 - categorical_accuracy: 0.9788 - val_loss: 1.4834 - val_categorical_accuracy: 0.7319
Epoch 40/100
100000/100000 [==============================] - 13s - loss: 0.0699 - categorical_accuracy: 0.9803 - val_loss: 1.4641 - val_categorical_accuracy: 0.7260
Epoch 41/100
100000/100000 [==============================] - 13s - loss: 0.0705 - categorical_accuracy: 0.9801 - val_loss: 1.5154 - val_categorical_accuracy: 0.7336
Epoch 42/100
100000/100000 [==============================] - 13s - loss: 0.0678 - categorical_accuracy: 0.9809 - val_loss: 1.4617 - val_categorical_accuracy: 0.7272
Epoch 43/100
100000/100000 [==============================] - 13s - loss: 0.0671 - categorical_accuracy: 0.9810 - val_loss: 1.5793 - val_categorical_accuracy: 0.7281
Epoch 44/100
100000/100000 [==============================] - 13s - loss: 0.0694 - categorical_accuracy: 0.9810 - val_loss: 1.5207 - val_categorical_accuracy: 0.7317
Epoch 45/100
100000/100000 [==============================] - 13s - loss: 0.0650 - categorical_accuracy: 0.9819 - val_loss: 1.6564 - val_categorical_accuracy: 0.7269
Epoch 46/100
100000/100000 [==============================] - 13s - loss: 0.0702 - categorical_accuracy: 0.9807 - val_loss: 1.5863 - val_categorical_accuracy: 0.7287
Epoch 47/100
100000/100000 [==============================] - 13s - loss: 0.0687 - categorical_accuracy: 0.9814 - val_loss: 1.5430 - val_categorical_accuracy: 0.7279
Epoch 48/100
100000/100000 [==============================] - 13s - loss: 0.0650 - categorical_accuracy: 0.9826 - val_loss: 1.5269 - val_categorical_accuracy: 0.7302
Epoch 49/100
100000/100000 [==============================] - 13s - loss: 0.0671 - categorical_accuracy: 0.9812 - val_loss: 1.5618 - val_categorical_accuracy: 0.7312
Epoch 50/100
100000/100000 [==============================] - 13s - loss: 0.0613 - categorical_accuracy: 0.9832 - val_loss: 1.5923 - val_categorical_accuracy: 0.7278
Epoch 51/100
100000/100000 [==============================] - 13s - loss: 0.0667 - categorical_accuracy: 0.9820 - val_loss: 1.5419 - val_categorical_accuracy: 0.7278
Epoch 52/100
100000/100000 [==============================] - 13s - loss: 0.0654 - categorical_accuracy: 0.9827 - val_loss: 1.6332 - val_categorical_accuracy: 0.7146
Epoch 53/100
100000/100000 [==============================] - 13s - loss: 0.0628 - categorical_accuracy: 0.9829 - val_loss: 1.5928 - val_categorical_accuracy: 0.7288
Epoch 54/100
100000/100000 [==============================] - 13s - loss: 0.0653 - categorical_accuracy: 0.9827 - val_loss: 1.5374 - val_categorical_accuracy: 0.7310
Epoch 55/100
100000/100000 [==============================] - 13s - loss: 0.0618 - categorical_accuracy: 0.9836 - val_loss: 1.6791 - val_categorical_accuracy: 0.7239
Epoch 56/100
100000/100000 [==============================] - 13s - loss: 0.0669 - categorical_accuracy: 0.9827 - val_loss: 1.5604 - val_categorical_accuracy: 0.7263
Epoch 57/100
100000/100000 [==============================] - 13s - loss: 0.0569 - categorical_accuracy: 0.9850 - val_loss: 1.5749 - val_categorical_accuracy: 0.7288
Epoch 58/100
100000/100000 [==============================] - 13s - loss: 0.0672 - categorical_accuracy: 0.9823 - val_loss: 1.6425 - val_categorical_accuracy: 0.7245
Epoch 59/100
100000/100000 [==============================] - 13s - loss: 0.0612 - categorical_accuracy: 0.9836 - val_loss: 1.5744 - val_categorical_accuracy: 0.7319
Epoch 60/100
100000/100000 [==============================] - 13s - loss: 0.0620 - categorical_accuracy: 0.9842 - val_loss: 1.7158 - val_categorical_accuracy: 0.7255
Epoch 61/100
100000/100000 [==============================] - 13s - loss: 0.0616 - categorical_accuracy: 0.9840 - val_loss: 1.5676 - val_categorical_accuracy: 0.7251
Epoch 62/100
 50496/100000 [==============>...............] - ETA: 6s - loss: 0.0668 - categorical_accuracy: 0.9828
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-15-61d54d283d16> in <module>()
      2     x=trnData, y=trnLabels,
      3     batch_size=64, epochs=100, verbose=1,
----> 4     validation_data=[tstData, tstLabels], shuffle=True)#, callbacks=[tbCallBack])

/usr/local/lib/python2.7/dist-packages/keras/engine/training.pyc in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, **kwargs)
   1496                               val_f=val_f, val_ins=val_ins, shuffle=shuffle,
   1497                               callback_metrics=callback_metrics,
-> 1498                               initial_epoch=initial_epoch)
   1499 
   1500     def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):

/usr/local/lib/python2.7/dist-packages/keras/engine/training.pyc in _fit_loop(self, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch)
   1150                 batch_logs['size'] = len(batch_ids)
   1151                 callbacks.on_batch_begin(batch_index, batch_logs)
-> 1152                 outs = f(ins_batch)
   1153                 if not isinstance(outs, list):
   1154                     outs = [outs]

/usr/local/lib/python2.7/dist-packages/keras/backend/tensorflow_backend.pyc in __call__(self, inputs)
   2227         session = get_session()
   2228         updated = session.run(self.outputs + [self.updates_op],
-> 2229                               feed_dict=feed_dict)
   2230         return updated[:len(self.outputs)]
   2231 

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1037   def _do_call(self, fn, *args):
   1038     try:
-> 1039       return fn(*args)
   1040     except errors.OpError as e:
   1041       message = compat.as_text(e.message)

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1019         return tf_session.TF_Run(session, options,
   1020                                  feed_dict, fetch_list, target_list,
-> 1021                                  status, run_metadata)
   1022 
   1023     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

Predict and evaluate


In [ ]:
classProb = model.predict(x=tstData[0:2])
print('Class probabilities:', classProb, '\n')
loss, acc = model.evaluate(x=tstData, y=tstLabels, batch_size=1024)
print()
print('loss', loss)
print('acc', acc)

Compute test accuracy by hand


In [ ]:
classProb = model.predict(x=tstData)
print(classProb.shape)

correctProb = (classProb * tstLabels).sum(axis=1)
wrongProb = (classProb * (1-tstLabels)).max(axis=1)
print(correctProb.shape, wrongProb.shape)

accuracy = (correctProb > wrongProb).mean()
print('Accuracy: ',  accuracy)

In [ ]: