In [1]:
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

from matplotlib import animation
from matplotlib import rc
rc('animation', html='html5')

1. Test function

  • Rosenbrock function
$$ f(x) = 100(x_2 - x_1^2)^2 + (1-x_1)^2 $$
  • Optimal: $x^* = (1, 1)^T$

In [2]:
def rsb(x):
    x = np.asarray(x, np.float32)
    return 100 * (x[..., 1] - x[..., 0] ** 2) ** 2 + (1 - x[..., 0]) ** 2

Heapmap


In [3]:
xx, yy = np.linspace(-10, 10, 200), np.linspace(-30, 70, 200)
xx, yy = np.meshgrid(xx, yy)
fx = rsb(np.stack((xx, yy), axis=-1))

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.pcolormesh(xx, yy, fx ** 0.3, cmap="autumn")
fig.colorbar(im)
plt.scatter([1], [1])


Out[3]:
<matplotlib.collections.PathCollection at 0x22632f870f0>

2. Compute Gradient and Hessian

  • Gradient
$$ \nabla f(x) = \left(\begin{array}{c} -400x_1(x_2-x_1^2) - 2(1 - x_1) \\ 200(x_2-x_1^2) \end{array} \right) $$
  • Hessian
$$ \nabla^2f(x) = \left(\begin{array}{cc} -400x_2 + 1200x_1^2 + 2 & -400x_1 \\ -400x_1 & 200 \end{array} \right) $$

In [4]:
def rsb_g(x):
    """ Gradient """
    x = np.asarray(x, np.float32)
    g1 = -400 * x[..., 0] * (x[..., 1] - x[..., 0] ** 2)- 2 * (1 - x[..., 0])
    g2 = 200 * (x[..., 1] - x[..., 0] ** 2)
    return np.stack((g1, g2), axis=-1)
    
def rsb_h(x):
    """ Hessian """
    x = np.asarray(x, np.float32)
    h11 = -400 * x[..., 1] + 1200 * x[..., 0] ** 2 + 2
    h12 = h21 = -400 * x[..., 0]
    h22 = 200 * np.ones_like(x[..., 0])
    h1 = np.stack((h11, h21), axis=-1)
    h2 = np.stack((h12, h22), axis=-1)
    h = np.stack((h1, h2), axis=-1)
    return h

def rsb_q(xk, s):
    """ Lagrange """
    return rsb(xk) + np.sum(rsb_g(xk) * s) + 0.5 * np.sum(np.sum(s[:, None] * rsb_h(xk), axis=0) * s)

3. Create Itearation Step


In [5]:
def step(xk, uk, tol=1e-10, verbose=False):
    Gk = rsb_h(xk)
    counter = 0
    while True:
        if counter > 10:
            raise ValueError("Error")
        summ = Gk + uk * np.eye(*Gk.shape)
        if np.all(np.linalg.eigvals(summ) > 0):
            break
        uk *= 4
        counter += 1
    sk = np.linalg.solve(summ, -rsb_g(xk))
    
    fk = rsb(xk)
    fkpsk = rsb(xk + sk)
    delta_f = fk - fkpsk
    if abs(delta_f) < tol:
        # We set a tolerance to avoid 0/0
        if verbose:
            print("Finish with delta_f =", delta_f)
        return xk, uk, delta_f
    delta_q = fk - rsb_q(xk, sk)
    rk = delta_f / delta_q
    if rk < 0.25:
        ukp1 = 4 * uk
    elif rk > 0.75:
        ukp1 = uk / 2.
    else:
        ukp1 = uk
    
    if rk <= 0:
        xkp1 = xk
    else:
        xkp1 = xk + sk
    return xkp1, ukp1, delta_f

3.1 Simulation


In [6]:
k, xk, uk, tol = 0, [5, 50], 1, 1e-10

In [7]:
while True:
    xk, uk, df = step(xk, uk, tol, verbose=True)
    k += 1
    if abs(df) < tol:
        break
print("Total steps:", k)
print("Found solution:", xk)


Finish with delta_f = 2.3305801732931286e-12
Total steps: 63
Found solution: [1.000001   1.00000205]

4. Create Animation


In [8]:
class RSBViewer(object):
    def __init__(self, step_fn, x0, u0, tol=1e-10, sln=None):
        self.step_fn = step_fn
        self.x0 = self.xk = x0
        self.u0 = self.uk = u0
        self.tol = tol
        self.sln = sln
        self.all_pts = [self.xk]
        
    def step(self):
        self.xk, self.uk, self.df = self.step_fn(self.xk, self.uk, self.tol)
        self.all_pts.append(self.xk)
        
    def draw(self, *args):
        """Draws the array and any other elements.

        contour: boolean, whether to draw contours
        """
        self.hm = plt.pcolormesh(*args, cmap="autumn")
        if self.sln is not None:
            plt.scatter(self.sln[:1], self.sln[1:], c=["yellow"])
        self.draw_points()

    def draw_points(self, color=None):
        self.path_collection = plt.scatter(*np.array(self.all_pts).T, c="blue")
        
    def init_func(self):
        pass
        
    def animate(self, *args, frames=20, interval=200, contour=False):
        """Creates an animation.

        frames: number of frames to draw
        interval: time between frames in ms
        """
        fig = plt.figure(figsize=(10, 8))
        plt.axis("off")
        self.draw(*args)
        anim = animation.FuncAnimation(fig, self.animate_func,
                                       init_func=self.init_func,
                                       frames=frames, interval=interval)
        return anim
    
    def animate_func(self, i):
        if i > 0:
            self.step()
        self.path_collection.set_offsets(np.array(self.all_pts))
        return self.path_collection

In [9]:
viewer = RSBViewer(step, [5, 50], 1, sln=[1, 1])
anim = viewer.animate(xx, yy, fx ** 0.3, frames=70, interval=100)



In [10]:
anim


Out[10]:

In [ ]: