In [1]:
'''Trains a Neural network with batch normalization 
on the MNIST dataset.
Uses tensorboard callback to write cross entropy and 
accuracy at epoch which can be visualized using tensorboard
'''


from __future__ import print_function
import numpy as np
np.random.seed(20)


import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation,BatchNormalization
from keras.utils import np_utils


# Model parameters
batch_size = 256
nb_classes = 10
nb_epoch = 20


# Load MNIST data and shuffle & split into train & test
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(X_train.shape[0], 784)
X_test = X_test.reshape(X_test.shape[0], 784)

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

X_train /= 255
X_test /= 255

# Convert to one hot encoding of classes
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)


# Neural network with no dropouts only Batch normalization
model = Sequential()
model.add(Dense(512, input_shape=(784,)))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Dense(256))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(Dense(10))
model.add(Activation('softmax'))

print(model.summary())


model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Callback to tensorboard to write logs 
# Visualize model and its performance using tensorboard --logdir=_mnist
# write_graph=True if you want to visualize model
tensorboard = keras.callbacks.TensorBoard(log_dir="/tmp/mnist",write_graph=False,write_images=True)

model.fit(X_train, Y_train,
                    batch_size=batch_size, nb_epoch=nb_epoch,
                    verbose=1, validation_data= (X_test,Y_test),callbacks=[tensorboard])


test_performance = model.evaluate(X_test, Y_test, verbose=0)
print('Test Categorical crossentropy:', test_performance[0])
print('Test accuracy:', test_performance[1])


Using TensorFlow backend.
/home/ubuntu/anaconda2/lib/python2.7/site-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.
  warnings.warn('The `nb_epoch` argument in `fit` '
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
activation_1 (Activation)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
batch_normalization_2 (Batch (None, 256)               1024      
_________________________________________________________________
activation_2 (Activation)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                2570      
_________________________________________________________________
activation_3 (Activation)    (None, 10)                0         
=================================================================
Total params: 538,890.0
Trainable params: 537,354.0
Non-trainable params: 1,536.0
_________________________________________________________________
None
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 13s - loss: 0.2001 - acc: 0.9420 - val_loss: 0.1283 - val_acc: 0.9665
Epoch 2/20
60000/60000 [==============================] - 13s - loss: 0.0684 - acc: 0.9794 - val_loss: 0.0817 - val_acc: 0.9751
Epoch 3/20
60000/60000 [==============================] - 12s - loss: 0.0382 - acc: 0.9892 - val_loss: 0.0759 - val_acc: 0.9756
Epoch 4/20
60000/60000 [==============================] - 13s - loss: 0.0258 - acc: 0.9925 - val_loss: 0.0736 - val_acc: 0.9779
Epoch 5/20
60000/60000 [==============================] - 12s - loss: 0.0158 - acc: 0.9959 - val_loss: 0.0745 - val_acc: 0.9779
Epoch 6/20
60000/60000 [==============================] - 13s - loss: 0.0130 - acc: 0.9964 - val_loss: 0.0689 - val_acc: 0.9815
Epoch 7/20
60000/60000 [==============================] - 13s - loss: 0.0110 - acc: 0.9968 - val_loss: 0.0644 - val_acc: 0.9814
Epoch 8/20
60000/60000 [==============================] - 13s - loss: 0.0085 - acc: 0.9978 - val_loss: 0.0826 - val_acc: 0.9751
Epoch 9/20
60000/60000 [==============================] - 13s - loss: 0.0085 - acc: 0.9975 - val_loss: 0.0810 - val_acc: 0.9767
Epoch 10/20
60000/60000 [==============================] - 13s - loss: 0.0108 - acc: 0.9966 - val_loss: 0.1106 - val_acc: 0.9709
Epoch 11/20
60000/60000 [==============================] - 13s - loss: 0.0110 - acc: 0.9967 - val_loss: 0.0710 - val_acc: 0.9793
Epoch 12/20
60000/60000 [==============================] - 13s - loss: 0.0063 - acc: 0.9983 - val_loss: 0.0667 - val_acc: 0.9823
Epoch 13/20
60000/60000 [==============================] - 13s - loss: 0.0050 - acc: 0.9986 - val_loss: 0.0848 - val_acc: 0.9785
Epoch 14/20
60000/60000 [==============================] - 13s - loss: 0.0046 - acc: 0.9989 - val_loss: 0.0980 - val_acc: 0.9761
Epoch 15/20
60000/60000 [==============================] - 13s - loss: 0.0050 - acc: 0.9987 - val_loss: 0.0810 - val_acc: 0.9798
Epoch 16/20
60000/60000 [==============================] - 13s - loss: 0.0059 - acc: 0.9983 - val_loss: 0.0843 - val_acc: 0.9791
Epoch 17/20
60000/60000 [==============================] - 13s - loss: 0.0093 - acc: 0.9971 - val_loss: 0.0982 - val_acc: 0.9766
Epoch 18/20
60000/60000 [==============================] - 13s - loss: 0.0074 - acc: 0.9976 - val_loss: 0.0817 - val_acc: 0.9800
Epoch 19/20
60000/60000 [==============================] - 13s - loss: 0.0055 - acc: 0.9983 - val_loss: 0.0780 - val_acc: 0.9817
Epoch 20/20
60000/60000 [==============================] - 17s - loss: 0.0032 - acc: 0.9991 - val_loss: 0.0745 - val_acc: 0.9819
Test Categorical crossentropy: 0.0745038796153
Test accuracy: 0.9819

In [ ]: