Backpropagation

This is meant to deepen the understanding of backpropagation and (stochastic) gradient descent in NN.

Softmax Linear Classifier

Initially a linear classifier, then move to 2-layer NN.


In [62]:
import numpy as np
from matplotlib import pyplot as plt

Normally we would want to preprocess the dataset so that each feature has zero mean and unit standard deviation, but in this case the features are already in a nice range from -1 to 1, so we skip this step.


In [75]:
# Generate a spiral dataset
N = 100  # number of points per class
D = 2  # dimensionality
K = 3  # number of classes
X = np.zeros((N * K, D))  # data matrix (each row = single example)
y = np.zeros(N * K, dtype='uint8')  # class labels
for j in range(K):
    ix = range(N * j, N * (j + 1))
    r = np.linspace(0.0, 1, N)  # radius
    t = np.linspace(j * 4, (j + 1) * 4, N) + np.random.randn(N) * 0.2  # theta
    X[ix] = np.c_[r * np.sin(t), r * np.cos(t)]
    y[ix] = j
# lets visualize the data:
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.Spectral)
plt.show()



In [76]:
# initialize parameters randomly
W = 0.01 * np.random.randn(D, K)
b = np.zeros((1, K))

In [77]:
scores = np.dot(X, W) + b
scores


Out[77]:
array([[  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
       [ -7.81081207e-05,  -4.05397930e-06,   1.26890690e-04],
       [ -1.43432414e-04,  -1.61111533e-05,   2.28869536e-04],
       [ -1.91544072e-04,  -3.35339727e-05,   2.99892906e-04],
       [ -2.61452634e-04,  -4.26355684e-05,   4.10846064e-04],
       [ -3.62826642e-04,  -3.81799830e-05,   5.80179486e-04],
       [ -4.50898169e-04,  -3.71717214e-05,   7.25923933e-04],
       [ -3.68188607e-04,  -9.99319487e-05,   5.59497563e-04],
       [ -4.61830955e-04,  -1.04032655e-04,   7.11987569e-04],
       [ -5.93983200e-04,  -9.38855793e-05,   9.34806933e-04],
       [ -6.46768093e-04,  -1.08979040e-04,   1.01465194e-03],
       [ -3.75482714e-04,  -1.93036218e-04,   5.27011819e-04],
       [ -4.94106724e-04,  -1.98013666e-04,   7.20286580e-04],
       [ -6.42521535e-04,  -1.94604964e-04,   9.66707048e-04],
       [ -2.66293180e-04,  -2.67999546e-04,   3.11075500e-04],
       [ -7.43807418e-04,  -2.24034866e-04,   1.11969328e-03],
       [ -6.10725428e-04,  -2.71597762e-04,   8.77450613e-04],
       [ -1.86820350e-04,  -3.34446052e-04,   1.48225152e-04],
       [ -6.92258581e-04,  -3.04760869e-04,   9.96072348e-04],
       [ -7.12717851e-04,  -3.24392725e-04,   1.02043054e-03],
       [ -1.07040190e-03,  -2.81239998e-04,   1.63101602e-03],
       [ -6.92029702e-04,  -3.71782965e-04,   9.63649233e-04],
       [ -1.55354945e-04,  -4.36893080e-04,   4.73436079e-05],
       [ -3.06301221e-05,  -4.61023262e-04,  -1.69911014e-04],
       [ -7.04647898e-04,  -4.35421917e-04,   9.54033272e-04],
       [  1.59620162e-04,  -5.03421037e-04,  -5.03975615e-04],
       [ -2.55419911e-05,  -5.21367112e-04,  -2.07155722e-04],
       [ -9.17608541e-04,  -4.74354267e-04,   1.28666883e-03],
       [ -1.18966900e-05,  -5.61812335e-04,  -2.49000122e-04],
       [  1.27819654e-03,  -5.14329830e-04,  -2.35413564e-03],
       [  2.73726522e-04,  -6.03776621e-04,  -7.40162518e-04],
       [ -1.34816358e-04,  -6.18734402e-04,  -7.34767818e-05],
       [  6.32579482e-04,  -6.35510304e-04,  -1.34721611e-03],
       [  9.04427769e-04,  -6.41440510e-04,  -1.79842943e-03],
       [ -6.09660847e-04,  -6.53717676e-04,   6.92989821e-04],
       [  4.71908745e-04,  -7.02116241e-04,  -1.11405756e-03],
       [  6.76158487e-04,  -7.16389137e-04,  -1.45776493e-03],
       [  1.16075628e-03,  -7.08033267e-04,  -2.25304982e-03],
       [  1.30816668e-03,  -7.16776229e-04,  -2.50036414e-03],
       [  1.83123471e-03,  -6.75167443e-04,  -3.34320121e-03],
       [  6.09467210e-04,  -8.00690609e-04,  -1.38807395e-03],
       [  2.24905881e-03,  -6.51775523e-04,  -4.02116244e-03],
       [  1.76665320e-03,  -7.56259274e-04,  -3.27545538e-03],
       [  2.44155697e-03,  -6.65905456e-04,  -4.34541877e-03],
       [  2.33150192e-03,  -7.15811296e-04,  -4.18775921e-03],
       [  2.44248409e-03,  -7.20659087e-04,  -4.37312756e-03],
       [  1.99146942e-03,  -8.20787346e-04,  -3.67711320e-03],
       [  1.92626577e-03,  -8.52726299e-04,  -3.58483951e-03],
       [  3.62857057e-03,  -4.26517317e-04,  -6.18878123e-03],
       [  3.04797418e-03,  -6.93482799e-04,  -5.35880963e-03],
       [  2.12527838e-03,  -8.97419804e-04,  -3.93445409e-03],
       [  2.98804413e-03,  -7.68689276e-04,  -5.29592178e-03],
       [  2.98854088e-03,  -7.97226730e-04,  -5.31038585e-03],
       [  3.93063378e-03,  -5.13751150e-04,  -6.72870419e-03],
       [  4.04935231e-03,  -4.99020675e-04,  -6.91747158e-03],
       [  3.95911911e-03,  -5.91495120e-04,  -6.81285909e-03],
       [  4.31706576e-03,  -4.42772479e-04,  -7.33213536e-03],
       [  4.54634069e-03,  -3.12592266e-04,  -7.64805039e-03],
       [  3.90892773e-03,  -7.28379139e-04,  -6.79552407e-03],
       [  4.24420083e-03,  -6.35795021e-04,  -7.30424511e-03],
       [  4.21408538e-03,  -6.89939961e-04,  -7.28046224e-03],
       [  4.87228163e-03,  -3.25938784e-04,  -8.19202841e-03],
       [  5.04182235e-03,  -1.11200773e-04,  -8.36898970e-03],
       [  4.90049083e-03,  -4.65242363e-04,  -8.30516168e-03],
       [  5.10206297e-03,  -3.54134439e-04,  -8.58450374e-03],
       [  5.17302824e-03,   1.66703163e-04,  -8.45252090e-03],
       [  5.35601938e-03,  -1.91767448e-05,  -8.84321633e-03],
       [  5.40741472e-03,   5.89586049e-05,  -8.89062691e-03],
       [  5.43788103e-03,  -3.54960312e-04,  -9.13878620e-03],
       [  5.61061841e-03,  -1.32254586e-04,  -9.31721031e-03],
       [  5.41178053e-03,   3.29581888e-04,  -8.76843328e-03],
       [  5.65666914e-03,   1.74675597e-04,  -9.24641085e-03],
       [  5.35093932e-03,   4.81677189e-04,  -8.59536164e-03],
       [  5.59850466e-03,   3.77465441e-04,  -9.05351529e-03],
       [  5.51116926e-03,   4.88347203e-04,  -8.85645061e-03],
       [  5.16884829e-03,   6.97661905e-04,  -8.19175666e-03],
       [  5.45689499e-03,   6.09105060e-04,  -8.70919381e-03],
       [  5.90680147e-03,   3.97046982e-04,  -9.55264758e-03],
       [  5.41527631e-03,   7.08947307e-04,  -8.59281122e-03],
       [  5.66031902e-03,   6.38917797e-04,  -9.03046039e-03],
       [  5.40651784e-03,   7.87021309e-04,  -8.54103536e-03],
       [  5.86355568e-03,   6.25700098e-04,  -9.37199229e-03],
       [  6.11672379e-03,   5.35433963e-04,  -9.83271908e-03],
       [  4.05612276e-03,   1.25202125e-03,  -6.09140410e-03],
       [  4.83549360e-03,   1.09601895e-03,  -7.45146365e-03],
       [  4.80825393e-03,   1.13135209e-03,  -7.38964139e-03],
       [  4.96427811e-03,   1.11845965e-03,  -7.65314701e-03],
       [  4.98858372e-03,   1.14037847e-03,  -7.68275576e-03],
       [  4.83115921e-03,   1.20827296e-03,  -7.39064207e-03],
       [  4.96159921e-03,   1.20321165e-03,  -7.60820562e-03],
       [  4.87556587e-03,   1.25161077e-03,  -7.44316367e-03],
       [  5.65652153e-03,   1.06670099e-03,  -8.81965881e-03],
       [  4.99473756e-03,   1.27682030e-03,  -7.62766809e-03],
       [  1.63692943e-03,   1.79040771e-03,  -1.84384252e-03],
       [  4.37201631e-03,   1.46192846e-03,  -6.51206486e-03],
       [  5.63623912e-03,   1.19350161e-03,  -8.72557784e-03],
       [  4.82365844e-03,   1.42043292e-03,  -7.27682943e-03],
       [  2.15505858e-03,   1.83070010e-03,  -2.67916298e-03],
       [  1.19768485e-03,   1.92131467e-03,  -1.05677534e-03],
       [  1.69278008e-03,   1.90957762e-03,  -1.87898149e-03],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
       [  2.64158636e-05,   1.84658169e-05,  -3.47403281e-05],
       [  5.56566191e-05,   3.66294460e-05,  -7.42844291e-05],
       [  2.96484273e-05,   5.91911617e-05,  -2.05998025e-05],
       [  6.27772581e-05,   7.75504951e-05,  -6.64631976e-05],
       [ -1.28754686e-06,   1.00391539e-04,   5.01243476e-05],
       [  3.36763664e-05,   1.19491195e-04,   1.58820015e-06],
       [ -6.03507608e-05,   1.40907405e-04,   1.66913379e-04],
       [ -3.61587814e-05,   1.61049054e-04,   1.36642329e-04],
       [  6.91240287e-05,   1.78494423e-04,  -2.86665328e-05],
       [ -1.01756017e-04,   2.01155318e-04,   2.64012521e-04],
       [ -5.02724698e-05,   2.21445681e-04,   1.88798737e-04],
       [ -6.88043889e-04,   1.84341913e-04,   1.22297793e-03],
       [ -1.31116845e-04,   2.61515137e-04,   3.41299477e-04],
       [ -3.93875776e-04,   2.71413414e-04,   7.79418375e-04],
       [ -5.34028040e-04,   2.81206846e-04,   1.01526362e-03],
       [ -6.63554709e-04,   2.89310106e-04,   1.23277523e-03],
       [ -6.94451972e-04,   3.08716073e-04,   1.29301486e-03],
       [ -5.95428177e-04,   3.41777930e-04,   1.14549621e-03],
       [ -9.90115063e-04,   3.12256820e-04,   1.78236507e-03],
       [ -1.26056000e-03,   2.78466722e-04,   2.21227208e-03],
       [ -9.23057971e-04,   3.72798355e-04,   1.70071027e-03],
       [ -1.29179685e-03,   3.30918156e-04,   2.28887200e-03],
       [ -1.05508730e-03,   4.01954208e-04,   1.93241566e-03],
       [ -1.12417109e-03,   4.15911618e-04,   2.05303378e-03],
       [ -1.16389359e-03,   4.34333718e-04,   2.12735905e-03],
       [ -1.00347886e-03,   4.79103568e-04,   1.88418216e-03],
       [ -1.53188584e-03,   4.18388409e-04,   2.72668971e-03],
       [ -1.94188695e-03,   3.31677582e-04,   3.36147296e-03],
       [ -2.35445677e-03,   7.98424275e-05,   3.92154123e-03],
       [ -2.02823859e-03,   3.74529601e-04,   3.52438764e-03],
       [ -1.83427254e-03,   4.62930057e-04,   3.24673360e-03],
       [ -2.20077993e-03,   3.86049277e-04,   3.81447977e-03],
       [ -2.21645034e-03,   4.17034332e-04,   3.85514109e-03],
       [ -2.63836012e-03,   2.56025796e-04,   4.47404178e-03],
       [ -2.84190325e-03,   9.49879241e-05,   4.73276148e-03],
       [ -2.58899275e-03,   3.88251569e-04,   4.45583859e-03],
       [ -2.86421978e-03,   2.83869122e-04,   4.85988037e-03],
       [ -3.08193781e-03,   1.16620542e-04,   5.13901012e-03],
       [ -3.08250691e-03,   2.44838727e-04,   5.20125438e-03],
       [ -3.08111844e-03,   3.18046644e-04,   5.23396755e-03],
       [ -2.95010801e-03,   4.41487954e-04,   5.07690475e-03],
       [ -3.19193038e-03,   3.62783945e-04,   5.43812770e-03],
       [ -3.49220078e-03,   1.13417355e-04,   5.81415321e-03],
       [ -3.49994428e-03,  -1.14979528e-04,   5.71772055e-03],
       [ -3.65139651e-03,   1.32051721e-04,   6.08563530e-03],
       [ -3.74088456e-03,   7.37807457e-05,   6.20537263e-03],
       [ -3.80349470e-03,   1.67978749e-04,   6.35367917e-03],
       [ -3.90337013e-03,   6.79255248e-05,   6.47057156e-03],
       [ -3.97414432e-03,   6.01824068e-06,   6.55770420e-03],
       [ -4.06329855e-03,   1.17611229e-04,   6.75810887e-03],
       [ -4.14636507e-03,   5.74960078e-05,   6.86637294e-03],
       [ -3.95960825e-03,  -2.88682666e-04,   6.39282206e-03],
       [ -4.29694948e-03,   1.67630475e-04,   7.16740119e-03],
       [ -3.19535797e-03,  -6.80781674e-04,   4.94481580e-03],
       [ -3.64759360e-03,  -5.67286727e-04,   5.74498470e-03],
       [ -3.93352899e-03,  -4.89210532e-04,   6.25392838e-03],
       [ -3.79320502e-03,  -5.83125528e-04,   5.97757845e-03],
       [ -4.01371661e-03,  -5.32662280e-04,   6.36541147e-03],
       [ -3.86896696e-03,  -6.24427351e-04,   6.08278993e-03],
       [ -4.25559538e-03,  -5.05615545e-04,   6.77729057e-03],
       [ -3.29698764e-03,  -8.50126389e-04,   5.03147089e-03],
       [ -4.20580814e-03,  -6.03785738e-04,   6.64823449e-03],
       [ -3.92932676e-03,  -7.34324849e-04,   6.12979961e-03],
       [ -3.04162696e-03,  -9.82721670e-04,   4.54688871e-03],
       [ -3.64272634e-03,  -8.73917066e-04,   5.59034611e-03],
       [ -4.57963418e-03,  -6.00947842e-04,   7.26616810e-03],
       [ -2.77985609e-03,  -1.10208909e-03,   4.05805846e-03],
       [ -3.04393018e-03,  -1.07978464e-03,   4.50427837e-03],
       [ -2.70371851e-03,  -1.16055328e-03,   3.90452578e-03],
       [ -3.82056376e-03,  -9.66599668e-04,   5.83935067e-03],
       [ -3.36927224e-03,  -1.09119935e-03,   5.03542940e-03],
       [ -9.95019098e-04,  -1.40470889e-03,   9.69512223e-04],
       [ -1.86214807e-03,  -1.35481120e-03,   2.42358488e-03],
       [ -3.78488263e-03,  -1.08049146e-03,   5.72604371e-03],
       [ -2.11332949e-03,  -1.37071681e-03,   2.83027046e-03],
       [ -3.83271784e-03,  -1.12151334e-03,   5.78532758e-03],
       [  2.94177930e-04,  -1.54972155e-03,  -1.22618380e-03],
       [ -2.22550367e-03,  -1.42246422e-03,   2.99054472e-03],
       [ -3.40185546e-03,  -1.27839491e-03,   4.99966636e-03],
       [ -2.19266494e-03,  -1.46883260e-03,   2.91421122e-03],
       [ -8.38896109e-04,  -1.59621871e-03,   6.20440404e-04],
       [ -1.72190711e-03,  -1.55636104e-03,   2.09590789e-03],
       [  1.43983147e-03,  -1.65603227e-03,  -3.16661902e-03],
       [ -8.32988117e-04,  -1.65721114e-03,   5.81533332e-04],
       [ -1.23495683e-03,  -1.65451865e-03,   1.24581506e-03],
       [  7.97531946e-04,  -1.73071771e-03,  -2.14294043e-03],
       [ -2.37311474e-04,  -1.74073578e-03,  -4.40892744e-04],
       [  1.27102961e-03,  -1.76334551e-03,  -2.93951284e-03],
       [  2.84829771e-03,  -1.69834148e-03,  -5.50992776e-03],
       [  2.07422295e-03,  -1.77329137e-03,  -4.26902974e-03],
       [  1.58794602e-03,  -1.81532859e-03,  -3.48707959e-03],
       [  6.25928937e-04,  -1.85257277e-03,  -1.91816708e-03],
       [  3.70599799e-03,  -1.70015750e-03,  -6.92545950e-03],
       [  3.80980312e-03,  -1.71072132e-03,  -7.10172329e-03],
       [  3.07159856e-03,  -1.81012569e-03,  -5.93168099e-03],
       [  3.70881472e-03,  -1.76858207e-03,  -6.96282151e-03],
       [  1.53668039e-03,  -1.94000690e-03,  -3.46213675e-03],
       [  5.35146803e-03,  -1.56288904e-03,  -9.57381235e-03],
       [  5.23834517e-03,  -1.61202274e-03,  -9.41072367e-03],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
       [  5.07176871e-05,  -1.66879117e-05,  -9.16312075e-05],
       [  9.32417658e-05,  -3.47267816e-05,  -1.70394080e-04],
       [  1.33352762e-04,  -5.30483969e-05,  -2.45312081e-04],
       [  2.85148576e-04,  -4.42436162e-05,  -4.91469373e-04],
       [  2.87823400e-04,  -7.65510158e-05,  -5.11328441e-04],
       [  4.28868561e-04,  -6.58698178e-05,  -7.38856769e-04],
       [  5.26214632e-04,  -6.39431822e-05,  -8.98495076e-04],
       [  5.62721238e-04,  -9.16509333e-05,  -9.71955951e-04],
       [  5.13391535e-04,  -1.38849452e-04,  -9.13160379e-04],
       [  6.24015848e-04,  -1.40983361e-04,  -1.09664088e-03],
       [  8.38735435e-04,  -9.32605454e-05,  -1.42797456e-03],
       [  9.75879466e-04,  -1.98206366e-05,  -1.61906135e-03],
       [  9.83082905e-04,  -1.15306866e-04,  -1.67659779e-03],
       [  1.13219917e-03,  -5.18416565e-05,  -1.89220045e-03],
       [  1.19685770e-03,  -8.17060926e-05,  -2.01312540e-03],
       [  1.03388635e-03,  -2.15319322e-04,  -1.80821082e-03],
       [  1.27190993e-03,  -1.58727290e-04,  -2.17374068e-03],
       [  1.38770229e-03,  -1.42266782e-04,  -2.35685455e-03],
       [  1.44710332e-03,  -1.62128791e-04,  -2.46432543e-03],
       [  1.61214089e-03,  -8.48529773e-05,  -2.69958485e-03],
       [  1.69656873e-03,  -8.15908010e-05,  -2.83727767e-03],
       [  1.75359901e-03,  -1.22016496e-04,  -2.95067054e-03],
       [  1.81648818e-03,   7.43981192e-05,  -2.96048522e-03],
       [  1.92606163e-03,  -1.15774281e-04,  -3.23214025e-03],
       [  2.02744302e-03,  -2.46176171e-06,  -3.34517681e-03],
       [  1.98003350e-03,   1.44185762e-04,  -3.19686374e-03],
       [  2.18765990e-03,   3.45871091e-06,  -3.60660261e-03],
       [  2.27628406e-03,  -5.94007897e-05,  -3.78283180e-03],
       [  2.33055001e-03,   4.33789076e-05,  -3.82319376e-03],
       [  2.29396522e-03,   1.59945559e-04,  -3.70711738e-03],
       [  2.20395955e-03,   2.58816621e-04,  -3.51139119e-03],
       [  2.45120204e-03,   1.67566104e-04,  -3.96281512e-03],
       [  2.45511464e-03,   2.19268035e-04,  -3.94454793e-03],
       [  2.17801165e-03,   3.77654700e-04,  -3.41177290e-03],
       [  2.24372305e-03,   3.88207755e-04,  -3.51510941e-03],
       [  2.77407938e-03,   1.76464253e-04,  -4.49110418e-03],
       [  2.59081184e-03,   3.26805190e-04,  -4.11694527e-03],
       [  1.71017133e-03,   6.01734175e-04,  -2.53299179e-03],
       [  2.17231837e-03,   5.27719923e-04,  -3.33063115e-03],
       [  2.48679684e-03,   4.68750338e-04,  -3.87751732e-03],
       [  2.57599011e-03,   4.71925608e-04,  -4.02311163e-03],
       [  1.75685216e-03,   6.88482775e-04,  -2.56850812e-03],
       [  2.00293248e-03,   6.68184938e-04,  -2.98409023e-03],
       [  2.50394573e-03,   5.81749151e-04,  -3.85177346e-03],
       [  1.72299827e-03,   7.63062808e-04,  -2.47701127e-03],
       [  2.85666533e-03,   5.40045933e-04,  -4.45347766e-03],
       [  2.69682626e-03,   6.15577355e-04,  -4.15372999e-03],
       [  2.25939603e-03,   7.41321572e-04,  -3.37212383e-03],
       [  2.83295887e-03,   6.36059168e-04,  -4.36846969e-03],
       [  1.84801513e-03,   8.57709976e-04,  -2.63795595e-03],
       [  2.51241309e-03,   7.65166426e-04,  -3.77804103e-03],
       [  2.65736844e-03,   7.59763752e-04,  -4.01970897e-03],
       [  1.68172083e-03,   9.46519692e-04,  -2.32121241e-03],
       [  1.79141189e-03,   9.54471465e-04,  -2.49833133e-03],
       [  2.27559804e-03,   9.05756667e-04,  -3.32022473e-03],
       [  1.21109412e-03,   1.05983322e-03,  -1.49079659e-03],
       [  7.68092804e-04,   1.11327178e-03,  -7.34573525e-04],
       [ -5.33854568e-04,   1.16726266e-03,   1.43863181e-03],
       [  9.91632616e-05,   1.18212616e-03,   4.01659181e-04],
       [  3.48958629e-04,   1.19446030e-03,  -4.44793675e-06],
       [  1.15206104e-03,   1.16835981e-03,  -1.34153895e-03],
       [ -1.03380928e-04,   1.24631731e-03,   7.66421131e-04],
       [ -5.57348789e-04,   1.26806708e-03,   1.52558055e-03],
       [ -1.40910131e-03,   1.26440301e-03,   2.92868202e-03],
       [  1.47980998e-03,   1.22352040e-03,  -1.85574336e-03],
       [ -5.27700715e-04,   1.32879114e-03,   1.50571426e-03],
       [ -1.28736293e-03,   1.33213147e-03,   2.76027398e-03],
       [  3.68230805e-04,   1.35472288e-03,   4.03923038e-05],
       [ -3.02157163e-03,   1.22647856e-03,   5.57010604e-03],
       [ -2.95886463e-03,   1.25855421e-03,   5.48201562e-03],
       [ -2.51435212e-03,   1.33238575e-03,   4.78415232e-03],
       [ -1.29083457e-03,   1.43508628e-03,   2.81522628e-03],
       [ -1.35537070e-03,   1.45328490e-03,   2.93037152e-03],
       [ -1.97723747e-03,   1.44181361e-03,   3.95057387e-03],
       [ -3.59806059e-03,   1.28633607e-03,   6.54956850e-03],
       [ -2.51681138e-03,   1.44281114e-03,   4.84100680e-03],
       [ -3.43250020e-03,   1.36018887e-03,   6.31181008e-03],
       [ -2.03732200e-03,   1.52279754e-03,   4.08839663e-03],
       [ -3.76622710e-03,   1.35873318e-03,   6.86155255e-03],
       [ -5.61421301e-03,   9.21758646e-04,   9.70062865e-03],
       [ -4.35448026e-03,   1.30557145e-03,   7.80638004e-03],
       [ -3.92798081e-03,   1.40734071e-03,   7.15158491e-03],
       [ -4.37588309e-03,   1.35452606e-03,   7.86508810e-03],
       [ -4.95519455e-03,   1.25803292e-03,   8.77444916e-03],
       [ -5.74373590e-03,   1.06219411e-03,   9.98140667e-03],
       [ -5.60828859e-03,   1.14161430e-03,   9.79597779e-03],
       [ -6.10174338e-03,   1.00390299e-03,   1.05440218e-02],
       [ -6.96541613e-03,   5.42309284e-04,   1.17478319e-02],
       [ -5.26045147e-03,   1.33043234e-03,   9.31254686e-03],
       [ -6.35916735e-03,   1.01937720e-03,   1.09760074e-02],
       [ -6.71695245e-03,   8.98720287e-04,   1.15084364e-02],
       [ -6.60679936e-03,   9.96445717e-04,   1.13734794e-02],
       [ -7.11482321e-03,   7.72946583e-04,   1.21045349e-02],
       [ -7.44140997e-03,   5.78175069e-04,   1.25500696e-02],
       [ -7.69067588e-03,   3.32592942e-04,   1.28437794e-02],
       [ -7.51938155e-03,   6.66046879e-04,   1.27206880e-02],
       [ -7.73933330e-03,   5.28793251e-04,   1.30178436e-02],
       [ -7.94829171e-03,   3.00303647e-04,   1.32532441e-02],
       [ -8.03795126e-03,   2.71183671e-04,   1.33872024e-02]])

Compute the Loss for the Softmax classifier:

$L_i = -\log\left(\frac{e^{f_{y_i}}}{ \sum_j e^{f_j} }\right)$

Softmax classifier interprets every element of f as holding the (unnormalized) log probabilities of the three classes. We exponentiate these to get (unnormalized) probabilities and then normalize them to get probabilities.

As $-\log(x)$ coverges towards infinity for x=0 and 0 for x=1 the loss is high if the probability inside the parentheses is small and low if it is large.

The full Softmax classifier loss is then defined as the average cross-entropy loss over all training examples:

$ L = \underbrace{ \frac{1}{N} \sum_i L_i }_\text{data loss} + \underbrace{ \frac{1}{2} \lambda \sum_k\sum_l W_{k,l}^2 }_\text{regularization loss} \\\\ $


In [78]:
scores.shape


Out[78]:
(300, 3)

In [79]:
scores[:4]


Out[79]:
array([[  0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
       [ -7.81081207e-05,  -4.05397930e-06,   1.26890690e-04],
       [ -1.43432414e-04,  -1.61111533e-05,   2.28869536e-04],
       [ -1.91544072e-04,  -3.35339727e-05,   2.99892906e-04]])

In [80]:
# compute loss of the scores
num_examples = X.shape[0]
# get unnormalized probabilities
exp_scores = np.exp(scores)
# normalize them for each example
probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)

In [81]:
# each row contains the class probabilities
probs[:4]


Out[81]:
array([[ 0.33333333,  0.33333333,  0.33333333],
       [ 0.33330233,  0.33332701,  0.33337066],
       [ 0.33327782,  0.33332026,  0.33340192],
       [ 0.33326117,  0.33331384,  0.33342499]])

In [82]:
# get log of probabilities of the actual classes
# the array indexing/querying here works as follows np.array([...])[[ROW_INDICES], [COL_INDICES]]
correct_probs = probs[range(num_examples),y]
corect_logprobs = -np.log(correct_probs)

In [83]:
reg = 0.5 # regularization strength
# compute the loss: average cross-entropy loss and regularization
data_loss = np.sum(corect_logprobs)/num_examples
reg_loss = 0.5*reg*np.sum(W*W)
loss = data_loss + reg_loss

Computing the analytic gradient with backpropagation.

Loss for one example is:

$ p_k = \frac{e^{f_k}}{ \sum_j e^{f_j} } \hspace{1in} L_i =-\log\left(p_{y_i}\right) $

We now want to understand how the computed scores inside $f$ should change to decrease the loss $L_i$. In other words derive the gradient $ \partial L_i / \partial f_k $ .

Chain rule:

$ \frac{\partial L_i}{\partial f_k} = \frac{\partial L_i}{\partial p} \frac{\partial p}{\partial f_k} $

$ \frac{\partial L_i }{ \partial f_k } = p_k - \mathbb{1}(y_i = k) $

That means for probabilities of p = [0.2, 0.3, 0.5] and correct class is middle one, the gradient on the scores would be df = [0.2, -0.7, 0.5].


In [90]:
# probs are probabilities of all classes (as rows)
dscores = np.copy(probs)
dscores[range(num_examples),y] -= 1 # using the previously calculated formulat (p_k - 1)
# avg gradients on scores
dscores /= num_examples

Note that the regularization gradient has the very simple form reg*W since we used the constant 0.5 for its loss contribution (i.e. $ \frac{d}{dw} ( \frac{1}{2} \lambda w^2) = \lambda w $)


In [91]:
# backpropagate into W and b
dW = np.dot(X.T, dscores)
db = np.sum(dscores, axis=0, keepdims=True)
dW += reg*W # don't forget the regularization gradient

In [93]:
step_size = 1e-0
# Perform a parameter update in the negative gradient direction to decrease loss!
W += -step_size * dW
b += -step_size * db

In [94]:
# putting it all together
# initialize parameters randomly
W = 0.01 * np.random.randn(D, K)
b = np.zeros((1, K))

# some hyperparameters
step_size = 1e-0
reg = 1e-3  # regularization strength

# gradient descent loop
num_examples = X.shape[0]
for i in range(200):

    # evaluate class scores, [N x K]
    scores = np.dot(X, W) + b

    # compute the class probabilities
    exp_scores = np.exp(scores)
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)  # [N x K]

    # compute the loss: average cross-entropy loss and regularization
    corect_logprobs = -np.log(probs[range(num_examples), y])
    data_loss = np.sum(corect_logprobs) / num_examples
    reg_loss = 0.5 * reg * np.sum(W * W)
    loss = data_loss + reg_loss
    if i % 10 == 0:
        print("iteration %d: loss %f" % (i, loss))

    # compute the gradient on scores
    dscores = probs
    dscores[range(num_examples), y] -= 1
    dscores /= num_examples

    # backpropate the gradient to the parameters (W,b)
    dW = np.dot(X.T, dscores)
    db = np.sum(dscores, axis=0, keepdims=True)

    dW += reg * W  # regularization gradient

    # perform a parameter update
    W += -step_size * dW
    b += -step_size * db


iteration 0: loss 1.097798
iteration 10: loss 0.900820
iteration 20: loss 0.827351
iteration 30: loss 0.793641
iteration 40: loss 0.775893
iteration 50: loss 0.765675
iteration 60: loss 0.759424
iteration 70: loss 0.755431
iteration 80: loss 0.752798
iteration 90: loss 0.751018
iteration 100: loss 0.749791
iteration 110: loss 0.748934
iteration 120: loss 0.748327
iteration 130: loss 0.747892
iteration 140: loss 0.747579
iteration 150: loss 0.747352
iteration 160: loss 0.747186
iteration 170: loss 0.747064
iteration 180: loss 0.746974
iteration 190: loss 0.746908

In [95]:
# evaluate training set accuracy
scores = np.dot(X, W) + b
predicted_class = np.argmax(scores, axis=1)
print('training accuracy: %.2f' % (np.mean(predicted_class == y)))


training accuracy: 0.54