Probability density functions in Pints can be defined via models and problems, but they can also be defined directly.
In this example, we implement the Rosenbrock function and run an optimisation using it.
The rosenbrock function is a two dimensional defined as
f(x,y) = -((a - x)^2 + b(y - x^2)^2)
where a
and b
are constants and x
and y
are variable. In analogy with typical Pints models x
and y
are our parameters.
First, take a look at the LogPDF interface. It tells us two things:
n_parameters
that tells pints the dimension of the parameter space.__call__
.The input to this method should be a vector, so we should rewrite it as
f(p) = -((a - p[0])^2 + b(p[1] - p[0]^2)^2)
The result of calling this method should be the logarithm of a normalised log likelihood. That means we should (1) take the logarithm of f
instead of returning it directly, and (2) invert the method, so that it has a clearly defined maximum that we can search for.
So we should create an object that evaluates
-log(f(p))
We now have all we need to implement a Rosenbrock
class:
In [1]:
import numpy as np
import pints
class Rosenbrock(pints.LogPDF):
def __init__(self, a=1, b=100):
self._a = a
self._b = b
def __call__(self, x):
return - np.log((self._a - x[0])**2 + self._b * (x[1] - x[0]**2)**2)
def n_parameters(self):
return 2
We can test our class by creating an object and calling it with a few parameters:
In [2]:
r = Rosenbrock()
print(r([0, 0]))
print(r([0.1, 0.1]))
print(r([0.4, 0.2]))
Wikipedia tells for that for a = 1
and b = 100
the minimum value should be at [1, 1]
. We can test this by inspecting its value at that point:
In [3]:
r([1, 1])
Out[3]:
We get an error here, because the notebook doesn't like it, but it returns the correct value!
Now let's try an optimisation:
In [4]:
# Define some boundaries
boundaries = pints.RectangularBoundaries([-5, -5], [5, 5])
# Pick an initial point
x0 = [2, 2]
# And run!
xbest, fbest = pints.optimise(r, x0, boundaries=boundaries)
Finally, print the returned point. If it worked, we should be at [1, 1]
:
In [5]:
print(xbest)