This example builds on adaptive covariance MCMC, and shows you how to plot the predicted time series.
Inference plots:
In [1]:
from __future__ import print_function
import pints
import pints.toy as toy
import numpy as np
import matplotlib.pyplot as plt
# Load a forward model
model = toy.LogisticModel()
# Create some toy data
real_parameters = [0.015, 500]
times = np.linspace(0, 1000, 100)
org_values = model.simulate(real_parameters, times)
# Add noise
noise = 50
values = org_values + np.random.normal(0, noise, org_values.shape)
real_parameters = np.array(real_parameters + [noise])
# Get properties of the noise sample
noise_sample_mean = np.mean(values - org_values)
noise_sample_std = np.std(values - org_values)
# Create an object with links to the model and time series
problem = pints.SingleOutputProblem(model, times, values)
# Create a log-likelihood function (adds an extra parameter!)
log_likelihood = pints.GaussianLogLikelihood(problem)
# Create a uniform prior over both the parameters and the new noise variable
log_prior = pints.UniformLogPrior(
[0.01, 400, noise*0.1],
[0.02, 600, noise*100]
)
# Create a posterior log-likelihood (log(likelihood * prior))
log_posterior = pints.LogPosterior(log_likelihood, log_prior)
# Perform sampling using MCMC, with a single chain
x0 = real_parameters * 1.1
mcmc = pints.MCMCController(log_posterior, 1, [x0])
mcmc.set_max_iterations(4000)
mcmc.set_log_to_screen(False)
In [2]:
plt.figure(figsize=(15, 7.5))
plt.plot(times, values, color='#1f77b4')
plt.plot(times, org_values, '--', color='#ff7f0e', lw=2, label='original values')
plt.plot(times, values, 'o', color='#7f7f7f', ms=6.5, label='data points')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Value')
plt.show()
In [3]:
# Run MCMC
print('Running...')
chains = mcmc.run()
print('Done!')
In [4]:
# Select chain 0 and discard warm-up
chain = chains[0]
chain = chain[3000:]
In [5]:
import pints.plot
fig, ax = pints.plot.series(chain, problem, ref_parameters=real_parameters)
# Enlarge the figure for easier viewing
fig.set_size_inches(15, 7.5)
plt.show()