In [ ]:
# Force matplotlib to use inline rendering
%matplotlib inline
import os
import sys
# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))
import numpy as np
import tensorflow as tf
import tensorlight as light
In [ ]:
TRAIN_DIR = "train-examples/runtime-demo"
BATCH_SIZE = 24
WEIGHT_DECAY = 5e-4
INITIAL_LR = 0.001
LR_DECAY_STEP_INTERVAL = 10000
LR_DECAY_RATE = 0.5
In [ ]:
DATA_ROOT = "data"
dataset_train = light.datasets.mnist.MNISTTrainDataset(DATA_ROOT)
dataset_valid = light.datasets.mnist.MNISTValidDataset(DATA_ROOT)
dataset_test = light.datasets.mnist.MNISTTestDataset(DATA_ROOT)
In [ ]:
class SimpleAutoencoderModel(light.model.AbstractModel):
def __init__(self, weight_decay=0.0):
super(SimpleAutoencoderModel, self).__init__(weight_decay)
@light.utils.attr.override
def inference(self, inputs, targets, feeds,
is_training, device_scope, memory_device):
x = tf.contrib.layers.flatten(inputs)
encoded = light.network.fc("FC_Enc", x, 64,
weight_init=tf.contrib.layers.xavier_initializer(),
bias_init=0.0,
regularizer=tf.contrib.layers.regularizers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu,
device=memory_device)
representation = encoded
decoded = light.network.fc("FC_Dec", representation, x.get_shape()[1],
weight_init=tf.contrib.layers.xavier_initializer(),
bias_init=0.0,
regularizer=tf.contrib.layers.regularizers.l2_regularizer(self.weight_decay),
activation=tf.nn.sigmoid,
device=memory_device)
return tf.reshape(decoded, [-1] + targets.get_shape().as_list()[1:])
@light.utils.attr.override
def loss(self, predictions, targets, device_scope):
loss1 = light.loss.mse(predictions, targets)
loss2 = light.loss.bce(predictions, targets)
tf.add_to_collection(light.core.LOG_LOSSES, loss1)
tf.add_to_collection(light.core.LOG_LOSSES, loss2)
return tf.add(0.5 * loss1, 0.5 * loss2, name="25mse_75bce")
@light.utils.attr.override
def evaluation(self, predictions, targets, device_scope=None):
psnr = light.image.psnr(predictions, targets)
sharpdiff = light.image.sharp_diff(predictions, targets)
ssim = light.image.ssim(predictions, targets, L=1.0)
return {"psnr": psnr, "sharpdiff": sharpdiff, "ssim": ssim}
In [ ]:
class SimpleFullyConvolutionalAutoencoderModel(light.model.AbstractModel):
def __init__(self, weight_decay=0.0):
super(SimpleFullyConvolutionalAutoencoderModel, self).__init__(weight_decay)
@light.utils.attr.override
def inference(self, inputs, targets, feeds,
is_training, device_scope, memory_device):
with tf.variable_scope("Encoder"):
# 1: Conv
conv1 = light.network.conv2d("Conv1", inputs,
8, (5, 5), (2, 2),
weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
# test-summary for the fist conv-layer:
light.board.activation_summary(conv1, True, scope="Conv1")
light.board.conv_image_summary("conv1_out", conv1)
with tf.variable_scope("Conv1", reuse=True):
# hack to access the kernel-weights
kernel = tf.get_variable("W")
light.board.conv_filter_image_summary("conv1_filters", kernel)
# 2: Conv
conv2 = light.network.conv2d("Conv2", conv1,
16, (3, 3), (2, 2),
weight_init=tf.contrib.layers.xavier_initializer_conv2d(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
encoder_out = conv2
with tf.variable_scope("Decoder"):
# 3: Deconv
conv3t = light.network.conv2d_transpose("Deconv1", encoder_out,
8, (3, 3), (2, 2),
weight_init=light.init.bilinear_initializer(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.relu)
# 4: Deconv
conv4t = light.network.conv2d_transpose("Deconv2", conv3t,
1, (5, 5), (2, 2),
weight_init=light.init.bilinear_initializer(),
bias_init=0.01,
regularizer=tf.contrib.layers.l2_regularizer(self.weight_decay),
activation=tf.nn.sigmoid)
decoder_out = conv4t
return decoder_out
@light.utils.attr.override
def loss(self, predictions, targets, device_scope):
return light.loss.bce(predictions, targets) + light.loss.mgdl(predictions, targets)
@light.utils.attr.override
def evaluation(self, predictions, targets, device_scope=None):
psnr = light.image.psnr(predictions, targets)
sharpdiff = light.image.sharp_diff(predictions, targets)
ssim = light.image.ssim(predictions, targets, L=1.0)
return {"psnr": psnr, "sharpdiff": sharpdiff, "ssim": ssim}
In [ ]:
runtime = light.core.DefaultRuntime(train_dir=TRAIN_DIR, gpu_devices=[4])
#runtime = light.core.MultiGpuRuntime(train_dir=TRAIN_DIR, gpu_devices=[0, 1])
In [ ]:
model = SimpleAutoencoderModel(weight_decay=WEIGHT_DECAY)
#model = SimpleFullyConvolutionalAutoencoderModel(weight_decay=WEIGHT_DECAY)
model.print_params()
runtime.register_model(model)
In [ ]:
optimizer = light.training.Optimizer('adam', INITIAL_LR,
LR_DECAY_STEP_INTERVAL, LR_DECAY_RATE)
optimizer.print_params()
runtime.register_optimizer(optimizer)
In [ ]:
runtime.register_datasets(dataset_train, dataset_valid, dataset_test)
In [ ]:
runtime.build(is_autoencoder=True, verbose=True)
print("Starting with global step: {}".format(runtime.gstep))
In [ ]:
runtime.list_params()
In [ ]:
def on_valid(rt, gstep):
print ("On-Validate Hook...")
In [ ]:
runtime.train(BATCH_SIZE, steps=1000, on_validate=on_valid,
display_steps=25, do_summary=False, do_checkpoints=True)
In [ ]:
runtime.validate(batch_size=50)
In [ ]:
runtime.test(batch_size=50)
In [ ]:
def show(inputs, predictions):
print("Inputs-Range : [{}, {}]".format(inputs.min(), inputs.max()))
print("Targets-Range: [{}, {}]".format(predictions.min(), predictions.max()))
light.visualization.display_batch(inputs, title="Inputs")
light.visualization.display_batch(predictions, title="Predictions")
In [ ]:
image_shape = dataset_train.input_shape
fake_inputs = np.random.rand(4,image_shape[-3],image_shape[-2],image_shape[-1])
predictions = runtime.predict(fake_inputs)
show(fake_inputs, predictions)
In [ ]:
inputs, _ = dataset_train.get_batch(4)
predictions = runtime.predict(inputs)
show(inputs, predictions)
In [ ]:
SIZE_FACTOR = 2.0
image_shape = dataset_train.input_shape
changed_height = int(image_shape[-3] * SIZE_FACTOR)
changed_width = int(image_shape[-2] * SIZE_FACTOR)
channels = image_shape[-1]
print("Changed-Shape: [{}, {}, {}]".format(changed_height, changed_width, channels))
In [ ]:
runtime.unregister_datasets()
runtime.build(is_autoencoder=True, track_ema_variables=False, restore_ema_variables=True,
input_shape=[changed_height, changed_width, channels])
In [ ]:
fake_inputs = np.random.rand(4, changed_height, changed_width, channels)
predictions = runtime.predict(fake_inputs)
show(fake_inputs, predictions)
In [ ]:
inputs, _ = dataset_train.get_batch(4)
# pad to have full shape
inputs = light.utils.image.pad_or_crop(inputs, [changed_height, changed_width, channels])
print(inputs.shape)
predictions = runtime.predict(inputs)
show(inputs, predictions)
In [ ]:
inputs, _ = dataset_train.get_batch(1)
# pad to have full shape
scaled = light.utils.image.resize(inputs[0], scale=SIZE_FACTOR)
scaled = np.expand_dims(scaled, 0)
print(scaled.shape)
predictions = runtime.predict(scaled)
show(scaled, predictions)
In [ ]:
runtime.close()
In [ ]: