In this notebook, we fit a hierarchical Bayesian model to the "8 schools" dataset. See also https://github.com/probml/pyprobml/blob/master/scripts/schools8_pymc3.py


In [2]:
%matplotlib inline
import sklearn
import scipy.stats as stats
import scipy.optimize
import matplotlib.pyplot as plt
import seaborn as sns
import time
import numpy as np
import os
import pandas as pd


/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

In [3]:
!pip install pymc3==3.8
import pymc3 as pm
pm.__version__
import theano.tensor as tt
import theano

!pip install arviz
import arviz as az


Collecting pymc3==3.8
  Downloading https://files.pythonhosted.org/packages/32/19/6c94cbadb287745ac38ff1197b9fadd66500b6b9c468e79099b110c6a2e9/pymc3-3.8-py3-none-any.whl (908kB)
     |████████████████████████████████| 911kB 2.8MB/s 
Requirement already satisfied: numpy>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.8) (1.18.4)
Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.8) (1.4.1)
Requirement already satisfied: patsy>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.8) (0.5.1)
Collecting arviz>=0.4.1
  Downloading https://files.pythonhosted.org/packages/6c/23/73ae3b88a6837fa5a162d984acabfd2e75dc847ed67e5690aa44d02f491a/arviz-0.7.0-py3-none-any.whl (1.5MB)
     |████████████████████████████████| 1.5MB 13.8MB/s 
Requirement already satisfied: pandas>=0.18.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.8) (1.0.3)
Requirement already satisfied: tqdm>=4.8.4 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.8) (4.41.1)
Requirement already satisfied: h5py>=2.7.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.8) (2.10.0)
Requirement already satisfied: theano>=1.0.4 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.8) (1.0.4)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from patsy>=0.4.0->pymc3==3.8) (1.12.0)
Collecting netcdf4
  Downloading https://files.pythonhosted.org/packages/35/4f/d49fe0c65dea4d2ebfdc602d3e3d2a45a172255c151f4497c43f6d94a5f6/netCDF4-1.5.3-cp36-cp36m-manylinux1_x86_64.whl (4.1MB)
     |████████████████████████████████| 4.1MB 20.8MB/s 
Requirement already satisfied: xarray>=0.11 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.4.1->pymc3==3.8) (0.15.1)
Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.4.1->pymc3==3.8) (3.2.1)
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from arviz>=0.4.1->pymc3==3.8) (20.3)
Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.18.0->pymc3==3.8) (2.8.1)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.18.0->pymc3==3.8) (2018.9)
Collecting cftime
  Downloading https://files.pythonhosted.org/packages/0f/5e/ee154e2aabb0beea0c4c7dc3d93c6f64f96a2a2019bbd05afc905439d042/cftime-1.1.3-cp36-cp36m-manylinux1_x86_64.whl (322kB)
     |████████████████████████████████| 327kB 28.8MB/s 
Requirement already satisfied: setuptools>=41.2 in /usr/local/lib/python3.6/dist-packages (from xarray>=0.11->arviz>=0.4.1->pymc3==3.8) (46.3.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.4.1->pymc3==3.8) (2.4.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.4.1->pymc3==3.8) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.4.1->pymc3==3.8) (0.10.0)
Installing collected packages: cftime, netcdf4, arviz, pymc3
  Found existing installation: pymc3 3.7
    Uninstalling pymc3-3.7:
      Successfully uninstalled pymc3-3.7
Successfully installed arviz-0.7.0 cftime-1.1.3 netcdf4-1.5.3 pymc3-3.8
Requirement already satisfied: arviz in /usr/local/lib/python3.6/dist-packages (0.7.0)
Requirement already satisfied: netcdf4 in /usr/local/lib/python3.6/dist-packages (from arviz) (1.5.3)
Requirement already satisfied: scipy>=0.19 in /usr/local/lib/python3.6/dist-packages (from arviz) (1.4.1)
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from arviz) (20.3)
Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.6/dist-packages (from arviz) (3.2.1)
Requirement already satisfied: xarray>=0.11 in /usr/local/lib/python3.6/dist-packages (from arviz) (0.15.1)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from arviz) (1.18.4)
Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.6/dist-packages (from arviz) (1.0.3)
Requirement already satisfied: cftime in /usr/local/lib/python3.6/dist-packages (from netcdf4->arviz) (1.1.3)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from packaging->arviz) (1.12.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->arviz) (2.4.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz) (1.2.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz) (2.8.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz) (0.10.0)
Requirement already satisfied: setuptools>=41.2 in /usr/local/lib/python3.6/dist-packages (from xarray>=0.11->arviz) (46.3.0)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.23->arviz) (2018.9)

In [22]:
# https://github.com/probml/pyprobml/blob/master/scripts/schools8_pymc3.py

# Data of the Eight Schools Model
J = 8
y = np.array([28.,  8., -3.,  7., -1.,  1., 18., 12.])
sigma = np.array([15., 10., 16., 11.,  9., 11., 10., 18.])
print(np.mean(y))
print(np.median(y))

names=[]; 
for t in range(8):
    names.append('theta {}'.format(t)); 

# Plot raw data
fig, ax = plt.subplots()
y_pos = np.arange(8)
ax.errorbar(y,y_pos, xerr=sigma, fmt='o')
ax.set_yticks(y_pos)
ax.set_yticklabels(names)
ax.invert_yaxis()  # labels read top-to-bottom
plt.show()


8.75
7.5

In [6]:
# Centered model
with pm.Model() as Centered_eight:
    mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=5)
    sigma_alpha = pm.HalfCauchy('sigma_alpha', beta=5)
    alpha = pm.Normal('alpha', mu=mu_alpha, sigma=sigma_alpha, shape=J)
    obs = pm.Normal('obs', mu=alpha, sigma=sigma, observed=y)
    log_sigma_alpha = pm.Deterministic('log_sigma_alpha', tt.log(sigma_alpha))
    
np.random.seed(0)
with Centered_eight:
    trace_centered = pm.sample(10000, chains=4)
    
pm.summary(trace_centered).round(2)
# PyMC3 gives multiple warnings about  divergences
# Also, see r_hat ~ 1.01, ESS << nchains*1000, especially for sigma_alpha
# We can solve these problems below by using a non-centered parameterization.
# In practice, for this model, the results are very similar.


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (4 chains in 1 job)
NUTS: [alpha, sigma_alpha, mu_alpha]
Sampling chain 0, 197 divergences: 100%|██████████| 10500/10500 [00:18<00:00, 579.22it/s]
Sampling chain 1, 175 divergences: 100%|██████████| 10500/10500 [00:15<00:00, 661.88it/s]
Sampling chain 2, 223 divergences: 100%|██████████| 10500/10500 [00:18<00:00, 569.41it/s]
Sampling chain 3, 157 divergences: 100%|██████████| 10500/10500 [00:15<00:00, 660.66it/s]
INFO (theano.gof.compilelock): Refreshing lock /root/.theano/compiledir_Linux-4.19.104+-x86_64-with-Ubuntu-18.04-bionic-x86_64-3.6.9-64/lock_dir/lock
There were 197 divergences after tuning. Increase `target_accept` or reparameterize.
There were 372 divergences after tuning. Increase `target_accept` or reparameterize.
There were 595 divergences after tuning. Increase `target_accept` or reparameterize.
There were 752 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some parameters.
Out[6]:
mean sd hpd_3% hpd_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
mu_alpha 4.38 3.39 -2.16 10.40 0.07 0.05 2550.0 2550.0 2565.0 2486.0 1.00
alpha[0] 6.41 5.78 -3.99 17.47 0.09 0.06 4178.0 4178.0 3511.0 3751.0 1.00
alpha[1] 5.02 4.91 -4.14 14.28 0.07 0.05 4643.0 4643.0 4229.0 4345.0 1.00
alpha[2] 3.88 5.46 -6.73 13.76 0.07 0.05 5705.0 5705.0 4898.0 10517.0 1.00
alpha[3] 4.76 4.97 -4.53 14.15 0.06 0.05 5788.0 5788.0 5229.0 15616.0 1.00
alpha[4] 3.51 4.82 -5.53 12.59 0.07 0.05 4813.0 4813.0 4422.0 5995.0 1.00
alpha[5] 3.94 5.02 -5.72 13.27 0.07 0.05 5001.0 5001.0 4573.0 8022.0 1.00
alpha[6] 6.53 5.31 -2.65 17.25 0.08 0.06 3958.0 3958.0 3452.0 1117.0 1.00
alpha[7] 4.89 5.57 -5.51 15.64 0.07 0.05 6505.0 6505.0 5318.0 15296.0 1.00
sigma_alpha 4.05 3.10 0.58 9.52 0.07 0.05 2017.0 2017.0 815.0 419.0 1.01
log_sigma_alpha 1.14 0.73 -0.28 2.34 0.02 0.02 946.0 946.0 815.0 419.0 1.01

In [7]:
# Display the total number and percentage of divergent chains
diverging = trace_centered['diverging']
print('Number of Divergent Chains: {}'.format(diverging.nonzero()[0].size))
diverging_pct = diverging.nonzero()[0].size / len(trace_centered) * 100
print('Percentage of Divergent Chains: {:.1f}'.format(diverging_pct))


Number of Divergent Chains: 752
Percentage of Divergent Chains: 7.5

In [8]:
# We can see somewhat high auto correlation of the samples
az.plot_autocorr(trace_centered, var_names=['mu_alpha', 'sigma_alpha']);



In [9]:
az.plot_forest(trace_centered, var_names="alpha", 
               credible_interval=0.95, combined=True);



In [11]:
# Non-centered parameterization

with pm.Model() as NonCentered_eight:
    mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=5)
    sigma_alpha = pm.HalfCauchy('sigma_alpha', beta=5)
    alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, shape=J)
    alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset)
    #alpha = pm.Normal('alpha', mu=mu_alpha, sigma=sigma_alpha, shape=J)
    obs = pm.Normal('obs', mu=alpha, sigma=sigma, observed=y)
    log_sigma_alpha = pm.Deterministic('log_sigma_alpha', tt.log(sigma_alpha))
    
np.random.seed(0)
with NonCentered_eight:
    trace_noncentered = pm.sample(10000, chains=4)
    
pm.summary(trace_noncentered).round(2)
# Samples look good: r_hat = 1, ESS ~= nchains*1000


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO (theano.gof.compilelock): Refreshing lock /root/.theano/compiledir_Linux-4.19.104+-x86_64-with-Ubuntu-18.04-bionic-x86_64-3.6.9-64/lock_dir/lock
Sequential sampling (4 chains in 1 job)
NUTS: [alpha_offset, sigma_alpha, mu_alpha]
Sampling chain 0, 6 divergences: 100%|██████████| 10500/10500 [00:10<00:00, 1020.17it/s]
Sampling chain 1, 5 divergences: 100%|██████████| 10500/10500 [00:09<00:00, 1078.90it/s]
Sampling chain 2, 12 divergences: 100%|██████████| 10500/10500 [00:09<00:00, 1072.70it/s]
Sampling chain 3, 3 divergences: 100%|██████████| 10500/10500 [00:09<00:00, 1106.74it/s]
There were 6 divergences after tuning. Increase `target_accept` or reparameterize.
There were 11 divergences after tuning. Increase `target_accept` or reparameterize.
There were 23 divergences after tuning. Increase `target_accept` or reparameterize.
There were 26 divergences after tuning. Increase `target_accept` or reparameterize.
Out[11]:
mean sd hpd_3% hpd_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
mu_alpha 4.39 3.30 -1.80 10.65 0.02 0.01 39155.0 35768.0 39478.0 23211.0 1.0
alpha_offset[0] 0.32 0.99 -1.51 2.21 0.00 0.00 47855.0 18067.0 47787.0 26722.0 1.0
alpha_offset[1] 0.10 0.94 -1.67 1.89 0.00 0.00 50056.0 17673.0 50105.0 28230.0 1.0
alpha_offset[2] -0.08 0.96 -1.96 1.67 0.00 0.00 53705.0 16047.0 53602.0 28294.0 1.0
alpha_offset[3] 0.06 0.94 -1.72 1.81 0.00 0.00 52013.0 18101.0 52042.0 29063.0 1.0
alpha_offset[4] -0.16 0.92 -1.86 1.64 0.00 0.00 46654.0 18725.0 46660.0 28955.0 1.0
alpha_offset[5] -0.07 0.95 -1.82 1.75 0.00 0.00 48738.0 16644.0 48750.0 27425.0 1.0
alpha_offset[6] 0.36 0.96 -1.46 2.17 0.00 0.00 45378.0 19489.0 45415.0 29028.0 1.0
alpha_offset[7] 0.08 0.98 -1.79 1.87 0.00 0.00 47547.0 17858.0 47510.0 29225.0 1.0
sigma_alpha 3.62 3.18 0.00 9.33 0.02 0.02 24727.0 20164.0 22515.0 17879.0 1.0
alpha[0] 6.26 5.62 -3.71 17.36 0.03 0.02 33406.0 24960.0 37825.0 27209.0 1.0
alpha[1] 4.96 4.69 -3.98 13.92 0.02 0.02 44581.0 32659.0 46400.0 31381.0 1.0
alpha[2] 3.96 5.24 -6.29 13.61 0.03 0.02 38935.0 29705.0 42101.0 29244.0 1.0
alpha[3] 4.77 4.79 -4.25 14.00 0.02 0.02 42894.0 31321.0 44819.0 29789.0 1.0
alpha[4] 3.56 4.66 -5.29 12.25 0.02 0.02 40816.0 30764.0 42631.0 29189.0 1.0
alpha[5] 4.03 4.86 -5.45 13.02 0.02 0.02 40275.0 29366.0 42168.0 27783.0 1.0
alpha[6] 6.33 5.10 -2.96 16.24 0.03 0.02 37531.0 28689.0 39906.0 30022.0 1.0
alpha[7] 4.87 5.33 -5.52 14.80 0.03 0.02 37795.0 27133.0 40186.0 29823.0 1.0
log_sigma_alpha 0.82 1.15 -1.33 2.67 0.01 0.01 18933.0 17559.0 22515.0 17879.0 1.0

In [12]:
az.plot_autocorr(trace_noncentered, var_names=['mu_alpha', 'sigma_alpha']);



In [13]:
az.plot_forest(trace_noncentered, var_names="alpha",
               combined=True, credible_interval=0.95);



In [23]:
az.plot_forest([trace_centered, trace_noncentered], model_names=['centered', 'noncentered'],
               var_names="alpha",
               combined=True, credible_interval=0.95);
plt.axvline(np.mean(y), color='k', linestyle='--')


Out[23]:
<matplotlib.lines.Line2D at 0x7f4b31ade748>

In [20]:
az.plot_forest([trace_centered, trace_noncentered], model_names=['centered', 'noncentered'],
               var_names="alpha", kind='ridgeplot',
               combined=True, credible_interval=0.95);



In [24]:
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
x = pd.Series(trace_centered['mu_alpha'], name='mu_alpha')
y  = pd.Series(trace_centered['log_sigma_alpha'], name='log_sigma_alpha')
axs[0].plot(x, y, '.');
axs[0].set(title='Centered', xlabel='µ', ylabel='log(sigma)');
#axs[0].axhline(0.01)

x = pd.Series(trace_noncentered['mu_alpha'], name='mu')
y  = pd.Series(trace_noncentered['log_sigma_alpha'], name='log_sigma_alpha')
axs[1].plot(x, y, '.');
axs[1].set(title='NonCentered', xlabel='µ', ylabel='log(sigma)');
#axs[1].axhline(0.01)

xlim = axs[0].get_xlim()
ylim = axs[0].get_ylim()



In [16]:
# Plot the "funnel of hell"
# Based on
# https://github.com/twiecki/WhileMyMCMCGentlySamples/blob/master/content/downloads/notebooks/GLM_hierarchical_non_centered.ipynb

x = pd.Series(trace_centered['mu_alpha'], name='mu')
y = pd.Series(trace_centered['log_sigma_alpha'], name='log sigma_alpha')
sns.jointplot(x, y, xlim=xlim, ylim=ylim);
plt.suptitle('centered')

x = pd.Series(trace_noncentered['mu_alpha'], name='mu')
y = pd.Series(trace_noncentered['log_sigma_alpha'], name='log sigma_alpha')
sns.jointplot(x, y, xlim=xlim, ylim=ylim);
plt.suptitle('noncentered')


Out[16]:
Text(0.5, 0.98, 'noncentered')

In [17]:
group = 0
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
x = pd.Series(trace_centered['alpha'][:, group], name=f'alpha {group}')
y  = pd.Series(trace_centered['log_sigma_alpha'], name='log_sigma_alpha')
axs[0].plot(x, y, '.');
axs[0].set(title='Centered', xlabel='µ', ylabel='log(sigma)');

x = pd.Series(trace_noncentered['alpha'][:,group], name=f'alpha {group}')
y  = pd.Series(trace_noncentered['log_sigma_alpha'], name='log_sigma_alpha')
axs[1].plot(x, y, '.');
axs[1].set(title='NonCentered', xlabel='µ', ylabel='log(sigma)');

xlim = axs[0].get_xlim()
ylim = axs[0].get_ylim()