In [0]:
# This notebook illustrates how to use numpyro
# https://github.com/pyro-ppl/numpyro

# Speed comparison with TFP
# https://rlouf.github.io/post/jax-random-walk-metropolis/
# Speed comparison with pymc3
# https://www.kaggle.com/s903124/numpyro-speed-benchmark

Installation


In [5]:
# Standard Python libraries
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import time
#import numpy as np
#np.set_printoptions(precision=3)
import glob
import matplotlib.pyplot as plt
import PIL
import imageio

from IPython import display
%matplotlib inline

import sklearn

import seaborn as sns;
sns.set(style="ticks", color_codes=True)

import pandas as pd
pd.set_option('precision', 2) # 2 decimal places
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 30)
pd.set_option('display.width', 100) # wide windows


/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

In [1]:
# As of 5/25/20, colab has jax=0.1.67 and jaxlib=0.1.47 builtin

import jax
import jax.numpy as np
import numpy as onp # original numpy
from jax import grad, hessian, jit, vmap, random

print("jax version {}".format(jax.__version__))


jax version 0.1.67

In [2]:
# Check if GPU is available
!nvidia-smi

# Check if JAX is using GPU

print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))


Mon May 25 21:46:40 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.82       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
jax backend gpu

In [3]:
#https://github.com/pyro-ppl/numpyro/issues/531

# https://github.com/pyro-ppl/numpyro
!pip install numpyro # requires jax=0.1.57, jaxlib=0.1.37

print("jax version {}".format(jax.__version__))
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))


Collecting numpyro
  Downloading https://files.pythonhosted.org/packages/b8/58/54e914bb6d8ee9196f8dbf28b81057fea81871fc171dbee03b790336d0c5/numpyro-0.2.4-py3-none-any.whl (159kB)
     |████████████████████████████████| 163kB 6.4MB/s 
Collecting jaxlib==0.1.37
  Downloading https://files.pythonhosted.org/packages/24/bf/e181454464b866f30f09b5d74d1dd08e8b15e032716d8bcc531c659776ab/jaxlib-0.1.37-cp36-none-manylinux2010_x86_64.whl (25.4MB)
     |████████████████████████████████| 25.4MB 4.8MB/s 
Collecting jax==0.1.57
  Downloading https://files.pythonhosted.org/packages/ae/f2/ea981ed2659f70a1d8286ce41b5e74f1d9df844c1c6be6696144ed3f2932/jax-0.1.57.tar.gz (255kB)
     |████████████████████████████████| 256kB 42.9MB/s 
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from numpyro) (4.41.1)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37->numpyro) (1.4.1)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37->numpyro) (1.18.4)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37->numpyro) (1.12.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37->numpyro) (3.10.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37->numpyro) (0.9.0)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.6/dist-packages (from jax==0.1.57->numpyro) (3.2.1)
Collecting fastcache
  Downloading https://files.pythonhosted.org/packages/5f/a3/b280cba4b4abfe5f5bdc643e6c9d81bf3b9dc2148a11e5df06b6ba85a560/fastcache-1.1.0.tar.gz
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.0->jaxlib==0.1.37->numpyro) (46.3.0)
Building wheels for collected packages: jax, fastcache
  Building wheel for jax (setup.py) ... done
  Created wheel for jax: filename=jax-0.1.57-cp36-none-any.whl size=297709 sha256=c2bea348ae9097f522298dfd173b13b80d87b6b3a8694218cc7f41561b5baef6
  Stored in directory: /root/.cache/pip/wheels/8a/b4/75/859bcdaf181569124306615bd9b68c747725c60bfa68826378
  Building wheel for fastcache (setup.py) ... done
  Created wheel for fastcache: filename=fastcache-1.1.0-cp36-cp36m-linux_x86_64.whl size=39211 sha256=8ccffe72cb0f057afa1a0020eba9ee41bfde1c83b251e4a4f8c5051f751b9233
  Stored in directory: /root/.cache/pip/wheels/6a/80/bf/30024738b03fa5aa521e2a2ac952a8d77d0c65e68d92bcd3b6
Successfully built jax fastcache
Installing collected packages: jaxlib, fastcache, jax, numpyro
  Found existing installation: jaxlib 0.1.47
    Uninstalling jaxlib-0.1.47:
      Successfully uninstalled jaxlib-0.1.47
  Found existing installation: jax 0.1.67
    Uninstalling jax-0.1.67:
      Successfully uninstalled jax-0.1.67
Successfully installed fastcache-1.1.0 jax-0.1.57 jaxlib-0.1.37 numpyro-0.2.4
jax version 0.1.67
jax backend gpu

In [19]:
'''
#https://github.com/pyro-ppl/numpyro/issues/531

#!pip install --upgrade jax==0.1.57  
#!pip install --upgrade jaxlib==0.1.37  

#!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-$(pip search jaxlib | grep -oP '[0-9\.]+' | head -n 1)-cp36-none-linux_x86_64.whl
#!pip install --upgrade -q jax

ver = !echo $CUDA_VERSION
print(ver)

# install jaxlib
PYTHON_VERSION='cp36'  # alternatives: cp36, cp37, cp38
CUDA_VERSION='cuda101'  # alternatives: cuda92, cuda100, cuda101, cuda102
PLATFORM='linux_x86_64'  # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
fname = f'{BASE_URL}/{CUDA_VERSION}/jaxlib-0.1.37-{PYTHON_VERSION}-none-{PLATFORM}.whl'
print(fname)
#!pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.37-$PYTHON_VERSION-none-$PLATFORM.whl
!pip install --upgrade $fname

!pip install numpyro
!pip install --upgrade jax==0.1.57  
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
'''


['cuda101']
https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.37-cp36-none-linux_x86_64.whl
Collecting jaxlib==0.1.37
  Downloading https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.37-cp36-none-linux_x86_64.whl (48.3MB)
     |████████████████████████████████| 48.3MB 66kB/s 
Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37) (1.4.1)
Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37) (1.12.0)
Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37) (1.18.4)
Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37) (0.9.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from jaxlib==0.1.37) (3.10.0)
Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.0->jaxlib==0.1.37) (46.3.0)
ERROR: numpyro 0.2.4 has requirement jax>=0.1.65, but you'll have jax 0.1.57 which is incompatible.
ERROR: numpyro 0.2.4 has requirement jaxlib>=0.1.45, but you'll have jaxlib 0.1.37 which is incompatible.
Installing collected packages: jaxlib
  Found existing installation: jaxlib 0.1.47
    Uninstalling jaxlib-0.1.47:
      Successfully uninstalled jaxlib-0.1.47
Successfully installed jaxlib-0.1.37
Requirement already satisfied: numpyro in ./numpyro (0.2.4)
Processing /root/.cache/pip/wheels/3d/8d/d8/b0463ab20eb85b4ae7c602f7fbc0bd890f2af483b61e6d6096/jax-0.1.68-cp36-none-any.whl
Collecting jaxlib>=0.1.45
  Using cached https://files.pythonhosted.org/packages/ea/c0/64c0e5a2c6da1d3ffdec95da74abf14df2c7508776ff5f155461fec1ef1d/jaxlib-0.1.47-cp36-none-manylinux2010_x86_64.whl
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from numpyro) (4.41.1)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.65->numpyro) (1.18.4)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.65->numpyro) (3.2.1)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.65->numpyro) (0.9.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib>=0.1.45->numpyro) (1.4.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax>=0.1.65->numpyro) (1.12.0)
Installing collected packages: jax, jaxlib
  Found existing installation: jax 0.1.57
    Uninstalling jax-0.1.57:
      Successfully uninstalled jax-0.1.57
  Found existing installation: jaxlib 0.1.37
    Uninstalling jaxlib-0.1.37:
      Successfully uninstalled jaxlib-0.1.37
Successfully installed jax-0.1.68 jaxlib-0.1.47
Processing /root/.cache/pip/wheels/8a/b4/75/859bcdaf181569124306615bd9b68c747725c60bfa68826378/jax-0.1.57-cp36-none-any.whl
Requirement already satisfied, skipping upgrade: fastcache in /usr/local/lib/python3.6/dist-packages (from jax==0.1.57) (1.1.0)
Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax==0.1.57) (1.18.4)
Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax==0.1.57) (0.9.0)
Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax==0.1.57) (3.2.1)
Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax==0.1.57) (1.12.0)
ERROR: numpyro 0.2.4 has requirement jax>=0.1.65, but you'll have jax 0.1.57 which is incompatible.
Installing collected packages: jax
  Found existing installation: jax 0.1.68
    Uninstalling jax-0.1.68:
      Successfully uninstalled jax-0.1.68
Successfully installed jax-0.1.57
jax backend cpu

In [1]:
'''

# The latest version uses jax >= 0.1.65, jaxlib >= 0.1.45
# https://github.com/pyro-ppl/numpyro/blob/master/setup.py

#https://medium.com/@ashwindesilva/how-to-use-google-colaboratory-to-clone-a-github-repository-e07cf8d3d22b

!git clone https://github.com/pyro-ppl/numpyro.git 
%cd numpyro
!pip install -e .[dev]

print("jax version {}".format(jax.__version__))
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
'''


fatal: destination path 'numpyro' already exists and is not an empty directory.
/content/numpyro
Obtaining file:///content/numpyro
Requirement already satisfied: jax>=0.1.65 in /usr/local/lib/python3.6/dist-packages (from numpyro==0.2.4) (0.1.68)
Requirement already satisfied: jaxlib>=0.1.45 in /usr/local/lib/python3.6/dist-packages (from numpyro==0.2.4) (0.1.47)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from numpyro==0.2.4) (4.41.1)
Requirement already satisfied: ipython in /usr/local/lib/python3.6/dist-packages (from numpyro==0.2.4) (5.5.0)
Requirement already satisfied: isort in /usr/local/lib/python3.6/dist-packages (from numpyro==0.2.4) (4.3.21)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.65->numpyro==0.2.4) (0.9.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.65->numpyro==0.2.4) (3.2.1)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.65->numpyro==0.2.4) (1.18.4)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib>=0.1.45->numpyro==0.2.4) (1.4.1)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (0.7.5)
Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (0.8.1)
Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (4.3.3)
Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (1.0.18)
Requirement already satisfied: decorator in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (4.4.2)
Requirement already satisfied: pygments in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (2.1.3)
Requirement already satisfied: pexpect; sys_platform != "win32" in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (4.8.0)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.6/dist-packages (from ipython->numpyro==0.2.4) (46.3.0)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax>=0.1.65->numpyro==0.2.4) (1.12.0)
Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.6/dist-packages (from traitlets>=4.2->ipython->numpyro==0.2.4) (0.2.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython->numpyro==0.2.4) (0.1.9)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.6/dist-packages (from pexpect; sys_platform != "win32"->ipython->numpyro==0.2.4) (0.6.0)
Installing collected packages: numpyro
  Found existing installation: numpyro 0.2.4
    Can't uninstall 'numpyro'. No files were found to uninstall.
  Running setup.py develop for numpyro
Successfully installed numpyro

Distributions


In [0]:
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import hpdi
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer import MCMC, NUTS, Predictive

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

1d Gaussian


In [87]:
# 2 independent 1d gaussians (ie 1 diagonal Gaussian)
mu = 1.5
sigma = 2
d = dist.Normal(mu, sigma)
dir(d)


Out[87]:
['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_batch_shape',
 '_event_shape',
 '_validate_args',
 '_validate_sample',
 'arg_constraints',
 'batch_shape',
 'event_shape',
 'icdf',
 'loc',
 'log_prob',
 'mean',
 'reparametrized_params',
 'sample',
 'sample_with_intermediates',
 'scale',
 'set_default_validate_args',
 'support',
 'to_event',
 'transform_with_intermediates',
 'variance']

In [88]:
#rng_key, rng_key_ = random.split(rng_key)
nsamples = 1000
ys = d.sample(rng_key_, (nsamples,))
print(ys.shape)
mu_hat = np.mean(ys,0)
print(mu_hat)
sigma_hat = np.std(ys, 0)
print(sigma_hat)


(1000,)
1.4788736
2.0460527

Multivariate Gaussian


In [0]:
mu = np.array([-1, 1])
sigma = np.array([1, 2])
Sigma = np.diag(sigma)
d2 = dist.MultivariateNormal(mu, Sigma)

In [94]:
#rng_key, rng_key_ = random.split(rng_key)
nsamples = 1000
ys = d2.sample(rng_key_, (nsamples,))
print(ys.shape)
mu_hat = np.mean(ys,0)
print(mu_hat)
Sigma_hat = onp.cov(ys, rowvar=False) #jax.np.cov not implemented
print(Sigma_hat)


(1000, 2)
[-0.9644672   0.99415004]
[[0.93275181 0.0756547 ]
 [0.0756547  1.91598212]]

Shape semantics

Numpyro, Pyro and TFP all distinguish between 'event shape' and 'batch shape'. For a D-dimensional Gaussian, the event shape is (D,), and the batch shape will be (), meaning we have a single instance of this distribution. If the covariance is diagonal, we can view this as D independent 1d Gaussians, stored along the batch dimension; this will have event shape () but batch shape (2,).

When we sample from a distribution, we also specify the sample_shape. Suppose we draw N samples from a single D-dim diagonal Gaussian, and N samples from D 1d Gaussians. These samples will have the same shape. However, the semantics of logprob differs. We illustrate this below.


In [85]:
d2 = dist.MultivariateNormal(mu, Sigma)
print(d2.event_shape)
print(d2.batch_shape) 
nsamples = 1000
ys2 = d2.sample(rng_key_, (nsamples,))
print(ys2.shape)

# 2 independent 1d gaussians (same as one 2d diagonal Gaussian)
d3 = dist.Normal(mu, np.diag(Sigma))
print(d3.event_shape)
print(d3.batch_shape)
ys3 = d3.sample(rng_key_, (nsamples,))
print(ys3.shape)

print(np.allclose(ys2, ys3))


(2,)
()
(1000, 2)
()
(2,)
(1000, 2)
True

In [86]:
y = ys2[0,:] # 2 numbers
print(d2.log_prob(y)) # log prob of a single 2d distribution on 2d input 
print(d3.log_prob(y)) # log prob of two 1d distributions on 2d input


-2.1086864
[-1.1897303 -0.9189563]

We can turn a set of independent distributions into a single product distribution using the Independent class


In [98]:
d4 = dist.Independent(d3, 1) # treat the first batch dimension as an event dimensions
print(d4.event_shape)
print(d4.batch_shape)
print(d4.log_prob(y))


(2,)
()
-2.1086864

Posterior inference with MCMC

Example: 1d Gaussian with unknown mean.

We use the simple example from the Pyro intro. The goal is to infer the weight $\theta$ of an object, given noisy measurements $y$. We assume the following model: $$ \begin{align} \theta &\sim N(\mu=8.5, \tau^2=1.0)\\ y \sim &N(\theta, \sigma^2=0.75^2) \end{align} $$

Where $\mu=8.5$ is the initial guess.

By Bayes rule for Gaussians, we know that the exact posterior, given a single observation $y=9.5$, is given by

$$ \begin{align} \theta|y &\sim N(m, s^s) \\ m &=\frac{\sigma^2 \mu + \tau^2 y}{\sigma^2 + \tau^2} = \frac{0.75^2 \times 8.5 + 1 \times 9.5}{0.75^2 + 1^2} = 9.14 \\ s^2 &= \frac{\sigma^2 \tau^2}{\sigma^2 + \tau^2} = \frac{0.75^2 \times 1^2}{0.75^2 + 1^2}= 0.6^2 \end{align} $$

In [108]:
mu = 8.5; tau = 1.0; sigma = 0.75; y = 9.5
m = (sigma**2 * mu + tau**2 * y)/(sigma**2 + tau**2)
s2 = (sigma**2 * tau**2)/(sigma**2 + tau**2)
s = np.sqrt(s2)
print(m)
print(s)


9.14
0.6

In [106]:
def model(guess, measurement=None):
    weight = numpyro.sample("weight", dist.Normal(guess, tau))
    return numpyro.sample("measurement", dist.Normal(weight, sigma), obs=measurement)


9.14

In [109]:
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000)
guess = mu
measurement = y
mcmc.run(rng_key_, guess, measurement=measurement)

mcmc.print_summary()
samples  = mcmc.get_samples()


sample: 100%|██████████| 1100/1100 [00:04<00:00, 229.85it/s, 1 steps of size 8.84e-01. acc. prob=0.96]
                mean       std    median      5.0%     95.0%     n_eff     r_hat
    weight      9.09      0.62      9.11      8.05     10.03    325.40      1.01

Number of divergences: 0


In [0]: