In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from functools import partial

import numpy as np
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel


from pypuffin.decorators import disable_typechecking
from pypuffin.numeric.mcmc.gphmc import GPHMC
from pypuffin.numeric.mcmc.hmc import ScalarMassHMC

np.random.seed(0)

# This now has the correct arguments to be passed to GPHMC, which just wants a way
# of creating an HMC sampler given a starting point, potential function, and gradient potential
# function
f_construct_hmc = partial(ScalarMassHMC, mass=1, num_leapfrog_steps=500, eps=1e-2)

# The GP that is going to be fit to our log probability distribution
# kernel = ConstantKernel(1, (1e-5, 1e5)) * RBF(10, (1e-4, 1e4))
kernel = RBF(10, (1e-4, 1e4))
regressor = GaussianProcessRegressor(kernel, n_restarts_optimizer=20)

# This is our target log probability - unnormalised gaussian distribution. Transform upwards to avoid zero mean
# problem...
f_target_log_prob = lambda x: 20 - x.dot(x)

x_start = np.asarray([0.1])
gphmc = GPHMC(f_target_log_prob, regressor, f_construct_hmc, x_start)

# FIXME HACKHACKHACK
# FIXME HACKHACKHACK
# TODO intialise gphmc with several training points so it doesn't blow up

# Training steps...
# with disable_typechecking():

for i in range(100):
    gphmc.sample_explore()
    print(gphmc._regressor.kernel_)
    if i % 5 != 0:
        continue
#     print(np.asarray(gphmc._x_train).flatten())
#     print(np.asarray(gphmc._y_train))
    
    x_train = np.asarray(gphmc._x_train).flatten()
    y_train = np.asarray(gphmc._y_train)
    limit = max(3, 1.2 * np.max(np.abs(x_train)))
    
    x_samples = np.linspace(-limit, limit, 500)
    y_samples, std_samples = gphmc.predict_gp(x_samples[:, np.newaxis], return_std=True)
    plt.plot(x_samples, y_samples, label=f'mean {i}')
    plt.plot(x_samples, y_samples - std_samples, label=f'mean - std {i}')
    plt.plot(x_samples, 20 - x_samples ** 2, label='actual')
    plt.scatter(x_train, y_train, label='evaluations')
    plt.legend()
    plt.show()


RBF(length_scale=10)
[ 0.1] [ 6.39006084]
RBF(length_scale=0.0001)
[ 6.39006084] [ 5.36426952]
RBF(length_scale=0.677)
[ 5.36426952] [-10.70041194]
RBF(length_scale=0.663)
[-10.70041194] [ 44.76815223]
RBF(length_scale=11.9)
[ 44.76815223] [-166.44384979]
RBF(length_scale=37.1)
/Users/thomas/Documents/Programming/pypuffin/pypuffin/sklearn/gaussian_process.py:125: RuntimeWarning: divide by zero encountered in true_divide
  result = d_variance / (2 * std[:, numpy.newaxis])
[-166.44384979] [ 643.88008584]
RBF(length_scale=133)
[ 643.88008584] [-2259.04999131]
RBF(length_scale=425)
[-2259.04999131] [-11230.66220664]
RBF(length_scale=2.05e+03)
[-11230.66220664] [-59343.03960794]
RBF(length_scale=1e+04)
[-59343.03960794] [-292549.21000451]
RBF(length_scale=1e+04)
[-292549.21000451] [-2082868.14300882]
RBF(length_scale=1e+04)
[-2082868.14300882] [-2082868.14300882]
RBF(length_scale=1e+04)
[-2082868.14300882] [-2082868.14300882]
RBF(length_scale=1e+04)
[-2082868.14300882] [-14259519.12880317]
RBF(length_scale=1e+04)
[-14259519.12880317] [ 3752890.9298541]
RBF(length_scale=1e+04)
[ 3752890.9298541] [ 3752890.92963877]
RBF(length_scale=1e+04)
[ 3752890.92963877] [ 3752890.92963877]
RBF(length_scale=1e+04)
[ 3752890.92963877] [ 3752890.92963877]
RBF(length_scale=1e+04)
[ 3752890.92963877] [ 3752890.92963877]
RBF(length_scale=1e+04)
[ 3752890.92963877] [ 3752890.93197945]
RBF(length_scale=1e+04)
[ 3752890.93197945] [ 3752890.64626371]
RBF(length_scale=1e+04)
[ 3752890.64626371] [ 3752890.46071255]
RBF(length_scale=1e+04)
[ 3752890.46071255] [ 3752178.47704806]
RBF(length_scale=1e+04)
[ 3752178.47704806] [ 3751827.5023091]
RBF(length_scale=1e+04)
[ 3751827.5023091] [ 3751338.62452059]
RBF(length_scale=1e+04)
[ 3751338.62452059] [ 3746497.05916292]
RBF(length_scale=1e+04)
[ 3746497.05916292] [ 3745835.81609377]
RBF(length_scale=1e+04)
[ 3745835.81609377] [ 3737165.25164686]
RBF(length_scale=1e+04)
[ 3737165.25164686] [ 3736979.88243073]
RBF(length_scale=1e+04)
[ 3736979.88243073] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
RBF(length_scale=1e+04)
[ 3730656.31483833] [ 3730656.31483833]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-2-9c6934272e4b> in <module>()
     45 
     46 for i in range(100):
---> 47     gphmc.sample_explore()
     48     print(gphmc._regressor.kernel_)
     49     if i % 5 != 0:

~/Documents/Programming/pypuffin/pypuffin/numeric/mcmc/gphmc.py in sample_explore(self)
    127 
    128         # Refit the distribution using all training data, and return our current location
--> 129         self._fit_gp()
    130         return x_1
    131 

~/Documents/Programming/pypuffin/pypuffin/numeric/mcmc/gphmc.py in _fit_gp(self)
     58         y_train_array = numpy.asarray(self._y_train)
     59         mean = numpy.mean(y_train_array, axis=0)
---> 60         return self._regressor.fit(x_train_array, y_train_array - mean)
     61 
     62     def predict_gp(self, x, return_std=False):  # pylint: disable=invalid-name

~/anaconda/lib/python3.6/site-packages/sklearn/gaussian_process/gpr.py in fit(self, X, y)
    230                     optima.append(
    231                         self._constrained_optimization(obj_func, theta_initial,
--> 232                                                        bounds))
    233             # Select result from run with minimal (negative) log-marginal
    234             # likelihood

~/anaconda/lib/python3.6/site-packages/sklearn/gaussian_process/gpr.py in _constrained_optimization(self, obj_func, initial_theta, bounds)
    452         if self.optimizer == "fmin_l_bfgs_b":
    453             theta_opt, func_min, convergence_dict = \
--> 454                 fmin_l_bfgs_b(obj_func, initial_theta, bounds=bounds)
    455             if convergence_dict["warnflag"] != 0:
    456                 warnings.warn("fmin_l_bfgs_b terminated abnormally with the "

~/anaconda/lib/python3.6/site-packages/scipy/optimize/lbfgsb.py in fmin_l_bfgs_b(func, x0, fprime, args, approx_grad, bounds, m, factr, pgtol, epsilon, iprint, maxfun, maxiter, disp, callback, maxls)
    191 
    192     res = _minimize_lbfgsb(fun, x0, args=args, jac=jac, bounds=bounds,
--> 193                            **opts)
    194     d = {'grad': res['jac'],
    195          'task': res['message'],

~/anaconda/lib/python3.6/site-packages/scipy/optimize/lbfgsb.py in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, **unknown_options)
    300         raise ValueError('maxls must be positive.')
    301 
--> 302     x = array(x0, float64)
    303     f = array(0.0, float64)
    304     g = zeros((n,), float64)

KeyboardInterrupt: