VQ-VAE training example

Demonstration of how to train the model specified in https://arxiv.org/abs/1711.00937, using TF 2 / Sonnet 2.

On Mac and Linux, simply execute each cell in turn.


In [1]:
!pip install dm-sonnet dm-tree


Requirement already satisfied: dm-sonnet in /tmp/sonnet-nb-env/lib/python3.7/site-packages (2.0.0)
Requirement already satisfied: dm-tree in /tmp/sonnet-nb-env/lib/python3.7/site-packages (0.1.5)
Requirement already satisfied: six>=1.12.0 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.14.0)
Requirement already satisfied: tabulate>=0.7.5 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.8.7)
Requirement already satisfied: absl-py>=0.7.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (0.9.0)
Requirement already satisfied: numpy>=1.16.3 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.18.3)
Requirement already satisfied: wrapt>=1.11.1 in /tmp/sonnet-nb-env/lib/python3.7/site-packages (from dm-sonnet) (1.12.1)

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree

try:
  import sonnet.v2 as snt
  tf.enable_v2_behavior()
except ImportError:
  import sonnet as snt

print("TensorFlow version {}".format(tf.__version__))
print("Sonnet version {}".format(snt.__version__))


TensorFlow version 2.1.0
Sonnet version 2.0.0

Download Cifar10 data

This requires a connection to the internet and will download ~160MB.


In [3]:
cifar10 = tfds.as_numpy(tfds.load("cifar10:3.0.2", split="train+test", batch_size=-1))
cifar10.pop("id", None)
cifar10.pop("label")
tree.map_structure(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)


Out[3]:
{'image': 'uint8[60000, 32, 32, 3]'}

Load the data into Numpy

We compute the variance of the whole training set to normalise the Mean Squared Error below.


In [4]:
train_data_dict = tree.map_structure(lambda x: x[:40000], cifar10)
valid_data_dict = tree.map_structure(lambda x: x[40000:50000], cifar10)
test_data_dict = tree.map_structure(lambda x: x[50000:], cifar10)

In [5]:
def cast_and_normalise_images(data_dict):
  """Convert images to floating point with the range [-0.5, 0.5]"""
  images = data_dict['image']
  data_dict['image'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
  return data_dict

train_data_variance = np.var(train_data_dict['image'] / 255.0)
print('train data variance: %s' % train_data_variance)


train data variance: 0.06327039811675479

Encoder & Decoder Architecture


In [6]:
class ResidualStack(snt.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(ResidualStack, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._layers = []
    for i in range(num_residual_layers):
      conv3 = snt.Conv2D(
          output_channels=num_residual_hiddens,
          kernel_shape=(3, 3),
          stride=(1, 1),
          name="res3x3_%d" % i)
      conv1 = snt.Conv2D(
          output_channels=num_hiddens,
          kernel_shape=(1, 1),
          stride=(1, 1),
          name="res1x1_%d" % i)
      self._layers.append((conv3, conv1))

  def __call__(self, inputs):
    h = inputs
    for conv3, conv1 in self._layers:
      conv3_out = conv3(tf.nn.relu(h))
      conv1_out = conv1(tf.nn.relu(conv3_out))
      h += conv1_out
    return tf.nn.relu(h)  # Resnet V1 style


class Encoder(snt.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Encoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._enc_1 = snt.Conv2D(
        output_channels=self._num_hiddens // 2,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_1")
    self._enc_2 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_2")
    self._enc_3 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="enc_3")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)

  def __call__(self, x):
    h = tf.nn.relu(self._enc_1(x))
    h = tf.nn.relu(self._enc_2(h))
    h = tf.nn.relu(self._enc_3(h))
    return self._residual_stack(h)


class Decoder(snt.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Decoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._dec_1 = snt.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="dec_1")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)
    self._dec_2 = snt.Conv2DTranspose(
        output_channels=self._num_hiddens // 2,
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_2")
    self._dec_3 = snt.Conv2DTranspose(
        output_channels=3,
        output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_3")
    
  def __call__(self, x):
    h = self._dec_1(x)
    h = self._residual_stack(h)
    h = tf.nn.relu(self._dec_2(h))
    x_recon = self._dec_3(h)
    return x_recon
    

class VQVAEModel(snt.Module):
  def __init__(self, encoder, decoder, vqvae, pre_vq_conv1, 
               data_variance, name=None):
    super(VQVAEModel, self).__init__(name=name)
    self._encoder = encoder
    self._decoder = decoder
    self._vqvae = vqvae
    self._pre_vq_conv1 = pre_vq_conv1
    self._data_variance = data_variance

  def __call__(self, inputs, is_training):
    z = self._pre_vq_conv1(self._encoder(inputs))
    vq_output = self._vqvae(z, is_training=is_training)
    x_recon = self._decoder(vq_output['quantize'])
    recon_error = tf.reduce_mean((x_recon - inputs) ** 2) / self._data_variance
    loss = recon_error + vq_output['loss']
    return {
        'z': z,
        'x_recon': x_recon,
        'loss': loss,
        'recon_error': recon_error,
        'vq_output': vq_output,
    }

Build Model and train


In [7]:
%%time

# Set hyper-parameters.
batch_size = 32
image_size = 32

# 100k steps should take < 30 minutes on a modern (>= 2017) GPU.
# 10k steps gives reasonable accuracy with VQVAE on Cifar10.
num_training_updates = 10000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
# These hyper-parameters define the size of the model (number of parameters and layers).
# The hyper-parameters in the paper were (For ImageNet):
# batch_size = 128
# image_size = 128
# num_hiddens = 128
# num_residual_hiddens = 32
# num_residual_layers = 2

# This value is not that important, usually 64 works.
# This will not change the capacity in the information-bottleneck.
embedding_dim = 64

# The higher this value, the higher the capacity in the information bottleneck.
num_embeddings = 512

# commitment_cost should be set appropriately. It's often useful to try a couple
# of values. It mostly depends on the scale of the reconstruction cost
# (log p(x|z)). So if the reconstruction cost is 100x higher, the
# commitment_cost should also be multiplied with the same amount.
commitment_cost = 0.25

# Use EMA updates for the codebook (instead of the Adam optimizer).
# This typically converges faster, and makes the model less dependent on choice
# of the optimizer. In the VQ-VAE paper EMA updates were not used (but was
# developed afterwards). See Appendix of the paper for more details.
vq_use_ema = True

# This is only used for EMA updates.
decay = 0.99

learning_rate = 3e-4


# # Data Loading.
train_dataset = (
    tf.data.Dataset.from_tensor_slices(train_data_dict)
    .map(cast_and_normalise_images)
    .shuffle(10000)
    .repeat(-1)  # repeat indefinitely
    .batch(batch_size, drop_remainder=True)
    .prefetch(-1))

valid_dataset = (
    tf.data.Dataset.from_tensor_slices(valid_data_dict)
    .map(cast_and_normalise_images)
    .repeat(1)  # 1 epoch
    .batch(batch_size)
    .prefetch(-1))

# # Build modules.
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim,
    kernel_shape=(1, 1),
    stride=(1, 1),
    name="to_vq")

if vq_use_ema:
  vq_vae = snt.nets.VectorQuantizerEMA(
      embedding_dim=embedding_dim,
      num_embeddings=num_embeddings,
      commitment_cost=commitment_cost,
      decay=decay)
else:
  vq_vae = snt.nets.VectorQuantizer(
      embedding_dim=embedding_dim,
      num_embeddings=num_embeddings,
      commitment_cost=commitment_cost)
  
model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,
                   data_variance=train_data_variance)

optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

@tf.function
def train_step(data):
  with tf.GradientTape() as tape:
    model_output = model(data['image'], is_training=True)
  trainable_variables = model.trainable_variables
  grads = tape.gradient(model_output['loss'], trainable_variables)
  optimizer.apply(grads, trainable_variables)

  return model_output

train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []

for step_index, data in enumerate(train_dataset):
  train_results = train_step(data)
  train_losses.append(train_results['loss'])
  train_recon_errors.append(train_results['recon_error'])
  train_perplexities.append(train_results['vq_output']['perplexity'])
  train_vqvae_loss.append(train_results['vq_output']['loss'])

  if (step_index + 1) % 100 == 0:
    print('%d train loss: %f ' % (step_index + 1,
                                   np.mean(train_losses[-100:])) +
          ('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) +
          ('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) +
          ('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:])))
  if step_index == num_training_updates:
    break


WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING:tensorflow:AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING: AutoGraph could not transform <function train_step at 0x7f1016cb5f80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Unable to locate the source code of <function train_step at 0x7f1016cb5f80>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmp/sonnet-nb-env/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
100 train loss: 0.523625 recon_error: 0.483 perplexity: 10.356 vqvae loss: 0.041
200 train loss: 0.248232 recon_error: 0.223 perplexity: 18.294 vqvae loss: 0.026
300 train loss: 0.215068 recon_error: 0.190 perplexity: 23.106 vqvae loss: 0.025
400 train loss: 0.191891 recon_error: 0.164 perplexity: 29.139 vqvae loss: 0.028
500 train loss: 0.180945 recon_error: 0.147 perplexity: 34.253 vqvae loss: 0.033
600 train loss: 0.167115 recon_error: 0.134 perplexity: 39.961 vqvae loss: 0.033
700 train loss: 0.157724 recon_error: 0.124 perplexity: 46.521 vqvae loss: 0.033
800 train loss: 0.153761 recon_error: 0.119 perplexity: 53.559 vqvae loss: 0.035
900 train loss: 0.145033 recon_error: 0.112 perplexity: 62.442 vqvae loss: 0.033
1000 train loss: 0.137589 recon_error: 0.105 perplexity: 71.831 vqvae loss: 0.033
1100 train loss: 0.133044 recon_error: 0.101 perplexity: 79.135 vqvae loss: 0.032
1200 train loss: 0.129990 recon_error: 0.098 perplexity: 87.959 vqvae loss: 0.032
1300 train loss: 0.126507 recon_error: 0.095 perplexity: 96.704 vqvae loss: 0.031
1400 train loss: 0.122403 recon_error: 0.092 perplexity: 104.202 vqvae loss: 0.031
1500 train loss: 0.122003 recon_error: 0.091 perplexity: 112.476 vqvae loss: 0.031
1600 train loss: 0.120192 recon_error: 0.089 perplexity: 122.269 vqvae loss: 0.032
1700 train loss: 0.117041 recon_error: 0.086 perplexity: 129.887 vqvae loss: 0.031
1800 train loss: 0.115004 recon_error: 0.083 perplexity: 138.603 vqvae loss: 0.032
1900 train loss: 0.114134 recon_error: 0.082 perplexity: 147.545 vqvae loss: 0.032
2000 train loss: 0.112840 recon_error: 0.081 perplexity: 153.993 vqvae loss: 0.032
2100 train loss: 0.108815 recon_error: 0.077 perplexity: 161.729 vqvae loss: 0.031
2200 train loss: 0.108596 recon_error: 0.078 perplexity: 171.971 vqvae loss: 0.031
2300 train loss: 0.108132 recon_error: 0.077 perplexity: 181.157 vqvae loss: 0.031
2400 train loss: 0.106273 recon_error: 0.076 perplexity: 186.200 vqvae loss: 0.031
2500 train loss: 0.105936 recon_error: 0.075 perplexity: 194.301 vqvae loss: 0.031
2600 train loss: 0.103880 recon_error: 0.073 perplexity: 201.674 vqvae loss: 0.030
2700 train loss: 0.101655 recon_error: 0.072 perplexity: 207.131 vqvae loss: 0.030
2800 train loss: 0.102564 recon_error: 0.072 perplexity: 216.983 vqvae loss: 0.030
2900 train loss: 0.101613 recon_error: 0.072 perplexity: 219.649 vqvae loss: 0.030
3000 train loss: 0.101227 recon_error: 0.071 perplexity: 226.789 vqvae loss: 0.030
3100 train loss: 0.100786 recon_error: 0.071 perplexity: 235.522 vqvae loss: 0.030
3200 train loss: 0.100130 recon_error: 0.070 perplexity: 243.282 vqvae loss: 0.030
3300 train loss: 0.097764 recon_error: 0.067 perplexity: 249.584 vqvae loss: 0.030
3400 train loss: 0.100630 recon_error: 0.069 perplexity: 260.551 vqvae loss: 0.031
3500 train loss: 0.099929 recon_error: 0.068 perplexity: 266.012 vqvae loss: 0.032
3600 train loss: 0.099245 recon_error: 0.067 perplexity: 272.031 vqvae loss: 0.032
3700 train loss: 0.097812 recon_error: 0.066 perplexity: 279.691 vqvae loss: 0.032
3800 train loss: 0.097137 recon_error: 0.064 perplexity: 284.240 vqvae loss: 0.033
3900 train loss: 0.099217 recon_error: 0.066 perplexity: 293.507 vqvae loss: 0.034
4000 train loss: 0.098570 recon_error: 0.065 perplexity: 300.891 vqvae loss: 0.034
4100 train loss: 0.099238 recon_error: 0.065 perplexity: 306.762 vqvae loss: 0.034
4200 train loss: 0.098172 recon_error: 0.064 perplexity: 311.918 vqvae loss: 0.035
4300 train loss: 0.096449 recon_error: 0.063 perplexity: 316.246 vqvae loss: 0.034
4400 train loss: 0.096487 recon_error: 0.062 perplexity: 319.591 vqvae loss: 0.034
4500 train loss: 0.096092 recon_error: 0.062 perplexity: 322.313 vqvae loss: 0.034
4600 train loss: 0.096474 recon_error: 0.062 perplexity: 324.620 vqvae loss: 0.035
4700 train loss: 0.097075 recon_error: 0.063 perplexity: 324.357 vqvae loss: 0.035
4800 train loss: 0.094709 recon_error: 0.060 perplexity: 326.024 vqvae loss: 0.034
4900 train loss: 0.096557 recon_error: 0.061 perplexity: 327.701 vqvae loss: 0.035
5000 train loss: 0.096185 recon_error: 0.061 perplexity: 326.664 vqvae loss: 0.035
5100 train loss: 0.095646 recon_error: 0.060 perplexity: 327.617 vqvae loss: 0.035
5200 train loss: 0.094689 recon_error: 0.059 perplexity: 328.692 vqvae loss: 0.035
5300 train loss: 0.097047 recon_error: 0.061 perplexity: 327.988 vqvae loss: 0.036
5400 train loss: 0.096259 recon_error: 0.060 perplexity: 327.075 vqvae loss: 0.036
5500 train loss: 0.094588 recon_error: 0.059 perplexity: 327.083 vqvae loss: 0.036
5600 train loss: 0.095947 recon_error: 0.060 perplexity: 328.213 vqvae loss: 0.036
5700 train loss: 0.095466 recon_error: 0.059 perplexity: 329.375 vqvae loss: 0.036
5800 train loss: 0.094849 recon_error: 0.059 perplexity: 326.821 vqvae loss: 0.036
5900 train loss: 0.093799 recon_error: 0.058 perplexity: 328.409 vqvae loss: 0.036
6000 train loss: 0.095373 recon_error: 0.059 perplexity: 326.791 vqvae loss: 0.036
6100 train loss: 0.093989 recon_error: 0.059 perplexity: 325.959 vqvae loss: 0.035
6200 train loss: 0.095549 recon_error: 0.059 perplexity: 330.829 vqvae loss: 0.036
6300 train loss: 0.094730 recon_error: 0.058 perplexity: 330.906 vqvae loss: 0.036
6400 train loss: 0.095038 recon_error: 0.058 perplexity: 329.353 vqvae loss: 0.037
6500 train loss: 0.095891 recon_error: 0.059 perplexity: 330.197 vqvae loss: 0.037
6600 train loss: 0.094342 recon_error: 0.058 perplexity: 331.240 vqvae loss: 0.036
6700 train loss: 0.095096 recon_error: 0.058 perplexity: 330.618 vqvae loss: 0.037
6800 train loss: 0.095581 recon_error: 0.059 perplexity: 324.493 vqvae loss: 0.037
6900 train loss: 0.094467 recon_error: 0.058 perplexity: 328.868 vqvae loss: 0.037
7000 train loss: 0.092967 recon_error: 0.057 perplexity: 328.276 vqvae loss: 0.036
7100 train loss: 0.094339 recon_error: 0.058 perplexity: 327.318 vqvae loss: 0.037
7200 train loss: 0.095227 recon_error: 0.058 perplexity: 326.306 vqvae loss: 0.037
7300 train loss: 0.093832 recon_error: 0.057 perplexity: 328.262 vqvae loss: 0.037
7400 train loss: 0.093331 recon_error: 0.057 perplexity: 327.987 vqvae loss: 0.037
7500 train loss: 0.094718 recon_error: 0.058 perplexity: 328.948 vqvae loss: 0.037
7600 train loss: 0.094199 recon_error: 0.058 perplexity: 328.468 vqvae loss: 0.037
7700 train loss: 0.094603 recon_error: 0.058 perplexity: 327.501 vqvae loss: 0.037
7800 train loss: 0.092299 recon_error: 0.056 perplexity: 327.630 vqvae loss: 0.037
7900 train loss: 0.095228 recon_error: 0.058 perplexity: 329.946 vqvae loss: 0.037
8000 train loss: 0.094291 recon_error: 0.058 perplexity: 326.790 vqvae loss: 0.037
8100 train loss: 0.094481 recon_error: 0.057 perplexity: 328.667 vqvae loss: 0.037
8200 train loss: 0.093992 recon_error: 0.057 perplexity: 329.655 vqvae loss: 0.037
8300 train loss: 0.093976 recon_error: 0.057 perplexity: 323.950 vqvae loss: 0.037
8400 train loss: 0.093422 recon_error: 0.057 perplexity: 324.523 vqvae loss: 0.036
8500 train loss: 0.092898 recon_error: 0.056 perplexity: 325.402 vqvae loss: 0.037
8600 train loss: 0.094298 recon_error: 0.057 perplexity: 329.251 vqvae loss: 0.037
8700 train loss: 0.094489 recon_error: 0.057 perplexity: 331.027 vqvae loss: 0.037
8800 train loss: 0.093022 recon_error: 0.056 perplexity: 327.495 vqvae loss: 0.037
8900 train loss: 0.093427 recon_error: 0.057 perplexity: 328.008 vqvae loss: 0.037
9000 train loss: 0.094884 recon_error: 0.058 perplexity: 327.057 vqvae loss: 0.037
9100 train loss: 0.093559 recon_error: 0.056 perplexity: 331.800 vqvae loss: 0.037
9200 train loss: 0.093282 recon_error: 0.056 perplexity: 328.689 vqvae loss: 0.037
9300 train loss: 0.092217 recon_error: 0.056 perplexity: 323.903 vqvae loss: 0.036
9400 train loss: 0.093902 recon_error: 0.057 perplexity: 326.350 vqvae loss: 0.037
9500 train loss: 0.093772 recon_error: 0.057 perplexity: 325.627 vqvae loss: 0.037
9600 train loss: 0.093123 recon_error: 0.056 perplexity: 327.352 vqvae loss: 0.037
9700 train loss: 0.092934 recon_error: 0.056 perplexity: 328.674 vqvae loss: 0.037
9800 train loss: 0.093284 recon_error: 0.056 perplexity: 329.437 vqvae loss: 0.037
9900 train loss: 0.094147 recon_error: 0.057 perplexity: 330.146 vqvae loss: 0.037
10000 train loss: 0.092876 recon_error: 0.056 perplexity: 326.349 vqvae loss: 0.037
CPU times: user 1h 47min 46s, sys: 14min 12s, total: 2h 1min 59s
Wall time: 4min 29s

Plot loss


In [8]:
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_recon_errors)
ax.set_yscale('log')
ax.set_title('NMSE.')

ax = f.add_subplot(1,2,2)
ax.plot(train_perplexities)
ax.set_title('Average codebook usage (perplexity).')


Out[8]:
Text(0.5, 1.0, 'Average codebook usage (perplexity).')

View reconstructions


In [9]:
# Reconstructions
train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))

# Put data through the model with is_training=False, so that in the case of 
# using EMA the codebook is not updated.
train_reconstructions = model(train_batch['image'],
                              is_training=False)['x_recon'].numpy()
valid_reconstructions = model(valid_batch['image'],
                              is_training=False)['x_recon'].numpy()


def convert_batch_to_image_grid(image_batch):
  reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
              .transpose(0, 2, 1, 3, 4)
              .reshape(4 * 32, 8 * 32, 3))
  return reshaped + 0.5



f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['image'].numpy()),
          interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')

ax = f.add_subplot(2,2,2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions),
          interpolation='nearest')
ax.set_title('training data reconstructions')
plt.axis('off')

ax = f.add_subplot(2,2,3)
ax.imshow(convert_batch_to_image_grid(valid_batch['image'].numpy()),
          interpolation='nearest')
ax.set_title('validation data originals')
plt.axis('off')

ax = f.add_subplot(2,2,4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions),
          interpolation='nearest')
ax.set_title('validation data reconstructions')
plt.axis('off')


WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[9]:
(-0.5, 255.5, 127.5, -0.5)