Test of a Generalized Metropolis-Hastings MCMC to explore the parameter space for a simple Spring-Mass ODE model with spring stiffness K and mass M.

If the number of proposals per iteration equals 1, then the behaviour of this runner is equivalent to Standard M-H, for number of proposals > 1, it will behave as the Generalized Metropolis-Hastings algorithm.

OPERATION:

  • Run a cell by pressing the black triangle in the toolbar above.
  • Note that the execution of a cell may take a while, and will be confirmed by a printout.
  • If a cell prints output in a pink box, re-run and see if it disappears. If not, close and re-open the notebook, or select "Kernel/Restart" at the top.
  • To remove all printed output and figures, select "Cell/All Output/Clear" at the top.

In [ ]:
###Load the PyPlot package (only on the main process)
import PyPlot
println("PyPlot package loaded successfully")

In the following cell, you can specify the number of parallel processes to run the MCMC with. The way to do this differs when running the notebook on a single computer vs. when running this notebook on a cluster of different computers (for more information on clusters see Preparing an AWS Cluster).

  1. To run the MCMC not in parallel (in a single Julia process), set RUNPARALLEL=false.

  2. To run the MCMC in parallel on a single machine, set RUNPARALLEL=true and RUNONCLUSTER=false. You can set how many additional processes to run with by setting the NPROCS variable. It is recommended not to make NPROCS larger than the total number of CPU cores on your machine (defined by Julia global variable Sys.CPU_CORES).

  3. When running this notebook on a cluster, set RUNPARALLEL=true and RUNONCLUSTER=true. Set the xxx.xxx.xxx.xxx values to the private IP addresses of the slave machines you have started (add as many slaveip entries to machvec as required).


In [ ]:
RUNPARALLEL = true
RUNONCLUSTER = false

if RUNPARALLEL
    println("Starting additional Julia processes")
    NPROCS = min(3,Sys.CPU_CORES) #do not make larger than CPU_CORES
    if nprocs() < NPROCS
        addprocs(NPROCS-nprocs(),topology=:master_slave)
    end
    println("Number of Julia processes: ",nprocs())

    if RUNONCLUSTER 
        println("Starting additional Julia processes on the cluster")
        slaveip1 = "ubuntu@xxx.xxx.xxx.xxx"
        slaveip2 = "ubuntu@xxx.xxx.xxx.xxx"
        machvec = [(slaveip1,:auto),(slaveip2,:auto)]
        addprocs(machvec,topology=:master_slave)
        println("Total number of Julia processes in cluster: ",nprocs())
    end
end

In [ ]:
import GeneralizedMetropolisHastings
import GMHExamples
    
###The following statement makes the GeneralizedMetropolisHastings core code available on all processes
@everywhere using GeneralizedMetropolisHastings
@everywhere using GMHExamples
println("GMH modules loaded successfully")

In [ ]:
#Standard M-H for nproposals == 1
#Generalized M-H for nproposals > 1
nproposals = 30

#MCMC iteration specifications
nburnin = 200
niterations = 1000
ntunerperiod = 40

#Time points to simulate the spring-mass ODE
timepoints = 0.0:0.1:10.0

###Initial conditions for the spring-mass ODE (position and speed)
initialposition = -1.0 #in meters
initialvelocity = 1.0 #in meters/second

###Values of the model parameters (spring stiffness K and mass M)
K = 50.0 #in Newton/meter
M = 10.0 #in kg
lows = [K-K/5,M-M/5]
highs = [K+K/5,M+M/5]

###The variance of the normal noise on the input data
variance = [0.01,0.09]

println("==========================================")
println("Simulation parameters defined successfully")
println("==========================================")

In [ ]:
###Create a Spring-Mass model with measurement data and ODE function
m = springmassmodel(timepoints,[initialposition,initialvelocity],[K,M],variance,lows,highs)

###Show the model
println("==========================")
println("Model defined successfully")
println("==========================")
show(m)

In [ ]:
###Plot the measurement data (simmulated data + noise)
PyPlot.figure("SpringMass-Measurement")
PyPlot.plot(dataindex(m),measurements(m)[:,1];label="location")
PyPlot.plot(dataindex(m),measurements(m)[:,2];label="velocity")
PyPlot.xlabel("Time")
PyPlot.ylabel("Amplitude")
PyPlot.title("Spring-Mass measurement data")
PyPlot.grid("on")
PyPlot.legend(loc="upper right",fancybox="true")

Now specify the sampler (a Metropolis-Hastings sampler with a Gaussian proposal density) and the runner to run a Generalized Metropolis-Hastings algorithm (remember that the choice between Standard and Generalized M-H is made by either setting nproposals to 1 or make it > 1).


In [ ]:
###Create a Metropolis sampler with a Normal proposal density
s = sampler(:mh,:normal,0.1,ones(2))
println("============================")
println("Sampler defined successfully")
println("============================")
show(s)

###Create a tuner that scales the proposal density
t = tuner(:scale,ntunerperiod,0.5,:erf)
println("==========================")
println("Tuner defined successfully")
println("==========================")
show(t)

###Create a Generalized Metropolis-Hastings runner (which will default to Standard MH when nproposals=1)
p = policy(:mh,nproposals;initialize=:prior)
r = runner(p,niterations,nproposals;numburnin = nburnin)
println("===========================")
println("Runner defined successfully")
println("===========================")
show(r)

Now run the simulation using the runner, the model and the sampler specified above. A printout will appear when the simulation is finished.


In [ ]:
###Run the MCMC (can take quite a bit of time)
println("=======================")
println("Run the MCMC simulation")
println("=======================")
@time c = run!(r,m,s,t)
println("=========================")
println("Completed MCMC simulation")
println("=========================")

In [ ]:
###Show the result of the simulations
show(c)

nparas = numparas(m)
meanparamvals = mean(samples(c),2)
stdparamvals = std(samples(c),2)

println("Results of the MCMC simulation:")
println(" mean K: ",meanparamvals[1])
println(" mean M: ",meanparamvals[2])
println(" mean K/M: ",meanparamvals[1]/meanparamvals[2])
println("Mean K/M should be close to $(K/M)")

In [ ]:
###Plot the loglikelihood values across samples
###After an initial few low values, this should remain relatively high
PyPlot.plot(1:numsamples(c),logposterior(c,m))
PyPlot.title("Log-Posterior values across samples")
PyPlot.xlabel("Samples")
PyPlot.ylabel("Log-Posterior")

In [ ]:
###Plot a scatter plot of K vs M values
###These should be spread around the K/M == 10.0 line (the diagonal in the figure)
ax3 = PyPlot.subplot(111)
ax3[:set_xlim]([lows[1],highs[1]])
ax3[:set_ylim]([lows[2],highs[2]])
PyPlot.scatter(vec(getindex(samples(c),1,:)),vec(getindex(samples(c),2,:)),marker=".",color="blue")
ax3[:set_aspect](abs(highs[1]-lows[1])/abs(highs[2]-lows[2]))
PyPlot.title("MCMC samples of Spring-Mass ODE parameters")
PyPlot.xlabel("Stiffness K (N/m)")
PyPlot.ylabel("Mass M (kg)")
PyPlot.grid("on")

In [ ]:
###Finally, plot the average model results in the data window
modeldata = evaluate!(m,vec(meanparamvals))
PyPlot.plot(dataindex(m),modeldata[:,1];label="model location")
PyPlot.plot(dataindex(m),modeldata[:,2];label="model velocity")
PyPlot.plot(dataindex(m),measurements(m)[:,1];label="location")
PyPlot.plot(dataindex(m),measurements(m)[:,2];label="velocity")
PyPlot.xlabel("Time")
PyPlot.ylabel("Amplitude")
PyPlot.title("Spring-Mass measurement data")
PyPlot.grid("on")
PyPlot.legend(loc="lower right",fancybox="true")

In [ ]:
###Plot a histogram of K/M values, which should peak around the true ratio of K/M
kml,kmu = K/M-K/M/10.0,K/M+K/M/10.0
ax4 = PyPlot.subplot(111)
ax4[:set_xlim]([kml,kmu])
nbins = linspace(kml,kmu,100)
h = PyPlot.plt[:hist](vec(getindex(samples(c),1,:))./vec(getindex(samples(c),2,:)),nbins)
PyPlot.grid("on")
PyPlot.xlabel("K/M")
PyPlot.ylabel("Number of Samples")
PyPlot.title("Histogram of K/M values")

In [ ]:
println("Number of processes running: ",nprocs())
println("Number of workers running: ",nworkers())
println("Process IDs: ",procs())

In [ ]:
###Only run this box if you want to shut down all worker processes
println("Pre processes running: ",procs())
if nprocs() > 1
    rmprocs(workers())
    sleep(1.0)
end
println("Post processes running: ",procs())

In [ ]: