In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import lasagne
import theano
import theano.tensor as T
import seaborn

import logging
logging.basicConfig(level=logging.INFO)


from robo.models.lcnet.basis_functions import vapor_pressure, exponential, hill_3, log_power, pow_func
from robo.models.lcnet import LCNet, get_lc_net


seaborn.set_style(style='whitegrid')

plt.rc('text', usetex=True)
plt.rc('font', size=15.0, family='serif')
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [2]:
n_epochs = 100
N = 500
t_grid = np.linspace(1, n_epochs, n_epochs) / n_epochs


fig = plt.figure(figsize=(28,4))
ax1 = fig.add_subplot(1,5,1)
ax2 = fig.add_subplot(1,5,2)
ax3 = fig.add_subplot(1,5,3)
ax4 = fig.add_subplot(1,5,4)
ax5 = fig.add_subplot(1,5,5)

for i in range(N):
    theta = np.random.rand(13)
    phi = [vapor_pressure(t, theta[0], theta[1], theta[2]) for t_idx, t in enumerate(t_grid)]

    ax1.plot(t_grid, phi)
    ax1.set_xlabel(r"$t$", fontsize=20)
    ax1.set_ylabel(r"$\phi(\boldsymbol{x}, t)$", fontsize=20)
    ax1.set_title("vapor pressure", fontsize=25)

    phi = [pow_func(t, theta[3], theta[4]) for t_idx, t in enumerate(t_grid)]

    ax2.plot(t_grid, phi)
    ax2.set_xlabel(r"$t$", fontsize=20)
    ax2.set_ylabel(r"$\phi(\boldsymbol{x}, t)$", fontsize=20)
    ax2.set_title("pow-3", fontsize=25)

    phi = [log_power(t, theta[5], theta[6], theta[7]) for t_idx, t in enumerate(t_grid)]

    ax3.plot(t_grid, phi)
    ax3.set_xlabel(r"$t$", fontsize=20)
    ax3.set_ylabel(r"$\phi(\boldsymbol{x}, t)$", fontsize=20)
    ax3.set_title("log-power", fontsize=25)

    phi = [exponential(t, theta[8], theta[9]) for t_idx, t in enumerate(t_grid)]

    ax4.plot(t_grid, phi)
    ax4.set_xlabel(r"$t$", fontsize=20)
    ax4.set_ylabel(r"$\phi(\boldsymbol{x}, t)$", fontsize=20)
    ax4.set_title("exponential", fontsize=25)  

    phi = [hill_3(t, theta[10], theta[11], theta[12]) for t_idx, t in enumerate(t_grid)]

    ax5.plot(t_grid, phi)
    ax5.set_xlabel(r"$t$", fontsize=20)
    ax5.set_ylabel(r"$\phi(\boldsymbol{x}, t)$", fontsize=20)
    ax5.set_title("hill-3", fontsize=25)



In [3]:
def toy_example(t, a, b):
    return (10 + a * np.log(b * t)) / 10. + 10e-3 * np.random.rand()

current_palette = seaborn.color_palette("Paired", 10)
seaborn.set_palette(current_palette)

observed = 20
N = 200
n_epochs = 100
observed_t = int(n_epochs * (observed / 100.))

t_idx = np.arange(1, observed_t + 1) / n_epochs
t_grid = np.arange(1, n_epochs + 1) / n_epochs

configs = np.random.rand(N, 2)
learning_curves = [toy_example(t_grid, configs[i, 0], configs[i, 1]) for i in range(N)]


X_train = None
y_train = None
X_test = None
y_test = None

for i in range(N):

    x = np.repeat(configs[i, None, :], t_idx.shape[0], axis=0)
    x = np.concatenate((x, t_idx[:, None]), axis=1)

    x_test = np.concatenate((configs[i, None, :], np.array([[1]])), axis=1)

    lc = learning_curves[i][:observed_t]
    lc_test = np.array([learning_curves[i][-1]])

    if X_train is None:
        X_train = x
        y_train = lc
        X_test = x_test
        y_test = lc_test
    else:
        X_train = np.concatenate((X_train, x), 0)
        y_train = np.concatenate((y_train, lc), 0)
        X_test = np.concatenate((X_test, x_test), 0)
        y_test = np.concatenate((y_test, lc_test), 0)        
    
    plt.plot(t_idx * n_epochs, lc)
plt.xlabel("Epochs", fontsize=20)
plt.ylabel("Validation Accuracy", fontsize=20)
plt.title("Training Data", fontsize=20)  
plt.xlim(1, n_epochs)
plt.show()



In [4]:
model = LCNet(sampling_method="sghmc",
              l_rate=np.sqrt(1e-4),
              mdecay=.05,
              n_nets=100,
              burn_in=5000,
              n_iters=30000,
              get_net=get_lc_net,
              precondition=True)

model.train(X_train, y_train)


WARNING (theano.tensor.blas): We did not found a dynamic library into the library_dir of the library we use for blas. If you use ATLAS, make sure to compile it with dynamics library.
WARNING:theano.tensor.blas:We did not found a dynamic library into the library_dir of the library we use for blas. If you use ATLAS, make sure to compile it with dynamics library.
INFO:root:Starting sampling
... compiling theano function
INFO:root:Iter        0 : NLL =  2.8682e+00 MSE = 3.2122e-01 Time = 12.32
INFO:root:Iter      512 : NLL = -3.6275e+00 MSE = 2.8227e-04 Time = 13.85
INFO:root:Iter     1024 : NLL = -4.0308e+00 MSE = 1.6477e-04 Time = 15.43
INFO:root:Iter     1536 : NLL = -4.2248e+00 MSE = 2.0352e-04 Time = 17.27
INFO:root:Iter     2048 : NLL = -4.1989e+00 MSE = 1.4375e-04 Time = 18.87
INFO:root:Iter     2560 : NLL = -4.3585e+00 MSE = 1.4744e-04 Time = 20.54
INFO:root:Iter     3072 : NLL = -4.3997e+00 MSE = 9.3837e-05 Time = 22.17
INFO:root:Iter     3584 : NLL = -4.4361e+00 MSE = 8.8999e-05 Time = 24.61
INFO:root:Iter     4096 : NLL = -4.2700e+00 MSE = 1.4923e-04 Time = 26.70
INFO:root:Iter     4608 : NLL = -4.1860e+00 MSE = 1.4302e-04 Time = 28.48
... compiling theano function
INFO:root:Iter     5000 : NLL = -4.2933e+00 MSE = 8.1963e-05 Samples= 1 Time = 35.20
INFO:root:Iter     5100 : NLL = -4.4471e+00 MSE = 7.4091e-05 Samples= 2 Time = 35.48
INFO:root:Iter     5200 : NLL = -4.5344e+00 MSE = 7.5038e-05 Samples= 3 Time = 35.77
INFO:root:Iter     5300 : NLL = -4.4351e+00 MSE = 1.4156e-04 Samples= 4 Time = 36.03
INFO:root:Iter     5400 : NLL = -4.4876e+00 MSE = 8.4800e-05 Samples= 5 Time = 36.31
INFO:root:Iter     5500 : NLL = -4.2676e+00 MSE = 2.2038e-04 Samples= 6 Time = 36.59
INFO:root:Iter     5600 : NLL = -3.9226e+00 MSE = 1.7672e-04 Samples= 7 Time = 36.87
INFO:root:Iter     5700 : NLL = -4.4651e+00 MSE = 9.3725e-05 Samples= 8 Time = 37.15
INFO:root:Iter     5800 : NLL = -4.5362e+00 MSE = 9.0954e-05 Samples= 9 Time = 37.42
INFO:root:Iter     5900 : NLL = -4.5964e+00 MSE = 7.5791e-05 Samples= 10 Time = 37.70
INFO:root:Iter     6000 : NLL = -4.4253e+00 MSE = 9.8829e-05 Samples= 11 Time = 37.98
INFO:root:Iter     6100 : NLL = -4.4162e+00 MSE = 9.4607e-05 Samples= 12 Time = 38.27
INFO:root:Iter     6200 : NLL = -4.4082e+00 MSE = 7.7889e-05 Samples= 13 Time = 38.55
INFO:root:Iter     6300 : NLL = -4.2985e+00 MSE = 1.1524e-04 Samples= 14 Time = 38.85
INFO:root:Iter     6400 : NLL = -4.5506e+00 MSE = 5.9104e-05 Samples= 15 Time = 39.16
INFO:root:Iter     6500 : NLL = -4.5594e+00 MSE = 6.0905e-05 Samples= 16 Time = 39.48
INFO:root:Iter     6600 : NLL = -4.1209e+00 MSE = 9.0220e-05 Samples= 17 Time = 39.78
INFO:root:Iter     6700 : NLL = -4.6139e+00 MSE = 7.6193e-05 Samples= 18 Time = 40.08
INFO:root:Iter     6800 : NLL = -4.0367e+00 MSE = 2.3808e-04 Samples= 19 Time = 40.36
INFO:root:Iter     6900 : NLL = -3.0919e+00 MSE = 1.9727e-04 Samples= 20 Time = 40.63
INFO:root:Iter     7000 : NLL = -4.1279e+00 MSE = 1.3434e-04 Samples= 21 Time = 40.92
INFO:root:Iter     7100 : NLL = -4.4278e+00 MSE = 7.1402e-05 Samples= 22 Time = 41.20
INFO:root:Iter     7200 : NLL = -4.3834e+00 MSE = 1.1969e-04 Samples= 23 Time = 41.48
INFO:root:Iter     7300 : NLL = -4.4360e+00 MSE = 8.4297e-05 Samples= 24 Time = 41.81
INFO:root:Iter     7400 : NLL = -4.6364e+00 MSE = 5.8254e-05 Samples= 25 Time = 42.37
INFO:root:Iter     7500 : NLL = -4.5338e+00 MSE = 8.2219e-05 Samples= 26 Time = 42.79
INFO:root:Iter     7600 : NLL = -4.1636e+00 MSE = 7.0663e-05 Samples= 27 Time = 43.21
INFO:root:Iter     7700 : NLL = -4.3722e+00 MSE = 9.1811e-05 Samples= 28 Time = 43.62
INFO:root:Iter     7800 : NLL = -4.1711e+00 MSE = 1.2261e-04 Samples= 29 Time = 44.11
INFO:root:Iter     7900 : NLL = -4.4304e+00 MSE = 4.6947e-05 Samples= 30 Time = 44.55
INFO:root:Iter     8000 : NLL = -4.1830e+00 MSE = 1.1042e-04 Samples= 31 Time = 44.94
INFO:root:Iter     8100 : NLL = -3.9499e+00 MSE = 2.3441e-04 Samples= 32 Time = 45.40
INFO:root:Iter     8200 : NLL = -4.2443e+00 MSE = 7.6210e-05 Samples= 33 Time = 45.82
INFO:root:Iter     8300 : NLL = -4.4676e+00 MSE = 6.6756e-05 Samples= 34 Time = 46.17
INFO:root:Iter     8400 : NLL = -4.4358e+00 MSE = 5.9845e-05 Samples= 35 Time = 46.50
INFO:root:Iter     8500 : NLL = -4.6232e+00 MSE = 1.5525e-05 Samples= 36 Time = 46.83
INFO:root:Iter     8600 : NLL = -4.4263e+00 MSE = 3.5861e-05 Samples= 37 Time = 47.17
INFO:root:Iter     8700 : NLL = -4.1311e+00 MSE = 1.0475e-04 Samples= 38 Time = 47.51
INFO:root:Iter     8800 : NLL = -4.4962e+00 MSE = 3.8080e-05 Samples= 39 Time = 47.87
INFO:root:Iter     8900 : NLL = -4.4321e+00 MSE = 1.2741e-04 Samples= 40 Time = 48.21
INFO:root:Iter     9000 : NLL = -4.4408e+00 MSE = 7.0874e-05 Samples= 41 Time = 48.54
INFO:root:Iter     9100 : NLL = -4.4328e+00 MSE = 1.0648e-04 Samples= 42 Time = 48.88
INFO:root:Iter     9200 : NLL = -4.5025e+00 MSE = 3.9651e-05 Samples= 43 Time = 49.23
INFO:root:Iter     9300 : NLL = -4.3178e+00 MSE = 7.4502e-05 Samples= 44 Time = 49.58
INFO:root:Iter     9400 : NLL = -4.2762e+00 MSE = 5.7705e-05 Samples= 45 Time = 49.91
INFO:root:Iter     9500 : NLL = -4.1647e+00 MSE = 7.1158e-05 Samples= 46 Time = 50.22
INFO:root:Iter     9600 : NLL = -4.5250e+00 MSE = 2.7285e-05 Samples= 47 Time = 50.51
INFO:root:Iter     9700 : NLL = -4.4925e+00 MSE = 4.2066e-05 Samples= 48 Time = 50.83
INFO:root:Iter     9800 : NLL = -4.4105e+00 MSE = 9.0610e-05 Samples= 49 Time = 51.13
INFO:root:Iter     9900 : NLL = -4.4329e+00 MSE = 6.0613e-05 Samples= 50 Time = 51.40
INFO:root:Iter    10000 : NLL = -4.3171e+00 MSE = 4.5552e-05 Samples= 51 Time = 51.69
INFO:root:Iter    10100 : NLL = -4.6382e+00 MSE = 1.7664e-05 Samples= 52 Time = 51.98
INFO:root:Iter    10200 : NLL = -4.4677e+00 MSE = 3.5776e-05 Samples= 53 Time = 52.26
INFO:root:Iter    10300 : NLL = -4.5164e+00 MSE = 4.7030e-05 Samples= 54 Time = 52.53
INFO:root:Iter    10400 : NLL = -4.6012e+00 MSE = 2.1792e-05 Samples= 55 Time = 52.81
INFO:root:Iter    10500 : NLL = -3.6881e+00 MSE = 1.5650e-04 Samples= 56 Time = 53.08
INFO:root:Iter    10600 : NLL = -4.3508e+00 MSE = 6.7612e-05 Samples= 57 Time = 53.38
INFO:root:Iter    10700 : NLL = -4.2989e+00 MSE = 5.4308e-05 Samples= 58 Time = 53.65
INFO:root:Iter    10800 : NLL = -4.5162e+00 MSE = 2.6591e-05 Samples= 59 Time = 53.97
INFO:root:Iter    10900 : NLL = -4.5525e+00 MSE = 2.5642e-05 Samples= 60 Time = 54.30
INFO:root:Iter    11000 : NLL = -4.5200e+00 MSE = 3.1306e-05 Samples= 61 Time = 54.62
INFO:root:Iter    11100 : NLL = -4.3727e+00 MSE = 4.2244e-05 Samples= 62 Time = 54.97
INFO:root:Iter    11200 : NLL = -4.5424e+00 MSE = 3.0845e-05 Samples= 63 Time = 55.28
INFO:root:Iter    11300 : NLL = -4.2831e+00 MSE = 1.2262e-04 Samples= 64 Time = 55.60
INFO:root:Iter    11400 : NLL = -4.6202e+00 MSE = 1.8350e-05 Samples= 65 Time = 55.94
INFO:root:Iter    11500 : NLL = -4.5490e+00 MSE = 2.6208e-05 Samples= 66 Time = 56.28
INFO:root:Iter    11600 : NLL = -4.2252e+00 MSE = 5.5760e-05 Samples= 67 Time = 56.61
INFO:root:Iter    11700 : NLL = -4.5001e+00 MSE = 2.8553e-05 Samples= 68 Time = 56.95
INFO:root:Iter    11800 : NLL = -4.4028e+00 MSE = 5.7712e-05 Samples= 69 Time = 57.28
INFO:root:Iter    11900 : NLL = -4.4923e+00 MSE = 4.3243e-05 Samples= 70 Time = 57.62
INFO:root:Iter    12000 : NLL = -4.4185e+00 MSE = 3.8547e-05 Samples= 71 Time = 58.03
INFO:root:Iter    12100 : NLL = -4.4109e+00 MSE = 3.8465e-05 Samples= 72 Time = 58.52
INFO:root:Iter    12200 : NLL = -4.0333e+00 MSE = 1.2494e-04 Samples= 73 Time = 59.03
INFO:root:Iter    12300 : NLL = -4.4652e+00 MSE = 3.3722e-05 Samples= 74 Time = 59.55
INFO:root:Iter    12400 : NLL = -4.2621e+00 MSE = 5.2959e-05 Samples= 75 Time = 60.06
INFO:root:Iter    12500 : NLL = -4.5096e+00 MSE = 2.5886e-05 Samples= 76 Time = 60.54
INFO:root:Iter    12600 : NLL = -4.2840e+00 MSE = 7.4629e-05 Samples= 77 Time = 61.03
INFO:root:Iter    12700 : NLL = -4.5105e+00 MSE = 3.2895e-05 Samples= 78 Time = 61.51
INFO:root:Iter    12800 : NLL = -4.5495e+00 MSE = 2.3725e-05 Samples= 79 Time = 61.97
INFO:root:Iter    12900 : NLL = -4.2356e+00 MSE = 1.0540e-04 Samples= 80 Time = 62.37
INFO:root:Iter    13000 : NLL = -4.4730e+00 MSE = 3.3421e-05 Samples= 81 Time = 62.78
INFO:root:Iter    13100 : NLL = -4.5666e+00 MSE = 2.2240e-05 Samples= 82 Time = 63.19
INFO:root:Iter    13200 : NLL = -4.3571e+00 MSE = 4.7572e-05 Samples= 83 Time = 63.59
INFO:root:Iter    13300 : NLL = -4.4436e+00 MSE = 3.1463e-05 Samples= 84 Time = 64.03
INFO:root:Iter    13400 : NLL = -4.5205e+00 MSE = 2.7791e-05 Samples= 85 Time = 64.45
INFO:root:Iter    13500 : NLL = -4.3454e+00 MSE = 5.0152e-05 Samples= 86 Time = 64.85
INFO:root:Iter    13600 : NLL = -4.3748e+00 MSE = 4.1866e-05 Samples= 87 Time = 65.25
INFO:root:Iter    13700 : NLL = -3.9806e+00 MSE = 8.7512e-05 Samples= 88 Time = 65.65
INFO:root:Iter    13800 : NLL = -4.4204e+00 MSE = 4.0643e-05 Samples= 89 Time = 66.04
INFO:root:Iter    13900 : NLL = -4.3617e+00 MSE = 5.0376e-05 Samples= 90 Time = 66.37
INFO:root:Iter    14000 : NLL = -3.6648e+00 MSE = 1.6252e-04 Samples= 91 Time = 66.72
INFO:root:Iter    14100 : NLL = -4.4028e+00 MSE = 3.8752e-05 Samples= 92 Time = 67.07
INFO:root:Iter    14200 : NLL = -4.5969e+00 MSE = 1.8915e-05 Samples= 93 Time = 67.40
INFO:root:Iter    14300 : NLL = -4.4849e+00 MSE = 2.7982e-05 Samples= 94 Time = 67.76
INFO:root:Iter    14400 : NLL = -4.4399e+00 MSE = 3.7365e-05 Samples= 95 Time = 68.09
INFO:root:Iter    14500 : NLL = -4.1367e+00 MSE = 1.6701e-04 Samples= 96 Time = 68.44
INFO:root:Iter    14600 : NLL = -4.4482e+00 MSE = 3.4460e-05 Samples= 97 Time = 68.78
INFO:root:Iter    14700 : NLL = -4.0484e+00 MSE = 6.5320e-05 Samples= 98 Time = 69.12
INFO:root:Iter    14800 : NLL = -4.4678e+00 MSE = 3.2723e-05 Samples= 99 Time = 69.46
INFO:root:Iter    14900 : NLL = -3.9420e+00 MSE = 2.8586e-04 Samples= 100 Time = 69.81

In [13]:
test_config = 9
x = configs[test_config, None, :]
epochs = np.arange(1, n_epochs+1)
idx = epochs / n_epochs
x = np.repeat(x, idx.shape[0], axis=0)
x = np.concatenate((x, idx[:, None]), axis=1)
y_test = learning_curves[test_config].flatten()

m, v = model.predict(x)
s = np.sqrt(v)
plt.plot(epochs, y_test, color="black", label="True Learning Curve", linewidth=3)


f, noise = model.predict(x, return_individual_predictions=True)

[plt.plot(epochs, fi, color="blue", alpha=0.08) for fi in f]


plt.plot(epochs, m, color="red", label="LC-Net", linewidth=3)
plt.legend(loc=4, fontsize=20)
plt.xlabel("Epochs", fontsize=20)
plt.ylabel("Validation Accuracy", fontsize=20)
plt.xlim(1, n_epochs)
plt.axvline(observed_t, linestyle="--", color="grey")
plt.ylim(0, 1)
plt.show()



In [6]:
l = lasagne.layers.get_all_layers(model.net)[:-7]
bf = lasagne.layers.get_output(l, x)[-1].eval()

plt.plot(epochs, y_test, color="black", label="True Learning Curve")
plt.plot(epochs, bf[:, 0], label="vapor pressure")
plt.plot(epochs, bf[:, 1], label="pow-func")
plt.plot(epochs, bf[:, 2], label="log power")
plt.plot(epochs, bf[:, 3], label="exponential")
plt.plot(epochs, bf[:, 4], label="hill-3")

l = lasagne.layers.get_all_layers(model.net)[:-3]
infty = lasagne.layers.get_output(l, x)[-1].eval()

plt.plot(epochs, infty[:, 0], label=r"$\hat{y}_{\infty}$")

l = lasagne.layers.get_all_layers(model.net)[:-2]
mean = lasagne.layers.get_output(l, x)[-1].eval()

plt.plot(epochs, mean[:, 0], label="mean")

plt.legend(loc=4, fontsize=20)
plt.xlabel("Epochs", fontsize=20)
plt.ylabel("Validation Accuracy", fontsize=20)
plt.xlim(1, n_epochs)
plt.ylim(-2, 2)
plt.show()



In [7]:
l = lasagne.layers.get_all_layers(model.net)[:-6]
weights = lasagne.layers.get_output(l, x)[-1].eval()

plt.bar(np.arange(5), weights[0])
plt.xticks(np.arange(5)+.4, ["vapor pressure", "pow-func", "log-power", "exponential", "hill-3"])
plt.show()