In [0]:
# This notebook illustrates how to use numpyro
# https://github.com/pyro-ppl/numpyro
# Speed comparison with TFP
# https://rlouf.github.io/post/jax-random-walk-metropolis/
# Speed comparison with pymc3
# https://www.kaggle.com/s903124/numpyro-speed-benchmark
In [5]:
# Standard Python libraries
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import time
#import numpy as np
#np.set_printoptions(precision=3)
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
from IPython import display
%matplotlib inline
import sklearn
import seaborn as sns;
sns.set(style="ticks", color_codes=True)
import pandas as pd
pd.set_option('precision', 2) # 2 decimal places
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 30)
pd.set_option('display.width', 100) # wide windows
In [1]:
# As of 5/25/20, colab has jax=0.1.67 and jaxlib=0.1.47 builtin
import jax
import jax.numpy as np
import numpy as onp # original numpy
from jax import grad, hessian, jit, vmap, random
print("jax version {}".format(jax.__version__))
In [2]:
# Check if GPU is available
!nvidia-smi
# Check if JAX is using GPU
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
In [3]:
#https://github.com/pyro-ppl/numpyro/issues/531
# https://github.com/pyro-ppl/numpyro
!pip install numpyro # requires jax=0.1.57, jaxlib=0.1.37
print("jax version {}".format(jax.__version__))
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
In [19]:
'''
#https://github.com/pyro-ppl/numpyro/issues/531
#!pip install --upgrade jax==0.1.57
#!pip install --upgrade jaxlib==0.1.37
#!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-$(pip search jaxlib | grep -oP '[0-9\.]+' | head -n 1)-cp36-none-linux_x86_64.whl
#!pip install --upgrade -q jax
ver = !echo $CUDA_VERSION
print(ver)
# install jaxlib
PYTHON_VERSION='cp36' # alternatives: cp36, cp37, cp38
CUDA_VERSION='cuda101' # alternatives: cuda92, cuda100, cuda101, cuda102
PLATFORM='linux_x86_64' # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
fname = f'{BASE_URL}/{CUDA_VERSION}/jaxlib-0.1.37-{PYTHON_VERSION}-none-{PLATFORM}.whl'
print(fname)
#!pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.37-$PYTHON_VERSION-none-$PLATFORM.whl
!pip install --upgrade $fname
!pip install numpyro
!pip install --upgrade jax==0.1.57
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
'''
In [1]:
'''
# The latest version uses jax >= 0.1.65, jaxlib >= 0.1.45
# https://github.com/pyro-ppl/numpyro/blob/master/setup.py
#https://medium.com/@ashwindesilva/how-to-use-google-colaboratory-to-clone-a-github-repository-e07cf8d3d22b
!git clone https://github.com/pyro-ppl/numpyro.git
%cd numpyro
!pip install -e .[dev]
print("jax version {}".format(jax.__version__))
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
'''
In [0]:
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import hpdi
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer import MCMC, NUTS, Predictive
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
In [87]:
# 2 independent 1d gaussians (ie 1 diagonal Gaussian)
mu = 1.5
sigma = 2
d = dist.Normal(mu, sigma)
dir(d)
Out[87]:
In [88]:
#rng_key, rng_key_ = random.split(rng_key)
nsamples = 1000
ys = d.sample(rng_key_, (nsamples,))
print(ys.shape)
mu_hat = np.mean(ys,0)
print(mu_hat)
sigma_hat = np.std(ys, 0)
print(sigma_hat)
In [0]:
mu = np.array([-1, 1])
sigma = np.array([1, 2])
Sigma = np.diag(sigma)
d2 = dist.MultivariateNormal(mu, Sigma)
In [94]:
#rng_key, rng_key_ = random.split(rng_key)
nsamples = 1000
ys = d2.sample(rng_key_, (nsamples,))
print(ys.shape)
mu_hat = np.mean(ys,0)
print(mu_hat)
Sigma_hat = onp.cov(ys, rowvar=False) #jax.np.cov not implemented
print(Sigma_hat)
Numpyro, Pyro and TFP all distinguish between 'event shape' and 'batch shape'. For a D-dimensional Gaussian, the event shape is (D,), and the batch shape will be (), meaning we have a single instance of this distribution. If the covariance is diagonal, we can view this as D independent 1d Gaussians, stored along the batch dimension; this will have event shape () but batch shape (2,).
When we sample from a distribution, we also specify the sample_shape. Suppose we draw N samples from a single D-dim diagonal Gaussian, and N samples from D 1d Gaussians. These samples will have the same shape. However, the semantics of logprob differs. We illustrate this below.
In [85]:
d2 = dist.MultivariateNormal(mu, Sigma)
print(d2.event_shape)
print(d2.batch_shape)
nsamples = 1000
ys2 = d2.sample(rng_key_, (nsamples,))
print(ys2.shape)
# 2 independent 1d gaussians (same as one 2d diagonal Gaussian)
d3 = dist.Normal(mu, np.diag(Sigma))
print(d3.event_shape)
print(d3.batch_shape)
ys3 = d3.sample(rng_key_, (nsamples,))
print(ys3.shape)
print(np.allclose(ys2, ys3))
In [86]:
y = ys2[0,:] # 2 numbers
print(d2.log_prob(y)) # log prob of a single 2d distribution on 2d input
print(d3.log_prob(y)) # log prob of two 1d distributions on 2d input
We can turn a set of independent distributions into a single product distribution using the Independent class
In [98]:
d4 = dist.Independent(d3, 1) # treat the first batch dimension as an event dimensions
print(d4.event_shape)
print(d4.batch_shape)
print(d4.log_prob(y))
We use the simple example from the Pyro intro. The goal is to infer the weight $\theta$ of an object, given noisy measurements $y$. We assume the following model: $$ \begin{align} \theta &\sim N(\mu=8.5, \tau^2=1.0)\\ y \sim &N(\theta, \sigma^2=0.75^2) \end{align} $$
Where $\mu=8.5$ is the initial guess.
By Bayes rule for Gaussians, we know that the exact posterior, given a single observation $y=9.5$, is given by
$$ \begin{align} \theta|y &\sim N(m, s^s) \\ m &=\frac{\sigma^2 \mu + \tau^2 y}{\sigma^2 + \tau^2} = \frac{0.75^2 \times 8.5 + 1 \times 9.5}{0.75^2 + 1^2} = 9.14 \\ s^2 &= \frac{\sigma^2 \tau^2}{\sigma^2 + \tau^2} = \frac{0.75^2 \times 1^2}{0.75^2 + 1^2}= 0.6^2 \end{align} $$
In [108]:
mu = 8.5; tau = 1.0; sigma = 0.75; y = 9.5
m = (sigma**2 * mu + tau**2 * y)/(sigma**2 + tau**2)
s2 = (sigma**2 * tau**2)/(sigma**2 + tau**2)
s = np.sqrt(s2)
print(m)
print(s)
In [106]:
def model(guess, measurement=None):
weight = numpyro.sample("weight", dist.Normal(guess, tau))
return numpyro.sample("measurement", dist.Normal(weight, sigma), obs=measurement)
In [109]:
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000)
guess = mu
measurement = y
mcmc.run(rng_key_, guess, measurement=measurement)
mcmc.print_summary()
samples = mcmc.get_samples()
In [0]: