In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import time
sns.set()

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('', validation_size = 0)


Extracting train-images-idx3-ubyte.gz
Extracting train-labels-idx1-ubyte.gz
Extracting t10k-images-idx3-ubyte.gz
Extracting t10k-labels-idx1-ubyte.gz

In [3]:
def squash(X, epsilon = 1e-9):
    vec_squared_norm = tf.reduce_sum(tf.square(X), -2, keep_dims=True)
    scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + epsilon)
    return scalar_factor * X

def conv_layer(X, num_output, num_vector, kernel=None, stride=None):
    global batch_size
    capsules = tf.contrib.layers.conv2d(X, num_output * num_vector,
                                        kernel, stride, padding="VALID", activation_fn=tf.nn.relu)
    capsules = tf.reshape(capsules, (batch_size, -1, num_vector, 1))
    return squash(capsules)

def routing(X, b_IJ, routing_times = 2):
    global batch_size
    w = tf.Variable(tf.truncated_normal([1, 1152, 10, 8, 16], stddev=1e-1))
    X = tf.tile(X, [1, 1, 10, 1, 1])
    w = tf.tile(w, [batch_size, 1, 1, 1, 1])
    u_hat = tf.matmul(w, X, transpose_a=True)
    u_hat_stopped = tf.stop_gradient(u_hat)
    for i in range(routing_times):
        c_IJ = tf.nn.softmax(b_IJ, dim=2)
        if i == routing_times - 1:
            s_J = tf.multiply(c_IJ, u_hat)
            s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
            v_J = squash(s_J)
        else:
            s_J = tf.multiply(c_IJ, u_hat_stopped)
            s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
            v_J = squash(s_J)
            v_J_tiled = tf.tile(v_J, [1, 1152, 1, 1, 1])
            u_produce_v = tf.matmul(u_hat_stopped, v_J_tiled, transpose_a=True)
            b_IJ += u_produce_v
    return v_J

def fully_conn_layer(X, num_output):
    global batch_size
    X_ = tf.reshape(X, shape=(batch_size, -1, 1, X.shape[-2].value, 1))
    b_IJ = tf.constant(np.zeros([batch_size, 1152, num_output, 1, 1], dtype=np.float32))
    capsules = routing(X_, b_IJ, routing_times = 2)
    capsules = tf.squeeze(capsules, axis=1)
    return capsules

class CapsuleNetwork:
    def __init__(self, batch_size, learning_rate, regularization_scale=0.392,
                 epsilon=1e-8, m_plus=0.9, m_minus=0.1, lambda_val=0.5):
        self.X = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
        self.Y = tf.placeholder(tf.float32, shape=(None, 10))
        conv1 = tf.contrib.layers.conv2d(self.X, num_outputs=256,
                                             kernel_size=9, stride=1,
                                             padding='VALID')
        caps1 = conv_layer(conv1, 32, 8, 9, 2)
        caps2 = fully_conn_layer(caps1, 10)
        v_length = tf.sqrt(tf.reduce_sum(tf.square(caps2),axis=2, keep_dims=True) + epsilon)
        self.logits = tf.nn.softmax(v_length, dim=1)[:,:,0,0]
        masked_v = tf.multiply(tf.squeeze(caps2), tf.reshape(self.Y, (-1, 10, 1)))
        v_length = tf.sqrt(tf.reduce_sum(tf.square(caps2), axis=2, keep_dims=True) + epsilon)
        vector_j = tf.reshape(masked_v, shape=(batch_size, -1))
        fc1 = tf.contrib.layers.fully_connected(vector_j, num_outputs=512)
        fc2 = tf.contrib.layers.fully_connected(fc1, num_outputs=1024)
        decoded = tf.contrib.layers.fully_connected(fc2, num_outputs=784, activation_fn=tf.sigmoid)
        max_l = tf.square(tf.maximum(0., m_plus - v_length))
        max_r = tf.square(tf.maximum(0., v_length - m_minus))
        max_l = tf.reshape(max_l, shape=(batch_size, -1))
        max_r = tf.reshape(max_r, shape=(batch_size, -1))
        L_c = self.Y * max_l + lambda_val * (1 - self.Y) * max_r
        margin_loss = tf.reduce_mean(tf.reduce_sum(L_c, axis=1))
        origin = tf.reshape(self.X, shape=(batch_size, -1))
        squared = tf.reduce_mean(tf.square(decoded - origin))
        self.cost = margin_loss + regularization_scale * squared
        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)
        correct_prediction = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.Y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

In [ ]:
batch_size = 128
learning_rate = 0.001
epoch = 5

tf.reset_default_graph()
sess = tf.InteractiveSession()
model = CapsuleNetwork(batch_size, learning_rate)
sess.run(tf.global_variables_initializer())

In [ ]:
LOSS, ACC_TRAIN, ACC_TEST = [], [], []
for i in range(epoch):
    total_loss, total_acc = 0, 0
    for n in range(0, (mnist.train.images.shape[0] // batch_size) * batch_size, batch_size):
        batch_x = mnist.train.images[n: n + batch_size, :].reshape((-1, 28, 28, 1))
        batch_y = np.zeros((batch_size, 10))
        for k in range(batch_size):
            batch_y[k, mnist.train.labels[n + k]] = 1.0
        cost, _ = sess.run([model.cost, model.optimizer], 
                           feed_dict = {model.X : batch_x, 
                                        model.Y : batch_y})
        total_acc += sess.run(model.accuracy, 
                              feed_dict = {model.X : batch_x, 
                                           model.Y : batch_y})
        total_loss += cost
    total_loss /= (mnist.train.images.shape[0] // batch_size)
    total_acc /= (mnist.train.images.shape[0] // batch_size)
    ACC_TRAIN.append(total_acc)
    total_acc = 0
    for n in range(0, (mnist.test.images[:1000,:].shape[0] // batch_size) * batch_size, batch_size):
        batch_x = mnist.test.images[n: n + batch_size, :].reshape((-1, 28, 28, 1))
        batch_y = np.zeros((batch_size, 10))
        for k in range(batch_size):
            batch_y[k, mnist.test.labels[n + k]] = 1.0
        total_acc += sess.run(model.accuracy, 
                              feed_dict = {model.X : batch_x, 
                                           model.Y : batch_y})
    total_acc /= (mnist.test.images[:1000,:].shape[0] // batch_size)
    ACC_TEST.append(total_acc)
    print('epoch: %d, accuracy train: %f, accuracy testing: %f'%(i+1, ACC_TRAIN[-1],ACC_TEST[-1]))


epoch: 1, accuracy train: 0.949820, accuracy testing: 0.989955
epoch: 2, accuracy train: 0.991186, accuracy testing: 0.989955
epoch: 3, accuracy train: 0.994474, accuracy testing: 0.989955
epoch: 4, accuracy train: 0.996444, accuracy testing: 0.989955