In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Jacobian determinant for chanage of variables

Notebook orignially contributed by: MarkDaoust

View source on GitHub

This notebook uses a jacobian determinant to do a change of variables in a probabilty distribution, and make some nice visualizations.

This is just a quick walkthrough, you should probably use TensorFlow probability for any serious applications.

Setup


In [0]:
import tensorflow as tf
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['figure.figsize'] = (12, 8)

Create the distribution

Build a (x0,x1) grid.


In [0]:
x0 = tf.linspace(-12.0, 12.0, 240+1)[:, tf.newaxis]
x1 = tf.linspace(-10.0, 10.0, 200+1)[tf.newaxis, :]

xs = tf.stack(tf.meshgrid(x0, x1), axis=-1)
xs_flat = tf.reshape(xs, [-1,2])

batch_shape = xs.shape[:-1]
print(batch_shape)
print(xs.shape)
print(xs_flat.shape)

Create a multivariate normal distribution over x


In [0]:
from tensorflow_probability import distributions as tfd

dist = tfd.MultivariateNormalFullCovariance(
    loc=[2,1], covariance_matrix=[[9,-2],[-2,4]])

Calculate the probability on the x grid, and take a few samples


In [0]:
px = dist.prob(xs_flat)
sample_xs = dist.sample(sample_shape=50)

Plot the density and samples


In [0]:
def plot_sheet(coords_flat, value_flat, batch_shape, sample = None, **kwargs):
  c0 = tf.reshape(coords_flat[..., 0], batch_shape)
  c1 = tf.reshape(coords_flat[..., 1], batch_shape)

  value = tf.reshape(value_flat, batch_shape)

  if sample is not None:
    plt.scatter(sample[:,0], sample[:,1], c='w', marker='.', zorder=1)
  plt.pcolormesh(c0, c1, value, zorder=-1, shading='gouraud', **kwargs)

  plt.axis('off')
  plt.gca().set_aspect('equal')

In [0]:
plot_sheet(xs_flat, px, batch_shape, sample=sample_xs)
plt.title('A MultiVariate probability density')
cbar = plt.colorbar()
cbar.set_label('Density $1/x^2$')

Transform the coordinates

Transform the x coordinates to u


In [0]:
def transform(xs):
  x0 = xs[..., 0] 
  x1 = xs[..., 1] 

  u0 = x0 + tf.sin(x1)*0.9
  u1 = x1 + tf.cos(x0)*0.9

  us = tf.stack([u0,u1], axis=-1)

  return us

In [0]:
us_flat = transform(xs_flat)
sample_us = transform(sample_xs)

Plot with the transformed coordinates, the density map is wrong.


In [0]:
plot_sheet(us_flat, px, batch_shape, sample=sample_us)
plt.title('This is wrong.')
cbar = plt.colorbar()
cbar.set_label('Wrong units: $1/x^2$')

Apply the jacobian determinant

Calculate the Jacobian determinant at each point, to scale the density.


In [0]:
with tf.GradientTape() as tape:
  tape.watch(xs_flat)
  us_flat = transform(xs_flat)

js = tape.batch_jacobian(us_flat, xs_flat)
js_scale = tf.linalg.det(js)

The Jacobian determinant tells you locally how much each differential area has expanded or contracted, and how much the density changed.


In [0]:
plot_sheet(us_flat, 1/js_scale, batch_shape)
plt.title('Jacobian density change')
cbar = plt.colorbar()
cbar.set_label('Area scale factor $u^2/x^2$')

Divide the density in x by the Jacobian determinant to get the density in u.


In [0]:
pu = px/js_scale

plot_sheet(us_flat, pu, batch_shape, sample=sample_us)
plt.title('Transformed density')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$')

Here is everything together:


In [0]:
vmax = pu.numpy().max()
plt.subplot(2,2,1)
plot_sheet(xs_flat, px, batch_shape, vmax=vmax)
plt.title('A MultiVariate probability density')
cbar = plt.colorbar()
cbar.set_label('Density $1/x^2$')

plt.subplot(2,2,2)
plot_sheet(us_flat, px, batch_shape, vmax=vmax)
plt.title('This is wrong')
cbar = plt.colorbar()
cbar.set_label('Wrong units: $1/x^2$') 

plt.subplot(2,2,3)
plot_sheet(us_flat, 1/js_scale, batch_shape)
plt.title('Jacobian density change')
cbar = plt.colorbar()
cbar.set_label('Area scale factor $x^2/u^2$') 

plt.subplot(2,2,4)
plot_sheet(us_flat, pu, batch_shape, vmax=vmax)
plt.title('Transformed density')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$')

Re-grid the new density

Now create a new grid over u, and re-evaluate the density, to make integration easier.


In [0]:
u0_mesh = tf.linspace(-12.0, 12, 240+1)
u1_mesh = tf.linspace(-10.0, 10.0, 200+1)

us_mesh = tf.stack(tf.meshgrid(u0_mesh, u1_mesh), axis=-1)
new_batch_shape = us_mesh.shape[:-1]
us_mesh_flat = tf.reshape(us_mesh, [-1,2])

In [0]:
import scipy.interpolate
pu_mesh_flat = scipy.interpolate.griddata(us_flat, pu, us_mesh_flat, fill_value=0.0)
pu_mesh = tf.reshape(pu_mesh_flat, new_batch_shape)

In [0]:
plt.subplot(2,1,1)
plot_sheet(us_flat, pu, batch_shape, vmax=vmax)
plt.title('Transformed density')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$') 

plt.subplot(2,1,2)
plot_sheet(us_mesh_flat, pu_mesh, new_batch_shape)
plt.title('Transformed density, re-meshed')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$')

Calculate the marginal distributions

With the density in u reevaluated on a nice grid, it's possible to integrate and get nice results.

Integrate to get the two marginals.


In [0]:
import scipy.integrate
pu1 = scipy.integrate.trapz(tf.reshape(pu_mesh, new_batch_shape), u0_mesh)
pu0 = scipy.integrate.trapz(tf.reshape(pu_mesh, new_batch_shape).numpy().T, u1_mesh)

Integrate the marginals to sanity-check that the total probability mass is ~=1.0


In [0]:
print(scipy.integrate.trapz(pu1, u1_mesh))
print(scipy.integrate.trapz(pu0, u0_mesh))

Plot the marginals, they're still surprisingly Gaussian.


In [0]:
def plot_extras(us, sheet, u0=None, v0=None, u1=None, v1=None):
  axes = [[None, None],
          [None, None]]

  batch_shape = sheet.shape
  fig = plt.figure()

  ax_joint = fig.add_axes([0.1, 0.1, 0.6, 0.6])
  axes[1][0]=ax_joint
  plot_sheet(us, sheet, batch_shape)
  plt.axis('on')
  plt.xlim([-12,12])
  plt.ylim([-10,10])
  position = ax_joint.get_position()

  if u0 is not None:
    u0_ax = fig.add_axes([position.x0, position.y1+0.005, position.width, 0.15],
                        sharex=ax_joint)
    axes[0][0]=u0_ax
    plt.plot(u0, v0)
    plt.axis('off')

  if u1 is not None:
    u1_ax = fig.add_axes([position.x1+0.005, position.y0, 0.15, position.height],
                sharey=ax_joint)
    axes[1][1]=u1_ax
    plt.plot(v1, u1)
    plt.axis('off')

  return axes

In [0]:
axes = plot_extras(us_mesh, pu_mesh, u0_mesh, pu0, u1_mesh, pu1)
_ = axes[0][0].set_title('Joint and marginals')

Calculate the conditional distributions

Divide by the marginals to calculate the conditional distribution.

Here is p(u1|u0) and a slice of it.


In [0]:
at_u0 = 1.0
u1_slice = scipy.interpolate.griddata(
    us_mesh_flat, pu_mesh_flat, (at_u0*np.ones_like(u1_mesh), u1_mesh), fill_value=0.0)


axes = plot_extras(us_mesh, pu_mesh/pu0[None, :], u0_mesh, pu0, u1_mesh, u1_slice)
plt.sca(axes[1][0])
plt.plot([at_u0,at_u0], plt.ylim(), color='w')
axes[1][0].set_xlabel('p(u1|u0)')
axes[0][0].set_title('p(u0)')
_ = axes[1][1].set_title(f'p(u1|u0={at_u0})')

Here is p(u0|u1) and a slice of it.


In [0]:
at_u1 = 2.0
u0_slice = scipy.interpolate.griddata(
    us_mesh_flat, pu_mesh_flat, (u0_mesh, at_u1*np.ones_like(u0_mesh)), fill_value=0.0)

axes = plot_extras(us_mesh, pu_mesh/pu1[:,None], u0_mesh, u0_slice, u1_mesh, pu1)
plt.sca(axes[1][0])
plt.plot(plt.xlim(), [at_u1,at_u1], color='w')
axes[1][0].set_xlabel('p(u0|u1)')
axes[1][1].set_title('p(u1)')
_ = axes[0][0].set_title(f'p(u0|u1={at_u1})')