In [54]:
import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import spacepy.plot as spp
from scipy import stats
import pandas as pd
%matplotlib inline
In [55]:
slope = 2.3
x = np.asarray(sorted(np.random.randint(50, 100+1, 23)))
y = np.asarray([stats.poisson(v).rvs()*slope for v in x]).astype(int)
data = pd.DataFrame({'x':x, 'y':y})
plt.scatter(x, y)
true_regression_line = slope * x
plt.plot(x, true_regression_line)
Out[55]:
In [53]:
with pm.Model() as model:
# specify glm and pass in data. The resulting linear model, its likelihood and
# and all its parameters are automatically added to our model.
glm = pm.glm.GLM.from_formula('y ~ x', data)
trace = pm.sample(3000, chains=2, target_accept=0.9) # draw 3000 posterior samples using NUTS sampling
In [33]:
plt.figure(figsize=(7, 7))
pm.traceplot(trace[100:])
plt.tight_layout();
In [45]:
plt.figure(figsize=(7, 7))
plt.plot(x, y, 'x', label='data')
for i in np.random.randint(100, len(trace), 50):
plt.plot(x, trace['Intercept'][i] + trace['x'][i]*x, c='grey', alpha=0.3)
plt.plot(x, true_regression_line, label='true regression line', lw=3., c='y')
plt.title('Posterior predictive regression lines')
plt.legend(loc=0)
plt.xlabel('x')
plt.ylabel('y');
In [47]:
family=pm.glm.families.Poisson()
In [56]:
with pm.Model() as model:
# specify glm and pass in data. The resulting linear model, its likelihood and
# and all its parameters are automatically added to our model.
glm = pm.glm.GLM.from_formula('y ~ x', data, family=pm.glm.families.Poisson())
trace = pm.sample(3000, chains=2, target_accept=0.9) # draw 3000 posterior samples using NUTS sampling
In [57]:
plt.figure(figsize=(7, 7))
pm.traceplot(trace[100:])
plt.tight_layout();
In [58]:
plt.figure(figsize=(7, 7))
plt.plot(x, y, 'x', label='data')
for i in np.random.randint(100, len(trace), 50):
plt.plot(x, trace['Intercept'][i] + trace['x'][i]*x, c='grey', alpha=0.3)
plt.plot(x, true_regression_line, label='true regression line', lw=3., c='y')
plt.title('Posterior predictive regression lines')
plt.legend(loc=0)
plt.xlabel('x')
plt.ylabel('y');
In [ ]:
In [ ]: