In this notebook, we explore the "funnel of hell". This refers to a posterior in which the mean and variance of a variable are highly correlated, and have a funnel shape. (The term "funnel of hell" is from this blog post by Thomas Wiecki.)
We illustrate this using a hierarchical Bayesian model for inferring Gaussian means, fit to synthetic data, similar to 8 schools (except we vary the same size and fix the variance). This code is based on this notebook from Justin Bois.
In [33]:
%matplotlib inline
import sklearn
import scipy.stats as stats
import scipy.optimize
import matplotlib.pyplot as plt
import seaborn as sns
import time
import numpy as np
import os
import pandas as pd
In [26]:
!pip install pymc3==3.8
import pymc3 as pm
pm.__version__
# The arviz package (https://github.com/arviz-devs/arviz) can be used to make various plots
# of posterior samples generated by any algorithm.
!pip install arviz
import arviz as az
In [0]:
import math
import pickle
import numpy as np
import pandas as pd
import scipy.stats as st
import theano.tensor as tt
import theano
In [55]:
np.random.seed(0)
# Specify parameters for random data
mu_val = 8
tau_val = 3
sigma_val = 10
n_groups = 10
# Generate number of replicates for each repeat
n = np.random.randint(low=3, high=10, size=n_groups, dtype=int)
print(n)
print(sum(n))
In [57]:
# Generate data set
mus = np.zeros(n_groups)
x = np.array([])
for i in range(n_groups):
mus[i] = np.random.normal(mu_val, tau_val)
samples = np.random.normal(mus[i], sigma_val, size=n[i])
x = np.append(x, samples)
print(x.shape)
group_ind = np.concatenate([[i]*n_val for i, n_val in enumerate(n)])
In [58]:
with pm.Model() as centered_model:
# Hyperpriors
mu = pm.Normal('mu', mu=0, sd=5)
tau = pm.HalfCauchy('tau', beta=2.5)
log_tau = pm.Deterministic('log_tau', tt.log(tau))
# Prior on theta
theta = pm.Normal('theta', mu=mu, sd=tau, shape=n_groups)
# Likelihood
x_obs = pm.Normal('x_obs',
mu=theta[group_ind],
sd=sigma_val,
observed=x)
np.random.seed(0)
with centered_model:
centered_trace = pm.sample(10000, chains=2)
pm.summary(centered_trace).round(2)
Out[58]:
In [59]:
with pm.Model() as noncentered_model:
# Hyperpriors
mu = pm.Normal('mu', mu=0, sd=5)
tau = pm.HalfCauchy('tau', beta=2.5)
log_tau = pm.Deterministic('log_tau', tt.log(tau))
# Prior on theta
#theta = pm.Normal('theta', mu=mu, sd=tau, shape=n_trials)
var_theta = pm.Normal('var_theta', mu=0, sd=1, shape=n_groups)
theta = pm.Deterministic('theta', mu + var_theta * tau)
# Likelihood
x_obs = pm.Normal('x_obs',
mu=theta[group_ind],
sd=sigma_val,
observed=x)
np.random.seed(0)
with noncentered_model:
noncentered_trace = pm.sample(1000, chains=2)
pm.summary(noncentered_trace).round(2)
Out[59]:
In [69]:
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
x = pd.Series(centered_trace['mu'], name='mu')
y = pd.Series(centered_trace['tau'], name='tau')
axs[0].plot(x, y, '.');
axs[0].set(title='Centered', xlabel='µ', ylabel='τ');
axs[0].axhline(0.01)
x = pd.Series(noncentered_trace['mu'], name='mu')
y = pd.Series(noncentered_trace['tau'], name='tau')
axs[1].plot(x, y, '.');
axs[1].set(title='NonCentered', xlabel='µ', ylabel='τ');
axs[1].axhline(0.01)
xlim = axs[0].get_xlim()
ylim = axs[0].get_ylim()
In [71]:
x = pd.Series(centered_trace['mu'], name='mu')
y = pd.Series(centered_trace['tau'], name='tau')
g = sns.jointplot(x, y, xlim=xlim, ylim=ylim)
plt.suptitle('centered')
plt.show()
In [70]:
x = pd.Series(noncentered_trace['mu'], name='mu')
y = pd.Series(noncentered_trace['tau'], name='tau')
g = sns.jointplot(x, y, xlim=xlim, ylim=ylim)
plt.suptitle('noncentered')
plt.show()
In [66]:
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
x = pd.Series(centered_trace['mu'], name='mu')
y = pd.Series(centered_trace['log_tau'], name='log_tau')
axs[0].plot(x, y, '.');
axs[0].set(title='Centered', xlabel='µ', ylabel='log(τ)');
x = pd.Series(noncentered_trace['mu'], name='mu')
y = pd.Series(noncentered_trace['log_tau'], name='log_tau')
axs[1].plot(x, y, '.');
axs[1].set(title='NonCentered', xlabel='µ', ylabel='log(τ)');
xlim = axs[0].get_xlim()
ylim = axs[0].get_ylim()
In [67]:
#https://seaborn.pydata.org/generated/seaborn.jointplot.html
x = pd.Series(centered_trace['mu'], name='mu')
y = pd.Series(centered_trace['log_tau'], name='log_tau')
g = sns.jointplot(x, y, xlim=xlim, ylim=ylim)
plt.suptitle('centered')
plt.show()
In [68]:
x = pd.Series(noncentered_trace['mu'], name='mu')
y = pd.Series(noncentered_trace['log_tau'], name='log_tau')
g = sns.jointplot(x, y, xlim=xlim, ylim=ylim)
plt.suptitle('noncentered')
plt.show()
In [72]:
az.plot_forest([centered_trace, noncentered_trace], model_names=['centered', 'noncentered'],
var_names="theta",
combined=True, credible_interval=0.95);