Test of a Generalized Metropolis-Hastings MCMC to explore the parameter space of a predator-prey model:

This notebook compares the performance of different samplers. For a more elementary example, see ode/springmass and ode/fitzhughnagumo

OPERATION:

  • Run a cell by pressing the black triangle in the toolbar above.
  • Note that the execution of a cell may take a while, and will be confirmed by a printout.
  • If a cell results in an error or warning, re-run the cell, or select "Kernel/Restart" at the top to restart.
  • To remove all printed output and figures, select "Cell/All Output/Clear" at the top.

In [ ]:
###Load the PyPlot package (only on the main process)
import PyPlot
println("PyPlot package loaded successfully")

In the following cell, you can specify the number of parallel processes to run the MCMC with. The way to do this differs when running the notebook on a single computer vs. when running this notebook on a cluster of different computers (for more information on clusters see Preparing an AWS Cluster).

  1. To run the MCMC not in parallel (in a single Julia process), set RUNPARALLEL=false.

  2. To run the MCMC in parallel on a single machine, set RUNPARALLEL=true and RUNONCLUSTER=false. You can set how many additional processes to run with by setting the NPROCS variable. It is recommended not to make NPROCS larger than the total number of CPU cores on your machine (defined by Julia global variable Sys.CPU_CORES).

  3. When running this notebook on a cluster, set RUNPARALLEL=true and RUNONCLUSTER=true. Set the xxx.xxx.xxx.xxx values to the private IP addresses of the slave machines you have started (add as many slaveip entries to machvec as required).


In [ ]:
RUNPARALLEL = true
RUNONCLUSTER = false

if RUNPARALLEL
    println("Starting additional Julia processes")
    NPROCS = min(16,Sys.CPU_CORES) #do not make larger than CPU_CORES
    if nprocs() < NPROCS
        addprocs(NPROCS-nprocs(),topology=:master_slave)
    end
    println("Number of Julia processes: ",nprocs())

    if RUNONCLUSTER 
        println("Starting additional Julia processes on the cluster")
        slaveip1 = "ubuntu@xxx.xxx.xxx.xxx"
        slaveip2 = "ubuntu@xxx.xxx.xxx.xxx"
        machvec = [(slaveip1,:auto),(slaveip2,:auto)]
        addprocs(machvec,topology=:master_slave)
        println("Total number of Julia processes in cluster: ",nprocs())
    end
end

In [ ]:
###Now Load the GMH package on all processes
import GeneralizedMetropolisHastings
import GMHExamples

@everywhere using GeneralizedMetropolisHastings
@everywhere using GMHExamples
println("GMH modules loaded successfully")

In [ ]:
nproposals = 300
niterations = 1000
nburnin = 500
ntunerperiod = 50

###Initial conditions for the ODE (prey and predator populations)
initial = [50.0,5.0]

###Default values of the parameters (a,b,c) and prior boundaries
defaults = [0.4,107.0,0.9,53.0,0.7,0.3]
lows = zeros(6)
highs = 150*ones(6)

###The variance of the noise on the input data
variance = sqrt(10.0)*ones(2)

println("==========================================")
println("Simulation parameters defined successfully")
println("==========================================")

In [ ]:
###Create a predator-prey model with measurement data, ODE function and parameters with default values and priors
m = predatorpreymodel(initial,variance,lows,highs,defaults)

###Show the model
println("==========================")
println("Model defined successfully")
println("==========================")
show(m)

In [ ]:
###Plot the measurement data (simmulated data + noise)
PyPlot.figure("PredatorPrey1")
PyPlot.plot(dataindex(m),measurements(m)[:,1];label="Prey")
PyPlot.plot(dataindex(m),measurements(m)[:,2];label="Predator")
PyPlot.xlabel("Time")
PyPlot.ylabel("Amplitude")
PyPlot.title("Predator-Prey measurement data")
PyPlot.grid("on")
PyPlot.legend(loc="upper right",fancybox="true")

In [ ]:
###Create different samplers

###Metropolis-Hastings Sampler with normal proposal density
mhsampler = sampler(:mh,:normal,0.01,6)

###Adaptive Metropolis-Hastings Sampler with normal proposal density
amhsampler = sampler(:adaptive,0.01,6)

println("============================")
println("Samplers defined successfully")
println("============================")
show(mhsampler)
show(amhsampler)

In [ ]:
###Create a tuner that scales the proposal density (for Metropolis-Hastings sampler)
stuner = tuner(:scale,ntunerperiod,0.5,:erf)

###Create a tuner that only monitors the acceptance rate (for Adaptive Metropolis-Hastings sampler)
mtuner = tuner(:monitor,ntunerperiod)

println("==========================")
println("Tuners defined successfully")
println("==========================")
show(stuner)
show(mtuner)

In [ ]:
###Create Generalized Metropolis-Hastings runner
p = policy(:mh,nproposals;initialize=:default)
r1 = runner(p,niterations,nproposals;numburnin=nburnin)
r2 = runner(p,niterations,nproposals;numburnin=nburnin) #longer burnin period for the adaptive sampler
println("============================")
println("Runners defined successfully")
println("============================")
show(r1)
show(r2)

In [ ]:
###Run the MCMC (can take quite a bit of time)
println("========================")
println("Run the MCMC simulations")
println("========================")
println("With Metropolis-Hastings Sampler")
@time result1 = run!(r1,m,mhsampler,stuner)
println("=========================================")
println("With Adaptive Metropolis-Hastings Sampler")
@time result2 = run!(r2,m,amhsampler,mtuner)
println("==========================")
println("Completed MCMC simulations")
println("==========================")

In [ ]:
###Show the results of the simulations
println("=========================")
println("Results of the MH Sampler")
println("=========================")
show(result1)

meanparamvals1 = mean(samples(result1),2)
stdparamvals1 = std(samples(result1),2)

println("Results of the MCMC simulation:")
for i=1:numparas(m)
    println(" parameter $(parameters(m)[i].key):  mean = ",meanparamvals1[i]," std = ",stdparamvals1[i])
end

println("==================================")
println("Results of the Adaptive-MH Sampler")
println("==================================")

show(result2)

meanparamvals2 = mean(samples(result2),2)
stdparamvals2 = std(samples(result2),2)

println("Results of the MCMC simulation:")
for i=1:numparas(m)
    println(" parameter $(parameters(m)[i].key):  mean = ",meanparamvals2[i]," std = ",stdparamvals2[i])
end

In [ ]:
###Plot the measurement data (simmulated data + noise)
PyPlot.figure("PredatorPrey2")
modeldata1 = evaluate!(m,vec(meanparamvals1))
modeldata2 = evaluate!(m,vec(meanparamvals2))
PyPlot.plot(dataindex(m),measurements(m)[:,1];label="Measured Prey")
PyPlot.plot(dataindex(m),measurements(m)[:,2];label="Measured Predator")
PyPlot.plot(dataindex(m),modeldata1[:,1];label="MH Sampler Prey")
PyPlot.plot(dataindex(m),modeldata1[:,2];label="MH Sampler Predator")
PyPlot.plot(dataindex(m),modeldata2[:,1];label="A-MH Sampler Prey")
PyPlot.plot(dataindex(m),modeldata2[:,2];label="A-MH Sampler Predator")
PyPlot.xlabel("Time")
PyPlot.ylabel("Amplitude")
PyPlot.title("Predator-Prey Model")
PyPlot.grid("on")
PyPlot.legend(loc="upper right",fancybox="true")

In [ ]:
###Plot the histograms of parameter values
PyPlot.figure("PredatorPrey3")
for i=1:numparas(m)
    PyPlot.subplot(320 + i)
    h = PyPlot.plt[:hist](vec(getindex(samples(result1),i,:)),20)
    PyPlot.grid("on")
    PyPlot.ylabel("Parameter $(parameters(m)[i].key)")
end
println("Parameter Histograms for MH Sampler")

In [ ]:
###Plot the histograms of parameter values
PyPlot.figure("PredatorPrey4")
for i=1:numparas(m)
    PyPlot.subplot(320 + i)
    h = PyPlot.plt[:hist](vec(getindex(samples(result2),i,:)),20)
    PyPlot.grid("on")
    PyPlot.ylabel("Parameter $(parameters(m)[i].key)")
end
println("Parameter Histograms for Adaptive MH Sampler")

In [ ]:
println("Number of processes running: ",nprocs())
println("Number of workers running: ",nworkers())
println("Process IDs: ",procs())

In [ ]:
###Only run this box if you want to shut down all worker processes
println("Pre processes running: ",procs())
if nprocs() > 1
    rmprocs(workers())
    sleep(1.0)
end
println("Post processes running: ",procs())

In [ ]: