Runtime MNIST CNN Classification Example

Uses Convs and FCs operations to do a simple. An image scale of [0, 1] is used here.


In [ ]:
# Force matplotlib to use inline rendering
%matplotlib inline

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
import tensorlight as light

In [ ]:
BATCH_SIZE = 32
WEIGHT_DECAY = 0.0001
INITIAL_LR = 0.001
TRAIN_DIR = "train-test/mnist"

In [ ]:
DATA_ROOT = "data"
dataset_train = light.datasets.mnist.MNISTTrainDataset(DATA_ROOT)
dataset_valid = light.datasets.mnist.MNISTValidDataset(DATA_ROOT)
dataset_test = light.datasets.mnist.MNISTTestDataset(DATA_ROOT)

Model


In [ ]:
DROPOUT_KEY = "keep_prob"

class SimpleClassificationModel(light.model.AbstractModel):
    
    def __init__(self, weight_decay=0.0):
        super(SimpleClassificationModel, self).__init__(weight_decay)
        
    @light.utils.attr.override
    def fetch_feeds(self):
        keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
        return {DROPOUT_KEY: keep_prob}
        
    @light.utils.attr.override
    def inference(self, inputs, targets, feeds, is_training, device_scope, memory_device):
        # 1: Conv
        conv1 = light.network.conv2d("Conv1", inputs,
                                  32, (5, 5), (1, 1),
                                  weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
                                  bias_init=0.1,
                                  regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
                                  activation=tf.nn.relu)
        conv1 = light.network.max_pool2d(conv1)

        # 2: Conv
        conv2 = light.network.conv2d("Conv2", conv1,
                                  64, (3, 3), (1, 1),
                                  weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
                                  bias_init=0.1,
                                  regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
                                  activation=tf.nn.relu)
        conv2 = light.network.max_pool2d(conv2)
        
        conv2_flat = tf.contrib.layers.flatten(conv2)
        
        # 1: FC
        fc1 = light.network.fc("FC1", conv2_flat, 256,
                            weight_init=tf.contrib.layers.xavier_initializer(),
                            bias_init=0.1,
                            regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
                            activation=tf.nn.relu)
        fc1_drop = tf.nn.dropout(fc1, keep_prob=feeds[DROPOUT_KEY])

        # 2: FC
        fc2 = light.network.fc("Out", fc1_drop, 10,
                            weight_init=tf.contrib.layers.xavier_initializer(),
                            bias_init=0.1,
                            regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
                            activation=tf.nn.softmax)

        res = tf.reshape(fc2, [-1] + targets.get_shape().as_list()[1:])
        return res
    
    @light.utils.attr.override
    def loss(self, predictions, targets, device_scope):
        return light.loss.ce(predictions, targets)

Training


In [ ]:
runtime = light.core.DefaultRuntime(TRAIN_DIR)
runtime.register_datasets(dataset_train, dataset_valid, dataset_test)
runtime.register_model(SimpleClassificationModel(weight_decay=WEIGHT_DECAY))
runtime.register_optimizer(light.training.Optimizer(light.training.SGD, INITIAL_LR))
runtime.build()

In [ ]:
def evaluate(dataset):
    x, y = dataset.get_batch(1)
    light.visualization.display_array(x[0])
    pred = runtime.predict(x, feeds={DROPOUT_KEY: 1.0})
    print(np.argmax(pred))

def on_valid(runtime, gstep):
    evaluate(runtime.datasets.valid)

In [ ]:
runtime.train(batch_size=BATCH_SIZE, steps=1000,
              train_feeds={DROPOUT_KEY: 0.5}, valid_feeds={DROPOUT_KEY: 1.0},
              on_validate=on_valid, do_checkpoints=True, do_summary=True)

Evaluation


In [ ]:
runtime.test(50, feeds={DROPOUT_KEY: 1.0})

In [ ]:
evaluate(dataset_test)

Terminate


In [ ]:
runtime.close()