In [1]:
from scipy.stats import norm as normdist
from neatplots.predefined import four as palette
from neatplots.tools import SharedAxesGrid, Broadcast
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec, SubplotSpec
from plume.behaviors import DUCB, PDUCB, GO
from plume.prediction import OnlineGP, ExponentialKernel, Matern32Kernel, RBFKernel
%pylab inline
%load_ext autoreload
%autoreload 2
gray = (0.3, 0.3, 0.3)
In [2]:
import latexstyle
latexstyle.setup()
from latexstyle import ltx_to_mpl_pt_factor
In [3]:
def target_fn(x):
return 2 * (normdist.pdf(x, -3, 2) + normdist.pdf(x, 4))
x = np.linspace(-12, 12, 200)
plot(x, target_fn(x))
Out[3]:
In [4]:
class NoisyStateMock(object):
def __init__(self, x):
self.position = [x]
class AcquisitionIterationPlotter(object):
def __init__(self, gp_args, acq_fn_type, acq_fn_args):
self.gp_args = gp_args
self.acq_fn_type = acq_fn_type
self.acq_fn_args = acq_fn_args
self.plot_x = np.linspace(-15, 15, 200)
self.target_fn = target_fn
self.x0 = -7
def plot_acq_iter(self, ax, iteration):
self._ax = ax
xs = []
ys = []
gp = OnlineGP(*self.gp_args, expected_samples=iteration)
acq_fn = self.acq_fn_type(gp, *self.acq_fn_args)
x = self.x0
for i in xrange(iteration):
y = self.target_fn(x)
xs.append(x)
ys.append(y)
gp.add_observations(np.atleast_2d(x), np.atleast_2d(y))
utility = acq_fn(np.atleast_2d(self.plot_x).T, [NoisyStateMock(x)])
utility -= np.mean(utility)
utility /= np.max(np.abs(utility))
oldx = x
x = self.plot_x[np.argmax(utility)]
line, = ax.plot(self.plot_x, self.target_fn(self.plot_x), '--', c=gray, label='Target function')
line.set_dashes((3, 3))
mean, mse = gp.predict(np.atleast_2d(self.plot_x).T, eval_MSE=True)
ax.fill_between(
self.plot_x, np.squeeze(mean) - np.sqrt(mse), np.squeeze(mean) + np.sqrt(mse),
color=tuple(palette.thin[1]) + (0.2,), edgecolor='none')
ax.plot(self.plot_x, np.squeeze(mean), color=palette.thin[1], label=r'$\mu(x)$')
ax.scatter(xs, ys, marker='+', s=24, color=palette.thin[1], label=r'$X$')
ax.plot(
self.plot_x, utility, c=palette.highlight[3],
linewidth=0.8 * ltx_to_mpl_pt_factor, label='Normalized utility function')
ax.set_title('{} Iteration {}'.format(self.acq_fn_type.__name__, iteration))
return acq_fn, oldx
def get_legend_handles_labels(self):
handles, labels = self._ax.get_legend_handles_labels()
# Move Observations legend
handles.insert(1, handles.pop(2))
labels.insert(1, labels.pop(2))
# Insert Predictive SD legend
handles.insert(3, Rectangle((0, 0), 1, 1, fc=tuple(palette.thin[1]) + (0.2,), ec='none'))
labels.insert(3, r'$\pm \sigma(x)$')
return handles, labels
In [5]:
fig = plt.figure()
class AcqFnPlot(SharedAxesGrid):
def _create_axes(self, subplot_spec, sharex, sharey):
return fig.add_subplot(subplot_spec, sharex=sharex, sharey=sharey)
def _plot(self, ax, iteration, acq_plotter):
acq_plotter.plot_acq_iter(ax, iteration)
gp_args = (Matern32Kernel(2), 1e-10)
plotters = [
AcquisitionIterationPlotter(gp_args, DUCB, (1.25, 1, -2e-4)),
AcquisitionIterationPlotter(gp_args, PDUCB, (1.25, 70, -2e-4, 1e-30)),
AcquisitionIterationPlotter(gp_args, GO, (-2e-4,))]
legend_height = 0.05
grid = GridSpec(2, 1, 0.05, 0.025, 0.975, 0.95,height_ratios=(1.0 - legend_height, legend_height), hspace=1.5 * legend_height)
p = AcqFnPlot([3, 10, 20], plotters, SubplotSpec(grid, 0), wspace=0.1, hspace=0.2)
p.axes.set_xlim(-15, 15)
p.axes.set_ylim(-1.1, 1.1)
p.axes_by_row[-1].set_xlabel(r'$x$', labelpad=0)
p.axes_by_col[0].set_ylabel(r'$y$', labelpad=0, rotation='horizontal', verticalalignment='center')
for ax in p.axes_by_row[:-1]:
plt.setp(ax.get_xticklabels(), visible=False)
for ax in p.axes_by_col[1:]:
plt.setp(ax.get_yticklabels(), visible=False)
latexstyle.style_axes(*p.axes)
p.axes.titleOffsetTrans._t = (0.0, 1.0/72.0)
ax_legend = fig.add_subplot(grid[1])
ax_legend.set_axis_off()
ax_legend.legend(
*plotters[0].get_legend_handles_labels(), ncol=5, loc='upper center',
bbox_to_anchor=(0.5, 1.0), frameon=False, columnspacing=1.5, handletextpad=0.2)
Out[5]:
In [6]:
fig.savefig('../../thesis/plots/acqfns.pdf')
In [7]:
gp_args = (Matern32Kernel(2), 1e-10)
plotter = AcquisitionIterationPlotter(gp_args, DUCB, (1.1, 0.4, -1e-5))
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(2, 1, 1)
u, x = plotter.plot_acq_iter(ax, 7)
ax.set_xlim(-15, 15)
ax = fig.add_subplot(2, 1, 2)
der = u.eval_with_derivative(np.atleast_2d(plotter.plot_x).T, [NoisyStateMock(x)])[1]
#print der.shape
ax.plot(plotter.plot_x, der)
ax.set_xlim(-15, 15)
#ax.set_ylim(-1.5, 1.5)
ax.axhline(0)
Out[7]:
In [7]: