Goodwin's oscillator toy model

This example shows how the Goodwin's Oscillator toy model can be used.

Our version of this model has five parameters and three oscillating states as described in [1].

[1] Estimating Bayes factors via thermodynamic integration and population MCMC. Ben Calderhead and Mark Girolami, 2009, Computational Statistics and Data Analysis.


In [1]:
import pints
import pints.toy
import pints.plot
import matplotlib.pyplot as plt
import numpy as np
import time

model = pints.toy.GoodwinOscillatorModel()

We can get an example set of parameters using the suggested_parameters() method:


In [2]:
real_parameters = model.suggested_parameters()
print(real_parameters)


[ 2.    4.    0.12  0.08  0.1 ]

In the same way, we can get a suggested set of sampling times:


In [3]:
times = model.suggested_times()

Now we can run a simulation:


In [4]:
values = model.simulate(real_parameters, times)

This gives us all we need to create a plot of current versus time:


In [5]:
plt.figure()
plt.subplot(3, 1, 1)
plt.plot(times, values[:, 0], 'b')
plt.subplot(3, 1, 2)
plt.plot(times, values[:, 1], 'g')
plt.subplot(3, 1, 3)
plt.plot(times, values[:, 2], 'r')
plt.show()


Now we will add some noise to generate some fake "experimental" data and try to recover the original parameters.


In [6]:
noise1 = 0.001
noise2 = 0.01
noise3 = 0.1
noisy_values = np.array(values, copy=True)
noisy_values[:, 0] += np.random.normal(0, noise1, len(times))
noisy_values[:, 1] += np.random.normal(0, noise2, len(times))
noisy_values[:, 2] += np.random.normal(0, noise3, len(times))

plt.figure()
plt.subplot(3, 1, 1)
plt.plot(times, noisy_values[:, 0], 'b')
plt.subplot(3, 1, 2)
plt.plot(times, noisy_values[:, 1], 'g')
plt.subplot(3, 1, 3)
plt.plot(times, noisy_values[:, 2], 'r')
plt.show()


Now we can try and infer the original parameters:


In [7]:
# Create an object with links to the model and time series
problem = pints.MultiOutputProblem(model, times, values)

# Create a log posterior
log_prior = pints.UniformLogPrior([1, 1, 0.01, 0.01, 0.01], [10, 10, 1, 1, 1])
log_likelihood = pints.GaussianKnownSigmaLogLikelihood(problem, [noise1, noise2, noise3])
log_posterior = pints.LogPosterior(log_likelihood, log_prior)

# Run MCMC on the noisy data
x0 = [[5, 5, 0.5, 0.5, 0.5]]*3
mcmc = pints.MCMCController(log_posterior, 3, x0)
mcmc.set_max_iterations(5000)
mcmc.set_log_to_screen(False)

start = time.time()

print('Running')
chains = mcmc.run()
print('Done!')

end = time.time()
diff = end - start


Running
Done!

Print results.


In [8]:
results = pints.MCMCSummary(chains=chains, time=diff,
                            parameter_names=["k2", "k3", "m1", "m2", "m3"])
print(results)


param    mean    std.    2.5%    25%    50%    75%    97.5%    rhat    ess    ess per sec.
-------  ------  ------  ------  -----  -----  -----  -------  ------  -----  --------------
k2       3.74    2.13    1.34    1.86   3.38   5.84   7.99     1.43    7.44   0.11
k3       4.60    2.17    1.75    2.95   3.67   6.60   8.17     1.80    5.29   0.08
m1       0.23    0.19    0.07    0.09   0.12   0.39   0.61     1.11    11.40  0.16
m2       0.09    0.05    0.02    0.04   0.08   0.11   0.18     1.21    10.28  0.15
m3       0.12    0.06    0.05    0.08   0.09   0.19   0.22     1.52    6.59   0.10

Now we can inspect the resulting chains:


In [9]:
pints.plot.trace(chains, ref_parameters=real_parameters)
plt.show()


This is a pretty hard problem!

And what about optimisation?


In [10]:
# Fit to the noisy data
parameters = []
opt = pints.OptimisationController(log_posterior, x0[0], method=pints.XNES)
opt.set_log_to_screen(False)
parameters, fbest = opt.run()

print('')
print('            p1       p2       p3       p4       p5')
print('real  ' + ' '.join(['{: 8.4g}'.format(float(x)) for x in real_parameters]))
print('found ' + ' '.join(['{: 8.4g}'.format(x) for x in parameters]))


            p1       p2       p3       p4       p5
real         2        4     0.12     0.08      0.1
found    1.999    4.002     0.12  0.07997   0.1001

Sampling using relativistic HMC

The Goodwin-oscillator model has sensitivities calculated by the forward sensitivities approach, so we can use samplers that use gradients (although they will be slower per iteration; although perhaps not by ESS per second!), like Relativistic HMC.


In [11]:
problem = pints.MultiOutputProblem(model, times, noisy_values)

# Create a log-likelihood function (adds an extra parameter!)
log_likelihood = pints.GaussianLogLikelihood(problem)

# Create a uniform prior over both the parameters and the new noise variable
log_prior = pints.UniformLogPrior(
    [0, 0, 0, 0, 0, 0, 0, 0],
    [10, 10, 1, 1, 1, 1, 1, 1]
)

# Create a posterior log-likelihood (log(likelihood * prior))
log_posterior = pints.LogPosterior(log_likelihood, log_prior)

# Choose starting points for 3 mcmc chains
real_parameters1 = np.array(real_parameters.tolist() + [noise1, noise2, noise3])
xs = [
    real_parameters1 * 1.1,
    real_parameters1 * 0.9,
    real_parameters1 * 1.15,
    real_parameters1 * 1.2,
]

# Create mcmc routine
mcmc = pints.MCMCController(log_posterior, 4, xs, method=pints.RelativisticMCMC)

# Add stopping criterion
mcmc.set_max_iterations(200)

# Run in parallel
mcmc.set_parallel(True)
mcmc.set_log_interval(1)


for sampler in mcmc.samplers():
    sampler.set_leapfrog_step_size([0.1, 0.5, 0.002, 0.002, 0.002, 0.0005, 0.001, 0.01])
    sampler.set_leapfrog_steps(10)

# time start
start = time.time()

# Run!
print('Running...')
chains = mcmc.run()
print('Done!')

# end time
end = time.time()
diff = end - start


Running...
Using Relativistic MCMC
Generating 4 chains.
Running in parallel with 4 worker processess.
Iter. Eval. Accept.   Accept.   Accept.   Accept.   Time m:s
0     4      0         0         0         0          0:00.1
1     44     0.333     0.333     0.333     0.333      0:00.7
2     84     0.5       0.5       0.25      0.5        0:01.3
3     124    0.6       0.6       0.4       0.6        0:01.8
4     164    0.5       0.667     0.5       0.5        0:02.4
5     204    0.429     0.571     0.429     0.429      0:02.9
6     244    0.5       0.5       0.5       0.375      0:03.5
7     284    0.556     0.444     0.556     0.444      0:04.0
8     324    0.6       0.4       0.5       0.4        0:04.5
9     364    0.636     0.455     0.545     0.455      0:05.1
10    404    0.583     0.5       0.583     0.5        0:05.6
11    444    0.538     0.538     0.615     0.538      0:06.2
12    484    0.571     0.571     0.571     0.571      0:06.7
13    524    0.533     0.6       0.6       0.533      0:07.2
14    564    0.5625    0.625     0.625     0.5625     0:07.8
15    604    0.588     0.647     0.647     0.529      0:08.3
16    644    0.556     0.611     0.667     0.556      0:08.8
17    684    0.526     0.632     0.632     0.526      0:09.4
18    724    0.55      0.6       0.6       0.5        0:09.9
19    764    0.524     0.571     0.571     0.476      0:10.5
20    804    0.5       0.545     0.591     0.5        0:11.0
21    844    0.478     0.565     0.565     0.522      0:11.6
22    884    0.458     0.542     0.583     0.542      0:12.1
23    924    0.48      0.56      0.6       0.52       0:12.6
24    964    0.5       0.577     0.615     0.5        0:13.2
25    1004   0.519     0.593     0.63      0.481      0:13.7
26    1044   0.536     0.571     0.643     0.5        0:14.3
27    1084   0.552     0.552     0.621     0.517      0:14.8
28    1124   0.567     0.567     0.6       0.5        0:15.3
29    1164   0.581     0.548     0.581     0.516129   0:15.9
30    1204   0.59375   0.5625    0.5625    0.53125    0:16.4
31    1244   0.606     0.576     0.576     0.515      0:17.0
32    1284   0.618     0.588     0.588     0.529      0:17.5
33    1324   0.629     0.6       0.571     0.543      0:18.1
34    1364   0.639     0.611     0.583     0.556      0:18.6
35    1404   0.649     0.622     0.595     0.568      0:19.1
36    1444   0.658     0.632     0.579     0.553      0:19.7
37    1484   0.667     0.615     0.59      0.538      0:20.2
38    1524   0.675     0.625     0.6       0.55       0:20.7
39    1564   0.683     0.634     0.61      0.537      0:21.3
40    1604   0.69      0.643     0.619     0.547619   0:21.8
41    1644   0.674     0.651     0.605     0.535      0:22.4
42    1684   0.682     0.659     0.614     0.545      0:22.9
43    1724   0.689     0.667     0.622     0.533      0:23.5
44    1764   0.696     0.673913  0.609     0.522      0:24.0
45    1804   0.702     0.681     0.617     0.532      0:24.5
46    1844   0.6875    0.6875    0.604     0.521      0:25.1
47    1884   0.694     0.694     0.592     0.51       0:25.6
48    1924   0.7       0.7       0.58      0.52       0:26.2
49    1964   0.686     0.686     0.569     0.529      0:26.7
50    2004   0.692     0.692     0.558     0.519      0:27.4
51    2044   0.679     0.698     0.547     0.528      0:27.9
52    2084   0.685     0.704     0.556     0.537037   0:28.5
53    2124   0.673     0.691     0.564     0.545      0:29.0
54    2164   0.679     0.696     0.554     0.554      0:29.6
55    2204   0.684     0.702     0.544     0.561      0:30.1
56    2244   0.69      0.707     0.552     0.552      0:30.6
57    2284   0.678     0.712     0.559322  0.559322   0:31.2
58    2324   0.683     0.717     0.55      0.567      0:31.7
59    2364   0.672     0.721     0.557377  0.574      0:32.3
60    2404   0.677     0.71      0.565     0.581      0:32.8
61    2444   0.683     0.698     0.571     0.587      0:33.3
62    2484   0.6875    0.6875    0.578125  0.59375    0:33.9
63    2524   0.677     0.692     0.569     0.585      0:34.4
64    2564   0.682     0.682     0.561     0.591      0:34.9
65    2604   0.672     0.672     0.567     0.597      0:35.5
66    2644   0.676     0.662     0.559     0.603      0:36.0
67    2684   0.681     0.652     0.551     0.594      0:36.5
68    2724   0.686     0.657     0.557     0.6        0:37.1
69    2764   0.676     0.662     0.563     0.606      0:37.6
70    2804   0.667     0.653     0.569     0.611      0:38.1
71    2844   0.671     0.658     0.562     0.603      0:38.7
72    2884   0.662     0.662     0.554     0.608      0:39.2
73    2924   0.667     0.667     0.547     0.613      0:39.7
74    2964   0.658     0.671     0.553     0.618      0:40.3
75    3004   0.649     0.662     0.558     0.623      0:40.8
76    3044   0.654     0.667     0.564     0.628      0:41.3
77    3084   0.658     0.671     0.556962  0.633      0:41.9
78    3124   0.6625    0.6625    0.55      0.625      0:42.4
79    3164   0.654321  0.667     0.556     0.617284   0:42.9
80    3204   0.646     0.659     0.549     0.622      0:43.5
81    3244   0.639     0.663     0.554     0.614      0:44.0
82    3284   0.643     0.655     0.56      0.607      0:44.5
83    3324   0.647     0.647     0.565     0.6        0:45.1
84    3364   0.651     0.64      0.558     0.593      0:45.6
85    3404   0.655     0.644     0.563     0.598      0:46.1
86    3444   0.659     0.648     0.568     0.602      0:46.6
87    3484   0.663     0.64      0.562     0.596      0:47.2
88    3524   0.667     0.633     0.556     0.589      0:47.7
89    3564   0.67      0.637     0.56      0.593      0:48.2
90    3604   0.663     0.63      0.565     0.598      0:48.8
91    3644   0.655914  0.634     0.57      0.591      0:49.3
92    3684   0.649     0.638     0.574     0.596      0:49.8
93    3724   0.653     0.642     0.579     0.6        0:50.4
94    3764   0.65625   0.646     0.573     0.59375    0:50.9
95    3804   0.649     0.649     0.577     0.588      0:51.4
96    3844   0.653     0.653     0.582     0.592      0:52.0
97    3884   0.657     0.657     0.586     0.596      0:52.5
98    3924   0.66      0.66      0.59      0.59       0:53.0
99    3964   0.663     0.663     0.594     0.584      0:53.5
100   4004   0.657     0.667     0.598     0.588      0:54.2
101   4044   0.65      0.67      0.592233  0.583      0:54.7
102   4084   0.654     0.673     0.587     0.587      0:55.3
103   4124   0.657     0.667     0.59      0.59       0:55.8
104   4164   0.66      0.67      0.594     0.594      0:56.3
105   4204   0.664     0.673     0.588785  0.598      0:56.9
106   4244   0.657     0.667     0.583     0.602      0:57.4
107   4284   0.651     0.67      0.587156  0.606      0:57.9
108   4324   0.655     0.673     0.591     0.6        0:58.5
109   4364   0.649     0.676     0.595     0.604      0:59.0
110   4404   0.652     0.67      0.598     0.607      0:59.5
111   4444   0.646     0.664     0.602     0.611      1:00.1
112   4484   0.649     0.658     0.596     0.614      1:00.6
113   4524   0.643     0.661     0.6       0.617      1:01.1
114   4564   0.647     0.664     0.595     0.612069   1:01.7
115   4604   0.65      0.667     0.598     0.615      1:02.2
116   4644   0.644     0.669     0.602     0.619      1:02.7
117   4684   0.639     0.664     0.605042  0.613      1:03.3
118   4724   0.633     0.667     0.608     0.608      1:03.8
119   4764   0.636     0.661157  0.603     0.612      1:04.3
120   4804   0.639     0.664     0.607     0.615      1:04.8
121   4844   0.642     0.667     0.61      0.618      1:05.4
122   4884   0.645     0.669     0.613     0.621      1:05.9
123   4924   0.648     0.672     0.608     0.624      1:06.4
124   4964   0.651     0.675     0.611     0.619      1:07.0
125   5004   0.654     0.677     0.614     0.614      1:07.5
126   5044   0.65625   0.68      0.617     0.609375   1:08.0
127   5084   0.659     0.682     0.620155  0.612      1:08.6
128   5124   0.654     0.685     0.615     0.615      1:09.1
129   5164   0.656     0.687     0.618     0.618      1:09.6
130   5204   0.659     0.689     0.614     0.621      1:10.2
131   5244   0.654     0.684     0.609     0.617      1:10.7
132   5284   0.657     0.687     0.612     0.619403   1:11.2
133   5324   0.652     0.689     0.615     0.622      1:11.8
134   5364   0.654     0.684     0.618     0.618      1:12.3
135   5404   0.649635  0.686     0.620438  0.613      1:12.8
136   5444   0.652     0.688     0.623     0.609      1:13.4
137   5484   0.655     0.691     0.618705  0.604      1:13.9
138   5524   0.65      0.693     0.621     0.607      1:14.5
139   5564   0.652     0.695     0.624     0.603      1:15.0
140   5604   0.648     0.697     0.627     0.599      1:15.5
141   5644   0.65      0.692     0.629     0.594      1:16.0
142   5684   0.653     0.694     0.632     0.59       1:16.6
143   5724   0.648     0.69      0.634     0.586      1:17.1
144   5764   0.651     0.692     0.637     0.582      1:17.6
145   5804   0.653     0.687     0.639     0.578      1:18.2
146   5844   0.655     0.689     0.642     0.574      1:18.7
147   5884   0.651     0.691     0.638     0.577      1:19.2
148   5924   0.647     0.687     0.633     0.58       1:19.8
149   5964   0.642     0.682     0.636     0.576      1:20.3
150   6004   0.645     0.684     0.638     0.579      1:21.0
151   6044   0.647     0.686     0.634     0.582      1:21.5
152   6084   0.649     0.688     0.636     0.584      1:22.1
153   6124   0.645     0.69      0.639     0.581      1:22.6
154   6164   0.647     0.692     0.641     0.583      1:23.1
155   6204   0.643     0.688     0.637     0.586      1:23.7
156   6244   0.646     0.69      0.639     0.589      1:24.2
157   6284   0.648     0.686     0.635     0.585      1:24.7
158   6324   0.64375   0.6875    0.6375    0.5875     1:25.3
159   6364   0.64      0.689441  0.64      0.59       1:25.8
160   6404   0.636     0.685     0.642     0.593      1:26.3
161   6444   0.638     0.681     0.644     0.589      1:26.9
162   6484   0.64      0.677     0.646     0.591      1:27.4
163   6524   0.642     0.679     0.648     0.588      1:27.9
164   6564   0.639     0.681     0.645     0.59       1:28.5
165   6604   0.641     0.677     0.647     0.593      1:29.0
166   6644   0.643     0.672619  0.643     0.595      1:29.5
167   6684   0.645     0.675     0.645     0.591716   1:30.1
168   6724   0.647     0.676     0.641     0.594      1:30.6
169   6764   0.649     0.678     0.637     0.596      1:31.1
170   6804   0.651     0.68      0.64      0.599      1:31.6
171   6844   0.647     0.682     0.642     0.601      1:32.2
172   6884   0.649     0.683908  0.644     0.598      1:32.7
173   6924   0.646     0.686     0.646     0.6        1:33.2
174   6964   0.642     0.682     0.648     0.602      1:33.8
175   7004   0.638     0.684     0.65      0.599      1:34.3
176   7044   0.635     0.685     0.646     0.596      1:34.9
177   7084   0.637     0.687     0.648     0.598      1:35.4
178   7124   0.639     0.683     0.65      0.6        1:35.9
179   7164   0.635     0.685     0.652     0.597      1:36.5
180   7204   0.637     0.687     0.654     0.593      1:37.0
181   7244   0.639     0.689     0.65      0.59       1:37.6
182   7284   0.636     0.685     0.652     0.592      1:38.1
183   7324   0.632     0.681     0.649     0.589      1:38.6
184   7364   0.629     0.683     0.651     0.591      1:39.2
185   7404   0.631016  0.684492  0.647     0.588      1:39.7
186   7444   0.633     0.686     0.649     0.59       1:40.2
187   7484   0.635     0.688     0.646     0.593      1:40.8
188   7524   0.637     0.689     0.647     0.589      1:41.3
189   7564   0.634     0.691     0.644     0.586      1:41.8
190   7604   0.635     0.6875    0.646     0.583      1:42.4
191   7644   0.632     0.689     0.648     0.58       1:42.9
192   7684   0.634     0.691     0.649     0.577      1:43.4
193   7724   0.636     0.687     0.646     0.574359   1:44.0
194   7764   0.638     0.689     0.643     0.571      1:44.5
195   7804   0.635     0.69      0.64      0.569      1:45.0
196   7844   0.636     0.687     0.636     0.571      1:45.6
197   7884   0.638191  0.683     0.638191  0.573      1:46.1
198   7924   0.635     0.685     0.635     0.575      1:46.6
199   7964   0.637     0.687     0.637     0.577      1:47.2
200   7964   0.637     0.687     0.637     0.577      1:47.2
Halting: Maximum number of iterations (200) reached.
Done!

Print results.


In [12]:
results = pints.MCMCSummary(chains=chains, time=diff,
                            parameter_names=["k2", "k3", "m1", "m2", "m3",
                                             "sigma_x", "sigma_y", "sigma_z"])
print(results)


param    mean    std.    2.5%    25%    50%    75%    97.5%    rhat    ess     ess per sec.
-------  ------  ------  ------  -----  -----  -----  -------  ------  ------  --------------
k2       2.03    0.09    1.89    1.97   2.02   2.08   2.21     1.09    42.50   0.39
k3       3.98    0.19    3.65    3.85   3.97   4.08   4.39     1.07    26.19   0.24
m1       0.12    0.00    0.11    0.12   0.12   0.12   0.13     1.08    41.92   0.39
m2       0.08    0.00    0.08    0.08   0.08   0.08   0.09     1.06    27.34   0.25
m3       0.10    0.00    0.09    0.10   0.10   0.10   0.11     1.08    23.06   0.21
sigma_x  0.00    0.00    0.00    0.00   0.00   0.00   0.00     1.00    127.58  1.19
sigma_y  0.01    0.00    0.01    0.01   0.01   0.01   0.01     1.00    117.00  1.09
sigma_z  0.10    0.01    0.09    0.10   0.10   0.10   0.13     1.01    83.57   0.78

In [13]:
pints.plot.trace(chains, ref_parameters=real_parameters1)
plt.show()


Plot posterior predictive distribution.


In [14]:
pints.plot.series(np.vstack(chains), problem)
plt.show()