This Notebook implements basic SoftMax Logistic Regression on the MNIST dataset using Numpy.


In [3]:
import numpy as np
import time
import os
import struct
import matplotlib.pyplot as plt
%matplotlib inline

Run softmax over a simple example with m=4 samples, C=3 classes and n=3 inputs.

Compute Y = X * W + b with C classes, n inputs and m samples.

X will have dimensions of (m,n), W of (n,C) and b of (1,C).

X * W will gives us a matrix of size m * C which is what we would expect and Y will be size m * C with one hot encoding.

The m samples will have inputs [0.6,0.3,0.1], [0.3,0.5,0.2], [0.0,0.1,0.9] and [0.1,0.6,0.4] with output classes of 0,1,2,1 respectively. The output class reflects the highest valued input index.


In [4]:
def InitExample():
    X=np.array([[0.6,0.3,0.1],[0.3,0.5,0.2],[0.0,0.1,0.9],[0.1,0.6,0.4]])
    labels=np.array([0,1,2,1])
    Y=np.eye(max(labels)+1)[labels]
    return X, Y, labels

The function InitWeights takes the X and Y matrices as inputs and initializes W and b based on their dimensions.


In [5]:
def InitWeights(X, Y):
    m, n = X.shape
    C = Y.shape[1]
    assert m == Y.shape[0], "X and Y do not have the same number of samples"
    W=np.zeros([n, C])
    b=np.zeros([1, C])
    return W, b

The following two functions compute the score and softmax given X, W and b.


In [6]:
def score(X, W, b):
    return np.dot(X, W) + b

In [7]:
def softmax(scores):
    exp_scores = np.exp(scores)
    s = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
    return s

Run a test to verify that we have the correct shape for W and b and that softmax is being computed correctly. Each row of softmax needs to add upto 1 and since W,b are initialized to all zeros, all the values should be the same.


In [8]:
X, Y, labels = InitExample()

print("X:", X)
print("X.shape:", X.shape)
print("Y:", Y)
print("Y.shape:", Y.shape)
print("labels:", labels)
print("labels.shape:", labels.shape)
m, n = X.shape
C = Y.shape[1]
print("n =", n, "m =", m, "C =", C)

W, b = InitWeights(X, Y)
print("W:", W)
print("W.shape:", W.shape)
print("b:", b)
print("b.shape:", b.shape)
softmax(score(X, W, b))


X: [[ 0.6  0.3  0.1]
 [ 0.3  0.5  0.2]
 [ 0.   0.1  0.9]
 [ 0.1  0.6  0.4]]
X.shape: (4, 3)
Y: [[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]
 [ 0.  1.  0.]]
Y.shape: (4, 3)
labels: [0 1 2 1]
labels.shape: (4,)
n = 3 m = 4 C = 3
W: [[ 0.  0.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]]
W.shape: (3, 3)
b: [[ 0.  0.  0.]]
b.shape: (1, 3)
Out[8]:
array([[ 0.33333333,  0.33333333,  0.33333333],
       [ 0.33333333,  0.33333333,  0.33333333],
       [ 0.33333333,  0.33333333,  0.33333333],
       [ 0.33333333,  0.33333333,  0.33333333]])

The ComputeCost function takes W,X,Y,b as inputs and returns the current softmax probabilities, cost and gradients for W and b.


In [9]:
def ComputeCost(W, X, Y, b):
    m = X.shape[0]
    grads = {}
    prob = softmax(score(X, W, b))
    cost = (-1 / m) * np.sum(Y * np.log(prob))
    dW = (-1 / m) * np.dot(X.T,(Y - prob))
    db = (-1 / m) * np.sum(Y - prob, axis=0)
    grads['dW'] = dW 
    grads['db'] = db
    return prob, cost, grads
prob, cost, grads = ComputeCost(W, X, Y, b)
print("cost:", cost)
print("prob:", prob)
print("dW:", grads['dW'])
print("db:", grads['db'])


cost: 1.09861228867
prob: [[ 0.33333333  0.33333333  0.33333333]
 [ 0.33333333  0.33333333  0.33333333]
 [ 0.33333333  0.33333333  0.33333333]
 [ 0.33333333  0.33333333  0.33333333]]
dW: [[-0.06666667 -0.01666667  0.08333333]
 [ 0.05       -0.15        0.1       ]
 [ 0.10833333 -0.01666667 -0.09166667]]
db: [ 0.08333333 -0.16666667  0.08333333]

The UpdateWeights functions updates W and b based on the gradients and learning rate.


In [10]:
def UpdateWeights(W, b, grads, learningRate):
    W = W - (learningRate * grads['dW'])
    b = b - (learningRate * grads['db'])
    return W,b

The Predictions function takes X, W and b as inputs and outputs a vector of predicted class labels


In [11]:
def Predictions(X, W, b ):
    probs = softmax(score(X, W, b))
    predictions = np.argmax(probs,axis=1)
    return predictions

The Accuracy function takes two vectors of class labels and computes the accuracy or in other words how many of them match.


In [12]:
def Accuracy(preds, labels):
    accuracy = sum(preds == labels)/(float(len(labels)))
    return accuracy

The TrainModels takes as input X, Y, labels, number of iterations and learning rate. Y is a one hot encoding of the labels vector. The function also takes as input a flag "verbose" whic if set to true will print intermediate results. The function computes the cost and gradients in each iteration and updates W and b.

It outputs W, b, cost history and accuracy. The cost history consists of 50 samples over the iterations and is useful for plotting and to select the optimal learning rate.


In [13]:
def TrainModel(X, Y, labels, iterations=500, learningRate=1e-2, verbose=False):
    W, b = InitWeights(X, Y)
    costHistory = []
    start = time.time()
    for i in range(0,iterations):
        prob, cost, grads = ComputeCost(W, X, Y, b)
        W,b = UpdateWeights(W, b, grads, learningRate)
        if (i) % (iterations/50) == 0:
            costHistory.append(cost)
            if verbose:
                print("iteration:", i)
                print("cost =", cost)
                preds = Predictions(X, W, b)
                accuracy = Accuracy(preds, labels)
                print("accuracy =", accuracy)
                print("time elapsed =", time.time() - start)
    preds = Predictions(X, W, b)
    accuracy = Accuracy(preds, labels)
    return W, b, costHistory, accuracy

Run the training model over the example inputs to find the optimal learning rate using 4 different values.


In [14]:
X, Y, labels = InitExample()
allCostHistory = {}
trainAccuracyHistory = {}
iterations=5000
for learningRate in [1e-4,1e-3,1e-2,1e-1]:
    W, b, costHistory, trainAccuracy = TrainModel(X, Y, labels, iterations, learningRate)
    trainAccuracyHistory[learningRate] = trainAccuracy
    allCostHistory[learningRate] = costHistory
    print("learningRate =", learningRate)
    print("W =", W)
    print("b = ", b)
    print("cost = ", costHistory[-1])
    print("training accuracy = ", trainAccuracy)
    print("\n\n")


learningRate = 0.0001
W = [[ 0.03420271  0.00601512 -0.04021782]
 [-0.02319332  0.0713422  -0.04814888]
 [-0.05166253  0.0048707   0.04679182]]
b =  [[-0.03663428  0.07415731 -0.03752303]]
cost =  1.05154860583
training accuracy =  0.5



learningRate = 0.001
W = [[ 0.37605304 -0.0546535  -0.32139953]
 [-0.14521265  0.5270905  -0.38187786]
 [-0.38987172 -0.10313858  0.4930103 ]]
b =  [[-0.12656411  0.3025828  -0.17601869]]
cost =  0.835099905223
training accuracy =  0.5



learningRate = 0.01
W = [[ 2.6082842  -1.18419409 -1.42409012]
 [-0.89369984  2.78744956 -1.89374972]
 [-1.78837749 -0.85382755  2.64220505]]
b =  [[ 0.10297433  0.37567159 -0.47864592]]
cost =  0.276258958105
training accuracy =  1.0



learningRate = 0.1
W = [[ 6.99242787 -4.28078502 -2.71164285]
 [-3.46650243  7.40122446 -3.93472203]
 [-3.68839194 -1.82434588  5.51273782]]
b =  [[ 0.13614253  0.56938868 -0.7055312 ]]
cost =  0.0364307166686
training accuracy =  1.0



From the output of the training model runs it is clear that learning rates of 0.01 and 0.1 converge faster and 0.001 and 0.0001 are too slow to converge even with 5000 iterations. The training accuracy goes to 1.0 with a learning rate of 0.01 or 0.1.

Lets now plot the cost history for all the 4 learning rates over the 5000 iterations.


In [15]:
print(len(allCostHistory))
for ch in allCostHistory:
    print(len(allCostHistory[ch]),ch)
    plt.plot(allCostHistory[ch], label=ch)
plt.legend()
plt.show()


4
50 0.0001
50 0.001
50 0.01
50 0.1

The plot shows that with a learning rate of 0.1, the cost converges the fastest and 0.01 is also making good progress over 5000 iterations. 0.0001 is hardly converging and 0.001 is very slow.

Lets us now run the model on the MNIST dataset.

We are going to cheat a little bit here and use tensorflow to load the dataset instead of writing numpy functions.


In [16]:
from tensorflow.examples.tutorials.mnist import input_data

In [17]:
mnist = input_data.read_data_sets("../datasets/MNIST/", one_hot=True)


Extracting ../datasets/MNIST/train-images-idx3-ubyte.gz
Extracting ../datasets/MNIST/train-labels-idx1-ubyte.gz
Extracting ../datasets/MNIST/t10k-images-idx3-ubyte.gz
Extracting ../datasets/MNIST/t10k-labels-idx1-ubyte.gz

The tensorflow data loader splits the 60000 images in the MNIST training set into a 55000+5000 examples where the set of 5000 is used as a validation set.


In [18]:
print(mnist.train.num_examples)
print(mnist.validation.num_examples)
print(mnist.test.num_examples)


55000
5000
10000

In [19]:
X = mnist.train.images
Y = mnist.train.labels
labels = np.argmax(Y,axis=1)
print(X.shape,Y.shape,labels.shape)

W, b = InitWeights(X, Y)
print(W.shape,b.shape)


(55000, 784) (55000, 10) (55000,)
(784, 10) (1, 10)

InitWeights sets all the weights to zero. Running Predictions with these values should predict everything to the "0" class. If the images are uniformly distributed over the 10 classes, we should see an accuracy value of 0.1. Here we get a value of 0.0989 which is pretty close.


In [20]:
preds = Predictions(X, W, b)
accuracy = Accuracy(preds, labels)
print("accuracy =", accuracy)
for i in range(10):
    print(preds[10000+i],labels[10000 + i])


accuracy = 0.0989818181818
0 5
0 1
0 1
0 3
0 6
0 2
0 6
0 3
0 3
0 1

Find the best learning rate among 0.1, 1.0 and 10 by running 100 iterations of the model for each of them.


In [21]:
allCostHistory = {}
trainAccuracyHistory = {}
iterations=100
for learningRate in [0.1,1,10]:
    W, b, costHistory, trainAccuracy = TrainModel(X, Y, labels, iterations, learningRate,verbose=True)
    trainAccuracyHistory[learningRate] = trainAccuracy
    allCostHistory[learningRate] = costHistory
    print("learningRate =", learningRate)
    print("W =", W)
    print("b = ", b)
    print("cost = ", costHistory[-1])
    print("training accuracy = ", trainAccuracy)
    print("\n\n")


iteration: 0
cost = 2.30258509299
accuracy = 0.662927272727
time elapsed = 1.2999680042266846
iteration: 2
cost = 2.10085950877
accuracy = 0.722036363636
time elapsed = 3.2730071544647217
iteration: 4
cost = 1.92974130319
accuracy = 0.744636363636
time elapsed = 5.312549829483032
iteration: 6
cost = 1.78266240095
accuracy = 0.758890909091
time elapsed = 7.004565954208374
iteration: 8
cost = 1.65629941081
accuracy = 0.7708
time elapsed = 8.681689023971558
iteration: 10
cost = 1.5477180063
accuracy = 0.779636363636
time elapsed = 11.538990020751953
iteration: 12
cost = 1.45421471527
accuracy = 0.787854545455
time elapsed = 14.34181809425354
iteration: 14
cost = 1.37338850518
accuracy = 0.794654545455
time elapsed = 16.281435012817383
iteration: 16
cost = 1.30317504621
accuracy = 0.799290909091
time elapsed = 18.571620225906372
iteration: 18
cost = 1.2418386614
accuracy = 0.804636363636
time elapsed = 20.20599603652954
iteration: 20
cost = 1.18794019755
accuracy = 0.809163636364
time elapsed = 21.88727617263794
iteration: 22
cost = 1.14029542161
accuracy = 0.812945454545
time elapsed = 23.67801022529602
iteration: 24
cost = 1.09793283905
accuracy = 0.8162
time elapsed = 25.36402201652527
iteration: 26
cost = 1.0600553798
accuracy = 0.819145454545
time elapsed = 27.975795030593872
iteration: 28
cost = 1.02600759824
accuracy = 0.821690909091
time elapsed = 30.263440132141113
iteration: 30
cost = 0.995248529971
accuracy = 0.823981818182
time elapsed = 32.57297897338867
iteration: 32
cost = 0.9673296645
accuracy = 0.826218181818
time elapsed = 35.06128406524658
iteration: 34
cost = 0.941877261124
accuracy = 0.828472727273
time elapsed = 39.588820934295654
iteration: 36
cost = 0.918578225284
accuracy = 0.830127272727
time elapsed = 42.47590923309326
iteration: 38
cost = 0.897168846329
accuracy = 0.832127272727
time elapsed = 44.285019874572754
iteration: 40
cost = 0.877425809647
accuracy = 0.833472727273
time elapsed = 46.13977003097534
iteration: 42
cost = 0.859159006649
accuracy = 0.835236363636
time elapsed = 48.670029163360596
iteration: 44
cost = 0.84220576323
accuracy = 0.836690909091
time elapsed = 50.665913105010986
iteration: 46
cost = 0.826426187808
accuracy = 0.838363636364
time elapsed = 53.031917095184326
iteration: 48
cost = 0.811699404864
accuracy = 0.839545454545
time elapsed = 55.065529108047485
iteration: 50
cost = 0.797920490914
accuracy = 0.840836363636
time elapsed = 57.384462118148804
iteration: 52
cost = 0.78499796982
accuracy = 0.841763636364
time elapsed = 60.115939140319824
iteration: 54
cost = 0.772851755335
accuracy = 0.842890909091
time elapsed = 62.67471480369568
iteration: 56
cost = 0.761411452842
accuracy = 0.844090909091
time elapsed = 64.64207410812378
iteration: 58
cost = 0.750614950903
accuracy = 0.8456
time elapsed = 66.68352317810059
iteration: 60
cost = 0.740407247715
accuracy = 0.846727272727
time elapsed = 68.75339102745056
iteration: 62
cost = 0.73073946887
accuracy = 0.847763636364
time elapsed = 70.64919304847717
iteration: 64
cost = 0.721568041615
accuracy = 0.848890909091
time elapsed = 72.483314037323
iteration: 66
cost = 0.712853997717
accuracy = 0.8496
time elapsed = 74.29001307487488
iteration: 68
cost = 0.704562382485
accuracy = 0.850436363636
time elapsed = 76.22498512268066
iteration: 70
cost = 0.696661751801
accuracy = 0.851054545455
time elapsed = 78.05432605743408
iteration: 72
cost = 0.689123742404
accuracy = 0.852018181818
time elapsed = 79.81446981430054
iteration: 74
cost = 0.681922703406
accuracy = 0.852836363636
time elapsed = 81.42067503929138
iteration: 76
cost = 0.675035379201
accuracy = 0.853454545455
time elapsed = 83.19182586669922
iteration: 78
cost = 0.66844063565
accuracy = 0.854127272727
time elapsed = 84.9247579574585
iteration: 80
cost = 0.662119222869
accuracy = 0.855018181818
time elapsed = 86.66642093658447
iteration: 82
cost = 0.656053569066
accuracy = 0.855872727273
time elapsed = 88.37172722816467
iteration: 84
cost = 0.650227600818
accuracy = 0.856327272727
time elapsed = 89.98332500457764
iteration: 86
cost = 0.644626585934
accuracy = 0.856872727273
time elapsed = 91.91024684906006
iteration: 88
cost = 0.639236995681
accuracy = 0.857436363636
time elapsed = 94.34327816963196
iteration: 90
cost = 0.63404638364
accuracy = 0.858
time elapsed = 97.79269909858704
iteration: 92
cost = 0.629043278932
accuracy = 0.858527272727
time elapsed = 101.23467803001404
iteration: 94
cost = 0.624217091838
accuracy = 0.859036363636
time elapsed = 103.32199001312256
iteration: 96
cost = 0.619558030191
accuracy = 0.859472727273
time elapsed = 105.76767301559448
iteration: 98
cost = 0.615057025131
accuracy = 0.860054545455
time elapsed = 107.89231204986572
learningRate = 0.1
W = [[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
b =  [[-0.04946891  0.10633936 -0.02627025 -0.03055497  0.0288012   0.0567122
  -0.01231484  0.04984078 -0.10961898 -0.0134656 ]]
cost =  0.615057025131
training accuracy =  0.860309090909



iteration: 0
cost = 2.30258509299
accuracy = 0.662927272727
time elapsed = 1.5737671852111816
iteration: 2
cost = 1.40019549247
accuracy = 0.505163636364
time elapsed = 3.7213292121887207
iteration: 4
cost = 1.73287619432
accuracy = 0.591290909091
time elapsed = 5.88126015663147
iteration: 6
cost = 1.39681274813
accuracy = 0.703436363636
time elapsed = 8.520270109176636
iteration: 8
cost = 0.889096711709
accuracy = 0.6682
time elapsed = 10.528398036956787
iteration: 10
cost = 1.10687543674
accuracy = 0.732490909091
time elapsed = 12.855698108673096
iteration: 12
cost = 0.774038098858
accuracy = 0.778927272727
time elapsed = 14.80822491645813
iteration: 14
cost = 0.702809050579
accuracy = 0.788272727273
time elapsed = 16.873213291168213
iteration: 16
cost = 0.638340156308
accuracy = 0.802854545455
time elapsed = 18.937928199768066
iteration: 18
cost = 0.634254149139
accuracy = 0.772818181818
time elapsed = 21.01507806777954
iteration: 20
cost = 0.630632074221
accuracy = 0.780127272727
time elapsed = 23.04680609703064
iteration: 22
cost = 0.584287464717
accuracy = 0.807418181818
time elapsed = 25.061519145965576
iteration: 24
cost = 0.542559042124
accuracy = 0.830018181818
time elapsed = 26.874672174453735
iteration: 26
cost = 0.506918796903
accuracy = 0.849909090909
time elapsed = 29.452399015426636
iteration: 28
cost = 0.477181456502
accuracy = 0.863454545455
time elapsed = 31.216227054595947
iteration: 30
cost = 0.454160427108
accuracy = 0.872890909091
time elapsed = 33.62435507774353
iteration: 32
cost = 0.437498463173
accuracy = 0.879327272727
time elapsed = 35.86733102798462
iteration: 34
cost = 0.425762528162
accuracy = 0.883109090909
time elapsed = 37.88940405845642
iteration: 36
cost = 0.417334264754
accuracy = 0.885763636364
time elapsed = 40.20981311798096
iteration: 38
cost = 0.41097232372
accuracy = 0.887181818182
time elapsed = 42.30674910545349
iteration: 40
cost = 0.405882711453
accuracy = 0.888181818182
time elapsed = 44.36179709434509
iteration: 42
cost = 0.401596514748
accuracy = 0.889272727273
time elapsed = 46.329506158828735
iteration: 44
cost = 0.397843121543
accuracy = 0.890636363636
time elapsed = 48.6947660446167
iteration: 46
cost = 0.394464791597
accuracy = 0.891636363636
time elapsed = 51.73757219314575
iteration: 48
cost = 0.391366783685
accuracy = 0.892654545455
time elapsed = 54.83999729156494
iteration: 50
cost = 0.388489772139
accuracy = 0.893490909091
time elapsed = 58.342057943344116
iteration: 52
cost = 0.38579478843
accuracy = 0.894254545455
time elapsed = 61.36375308036804
iteration: 54
cost = 0.383254923872
accuracy = 0.8946
time elapsed = 63.861733198165894
iteration: 56
cost = 0.380850658818
accuracy = 0.8952
time elapsed = 66.98782992362976
iteration: 58
cost = 0.37856716282
accuracy = 0.895909090909
time elapsed = 70.8203330039978
iteration: 60
cost = 0.376392691773
accuracy = 0.896236363636
time elapsed = 73.2712082862854
iteration: 62
cost = 0.374317612392
accuracy = 0.896818181818
time elapsed = 75.5718560218811
iteration: 64
cost = 0.372333794492
accuracy = 0.897327272727
time elapsed = 77.90771198272705
iteration: 66
cost = 0.370434223279
accuracy = 0.897745454545
time elapsed = 80.1128191947937
iteration: 68
cost = 0.368612745242
accuracy = 0.898327272727
time elapsed = 82.13282918930054
iteration: 70
cost = 0.366863896187
accuracy = 0.899018181818
time elapsed = 84.80580806732178
iteration: 72
cost = 0.365182780429
accuracy = 0.899545454545
time elapsed = 87.81743216514587
iteration: 74
cost = 0.363564982387
accuracy = 0.899854545455
time elapsed = 91.00881099700928
iteration: 76
cost = 0.362006499217
accuracy = 0.9002
time elapsed = 93.50413799285889
iteration: 78
cost = 0.360503687548
accuracy = 0.9006
time elapsed = 96.53227019309998
iteration: 80
cost = 0.359053220075
accuracy = 0.900890909091
time elapsed = 99.38571429252625
iteration: 82
cost = 0.357652049349
accuracy = 0.901127272727
time elapsed = 102.75490927696228
iteration: 84
cost = 0.35629737704
accuracy = 0.901436363636
time elapsed = 105.77681612968445
iteration: 86
cost = 0.354986627531
accuracy = 0.901745454545
time elapsed = 108.4348258972168
iteration: 88
cost = 0.353717425027
accuracy = 0.902
time elapsed = 112.04065704345703
iteration: 90
cost = 0.352487573582
accuracy = 0.902218181818
time elapsed = 114.99251413345337
iteration: 92
cost = 0.351295039583
accuracy = 0.902636363636
time elapsed = 117.54528212547302
iteration: 94
cost = 0.350137936323
accuracy = 0.902927272727
time elapsed = 119.97918891906738
iteration: 96
cost = 0.349014510372
accuracy = 0.903090909091
time elapsed = 121.94662714004517
iteration: 98
cost = 0.347923129484
accuracy = 0.9034
time elapsed = 123.97607493400574
learningRate = 1
W = [[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
b =  [[-0.15143402  0.24221035 -0.02417535 -0.11578452  0.06370963  0.45914952
  -0.03383453  0.2370799  -0.59492888 -0.0819921 ]]
cost =  0.347923129484
training accuracy =  0.903581818182



iteration: 0
cost = 2.30258509299
accuracy = 0.662927272727
time elapsed = 1.489305019378662
iteration: 2
cost = 20.2804727806
accuracy = 0.485381818182
time elapsed = 4.092714071273804
iteration: 4
cost = 41.1168051189
accuracy = 0.404345454545
time elapsed = 6.700016021728516
iteration: 6
cost = 40.0070815612
accuracy = 0.280381818182
time elapsed = 9.517518043518066
iteration: 8
cost = 39.1940183988
accuracy = 0.460690909091
time elapsed = 11.543292999267578
iteration: 10
cost = 26.9832349953
accuracy = 0.544218181818
time elapsed = 13.556303977966309
iteration: 12
cost = 13.709421489
accuracy = 0.732018181818
time elapsed = 15.715801000595093
iteration: 14
cost = 3.49892593543
accuracy = 0.809145454545
time elapsed = 18.30279302597046
iteration: 16
cost = 2.52176784557
accuracy = 0.838854545455
time elapsed = 20.19260811805725
iteration: 18
cost = 1.99180647781
accuracy = 0.859818181818
time elapsed = 22.329620122909546
iteration: 20
cost = 1.70513524326
accuracy = 0.8698
time elapsed = 25.011372089385986
iteration: 22
cost = 1.57079727518
accuracy = 0.875145454545
time elapsed = 27.14105987548828
iteration: 24
cost = 1.47754520425
accuracy = 0.878272727273
time elapsed = 29.15275812149048
iteration: 26
cost = 1.39924156895
accuracy = 0.881563636364
time elapsed = 31.23005509376526
iteration: 28
cost = 1.33084528759
accuracy = 0.884090909091
time elapsed = 33.40209603309631
iteration: 30
cost = 1.27202479475
accuracy = 0.886109090909
time elapsed = 35.68943810462952
iteration: 32
cost = 1.2263910727
accuracy = 0.886636363636
time elapsed = 37.7808620929718
iteration: 34
cost = 1.21055655529
accuracy = 0.880890909091
time elapsed = 40.56387996673584
iteration: 36
cost = 1.3475358968
accuracy = 0.844163636364
time elapsed = 43.22394919395447
iteration: 38
cost = 3.19951085125
accuracy = 0.733963636364
time elapsed = 45.66482591629028
iteration: 40
cost = 7.21955821233
accuracy = 0.836509090909
time elapsed = 49.22149300575256
iteration: 42
cost = 1.92240658339
accuracy = 0.796
time elapsed = 52.98125600814819
iteration: 44
cost = 4.62247402041
accuracy = 0.782236363636
time elapsed = 56.002532958984375
iteration: 46
cost = 1.91896698291
accuracy = 0.835327272727
time elapsed = 58.994791984558105
iteration: 48
cost = 2.66923091704
accuracy = 0.799636363636
time elapsed = 60.88868808746338
iteration: 50
cost = 2.99038436717
accuracy = 0.825527272727
time elapsed = 62.73026895523071
iteration: 52
cost = 2.30321724046
accuracy = 0.830163636364
time elapsed = 64.70345711708069
iteration: 54
cost = 2.11832642236
accuracy = 0.840054545455
time elapsed = 68.71608781814575
iteration: 56
cost = 1.80498666809
accuracy = 0.858854545455
time elapsed = 70.45603489875793
iteration: 58
cost = 1.42769378765
accuracy = 0.878909090909
time elapsed = 72.1028470993042
iteration: 60
cost = 1.17756534961
accuracy = 0.892527272727
time elapsed = 73.88221192359924
iteration: 62
cost = 1.06054748345
accuracy = 0.896745454545
time elapsed = 75.60078191757202
iteration: 64
cost = 0.999395499586
accuracy = 0.899472727273
time elapsed = 77.32326889038086
iteration: 66
cost = 0.957014958327
accuracy = 0.9012
time elapsed = 80.02597689628601
iteration: 68
cost = 0.923084827946
accuracy = 0.902309090909
time elapsed = 83.03808808326721
iteration: 70
cost = 0.894644326612
accuracy = 0.902709090909
time elapsed = 85.25780701637268
iteration: 72
cost = 0.872026679073
accuracy = 0.9026
time elapsed = 87.41238403320312
iteration: 74
cost = 0.858987236753
accuracy = 0.901563636364
time elapsed = 89.55886507034302
iteration: 76
cost = 0.867820955905
accuracy = 0.896290909091
time elapsed = 91.40787696838379
iteration: 78
cost = 0.955939003306
accuracy = 0.8712
time elapsed = 95.66158699989319
iteration: 80
cost = 1.82842153602
accuracy = 0.795945454545
time elapsed = 98.79932022094727
iteration: 82
cost = 2.18756704866
accuracy = 0.677781818182
time elapsed = 101.56900906562805
iteration: 84
cost = 10.6097008819
accuracy = 0.750563636364
time elapsed = 105.11271619796753
iteration: 86
cost = 3.23380987149
accuracy = 0.804563636364
time elapsed = 108.05067610740662
iteration: 88
cost = 2.15756439831
accuracy = 0.824309090909
time elapsed = 109.9307210445404
iteration: 90
cost = 3.028721887
accuracy = 0.848490909091
time elapsed = 113.25370812416077
iteration: 92
cost = 1.97713992732
accuracy = 0.817763636364
time elapsed = 116.42532300949097
iteration: 94
cost = 2.87637749534
accuracy = 0.801436363636
time elapsed = 119.12271404266357
iteration: 96
cost = 2.91636743806
accuracy = 0.818781818182
time elapsed = 122.07815790176392
iteration: 98
cost = 3.23256898863
accuracy = 0.813
time elapsed = 125.70973300933838
learningRate = 10
W = [[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
b =  [[-1.07677289  0.82992739  0.60228153 -0.48873757  0.37426064  3.74550035
   0.05545659  1.9182902  -5.345129   -0.61507724]]
cost =  3.23256898863
training accuracy =  0.855672727273



The plot shows that with a learning rate of 10, the cost does not go down monotonically and keep oscillating which points to it being too high. Learning rates of 0.1 and 1 show similar behaviour with 1 slight better in terms of lower cost. So we will use that for running the model for more epochs till we get good accuracy.


In [22]:
for ch in allCostHistory:
    plt.plot(allCostHistory[ch], label=ch)
plt.legend()
plt.show()


Run the model for a 1000 epochs on the training set.


In [20]:
allCostHistory = {}
trainAccuracyHistory = {}
iterations=1000
for learningRate in [1]:
    W, b, costHistory, trainAccuracy = TrainModel(X, Y, labels, iterations, learningRate,verbose=True)
    trainAccuracyHistory[learningRate] = trainAccuracy
    allCostHistory[learningRate] = costHistory
    print("learningRate =", learningRate)
    print("W =", W)
    print("b = ", b)
    print("cost = ", costHistory[-1])
    print("training accuracy = ", trainAccuracy)
    print("\n\n")
for ch in allCostHistory:
    plt.plot(allCostHistory[ch], label=ch)
plt.legend()
plt.show()


iteration: 0
cost = 2.30258509299
accuracy = 0.662927272727
time elapsed = 1.7432122230529785
iteration: 20
cost = 0.630632074221
accuracy = 0.780127272727
time elapsed = 15.271590232849121
iteration: 40
cost = 0.405882711453
accuracy = 0.888181818182
time elapsed = 27.892492055892944
iteration: 60
cost = 0.376392691773
accuracy = 0.896236363636
time elapsed = 40.44439721107483
iteration: 80
cost = 0.359053220075
accuracy = 0.900890909091
time elapsed = 53.15896224975586
iteration: 100
cost = 0.346862271846
accuracy = 0.903818181818
time elapsed = 65.97528910636902
iteration: 120
cost = 0.337660104572
accuracy = 0.906363636364
time elapsed = 78.82346820831299
iteration: 140
cost = 0.33038492847
accuracy = 0.908272727273
time elapsed = 91.37506413459778
iteration: 160
cost = 0.324440572934
accuracy = 0.910127272727
time elapsed = 104.74938917160034
iteration: 180
cost = 0.319461734922
accuracy = 0.911581818182
time elapsed = 117.3792781829834
iteration: 200
cost = 0.315209988465
accuracy = 0.912854545455
time elapsed = 130.63322615623474
iteration: 220
cost = 0.311522052898
accuracy = 0.913818181818
time elapsed = 143.22618317604065
iteration: 240
cost = 0.308281782854
accuracy = 0.914781818182
time elapsed = 155.76766419410706
iteration: 260
cost = 0.305403979085
accuracy = 0.915527272727
time elapsed = 168.4058141708374
iteration: 280
cost = 0.302824529928
accuracy = 0.916054545455
time elapsed = 180.8030662536621
iteration: 300
cost = 0.300494144978
accuracy = 0.916636363636
time elapsed = 193.45475816726685
iteration: 320
cost = 0.298374226674
accuracy = 0.9172
time elapsed = 206.29236817359924
iteration: 340
cost = 0.296434066358
accuracy = 0.917836363636
time elapsed = 219.08929324150085
iteration: 360
cost = 0.294648889456
accuracy = 0.918218181818
time elapsed = 231.80854415893555
iteration: 380
cost = 0.292998461364
accuracy = 0.918527272727
time elapsed = 244.7593011856079
iteration: 400
cost = 0.29146607325
accuracy = 0.918927272727
time elapsed = 257.7118830680847
iteration: 420
cost = 0.290037791179
accuracy = 0.919327272727
time elapsed = 270.89046812057495
iteration: 440
cost = 0.288701891429
accuracy = 0.919672727273
time elapsed = 283.5819251537323
iteration: 460
cost = 0.287448429844
accuracy = 0.920090909091
time elapsed = 296.2090051174164
iteration: 480
cost = 0.286268909232
accuracy = 0.920654545455
time elapsed = 308.74913930892944
iteration: 500
cost = 0.285156019515
accuracy = 0.920945454545
time elapsed = 322.857360124588
iteration: 520
cost = 0.284103432597
accuracy = 0.921127272727
time elapsed = 341.2735800743103
iteration: 540
cost = 0.283105638851
accuracy = 0.9214
time elapsed = 362.0401883125305
iteration: 560
cost = 0.282157815638
accuracy = 0.921690909091
time elapsed = 379.866571187973
iteration: 580
cost = 0.281255720704
accuracy = 0.921890909091
time elapsed = 396.413702249527
iteration: 600
cost = 0.280395605105
accuracy = 0.922054545455
time elapsed = 409.76110911369324
iteration: 620
cost = 0.279574141562
accuracy = 0.922272727273
time elapsed = 422.18694710731506
iteration: 640
cost = 0.278788365142
accuracy = 0.9224
time elapsed = 434.48607325553894
iteration: 660
cost = 0.27803562382
accuracy = 0.922836363636
time elapsed = 447.0326442718506
iteration: 680
cost = 0.277313537055
accuracy = 0.923109090909
time elapsed = 459.47134733200073
iteration: 700
cost = 0.276619960882
accuracy = 0.923327272727
time elapsed = 471.79906940460205
iteration: 720
cost = 0.275952958342
accuracy = 0.923509090909
time elapsed = 484.3128662109375
iteration: 740
cost = 0.275310774322
accuracy = 0.9236
time elapsed = 498.205126285553
iteration: 760
cost = 0.27469181404
accuracy = 0.923709090909
time elapsed = 510.48304414749146
iteration: 780
cost = 0.274094624569
accuracy = 0.924018181818
time elapsed = 522.9618771076202
iteration: 800
cost = 0.273517878909
accuracy = 0.924090909091
time elapsed = 535.485894203186
iteration: 820
cost = 0.272960362195
accuracy = 0.924127272727
time elapsed = 547.8997843265533
iteration: 840
cost = 0.272420959723
accuracy = 0.924345454545
time elapsed = 560.344306230545
iteration: 860
cost = 0.271898646497
accuracy = 0.924527272727
time elapsed = 572.9929621219635
iteration: 880
cost = 0.271392478094
accuracy = 0.924690909091
time elapsed = 585.5207452774048
iteration: 900
cost = 0.270901582641
accuracy = 0.924927272727
time elapsed = 597.963561296463
iteration: 920
cost = 0.270425153757
accuracy = 0.925163636364
time elapsed = 610.4837062358856
iteration: 940
cost = 0.269962444313
accuracy = 0.925254545455
time elapsed = 622.7845861911774
iteration: 960
cost = 0.269512760912
accuracy = 0.925381818182
time elapsed = 634.9940581321716
iteration: 980
cost = 0.26907545899
accuracy = 0.925490909091
time elapsed = 647.3423731327057
learningRate = 1
W = [[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
b =  [[-0.52863148  0.4338832   0.15959347 -0.353872   -0.03684081  1.72875834
  -0.15982544  0.85336372 -1.78453263 -0.31189636]]
cost =  0.26907545899
training accuracy =  0.925581818182



The training accuracy got to around 92.5% and is making very small improvements after that. The training accuracy hit 91.6% after 300 epochs and hit 92% after 460 epoch for comparison. It took 659.55 sec to run the model for a 1000 epochs.

Lets look at a few predictions and try to find a misprediction.


In [25]:
preds = Predictions(X, W, b)
accuracy = Accuracy(preds, labels)
print("training accuracy =", accuracy)
for i in range(10):
    print("prediction =",preds[30020+i],"label =",labels[30020 + i])


training accuracy = 0.925581818182
prediction = 2 label = 2
prediction = 8 label = 8
prediction = 3 label = 3
prediction = 2 label = 2
prediction = 9 label = 9
prediction = 8 label = 8
prediction = 4 label = 6
prediction = 6 label = 6
prediction = 0 label = 0
prediction = 5 label = 5

Training example 30026 was incorrectly classified as a "4" instead of a "6". Looking at the image shows why the classifier might have gotten it wrong. Example 30027 which was correctly classied as "6" is much easier to classify.


In [26]:
plt.imshow(mnist.train.images[30026].reshape(28,28),cmap="Greys")
plt.show()
plt.imshow(mnist.train.images[30027].reshape(28,28),cmap="Greys")
plt.show()


Compute the accuracy of predictions on the validation and test sets.


In [27]:
XV = mnist.validation.images
YV = mnist.validation.labels
labelsV = np.argmax(YV,axis=1)
predsV = Predictions(XV, W, b)
accuracyV = Accuracy(predsV, labelsV)
print("validation accuracy =", accuracyV)


validation accuracy = 0.9274

In [28]:
XT = mnist.test.images
YT = mnist.test.labels
labelsT = np.argmax(YT,axis=1)
predsT = Predictions(XT, W, b)
accuracyT = Accuracy(predsT, labelsT)
print("Test accuracy =", accuracyT)


Test accuracy = 0.9228

The validation and test accuracy values are pretty close to the training accuracy which indicates that there was no overfitting. It is possible that running the model longer could improve the accuracy but at a very slow rate and does not seem worthwhile.

We did achieve the same accuracy as the TensorFlow SoftMax tutorial except that it took 200X the time and number of epochs to get there.

The next step will be to implement mini batch gradient descent to move towards the minima faster with fewer epochs.