import numpy as np, pymc3 as pm, theano.tensor as T, matplotlib.pyplot as plt
import theano
floatX = theano.config.floatX

M    = 6  # number of columns in X - fixed effect
N    = 10 # number of columns in L - random effect
nobs = 10

# generate design matrix using patsy
from patsy import dmatrices
import pandas as pd
predictors = []
for s1 in range(N):
    for c1 in range(2):
        for c2 in range(3):
            for i in range(nobs):
tbltest             = pd.DataFrame(predictors, columns=['Condi1', 'Condi2', 'subj'])
tbltest['Condi1']   = tbltest['Condi1'].astype('category')
tbltest['Condi2']   = tbltest['Condi2'].astype('category')
tbltest['subj']     = tbltest['subj'].astype('category')
tbltest['tempresp'] = np.random.normal(size=(nobs*M*N,1))

Y, X    = dmatrices("tempresp ~ Condi1*Condi2", data=tbltest, return_type='matrix')
Terms   = X.design_info.column_names
_, L    = dmatrices('tempresp ~ -1+subj', data=tbltest, return_type='matrix')
X       = np.asarray(X) # fixed effect
L       = np.asarray(L) # mixed effect
Y       = np.asarray(Y) 
# generate data
w0 = [5.,1.5,2.,3.,1.1,1.25]
z0 = np.random.normal(size=(N,))
Pheno   =,w0) +,z0) + Y.flatten()

with pm.Model() as mixedEffect:
    ### hyperpriors
    h2     = pm.Uniform('h2')
    sigma2 = pm.HalfCauchy('sigma2', 5)
    #beta_0 = pm.Uniform('beta_0', lower=-1000, upper=1000)   # a replacement for improper prior
    w = pm.Normal('w', mu=0, sd=100, shape=M)
    z = pm.Normal('z', mu=0, sd=(h2*sigma2)**0.5, shape=N)
    g =, z)
    y = pm.Normal('y', mu=g +,w), 
                  sd=((1-h2)*sigma2)**0.5, observed=Pheno)


with mixedEffect:
    s = theano.shared(pm.floatX(1))
    inference = pm.ADVI(cost_part_grad_scale=s)
    # ADVI has nearly converged
    # It is time to set `s` to zero
    approx =
    trace_vi = approx.sample(3000) 
    elbos1 = -inference.hist

pm.traceplot(trace_vi, lines={'w':w0, 'z':z0});

plt.plot(elbos1, alpha=.3)

with mixedEffect:
    trace = pm.sample(3000, njobs=2, tune=1000)

pm.traceplot(trace, lines={'w':w0, 'z':z0});

# plot advi and NUTS (copy from pymc3 example)

burnin = 1000
from scipy import stats
import seaborn as sns

gbij = approx.bij
means = gbij.rmap(approx.mean.eval())
cov = approx.cov.eval()
sds = gbij.rmap(np.diag(cov)**.5)

varnames = means.keys()
fig, axs = plt.subplots(nrows=len(varnames), figsize=(12, 18))
for var, ax in zip(varnames, axs):
    mu_arr = means[var]
    sigma_arr = sds[var]
    for i, (mu, sigma) in enumerate(zip(mu_arr.flatten(), sigma_arr.flatten())):
        sd3 = (-4*sigma + mu, 4*sigma + mu)
        x = np.linspace(sd3[0], sd3[1], 300)
        y = stats.norm(mu, sigma).pdf(x)
        ax.plot(x, y/4., '--')
        if trace[var].ndim > 1:
            t = trace[burnin:][var][:, i]
            t = trace[burnin:][var]
        pm.kdeplot(t, ax=ax)


with mixedEffect:
    tracede = pm.sample(5000, njobs=50, tune=1000, step=pm.DEMetropolis(), parallelize=True)

pm.traceplot(tracede, lines={'w':w0, 'z':z0});

with mixedEffect:
    mtrace = pm.sample(10000, pm.SMC())

pm.traceplot(mtrace, lines={'w':w0, 'z':z0});

# Evaluate output

df_summary1 = pm.summary(trace[burnin:],varnames=['w'])
wpymc = np.asarray(df_summary1['mean'])
df_summary2 = pm.summary(trace[burnin:],varnames=['z'])
zpymc = np.asarray(df_summary2['mean'])

df_summary1 = pm.summary(tracede[burnin:],varnames=['w'])
wpymcde = np.asarray(df_summary1['mean'])
df_summary2 = pm.summary(tracede[burnin:],varnames=['z'])
zpymcde = np.asarray(df_summary2['mean'])

df_summary1 = pm.summary(mtrace, varnames=['w'])
wpymc2 = np.asarray(df_summary1['mean'])
df_summary2 = pm.summary(mtrace, varnames=['z'])
zpymc2 = np.asarray(df_summary2['mean'])

w_vi1 = trace_vi['w'].mean(axis=0)
z_vi1 = trace_vi['z'].mean(axis=0)

import statsmodels.formula.api as smf
tbltest['Pheno'] = Pheno
md  = smf.mixedlm("Pheno ~ Condi1*Condi2", tbltest, groups=tbltest["subj"])
mdf =
fe_params = pd.DataFrame(mdf.fe_params,columns=['LMM'])
random_effects = pd.DataFrame(mdf.random_effects)
random_effects = random_effects.transpose()
random_effects = random_effects.rename(index=str, columns={'Group': 'LMM'})

fe_params['NUTS'] = pd.Series(wpymc, index=fe_params.index)
random_effects['NUTS'] = pd.Series(zpymc, index=random_effects.index)

fe_params['DEM'] = pd.Series(wpymcde, index=fe_params.index)
random_effects['DEM'] = pd.Series(zpymcde, index=random_effects.index)

fe_params['SMC'] = pd.Series(wpymc2, index=fe_params.index)
random_effects['SMC'] = pd.Series(zpymc2, index=random_effects.index)

fe_params['MeanField'] = pd.Series(w_vi1, index=fe_params.index)
random_effects['MeanField'] = pd.Series(z_vi1, index=random_effects.index)

# ploting function 
def plotfitted(fe_params,random_effects,X,Z,Y):
    ax1 = plt.subplot2grid((2,2), (0, 0))
    ax2 = plt.subplot2grid((2,2), (0, 1))
    ax3 = plt.subplot2grid((2,2), (1, 0), colspan=2)
    ax3.plot(Y.flatten(),'o',color='k',label = 'Observed', alpha=.25)
    for iname in fe_params.columns.get_values():
        fitted =,fe_params[iname]),random_effects[iname]).flatten()
        print("The MSE of "+iname+ " is " + str(np.mean(np.square(Y.flatten()-fitted))))
        ax3.plot(fitted,lw=1,label = iname, alpha=.5)

The MSE of LMM is 1.0181461088982078
The MSE of NUTS is 1.0180786888926217
The MSE of DEM is 1.0181278807943486
The MSE of SMC is 1.0182671564091361
The MSE of MeanField is 1.0178593524744708

