Licensed under the Apache License, Version 2.0 (the "License");


In [0]:
#@title ##### Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Variational Autoencoder

View source on GitHub

1 Imports


In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import sys
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import tensorflow_datasets as tfds
import tensorflow_probability as tfp

# Globally Enable XLA.
# tf.config.optimizer.set_jit(True)

try:
  physical_devices = tf.config.list_physical_devices('GPU')
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

tfb = tfp.bijectors
tfd = tfp.distributions
tfn = tfp.experimental.nn

2 Load Dataset


In [0]:
[train_dataset, eval_dataset], datasets_info = tfds.load(
    name='binarized_mnist',
    split=['train', 'test'],
    with_info=True,
    shuffle_files=True)

def _preprocess(sample):
  return tf.cast(sample['image'], tf.float32)

train_size = datasets_info.splits['train'].num_examples
batch_size = 32

train_dataset = tfn.util.tune_dataset(
    train_dataset,
    batch_size=batch_size,
    shuffle_size=int(train_size  / 7),
    preprocess_fn=_preprocess)

eval_dataset = tfn.util.tune_dataset(
    eval_dataset,
    repeat_count=1,
    preprocess_fn=_preprocess)

x = next(iter(eval_dataset.batch(10)))
tfn.util.display_imgs(x);


3 Define Model


In [0]:
input_shape = datasets_info.features['image'].shape
encoded_size = 16
base_depth = 32

In [0]:
prior = tfd.Sample(tfd.Normal(loc=0, scale=1), sample_shape=encoded_size)

In [0]:
Conv = functools.partial(
    tfn.Convolution,
    init_kernel_fn=tf.initializers.he_uniform())  # Better for leaky_relu.

encoder = tfn.Sequential([
    lambda x: 2. * tf.cast(x, tf.float32) - 1.,  # Center.
    Conv(1, 1 * base_depth, 5, strides=1, padding='same'),
    tf.nn.leaky_relu,
    Conv(1 * base_depth, 1 * base_depth, 5, strides=2, padding='same'),
    tf.nn.leaky_relu,
    Conv(1 * base_depth, 2 * base_depth, 5, strides=1, padding='same'),
    tf.nn.leaky_relu,
    Conv(2 * base_depth, 2 * base_depth, 5, strides=2, padding='same'),
    tf.nn.leaky_relu,
    Conv(2 * base_depth, 4 * encoded_size, 7, strides=1, padding='valid'),
    tf.nn.leaky_relu,
    tfn.util.flatten_rightmost(ndims=3),
    tfn.Affine(64, encoded_size + encoded_size * (encoded_size + 1) // 2),
    lambda x: tfd.MultivariateNormalTriL(
        loc=x[..., :encoded_size],
        scale_tril=tfb.FillScaleTriL()(x[..., encoded_size:]))
], name='encoder')

print(encoder.summary())


=== encoder ==================================================
  SIZE SHAPE                TRAIN NAME                                    
    32 [32]                 True  bias:0                                  
   800 [5, 5, 1, 32]        True  kernel:0                                
    32 [32]                 True  bias:0                                  
 25600 [5, 5, 32, 32]       True  kernel:0                                
    64 [64]                 True  bias:0                                  
 51200 [5, 5, 32, 64]       True  kernel:0                                
    64 [64]                 True  bias:0                                  
102400 [5, 5, 64, 64]       True  kernel:0                                
    64 [64]                 True  bias:0                                  
200704 [7, 7, 64, 64]       True  kernel:0                                
   152 [152]                True  bias:0                                  
  9728 [64, 152]            True  kernel:0                                
trainable size: 390840  /  1.491 MiB  /  {float32: 390840}

In [0]:
DeConv = functools.partial(
    tfn.ConvolutionTranspose,
    init_kernel_fn=tf.initializers.he_uniform())  # Better for leaky_relu.
    
decoder = tfn.Sequential([
    lambda x: x[..., tf.newaxis, tf.newaxis, :],
    DeConv(encoded_size, 2 * base_depth, 7, strides=1, padding='valid'),
    tf.nn.leaky_relu,
    DeConv(2 * base_depth, 2 * base_depth, 5, strides=1, padding='same'),
    tf.nn.leaky_relu,
    DeConv(2 * base_depth, 2 * base_depth, 5, strides=2, padding='same'),
    tf.nn.leaky_relu,
    DeConv(2 * base_depth, base_depth, 5, strides=1, padding='same'),
    tf.nn.leaky_relu,
    DeConv(1 * base_depth, 1 * base_depth, 5, strides=2, padding='same'),
    tf.nn.leaky_relu,
    DeConv(1 * base_depth, 1 * base_depth, 5, strides=1, padding='same'),
    tf.nn.leaky_relu,
    Conv(1 * base_depth, 1, 5, strides=1, padding='same'),
    tfn.util.flatten_rightmost(ndims=3),
    tfb.Reshape(input_shape),
    lambda x: tfd.Independent(tfd.Bernoulli(logits=x),
                              reinterpreted_batch_ndims=len(input_shape)),
], name='decoder')

print(decoder.summary())


=== decoder ==================================================
  SIZE SHAPE                TRAIN NAME                                    
    64 [64]                 True  bias:0                                  
 50176 [7, 7, 64, 16]       True  kernel:0                                
    64 [64]                 True  bias:0                                  
102400 [5, 5, 64, 64]       True  kernel:0                                
    64 [64]                 True  bias:0                                  
102400 [5, 5, 64, 64]       True  kernel:0                                
    32 [32]                 True  bias:0                                  
 51200 [5, 5, 32, 64]       True  kernel:0                                
    32 [32]                 True  bias:0                                  
 25600 [5, 5, 32, 32]       True  kernel:0                                
    32 [32]                 True  bias:0                                  
 25600 [5, 5, 32, 32]       True  kernel:0                                
     1 [1]                  True  bias:0                                  
   800 [5, 5, 32, 1]        True  kernel:0                                
trainable size: 358465  /  1.367 MiB  /  {float32: 358465}

4 Loss / Eval


In [0]:
def compute_loss(x, beta=1.):
  q = encoder(x)
  z = q.sample()
  p = decoder(z)
  kl = tf.reduce_mean(q.log_prob(z) - prior.log_prob(z), axis=-1)
  # Note: we could use exact KL divergence, eg:
  #   kl = tf.reduce_mean(tfd.kl_divergence(q, prior))
  # however we generally find that using the Monte Carlo approximation has
  # lower variance.
  nll = -tf.reduce_mean(p.log_prob(x), axis=-1)
  loss = nll + beta * kl
  return loss, (nll, kl), (q, z, p)

In [0]:
train_iter = iter(train_dataset)

def loss():
  x = next(train_iter)
  loss, (nll, kl), _ = compute_loss(x)
  return loss, (nll, kl)

opt = tf.optimizers.Adam(learning_rate=1e-3)

fit = tfn.util.make_fit_op(
    loss,
    opt,
    decoder.trainable_variables + encoder.trainable_variables,
    grad_summary_fn=lambda gs: tf.nest.map_structure(tf.norm, gs))

In [0]:
eval_iter = iter(eval_dataset.batch(5000).repeat())

@tfn.util.tfcompile
def eval():
  x = next(eval_iter)
  loss, (nll, kl), _ = compute_loss(x)
  return loss, (nll, kl)

5 Train


In [0]:
DEBUG_MODE = False
tf.config.experimental_run_functions_eagerly(DEBUG_MODE)

In [0]:
num_train_epochs = 1.  # @param { isTemplate: true}
num_evals = 200        # @param { isTemplate: true}

dur_sec = dur_num = 0
num_train_steps = int(num_train_epochs * train_size)
for i in range(num_train_steps):
  start = time.time()
  trn_loss, (trn_nll, trn_kl), g = fit()
  stop = time.time()
  dur_sec += stop - start
  dur_num += 1
  if i % int(num_train_steps / num_evals) == 0 or i == num_train_steps - 1:
    tst_loss, (tst_nll, tst_kl) = eval()
    f, x = zip(*[
        ('it:{:5}', opt.iterations),
        ('ms/it:{:6.4f}', dur_sec / max(1., dur_num) * 1000.),
        ('trn_loss:{:6.4f}', trn_loss),
        ('tst_loss:{:6.4f}', tst_loss),
        ('tst_nll:{:6.4f}', tst_nll),
        ('tst_kl:{:6.4f}', tst_kl),
        ('sum_norm_grad:{:6.4f}', sum(g)),

    ])
    print('   '.join(f).format(*[getattr(x_, 'numpy', lambda: x_)()
                                 for x_ in x]))
    sys.stdout.flush()
    dur_sec = dur_num = 0
  # if i % 1000 == 0 or i == maxiter - 1:
  #   encoder.save('/tmp/encoder.npz')
  #   decoder.save('/tmp/decoder.npz')


it:    1   ms/it:8167.5928   trn_loss:571.6944   tst_loss:542.9343   tst_nll:532.3325   tst_kl:10.6018   sum_norm_grad:2056.6118
it:  251   ms/it:7.7388   trn_loss:188.3120   tst_loss:174.4140   tst_nll:166.6604   tst_kl:7.7536   sum_norm_grad:561.9462
it:  501   ms/it:7.7590   trn_loss:148.8059   tst_loss:148.2165   tst_nll:137.6826   tst_kl:10.5339   sum_norm_grad:796.4335
it:  751   ms/it:7.7292   trn_loss:141.5680   tst_loss:142.2636   tst_nll:131.5428   tst_kl:10.7208   sum_norm_grad:490.8536
it: 1001   ms/it:7.7422   trn_loss:130.4454   tst_loss:135.6136   tst_nll:125.0441   tst_kl:10.5695   sum_norm_grad:442.8789
it: 1251   ms/it:7.7527   trn_loss:137.6278   tst_loss:132.9517   tst_nll:121.3120   tst_kl:11.6397   sum_norm_grad:442.4716
it: 1501   ms/it:7.7564   trn_loss:130.9861   tst_loss:130.4965   tst_nll:119.9839   tst_kl:10.5125   sum_norm_grad:524.5187
it: 1751   ms/it:7.7942   trn_loss:123.1555   tst_loss:122.9402   tst_nll:110.2459   tst_kl:12.6943   sum_norm_grad:551.7631
it: 2001   ms/it:7.7790   trn_loss:123.9931   tst_loss:120.8062   tst_nll:107.7721   tst_kl:13.0341   sum_norm_grad:511.3472
it: 2251   ms/it:7.8637   trn_loss:103.8219   tst_loss:123.7093   tst_nll:108.0471   tst_kl:15.6622   sum_norm_grad:464.8754
it: 2501   ms/it:7.8613   trn_loss:123.9166   tst_loss:118.8462   tst_nll:104.1422   tst_kl:14.7040   sum_norm_grad:572.0829
it: 2751   ms/it:7.8550   trn_loss:126.4342   tst_loss:116.0461   tst_nll:101.6044   tst_kl:14.4417   sum_norm_grad:638.6761
it: 3001   ms/it:7.8510   trn_loss:110.7648   tst_loss:116.2834   tst_nll:102.1414   tst_kl:14.1420   sum_norm_grad:367.6692
it: 3251   ms/it:7.8399   trn_loss:114.0675   tst_loss:115.1729   tst_nll:100.2288   tst_kl:14.9441   sum_norm_grad:368.8145
it: 3501   ms/it:7.8515   trn_loss:110.6161   tst_loss:114.5887   tst_nll:98.9514   tst_kl:15.6373   sum_norm_grad:402.8748
it: 3751   ms/it:7.8610   trn_loss:111.0085   tst_loss:114.3741   tst_nll:100.2768   tst_kl:14.0973   sum_norm_grad:401.1852
it: 4001   ms/it:7.8447   trn_loss:106.8356   tst_loss:112.8442   tst_nll:97.0991   tst_kl:15.7451   sum_norm_grad:409.7707
it: 4251   ms/it:7.8881   trn_loss:115.3077   tst_loss:114.2563   tst_nll:97.9633   tst_kl:16.2930   sum_norm_grad:629.7654
it: 4501   ms/it:7.9070   trn_loss:127.3311   tst_loss:112.6550   tst_nll:97.2963   tst_kl:15.3586   sum_norm_grad:490.1459
it: 4751   ms/it:7.9127   trn_loss:124.0865   tst_loss:113.1525   tst_nll:97.0626   tst_kl:16.0898   sum_norm_grad:509.3741
it: 5001   ms/it:7.9179   trn_loss:109.1007   tst_loss:110.2191   tst_nll:94.4971   tst_kl:15.7220   sum_norm_grad:397.1542
it: 5251   ms/it:7.9178   trn_loss:104.5153   tst_loss:111.0835   tst_nll:94.6629   tst_kl:16.4207   sum_norm_grad:428.8897
it: 5501   ms/it:7.8908   trn_loss:104.2857   tst_loss:109.6805   tst_nll:93.3886   tst_kl:16.2920   sum_norm_grad:415.7182
it: 5751   ms/it:7.9090   trn_loss:115.1790   tst_loss:109.1490   tst_nll:93.1569   tst_kl:15.9920   sum_norm_grad:432.4279
it: 6001   ms/it:7.9119   trn_loss:104.6124   tst_loss:107.9034   tst_nll:90.9999   tst_kl:16.9036   sum_norm_grad:493.7137
it: 6251   ms/it:7.9066   trn_loss:99.7216   tst_loss:106.7395   tst_nll:89.0871   tst_kl:17.6524   sum_norm_grad:556.7726
it: 6501   ms/it:7.9099   trn_loss:105.5314   tst_loss:106.6788   tst_nll:89.3161   tst_kl:17.3627   sum_norm_grad:358.8821
it: 6751   ms/it:7.9106   trn_loss:108.3402   tst_loss:107.4677   tst_nll:90.2000   tst_kl:17.2677   sum_norm_grad:420.5543
it: 7001   ms/it:7.9042   trn_loss:117.6976   tst_loss:106.7938   tst_nll:89.3268   tst_kl:17.4670   sum_norm_grad:859.7291
it: 7251   ms/it:7.9174   trn_loss:100.5208   tst_loss:107.8533   tst_nll:90.8074   tst_kl:17.0460   sum_norm_grad:415.5740
it: 7501   ms/it:7.9589   trn_loss:105.1086   tst_loss:105.6644   tst_nll:88.7536   tst_kl:16.9108   sum_norm_grad:578.0068
it: 7751   ms/it:7.9450   trn_loss:109.4627   tst_loss:107.2615   tst_nll:89.6180   tst_kl:17.6435   sum_norm_grad:420.5908
it: 8001   ms/it:7.9866   trn_loss:111.7679   tst_loss:105.2662   tst_nll:87.2149   tst_kl:18.0514   sum_norm_grad:509.2001
it: 8251   ms/it:7.9858   trn_loss:108.6847   tst_loss:105.5514   tst_nll:87.9637   tst_kl:17.5877   sum_norm_grad:438.8681
it: 8501   ms/it:7.9721   trn_loss:103.6953   tst_loss:105.0167   tst_nll:88.0758   tst_kl:16.9409   sum_norm_grad:424.3508
it: 8751   ms/it:7.9588   trn_loss:95.5592   tst_loss:104.7522   tst_nll:87.5408   tst_kl:17.2114   sum_norm_grad:389.7538
it: 9001   ms/it:7.9823   trn_loss:98.4096   tst_loss:106.6480   tst_nll:88.3408   tst_kl:18.3072   sum_norm_grad:427.1542
it: 9251   ms/it:7.9650   trn_loss:93.9016   tst_loss:103.6268   tst_nll:85.6080   tst_kl:18.0189   sum_norm_grad:446.4226
it: 9501   ms/it:7.9577   trn_loss:101.2909   tst_loss:102.8728   tst_nll:85.1667   tst_kl:17.7061   sum_norm_grad:457.3659
it: 9751   ms/it:7.9687   trn_loss:107.5655   tst_loss:103.1900   tst_nll:85.4349   tst_kl:17.7550   sum_norm_grad:505.9274
it:10001   ms/it:7.9752   trn_loss:110.6296   tst_loss:104.9651   tst_nll:86.3768   tst_kl:18.5883   sum_norm_grad:417.9517
it:10251   ms/it:7.9542   trn_loss:98.3785   tst_loss:103.7080   tst_nll:85.0731   tst_kl:18.6348   sum_norm_grad:404.2544
it:10501   ms/it:8.0073   trn_loss:104.4330   tst_loss:102.5719   tst_nll:85.2622   tst_kl:17.3097   sum_norm_grad:429.2713
it:10751   ms/it:8.2205   trn_loss:100.1614   tst_loss:103.6863   tst_nll:84.6275   tst_kl:19.0587   sum_norm_grad:473.0844
it:11001   ms/it:8.2890   trn_loss:101.6195   tst_loss:102.5362   tst_nll:84.9307   tst_kl:17.6055   sum_norm_grad:398.6784
it:11251   ms/it:8.3585   trn_loss:107.2305   tst_loss:102.2765   tst_nll:84.2955   tst_kl:17.9810   sum_norm_grad:397.1293
it:11501   ms/it:8.5587   trn_loss:103.8816   tst_loss:102.7018   tst_nll:83.9922   tst_kl:18.7096   sum_norm_grad:515.5089
it:11751   ms/it:8.2993   trn_loss:102.9577   tst_loss:102.9828   tst_nll:85.3075   tst_kl:17.6753   sum_norm_grad:481.9806
it:12001   ms/it:8.2880   trn_loss:102.1254   tst_loss:101.5072   tst_nll:82.8090   tst_kl:18.6982   sum_norm_grad:720.7101
it:12251   ms/it:8.2124   trn_loss:93.5831   tst_loss:103.6871   tst_nll:86.1490   tst_kl:17.5381   sum_norm_grad:483.7248
it:12501   ms/it:8.1652   trn_loss:106.4728   tst_loss:102.3646   tst_nll:84.6786   tst_kl:17.6860   sum_norm_grad:710.8688
it:12751   ms/it:8.1655   trn_loss:95.7906   tst_loss:103.7332   tst_nll:84.8222   tst_kl:18.9110   sum_norm_grad:427.0826
it:13001   ms/it:8.1497   trn_loss:103.2529   tst_loss:101.0650   tst_nll:82.8124   tst_kl:18.2526   sum_norm_grad:358.3357
it:13251   ms/it:8.1346   trn_loss:104.1242   tst_loss:101.2779   tst_nll:82.7441   tst_kl:18.5338   sum_norm_grad:462.1422
it:13501   ms/it:8.0507   trn_loss:102.6927   tst_loss:100.8160   tst_nll:82.7331   tst_kl:18.0828   sum_norm_grad:449.6999
it:13751   ms/it:7.9958   trn_loss:86.5824   tst_loss:100.8928   tst_nll:82.9390   tst_kl:17.9538   sum_norm_grad:326.1215
it:14001   ms/it:8.0597   trn_loss:101.3285   tst_loss:101.2132   tst_nll:82.9017   tst_kl:18.3115   sum_norm_grad:375.9016
it:14251   ms/it:8.0435   trn_loss:91.6417   tst_loss:101.7879   tst_nll:83.7527   tst_kl:18.0352   sum_norm_grad:413.0688
it:14501   ms/it:8.0374   trn_loss:91.9398   tst_loss:102.1389   tst_nll:83.9429   tst_kl:18.1960   sum_norm_grad:328.8266
it:14751   ms/it:8.0733   trn_loss:106.3819   tst_loss:101.8671   tst_nll:84.0552   tst_kl:17.8119   sum_norm_grad:509.4388
it:15001   ms/it:8.0907   trn_loss:102.1323   tst_loss:100.7362   tst_nll:82.7936   tst_kl:17.9425   sum_norm_grad:417.7497
it:15251   ms/it:8.0734   trn_loss:102.3906   tst_loss:101.5540   tst_nll:82.8907   tst_kl:18.6633   sum_norm_grad:405.2325
it:15501   ms/it:8.0804   trn_loss:103.5540   tst_loss:100.2225   tst_nll:81.6333   tst_kl:18.5892   sum_norm_grad:442.2773
it:15751   ms/it:8.0910   trn_loss:98.5541   tst_loss:102.5174   tst_nll:83.2855   tst_kl:19.2320   sum_norm_grad:586.1303
it:16001   ms/it:8.1574   trn_loss:104.2341   tst_loss:101.2050   tst_nll:82.7563   tst_kl:18.4488   sum_norm_grad:641.9893
it:16251   ms/it:8.2550   trn_loss:99.4588   tst_loss:101.4318   tst_nll:82.1074   tst_kl:19.3244   sum_norm_grad:544.8028
it:16501   ms/it:8.2738   trn_loss:101.8940   tst_loss:99.7064   tst_nll:81.0831   tst_kl:18.6233   sum_norm_grad:401.9953
it:16751   ms/it:8.2277   trn_loss:93.5180   tst_loss:101.1547   tst_nll:81.5737   tst_kl:19.5810   sum_norm_grad:456.8716
it:17001   ms/it:8.1098   trn_loss:105.9864   tst_loss:100.3267   tst_nll:81.3302   tst_kl:18.9965   sum_norm_grad:480.0489
it:17251   ms/it:8.1619   trn_loss:89.7933   tst_loss:100.5733   tst_nll:81.6203   tst_kl:18.9530   sum_norm_grad:349.7754
it:17501   ms/it:8.1657   trn_loss:99.5181   tst_loss:99.7858   tst_nll:80.0372   tst_kl:19.7486   sum_norm_grad:675.3879
it:17751   ms/it:8.1532   trn_loss:90.1636   tst_loss:100.5669   tst_nll:80.6271   tst_kl:19.9398   sum_norm_grad:464.5888
it:18001   ms/it:8.1601   trn_loss:87.1849   tst_loss:100.3409   tst_nll:81.4788   tst_kl:18.8621   sum_norm_grad:548.5415
it:18251   ms/it:8.1596   trn_loss:96.3656   tst_loss:99.7642   tst_nll:80.4470   tst_kl:19.3172   sum_norm_grad:568.6088
it:18501   ms/it:8.1601   trn_loss:96.4203   tst_loss:98.6329   tst_nll:79.5682   tst_kl:19.0647   sum_norm_grad:556.0184
it:18751   ms/it:8.1582   trn_loss:99.5522   tst_loss:100.0960   tst_nll:81.1789   tst_kl:18.9171   sum_norm_grad:375.0641
it:19001   ms/it:8.1543   trn_loss:92.3773   tst_loss:99.7789   tst_nll:80.2575   tst_kl:19.5214   sum_norm_grad:395.3218
it:19251   ms/it:8.0936   trn_loss:95.3878   tst_loss:99.0224   tst_nll:79.6327   tst_kl:19.3897   sum_norm_grad:482.4856
it:19501   ms/it:8.1621   trn_loss:100.4647   tst_loss:99.2252   tst_nll:78.5188   tst_kl:20.7064   sum_norm_grad:479.1700
it:19751   ms/it:8.1458   trn_loss:95.9030   tst_loss:99.7306   tst_nll:80.2872   tst_kl:19.4433   sum_norm_grad:504.3706
it:20001   ms/it:8.1910   trn_loss:95.3447   tst_loss:98.6879   tst_nll:79.9386   tst_kl:18.7493   sum_norm_grad:430.0017
it:20251   ms/it:8.1732   trn_loss:98.5342   tst_loss:98.9194   tst_nll:78.6354   tst_kl:20.2839   sum_norm_grad:449.0607
it:20501   ms/it:8.2234   trn_loss:100.6554   tst_loss:97.9047   tst_nll:77.9066   tst_kl:19.9981   sum_norm_grad:409.5266
it:20751   ms/it:8.1860   trn_loss:97.2872   tst_loss:98.0931   tst_nll:78.7699   tst_kl:19.3232   sum_norm_grad:423.0030
it:21001   ms/it:8.1695   trn_loss:99.2308   tst_loss:97.3468   tst_nll:78.1596   tst_kl:19.1872   sum_norm_grad:720.2530
it:21251   ms/it:8.1776   trn_loss:92.1392   tst_loss:98.5678   tst_nll:79.1944   tst_kl:19.3734   sum_norm_grad:408.9140
it:21501   ms/it:8.1605   trn_loss:98.8631   tst_loss:98.4129   tst_nll:78.8784   tst_kl:19.5345   sum_norm_grad:627.6019
it:21751   ms/it:8.1882   trn_loss:97.7209   tst_loss:98.1413   tst_nll:78.8780   tst_kl:19.2634   sum_norm_grad:481.7331
it:22001   ms/it:8.1737   trn_loss:93.9814   tst_loss:97.0028   tst_nll:77.0947   tst_kl:19.9081   sum_norm_grad:399.8244
it:22251   ms/it:8.1458   trn_loss:103.6036   tst_loss:98.1384   tst_nll:78.0249   tst_kl:20.1135   sum_norm_grad:481.9254
it:22501   ms/it:8.1774   trn_loss:94.8821   tst_loss:97.6829   tst_nll:77.5557   tst_kl:20.1272   sum_norm_grad:492.4880
it:22751   ms/it:8.1611   trn_loss:94.7568   tst_loss:98.7982   tst_nll:77.4158   tst_kl:21.3823   sum_norm_grad:551.9900
it:23001   ms/it:8.2206   trn_loss:101.1991   tst_loss:97.1871   tst_nll:76.1145   tst_kl:21.0726   sum_norm_grad:490.0321
it:23251   ms/it:8.2294   trn_loss:94.7145   tst_loss:97.1758   tst_nll:76.8328   tst_kl:20.3430   sum_norm_grad:488.6527
it:23501   ms/it:8.2576   trn_loss:106.8952   tst_loss:96.7854   tst_nll:76.9890   tst_kl:19.7964   sum_norm_grad:550.9918
it:23751   ms/it:8.2320   trn_loss:88.6979   tst_loss:97.4744   tst_nll:76.2718   tst_kl:21.2026   sum_norm_grad:588.6791
it:24001   ms/it:8.2377   trn_loss:99.6044   tst_loss:98.1709   tst_nll:77.6664   tst_kl:20.5045   sum_norm_grad:404.7585
it:24251   ms/it:8.2496   trn_loss:95.8460   tst_loss:98.2426   tst_nll:78.6304   tst_kl:19.6122   sum_norm_grad:586.2623
it:24501   ms/it:8.2374   trn_loss:105.8137   tst_loss:97.1946   tst_nll:76.9831   tst_kl:20.2115   sum_norm_grad:486.7053
it:24751   ms/it:8.1393   trn_loss:93.5827   tst_loss:96.9893   tst_nll:77.2875   tst_kl:19.7019   sum_norm_grad:502.4047
it:25001   ms/it:8.0884   trn_loss:93.0325   tst_loss:96.9008   tst_nll:77.3248   tst_kl:19.5760   sum_norm_grad:614.4216
it:25251   ms/it:8.1797   trn_loss:105.5030   tst_loss:97.0390   tst_nll:76.9735   tst_kl:20.0654   sum_norm_grad:451.8987
it:25501   ms/it:8.2334   trn_loss:90.2355   tst_loss:96.4251   tst_nll:76.2132   tst_kl:20.2119   sum_norm_grad:414.7156
it:25751   ms/it:8.2680   trn_loss:96.4100   tst_loss:97.2122   tst_nll:76.7706   tst_kl:20.4416   sum_norm_grad:414.3313
it:26001   ms/it:8.2876   trn_loss:92.2220   tst_loss:96.9308   tst_nll:76.2899   tst_kl:20.6410   sum_norm_grad:454.7311
it:26251   ms/it:8.2632   trn_loss:107.7578   tst_loss:97.0452   tst_nll:76.9183   tst_kl:20.1269   sum_norm_grad:753.7256
it:26501   ms/it:8.2665   trn_loss:94.1039   tst_loss:97.0599   tst_nll:76.2740   tst_kl:20.7859   sum_norm_grad:430.0159
it:26751   ms/it:8.2221   trn_loss:93.9554   tst_loss:96.8069   tst_nll:76.8602   tst_kl:19.9467   sum_norm_grad:694.3672
it:27001   ms/it:8.0894   trn_loss:97.2534   tst_loss:96.5289   tst_nll:75.7308   tst_kl:20.7981   sum_norm_grad:569.3329
it:27251   ms/it:8.1047   trn_loss:98.7436   tst_loss:98.1664   tst_nll:78.3807   tst_kl:19.7857   sum_norm_grad:512.0961
it:27501   ms/it:8.0659   trn_loss:93.8678   tst_loss:96.9302   tst_nll:76.6646   tst_kl:20.2656   sum_norm_grad:550.1595
it:27751   ms/it:8.1921   trn_loss:92.3248   tst_loss:96.4746   tst_nll:76.6164   tst_kl:19.8582   sum_norm_grad:515.1254
it:28001   ms/it:8.2133   trn_loss:86.0317   tst_loss:95.7819   tst_nll:74.8646   tst_kl:20.9173   sum_norm_grad:441.3182
it:28251   ms/it:8.2641   trn_loss:96.6930   tst_loss:96.8639   tst_nll:76.0705   tst_kl:20.7934   sum_norm_grad:653.7979
it:28501   ms/it:8.2619   trn_loss:97.0918   tst_loss:95.5800   tst_nll:75.3970   tst_kl:20.1830   sum_norm_grad:423.9479
it:28751   ms/it:8.2592   trn_loss:100.7624   tst_loss:97.2540   tst_nll:76.5098   tst_kl:20.7442   sum_norm_grad:451.6195
it:29001   ms/it:8.2559   trn_loss:97.7655   tst_loss:96.7856   tst_nll:76.4287   tst_kl:20.3570   sum_norm_grad:636.2123
it:29251   ms/it:8.2444   trn_loss:99.4683   tst_loss:96.2001   tst_nll:75.3047   tst_kl:20.8954   sum_norm_grad:532.6425
it:29501   ms/it:8.2219   trn_loss:103.1982   tst_loss:95.8101   tst_nll:75.4957   tst_kl:20.3145   sum_norm_grad:737.2990
it:29751   ms/it:8.2706   trn_loss:88.1869   tst_loss:95.5334   tst_nll:75.0396   tst_kl:20.4939   sum_norm_grad:435.1315
it:30001   ms/it:8.2572   trn_loss:88.9855   tst_loss:95.1945   tst_nll:74.7884   tst_kl:20.4061   sum_norm_grad:532.7263
it:30251   ms/it:8.2379   trn_loss:93.3942   tst_loss:95.9549   tst_nll:75.8224   tst_kl:20.1325   sum_norm_grad:537.1184
it:30501   ms/it:8.1763   trn_loss:95.9863   tst_loss:95.2355   tst_nll:74.6904   tst_kl:20.5450   sum_norm_grad:443.2344
it:30751   ms/it:8.1587   trn_loss:103.2956   tst_loss:95.7448   tst_nll:74.8156   tst_kl:20.9292   sum_norm_grad:639.9230
it:31001   ms/it:8.2373   trn_loss:99.6281   tst_loss:96.0330   tst_nll:75.3437   tst_kl:20.6892   sum_norm_grad:509.9691
it:31251   ms/it:8.2642   trn_loss:93.2277   tst_loss:96.5316   tst_nll:74.8011   tst_kl:21.7304   sum_norm_grad:565.1442
it:31501   ms/it:8.2399   trn_loss:98.8137   tst_loss:94.8659   tst_nll:74.2961   tst_kl:20.5699   sum_norm_grad:487.9633
it:31751   ms/it:8.2390   trn_loss:89.9007   tst_loss:95.9717   tst_nll:74.6950   tst_kl:21.2767   sum_norm_grad:323.7951
it:32001   ms/it:8.2598   trn_loss:92.0165   tst_loss:95.5847   tst_nll:74.3418   tst_kl:21.2429   sum_norm_grad:544.3785
it:32251   ms/it:8.2922   trn_loss:91.8557   tst_loss:95.3442   tst_nll:74.9442   tst_kl:20.4000   sum_norm_grad:354.7590
it:32501   ms/it:8.2118   trn_loss:93.8079   tst_loss:94.9228   tst_nll:74.1481   tst_kl:20.7746   sum_norm_grad:496.3012
it:32751   ms/it:8.2230   trn_loss:92.0644   tst_loss:96.2702   tst_nll:75.1148   tst_kl:21.1554   sum_norm_grad:502.8069
it:33001   ms/it:8.2497   trn_loss:86.3031   tst_loss:94.8656   tst_nll:73.5262   tst_kl:21.3394   sum_norm_grad:398.3834
it:33251   ms/it:8.2422   trn_loss:90.9470   tst_loss:95.6094   tst_nll:74.4927   tst_kl:21.1166   sum_norm_grad:518.0490
it:33501   ms/it:8.1715   trn_loss:94.1972   tst_loss:95.5603   tst_nll:74.5368   tst_kl:21.0235   sum_norm_grad:491.9893
it:33751   ms/it:8.1740   trn_loss:84.3660   tst_loss:96.0006   tst_nll:75.3154   tst_kl:20.6852   sum_norm_grad:399.3826
it:34001   ms/it:8.2510   trn_loss:96.6165   tst_loss:95.5739   tst_nll:74.1825   tst_kl:21.3914   sum_norm_grad:598.1465
it:34251   ms/it:8.2853   trn_loss:91.1956   tst_loss:95.6995   tst_nll:74.0909   tst_kl:21.6086   sum_norm_grad:722.1900
it:34501   ms/it:8.2397   trn_loss:85.3465   tst_loss:94.9647   tst_nll:74.0027   tst_kl:20.9620   sum_norm_grad:372.8212
it:34751   ms/it:8.2395   trn_loss:95.9603   tst_loss:95.6924   tst_nll:74.0264   tst_kl:21.6659   sum_norm_grad:512.9751
it:35001   ms/it:8.2757   trn_loss:89.3138   tst_loss:94.2024   tst_nll:73.4701   tst_kl:20.7323   sum_norm_grad:539.6708
it:35251   ms/it:8.2877   trn_loss:104.4182   tst_loss:95.7220   tst_nll:74.9508   tst_kl:20.7712   sum_norm_grad:591.6014
it:35501   ms/it:8.1780   trn_loss:89.0770   tst_loss:94.6851   tst_nll:73.9579   tst_kl:20.7272   sum_norm_grad:493.3982
it:35751   ms/it:8.2371   trn_loss:96.2822   tst_loss:95.2431   tst_nll:73.9900   tst_kl:21.2531   sum_norm_grad:616.2849
it:36001   ms/it:8.2805   trn_loss:104.2974   tst_loss:95.0059   tst_nll:74.5361   tst_kl:20.4698   sum_norm_grad:598.4661
it:36251   ms/it:8.2790   trn_loss:92.3718   tst_loss:95.6937   tst_nll:73.8834   tst_kl:21.8104   sum_norm_grad:540.7161
it:36501   ms/it:8.2373   trn_loss:95.7532   tst_loss:94.5186   tst_nll:73.2902   tst_kl:21.2284   sum_norm_grad:544.5507
it:36751   ms/it:8.2614   trn_loss:92.3881   tst_loss:95.6637   tst_nll:74.2255   tst_kl:21.4381   sum_norm_grad:556.7919
it:37001   ms/it:8.2704   trn_loss:89.5498   tst_loss:94.5648   tst_nll:73.0089   tst_kl:21.5559   sum_norm_grad:669.0724
it:37251   ms/it:8.2494   trn_loss:81.2119   tst_loss:94.8074   tst_nll:73.6568   tst_kl:21.1506   sum_norm_grad:470.7941
it:37501   ms/it:8.2374   trn_loss:91.9064   tst_loss:94.9572   tst_nll:74.0324   tst_kl:20.9249   sum_norm_grad:526.6972
it:37751   ms/it:8.2555   trn_loss:95.1651   tst_loss:95.0565   tst_nll:74.1256   tst_kl:20.9309   sum_norm_grad:482.9742
it:38001   ms/it:8.2687   trn_loss:94.9950   tst_loss:95.4203   tst_nll:74.6256   tst_kl:20.7947   sum_norm_grad:685.8509
it:38251   ms/it:8.2352   trn_loss:87.7089   tst_loss:95.3314   tst_nll:74.1835   tst_kl:21.1479   sum_norm_grad:484.6642
it:38501   ms/it:8.2450   trn_loss:96.3948   tst_loss:95.3057   tst_nll:74.1131   tst_kl:21.1926   sum_norm_grad:477.4860
it:38751   ms/it:8.2741   trn_loss:91.0143   tst_loss:96.1284   tst_nll:74.4671   tst_kl:21.6613   sum_norm_grad:516.8703
it:39001   ms/it:8.2332   trn_loss:93.1803   tst_loss:94.5465   tst_nll:73.6825   tst_kl:20.8640   sum_norm_grad:546.8208
it:39251   ms/it:8.2147   trn_loss:94.1186   tst_loss:95.2316   tst_nll:74.7130   tst_kl:20.5187   sum_norm_grad:585.1824
it:39501   ms/it:8.2274   trn_loss:85.8579   tst_loss:95.2314   tst_nll:73.7965   tst_kl:21.4348   sum_norm_grad:454.3126
it:39751   ms/it:8.2873   trn_loss:89.4100   tst_loss:95.4110   tst_nll:74.6880   tst_kl:20.7229   sum_norm_grad:574.0444
it:40001   ms/it:8.2554   trn_loss:93.0305   tst_loss:94.0303   tst_nll:73.2089   tst_kl:20.8214   sum_norm_grad:370.2800
it:40251   ms/it:8.2352   trn_loss:94.4651   tst_loss:95.8122   tst_nll:73.9440   tst_kl:21.8682   sum_norm_grad:518.2649
it:40501   ms/it:8.2360   trn_loss:93.1322   tst_loss:94.2561   tst_nll:73.1101   tst_kl:21.1460   sum_norm_grad:443.4222
it:40751   ms/it:8.2455   trn_loss:94.1579   tst_loss:95.1274   tst_nll:73.9682   tst_kl:21.1592   sum_norm_grad:590.7770
it:41001   ms/it:8.2458   trn_loss:94.5556   tst_loss:93.4816   tst_nll:72.9166   tst_kl:20.5649   sum_norm_grad:401.8537
it:41251   ms/it:8.2261   trn_loss:82.3316   tst_loss:95.4182   tst_nll:74.5317   tst_kl:20.8865   sum_norm_grad:666.6778
it:41501   ms/it:8.2904   trn_loss:107.5355   tst_loss:94.3706   tst_nll:73.0945   tst_kl:21.2761   sum_norm_grad:633.9945
it:41751   ms/it:8.2533   trn_loss:102.3969   tst_loss:94.1628   tst_nll:73.5858   tst_kl:20.5769   sum_norm_grad:727.9357
it:42001   ms/it:8.2422   trn_loss:87.7422   tst_loss:94.1707   tst_nll:73.1421   tst_kl:21.0286   sum_norm_grad:473.5380
it:42251   ms/it:8.2345   trn_loss:90.3090   tst_loss:94.8594   tst_nll:74.3117   tst_kl:20.5477   sum_norm_grad:480.7428
it:42501   ms/it:8.2752   trn_loss:90.7338   tst_loss:94.3450   tst_nll:73.5954   tst_kl:20.7496   sum_norm_grad:552.8487
it:42751   ms/it:8.2624   trn_loss:97.9351   tst_loss:94.3638   tst_nll:73.5087   tst_kl:20.8551   sum_norm_grad:566.5553
it:43001   ms/it:8.1799   trn_loss:93.4709   tst_loss:94.7935   tst_nll:73.6575   tst_kl:21.1360   sum_norm_grad:589.3351
it:43251   ms/it:8.2390   trn_loss:93.8279   tst_loss:94.8600   tst_nll:73.7775   tst_kl:21.0824   sum_norm_grad:819.6432
it:43501   ms/it:8.2791   trn_loss:95.1309   tst_loss:94.9538   tst_nll:73.8149   tst_kl:21.1389   sum_norm_grad:641.3374
it:43751   ms/it:8.2810   trn_loss:88.6112   tst_loss:94.5549   tst_nll:73.6116   tst_kl:20.9433   sum_norm_grad:782.4581
it:44001   ms/it:8.2794   trn_loss:92.6073   tst_loss:93.8861   tst_nll:72.9349   tst_kl:20.9513   sum_norm_grad:444.0027
it:44251   ms/it:8.3472   trn_loss:90.3468   tst_loss:94.8443   tst_nll:73.8619   tst_kl:20.9823   sum_norm_grad:496.5469
it:44501   ms/it:8.2447   trn_loss:94.1757   tst_loss:94.2362   tst_nll:73.9904   tst_kl:20.2458   sum_norm_grad:657.1694
it:44751   ms/it:8.2415   trn_loss:91.1105   tst_loss:94.8063   tst_nll:73.1909   tst_kl:21.6155   sum_norm_grad:372.2267
it:45001   ms/it:8.2673   trn_loss:91.7606   tst_loss:94.0004   tst_nll:73.1480   tst_kl:20.8524   sum_norm_grad:719.2126
it:45251   ms/it:8.2901   trn_loss:97.6344   tst_loss:94.6105   tst_nll:73.5537   tst_kl:21.0568   sum_norm_grad:524.6968
it:45501   ms/it:8.2223   trn_loss:96.7768   tst_loss:94.1077   tst_nll:73.0954   tst_kl:21.0123   sum_norm_grad:557.7588
it:45751   ms/it:8.2532   trn_loss:88.5243   tst_loss:94.5430   tst_nll:73.4593   tst_kl:21.0837   sum_norm_grad:549.5432
it:46001   ms/it:8.2846   trn_loss:95.0143   tst_loss:94.1733   tst_nll:72.7262   tst_kl:21.4471   sum_norm_grad:721.0759
it:46251   ms/it:8.2655   trn_loss:92.8926   tst_loss:95.6013   tst_nll:73.6601   tst_kl:21.9412   sum_norm_grad:847.6617
it:46501   ms/it:8.2790   trn_loss:89.3002   tst_loss:94.7222   tst_nll:72.9084   tst_kl:21.8138   sum_norm_grad:647.9052
it:46751   ms/it:8.3193   trn_loss:89.6552   tst_loss:94.6328   tst_nll:73.5012   tst_kl:21.1316   sum_norm_grad:453.8050
it:47001   ms/it:8.2748   trn_loss:91.9704   tst_loss:94.1757   tst_nll:73.4060   tst_kl:20.7697   sum_norm_grad:817.0945
it:47251   ms/it:8.2459   trn_loss:93.0032   tst_loss:93.8077   tst_nll:72.8963   tst_kl:20.9114   sum_norm_grad:604.6274
it:47501   ms/it:8.2840   trn_loss:93.3490   tst_loss:94.9580   tst_nll:73.0902   tst_kl:21.8678   sum_norm_grad:429.7482
it:47751   ms/it:8.2879   trn_loss:87.7900   tst_loss:94.3695   tst_nll:72.8856   tst_kl:21.4839   sum_norm_grad:539.5405
it:48001   ms/it:8.2563   trn_loss:98.2043   tst_loss:93.3153   tst_nll:72.1551   tst_kl:21.1602   sum_norm_grad:623.7249
it:48251   ms/it:8.2457   trn_loss:100.6152   tst_loss:94.2872   tst_nll:73.7673   tst_kl:20.5199   sum_norm_grad:677.5202
it:48501   ms/it:8.2771   trn_loss:88.7600   tst_loss:93.7643   tst_nll:72.6830   tst_kl:21.0813   sum_norm_grad:454.2632
it:48751   ms/it:8.3455   trn_loss:96.5667   tst_loss:94.6928   tst_nll:73.5563   tst_kl:21.1365   sum_norm_grad:605.6883
it:49001   ms/it:8.2370   trn_loss:97.5158   tst_loss:94.2065   tst_nll:73.1295   tst_kl:21.0770   sum_norm_grad:442.5939
it:49251   ms/it:8.2352   trn_loss:96.8186   tst_loss:94.0567   tst_nll:72.9796   tst_kl:21.0771   sum_norm_grad:573.3347
it:49501   ms/it:8.2618   trn_loss:95.7419   tst_loss:94.3354   tst_nll:73.4407   tst_kl:20.8948   sum_norm_grad:905.2293
it:49751   ms/it:8.2658   trn_loss:93.3351   tst_loss:95.1076   tst_nll:73.0878   tst_kl:22.0198   sum_norm_grad:422.7823
it:50000   ms/it:8.2503   trn_loss:96.9131   tst_loss:93.8293   tst_nll:72.3012   tst_kl:21.5281   sum_norm_grad:2663.0032

6 Evaluate


In [0]:
# We'll just examine ten random digits.
x = next(iter(eval_dataset.batch(100)))
xhat = decoder(encoder(x).sample())
assert isinstance(xhat, tfd.Distribution)

In [0]:
print('Originals:')
tfn.util.display_imgs(x);

print('Decoded Random Samples:')
tfn.util.display_imgs(xhat.sample());

print('Decoded Modes:')
tfn.util.display_imgs(xhat.mode());

print('Decoded Means:')
tfn.util.display_imgs(xhat.mean());


Originals:
Decoded Random Samples:
Decoded Modes:
Decoded Means: