In [371]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from IPython.html.widgets import interact

from sklearn.datasets import load_digits
digits = load_digits()

In [372]:
def sigmoid(x):
    return 1/(1 + np.exp(-x))

sigmoid_v = np.vectorize(sigmoid)

def sigmoidprime(x):
    return sigmoid(x) * (1 - sigmoid(x))

sigmoidprime_v = np.vectorize(sigmoidprime)

In [373]:
size = [64, 20, 10]

weights = []
for n in range(1, len(size)):
    weights.append(np.random.rand(size[n-1], size[n]) * 2 - 1)

biases = []
for n in range(1, len(size)):
    biases.append(np.random.rand(size[n]) * 2 - 1)

trainingdata = digits.data[0:1200]
traininganswers = digits.target[0:1200]
lc = 0.02

#convert the integer answers into a 10-dimension array
traininganswervectors = np.zeros((1796,10))
for n in range(1796):
    traininganswervectors[n][digits.target[n]] = 1

In [374]:
def feedforward(a, weights, biases):
    b = []
    #first element is inputs "a"
    b.append(a)
    for n in range(1, len(size)):
        #all other elements depend on the number of neurons
        b.append(np.zeros(size[n]))
        for n2 in range(0, size[n]):
            b[n][n2] = sigmoid_v(np.dot(weights[n-1][0:,n2], b[n-1]) + biases[n-1][n2])
      
    return b

In [375]:
feedforward(trainingdata[0], weights, biases)


Out[375]:
[array([  0.,   0.,   5.,  13.,   9.,   1.,   0.,   0.,   0.,   0.,  13.,
         15.,  10.,  15.,   5.,   0.,   0.,   3.,  15.,   2.,   0.,  11.,
          8.,   0.,   0.,   4.,  12.,   0.,   0.,   8.,   8.,   0.,   0.,
          5.,   8.,   0.,   0.,   9.,   8.,   0.,   0.,   4.,  11.,   0.,
          1.,  12.,   7.,   0.,   0.,   2.,  14.,   5.,  10.,  12.,   0.,
          0.,   0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]),
 array([  9.99999748e-01,   8.98553537e-04,   2.69396153e-23,
          1.00000000e+00,   1.36074434e-32,   8.18070554e-23,
          1.00000000e+00,   1.00000000e+00,   3.97928335e-20,
          1.78323070e-18,   1.53818826e-24,   9.99994182e-01,
          9.96131193e-01,   3.20678543e-01,   9.99949835e-01,
          1.00000000e+00,   1.26558022e-11,   1.50519342e-03,
          1.61018918e-13,   9.72589484e-17]),
 array([ 0.34792334,  0.26626077,  0.72389612,  0.10852067,  0.32030888,
         0.47422726,  0.9414225 ,  0.38432016,  0.94768218,  0.46550206])]

In [376]:
def GradientDescent(inputs, answers, weights, biases, batchsize, lc, epochs):
    for n in range(epochs):
        #pick random locations for input/result data
        locations = np.random.randint(0, len(inputs), batchsize)
        minibatch = []
        #create tuples (inputs, result) based on random locations
        for n2 in range(batchsize):
            minibatch.append((inputs[locations[n2]], answers[locations[n2]]))
        for n3 in range(batchsize):
            train(minibatch, weights, biases, lc)
        
        results = []
        for n4 in range(len(trainingdata)):
            results.append(feedforward(inputs[n4], weights, biases)[-1])
            
        accresult = accuracy(inputs, results, answers)
        print("Epoch ", n, " : ", accresult)

In [ ]:


In [397]:
def train(minibatch, weights, biases, lc):
    #set the nabla functions to be the functions themselves initially, same size
    nb = [np.zeros(b.shape) for b in biases]
    nw = [np.zeros(w.shape) for w in weights]
    #largely taken from Michael Nielsen's implementation
    for i, r in minibatch:
        dnb, dnw = backprop(i, r)
        nb = [a+b for a, b in zip(nb, dnb)]
        nw = [a+b for a, b in zip(nw, dnw)]
    
    print(weights[0][0])
    
    #not changing?
    weights = [w-(lc/len(minibatch))*nw_ for w, nw_ in zip(weights, nw)]
    biases = [b-(lc/len(minibatch))*nb_ for b, nb_ in zip(biases, nb)]

In [378]:
def backprop(inputs, answers):
    #set the nabla functions to be the functions themselves initially, same size
    nb = [np.zeros(b.shape) for b in biases]
    nw = [np.zeros(w.shape) for w in weights]
    a = inputs
    alist = [inputs]
    zlist = []
    #from feedforward
    for n in range(1, len(size)):
        #all other elements depend on the number of neurons
        alist.append(np.zeros(size[n]))
        zlist.append(np.zeros(size[n]))
        for n2 in range(0, size[n]):
            alist[n][n2] = np.dot(weights[n-1][0:,n2], alist[n-1]) + biases[n-1][n2]
            zlist[n-1][n2] = sigmoid_v(alist[n][n2])
    
    delta = costderivative(alist[-1], answers) * sigmoidprime_v(zlist[-1])
    nb[-1] = delta
    nw[-1] = np.dot(delta, alist[-1].transpose())
    
    for n in range(2, len(size)):
        delta = np.dot(weights[-n+1], delta) * sigmoidprime_v(zlist[-n])
        nb[-n] = delta
        nw[-n] = np.dot(delta, alist[-n].transpose())
    
    return (nb, nw)

In [379]:
def costderivative(output, answers):
    return (output - answers)

In [380]:
def accuracy(inputs, results, answers):
    correct = 0
    binresults = results
    for n in range(0, len(results)):
        #converts the output into a binary y/n for each digit
        for n2 in range(len(results[n])):
            if results[n][n2] == np.amax(results[n]):
                binresults[n][n2] = 1
            else:
                binresults[n][n2] = 0
        
        if np.array_equal(answers[n], binresults[n]):
            correct += 1
    return correct / len(results)

In [381]:
trainingdata = digits.data[0:100]
traininganswers = digits.target[0:100]

traininganswervectors = np.zeros((100,10))
for n in range(100):
    traininganswervectors[n][digits.target[n]] = 1

In [398]:
GradientDescent(trainingdata, traininganswervectors, weights, biases, 5, 0.1, 10)


[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  0  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  1  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  2  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  3  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  4  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  5  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  6  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  7  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  8  :  0.09
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
[ 0.87299774 -0.25864325 -0.55462275 -0.09420273 -0.84490824  0.6782452
  0.05929519 -0.66297401  0.99509137  0.31726547  0.13995765  0.77248254
 -0.38440848 -0.27380042 -0.54618856  0.26302363 -0.3260963  -0.11028518
  0.36211248  0.75764216]
Epoch  9  :  0.09

In [ ]:


In [ ]: