Tutorial on self-normalizing networks on the MNIST data set: convolutional neural networks

Author: Guenter Klambauer, 2017

tested under Python 3.5 and Tensorflow 1.1

Derived from: Aymeric Damien


In [1]:
import tensorflow as tf
import numpy as np

from __future__ import absolute_import, division, print_function
import numbers
from tensorflow.contrib import layers
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.layers import utils

from sklearn.preprocessing import StandardScaler
from scipy.special import erf,erfc

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

(1) Definition of scaled exponential linear units (SELUs)


In [2]:
def selu(x):
    with ops.name_scope('elu') as scope:
        alpha = 1.6732632423543772848170429916717
        scale = 1.0507009873554804934193349852946
        return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x))

(2) Definition of dropout variant for SNNs


In [3]:
def dropout_selu(x, rate, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, 
                 noise_shape=None, seed=None, name=None, training=False):
    """Dropout to a value with rescaling."""

    def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name):
        keep_prob = 1.0 - rate
        x = ops.convert_to_tensor(x, name="x")
        if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:
            raise ValueError("keep_prob must be a scalar tensor or a float in the "
                                             "range (0, 1], got %g" % keep_prob)
        keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob")
        keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())

        alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha")
        keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())

        if tensor_util.constant_value(keep_prob) == 1:
            return x

        noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
        random_tensor = keep_prob
        random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype)
        binary_tensor = math_ops.floor(random_tensor)
        ret = x * binary_tensor + alpha * (1-binary_tensor)

        a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar)))

        b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha)
        ret = a * ret + b
        ret.set_shape(x.get_shape())
        return ret

    with ops.name_scope(name, "dropout", [x]) as name:
        return utils.smart_cond(training,
            lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name),
            lambda: array_ops.identity(x))

(3) Scale input to zero mean and unit variance


In [4]:
scaler = StandardScaler().fit(mnist.train.images)

In [5]:
# Parameters
learning_rate = 0.025
training_iters = 50
batch_size = 128
display_step = 1

# Network Parameters
n_input = 784 # MNIST data input (img shape: 28*28)
n_classes = 10 # MNIST total classes (0-9 digits)
keep_prob_ReLU = 0.5 # Dropout, probability to keep units
dropout_prob_SNN = 0.05 # Dropout, probability to dropout units

# tf Graph input
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32) #dropout (keep probability for ReLU)
dropout_prob =  tf.placeholder(tf.float32) #dropout (dropout probability for SNN)
is_training = tf.placeholder(tf.bool)

In [6]:
# Create some wrappers for simplicity
def conv2d(x, W, b, strides=1):
    # Conv2D wrapper, with bias and relu activation
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

def conv2d_SNN(x, W, b, strides=1):
    # Conv2D wrapper, with bias and relu activation
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return selu(x)

def maxpool2d(x, k=2):
    # MaxPool2D wrapper
    return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],
                          padding='SAME')

In [7]:
# Create model
def conv_net_ReLU(x, weights, biases, keep_prob):
    # Reshape input picture
    x = tf.reshape(x, shape=[-1, 28, 28, 1])

    # Convolution Layer
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    # Max Pooling (down-sampling)
    conv1 = maxpool2d(conv1, k=2)

    # Convolution Layer
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    # Max Pooling (down-sampling)
    conv2 = maxpool2d(conv2, k=2)

    # Fully connected layer
    # Reshape conv2 output to fit fully connected layer input
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = tf.nn.relu(fc1)
    
    # Apply Dropout
    fc1 = tf.nn.dropout(fc1, keep_prob)

    # Output, class prediction
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
    return out

In [8]:
# Create model
def conv_net_SNN(x, weights, biases, dropout_prob, is_training):
    # Reshape input picture
    x = tf.reshape(x, shape=[-1, 28, 28, 1])

    # Convolution Layer
    conv1 = conv2d_SNN(x, weights['wc1'], biases['bc1'],)
    # Max Pooling (down-sampling)
    conv1 = maxpool2d(conv1, k=2)

    # Convolution Layer
    conv2 = conv2d_SNN(conv1, weights['wc2'], biases['bc2'])
    # Max Pooling (down-sampling)
    conv2 = maxpool2d(conv2, k=2)

    # Fully connected layer
    # Reshape conv2 output to fit fully connected layer input
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = selu(fc1)
    
    # Apply Dropout
    fc1 = dropout_selu(fc1, dropout_prob,training=is_training)

    # Output, class prediction
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
    return out

In [9]:
# RELU: Store layers weight & bias
## Improved with MSRA initialization

weights = {
    # 5x5 conv, 1 input, 32 outputs
    'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32],stddev=np.sqrt(2/25)) ),
    # 5x5 conv, 32 inputs, 64 outputs
    'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64],stddev=np.sqrt(2/(25*32)))),
    # fully connected, 7*7*64 inputs, 1024 outputs
    'wd1': tf.Variable(tf.random_normal([7*7*64, 1024],stddev=np.sqrt(2/(7*7*64)))),
    # 1024 inputs, 10 outputs (class prediction)
    'out': tf.Variable(tf.random_normal([1024, n_classes],stddev=np.sqrt(2/(1024))))
}

biases = {
    'bc1': tf.Variable(tf.random_normal([32],stddev=0)),
    'bc2': tf.Variable(tf.random_normal([64],stddev=0)),
    'bd1': tf.Variable(tf.random_normal([1024],stddev=0)),
    'out': tf.Variable(tf.random_normal([n_classes],stddev=0))
}

(4) Initialization with STDDEV of sqrt(1/n)


In [10]:
# SNN: Store layers weight & bias
weights2 = {
    # 5x5 conv, 1 input, 32 outputs
    'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32],stddev=np.sqrt(1/25)) ),
    # 5x5 conv, 32 inputs, 64 outputs
    'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64],stddev=np.sqrt(1/(25*32)))),
    # fully connected, 7*7*64 inputs, 1024 outputs
    'wd1': tf.Variable(tf.random_normal([7*7*64, 1024],stddev=np.sqrt(1/(7*7*64)))),
    # 1024 inputs, 10 outputs (class prediction)
    'out': tf.Variable(tf.random_normal([1024, n_classes],stddev=np.sqrt(1/(1024))))
}

biases2 = {
    'bc1': tf.Variable(tf.random_normal([32],stddev=0)),
    'bc2': tf.Variable(tf.random_normal([64],stddev=0)),
    'bd1': tf.Variable(tf.random_normal([1024],stddev=0)),
    'out': tf.Variable(tf.random_normal([n_classes],stddev=0))
}

In [11]:
# Construct model
pred_ReLU = conv_net_ReLU(x, weights, biases, keep_prob)
pred_SNN = conv_net_SNN(x, weights2, biases2, dropout_prob,is_training)

# Define loss and optimizer
cost_ReLU = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred_ReLU, labels=y))
cost_SNN = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred_SNN, labels=y))

optimizer_ReLU = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost_ReLU)
optimizer_SNN = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost_SNN)

# Evaluate ReLU model
correct_pred_ReLU = tf.equal(tf.argmax(pred_ReLU, 1), tf.argmax(y, 1))
accuracy_ReLU = tf.reduce_mean(tf.cast(correct_pred_ReLU, tf.float32))

# Evaluate SNN model
correct_pred_SNN = tf.equal(tf.argmax(pred_SNN, 1), tf.argmax(y, 1))
accuracy_SNN = tf.reduce_mean(tf.cast(correct_pred_SNN, tf.float32))


# Initializing the variables
init = tf.global_variables_initializer()

In [12]:
training_loss_protocol_ReLU = []
training_loss_protocol_SNN = []

In [13]:
# Launch the graph
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    sess.run(init)
    step = 0
    # Keep training until reach max iterations
    while step < training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x_norm = scaler.transform(batch_x)
        # Run optimization op (backprop)
        sess.run(optimizer_ReLU, feed_dict={x: batch_x, y: batch_y,
                                       keep_prob: keep_prob_ReLU})
        sess.run(optimizer_SNN, feed_dict={x: batch_x_norm, y: batch_y,
                                       dropout_prob: dropout_prob_SNN,is_training:True})
        
        
        if step % display_step == 0:
            #batch_x, batch_y = mnist.test.next_batch(batch_size)
            #batch_x_norm = scaler.transform(batch_x)
            # Calculate batch loss and accuracy
            loss_ReLU, acc_ReLU = sess.run([cost_ReLU, accuracy_ReLU], feed_dict={x: batch_x,
                                                              y: batch_y,
                                                              keep_prob: 1.0})
            training_loss_protocol_ReLU.append(loss_ReLU)
            
            loss_SNN, acc_SNN = sess.run([cost_SNN, accuracy_SNN], feed_dict={x: batch_x_norm,
                                                              y: batch_y,
                                                              dropout_prob: 0.0, is_training:False})
            training_loss_protocol_SNN.append(loss_SNN)
            
            print( "RELU: Nbr of updates: " + str(step+1) + ", Minibatch Loss= " + \
                  "{:.6f}".format(loss_ReLU) + ", Training Accuracy= " + \
                  "{:.5f}".format(acc_ReLU))
            
            print( "SNN: Nbr of updates: " + str(step+1) + ", Minibatch Loss= " + \
                  "{:.6f}".format(loss_SNN) + ", Training Accuracy= " + \
                  "{:.5f}".format(acc_SNN))
        step += 1
    print("Optimization Finished!\n")

    # Calculate accuracy for 256 mnist test images
    print("ReLU: Testing Accuracy:", sess.run(accuracy_ReLU, feed_dict={x: mnist.test.images[:512],
                                      y: mnist.test.labels[:512],
                                      keep_prob: 1.0}))
    print("SNN: Testing Accuracy:", sess.run(accuracy_SNN, feed_dict={x: scaler.transform(mnist.test.images[:512]),
                                      y: mnist.test.labels[:512],
                                      dropout_prob: 0.0, is_training:False}))


RELU: Nbr of updates: 1, Minibatch Loss= 2.008754, Training Accuracy= 0.35156
SNN: Nbr of updates: 1, Minibatch Loss= 2.451197, Training Accuracy= 0.39844
RELU: Nbr of updates: 2, Minibatch Loss= 1.732718, Training Accuracy= 0.37500
SNN: Nbr of updates: 2, Minibatch Loss= 2.160526, Training Accuracy= 0.65625
RELU: Nbr of updates: 3, Minibatch Loss= 1.752255, Training Accuracy= 0.39844
SNN: Nbr of updates: 3, Minibatch Loss= 1.601260, Training Accuracy= 0.53125
RELU: Nbr of updates: 4, Minibatch Loss= 1.605729, Training Accuracy= 0.53906
SNN: Nbr of updates: 4, Minibatch Loss= 0.978341, Training Accuracy= 0.70312
RELU: Nbr of updates: 5, Minibatch Loss= 1.555425, Training Accuracy= 0.51562
SNN: Nbr of updates: 5, Minibatch Loss= 0.645711, Training Accuracy= 0.82031
RELU: Nbr of updates: 6, Minibatch Loss= 1.313229, Training Accuracy= 0.67969
SNN: Nbr of updates: 6, Minibatch Loss= 0.401476, Training Accuracy= 0.90625
RELU: Nbr of updates: 7, Minibatch Loss= 1.203895, Training Accuracy= 0.77344
SNN: Nbr of updates: 7, Minibatch Loss= 0.453578, Training Accuracy= 0.92188
RELU: Nbr of updates: 8, Minibatch Loss= 1.089910, Training Accuracy= 0.85938
SNN: Nbr of updates: 8, Minibatch Loss= 0.297481, Training Accuracy= 0.95312
RELU: Nbr of updates: 9, Minibatch Loss= 1.017870, Training Accuracy= 0.79688
SNN: Nbr of updates: 9, Minibatch Loss= 0.365949, Training Accuracy= 0.91406
RELU: Nbr of updates: 10, Minibatch Loss= 1.070305, Training Accuracy= 0.76562
SNN: Nbr of updates: 10, Minibatch Loss= 0.405422, Training Accuracy= 0.90625
RELU: Nbr of updates: 11, Minibatch Loss= 0.985618, Training Accuracy= 0.79688
SNN: Nbr of updates: 11, Minibatch Loss= 0.460914, Training Accuracy= 0.88281
RELU: Nbr of updates: 12, Minibatch Loss= 0.875668, Training Accuracy= 0.72656
SNN: Nbr of updates: 12, Minibatch Loss= 0.349492, Training Accuracy= 0.90625
RELU: Nbr of updates: 13, Minibatch Loss= 1.041480, Training Accuracy= 0.76562
SNN: Nbr of updates: 13, Minibatch Loss= 0.436600, Training Accuracy= 0.89062
RELU: Nbr of updates: 14, Minibatch Loss= 0.836483, Training Accuracy= 0.83594
SNN: Nbr of updates: 14, Minibatch Loss= 0.356240, Training Accuracy= 0.92188
RELU: Nbr of updates: 15, Minibatch Loss= 0.824995, Training Accuracy= 0.81250
SNN: Nbr of updates: 15, Minibatch Loss= 0.407508, Training Accuracy= 0.87500
RELU: Nbr of updates: 16, Minibatch Loss= 0.739613, Training Accuracy= 0.85156
SNN: Nbr of updates: 16, Minibatch Loss= 0.289174, Training Accuracy= 0.92969
RELU: Nbr of updates: 17, Minibatch Loss= 0.782138, Training Accuracy= 0.80469
SNN: Nbr of updates: 17, Minibatch Loss= 0.314916, Training Accuracy= 0.91406
RELU: Nbr of updates: 18, Minibatch Loss= 0.687675, Training Accuracy= 0.85156
SNN: Nbr of updates: 18, Minibatch Loss= 0.243602, Training Accuracy= 0.94531
RELU: Nbr of updates: 19, Minibatch Loss= 0.647239, Training Accuracy= 0.82812
SNN: Nbr of updates: 19, Minibatch Loss= 0.205704, Training Accuracy= 0.96094
RELU: Nbr of updates: 20, Minibatch Loss= 0.673955, Training Accuracy= 0.78906
SNN: Nbr of updates: 20, Minibatch Loss= 0.293074, Training Accuracy= 0.92188
RELU: Nbr of updates: 21, Minibatch Loss= 0.643871, Training Accuracy= 0.84375
SNN: Nbr of updates: 21, Minibatch Loss= 0.305403, Training Accuracy= 0.92969
RELU: Nbr of updates: 22, Minibatch Loss= 0.577555, Training Accuracy= 0.91406
SNN: Nbr of updates: 22, Minibatch Loss= 0.225528, Training Accuracy= 0.96875
RELU: Nbr of updates: 23, Minibatch Loss= 0.539012, Training Accuracy= 0.90625
SNN: Nbr of updates: 23, Minibatch Loss= 0.207042, Training Accuracy= 0.96094
RELU: Nbr of updates: 24, Minibatch Loss= 0.595193, Training Accuracy= 0.85938
SNN: Nbr of updates: 24, Minibatch Loss= 0.297265, Training Accuracy= 0.89844
RELU: Nbr of updates: 25, Minibatch Loss= 0.610190, Training Accuracy= 0.83594
SNN: Nbr of updates: 25, Minibatch Loss= 0.255643, Training Accuracy= 0.95312
RELU: Nbr of updates: 26, Minibatch Loss= 0.708689, Training Accuracy= 0.69531
SNN: Nbr of updates: 26, Minibatch Loss= 0.161673, Training Accuracy= 0.98438
RELU: Nbr of updates: 27, Minibatch Loss= 0.702952, Training Accuracy= 0.79688
SNN: Nbr of updates: 27, Minibatch Loss= 0.215801, Training Accuracy= 0.94531
RELU: Nbr of updates: 28, Minibatch Loss= 0.470672, Training Accuracy= 0.88281
SNN: Nbr of updates: 28, Minibatch Loss= 0.269345, Training Accuracy= 0.91406
RELU: Nbr of updates: 29, Minibatch Loss= 0.554051, Training Accuracy= 0.83594
SNN: Nbr of updates: 29, Minibatch Loss= 0.296727, Training Accuracy= 0.92188
RELU: Nbr of updates: 30, Minibatch Loss= 0.504638, Training Accuracy= 0.84375
SNN: Nbr of updates: 30, Minibatch Loss= 0.227030, Training Accuracy= 0.93750
RELU: Nbr of updates: 31, Minibatch Loss= 0.566984, Training Accuracy= 0.85938
SNN: Nbr of updates: 31, Minibatch Loss= 0.212100, Training Accuracy= 0.96875
RELU: Nbr of updates: 32, Minibatch Loss= 0.505076, Training Accuracy= 0.86719
SNN: Nbr of updates: 32, Minibatch Loss= 0.224962, Training Accuracy= 0.92188
RELU: Nbr of updates: 33, Minibatch Loss= 0.487980, Training Accuracy= 0.87500
SNN: Nbr of updates: 33, Minibatch Loss= 0.192593, Training Accuracy= 0.96094
RELU: Nbr of updates: 34, Minibatch Loss= 0.377008, Training Accuracy= 0.93750
SNN: Nbr of updates: 34, Minibatch Loss= 0.164228, Training Accuracy= 0.96094
RELU: Nbr of updates: 35, Minibatch Loss= 0.468827, Training Accuracy= 0.89062
SNN: Nbr of updates: 35, Minibatch Loss= 0.222637, Training Accuracy= 0.92969
RELU: Nbr of updates: 36, Minibatch Loss= 0.456475, Training Accuracy= 0.90625
SNN: Nbr of updates: 36, Minibatch Loss= 0.223814, Training Accuracy= 0.92969
RELU: Nbr of updates: 37, Minibatch Loss= 0.521786, Training Accuracy= 0.83594
SNN: Nbr of updates: 37, Minibatch Loss= 0.289590, Training Accuracy= 0.91406
RELU: Nbr of updates: 38, Minibatch Loss= 0.512233, Training Accuracy= 0.80469
SNN: Nbr of updates: 38, Minibatch Loss= 0.254801, Training Accuracy= 0.92188
RELU: Nbr of updates: 39, Minibatch Loss= 0.462405, Training Accuracy= 0.84375
SNN: Nbr of updates: 39, Minibatch Loss= 0.192647, Training Accuracy= 0.95312
RELU: Nbr of updates: 40, Minibatch Loss= 0.398073, Training Accuracy= 0.89844
SNN: Nbr of updates: 40, Minibatch Loss= 0.127224, Training Accuracy= 0.97656
RELU: Nbr of updates: 41, Minibatch Loss= 0.454393, Training Accuracy= 0.85156
SNN: Nbr of updates: 41, Minibatch Loss= 0.204394, Training Accuracy= 0.92969
RELU: Nbr of updates: 42, Minibatch Loss= 0.455688, Training Accuracy= 0.88281
SNN: Nbr of updates: 42, Minibatch Loss= 0.198009, Training Accuracy= 0.95312
RELU: Nbr of updates: 43, Minibatch Loss= 0.402138, Training Accuracy= 0.89062
SNN: Nbr of updates: 43, Minibatch Loss= 0.170651, Training Accuracy= 0.96094
RELU: Nbr of updates: 44, Minibatch Loss= 0.430634, Training Accuracy= 0.89062
SNN: Nbr of updates: 44, Minibatch Loss= 0.216837, Training Accuracy= 0.96094
RELU: Nbr of updates: 45, Minibatch Loss= 0.389273, Training Accuracy= 0.91406
SNN: Nbr of updates: 45, Minibatch Loss= 0.180505, Training Accuracy= 0.96094
RELU: Nbr of updates: 46, Minibatch Loss= 0.409469, Training Accuracy= 0.91406
SNN: Nbr of updates: 46, Minibatch Loss= 0.193067, Training Accuracy= 0.94531
RELU: Nbr of updates: 47, Minibatch Loss= 0.368824, Training Accuracy= 0.89062
SNN: Nbr of updates: 47, Minibatch Loss= 0.158238, Training Accuracy= 0.97656
RELU: Nbr of updates: 48, Minibatch Loss= 0.388534, Training Accuracy= 0.89844
SNN: Nbr of updates: 48, Minibatch Loss= 0.229685, Training Accuracy= 0.93750
RELU: Nbr of updates: 49, Minibatch Loss= 0.321354, Training Accuracy= 0.94531
SNN: Nbr of updates: 49, Minibatch Loss= 0.143143, Training Accuracy= 0.96875
RELU: Nbr of updates: 50, Minibatch Loss= 0.356414, Training Accuracy= 0.90625
SNN: Nbr of updates: 50, Minibatch Loss= 0.160477, Training Accuracy= 0.96094
Optimization Finished!

ReLU: Testing Accuracy: 0.859375
SNN: Testing Accuracy: 0.916016

In [14]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot( training_loss_protocol_ReLU, label='Loss ReLU-CNN')
ax.plot( training_loss_protocol_SNN, label='Loss SNN')
ax.set_yscale('log')  # log scale
ax.set_xlabel('iterations/updates')
ax.set_ylabel('training loss')
fig.tight_layout()
ax.legend()
fig


Out[14]: