In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Notebook orignially contributed by: MarkDaoust
|
|
This notebook uses a jacobian determinant to do a change of variables in a probabilty distribution, and make some nice visualizations.
This is just a quick walkthrough, you should probably use TensorFlow probability for any serious applications.
In [0]:
import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (12, 8)
In [0]:
x0 = tf.linspace(-12.0, 12.0, 240+1)[:, tf.newaxis]
x1 = tf.linspace(-10.0, 10.0, 200+1)[tf.newaxis, :]
xs = tf.stack(tf.meshgrid(x0, x1), axis=-1)
xs_flat = tf.reshape(xs, [-1,2])
batch_shape = xs.shape[:-1]
print(batch_shape)
print(xs.shape)
print(xs_flat.shape)
Create a multivariate normal distribution over x
In [0]:
from tensorflow_probability import distributions as tfd
dist = tfd.MultivariateNormalFullCovariance(
loc=[2,1], covariance_matrix=[[9,-2],[-2,4]])
Calculate the probability on the x grid, and take a few samples
In [0]:
px = dist.prob(xs_flat)
sample_xs = dist.sample(sample_shape=50)
Plot the density and samples
In [0]:
def plot_sheet(coords_flat, value_flat, batch_shape, sample = None, **kwargs):
c0 = tf.reshape(coords_flat[..., 0], batch_shape)
c1 = tf.reshape(coords_flat[..., 1], batch_shape)
value = tf.reshape(value_flat, batch_shape)
if sample is not None:
plt.scatter(sample[:,0], sample[:,1], c='w', marker='.', zorder=1)
plt.pcolormesh(c0, c1, value, zorder=-1, shading='gouraud', **kwargs)
plt.axis('off')
plt.gca().set_aspect('equal')
In [0]:
plot_sheet(xs_flat, px, batch_shape, sample=sample_xs)
plt.title('A MultiVariate probability density')
cbar = plt.colorbar()
cbar.set_label('Density $1/x^2$')
In [0]:
def transform(xs):
x0 = xs[..., 0]
x1 = xs[..., 1]
u0 = x0 + tf.sin(x1)*0.9
u1 = x1 + tf.cos(x0)*0.9
us = tf.stack([u0,u1], axis=-1)
return us
In [0]:
us_flat = transform(xs_flat)
sample_us = transform(sample_xs)
Plot with the transformed coordinates, the density map is wrong.
In [0]:
plot_sheet(us_flat, px, batch_shape, sample=sample_us)
plt.title('This is wrong.')
cbar = plt.colorbar()
cbar.set_label('Wrong units: $1/x^2$')
In [0]:
with tf.GradientTape() as tape:
tape.watch(xs_flat)
us_flat = transform(xs_flat)
js = tape.batch_jacobian(us_flat, xs_flat)
js_scale = tf.linalg.det(js)
The Jacobian determinant tells you locally how much each differential area has expanded or contracted, and how much the density changed.
In [0]:
plot_sheet(us_flat, 1/js_scale, batch_shape)
plt.title('Jacobian density change')
cbar = plt.colorbar()
cbar.set_label('Area scale factor $u^2/x^2$')
Divide the density in x by the Jacobian determinant to get the density in u.
In [0]:
pu = px/js_scale
plot_sheet(us_flat, pu, batch_shape, sample=sample_us)
plt.title('Transformed density')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$')
Here is everything together:
In [0]:
vmax = pu.numpy().max()
plt.subplot(2,2,1)
plot_sheet(xs_flat, px, batch_shape, vmax=vmax)
plt.title('A MultiVariate probability density')
cbar = plt.colorbar()
cbar.set_label('Density $1/x^2$')
plt.subplot(2,2,2)
plot_sheet(us_flat, px, batch_shape, vmax=vmax)
plt.title('This is wrong')
cbar = plt.colorbar()
cbar.set_label('Wrong units: $1/x^2$')
plt.subplot(2,2,3)
plot_sheet(us_flat, 1/js_scale, batch_shape)
plt.title('Jacobian density change')
cbar = plt.colorbar()
cbar.set_label('Area scale factor $x^2/u^2$')
plt.subplot(2,2,4)
plot_sheet(us_flat, pu, batch_shape, vmax=vmax)
plt.title('Transformed density')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$')
In [0]:
u0_mesh = tf.linspace(-12.0, 12, 240+1)
u1_mesh = tf.linspace(-10.0, 10.0, 200+1)
us_mesh = tf.stack(tf.meshgrid(u0_mesh, u1_mesh), axis=-1)
new_batch_shape = us_mesh.shape[:-1]
us_mesh_flat = tf.reshape(us_mesh, [-1,2])
In [0]:
import scipy.interpolate
pu_mesh_flat = scipy.interpolate.griddata(us_flat, pu, us_mesh_flat, fill_value=0.0)
pu_mesh = tf.reshape(pu_mesh_flat, new_batch_shape)
In [0]:
plt.subplot(2,1,1)
plot_sheet(us_flat, pu, batch_shape, vmax=vmax)
plt.title('Transformed density')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$')
plt.subplot(2,1,2)
plot_sheet(us_mesh_flat, pu_mesh, new_batch_shape)
plt.title('Transformed density, re-meshed')
cbar = plt.colorbar()
cbar.set_label('Density $1/u^2$')
With the density in u reevaluated on a nice grid, it's possible to integrate and get nice results.
Integrate to get the two marginals.
In [0]:
import scipy.integrate
pu1 = scipy.integrate.trapz(tf.reshape(pu_mesh, new_batch_shape), u0_mesh)
pu0 = scipy.integrate.trapz(tf.reshape(pu_mesh, new_batch_shape).numpy().T, u1_mesh)
Integrate the marginals to sanity-check that the total probability mass is ~=1.0
In [0]:
print(scipy.integrate.trapz(pu1, u1_mesh))
print(scipy.integrate.trapz(pu0, u0_mesh))
Plot the marginals, they're still surprisingly Gaussian.
In [0]:
def plot_extras(us, sheet, u0=None, v0=None, u1=None, v1=None):
axes = [[None, None],
[None, None]]
batch_shape = sheet.shape
fig = plt.figure()
ax_joint = fig.add_axes([0.1, 0.1, 0.6, 0.6])
axes[1][0]=ax_joint
plot_sheet(us, sheet, batch_shape)
plt.axis('on')
plt.xlim([-12,12])
plt.ylim([-10,10])
position = ax_joint.get_position()
if u0 is not None:
u0_ax = fig.add_axes([position.x0, position.y1+0.005, position.width, 0.15],
sharex=ax_joint)
axes[0][0]=u0_ax
plt.plot(u0, v0)
plt.axis('off')
if u1 is not None:
u1_ax = fig.add_axes([position.x1+0.005, position.y0, 0.15, position.height],
sharey=ax_joint)
axes[1][1]=u1_ax
plt.plot(v1, u1)
plt.axis('off')
return axes
In [0]:
axes = plot_extras(us_mesh, pu_mesh, u0_mesh, pu0, u1_mesh, pu1)
_ = axes[0][0].set_title('Joint and marginals')
In [0]:
at_u0 = 1.0
u1_slice = scipy.interpolate.griddata(
us_mesh_flat, pu_mesh_flat, (at_u0*np.ones_like(u1_mesh), u1_mesh), fill_value=0.0)
axes = plot_extras(us_mesh, pu_mesh/pu0[None, :], u0_mesh, pu0, u1_mesh, u1_slice)
plt.sca(axes[1][0])
plt.plot([at_u0,at_u0], plt.ylim(), color='w')
axes[1][0].set_xlabel('p(u1|u0)')
axes[0][0].set_title('p(u0)')
_ = axes[1][1].set_title(f'p(u1|u0={at_u0})')
Here is p(u0|u1) and a slice of it.
In [0]:
at_u1 = 2.0
u0_slice = scipy.interpolate.griddata(
us_mesh_flat, pu_mesh_flat, (u0_mesh, at_u1*np.ones_like(u0_mesh)), fill_value=0.0)
axes = plot_extras(us_mesh, pu_mesh/pu1[:,None], u0_mesh, u0_slice, u1_mesh, pu1)
plt.sca(axes[1][0])
plt.plot(plt.xlim(), [at_u1,at_u1], color='w')
axes[1][0].set_xlabel('p(u0|u1)')
axes[1][1].set_title('p(u1)')
_ = axes[0][0].set_title(f'p(u0|u1={at_u1})')