Using find_MAP on models with discrete variables

Maximum a posterior(MAP) estimation, can be difficult in models which have discrete stochastic variables. Here we demonstrate the problem with a simple model, and present a few possible work arounds.


In [1]:
import pymc3 as mc

We define a simple model of a survey with one data point. We use a $Beta$ distribution for the $p$ parameter in a binomial. We would like to know both the posterior distribution for p, as well as the predictive posterior distribution over the survey parameter.


In [2]:
alpha = 4
beta = 4
n = 20
yes = 15

with mc.Model() as model:
    p = mc.Beta('p', alpha, beta)
    surv_sim = mc.Binomial('surv_sim', n=n, p=p)
    surv = mc.Binomial('surv', n=n, p=p, observed=yes)

First let's try and use find_MAP.


In [3]:
with model:
    print(mc.find_MAP())


{'p': array(0.6086956533498806)}

find_map defaults to find the MAP for only the continuous variables we have to specify if we would like to use the discrete variables.


In [4]:
with model:
    print(mc.find_MAP(vars=model.vars, disp=True))


Warning: vars contains discrete variables. MAP estimates may not be accurate for the default parameters. Defaulting to non-gradient minimization fmin_powell.
Optimization terminated successfully.
         Current function value: 3.111511
         Iterations: 3
         Function evaluations: 95
{'surv_sim': 14.0, 'p': array(0.695652178810167)}

We set the disp variable to display a warning that we are using a non-gradient minimization technique, as discrete variables do not give much gradient information. To demonstrate this, if we use a gradient based minimization, fmin_bfgs, with various starting points we see that the map does not converge.


In [5]:
with model:
    for i in range(n+1):
        s = {'p':0.5, 'surv_sim':i}
        map_est = mc.find_MAP(start=s, vars=model.vars, fmin=mc.starting.optimize.fmin_bfgs)
        print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
                                                      map_est['surv_sim'],
                                                      s['p'],
                                                      map_est['p'],
                                                      model.logpc(map_est)))


surv_sim: 0->-1, p: 0.500000->0.391133, LogP:-inf
surv_sim: 1->1, p: 0.500000->0.500000, LogP:-14.298540
surv_sim: 2->2, p: 0.500000->0.500000, LogP:-12.047249
surv_sim: 3->3, p: 0.500000->0.500000, LogP:-10.255489
surv_sim: 4->4, p: 0.500000->0.500000, LogP:-8.808570
surv_sim: 5->5, p: 0.500000->0.500000, LogP:-7.645419
surv_sim: 6->6, p: 0.500000->0.500000, LogP:-6.729129
surv_sim: 7->7, p: 0.500000->0.500000, LogP:-6.035981
surv_sim: 8->8, p: 0.500000->0.500000, LogP:-5.550474
surv_sim: 9->8, p: 0.500000->0.558888, LogP:-5.161793
surv_sim: 10->9, p: 0.500000->0.587384, LogP:-4.563607
surv_sim: 11->10, p: 0.500000->0.500000, LogP:-5.167477
surv_sim: 12->11, p: 0.500000->0.500000, LogP:-5.262785
surv_sim: 13->12, p: 0.500000->0.500000, LogP:-5.550465
surv_sim: 14->13, p: 0.500000->0.500000, LogP:-6.035970
surv_sim: 15->14, p: 0.500000->0.681157, LogP:-3.133947
surv_sim: 16->16, p: 0.500000->0.500000, LogP:-8.808570
surv_sim: 17->17, p: 0.500000->0.500000, LogP:-10.255489
surv_sim: 18->18, p: 0.500000->0.500000, LogP:-12.047249
surv_sim: 19->19, p: 0.500000->0.500000, LogP:-14.298540
surv_sim: 20->20, p: 0.500000->0.500000, LogP:-17.294273

Once again because the gradient of surv_sim provides no information to the fmin routine and it is only changed in a few cases, most of which are not correct. Manually, looking at the log proability we can see that the maximum is somewhere around surv_sim$=14$ and p$=0.7$. If we employ a non-gradient minimization, such as fmin_powell (the default when discrete variables are detected), we might be able to get a better estimate.


In [6]:
with model:
    for i in range(n+1):
        s = {'p':0.5, 'surv_sim':i}
        map_est = mc.find_MAP(start=s, vars=model.vars)
        print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
                                                      map_est['surv_sim'],
                                                      s['p'],
                                                      map_est['p'],
                                                      model.logpc(map_est)))


surv_sim: 0->2, p: 0.500000->0.434783, LogP:-11.654827
surv_sim: 1->3, p: 0.500000->0.456522, LogP:-10.081356
surv_sim: 2->6, p: 0.500000->0.521739, LogP:-6.685637
surv_sim: 3->7, p: 0.500000->0.543478, LogP:-5.861849
surv_sim: 4->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 5->14, p: 0.500000->0.674290, LogP:-3.159870
surv_sim: 6->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 7->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 8->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 9->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 10->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 11->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 12->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 13->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 14->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 15->15, p: 0.500000->0.717392, LogP:-3.149062
surv_sim: 16->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 17->15, p: 0.500000->0.717391, LogP:-3.149062
surv_sim: 18->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 19->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 20->14, p: 0.500000->0.712421, LogP:-3.142725

For most starting values this converges to the maximum log likelihood of $\approx -3.15$, but for particularly low starting values of surv_sim, or values near surv_sim$=14$ there is still some noise. The scipy optimize package contains some more general 'global' minimization functions that we can utilize. The basinhopping algorithm restarts the optimization at places near found minimums. Because it has a slightly different interface to other minimization schemes we have to define a wrapper function.


In [7]:
def bh(*args,**kwargs):
    result = mc.starting.optimize.basinhopping(*args, **kwargs)
    # A `Result` object is returned, the argmin value can be in `x`
    return result['x']

with model:
    for i in range(n+1):
        s = {'p':0.5, 'surv_sim':i}
        map_est = mc.find_MAP(start=s, vars=model.vars, fmin=bh)
        print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
                                                      floor(map_est['surv_sim']),
                                                      s['p'],
                                                      map_est['p'],
                                                      model.logpc(map_est)))


surv_sim: 0->5, p: 0.500000->0.500000, LogP:-7.645419
surv_sim: 1->7, p: 0.500000->0.543478, LogP:-5.861849
surv_sim: 2->10, p: 0.500000->0.608696, LogP:-4.071797
surv_sim: 3->8, p: 0.500000->0.565217, LogP:-5.158052
surv_sim: 4->10, p: 0.500000->0.608696, LogP:-4.071797
surv_sim: 5->7, p: 0.500000->0.543478, LogP:-5.861849
surv_sim: 6->12, p: 0.500000->0.652174, LogP:-3.385867
surv_sim: 7->8, p: 0.500000->0.565217, LogP:-5.158052
surv_sim: 8->11, p: 0.500000->0.630435, LogP:-3.679320
surv_sim: 9->10, p: 0.500000->0.608696, LogP:-4.071797
surv_sim: 10->12, p: 0.500000->0.652174, LogP:-3.385867
surv_sim: 11->13, p: 0.500000->0.673913, LogP:-3.194359
surv_sim: 12->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 13->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 14->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 15->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 16->15, p: 0.500000->0.717391, LogP:-3.149062
surv_sim: 17->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 18->15, p: 0.500000->0.717391, LogP:-3.149062
surv_sim: 19->18, p: 0.500000->0.782609, LogP:-4.247450
surv_sim: 20->18, p: 0.500000->0.782609, LogP:-4.247450

By default basinhopping uses a gradient minimization technique, fmin_bfgs, resulting in inaccurate predictions many times. If we force basinhoping to use a non-gradient technique we get much better results


In [8]:
with model:
    for i in range(n+1):
        s = {'p':0.5, 'surv_sim':i}
        map_est = mc.find_MAP(start=s, vars=model.vars, fmin=bh, minimizer_kwargs={"method": /"Powell"})
        print('surv_sim: %i->%i, p: %f->%f, LogP:%f'%(s['surv_sim'],
                                                      map_est['surv_sim'],
                                                      s['p'],
                                                      map_est['p'],
                                                      model.logpc(map_est)))


surv_sim: 0->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 1->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 2->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 3->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 4->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 5->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 6->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 7->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 8->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 9->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 10->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 11->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 12->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 13->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 14->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 15->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 16->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 17->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 18->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 19->14, p: 0.500000->0.695652, LogP:-3.111511
surv_sim: 20->14, p: 0.500000->0.695652, LogP:-3.111511

Confident in our MAP estimate we can sample from the posterior, making sure we use the Metropolis method for our discrete variables.


In [9]:
with model:
    step1 = mc.step_methods.HamiltonianMC(vars=[p])
    step2 = mc.step_methods.Metropolis(vars=[surv_sim])

In [10]:
with model:
    trace = mc.sample(25000,[step1,step2],start=map_est)


 [-----------------100%-----------------] 25000 of 25000 complete in 37.4 sec

In [11]:
mc.traceplot(trace);



In [ ]: