Some examples for learning the Pymc library.
In [1]:
import pymc
import numpy as np
import matplotlib.pyplot as plt
import seaborn
% matplotlib inline
seaborn.set_style('darkgrid')
In [8]:
from pymc import DiscreteUniform, Exponential, deterministic, Poisson, Uniform, MCMC
In [28]:
disasters_array = np.array([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
# Define data and stochastics
switchpoint = DiscreteUniform(
'switchpoint',
lower=0,
upper=110,
doc='Switchpoint[year]')
early_mean = Exponential('early_mean', beta=1.)
late_mean = Exponential('late_mean', beta=1.)
@deterministic(plot=False)
def rate(s=switchpoint, e=early_mean, l=late_mean):
''' Concatenate Poisson means '''
out = np.empty(len(disasters_array))
out[:s] = e
out[s:] = l
return out
disasters = Poisson('disasters', mu=rate, value=disasters_array, observed=True)
#run the model
disaster_model = [disasters, rate, switchpoint, early_mean, late_mean]
M = MCMC(disaster_model)
M.sample(iter=50000, burn=10000, thin=10)
Rate is a vector, so can't be plotted in the same manner as the other variables.
In [39]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
for idx, rv in enumerate(M.stats().iterkeys()):
ax[idx % 2, idx // 2].set_title(rv)
if rv != 'rate':
ax[idx % 2, idx // 2].plot(M.trace(rv)[:])
else:
ax[idx % 2, idx // 2].plot(M.trace(rv)[-1,:])
Can get such a plot automatically via Matplot submodule:
In [42]:
from pymc.Matplot import plot
plot(M)
Compare the data with the (mean) model:
In [61]:
#plt.xlim(0, len(disasters_array))
#plt.scatter(*zip(*enumerate(disasters_array)))
plt.xlim(1851, 1962)
plt.xlabel('year')
plt.ylabel('count')
years = xrange(1851, 1962)
plt.scatter(years, disasters_array)
#mean of poisson is parameter itself
plt.plot(years, M.trace('rate')[-1,:])
change = 1851 + M.trace('switchpoint')[-1]
plt.title('change in legislation {}?'.format(change))
Out[61]: