In this notebook, we fit a hierarchical Bayesian model to the "8 schools" dataset. See also https://github.com/probml/pyprobml/blob/master/scripts/schools8_pymc3.py
In [2]:
%matplotlib inline
import sklearn
import scipy.stats as stats
import scipy.optimize
import matplotlib.pyplot as plt
import seaborn as sns
import time
import numpy as np
import os
import pandas as pd
In [3]:
!pip install pymc3==3.8
import pymc3 as pm
pm.__version__
import theano.tensor as tt
import theano
!pip install arviz
import arviz as az
In [22]:
# https://github.com/probml/pyprobml/blob/master/scripts/schools8_pymc3.py
# Data of the Eight Schools Model
J = 8
y = np.array([28., 8., -3., 7., -1., 1., 18., 12.])
sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.])
print(np.mean(y))
print(np.median(y))
names=[];
for t in range(8):
names.append('theta {}'.format(t));
# Plot raw data
fig, ax = plt.subplots()
y_pos = np.arange(8)
ax.errorbar(y,y_pos, xerr=sigma, fmt='o')
ax.set_yticks(y_pos)
ax.set_yticklabels(names)
ax.invert_yaxis() # labels read top-to-bottom
plt.show()
In [6]:
# Centered model
with pm.Model() as Centered_eight:
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=5)
sigma_alpha = pm.HalfCauchy('sigma_alpha', beta=5)
alpha = pm.Normal('alpha', mu=mu_alpha, sigma=sigma_alpha, shape=J)
obs = pm.Normal('obs', mu=alpha, sigma=sigma, observed=y)
log_sigma_alpha = pm.Deterministic('log_sigma_alpha', tt.log(sigma_alpha))
np.random.seed(0)
with Centered_eight:
trace_centered = pm.sample(10000, chains=4)
pm.summary(trace_centered).round(2)
# PyMC3 gives multiple warnings about divergences
# Also, see r_hat ~ 1.01, ESS << nchains*1000, especially for sigma_alpha
# We can solve these problems below by using a non-centered parameterization.
# In practice, for this model, the results are very similar.
Out[6]:
In [7]:
# Display the total number and percentage of divergent chains
diverging = trace_centered['diverging']
print('Number of Divergent Chains: {}'.format(diverging.nonzero()[0].size))
diverging_pct = diverging.nonzero()[0].size / len(trace_centered) * 100
print('Percentage of Divergent Chains: {:.1f}'.format(diverging_pct))
In [8]:
# We can see somewhat high auto correlation of the samples
az.plot_autocorr(trace_centered, var_names=['mu_alpha', 'sigma_alpha']);
In [9]:
az.plot_forest(trace_centered, var_names="alpha",
credible_interval=0.95, combined=True);
In [11]:
# Non-centered parameterization
with pm.Model() as NonCentered_eight:
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=5)
sigma_alpha = pm.HalfCauchy('sigma_alpha', beta=5)
alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, shape=J)
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset)
#alpha = pm.Normal('alpha', mu=mu_alpha, sigma=sigma_alpha, shape=J)
obs = pm.Normal('obs', mu=alpha, sigma=sigma, observed=y)
log_sigma_alpha = pm.Deterministic('log_sigma_alpha', tt.log(sigma_alpha))
np.random.seed(0)
with NonCentered_eight:
trace_noncentered = pm.sample(10000, chains=4)
pm.summary(trace_noncentered).round(2)
# Samples look good: r_hat = 1, ESS ~= nchains*1000
Out[11]:
In [12]:
az.plot_autocorr(trace_noncentered, var_names=['mu_alpha', 'sigma_alpha']);
In [13]:
az.plot_forest(trace_noncentered, var_names="alpha",
combined=True, credible_interval=0.95);
In [23]:
az.plot_forest([trace_centered, trace_noncentered], model_names=['centered', 'noncentered'],
var_names="alpha",
combined=True, credible_interval=0.95);
plt.axvline(np.mean(y), color='k', linestyle='--')
Out[23]:
In [20]:
az.plot_forest([trace_centered, trace_noncentered], model_names=['centered', 'noncentered'],
var_names="alpha", kind='ridgeplot',
combined=True, credible_interval=0.95);
In [24]:
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
x = pd.Series(trace_centered['mu_alpha'], name='mu_alpha')
y = pd.Series(trace_centered['log_sigma_alpha'], name='log_sigma_alpha')
axs[0].plot(x, y, '.');
axs[0].set(title='Centered', xlabel='µ', ylabel='log(sigma)');
#axs[0].axhline(0.01)
x = pd.Series(trace_noncentered['mu_alpha'], name='mu')
y = pd.Series(trace_noncentered['log_sigma_alpha'], name='log_sigma_alpha')
axs[1].plot(x, y, '.');
axs[1].set(title='NonCentered', xlabel='µ', ylabel='log(sigma)');
#axs[1].axhline(0.01)
xlim = axs[0].get_xlim()
ylim = axs[0].get_ylim()
In [16]:
# Plot the "funnel of hell"
# Based on
# https://github.com/twiecki/WhileMyMCMCGentlySamples/blob/master/content/downloads/notebooks/GLM_hierarchical_non_centered.ipynb
x = pd.Series(trace_centered['mu_alpha'], name='mu')
y = pd.Series(trace_centered['log_sigma_alpha'], name='log sigma_alpha')
sns.jointplot(x, y, xlim=xlim, ylim=ylim);
plt.suptitle('centered')
x = pd.Series(trace_noncentered['mu_alpha'], name='mu')
y = pd.Series(trace_noncentered['log_sigma_alpha'], name='log sigma_alpha')
sns.jointplot(x, y, xlim=xlim, ylim=ylim);
plt.suptitle('noncentered')
Out[16]:
In [17]:
group = 0
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
x = pd.Series(trace_centered['alpha'][:, group], name=f'alpha {group}')
y = pd.Series(trace_centered['log_sigma_alpha'], name='log_sigma_alpha')
axs[0].plot(x, y, '.');
axs[0].set(title='Centered', xlabel='µ', ylabel='log(sigma)');
x = pd.Series(trace_noncentered['alpha'][:,group], name=f'alpha {group}')
y = pd.Series(trace_noncentered['log_sigma_alpha'], name='log_sigma_alpha')
axs[1].plot(x, y, '.');
axs[1].set(title='NonCentered', xlabel='µ', ylabel='log(sigma)');
xlim = axs[0].get_xlim()
ylim = axs[0].get_ylim()