The objective of this notebook is to demonstrate splitting a log_prob and gradient computation across a number of GPU devices. For development purposes, this was prototyped in colab with a single GPU partitioned into multiple logical GPUs.

Note: Since it runs on a single GPU, performance is not representative of what can be achieved with multiple GPUs. Usage of tf.data can likely benefit from some tuning when deployed to multiple GPUs.

Needs a GPU: Edit > Notebook Settings: Hardware Accelerator => GPU


In [42]:
%tensorflow_version 2.x
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfb, tfd = tfp.bijectors, tfp.distributions

physical_gpus = tf.config.experimental.list_physical_devices('GPU')
print(physical_gpus)

tf.config.experimental.set_virtual_device_configuration(
    physical_gpus[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)] * 4)
gpus = tf.config.list_logical_devices('GPU')
print(gpus)

st = tf.distribute.MirroredStrategy(devices=tf.config.list_logical_devices('GPU'))
print(st.extended.worker_devices)


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU'), LogicalDevice(name='/device:GPU:2', device_type='GPU'), LogicalDevice(name='/device:GPU:3', device_type='GPU')]
WARNING:tensorflow:NCCL is not supported when using virtual GPUs, fallingback to reduction to one device
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

In [43]:
# Draw samples from an MVN, then sort them. This way we can easily visually
# verify the correct partition ends up on the correct GPUs.
ndim = 3

def model():
  Root = tfd.JointDistributionCoroutine.Root
  loc = yield Root(tfb.Shift(.5)(tfd.MultivariateNormalDiag(loc=tf.zeros([ndim]))))
  scale_tril = yield Root(tfb.FillScaleTriL()(tfd.MultivariateNormalDiag(loc=tf.zeros([ndim * (ndim + 1) // 2]))))
  yield tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril)

dist = tfd.JointDistributionCoroutine(model)
tf.random.set_seed(1)
loc, scale_tril, _ = dist.sample(seed=2)

samples = dist.sample(value=([loc] * 1024, scale_tril, None), seed=3)[2]
samples = tf.round(samples * 1000) / 1000
for dim in reversed(range(ndim)):
  samples = tf.gather(samples, tf.argsort(samples[:,dim]))

print(samples)


tf.Tensor(
[[-4.534 -5.856  3.606]
 [-4.527 -5.875  1.671]
 [-4.269 -4.346  5.697]
 ...
 [ 2.158  5.71  -2.926]
 [ 2.302  6.658 -3.491]
 [ 2.632  5.67  -4.854]], shape=(1024, 3), dtype=float32)

In [44]:
print(loc)
print(scale_tril)
print(tf.reduce_mean(samples, 0))


tf.Tensor([-1.0574996   0.24829748  1.0737331 ], shape=(3,), dtype=float32)
tf.Tensor(
[[ 1.1475685   0.          0.        ]
 [ 1.9094281   0.5724521   0.        ]
 [-1.1899896   0.49813363  1.5088601 ]], shape=(3, 3), dtype=float32)
tf.Tensor([-0.9953702  0.3626416  1.0675195], shape=(3,), dtype=float32)

Single batch of data resident on GPU.


In [45]:
%%time

def dataset_fn(ctx):
  batch_size = ctx.get_per_replica_batch_size(len(samples))
  d = tf.data.Dataset.from_tensor_slices(samples).batch(batch_size)
  return d.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)

ds = st.experimental_distribute_datasets_from_function(dataset_fn)

observations = next(iter(ds))
# print(observations)

@tf.function(autograph=False)
def log_prob_and_grad(loc, scale_tril, observations):
  ctx = tf.distribute.get_replica_context()
  with tf.GradientTape() as tape:
    tape.watch((loc, scale_tril))
    lp = tf.reduce_sum(dist.log_prob(loc, scale_tril, observations)) / len(samples)
  grad = tape.gradient(lp, (loc, scale_tril))
  return ctx.all_reduce('sum', lp), [ctx.all_reduce('sum', g) for g in grad]

@tf.function(autograph=False)
@tf.custom_gradient
def target_log_prob(loc, scale_tril):
  lp, grads = st.run(log_prob_and_grad, (loc, scale_tril, observations))
  return lp.values[0], lambda grad_lp: [grad_lp * g.values[0] for g in grads]

singleton_vals = tfp.math.value_and_gradient(target_log_prob, (loc, scale_tril))

kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, step_size=.35, num_leapfrog_steps=2)
kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector=[tfb.Identity(), tfb.FillScaleTriL()])

@tf.function(autograph=False)
def sample_chain():
  return tfp.mcmc.sample_chain(
      num_results=200, num_burnin_steps=100,
      current_state=[tf.ones_like(loc), tf.linalg.eye(scale_tril.shape[-1])], 
      kernel=kernel, trace_fn=lambda _, kr: kr.inner_results.is_accepted)
samps, is_accepted = sample_chain()

print(f'accept rate: {np.mean(is_accepted)}')
print(f'ess: {tfp.mcmc.effective_sample_size(samps)}')

print(tf.reduce_mean(samps[0], axis=0))
# print(tf.reduce_mean(samps[1], axis=0))

import matplotlib.pyplot as plt
for dim in range(ndim):
  plt.figure(figsize=(10,1))
  plt.hist(samps[0][:,dim], bins=50)
  plt.title(f'loc[{dim}]: prior mean = 0.5, observation = {loc[dim]}')
  plt.show()


accept rate: 0.8
ess: [<tf.Tensor: shape=(3,), dtype=float32, numpy=array([42.702564, 56.667793, 43.04328 ], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[57.469185,       nan,       nan],
       [34.60636 , 43.471287,       nan],
       [13.848625, 30.807411, 72.34708 ]], dtype=float32)>]
tf.Tensor([-0.18598816  0.7396643   0.4074543 ], shape=(3,), dtype=float32)
CPU times: user 14.3 s, sys: 1.57 s, total: 15.8 s
Wall time: 12.9 s

Two batches of data per log-prob eval (2x slower).


In [46]:
%%time
batches_per_eval = 2

def dataset_fn(ctx):
  batch_size = ctx.get_per_replica_batch_size(len(samples))
  d = tf.data.Dataset.from_tensor_slices(samples).batch(batch_size // batches_per_eval)
  return d.shard(ctx.num_input_pipelines, ctx.input_pipeline_id).prefetch(2)

ds = st.experimental_distribute_datasets_from_function(dataset_fn)

@tf.function(autograph=False)
def log_prob_and_grad(loc, scale_tril, observations, prev_sum_lp, prev_sum_grads):
  with tf.GradientTape() as tape:
    tape.watch((loc, scale_tril))
    lp = tf.reduce_sum(dist.log_prob(loc, scale_tril, observations)) / len(samples)
  grad = tape.gradient(lp, (loc, scale_tril))
  return lp + prev_sum_lp, [g + pg for (g, pg) in zip(grad, prev_sum_grads)]

@tf.function(autograph=False)
@tf.custom_gradient
def target_log_prob(loc, scale_tril):
  sum_lp = tf.zeros([])
  sum_grads = [tf.zeros_like(x) for x in (loc, scale_tril)]
  sum_lp, sum_grads = st.run(
      lambda *x: tf.nest.map_structure(tf.identity, x), (sum_lp, sum_grads))
  def reduce_fn(state, observations):
    sum_lp, sum_grads = state
    return st.run(
        log_prob_and_grad, (loc, scale_tril, observations, sum_lp, sum_grads))
  sum_lp, sum_grads = ds.reduce((sum_lp, sum_grads), reduce_fn)
  sum_lp = st.reduce('sum', sum_lp, None)
  sum_grads = [st.reduce('sum', sg, None) for sg in sum_grads]
  return sum_lp, lambda grad_lp: [grad_lp * sg for sg in sum_grads]

multibatch_vals = tfp.math.value_and_gradient(target_log_prob, (loc, scale_tril))

kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, step_size=.35, num_leapfrog_steps=2)
kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector=[tfb.Identity(), tfb.FillScaleTriL()])

@tf.function(autograph=False)
def sample_chain():
  return tfp.mcmc.sample_chain(
      num_results=200, num_burnin_steps=100,
      current_state=[tf.ones_like(loc), tf.linalg.eye(scale_tril.shape[-1])], 
      kernel=kernel, trace_fn=lambda _, kr: kr.inner_results.is_accepted)
samps, is_accepted = sample_chain()

print(f'accept rate: {np.mean(is_accepted)}')
print(f'ess: {tfp.mcmc.effective_sample_size(samps)}')

print(tf.reduce_mean(samps[0], axis=0))
# print(tf.reduce_mean(samps[1], axis=0))

import matplotlib.pyplot as plt
for dim in range(ndim):
  plt.figure(figsize=(10,1))
  plt.hist(samps[0][:,dim], bins=50)
  plt.title(f'loc[{dim}]: prior mean = 0.5, observation = {loc[dim]}')
  plt.show()


accept rate: 0.8
ess: [<tf.Tensor: shape=(3,), dtype=float32, numpy=array([42.702564, 56.667793, 43.043278], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[57.469193,       nan,       nan],
       [34.606365, 43.471287,       nan],
       [13.848625, 30.807411, 72.34709 ]], dtype=float32)>]
tf.Tensor([-0.18598816  0.7396643   0.4074543 ], shape=(3,), dtype=float32)
CPU times: user 1min 8s, sys: 8.71 s, total: 1min 17s
Wall time: 51.5 s

Sanity check logprob and gradients.


In [0]:
for i, (sv, mv) in enumerate(zip(tf.nest.flatten(singleton_vals), 
                                 tf.nest.flatten(multibatch_vals))):
  np.testing.assert_allclose(sv, mv, err_msg=i, rtol=1e-5)

In [0]: