In [ ]:
# Force matplotlib to use inline rendering
%matplotlib inline
import os
import sys
# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))
import numpy as np
import tensorflow as tf
import tensorlight as light
In [ ]:
BATCH_SIZE = 32
WEIGHT_DECAY = 0.0001
INITIAL_LR = 0.001
TRAIN_DIR = "train-test/mnist"
In [ ]:
DATA_ROOT = "data"
dataset_train = light.datasets.mnist.MNISTTrainDataset(DATA_ROOT)
dataset_valid = light.datasets.mnist.MNISTValidDataset(DATA_ROOT)
dataset_test = light.datasets.mnist.MNISTTestDataset(DATA_ROOT)
In [ ]:
DROPOUT_KEY = "keep_prob"
class SimpleClassificationModel(light.model.AbstractModel):
def __init__(self, weight_decay=0.0):
super(SimpleClassificationModel, self).__init__(weight_decay)
@light.utils.attr.override
def fetch_feeds(self):
keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
return {DROPOUT_KEY: keep_prob}
@light.utils.attr.override
def inference(self, inputs, targets, feeds, is_training, device_scope, memory_device):
# 1: Conv
conv1 = light.network.conv2d("Conv1", inputs,
32, (5, 5), (1, 1),
weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
bias_init=0.1,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
conv1 = light.network.max_pool2d(conv1)
# 2: Conv
conv2 = light.network.conv2d("Conv2", conv1,
64, (3, 3), (1, 1),
weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
bias_init=0.1,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
conv2 = light.network.max_pool2d(conv2)
conv2_flat = tf.contrib.layers.flatten(conv2)
# 1: FC
fc1 = light.network.fc("FC1", conv2_flat, 256,
weight_init=tf.contrib.layers.xavier_initializer(),
bias_init=0.1,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
fc1_drop = tf.nn.dropout(fc1, keep_prob=feeds[DROPOUT_KEY])
# 2: FC
fc2 = light.network.fc("Out", fc1_drop, 10,
weight_init=tf.contrib.layers.xavier_initializer(),
bias_init=0.1,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.softmax)
res = tf.reshape(fc2, [-1] + targets.get_shape().as_list()[1:])
return res
@light.utils.attr.override
def loss(self, predictions, targets, device_scope):
return light.loss.ce(predictions, targets)
In [ ]:
runtime = light.core.DefaultRuntime(TRAIN_DIR)
runtime.register_datasets(dataset_train, dataset_valid, dataset_test)
runtime.register_model(SimpleClassificationModel(weight_decay=WEIGHT_DECAY))
runtime.register_optimizer(light.training.Optimizer(light.training.SGD, INITIAL_LR))
runtime.build()
In [ ]:
def evaluate(dataset):
x, y = dataset.get_batch(1)
light.visualization.display_array(x[0])
pred = runtime.predict(x, feeds={DROPOUT_KEY: 1.0})
print(np.argmax(pred))
def on_valid(runtime, gstep):
evaluate(runtime.datasets.valid)
In [ ]:
runtime.train(batch_size=BATCH_SIZE, steps=1000,
train_feeds={DROPOUT_KEY: 0.5}, valid_feeds={DROPOUT_KEY: 1.0},
on_validate=on_valid, do_checkpoints=True, do_summary=True)
In [ ]:
runtime.test(50, feeds={DROPOUT_KEY: 1.0})
In [ ]:
evaluate(dataset_test)
In [ ]:
runtime.close()