Configuring Sonnet's BatchNorm Module

This colab walks you through Sonnet's BatchNorm module's different modes of operation.

The module's behaviour is determined by three main parameters: One constructor argument (update_ops_collection) and two arguments that are passed to the graph builder (is_training and test_local_stats).

bn = BatchNorm(update_ops_collection)
bn(inputs, is_training, test_local_stats)

The following diagram visualizes how different parameter settings lead to different modes of operation. Bold arrows mark the current default values of the arguments.

  • Normalize output using local batch statistics
  • Update moving averages in each forward pass
  • Normalize output using local batch statistics
  • Update ops for the moving averages are placed in a named collection. They are not executed automatically.
  • Normalize output using stored moving averages.
  • No update ops are created.
  • Normalize output using local batch statistics
  • No update ops are created.
import numpy as np
import tensorflow as tf
import sonnet as snt
import matplotlib.pyplot as plt
from matplotlib import patches
%matplotlib inline

def run_and_visualize(inputs, outputs, bn_module):
  init = tf.global_variables_initializer()
  with tf.Session() as sess:

    inputs_collection = []
    outputs_collection = []

    for i in range(1000):
      current_inputs, current_outputs =[inputs, outputs])

    bn_mean, bn_var =[bn_module._moving_mean,

  inputs_collection = np.concatenate(inputs_collection, axis=0)
  outputs_collection = np.concatenate(outputs_collection, axis=0)

  print("Number of update ops in collection: {}".format(
  print("Input mean: {}".format(np.mean(inputs_collection, axis=0)))
  print("Input variance: {}".format(np.var(inputs_collection, axis=0)))
  print("Moving mean: {}".format(bn_mean))
  print("Moving variance: {}".format(bn_var))

  # Plot the learned Gaussian distribution.
  ellipse = patches.Ellipse(xy=bn_mean[0], width=bn_var[0, 0],
                            height=bn_var[0, 1], angle=0, edgecolor='g',
                            fc='None', zorder=1000, linestyle='solid',
  # Plot the input distribution.
  input_ax = plt.scatter(inputs_collection[:, 0], inputs_collection[:, 1],
                         c='r', alpha=0.1, zorder=1)
  # Plot the output distribution.
  output_ax = plt.scatter(outputs_collection[:, 0], outputs_collection[:, 1],
                          c='b', alpha=0.1, zorder=1)
  ax = plt.gca()
  ellipse_ax = ax.add_patch(ellipse)
  plt.legend((input_ax, output_ax, ellipse_ax),
             ("Inputs", "Outputs", "Aggregated statistics"),
             loc="lower right")

def get_inputs():
  return tf.concat([
      tf.random_normal((10, 1), 10, 1),
      tf.random_normal((10, 1), 10, 2)],


Default mode

inputs = get_inputs()
bn = snt.BatchNorm()
outputs = bn(inputs, is_training=True)

run_and_visualize(inputs, outputs, bn)


  1. The outputs have been normalized. This is indicated by the blue isotropic Gaussian distribution.
  2. Update ops have been created and placed in a collection.
  3. No moving statistics have been collected. The green circle shows the learned Gaussian distribution. It is initialized to have mean 0 and standard deviation 1. Because the update ops were created but not executed, these statistics have not been updated.
  4. The "boxy" shape of the normalized data points comes from the rather small batch size of 10. Because the batch statistics are only computed over 10 data points, they are very noisy.

Collecting statistics during training

First option: Update statistics automatically on every forward pass

inputs = get_inputs()
bn = snt.BatchNorm(update_ops_collection=None)
outputs = bn(inputs, is_training=True)

run_and_visualize(inputs, outputs, bn)


  1. The outputs have been normalized as we can tell from the blue isotropic Gaussian distribution.
  2. Update ops have been created and executed. We can see that the moving statistics no longer have their default values (i.e. the green ellipsis has changed). The aggregated statistics don't represent the input distribution yet because we only ran 1000 forward passes.

Second option: Explicitly add update ops as control dependencies

inputs = get_inputs()
bn = snt.BatchNorm(update_ops_collection=None)
outputs = bn(inputs, is_training=True)

# Add the update ops as control dependencies
# This can usually be done when defining the gradient descent 
# ops
update_ops =*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
with tf.control_dependencies([update_ops]):
  outputs = tf.identity(outputs)
run_and_visualize(inputs, outputs, bn)


The actual results are identical to the previous run. However, this time, the update ops have not been executed automatically whenever we did a forward pass. We have to explicitly make the updates a dependency of our output by using tf.control_dependencies. Usually, we would add the dependencies to our learning ops.

Using statistics at test time

Default mode

inputs = get_inputs()
bn = snt.BatchNorm()
outputs = bn(inputs, is_training=False)

run_and_visualize(inputs, outputs, bn)


  1. No update ops have been created and the moving statistics still have their initial values (mean 0, standard deviation 1).
  2. The inputs have been normalized using the batch statistics as we can tell from the blue isotropic Gaussian distribution.

This means: In the default testing mode, the inputs are normalized using the batch statistics and the aggregated statistics are ignored.

Using moving averages at test time

def hacky_np_initializer(array):
  """Allows us to initialize a tf variable with a numpy array."""
  def _init(shape, dtype, partition_info):
    return tf.constant(np.asarray(array, dtype='float32'))
  return _init


inputs = get_inputs()
# We initialize the moving mean and variance to non-standard values
# so we can see the effect of this setting
bn = snt.BatchNorm(initializers={
    "moving_mean": hacky_np_initializer([[10, 10]]), 
    "moving_variance": hacky_np_initializer([[1, 4]])
outputs = bn(inputs, is_training=False, test_local_stats=False)

run_and_visualize(inputs, outputs, bn)


We have now manually initialized the moving statistics to the moments of the input distribution. We can see that the inputs have been normalized according to our stored statistics.