Neural Networks

This project was created by Brian Granger. All content is licensed under the MIT License.


Introduction

Neural networks are a class of algorithms that can learn how to compute the value of a function given previous examples of the functions output. Because neural networks are capable of learning how to compute the output of a function based on existing data, they generally fall under the field of Machine Learning.

Let's say that we don't know how to compute some function $f$:

$$ f(x) \rightarrow y $$

But we do have some data about the output that $f$ produces for particular input $x$:

$$ f(x_1) \rightarrow y_1 $$$$ f(x_2) \rightarrow y_2 $$$$ \ldots $$$$ f(x_n) \rightarrow y_n $$

A neural network learns how to use that existing data to compute the value of the function $f$ on yet unseen data. Neural networks get their name from the similarity of their design to how neurons in the brain work.

Work on neural networks began in the 1940s, but significant advancements were made in the 1970s (backpropagation) and more recently, since the late 2000s, with the advent of deep neural networks. These days neural networks are starting to be used extensively in products that you use. A great example of the application of neural networks is the recently released Flickr automated image tagging. With these algorithms, Flickr is able to determine what tags ("kitten", "puppy") should be applied to each photo, without human involvement.

In this case the function takes an image as input and outputs a set of tags for that image:

$$ f(image) \rightarrow \{tag_1, \ldots\} $$

For the purpose of this project, good introductions to neural networks can be found at:

The Project

Your general goal is to write Python code to predict the number associated with handwritten digits. The dataset for these digits can be found in sklearn:


In [6]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.html.widgets import interact


:0: FutureWarning: IPython widgets are experimental and may change in the future.

In [15]:
from sklearn.datasets import load_digits
digits = load_digits()
print(digits.data.shape)


(1797, 64)

In [8]:
def show_digit(i):
    plt.matshow(digits.images[i]);

In [9]:
interact(show_digit, i=(0,100));


The actual, known values (0,1,2,3,4,5,6,7,8,9) associated with each image can be found in the target array:


In [14]:
digits.target[87]


Out[14]:
4

Here are some of the things you will need to do as part of this project:

  • Split the original data set into two parts: 1) a training set that you will use to train your neural network and 2) a test set you will use to see if your trained neural network can accurately predict previously unseen data.
  • Write Python code to implement the basic building blocks of neural networks. This code should be modular and fully tested. While you can look at the code examples in the above resources, your code should be your own creation and be substantially different. One way of ensuring your code is different is to make it more general.
  • Create appropriate data structures for the neural network.
  • Figure out how to initialize the weights of the neural network.
  • Write code to implement forward and back propagation.
  • Write code to train the network with the training set.

Your base question should be to get a basic version of your code working that can predict handwritten digits with an accuracy that is significantly better than that of random guessing.

Here are some ideas of questions you could explore as your two additional questions:

  • How to specify, train and use networks with more hidden layers.
  • The best way to determine the initial weights.
  • Making it all fast to handle more layers and neurons per layer (%timeit and %%timeit).
  • Explore different ways of optimizing the weights/output of the neural network.
  • Tackle the full MNIST benchmark of $10,000$ digits.
  • How different sigmoid function affect the results.

Implementation hints

There are optimization routines in scipy.optimize that may be helpful.

You should use NumPy arrays and fast NumPy operations (dot) everywhere that is possible.