In [1]:
from scipy.stats import norm as normdist

from neatplots.predefined import four as palette
from neatplots.tools import SharedAxesGrid, Broadcast
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec, SubplotSpec

from plume.behaviors import DUCB, PDUCB, GO
from plume.prediction import OnlineGP, ExponentialKernel, Matern32Kernel, RBFKernel

%pylab inline
%load_ext autoreload
%autoreload 2

gray = (0.3, 0.3, 0.3)


Populating the interactive namespace from numpy and matplotlib

In [2]:
import latexstyle
latexstyle.setup()
from latexstyle import ltx_to_mpl_pt_factor

In [3]:
def target_fn(x):
    return 2 * (normdist.pdf(x, -3, 2) + normdist.pdf(x, 4))

x = np.linspace(-12, 12, 200)
plot(x, target_fn(x))


Out[3]:
[<matplotlib.lines.Line2D at 0x1085faa10>]

In [4]:
class NoisyStateMock(object):
    def __init__(self, x):
        self.position = [x]

class AcquisitionIterationPlotter(object):
    def __init__(self, gp_args, acq_fn_type, acq_fn_args):
        self.gp_args = gp_args
        self.acq_fn_type = acq_fn_type
        self.acq_fn_args = acq_fn_args
        
        self.plot_x = np.linspace(-15, 15, 200)
        self.target_fn = target_fn
        self.x0 = -7
        
    def plot_acq_iter(self, ax, iteration):
        self._ax = ax
        
        xs = []
        ys = []
        
        gp = OnlineGP(*self.gp_args, expected_samples=iteration)
        acq_fn = self.acq_fn_type(gp, *self.acq_fn_args)
        
        x = self.x0
        for i in xrange(iteration):
            y = self.target_fn(x)
            xs.append(x)
            ys.append(y)
    
            gp.add_observations(np.atleast_2d(x), np.atleast_2d(y))
            utility = acq_fn(np.atleast_2d(self.plot_x).T, [NoisyStateMock(x)])
            utility -= np.mean(utility)
            utility /= np.max(np.abs(utility))
            oldx = x
            x = self.plot_x[np.argmax(utility)]
    
        line, = ax.plot(self.plot_x, self.target_fn(self.plot_x), '--', c=gray, label='Target function')
        line.set_dashes((3, 3))
        mean, mse = gp.predict(np.atleast_2d(self.plot_x).T, eval_MSE=True)
        ax.fill_between(
            self.plot_x, np.squeeze(mean) - np.sqrt(mse), np.squeeze(mean) + np.sqrt(mse),
            color=tuple(palette.thin[1]) + (0.2,), edgecolor='none')
        ax.plot(self.plot_x, np.squeeze(mean), color=palette.thin[1], label=r'$\mu(x)$')
        ax.scatter(xs, ys, marker='+', s=24, color=palette.thin[1], label=r'$X$')
        ax.plot(
            self.plot_x, utility, c=palette.highlight[3],
            linewidth=0.8 * ltx_to_mpl_pt_factor, label='Normalized utility function')
        ax.set_title('{} Iteration {}'.format(self.acq_fn_type.__name__, iteration))
        return acq_fn, oldx
    
    def get_legend_handles_labels(self):
        handles, labels = self._ax.get_legend_handles_labels()
        # Move Observations legend
        handles.insert(1, handles.pop(2))
        labels.insert(1, labels.pop(2))
        # Insert Predictive SD legend
        handles.insert(3, Rectangle((0, 0), 1, 1, fc=tuple(palette.thin[1]) + (0.2,), ec='none'))
        labels.insert(3, r'$\pm \sigma(x)$')
        return handles, labels

In [5]:
fig = plt.figure()

class AcqFnPlot(SharedAxesGrid):
    def _create_axes(self, subplot_spec, sharex, sharey):
        return fig.add_subplot(subplot_spec, sharex=sharex, sharey=sharey)
        
    def _plot(self, ax, iteration, acq_plotter):
        acq_plotter.plot_acq_iter(ax, iteration)

gp_args = (Matern32Kernel(2), 1e-10)
plotters = [
    AcquisitionIterationPlotter(gp_args, DUCB, (1.25, 1, -2e-4)),
    AcquisitionIterationPlotter(gp_args, PDUCB, (1.25, 70, -2e-4, 1e-30)),
    AcquisitionIterationPlotter(gp_args, GO, (-2e-4,))]

legend_height = 0.05
grid = GridSpec(2, 1, 0.05, 0.025, 0.975, 0.95,height_ratios=(1.0 - legend_height, legend_height), hspace=1.5 * legend_height)

p = AcqFnPlot([3, 10, 20], plotters, SubplotSpec(grid, 0), wspace=0.1, hspace=0.2)
p.axes.set_xlim(-15, 15)
p.axes.set_ylim(-1.1, 1.1)
p.axes_by_row[-1].set_xlabel(r'$x$', labelpad=0)
p.axes_by_col[0].set_ylabel(r'$y$', labelpad=0, rotation='horizontal', verticalalignment='center')
for ax in p.axes_by_row[:-1]:
    plt.setp(ax.get_xticklabels(), visible=False)
for ax in p.axes_by_col[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
latexstyle.style_axes(*p.axes)
p.axes.titleOffsetTrans._t = (0.0, 1.0/72.0)

ax_legend = fig.add_subplot(grid[1])
ax_legend.set_axis_off()
ax_legend.legend(
    *plotters[0].get_legend_handles_labels(), ncol=5, loc='upper center',
    bbox_to_anchor=(0.5, 1.0), frameon=False, columnspacing=1.5, handletextpad=0.2)


Out[5]:
<matplotlib.legend.Legend at 0x108727490>

In [6]:
fig.savefig('../../thesis/plots/acqfns.pdf')

In [7]:
gp_args = (Matern32Kernel(2), 1e-10)
plotter = AcquisitionIterationPlotter(gp_args, DUCB, (1.1, 0.4, -1e-5))
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(2, 1, 1)
u, x = plotter.plot_acq_iter(ax, 7)
ax.set_xlim(-15, 15)
ax = fig.add_subplot(2, 1, 2)
der = u.eval_with_derivative(np.atleast_2d(plotter.plot_x).T, [NoisyStateMock(x)])[1]
#print der.shape
ax.plot(plotter.plot_x, der)
ax.set_xlim(-15, 15)
#ax.set_ylim(-1.5, 1.5)
ax.axhline(0)


Out[7]:
<matplotlib.lines.Line2D at 0x10f2ffa50>

In [7]: