Uses Conv2D and Deconv2D operations to create a simple auto-encoder. In this example, it is possible that it learns the trivial function. It's intention is more to see how the light.network.conv2d_transpose() function works.
An image scale of [0, 1] is used here.
In [ ]:
# Force matplotlib to use inline rendering
%matplotlib inline
import os
import sys
# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))
import numpy as np
import tensorflow as tf
import tensorlight as light
In [ ]:
BATCH_SIZE = 64
WEIGHT_DECAY = 1e-8
INITIAL_LR = 0.01
DATA_ROOT = "data"
TRAIN_DIR = "train-test/mnist-cnn"
In [ ]:
dataset_train = light.datasets.mnist.MNISTTrainDataset(DATA_ROOT)
dataset_valid = light.datasets.mnist.MNISTValidDataset(DATA_ROOT)
In [ ]:
INPUT_BN = False
ENC_BN = False
DEC_BN = False
class SimpleCNNAutoencoderModel(light.model.AbstractModel):
def __init__(self, weight_decay=0.0):
super(SimpleCNNAutoencoderModel, self).__init__(weight_decay)
@light.utils.attr.override
def inference(self, inputs, targets, feeds,
is_training, device_scope, memory_device):
if INPUT_BN:
inputs = tf.contrib.layers.batch_norm(inputs)
with tf.variable_scope("Encoder"):
# 1: Conv
conv1 = light.network.conv2d("Conv1", inputs,
16, (5, 5), (2, 2),
weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
if ENC_BN:
conv1 = tf.contrib.layers.batch_norm(conv1)
# 2: Conv
conv2 = light.network.conv2d("Conv2", conv1,
32, (3, 3), (2, 2),
weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
if ENC_BN:
conv2 = tf.contrib.layers.batch_norm(conv2)
encoder_out = conv2
with tf.variable_scope("Decoder"):
# 3: Deconv
conv3t = light.network.conv2d_transpose("Deconv1", encoder_out,
16, (3, 3), (2, 2),
weight_init=light.init.bilinear_initializer(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
if DEC_BN:
conv3t = tf.contrib.layers.batch_norm(conv3t)
# 4: Deconv
conv4t = light.network.conv2d_transpose("Deconv2", conv3t,
1, (5, 5), (2, 2),
weight_init=light.init.bilinear_initializer(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.sigmoid)
decoder_out = conv4t
return decoder_out
@light.utils.attr.override
def loss(self, predictions, targets, device_scope):
return light.loss.bce(predictions, targets) + light.loss.mgdl(predictions, targets)
In [ ]:
light.hardware.set_cuda_devices([1])
runtime = light.core.DefaultRuntime(TRAIN_DIR)
runtime.register_datasets(dataset_train, dataset_valid, None)
runtime.register_model(SimpleCNNAutoencoderModel(weight_decay=WEIGHT_DECAY))
runtime.register_optimizer(light.training.Optimizer('adam', INITIAL_LR))
runtime.build(is_autoencoder=True,
verbose=True)
In [ ]:
runtime.train(batch_size=BATCH_SIZE, steps=1000, display_steps=50, do_checkpoints=False, do_summary=False)
In [ ]:
x, _ = dataset_valid.get_batch(4)
light.visualization.display_batch(x, nrows=2, ncols=2, title="Input")
pred = runtime.predict(x)
light.visualization.display_batch(pred, nrows=2, ncols=2, title="Reconstruction")
In [ ]:
runtime.test(BATCH_SIZE)
In [ ]:
runtime.close()
In [ ]: