The hardest possible way to fit a gaussian

How well does a neural network work for nonlinear regression? Obviously it is way overkill, but it is interesting to check!


In [1]:
# Imports
from __future__ import print_function
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.utils import np_utils
from sklearn.cross_validation import train_test_split
import matplotlib.pyplot as plt

%matplotlib inline


Using Theano backend.

In [45]:
N = 10000
batch_size = N//50
nb_epoch = 100
nb_dense = 512
nb_hidden = 3   # The number of hidden layers to use
p_dropout = 0.5

In [46]:
# Make data: y = N(0, 2)
x = np.random.uniform(-10, 10, N)
X = x[:, np.newaxis]
Y = np.exp(-X**2 / (2*2**2)) + np.random.normal(loc=0, scale=0.1, size=x.size)[:, np.newaxis]
plt.scatter(x, Y[:, 0], alpha=0.1)
#Y = y[:, np.newaxis]

X_train, X_test, Y_train, Y_test = train_test_split(X, Y)
print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)


(7500, 1) (2500, 1) (7500, 1) (2500, 1)

In [47]:
print('Building model...')
model = Sequential()
model.add(Dense(nb_dense, input_shape=(X.shape[1],)))
model.add(Activation('tanh'))
model.add(Dropout(p_dropout))
for _ in range(nb_hidden-1):
    model.add(Dense(nb_dense))
    model.add(Activation('tanh'))
    model.add(Dropout(p_dropout))
model.add(Dense(1))
model.add(Activation('linear'))

model.compile(loss='mean_squared_error',
              optimizer='adam',
              metrics=['accuracy'])


Building model...

In [48]:
import time
t1 = time.time()
result = model.fit(X, Y,
                   nb_epoch=nb_epoch, batch_size=batch_size,
                   verbose=1, validation_split=0.1)
t2 = time.time()
print('Model training took {:.2g} minutes'.format((t2-t1)/60))


Train on 9000 samples, validate on 1000 samples
Epoch 1/100
9000/9000 [==============================] - 1s - loss: 1.2129 - acc: 0.0000e+00 - val_loss: 0.1576 - val_acc: 0.0000e+00
Epoch 2/100
9000/9000 [==============================] - 1s - loss: 0.2864 - acc: 0.0000e+00 - val_loss: 0.1335 - val_acc: 0.0000e+00
Epoch 3/100
9000/9000 [==============================] - 1s - loss: 0.2593 - acc: 0.0000e+00 - val_loss: 0.1354 - val_acc: 0.0000e+00
Epoch 4/100
9000/9000 [==============================] - 1s - loss: 0.2429 - acc: 0.0000e+00 - val_loss: 0.1361 - val_acc: 0.0000e+00
Epoch 5/100
9000/9000 [==============================] - 1s - loss: 0.2168 - acc: 0.0000e+00 - val_loss: 0.1366 - val_acc: 0.0000e+00
Epoch 6/100
9000/9000 [==============================] - 1s - loss: 0.2039 - acc: 0.0000e+00 - val_loss: 0.1331 - val_acc: 0.0000e+00
Epoch 7/100
9000/9000 [==============================] - 1s - loss: 0.1879 - acc: 0.0000e+00 - val_loss: 0.1343 - val_acc: 0.0000e+00
Epoch 8/100
9000/9000 [==============================] - 1s - loss: 0.1738 - acc: 0.0000e+00 - val_loss: 0.1338 - val_acc: 0.0000e+00
Epoch 9/100
9000/9000 [==============================] - 1s - loss: 0.1643 - acc: 0.0000e+00 - val_loss: 0.1340 - val_acc: 0.0000e+00
Epoch 10/100
9000/9000 [==============================] - 1s - loss: 0.1615 - acc: 0.0000e+00 - val_loss: 0.1331 - val_acc: 0.0000e+00
Epoch 11/100
9000/9000 [==============================] - 1s - loss: 0.1500 - acc: 0.0000e+00 - val_loss: 0.1325 - val_acc: 0.0000e+00
Epoch 12/100
9000/9000 [==============================] - 1s - loss: 0.1451 - acc: 0.0000e+00 - val_loss: 0.1339 - val_acc: 0.0000e+00
Epoch 13/100
9000/9000 [==============================] - 1s - loss: 0.1414 - acc: 0.0000e+00 - val_loss: 0.1323 - val_acc: 0.0000e+00
Epoch 14/100
9000/9000 [==============================] - 1s - loss: 0.1352 - acc: 0.0000e+00 - val_loss: 0.1316 - val_acc: 0.0000e+00
Epoch 15/100
9000/9000 [==============================] - 1s - loss: 0.1351 - acc: 0.0000e+00 - val_loss: 0.1310 - val_acc: 0.0000e+00
Epoch 16/100
9000/9000 [==============================] - 1s - loss: 0.1306 - acc: 0.0000e+00 - val_loss: 0.1297 - val_acc: 0.0000e+00
Epoch 17/100
9000/9000 [==============================] - 1s - loss: 0.1277 - acc: 0.0000e+00 - val_loss: 0.1259 - val_acc: 0.0000e+00
Epoch 18/100
9000/9000 [==============================] - 1s - loss: 0.1192 - acc: 0.0000e+00 - val_loss: 0.1009 - val_acc: 0.0000e+00
Epoch 19/100
9000/9000 [==============================] - 1s - loss: 0.0994 - acc: 0.0000e+00 - val_loss: 0.0636 - val_acc: 0.0000e+00
Epoch 20/100
9000/9000 [==============================] - 1s - loss: 0.0837 - acc: 0.0000e+00 - val_loss: 0.0438 - val_acc: 0.0000e+00
Epoch 21/100
9000/9000 [==============================] - 1s - loss: 0.0700 - acc: 0.0000e+00 - val_loss: 0.0310 - val_acc: 0.0000e+00
Epoch 22/100
9000/9000 [==============================] - 1s - loss: 0.0606 - acc: 0.0000e+00 - val_loss: 0.0261 - val_acc: 0.0000e+00
Epoch 23/100
9000/9000 [==============================] - 1s - loss: 0.0524 - acc: 0.0000e+00 - val_loss: 0.0217 - val_acc: 0.0000e+00
Epoch 24/100
9000/9000 [==============================] - 1s - loss: 0.0485 - acc: 0.0000e+00 - val_loss: 0.0200 - val_acc: 0.0000e+00
Epoch 25/100
9000/9000 [==============================] - 1s - loss: 0.0438 - acc: 0.0000e+00 - val_loss: 0.0184 - val_acc: 0.0000e+00
Epoch 26/100
9000/9000 [==============================] - 1s - loss: 0.0383 - acc: 0.0000e+00 - val_loss: 0.0158 - val_acc: 0.0000e+00
Epoch 27/100
9000/9000 [==============================] - 1s - loss: 0.0352 - acc: 0.0000e+00 - val_loss: 0.0162 - val_acc: 0.0000e+00
Epoch 28/100
9000/9000 [==============================] - 1s - loss: 0.0324 - acc: 0.0000e+00 - val_loss: 0.0141 - val_acc: 0.0000e+00
Epoch 29/100
9000/9000 [==============================] - 1s - loss: 0.0301 - acc: 0.0000e+00 - val_loss: 0.0141 - val_acc: 0.0000e+00
Epoch 30/100
9000/9000 [==============================] - 1s - loss: 0.0278 - acc: 0.0000e+00 - val_loss: 0.0131 - val_acc: 0.0000e+00
Epoch 31/100
9000/9000 [==============================] - 1s - loss: 0.0264 - acc: 0.0000e+00 - val_loss: 0.0125 - val_acc: 0.0000e+00
Epoch 32/100
9000/9000 [==============================] - 1s - loss: 0.0249 - acc: 0.0000e+00 - val_loss: 0.0128 - val_acc: 0.0000e+00
Epoch 33/100
9000/9000 [==============================] - 1s - loss: 0.0238 - acc: 0.0000e+00 - val_loss: 0.0121 - val_acc: 0.0000e+00
Epoch 34/100
9000/9000 [==============================] - 1s - loss: 0.0223 - acc: 0.0000e+00 - val_loss: 0.0129 - val_acc: 0.0000e+00
Epoch 35/100
9000/9000 [==============================] - 1s - loss: 0.0220 - acc: 0.0000e+00 - val_loss: 0.0119 - val_acc: 0.0000e+00
Epoch 36/100
9000/9000 [==============================] - 1s - loss: 0.0218 - acc: 0.0000e+00 - val_loss: 0.0121 - val_acc: 0.0000e+00
Epoch 37/100
9000/9000 [==============================] - 1s - loss: 0.0201 - acc: 0.0000e+00 - val_loss: 0.0121 - val_acc: 0.0000e+00
Epoch 38/100
9000/9000 [==============================] - 1s - loss: 0.0206 - acc: 0.0000e+00 - val_loss: 0.0118 - val_acc: 0.0000e+00
Epoch 39/100
9000/9000 [==============================] - 1s - loss: 0.0197 - acc: 0.0000e+00 - val_loss: 0.0119 - val_acc: 0.0000e+00
Epoch 40/100
9000/9000 [==============================] - 1s - loss: 0.0201 - acc: 0.0000e+00 - val_loss: 0.0140 - val_acc: 0.0000e+00
Epoch 41/100
9000/9000 [==============================] - 1s - loss: 0.0192 - acc: 0.0000e+00 - val_loss: 0.0120 - val_acc: 0.0000e+00
Epoch 42/100
9000/9000 [==============================] - 1s - loss: 0.0187 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 43/100
9000/9000 [==============================] - 1s - loss: 0.0185 - acc: 0.0000e+00 - val_loss: 0.0118 - val_acc: 0.0000e+00
Epoch 44/100
9000/9000 [==============================] - 1s - loss: 0.0181 - acc: 0.0000e+00 - val_loss: 0.0126 - val_acc: 0.0000e+00
Epoch 45/100
9000/9000 [==============================] - 1s - loss: 0.0183 - acc: 0.0000e+00 - val_loss: 0.0117 - val_acc: 0.0000e+00
Epoch 46/100
9000/9000 [==============================] - 1s - loss: 0.0179 - acc: 0.0000e+00 - val_loss: 0.0131 - val_acc: 0.0000e+00
Epoch 47/100
9000/9000 [==============================] - 1s - loss: 0.0183 - acc: 0.0000e+00 - val_loss: 0.0128 - val_acc: 0.0000e+00
Epoch 48/100
9000/9000 [==============================] - 1s - loss: 0.0180 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 49/100
9000/9000 [==============================] - 1s - loss: 0.0173 - acc: 0.0000e+00 - val_loss: 0.0119 - val_acc: 0.0000e+00
Epoch 50/100
9000/9000 [==============================] - 1s - loss: 0.0179 - acc: 0.0000e+00 - val_loss: 0.0114 - val_acc: 0.0000e+00
Epoch 51/100
9000/9000 [==============================] - 1s - loss: 0.0169 - acc: 0.0000e+00 - val_loss: 0.0120 - val_acc: 0.0000e+00
Epoch 52/100
9000/9000 [==============================] - 1s - loss: 0.0174 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 53/100
9000/9000 [==============================] - 1s - loss: 0.0169 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 54/100
9000/9000 [==============================] - 1s - loss: 0.0173 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 55/100
9000/9000 [==============================] - 1s - loss: 0.0168 - acc: 0.0000e+00 - val_loss: 0.0124 - val_acc: 0.0000e+00
Epoch 56/100
9000/9000 [==============================] - 1s - loss: 0.0170 - acc: 0.0000e+00 - val_loss: 0.0116 - val_acc: 0.0000e+00
Epoch 57/100
9000/9000 [==============================] - 1s - loss: 0.0168 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 58/100
9000/9000 [==============================] - 1s - loss: 0.0164 - acc: 0.0000e+00 - val_loss: 0.0121 - val_acc: 0.0000e+00
Epoch 59/100
9000/9000 [==============================] - 1s - loss: 0.0163 - acc: 0.0000e+00 - val_loss: 0.0117 - val_acc: 0.0000e+00
Epoch 60/100
9000/9000 [==============================] - 1s - loss: 0.0164 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 61/100
9000/9000 [==============================] - 1s - loss: 0.0160 - acc: 0.0000e+00 - val_loss: 0.0117 - val_acc: 0.0000e+00
Epoch 62/100
9000/9000 [==============================] - 1s - loss: 0.0163 - acc: 0.0000e+00 - val_loss: 0.0116 - val_acc: 0.0000e+00
Epoch 63/100
9000/9000 [==============================] - 1s - loss: 0.0165 - acc: 0.0000e+00 - val_loss: 0.0123 - val_acc: 0.0000e+00
Epoch 64/100
9000/9000 [==============================] - 1s - loss: 0.0170 - acc: 0.0000e+00 - val_loss: 0.0114 - val_acc: 0.0000e+00
Epoch 65/100
9000/9000 [==============================] - 1s - loss: 0.0165 - acc: 0.0000e+00 - val_loss: 0.0119 - val_acc: 0.0000e+00
Epoch 66/100
9000/9000 [==============================] - 1s - loss: 0.0159 - acc: 0.0000e+00 - val_loss: 0.0118 - val_acc: 0.0000e+00
Epoch 67/100
9000/9000 [==============================] - 1s - loss: 0.0158 - acc: 0.0000e+00 - val_loss: 0.0113 - val_acc: 0.0000e+00
Epoch 68/100
9000/9000 [==============================] - 1s - loss: 0.0158 - acc: 0.0000e+00 - val_loss: 0.0132 - val_acc: 0.0000e+00
Epoch 69/100
9000/9000 [==============================] - 1s - loss: 0.0166 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 70/100
9000/9000 [==============================] - 1s - loss: 0.0166 - acc: 0.0000e+00 - val_loss: 0.0120 - val_acc: 0.0000e+00
Epoch 71/100
9000/9000 [==============================] - 1s - loss: 0.0164 - acc: 0.0000e+00 - val_loss: 0.0117 - val_acc: 0.0000e+00
Epoch 72/100
9000/9000 [==============================] - 1s - loss: 0.0162 - acc: 0.0000e+00 - val_loss: 0.0129 - val_acc: 0.0000e+00
Epoch 73/100
9000/9000 [==============================] - 1s - loss: 0.0164 - acc: 0.0000e+00 - val_loss: 0.0131 - val_acc: 0.0000e+00
Epoch 74/100
9000/9000 [==============================] - 1s - loss: 0.0165 - acc: 0.0000e+00 - val_loss: 0.0112 - val_acc: 0.0000e+00
Epoch 75/100
9000/9000 [==============================] - 1s - loss: 0.0157 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 76/100
9000/9000 [==============================] - 1s - loss: 0.0156 - acc: 0.0000e+00 - val_loss: 0.0113 - val_acc: 0.0000e+00
Epoch 77/100
9000/9000 [==============================] - 1s - loss: 0.0154 - acc: 0.0000e+00 - val_loss: 0.0112 - val_acc: 0.0000e+00
Epoch 78/100
9000/9000 [==============================] - 1s - loss: 0.0162 - acc: 0.0000e+00 - val_loss: 0.0117 - val_acc: 0.0000e+00
Epoch 79/100
9000/9000 [==============================] - 1s - loss: 0.0159 - acc: 0.0000e+00 - val_loss: 0.0114 - val_acc: 0.0000e+00
Epoch 80/100
9000/9000 [==============================] - 1s - loss: 0.0157 - acc: 0.0000e+00 - val_loss: 0.0126 - val_acc: 0.0000e+00
Epoch 81/100
9000/9000 [==============================] - 1s - loss: 0.0158 - acc: 0.0000e+00 - val_loss: 0.0123 - val_acc: 0.0000e+00
Epoch 82/100
9000/9000 [==============================] - 1s - loss: 0.0154 - acc: 0.0000e+00 - val_loss: 0.0114 - val_acc: 0.0000e+00
Epoch 83/100
9000/9000 [==============================] - 1s - loss: 0.0152 - acc: 0.0000e+00 - val_loss: 0.0113 - val_acc: 0.0000e+00
Epoch 84/100
9000/9000 [==============================] - 1s - loss: 0.0153 - acc: 0.0000e+00 - val_loss: 0.0114 - val_acc: 0.0000e+00
Epoch 85/100
9000/9000 [==============================] - 1s - loss: 0.0151 - acc: 0.0000e+00 - val_loss: 0.0138 - val_acc: 0.0000e+00
Epoch 86/100
9000/9000 [==============================] - 1s - loss: 0.0161 - acc: 0.0000e+00 - val_loss: 0.0137 - val_acc: 0.0000e+00
Epoch 87/100
9000/9000 [==============================] - 1s - loss: 0.0154 - acc: 0.0000e+00 - val_loss: 0.0126 - val_acc: 0.0000e+00
Epoch 88/100
9000/9000 [==============================] - 1s - loss: 0.0160 - acc: 0.0000e+00 - val_loss: 0.0123 - val_acc: 0.0000e+00
Epoch 89/100
9000/9000 [==============================] - 1s - loss: 0.0150 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 90/100
9000/9000 [==============================] - 1s - loss: 0.0155 - acc: 0.0000e+00 - val_loss: 0.0114 - val_acc: 0.0000e+00
Epoch 91/100
9000/9000 [==============================] - 1s - loss: 0.0160 - acc: 0.0000e+00 - val_loss: 0.0114 - val_acc: 0.0000e+00
Epoch 92/100
9000/9000 [==============================] - 1s - loss: 0.0158 - acc: 0.0000e+00 - val_loss: 0.0112 - val_acc: 0.0000e+00
Epoch 93/100
9000/9000 [==============================] - 1s - loss: 0.0153 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 94/100
9000/9000 [==============================] - 1s - loss: 0.0151 - acc: 0.0000e+00 - val_loss: 0.0113 - val_acc: 0.0000e+00
Epoch 95/100
9000/9000 [==============================] - 1s - loss: 0.0150 - acc: 0.0000e+00 - val_loss: 0.0112 - val_acc: 0.0000e+00
Epoch 96/100
9000/9000 [==============================] - 1s - loss: 0.0150 - acc: 0.0000e+00 - val_loss: 0.0115 - val_acc: 0.0000e+00
Epoch 97/100
9000/9000 [==============================] - 1s - loss: 0.0149 - acc: 0.0000e+00 - val_loss: 0.0112 - val_acc: 0.0000e+00
Epoch 98/100
9000/9000 [==============================] - 1s - loss: 0.0153 - acc: 0.0000e+00 - val_loss: 0.0122 - val_acc: 0.0000e+00
Epoch 99/100
9000/9000 [==============================] - 1s - loss: 0.0161 - acc: 0.0000e+00 - val_loss: 0.0123 - val_acc: 0.0000e+00
Epoch 100/100
9000/9000 [==============================] - 1s - loss: 0.0154 - acc: 0.0000e+00 - val_loss: 0.0124 - val_acc: 0.0000e+00
Model training took 2 minutes

In [49]:
xplot = np.linspace(np.min(X[:, 0]), np.max(X[:, 0]), 100)
Xplot = xplot[:, np.newaxis]
Yplot = model.predict(Xplot)

plt.scatter(X, Y, alpha=0.03)
plt.plot(xplot, Yplot[:, 0], 'r--', lw=3)
#plt.plot(xplot, 3*xplot - xplot**2, 'g-.', lw=2)

#plt.ylim((-0.5, 2))


Out[49]:
[<matplotlib.lines.Line2D at 0x149656550>]

In [ ]: