two_layer_net


Implementing a Neural Network

In this exercise we will develop a neural network with fully-connected layers to perform classification, and test it out on the CIFAR-10 dataset.


In [2]:
# A bit of setup

import numpy as np
import matplotlib.pyplot as plt

from cs231n.classifiers.neural_net import TwoLayerNet

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

def rel_error(x, y):
  """ returns relative error """
  return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

We will use the class TwoLayerNet in the file cs231n/classifiers/neural_net.py to represent instances of our network. The network parameters are stored in the instance variable self.params where keys are string parameter names and values are numpy arrays. Below, we initialize toy data and a toy model that we will use to develop your implementation.


In [2]:
# Create a small net and some toy data to check your implementations.
# Note that we set the random seed for repeatable experiments.

input_size = 4
hidden_size = 10
num_classes = 3
num_inputs = 5

def init_toy_model():
  np.random.seed(0)
  return TwoLayerNet(input_size, hidden_size, num_classes, std=1e-1)

def init_toy_data():
  np.random.seed(1)
  X = 10 * np.random.randn(num_inputs, input_size)
  y = np.array([0, 1, 2, 2, 1])
  return X, y

net = init_toy_model()
X, y = init_toy_data()

Forward pass: compute scores

Open the file cs231n/classifiers/neural_net.py and look at the method TwoLayerNet.loss. This function is very similar to the loss functions you have written for the SVM and Softmax exercises: It takes the data and weights and computes the class scores, the loss, and the gradients on the parameters.

Implement the first part of the forward pass which uses the weights and biases to compute the scores for all inputs.


In [5]:
scores = net.loss(X)
print 'Your scores:'
print scores
print
print 'correct scores:'
correct_scores = np.asarray([
  [-0.81233741, -1.27654624, -0.70335995],
  [-0.17129677, -1.18803311, -0.47310444],
  [-0.51590475, -1.01354314, -0.8504215 ],
  [-0.15419291, -0.48629638, -0.52901952],
  [-0.00618733, -0.12435261, -0.15226949]])
print correct_scores
print

# The difference should be very small. We get < 1e-7
print 'Difference between your scores and correct scores:'
print np.sum(np.abs(scores - correct_scores))


Your scores:
[[-0.81233741 -1.27654624 -0.70335995]
 [-0.17129677 -1.18803311 -0.47310444]
 [-0.51590475 -1.01354314 -0.8504215 ]
 [-0.15419291 -0.48629638 -0.52901952]
 [-0.00618733 -0.12435261 -0.15226949]]

correct scores:
[[-0.81233741 -1.27654624 -0.70335995]
 [-0.17129677 -1.18803311 -0.47310444]
 [-0.51590475 -1.01354314 -0.8504215 ]
 [-0.15419291 -0.48629638 -0.52901952]
 [-0.00618733 -0.12435261 -0.15226949]]

Difference between your scores and correct scores:
3.68027209324e-08

Forward pass: compute loss

In the same function, implement the second part that computes the data and regularizaion loss.


In [6]:
loss, _ = net.loss(X, y, reg=0.1)
correct_loss = 1.30378789133

# should be very small, we get < 1e-12
print 'Difference between your loss and correct loss:'
print np.sum(np.abs(loss - correct_loss))


Difference between your loss and correct loss:
1.79412040779e-13

Backward pass

Implement the rest of the function. This will compute the gradient of the loss with respect to the variables W1, b1, W2, and b2. Now that you (hopefully!) have a correctly implemented forward pass, you can debug your backward pass using a numeric gradient check:


In [15]:
from cs231n.gradient_check import eval_numerical_gradient

# Use numeric gradient checking to check your implementation of the backward pass.
# If your implementation is correct, the difference between the numeric and
# analytic gradients should be less than 1e-8 for each of W1, W2, b1, and b2.

loss, grads = net.loss(X, y, reg=0.1)

# these should all be less than 1e-8 or so
for param_name in grads:
  f = lambda W: net.loss(X, y, reg=0.1)[0]
  param_grad_num = eval_numerical_gradient(f, net.params[param_name], verbose=False)
  print '%s max relative error: %e' % (param_name, rel_error(param_grad_num, grads[param_name]))


W1 max relative error: 3.561318e-09
W2 max relative error: 3.440708e-09
b2 max relative error: 4.447646e-11
b1 max relative error: 2.738421e-09

Train the network

To train the network we will use stochastic gradient descent (SGD), similar to the SVM and Softmax classifiers. Look at the function TwoLayerNet.train and fill in the missing sections to implement the training procedure. This should be very similar to the training procedure you used for the SVM and Softmax classifiers. You will also have to implement TwoLayerNet.predict, as the training process periodically performs prediction to keep track of accuracy over time while the network trains.

Once you have implemented the method, run the code below to train a two-layer network on toy data. You should achieve a training loss less than 0.2.


In [18]:
net = init_toy_model()
stats = net.train(X, y, X, y,
            learning_rate=1e-1, reg=1e-5,
            num_iters=100, verbose=False)

print 'Final training loss: ', stats['loss_history'][-1]

# plot the loss history
plt.plot(stats['loss_history'])
plt.xlabel('iteration')
plt.ylabel('training loss')
plt.title('Training Loss history')
plt.show()


Final training loss:  0.0171496079387

Load the data

Now that you have implemented a two-layer network that passes gradient checks and works on toy data, it's time to load up our favorite CIFAR-10 data so we can use it to train a classifier on a real dataset.


In [19]:
from cs231n.data_utils import load_CIFAR10

def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000):
    """
    Load the CIFAR-10 dataset from disk and perform preprocessing to prepare
    it for the two-layer neural net classifier. These are the same steps as
    we used for the SVM, but condensed to a single function.  
    """
    # Load the raw CIFAR-10 data
    cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
        
    # Subsample the data
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]

    # Normalize the data: subtract the mean image
    mean_image = np.mean(X_train, axis=0)
    X_train -= mean_image
    X_val -= mean_image
    X_test -= mean_image

    # Reshape data to rows
    X_train = X_train.reshape(num_training, -1)
    X_val = X_val.reshape(num_validation, -1)
    X_test = X_test.reshape(num_test, -1)

    return X_train, y_train, X_val, y_val, X_test, y_test


# Invoke the above function to get our data.
X_train, y_train, X_val, y_val, X_test, y_test = get_CIFAR10_data()
print 'Train data shape: ', X_train.shape
print 'Train labels shape: ', y_train.shape
print 'Validation data shape: ', X_val.shape
print 'Validation labels shape: ', y_val.shape
print 'Test data shape: ', X_test.shape
print 'Test labels shape: ', y_test.shape


Train data shape:  (49000, 3072)
Train labels shape:  (49000,)
Validation data shape:  (1000, 3072)
Validation labels shape:  (1000,)
Test data shape:  (1000, 3072)
Test labels shape:  (1000,)

Train a network

To train our network we will use SGD with momentum. In addition, we will adjust the learning rate with an exponential learning rate schedule as optimization proceeds; after each epoch, we will reduce the learning rate by multiplying it by a decay rate.


In [ ]:
input_size = 32 * 32 * 3
hidden_size = 50
num_classes = 10
net = TwoLayerNet(input_size, hidden_size, num_classes)

# Train the network
stats = net.train(X_train, y_train, X_val, y_val,
            num_iters=1000, batch_size=200,
            learning_rate=1e-4, learning_rate_decay=0.95,
            reg=0.5, verbose=True)

In [21]:
# Predict on the validation set
val_acc = (net.predict(X_val) == y_val).mean()
print 'Validation accuracy: ', val_acc


Validation accuracy:  0.287

Debug the training

With the default parameters we provided above, you should get a validation accuracy of about 0.29 on the validation set. This isn't very good.

One strategy for getting insight into what's wrong is to plot the loss function and the accuracies on the training and validation sets during optimization.

Another strategy is to visualize the weights that were learned in the first layer of the network. In most neural networks trained on visual data, the first layer weights typically show some visible structure when visualized.


In [22]:
# Plot the loss function and train / validation accuracies
plt.subplot(2, 1, 1)
plt.plot(stats['loss_history'])
plt.title('Loss history')
plt.xlabel('Iteration')
plt.ylabel('Loss')

plt.subplot(2, 1, 2)
plt.plot(stats['train_acc_history'], label='train')
plt.plot(stats['val_acc_history'], label='val')
plt.title('Classification accuracy history')
plt.xlabel('Epoch')
plt.ylabel('Clasification accuracy')
plt.show()



In [23]:
from cs231n.vis_utils import visualize_grid

# Visualize the weights of the network

def show_net_weights(net):
  W1 = net.params['W1']
  W1 = W1.reshape(32, 32, 3, -1).transpose(3, 0, 1, 2)
  plt.imshow(visualize_grid(W1, padding=3).astype('uint8'))
  plt.gca().axis('off')
  plt.show()

show_net_weights(net)


Tune your hyperparameters

What's wrong?. Looking at the visualizations above, we see that the loss is decreasing more or less linearly, which seems to suggest that the learning rate may be too low. Moreover, there is no gap between the training and validation accuracy, suggesting that the model we used has low capacity, and that we should increase its size. On the other hand, with a very large model we would expect to see more overfitting, which would manifest itself as a very large gap between the training and validation accuracy.

Tuning. Tuning the hyperparameters and developing intuition for how they affect the final performance is a large part of using Neural Networks, so we want you to get a lot of practice. Below, you should experiment with different values of the various hyperparameters, including hidden layer size, learning rate, numer of training epochs, and regularization strength. You might also consider tuning the learning rate decay, but you should be able to get good performance using the default value.

Approximate results. You should be aim to achieve a classification accuracy of greater than 48% on the validation set. Our best network gets over 52% on the validation set.

Experiment: You goal in this exercise is to get as good of a result on CIFAR-10 as you can, with a fully-connected Neural Network. For every 1% above 52% on the Test set we will award you with one extra bonus point. Feel free implement your own techniques (e.g. PCA to reduce dimensionality, or adding dropout, or adding features to the solver, etc.).


In [ ]:
best_net = None # store the best model into this 

learning = [1e-5, 1e-3]
regularization = [0, 1]
decay = [0.9, 1]

results = {}
best_val = -1

for num_hidden in np.arange(50, 300, 50):
    for _ in np.arange(0, 50):
            i = np.random.uniform(low=learning[0], high=learning[1])
            j = np.random.uniform(low=regularization[0], high=regularization[1])
            k = np.random.uniform(low=decay[0], high=decay[1])


            # Train the network
            net = TwoLayerNet(input_size, num_hidden, num_classes)
            stats = net.train(X_train, y_train, X_val, y_val,
                              num_iters=500, batch_size=200,
                              learning_rate=i, learning_rate_decay=k,
                              reg=j, verbose=False)
   
            # Predict on the validation set
            val_acc = (net.predict(X_val) == y_val).mean()
        
            results[(num_hidden, i, j, k)] = val_acc
            if val_acc > best_val:
                best_val = val_acc

In [29]:
# Print the obtained accuracies
for nh, lr, reg, dec in sorted(results):
    print 'Hidden: %d, learning rate: %f, regularisation: %f, decay: %f -> %f' % ( \
    nh, lr, reg, dec, results[nh, lr, reg, dec])


Hidden: 50, learning rate: 0.000011, regularisation: 0.363304, decay: 0.989335 -> 0.191000
Hidden: 50, learning rate: 0.000014, regularisation: 0.442092, decay: 0.911811 -> 0.191000
Hidden: 50, learning rate: 0.000016, regularisation: 0.101872, decay: 0.964965 -> 0.226000
Hidden: 50, learning rate: 0.000085, regularisation: 0.371178, decay: 0.935648 -> 0.204000
Hidden: 50, learning rate: 0.000086, regularisation: 0.298424, decay: 0.918874 -> 0.192000
Hidden: 50, learning rate: 0.000123, regularisation: 0.621787, decay: 0.939641 -> 0.245000
Hidden: 50, learning rate: 0.000135, regularisation: 0.781503, decay: 0.916323 -> 0.247000
Hidden: 50, learning rate: 0.000137, regularisation: 0.245870, decay: 0.932617 -> 0.249000
Hidden: 50, learning rate: 0.000148, regularisation: 0.656085, decay: 0.961230 -> 0.260000
Hidden: 50, learning rate: 0.000183, regularisation: 0.959219, decay: 0.921856 -> 0.269000
Hidden: 50, learning rate: 0.000189, regularisation: 0.430128, decay: 0.910072 -> 0.275000
Hidden: 50, learning rate: 0.000195, regularisation: 0.991322, decay: 0.991741 -> 0.297000
Hidden: 50, learning rate: 0.000200, regularisation: 0.482552, decay: 0.949017 -> 0.285000
Hidden: 50, learning rate: 0.000204, regularisation: 0.361121, decay: 0.932783 -> 0.288000
Hidden: 50, learning rate: 0.000233, regularisation: 0.624810, decay: 0.927317 -> 0.301000
Hidden: 50, learning rate: 0.000267, regularisation: 0.961234, decay: 0.973184 -> 0.331000
Hidden: 50, learning rate: 0.000345, regularisation: 0.272308, decay: 0.932819 -> 0.359000
Hidden: 50, learning rate: 0.000359, regularisation: 0.109756, decay: 0.937263 -> 0.357000
Hidden: 50, learning rate: 0.000373, regularisation: 0.364964, decay: 0.998249 -> 0.365000
Hidden: 50, learning rate: 0.000375, regularisation: 0.783115, decay: 0.934257 -> 0.368000
Hidden: 50, learning rate: 0.000377, regularisation: 0.285748, decay: 0.924228 -> 0.358000
Hidden: 50, learning rate: 0.000459, regularisation: 0.217852, decay: 0.954128 -> 0.379000
Hidden: 50, learning rate: 0.000469, regularisation: 0.233909, decay: 0.949225 -> 0.389000
Hidden: 50, learning rate: 0.000473, regularisation: 0.796871, decay: 0.991895 -> 0.396000
Hidden: 50, learning rate: 0.000474, regularisation: 0.195203, decay: 0.941541 -> 0.389000
Hidden: 50, learning rate: 0.000503, regularisation: 0.959743, decay: 0.993361 -> 0.404000
Hidden: 50, learning rate: 0.000576, regularisation: 0.047176, decay: 0.931908 -> 0.409000
Hidden: 50, learning rate: 0.000578, regularisation: 0.542868, decay: 0.938592 -> 0.398000
Hidden: 50, learning rate: 0.000581, regularisation: 0.425739, decay: 0.947874 -> 0.409000
Hidden: 50, learning rate: 0.000604, regularisation: 0.742574, decay: 0.930685 -> 0.405000
Hidden: 50, learning rate: 0.000614, regularisation: 0.586256, decay: 0.908312 -> 0.413000
Hidden: 50, learning rate: 0.000623, regularisation: 0.890790, decay: 0.940895 -> 0.412000
Hidden: 50, learning rate: 0.000627, regularisation: 0.998568, decay: 0.948984 -> 0.412000
Hidden: 50, learning rate: 0.000655, regularisation: 0.578601, decay: 0.969834 -> 0.425000
Hidden: 50, learning rate: 0.000667, regularisation: 0.886517, decay: 0.908658 -> 0.404000
Hidden: 50, learning rate: 0.000700, regularisation: 0.052630, decay: 0.976770 -> 0.421000
Hidden: 50, learning rate: 0.000710, regularisation: 0.297222, decay: 0.911363 -> 0.415000
Hidden: 50, learning rate: 0.000730, regularisation: 0.224748, decay: 0.969152 -> 0.411000
Hidden: 50, learning rate: 0.000739, regularisation: 0.092852, decay: 0.934343 -> 0.428000
Hidden: 50, learning rate: 0.000748, regularisation: 0.422783, decay: 0.905761 -> 0.431000
Hidden: 50, learning rate: 0.000749, regularisation: 0.763117, decay: 0.909301 -> 0.431000
Hidden: 50, learning rate: 0.000811, regularisation: 0.138834, decay: 0.947758 -> 0.440000
Hidden: 50, learning rate: 0.000870, regularisation: 0.267232, decay: 0.906252 -> 0.435000
Hidden: 50, learning rate: 0.000887, regularisation: 0.456159, decay: 0.927681 -> 0.434000
Hidden: 50, learning rate: 0.000897, regularisation: 0.652603, decay: 0.918217 -> 0.437000
Hidden: 50, learning rate: 0.000900, regularisation: 0.169740, decay: 0.954826 -> 0.446000
Hidden: 50, learning rate: 0.000939, regularisation: 0.601901, decay: 0.928653 -> 0.428000
Hidden: 50, learning rate: 0.000943, regularisation: 0.867811, decay: 0.971636 -> 0.452000
Hidden: 50, learning rate: 0.000945, regularisation: 0.103825, decay: 0.910472 -> 0.450000
Hidden: 50, learning rate: 0.000961, regularisation: 0.337382, decay: 0.953312 -> 0.442000
Hidden: 100, learning rate: 0.000026, regularisation: 0.525241, decay: 0.957196 -> 0.197000
Hidden: 100, learning rate: 0.000049, regularisation: 0.971909, decay: 0.993614 -> 0.203000
Hidden: 100, learning rate: 0.000069, regularisation: 0.154570, decay: 0.991666 -> 0.204000
Hidden: 100, learning rate: 0.000091, regularisation: 0.085586, decay: 0.912668 -> 0.209000
Hidden: 100, learning rate: 0.000095, regularisation: 0.763349, decay: 0.966351 -> 0.229000
Hidden: 100, learning rate: 0.000111, regularisation: 0.752424, decay: 0.952243 -> 0.249000
Hidden: 100, learning rate: 0.000144, regularisation: 0.767565, decay: 0.954234 -> 0.268000
Hidden: 100, learning rate: 0.000198, regularisation: 0.096014, decay: 0.932896 -> 0.300000
Hidden: 100, learning rate: 0.000200, regularisation: 0.419288, decay: 0.995093 -> 0.308000
Hidden: 100, learning rate: 0.000213, regularisation: 0.156374, decay: 0.933044 -> 0.304000
Hidden: 100, learning rate: 0.000225, regularisation: 0.087112, decay: 0.988090 -> 0.334000
Hidden: 100, learning rate: 0.000242, regularisation: 0.833380, decay: 0.949385 -> 0.317000
Hidden: 100, learning rate: 0.000251, regularisation: 0.641326, decay: 0.963999 -> 0.322000
Hidden: 100, learning rate: 0.000302, regularisation: 0.522565, decay: 0.941901 -> 0.335000
Hidden: 100, learning rate: 0.000306, regularisation: 0.181819, decay: 0.907356 -> 0.344000
Hidden: 100, learning rate: 0.000330, regularisation: 0.425905, decay: 0.993786 -> 0.366000
Hidden: 100, learning rate: 0.000337, regularisation: 0.610079, decay: 0.999148 -> 0.377000
Hidden: 100, learning rate: 0.000355, regularisation: 0.156621, decay: 0.982748 -> 0.370000
Hidden: 100, learning rate: 0.000415, regularisation: 0.074547, decay: 0.911140 -> 0.380000
Hidden: 100, learning rate: 0.000436, regularisation: 0.223116, decay: 0.920913 -> 0.394000
Hidden: 100, learning rate: 0.000440, regularisation: 0.785730, decay: 0.922896 -> 0.388000
Hidden: 100, learning rate: 0.000465, regularisation: 0.317691, decay: 0.903797 -> 0.389000
Hidden: 100, learning rate: 0.000468, regularisation: 0.498534, decay: 0.969394 -> 0.385000
Hidden: 100, learning rate: 0.000468, regularisation: 0.025249, decay: 0.914001 -> 0.401000
Hidden: 100, learning rate: 0.000483, regularisation: 0.860138, decay: 0.968463 -> 0.390000
Hidden: 100, learning rate: 0.000510, regularisation: 0.387200, decay: 0.958290 -> 0.401000
Hidden: 100, learning rate: 0.000540, regularisation: 0.728495, decay: 0.901040 -> 0.385000
Hidden: 100, learning rate: 0.000556, regularisation: 0.472995, decay: 0.927395 -> 0.391000
Hidden: 100, learning rate: 0.000569, regularisation: 0.212449, decay: 0.939040 -> 0.417000
Hidden: 100, learning rate: 0.000570, regularisation: 0.603224, decay: 0.954009 -> 0.435000
Hidden: 100, learning rate: 0.000571, regularisation: 0.795067, decay: 0.914001 -> 0.407000
Hidden: 100, learning rate: 0.000593, regularisation: 0.381867, decay: 0.907367 -> 0.423000
Hidden: 100, learning rate: 0.000638, regularisation: 0.608742, decay: 0.992501 -> 0.441000
Hidden: 100, learning rate: 0.000645, regularisation: 0.038368, decay: 0.915380 -> 0.413000
Hidden: 100, learning rate: 0.000688, regularisation: 0.051829, decay: 0.968964 -> 0.449000
Hidden: 100, learning rate: 0.000690, regularisation: 0.334620, decay: 0.909284 -> 0.428000
Hidden: 100, learning rate: 0.000696, regularisation: 0.956262, decay: 0.914229 -> 0.419000
Hidden: 100, learning rate: 0.000712, regularisation: 0.922834, decay: 0.944014 -> 0.427000
Hidden: 100, learning rate: 0.000720, regularisation: 0.237622, decay: 0.946307 -> 0.442000
Hidden: 100, learning rate: 0.000741, regularisation: 0.942698, decay: 0.915562 -> 0.423000
Hidden: 100, learning rate: 0.000761, regularisation: 0.729526, decay: 0.965566 -> 0.435000
Hidden: 100, learning rate: 0.000766, regularisation: 0.015107, decay: 0.900946 -> 0.433000
Hidden: 100, learning rate: 0.000831, regularisation: 0.401818, decay: 0.914732 -> 0.426000
Hidden: 100, learning rate: 0.000833, regularisation: 0.045080, decay: 0.941438 -> 0.450000
Hidden: 100, learning rate: 0.000840, regularisation: 0.057334, decay: 0.934159 -> 0.437000
Hidden: 100, learning rate: 0.000864, regularisation: 0.516705, decay: 0.919641 -> 0.441000
Hidden: 100, learning rate: 0.000876, regularisation: 0.560520, decay: 0.949787 -> 0.449000
Hidden: 100, learning rate: 0.000904, regularisation: 0.751927, decay: 0.907922 -> 0.436000
Hidden: 100, learning rate: 0.000954, regularisation: 0.486600, decay: 0.979251 -> 0.451000
Hidden: 100, learning rate: 0.000958, regularisation: 0.952745, decay: 0.935156 -> 0.457000
Hidden: 150, learning rate: 0.000012, regularisation: 0.417100, decay: 0.911500 -> 0.239000
Hidden: 150, learning rate: 0.000040, regularisation: 0.439002, decay: 0.902276 -> 0.183000
Hidden: 150, learning rate: 0.000053, regularisation: 0.457982, decay: 0.953286 -> 0.180000
Hidden: 150, learning rate: 0.000062, regularisation: 0.995679, decay: 0.948896 -> 0.201000
Hidden: 150, learning rate: 0.000089, regularisation: 0.066143, decay: 0.944133 -> 0.228000
Hidden: 150, learning rate: 0.000098, regularisation: 0.508066, decay: 0.934873 -> 0.244000
Hidden: 150, learning rate: 0.000103, regularisation: 0.692714, decay: 0.950727 -> 0.246000
Hidden: 150, learning rate: 0.000114, regularisation: 0.000717, decay: 0.920310 -> 0.252000
Hidden: 150, learning rate: 0.000157, regularisation: 0.080785, decay: 0.926671 -> 0.275000
Hidden: 150, learning rate: 0.000175, regularisation: 0.187289, decay: 0.986274 -> 0.295000
Hidden: 150, learning rate: 0.000205, regularisation: 0.784708, decay: 0.987382 -> 0.302000
Hidden: 150, learning rate: 0.000225, regularisation: 0.736446, decay: 0.953031 -> 0.324000
Hidden: 150, learning rate: 0.000227, regularisation: 0.201981, decay: 0.927178 -> 0.308000
Hidden: 150, learning rate: 0.000231, regularisation: 0.371191, decay: 0.935848 -> 0.316000
Hidden: 150, learning rate: 0.000241, regularisation: 0.989268, decay: 0.952152 -> 0.329000
Hidden: 150, learning rate: 0.000250, regularisation: 0.720408, decay: 0.966968 -> 0.334000
Hidden: 150, learning rate: 0.000289, regularisation: 0.983061, decay: 0.975679 -> 0.346000
Hidden: 150, learning rate: 0.000341, regularisation: 0.784400, decay: 0.921162 -> 0.360000
Hidden: 150, learning rate: 0.000385, regularisation: 0.938107, decay: 0.993547 -> 0.367000
Hidden: 150, learning rate: 0.000424, regularisation: 0.731652, decay: 0.928577 -> 0.383000
Hidden: 150, learning rate: 0.000434, regularisation: 0.097998, decay: 0.999737 -> 0.391000
Hidden: 150, learning rate: 0.000458, regularisation: 0.902791, decay: 0.957587 -> 0.388000
Hidden: 150, learning rate: 0.000476, regularisation: 0.299629, decay: 0.989933 -> 0.401000
Hidden: 150, learning rate: 0.000477, regularisation: 0.968832, decay: 0.914224 -> 0.390000
Hidden: 150, learning rate: 0.000489, regularisation: 0.683991, decay: 0.977084 -> 0.409000
Hidden: 150, learning rate: 0.000534, regularisation: 0.454714, decay: 0.948111 -> 0.423000
Hidden: 150, learning rate: 0.000559, regularisation: 0.471037, decay: 0.960400 -> 0.412000
Hidden: 150, learning rate: 0.000560, regularisation: 0.343579, decay: 0.929949 -> 0.414000
Hidden: 150, learning rate: 0.000567, regularisation: 0.723767, decay: 0.972680 -> 0.420000
Hidden: 150, learning rate: 0.000619, regularisation: 0.522062, decay: 0.932284 -> 0.399000
Hidden: 150, learning rate: 0.000629, regularisation: 0.956838, decay: 0.918907 -> 0.420000
Hidden: 150, learning rate: 0.000633, regularisation: 0.991591, decay: 0.964312 -> 0.413000
Hidden: 150, learning rate: 0.000635, regularisation: 0.198947, decay: 0.902095 -> 0.424000
Hidden: 150, learning rate: 0.000640, regularisation: 0.055402, decay: 0.906297 -> 0.412000
Hidden: 150, learning rate: 0.000674, regularisation: 0.705702, decay: 0.956029 -> 0.422000
Hidden: 150, learning rate: 0.000683, regularisation: 0.693113, decay: 0.963548 -> 0.452000
Hidden: 150, learning rate: 0.000687, regularisation: 0.157089, decay: 0.951175 -> 0.439000
Hidden: 150, learning rate: 0.000725, regularisation: 0.571389, decay: 0.959523 -> 0.433000
Hidden: 150, learning rate: 0.000764, regularisation: 0.212596, decay: 0.987081 -> 0.438000
Hidden: 150, learning rate: 0.000793, regularisation: 0.190560, decay: 0.930990 -> 0.440000
Hidden: 150, learning rate: 0.000839, regularisation: 0.570742, decay: 0.958949 -> 0.436000
Hidden: 150, learning rate: 0.000867, regularisation: 0.305280, decay: 0.999850 -> 0.447000
Hidden: 150, learning rate: 0.000916, regularisation: 0.545413, decay: 0.923442 -> 0.443000
Hidden: 150, learning rate: 0.000936, regularisation: 0.542844, decay: 0.916846 -> 0.447000
Hidden: 150, learning rate: 0.000951, regularisation: 0.188635, decay: 0.912570 -> 0.448000
Hidden: 150, learning rate: 0.000967, regularisation: 0.998075, decay: 0.954153 -> 0.419000
Hidden: 150, learning rate: 0.000975, regularisation: 0.132594, decay: 0.931893 -> 0.460000
Hidden: 150, learning rate: 0.000977, regularisation: 0.588602, decay: 0.999561 -> 0.460000
Hidden: 150, learning rate: 0.000981, regularisation: 0.834292, decay: 0.985803 -> 0.454000
Hidden: 150, learning rate: 0.000987, regularisation: 0.352451, decay: 0.908834 -> 0.438000
Hidden: 200, learning rate: 0.000025, regularisation: 0.052805, decay: 0.998130 -> 0.225000
Hidden: 200, learning rate: 0.000035, regularisation: 0.940661, decay: 0.973288 -> 0.198000
Hidden: 200, learning rate: 0.000063, regularisation: 0.552263, decay: 0.946245 -> 0.201000
Hidden: 200, learning rate: 0.000071, regularisation: 0.265604, decay: 0.952872 -> 0.207000
Hidden: 200, learning rate: 0.000071, regularisation: 0.139111, decay: 0.982636 -> 0.217000
Hidden: 200, learning rate: 0.000141, regularisation: 0.811858, decay: 0.998594 -> 0.288000
Hidden: 200, learning rate: 0.000146, regularisation: 0.461511, decay: 0.924378 -> 0.278000
Hidden: 200, learning rate: 0.000184, regularisation: 0.045106, decay: 0.932911 -> 0.303000
Hidden: 200, learning rate: 0.000194, regularisation: 0.189522, decay: 0.963319 -> 0.310000
Hidden: 200, learning rate: 0.000199, regularisation: 0.433225, decay: 0.957427 -> 0.305000
Hidden: 200, learning rate: 0.000209, regularisation: 0.787378, decay: 0.956444 -> 0.318000
Hidden: 200, learning rate: 0.000241, regularisation: 0.662336, decay: 0.921869 -> 0.318000
Hidden: 200, learning rate: 0.000267, regularisation: 0.799541, decay: 0.986073 -> 0.353000
Hidden: 200, learning rate: 0.000300, regularisation: 0.199952, decay: 0.946168 -> 0.362000
Hidden: 200, learning rate: 0.000310, regularisation: 0.330595, decay: 0.978804 -> 0.369000
Hidden: 200, learning rate: 0.000353, regularisation: 0.148643, decay: 0.929185 -> 0.370000
Hidden: 200, learning rate: 0.000400, regularisation: 0.118089, decay: 0.931082 -> 0.381000
Hidden: 200, learning rate: 0.000411, regularisation: 0.453531, decay: 0.987659 -> 0.390000
Hidden: 200, learning rate: 0.000413, regularisation: 0.590013, decay: 0.923755 -> 0.382000
Hidden: 200, learning rate: 0.000427, regularisation: 0.660502, decay: 0.992249 -> 0.387000
Hidden: 200, learning rate: 0.000454, regularisation: 0.238923, decay: 0.983711 -> 0.403000
Hidden: 200, learning rate: 0.000470, regularisation: 0.820582, decay: 0.996969 -> 0.405000
Hidden: 200, learning rate: 0.000497, regularisation: 0.734868, decay: 0.985900 -> 0.407000
Hidden: 200, learning rate: 0.000517, regularisation: 0.440662, decay: 0.955356 -> 0.412000
Hidden: 200, learning rate: 0.000523, regularisation: 0.171600, decay: 0.965449 -> 0.412000
Hidden: 200, learning rate: 0.000531, regularisation: 0.791411, decay: 0.919720 -> 0.404000
Hidden: 200, learning rate: 0.000535, regularisation: 0.747470, decay: 0.938248 -> 0.414000
Hidden: 200, learning rate: 0.000574, regularisation: 0.916590, decay: 0.936493 -> 0.401000
Hidden: 200, learning rate: 0.000588, regularisation: 0.903051, decay: 0.929304 -> 0.427000
Hidden: 200, learning rate: 0.000596, regularisation: 0.640928, decay: 0.961943 -> 0.428000
Hidden: 200, learning rate: 0.000620, regularisation: 0.165555, decay: 0.952455 -> 0.436000
Hidden: 200, learning rate: 0.000696, regularisation: 0.535948, decay: 0.988269 -> 0.444000
Hidden: 200, learning rate: 0.000723, regularisation: 0.194769, decay: 0.977227 -> 0.441000
Hidden: 200, learning rate: 0.000766, regularisation: 0.045871, decay: 0.932919 -> 0.453000
Hidden: 200, learning rate: 0.000781, regularisation: 0.229017, decay: 0.941971 -> 0.454000
Hidden: 200, learning rate: 0.000782, regularisation: 0.910479, decay: 0.956308 -> 0.428000
Hidden: 200, learning rate: 0.000796, regularisation: 0.514791, decay: 0.984602 -> 0.456000
Hidden: 200, learning rate: 0.000798, regularisation: 0.511781, decay: 0.993948 -> 0.451000
Hidden: 200, learning rate: 0.000823, regularisation: 0.536475, decay: 0.927475 -> 0.443000
Hidden: 200, learning rate: 0.000832, regularisation: 0.322150, decay: 0.978314 -> 0.450000
Hidden: 200, learning rate: 0.000837, regularisation: 0.650761, decay: 0.928345 -> 0.450000
Hidden: 200, learning rate: 0.000838, regularisation: 0.544381, decay: 0.991772 -> 0.449000
Hidden: 200, learning rate: 0.000845, regularisation: 0.054931, decay: 0.976286 -> 0.451000
Hidden: 200, learning rate: 0.000849, regularisation: 0.704150, decay: 0.946256 -> 0.437000
Hidden: 200, learning rate: 0.000857, regularisation: 0.438885, decay: 0.916391 -> 0.434000
Hidden: 200, learning rate: 0.000861, regularisation: 0.843549, decay: 0.952695 -> 0.457000
Hidden: 200, learning rate: 0.000877, regularisation: 0.110034, decay: 0.916207 -> 0.446000
Hidden: 200, learning rate: 0.000897, regularisation: 0.089111, decay: 0.984270 -> 0.442000
Hidden: 200, learning rate: 0.000932, regularisation: 0.764572, decay: 0.947404 -> 0.454000
Hidden: 200, learning rate: 0.000949, regularisation: 0.885840, decay: 0.983951 -> 0.459000
Hidden: 250, learning rate: 0.000030, regularisation: 0.792051, decay: 0.982715 -> 0.220000
Hidden: 250, learning rate: 0.000036, regularisation: 0.996424, decay: 0.971697 -> 0.199000
Hidden: 250, learning rate: 0.000087, regularisation: 0.315784, decay: 0.990090 -> 0.239000
Hidden: 250, learning rate: 0.000098, regularisation: 0.478054, decay: 0.918208 -> 0.245000
Hidden: 250, learning rate: 0.000102, regularisation: 0.323909, decay: 0.956093 -> 0.251000
Hidden: 250, learning rate: 0.000129, regularisation: 0.701754, decay: 0.956489 -> 0.268000
Hidden: 250, learning rate: 0.000150, regularisation: 0.601731, decay: 0.983770 -> 0.290000
Hidden: 250, learning rate: 0.000157, regularisation: 0.608786, decay: 0.966004 -> 0.278000
Hidden: 250, learning rate: 0.000183, regularisation: 0.830140, decay: 0.993483 -> 0.311000
Hidden: 250, learning rate: 0.000185, regularisation: 0.064388, decay: 0.958197 -> 0.314000
Hidden: 250, learning rate: 0.000244, regularisation: 0.857078, decay: 0.909931 -> 0.324000
Hidden: 250, learning rate: 0.000275, regularisation: 0.289786, decay: 0.994570 -> 0.357000
Hidden: 250, learning rate: 0.000278, regularisation: 0.328094, decay: 0.977333 -> 0.352000
Hidden: 250, learning rate: 0.000283, regularisation: 0.795640, decay: 0.901462 -> 0.343000
Hidden: 250, learning rate: 0.000296, regularisation: 0.476142, decay: 0.919533 -> 0.356000
Hidden: 250, learning rate: 0.000306, regularisation: 0.516156, decay: 0.900622 -> 0.349000
Hidden: 250, learning rate: 0.000311, regularisation: 0.811208, decay: 0.925997 -> 0.362000
Hidden: 250, learning rate: 0.000369, regularisation: 0.665320, decay: 0.936695 -> 0.394000
Hidden: 250, learning rate: 0.000384, regularisation: 0.379808, decay: 0.953761 -> 0.386000
Hidden: 250, learning rate: 0.000415, regularisation: 0.827248, decay: 0.987918 -> 0.393000
Hidden: 250, learning rate: 0.000423, regularisation: 0.953376, decay: 0.932926 -> 0.400000
Hidden: 250, learning rate: 0.000441, regularisation: 0.901468, decay: 0.955916 -> 0.380000
Hidden: 250, learning rate: 0.000469, regularisation: 0.879819, decay: 0.979375 -> 0.403000
Hidden: 250, learning rate: 0.000514, regularisation: 0.056162, decay: 0.949760 -> 0.422000
Hidden: 250, learning rate: 0.000547, regularisation: 0.667579, decay: 0.955214 -> 0.430000
Hidden: 250, learning rate: 0.000569, regularisation: 0.143896, decay: 0.980064 -> 0.423000
Hidden: 250, learning rate: 0.000577, regularisation: 0.749325, decay: 0.985898 -> 0.424000
Hidden: 250, learning rate: 0.000579, regularisation: 0.655987, decay: 0.932973 -> 0.402000
Hidden: 250, learning rate: 0.000581, regularisation: 0.133335, decay: 0.911578 -> 0.434000
Hidden: 250, learning rate: 0.000610, regularisation: 0.284044, decay: 0.981651 -> 0.425000
Hidden: 250, learning rate: 0.000613, regularisation: 0.493983, decay: 0.915265 -> 0.418000
Hidden: 250, learning rate: 0.000619, regularisation: 0.698333, decay: 0.911244 -> 0.413000
Hidden: 250, learning rate: 0.000643, regularisation: 0.338197, decay: 0.918876 -> 0.435000
Hidden: 250, learning rate: 0.000647, regularisation: 0.730527, decay: 0.969541 -> 0.438000
Hidden: 250, learning rate: 0.000698, regularisation: 0.834657, decay: 0.980837 -> 0.435000
Hidden: 250, learning rate: 0.000784, regularisation: 0.831870, decay: 0.998203 -> 0.444000
Hidden: 250, learning rate: 0.000786, regularisation: 0.242257, decay: 0.909812 -> 0.435000
Hidden: 250, learning rate: 0.000793, regularisation: 0.750899, decay: 0.924260 -> 0.442000
Hidden: 250, learning rate: 0.000794, regularisation: 0.464423, decay: 0.939778 -> 0.433000
Hidden: 250, learning rate: 0.000795, regularisation: 0.078712, decay: 0.909360 -> 0.438000
Hidden: 250, learning rate: 0.000811, regularisation: 0.547364, decay: 0.985177 -> 0.450000
Hidden: 250, learning rate: 0.000854, regularisation: 0.844599, decay: 0.927784 -> 0.431000
Hidden: 250, learning rate: 0.000885, regularisation: 0.714459, decay: 0.961539 -> 0.447000
Hidden: 250, learning rate: 0.000895, regularisation: 0.837530, decay: 0.950305 -> 0.441000
Hidden: 250, learning rate: 0.000925, regularisation: 0.548671, decay: 0.926887 -> 0.459000
Hidden: 250, learning rate: 0.000927, regularisation: 0.258387, decay: 0.938764 -> 0.453000
Hidden: 250, learning rate: 0.000934, regularisation: 0.758874, decay: 0.990726 -> 0.452000
Hidden: 250, learning rate: 0.000960, regularisation: 0.017008, decay: 0.912560 -> 0.439000
Hidden: 250, learning rate: 0.000986, regularisation: 0.969803, decay: 0.977761 -> 0.449000
Hidden: 250, learning rate: 0.000995, regularisation: 0.300747, decay: 0.959434 -> 0.448000
{(150, 0.0002247939833878815, 0.7364460445177025, 0.953031299036297): 0.32400000000000001, (100, 0.0006375095186806413, 0.6087421838920903, 0.9925007151550612): 0.441, (200, 0.00041124815918575596, 0.45353133996012074, 0.9876594533684355): 0.39000000000000001, (50, 0.0004593680331854585, 0.2178522637565904, 0.954128034369639): 0.379, (250, 0.00027828604268455325, 0.3280940956603803, 0.9773329867556975): 0.35199999999999998, (150, 0.00097516662662409, 0.1325940461706302, 0.9318933747973769): 0.46000000000000002, (250, 0.0007942413472134181, 0.464422810321249, 0.9397780268102327): 0.433, (250, 0.0005474035138660877, 0.6675794326964799, 0.9552140113654224): 0.42999999999999999, (150, 0.00025004750176814866, 0.7204082881260264, 0.9669676194886312): 0.33400000000000002, (150, 0.00022722362674220843, 0.2019806197035916, 0.927177710694509): 0.308, (150, 0.0006289938462069027, 0.9568383976577269, 0.9189068320556172): 0.41999999999999998, (100, 0.0009035422604305505, 0.7519267365734256, 0.907922169488436): 0.436, (150, 0.0009868133566928849, 0.3524510424118117, 0.908834007176346): 0.438, (50, 0.0003592888081673193, 0.10975586496763678, 0.9372627760868479): 0.35699999999999998, (250, 0.0006984652820683779, 0.8346566338128255, 0.9808369442963876): 0.435, (150, 0.000967123609198984, 0.9980751431496112, 0.9541527081060585): 0.41899999999999998, (250, 0.00014967344223554868, 0.6017310490627307, 0.9837698141289555): 0.28999999999999998, (150, 0.00047713988755129343, 0.968832155286147, 0.9142244432353864): 0.39000000000000001, (250, 0.0007862949795128859, 0.2422568419960276, 0.9098118809968987): 0.435, (200, 0.0005313709224651059, 0.7914105398685448, 0.9197197818873113): 0.40400000000000003, (100, 0.0005711334970715286, 0.7950665608583685, 0.9140005451031055): 0.40699999999999997, (150, 0.0004763189551066795, 0.2996291059067244, 0.9899328314844841): 0.40100000000000002, (50, 0.0005775494425983414, 0.5428683234159953, 0.938591690943709): 0.39800000000000002, (50, 0.0007104805433846716, 0.2972222113438989, 0.9113631826527884): 0.41499999999999998, (250, 0.0009862022168867159, 0.9698028538381556, 0.977761071332205): 0.44900000000000001, (50, 0.000604383442134247, 0.7425738119256627, 0.9306850935111102): 0.40500000000000003, (250, 0.0008540951906398946, 0.8445991029558162, 0.9277844272153154): 0.43099999999999999, (100, 0.000876291791405967, 0.5605203843749893, 0.9497867326545432): 0.44900000000000001, (50, 0.0006136168644180342, 0.5862559585866668, 0.9083117335981792): 0.41299999999999998, (150, 0.0007643972266038543, 0.21259575884920023, 0.987080843130061): 0.438, (100, 0.00025127702365282655, 0.6413259384545547, 0.9639985955658188): 0.32200000000000001, (250, 0.0006126830832549915, 0.49398285810235876, 0.9152651329871607): 0.41799999999999998, (50, 8.577831984232158e-05, 0.29842443969445065, 0.9188743344589517): 0.192, (100, 4.8611783482978154e-05, 0.9719092189045399, 0.9936143455740749): 0.20300000000000001, (250, 0.0006469403840349027, 0.7305271757716629, 0.9695408055171547): 0.438, (100, 0.00030180051033025545, 0.5225648582279303, 0.9419009514361992): 0.33500000000000002, (100, 0.0007121164806265709, 0.922833507106365, 0.9440143623352353): 0.42699999999999999, (200, 0.00040029281842565267, 0.1180888135218382, 0.9310818128323295): 0.38100000000000001, (250, 0.0007928329642170297, 0.750898998584531, 0.9242597177586821): 0.442, (100, 0.0008305796036917698, 0.40181825914506997, 0.9147324189585065): 0.42599999999999999, (100, 0.0005558800256849773, 0.4729950269086908, 0.9273953476547337): 0.39100000000000001, (200, 6.326227449577576e-05, 0.55226320251558, 0.9462446287719106): 0.20100000000000001, (100, 0.0006956871843488254, 0.9562619128934108, 0.9142285984372984): 0.41899999999999998, (250, 0.0006427473404006292, 0.33819700872403113, 0.9188763269649909): 0.435, (250, 0.0005789735332372524, 0.6559865179076533, 0.9329731181925343): 0.40200000000000002, (100, 0.00041466446763801953, 0.0745466976395156, 0.911139709123685): 0.38, (150, 5.283527231299541e-05, 0.45798208910832394, 0.9532855029585243): 0.17999999999999999, (150, 0.000916001679918934, 0.545413441196814, 0.9234423945911482): 0.443, (200, 0.0002094680975498127, 0.7873779646178565, 0.9564444210668647): 0.318, (200, 0.0008765724709551406, 0.11003382349558943, 0.9162068598874508): 0.44600000000000001, (100, 0.00014402704350604482, 0.7675645399943928, 0.9542335623208419): 0.26800000000000002, (150, 0.00023107751249536957, 0.37119132315117565, 0.9358475173050531): 0.316, (250, 8.68155948420097e-05, 0.31578384983060237, 0.9900896447987833): 0.23899999999999999, (150, 0.0009808690011990308, 0.8342918586973639, 0.9858030655373405): 0.45400000000000001, (50, 0.0002669204042744615, 0.9612339243479552, 0.9731839173847462): 0.33100000000000002, (50, 0.00019542144048678958, 0.9913218285709746, 0.9917409655355288): 0.29699999999999999, (50, 0.0007296857508495531, 0.2247481074499671, 0.9691523903331729): 0.41099999999999998, (250, 0.00046889755855530375, 0.8798186677111239, 0.9793754885647288): 0.40300000000000002, (50, 0.00037660276337168013, 0.285748274746531, 0.9242282964996369): 0.35799999999999998, (150, 0.0007252340006875304, 0.5713890270050008, 0.9595229265189438): 0.433, (50, 0.00013662748490104703, 0.24586958581874307, 0.9326169139436994): 0.249, (150, 0.00043369270793290794, 0.09799838317233445, 0.9997372121967751): 0.39100000000000001, (50, 0.0004732707076467591, 0.796871388154209, 0.9918950614171429): 0.39600000000000002, (250, 0.00018540486963535972, 0.06438757140609319, 0.9581965692136069): 0.314, (200, 0.000861286982769568, 0.843548840343738, 0.952695143810642): 0.45700000000000002, (200, 0.0004965062310884537, 0.7348680339377227, 0.9859002055175674): 0.40699999999999997, (100, 0.0007414694050440216, 0.9426977412271933, 0.9155623301272203): 0.42299999999999999, (250, 0.0004149239473131685, 0.82724759313885, 0.9879176120771263): 0.39300000000000002, (50, 0.000575881537914494, 0.047175915270613045, 0.9319083182580243): 0.40899999999999997, (250, 0.0009268575371551104, 0.2583871428769129, 0.9387640140894399): 0.45300000000000001, (150, 1.2423759369652946e-05, 0.4171004369692991, 0.9115001489736172): 0.23899999999999999, (50, 0.0009613982276352605, 0.3373816580603315, 0.9533119114192239): 0.442, (150, 0.0006191857825461, 0.5220619059358178, 0.9322841467738261): 0.39900000000000002, (150, 0.0006397123116392149, 0.05540232627430619, 0.9062972867023423): 0.41199999999999998, (200, 0.0002666025064376474, 0.7995413054710876, 0.9860730849630808): 0.35299999999999998, (200, 0.00019429350598582255, 0.189521964881199, 0.9633193963034823): 0.31, (100, 0.0004834147951524466, 0.8601382422106265, 0.9684629960455691): 0.39000000000000001, (200, 0.0004133877851818796, 0.5900134515571146, 0.923755385255177): 0.38200000000000001, (100, 0.0006884597028908369, 0.051828779508901746, 0.9689637942552236): 0.44900000000000001, (250, 0.00030607716377827504, 0.5161564597849986, 0.9006224343829742): 0.34899999999999998, (100, 0.0005099912706685373, 0.38719979418758077, 0.9582899730625613): 0.40100000000000002, (100, 0.0005929259435770236, 0.38186660879011003, 0.907367152998206): 0.42299999999999999, (150, 0.0008671950551475315, 0.3052796155121661, 0.9998501527586368): 0.44700000000000001, (150, 0.00038466314272450077, 0.9381074305661116, 0.9935470850569232): 0.36699999999999999, (100, 0.00019829520150429182, 0.09601384269255864, 0.9328960900120465): 0.29999999999999999, (250, 0.0001293387004768965, 0.7017537090459796, 0.9564894985633342): 0.26800000000000002, (50, 0.0008973261477760534, 0.6526026029509817, 0.9182170563701838): 0.437, (200, 0.0007658282498771027, 0.045870715257145256, 0.9329188351314668): 0.45300000000000001, (250, 9.806893108272417e-05, 0.4780540813888723, 0.9182080972914077): 0.245, (250, 0.00024375134931470228, 0.8570780477795718, 0.9099306437467618): 0.32400000000000001, (150, 0.0006829493690415167, 0.6931134489953596, 0.963547643670388): 0.45200000000000001, (50, 0.0001830007401283098, 0.959219291514418, 0.9218560445924374): 0.26900000000000002, (200, 0.0007810936485877282, 0.22901671173286864, 0.9419714742863224): 0.45400000000000001, (100, 0.0004400781642779105, 0.7857300719625077, 0.9228962300031474): 0.38800000000000001, (200, 0.0008566016627256125, 0.4388850651186498, 0.9163913871533547): 0.434, (100, 0.0006451308140158548, 0.03836778078958025, 0.9153802048891467): 0.41299999999999998, (200, 0.0006957646414828252, 0.5359481855036166, 0.988269440660387): 0.44400000000000001, (50, 0.000739392855073123, 0.09285166845635884, 0.9343429671524246): 0.42799999999999999, (100, 0.0003366207107418531, 0.6100791733806473, 0.9991484854737078): 0.377, (250, 0.0001832013586075105, 0.8301399692061933, 0.9934833736956441): 0.311, (200, 0.00042666481106989136, 0.6605021300151538, 0.9922487582497443): 0.38700000000000001, (150, 0.0009505555834572882, 0.18863516542982095, 0.9125702188240992): 0.44800000000000001, (100, 0.0008325097135667625, 0.04508017815550591, 0.9414379604628639): 0.45000000000000001, (200, 0.0008451832060259951, 0.05493139020486615, 0.9762862280177351): 0.45100000000000001, (150, 0.0002052055010642781, 0.7847075302605926, 0.9873818638802132): 0.30199999999999999, (200, 0.00024104230514406725, 0.6623356167272337, 0.921869393028157): 0.318, (150, 0.0005338774926221017, 0.4547143823978105, 0.9481111878532742): 0.42299999999999999, (100, 0.0009575487185075048, 0.9527452181945446, 0.9351558329774896): 0.45700000000000002, (250, 0.00010165904464024778, 0.32390890001007944, 0.9560929325484565): 0.251, (50, 0.0008868110813758455, 0.4561589975352218, 0.927680710619586): 0.434, (50, 0.00037325422246660753, 0.3649637888971814, 0.9982488292586515): 0.36499999999999999, (200, 7.104796613233764e-05, 0.13911146723016954, 0.9826359885235976): 0.217, (250, 0.0007835200669786257, 0.8318696879719513, 0.9982025749083938): 0.44400000000000001, (150, 0.0005599098560387191, 0.3435793886458234, 0.9299486438399566): 0.41399999999999998, (250, 0.0005688808588754003, 0.14389620842548756, 0.9800641205403267): 0.42299999999999999, (150, 0.00024116615871387683, 0.9892678120815964, 0.9521517744675292): 0.32900000000000001, (250, 0.0004413058579596207, 0.9014681870858094, 0.9559157151096662): 0.38, (100, 0.0004648815804509106, 0.31769127379709916, 0.9037971299834751): 0.38900000000000001, (200, 0.0005345302232468725, 0.7474701852683259, 0.9382483995213987): 0.41399999999999998, (150, 0.00045812995263047845, 0.9027905727618731, 0.957586807030729): 0.38800000000000001, (200, 0.0005227738609984648, 0.17159969244710105, 0.9654489623139686): 0.41199999999999998, (200, 2.5244206836710327e-05, 0.05280451378934525, 0.9981300272756248): 0.22500000000000001, (250, 0.000609750691507519, 0.28404361697268177, 0.9816507440196715): 0.42499999999999999, (150, 0.0006742933558990132, 0.7057024414046864, 0.9560293820806254): 0.42199999999999999, (100, 0.0007607024727825244, 0.7295259998452002, 0.9655659835474392): 0.435, (200, 0.0006197730929671072, 0.16555506042541535, 0.9524553294055269): 0.436, (200, 0.0008225770564557667, 0.5364748305843693, 0.9274751500735487): 0.443, (250, 0.0003688243716393594, 0.6653203388157359, 0.936695435831894): 0.39400000000000002, (200, 0.0003003010881083073, 0.19995182937117606, 0.9461680474055572): 0.36199999999999999, (100, 0.0005700936242648683, 0.6032240981287557, 0.9540088525653923): 0.435, (100, 0.00030576026745413737, 0.18181930298383253, 0.9073563334144322): 0.34399999999999997, (50, 0.0009002060037792747, 0.16974035945451904, 0.9548258050916258): 0.44600000000000001, (200, 0.0007958038709757661, 0.5147912552432676, 0.9846022493611547): 0.45600000000000002, (150, 0.0001754189019776713, 0.18728925158634058, 0.9862744565402252): 0.29499999999999998, (150, 0.0004888636763487053, 0.6839905048197579, 0.9770840099461846): 0.40899999999999997, (200, 0.00046982445546347415, 0.820582168725975, 0.9969694252034826): 0.40500000000000003, (50, 0.0008109831120793883, 0.1388335987648316, 0.9477576362929807): 0.44, (50, 0.0009427234069538015, 0.8678110254900591, 0.9716360542964826): 0.45200000000000001, (50, 0.0007483543060048669, 0.4227827785594931, 0.9057612983568997): 0.43099999999999999, (250, 0.00027544449733409525, 0.2897855955005547, 0.9945700304370148): 0.35699999999999998, (50, 0.0007490649492049832, 0.7631169080531975, 0.9093012082994475): 0.43099999999999999, (200, 0.0004543624696447771, 0.23892336560531435, 0.9837113933338317): 0.40300000000000002, (100, 0.00032958521174508145, 0.4259045263432646, 0.9937863404830027): 0.36599999999999999, (200, 0.0007231386915828095, 0.19476873855443366, 0.9772273628535845): 0.441, (100, 0.000953988026153383, 0.4866004753605918, 0.9792512693093126): 0.45100000000000001, (200, 0.00014118605034211553, 0.811858457605124, 0.9985937458496807): 0.28799999999999998, (100, 0.0008642827200247826, 0.5167050683172761, 0.9196409808345918): 0.441, (200, 0.0008322950307439004, 0.3221495767998023, 0.9783144755115325): 0.45000000000000001, (50, 0.00013499269795546844, 0.7815029981906237, 0.9163229648060734): 0.247, (200, 0.0001844960263098198, 0.045106141144383605, 0.9329106435350927): 0.30299999999999999, (100, 0.0005685372939675206, 0.2124492437920913, 0.9390395412708413): 0.41699999999999998, (250, 0.0006191754040875002, 0.6983326207277875, 0.9112440768816529): 0.41299999999999998, (100, 0.00043613050707477986, 0.22311605766158615, 0.9209125062000549): 0.39400000000000002, (200, 0.0005875433414138977, 0.9030509646442613, 0.9293042011929942): 0.42699999999999999, (100, 0.0006896194994058481, 0.3346198610593756, 0.9092835209570845): 0.42799999999999999, (50, 0.00012329583677527151, 0.6217873282383937, 0.9396407467582513): 0.245, (200, 0.0005957340328108068, 0.6409278551195308, 0.9619428559440351): 0.42799999999999999, (150, 0.00011361390808226068, 0.0007165024925176455, 0.9203099757826935): 0.252, (200, 0.0008371043483095149, 0.6507612221206394, 0.9283454198147824): 0.45000000000000001, (200, 0.0003095033564391284, 0.33059493271764673, 0.9788039251882327): 0.36899999999999999, (100, 0.00021305025969695548, 0.15637412898387415, 0.9330436474503089): 0.30399999999999999, (100, 0.00046824479338884, 0.025249116536612437, 0.9140005254704224): 0.40100000000000002, (50, 0.0005027293200010108, 0.9597425805166264, 0.9933607758888824): 0.40400000000000003, (200, 0.0007975010257019514, 0.5117810441845719, 0.9939475119129589): 0.45100000000000001, (50, 0.00034498559139813914, 0.27230830418205054, 0.9328186895369818): 0.35899999999999999, (250, 3.597803182931227e-05, 0.9964236392703093, 0.9716970347664919): 0.19900000000000001, (50, 1.603430235187431e-05, 0.10187161118409216, 0.964964818899473): 0.22600000000000001, (150, 0.00015734851836452058, 0.08078451732931413, 0.9266714025871324): 0.27500000000000002, (250, 0.0008845662239899973, 0.7144588510219021, 0.9615390465363725): 0.44700000000000001, (250, 0.0008110884382388486, 0.5473636650209536, 0.9851770878092154): 0.45000000000000001, (200, 0.0008968397687136302, 0.08911068444120085, 0.9842699934740153): 0.442, (250, 0.0005806324395979479, 0.13333504687331943, 0.9115776527147432): 0.434, (250, 0.0003108937950919561, 0.8112080759789924, 0.925997446533405): 0.36199999999999999, (200, 0.0009490745151102985, 0.8858403049792549, 0.9839509691771525): 0.45900000000000002, (200, 0.00019898372479897964, 0.4332250413513641, 0.9574274422600055): 0.30499999999999999, (50, 0.0008702145800956831, 0.2672323083946009, 0.9062515614515816): 0.435, (100, 0.0007661479841885424, 0.015106762085522085, 0.9009460088435678): 0.433, (100, 0.00046769022308740965, 0.4985339302066564, 0.9693942744840764): 0.38500000000000001, (250, 0.0005143956929388689, 0.05616184207255959, 0.9497600936224311): 0.42199999999999999, (150, 0.00034138825295703653, 0.7844003849119031, 0.9211622121317321): 0.35999999999999999, (200, 0.0005740811208753585, 0.9165901568756757, 0.9364931313877536): 0.40100000000000002, (50, 0.0006673385807530815, 0.8865167444384865, 0.9086584685891724): 0.40400000000000003, (200, 0.0008494070448055074, 0.7041497040968728, 0.9462558059185395): 0.437, (200, 0.0008383751125763001, 0.5443809450927055, 0.9917717411930388): 0.44900000000000001, (200, 0.0009316255501760393, 0.7645723999778823, 0.9474036817889233): 0.45400000000000001, (50, 0.0006274989682633989, 0.9985680682206097, 0.948983636856963): 0.41199999999999998, (50, 0.00014765601454693866, 0.656085339294214, 0.9612295045485195): 0.26000000000000001, (100, 9.48286978218442e-05, 0.7633490535325036, 0.9663512833386063): 0.22900000000000001, (50, 0.0001894329384309893, 0.43012782108550296, 0.9100716409815982): 0.27500000000000002, (100, 0.00011058863000362021, 0.7524244709167405, 0.9522431540019057): 0.249, (150, 3.9821624488974845e-05, 0.4390021011433056, 0.9022757738216686): 0.183, (50, 1.1025791798111119e-05, 0.36330448850621344, 0.9893353530173773): 0.191, (100, 0.0007203325439355195, 0.23762205347584475, 0.946306774284782): 0.442, (150, 0.00010323969073849585, 0.692713683920309, 0.9507269240514103): 0.246, (250, 0.00015656171721327374, 0.6087856696535298, 0.96600414250964): 0.27800000000000002, (250, 2.9546113382334783e-05, 0.7920505712613259, 0.9827151476281708): 0.22, (50, 0.0006548287606738536, 0.5786011104808956, 0.9698338717841761): 0.42499999999999999, (100, 9.062016609532949e-05, 0.0855861413607365, 0.9126681343318253): 0.20899999999999999, (150, 0.0006865816204991877, 0.15708888273359622, 0.9511753039234323): 0.439, (250, 0.000894831553337236, 0.8375302334338183, 0.950305050468722): 0.441, (150, 0.0005673304020280562, 0.723767022450434, 0.9726802859032926): 0.41999999999999998, (250, 0.0009252494358332768, 0.5486708196070373, 0.9268874534695013): 0.45900000000000002, (200, 7.06259915788447e-05, 0.26560421788939714, 0.9528717455786939): 0.20699999999999999, (250, 0.00029564152785579935, 0.47614180519887184, 0.9195333535063511): 0.35599999999999998, (50, 0.0001997100681070257, 0.48255209027948365, 0.9490168744786834): 0.28499999999999998, (200, 0.0007821996159082461, 0.9104793755449505, 0.9563081096255508): 0.42799999999999999, (100, 0.0008404139142469408, 0.0573338813196137, 0.9341593095746226): 0.437, (250, 0.00042287170741041303, 0.9533764015702586, 0.932925562317469): 0.40000000000000002, (100, 0.000224945422489662, 0.08711151643867443, 0.988090221778302): 0.33400000000000002, (200, 0.000145968230081264, 0.461511117075752, 0.9243782689559793): 0.27800000000000002, (200, 0.0003534949025565561, 0.14864296160502832, 0.9291849154817675): 0.37, (150, 0.0007932478174312034, 0.19056037769585166, 0.9309897892419043): 0.44, (150, 6.156867774233974e-05, 0.9956789894427621, 0.948895843531299): 0.20100000000000001, (200, 3.505202610532547e-05, 0.9406613430697308, 0.9732881657114094): 0.19800000000000001, (200, 0.0005167900043394081, 0.4406615981587221, 0.9553563106267587): 0.41199999999999998, (150, 0.0009356998358566354, 0.542844476322939, 0.916846164280192): 0.44700000000000001, (50, 0.0007000651357199204, 0.05262992286609147, 0.9767697580487681): 0.42099999999999999, (50, 1.3644339450499241e-05, 0.44209197021881297, 0.9118106825964951): 0.191, (50, 0.00037459507820536495, 0.7831153927249478, 0.9342567908974206): 0.36799999999999999, (250, 0.0007946722289015232, 0.07871162880928473, 0.9093597223891596): 0.438, (50, 8.524105726202221e-05, 0.37117769496718145, 0.9356481691385665): 0.20399999999999999, (50, 0.0004739977759179161, 0.19520324275034073, 0.9415408420463556): 0.38900000000000001, (100, 6.931017179541206e-05, 0.15456969018830624, 0.9916657495577897): 0.20399999999999999, (250, 0.0005765914467586578, 0.7493247237539596, 0.9858984988535611): 0.42399999999999999, (50, 0.0004689554609326399, 0.23390884373576037, 0.949225449565914): 0.38900000000000001, (50, 0.00020396456614148802, 0.361120615843626, 0.9327825376547755): 0.28799999999999998, (150, 0.0006334830838511333, 0.9915906173380895, 0.9643120219247994): 0.41299999999999998, (100, 2.6450600604980042e-05, 0.5252406725962376, 0.9571964596840977): 0.19700000000000001, (250, 0.000933857912858053, 0.7588738703505026, 0.9907260538932282): 0.45200000000000001, (250, 0.00028287509115895054, 0.7956396250785762, 0.9014619922160112): 0.34300000000000003, (50, 0.0002334685907340758, 0.6248101850835447, 0.9273170594051353): 0.30099999999999999, (250, 0.0003844095256863825, 0.3798076957622635, 0.9537606598814236): 0.38600000000000001, (150, 0.000558754352537555, 0.4710372549773634, 0.9604003541719931): 0.41199999999999998, (150, 8.890587866359085e-05, 0.06614270870747252, 0.9441333273788907): 0.22800000000000001, (100, 0.0005401938231423449, 0.728495382538007, 0.9010402167814123): 0.38500000000000001, (100, 0.00024173934327682172, 0.8333802344840741, 0.9493854403617005): 0.317, (100, 0.00035451922086374263, 0.15662125581491138, 0.982748115111801): 0.37, (150, 0.00042365780118882153, 0.7316523641101516, 0.9285773210320306): 0.38300000000000001, (50, 0.0005812566466456446, 0.4257387795447397, 0.9478739400126143): 0.40899999999999997, (250, 0.0009604253098570212, 0.01700817918002584, 0.912560220950049): 0.439, (150, 0.0006349412833752068, 0.1989467237588437, 0.9020953372152752): 0.42399999999999999, (250, 0.000995069327945829, 0.30074743613897903, 0.9594335438203866): 0.44800000000000001, (50, 0.0009388982154558996, 0.6019009685934398, 0.9286530083910258): 0.42799999999999999, (50, 0.0006230199428842103, 0.8907898632183969, 0.9408950361033024): 0.41199999999999998, (100, 0.00019995330729164704, 0.4192883528323593, 0.9950933949011073): 0.308, (50, 0.0009449680433910691, 0.10382469960454621, 0.9104717596893173): 0.45000000000000001, (150, 0.000977112356335308, 0.5886018223643927, 0.9995607434501561): 0.46000000000000002, (150, 0.00028920919764460973, 0.9830609010767556, 0.9756787046189399): 0.34599999999999997, (150, 0.0008387311861416587, 0.5707424372150834, 0.9589490921007829): 0.436, (150, 9.836338092445456e-05, 0.5080662224539926, 0.9348729676432226): 0.24399999999999999}

In [49]:
# Find the best learning rate and regularization strength
best_hidden = 25
best_lr = 0.000958
best_reg = 0.952745
best_decay = 0.935156

best_val = -1
for nh, lr, reg, dec in sorted(results):
    if results[(nh, lr, reg, dec)] > best_val:
        best_val = results[(nh, lr, reg, dec)]
        best_hidden = nh
        best_lr = lr
        best_reg = reg
        best_decay = dec

# Train the best_svm with more iterations
best_net = TwoLayerNet(input_size, best_hidden, num_classes)
stats = best_net.train(X_train, y_train, X_val, y_val,
                  num_iters=2000, batch_size=200,
                  learning_rate=best_lr, learning_rate_decay=best_decay,
                  reg=best_reg, verbose=True)

# Predict on the validation set
val_acc = (net.predict(X_val) == y_val).mean()

print 'Best validation accuracy now: %f' % val_acc


iteration 0 / 2000: loss 2.302885
iteration 100 / 2000: loss 1.911262
iteration 200 / 2000: loss 1.826928
iteration 300 / 2000: loss 1.654449
iteration 400 / 2000: loss 1.621161
iteration 500 / 2000: loss 1.514830
iteration 600 / 2000: loss 1.462329
iteration 700 / 2000: loss 1.608378
iteration 800 / 2000: loss 1.393405
iteration 900 / 2000: loss 1.479433
iteration 1000 / 2000: loss 1.466119
iteration 1100 / 2000: loss 1.403799
iteration 1200 / 2000: loss 1.405784
iteration 1300 / 2000: loss 1.349541
iteration 1400 / 2000: loss 1.395470
iteration 1500 / 2000: loss 1.223357
iteration 1600 / 2000: loss 1.380818
iteration 1700 / 2000: loss 1.437497
iteration 1800 / 2000: loss 1.250355
iteration 1900 / 2000: loss 1.322252
Best validation accuracy now: 0.504000

In [ ]:
# visualize the weights of the best network
show_net_weights(best_net)

Run on the test set

When you are done experimenting, you should evaluate your final trained network on the test set; you should get above 48%.

We will give you extra bonus point for every 1% of accuracy above 52%.


In [1]:
test_acc = (best_net.predict(X_test) == y_test).mean()
print 'Test accuracy: ', test_acc



NameErrorTraceback (most recent call last)
<ipython-input-1-758052de9fc4> in <module>()
----> 1 test_acc = (best_net.predict(X_test) == y_test).mean()
      2 print 'Test accuracy: ', test_acc

NameError: name 'best_net' is not defined

In [ ]: