In [1]:
%matplotlib inline
import pystan
import numpy as np
import matplotlib.pyplot as plt

PyStan: Python interface to Stan

https://github.com/stan-dev/pystan


In [2]:
schools_code = """
data {
    int<lower=0> J; // number of schools
    real y[J]; // estimated treatment effects
    real<lower=0> sigma[J]; // s.e. of effect estimates
}
parameters {
    real mu;
    real<lower=0> tau;
    real eta[J];
}
transformed parameters {
    real theta[J];
    for (j in 1:J)
        theta[j] <- mu + tau * eta[j];
}
model {
    eta ~ normal(0, 1);
    y ~ normal(theta, sigma);
}
"""

In [3]:
schools_dat = {'J': 8,
               'y': [28,  8, -3,  7, -1,  1, 18, 12],
               'sigma': [15, 10, 16, 11,  9, 11, 10, 18]}

In [4]:
fit = pystan.stan(model_code=schools_code, data=schools_dat,
                  iter=1000, chains=4)

print(fit)


Inference for Stan model: anon_model_7af71fe73cad6b7d51a4ce15383db4c6.
4 chains, each with iter=1000; warmup=500; thin=1; 
post-warmup draws per chain=500, total post-warmup draws=2000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.19    0.31   5.44  -1.28   4.65   7.93  11.43  19.18  315.0    1.0
tau        6.73    0.31   5.59   0.29   2.52   5.42   9.51  21.29  329.0    1.0
eta[0]     0.35    0.04   0.93  -1.55  -0.27   0.37   0.99   2.14  563.0    1.0
eta[1]     0.02    0.04   0.88  -1.72  -0.54   0.02   0.58   1.84  560.0    1.0
eta[2]     -0.2    0.04   0.94  -2.05  -0.84  -0.24   0.42    1.7  534.0    1.0
eta[3]    -0.03    0.04    0.9   -1.9  -0.62  -0.01   0.55   1.71  543.0    1.0
eta[4]    -0.33    0.04   0.88  -2.11  -0.88  -0.34   0.24   1.41  537.0    1.0
eta[5]    -0.25    0.04   0.91  -2.05  -0.86  -0.27   0.36   1.54  552.0    1.0
eta[6]     0.32    0.04   0.89  -1.45  -0.25   0.31    0.9    2.1  506.0    1.0
eta[7]     0.07    0.04   0.94  -1.84  -0.55   0.12   0.71   1.95  559.0    1.0
theta[0]  11.35    0.39   8.54  -3.13   5.89  10.24  15.76  32.16  488.0    1.0
theta[1]   8.21    0.28   6.52  -4.55   4.21   8.07  11.96  21.55  548.0    1.0
theta[2]   6.15    0.37   8.25 -13.94   2.09   6.57   11.1  20.72  495.0    1.0
theta[3]    7.8    0.26   6.49  -5.19   3.62   7.75   11.8  20.98  628.0    1.0
theta[4]   5.29    0.26   6.33   -9.1   1.52   5.72   9.61  16.77  610.0    1.0
theta[5]   6.11    0.31   7.01  -9.08   1.94   6.55  10.67  18.68  522.0    1.0
theta[6]  10.96     0.3   6.97  -1.03   6.33  10.25   14.9  26.98  542.0    1.0
theta[7]   8.79    0.41   8.16  -7.45   3.95   8.33  13.23  26.95  392.0    1.0
lp__      -4.91    0.14   2.64 -10.63  -6.56  -4.75  -2.99  -0.25  336.0    1.0

Samples were drawn using NUTS(diag_e) at Mon May  9 15:49:51 2016.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

In [5]:
eta = fit.extract(permuted=True)['eta']
np.mean(eta, axis=0)


Out[5]:
array([ 0.34745075,  0.02324738, -0.20351687, -0.03454671, -0.33163158,
       -0.24633651,  0.31718647,  0.07475159])

In [6]:
# if matplotlib is installed (optional, not required), a visual summary and
# traceplot are available
fit.plot()
plt.show()



In [ ]: