Deep Learning: Linear Stochastic Gradient Descent


In [235]:
%matplotlib inline
import numpy as np

from matplotlib import pyplot as plt

In [236]:
# Create a random X axis points.
np.random.seed(1)
X = np.random.random(4)
print X


[  4.17022005e-01   7.20324493e-01   1.14374817e-04   3.02332573e-01]

In [237]:
def lin(a, b, x): return a*x + b

In [238]:
# We define the slope "a"
a = 10.
# We define the y-intercept var "b"
b = 3.
# We define the len of the X.
n = len(X)

In [239]:
# Create y axis points which follows a linear form.
y = lin(a, b, X)

In [240]:
plt.scatter(X, y)


Out[240]:
<matplotlib.collections.PathCollection at 0x115f6cfd0>

In [241]:
# We choose the Least Square Errors (lse) as cost function.
# But we return the sum of all those errors for measuring the performance.
def sse(y, y_pred): return ((y-y_pred)**2).sum()

In [242]:
# The loss function is the comparison of a perfect regression line with a line created using
# our guessed "a" and "b" using the cost function.
def loss(y, a, b, X): return sse(y, lin(a, b, X))

In [243]:
# This function represents the average loss per point.
def avg_loss(y, a, b, X): return np.sqrt(loss(y, a, b, X)/n)

In [244]:
# We define the random numbers that will conform the initial approximation line.
a_guess = -1
b_guess = 2

In [245]:
# As we can see we get a really bad approximation at the beginning.
avg_loss(y, a_guess, b_guess, X)


Out[245]:
5.7131690554944567

In [246]:
lr=0.6 # This is the learning rate

In [247]:
# The partial derivatives of the cost function (lse) are:
# d[(y-(a*x+b))**2,b] = 2 (b + a x - y)      = 2 (y_pred - y)
# d[(y-(a*x+b))**2,a] = 2 x (b + a x - y)    = x * dy/db

# This tells us how much is changing df with regards of "b" and "a" respectively.

In [251]:
def update():
    print '-----------------------------------------------'
    global a_guess, b_guess
    # We calculate the current "guessed" line based on the iterative process.
    y_pred = lin(a_guess, b_guess, X)
    # Now, in order to iterate again we need to calculate how much the
    # derivative of the cost function changes depending on "a" and "b".
    # d[(y-(a*x+b))**2,a] = 2 x (b + a x - y)    = x * dy/db
    # d[(y-(a*x+b))**2,b] = 2 (b + a x - y)      = 2 (y_pred - y)
    dydb = 2 * (y_pred - y)
    dyda = X * dydb
    print 'dy/da: {}'.format(dyda)
    print 'dy/db: {}'.format(dydb)
    print 'Current a_guess = {}'.format(a_guess)
    print 'Current b_guess = {}'.format(b_guess)
    # Because these changes always tell us how to go higher instead than
    # lower, we need to substract it from the previous
    # guesses in order to approach the minimum. Please notice how the speed is
    # controlled by the defined learning rate.
    print 'Formula a: a_guess -= lr * dyda.mean()'
    print 'Formula b: b_guess -= lr * dydb.mean()'
    print '* Where:'
    print 'Learning rate = {}'.format(lr)
    print 'dyda.mean() = {}'.format(dyda.mean())
    print 'dydb.mean() = {}'.format(dydb.mean())
    print '* Result:'
    a_guess -= lr * dyda.mean()
    b_guess -= lr * dydb.mean()
    print 'New a_guess = {}'.format(a_guess)
    print 'New b_guess = {}'.format(b_guess)

In [252]:
plt.plot(X, y, 'ro')
line, = plt.plot(X, lin(a_guess, b_guess, X))



In [253]:
# We can see how why are getting better.
for i in range(3):
    update()

avg_loss(y, a_guess, b_guess, X)


-----------------------------------------------
dy/da: [ -1.99349856e-03  -3.22108978e+00   7.01723736e-04   5.09227288e-01]
dy/db: [ -4.78031984e-03  -4.47172046e+00   6.13529929e+00   1.68432823e+00]
Current a_guess = 2.63616305654
Current b_guess = 6.06849188454
Formula a: a_guess -= lr * dyda.mean()
Formula b: b_guess -= lr * dydb.mean()
* Where:
Learning rate = 0.6
dyda.mean() = -0.678288566022
dydb.mean() = 0.835781686277
* Result:
New a_guess = 3.04313619615
New b_guess = 5.56702287278
-----------------------------------------------
dy/da: [ -2.78689481e-01  -3.52120043e+00   5.87023531e-04   2.80405202e-01]
dy/db: [-0.66828483 -4.88835305  5.13245437  0.92747268]
Current a_guess = 3.04313619615
Current b_guess = 5.56702287278
Formula a: a_guess -= lr * dyda.mean()
Formula b: b_guess -= lr * dydb.mean()
* Where:
Learning rate = 0.6
dyda.mean() = -0.879724421659
dydb.mean() = 0.125822292219
* Result:
New a_guess = 3.57097084915
New b_guess = 5.49152949745
-----------------------------------------------
dy/da: [ -1.58065625e-01  -3.08220752e+00   5.69768258e-04   3.31250426e-01]
dy/db: [-0.37903425 -4.27891534  4.98158836  1.09564915]
Current a_guess = 3.57097084915
Current b_guess = 5.49152949745
Formula a: a_guess -= lr * dyda.mean()
Formula b: b_guess -= lr * dydb.mean()
* Where:
Learning rate = 0.6
dyda.mean() = -0.727113238422
dydb.mean() = 0.354821978426
* Result:
New a_guess = 4.0072387922
New b_guess = 5.27863631039
Out[253]:
1.5499545846314433

In [230]:
plt.plot(X, y, 'ro')
line, = plt.plot(X, lin(a_guess, b_guess, X))



In [231]:
# And even better!
for i in range(3):
    update()

avg_loss(y, a_guess, b_guess, X)


-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.99349856e-03  -3.22108978e+00   7.01723736e-04   5.09227288e-01]
dy/da - mean: -0.678288566022
dy/db: [ -4.78031984e-03  -4.47172046e+00   6.13529929e+00   1.68432823e+00]
dy/db - mean: 0.835781686277
Current a_guess 2.63616305654
Current b_guess 6.06849188454
New a_guess 3.04313619615
New b_guess 5.56702287278
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -2.78689481e-01  -3.52120043e+00   5.87023531e-04   2.80405202e-01]
dy/da - mean: -0.879724421659
dy/db: [-0.66828483 -4.88835305  5.13245437  0.92747268]
dy/db - mean: 0.125822292219
Current a_guess 3.04313619615
Current b_guess 5.56702287278
New a_guess 3.57097084915
New b_guess 5.49152949745
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.58065625e-01  -3.08220752e+00   5.69768258e-04   3.31250426e-01]
dy/da - mean: -0.727113238422
dy/db: [-0.37903425 -4.27891534  4.98158836  1.09564915]
dy/db - mean: 0.354821978426
Current a_guess 3.57097084915
Current b_guess 5.49152949745
New a_guess 4.0072387922
New b_guess 5.27863631039
Out[231]:
1.5499545846314433

In [232]:
plt.plot(X, y, 'ro')
line, = plt.plot(X, lin(a_guess, b_guess, X))



In [233]:
# And even better!
for i in range(15):
    update()

avg_loss(y, a_guess, b_guess, X)


-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.83887506e-01  -2.93618147e+00   5.21080434e-04   2.82275465e-01]
dy/da - mean: -0.709318108107
dy/db: [-0.44095396 -4.07619274  4.55590178  0.93365879]
dy/db - mean: 0.24310346678
Current a_guess 4.0072387922
Current b_guess 5.27863631039
New a_guess 4.43282965706
New b_guess 5.13277423032
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.57516139e-01  -2.70466710e+00   4.87725671e-04   2.71880002e-01]
dy/da - mean: -0.647453877669
dy/db: [-0.37771661 -3.75478985  4.26427497  0.8992746 ]
dy/db - mean: 0.257760775512
Current a_guess 4.43282965706
Current b_guess 5.13277423032
New a_guess 4.82130198366
New b_guess 4.97811776502
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.51390050e-01  -2.52434155e+00   4.52358225e-04   2.49381242e-01]
dy/da - mean: -0.606474498923
dy/db: [-0.36302653 -3.50445052  3.9550509   0.82485734]
dy/db - mean: 0.228107799715
Current a_guess 4.82130198366
Current b_guess 4.97811776502
New a_guess 5.18518668302
New b_guess 4.84125308519
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.38976767e-01  -2.34389971e+00   4.21060000e-04   2.33145691e-01]
dy/da - mean: -0.562327431515
dy/db: [-0.33326003 -3.25394976  3.68140478  0.77115638]
dy/db - mean: 0.216337842598
Current a_guess 5.18518668302
Current b_guess 4.84125308519
New a_guess 5.52258314193
New b_guess 4.71145037963
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.29886486e-01  -2.18077182e+00   3.91376506e-04   2.16337956e-01]
dy/da - mean: -0.523482242537
dy/db: [-0.31146195 -3.0274853   3.42187655  0.71556284]
dy/db - mean: 0.199623036532
Current a_guess 5.52258314193
Current b_guess 4.71145037963
New a_guess 5.83667248745
New b_guess 4.59167655771
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.20538232e-01  -2.02738242e+00   3.63986505e-04   2.01333564e-01]
dy/da - mean: -0.486555775987
dy/db: [-0.28904526 -2.81454045  3.18240076  0.66593408]
dy/db - mean: 0.186187283201
Current a_guess 5.83667248745
Current b_guess 4.59167655771
New a_guess 6.12860595304
New b_guess 4.47996418779
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.12172513e-01  -1.88537123e+00   3.38439979e-04   1.87153335e-01]
dy/da - mean: -0.452512992487
dy/db: [-0.26898464 -2.61739154  2.9590428   0.61903133]
dy/db - mean: 0.172924488515
Current a_guess 6.12860595304
Current b_guess 4.47996418779
New a_guess 6.40011374853
New b_guess 4.37620949468
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -1.04274089e-01  -1.75309225e+00   3.14713235e-04   1.74050820e-01]
dy/da - mean: -0.420750201718
dy/db: [-0.25004457 -2.43375349  2.75159552  0.57569325]
dy/db - mean: 0.160872674486
Current a_guess 6.40011374853
Current b_guess 4.37620949468
New a_guess 6.65256386956
New b_guess 4.27968588999
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -9.69731591e-02  -1.63017262e+00   2.92640100e-04   1.61836759e-01]
dy/da - mean: -0.391254095062
dy/db: [-0.23253727 -2.26310869  2.55860606  0.53529383]
dy/db - mean: 0.149563479903
Current a_guess 6.65256386956
Current b_guess 4.27968588999
New a_guess 6.8873163266
New b_guess 4.18994780205
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -9.01683172e-02  -1.51584292e+00   2.72118687e-04   1.50490355e-01]
dy/da - mean: -0.363812191784
dy/db: [-0.21621957 -2.10438898  2.37918358  0.49776428]
dy/db - mean: 0.139084828509
Current a_guess 6.8873163266
Current b_guess 4.18994780205
New a_guess 7.10560364167
New b_guess 4.10649690494
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -8.38465000e-02  -1.40954204e+00   2.53035036e-04   1.39935603e-01]
dy/da - mean: -0.338299975684
dy/db: [-0.20106013 -1.95681537  2.21233172  0.46285322]
dy/db - mean: 0.129327357046
Current a_guess 7.10560364167
Current b_guess 4.10649690494
New a_guess 7.30858362708
New b_guess 4.02890049071
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -7.79659007e-02  -1.31069185e+00   2.35290195e-04   1.30122521e-01]
dy/da - mean: -0.314574985335
dy/db: [-0.18695872 -1.81958529  2.05718532  0.43039531]
dy/db - mean: 0.120259154881
Current a_guess 7.30858362708
Current b_guess 4.02890049071
New a_guess 7.49732861828
New b_guess 3.95674499778
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -7.24984739e-02  -1.21877535e+00   2.18789591e-04   1.20997075e-01]
dy/da - mean: -0.292514490502
dy/db: [-0.17384808 -1.691981    1.91291751  0.40021184]
dy/db - mean: 0.111825069634
Current a_guess 7.49732861828
Current b_guess 3.95674499778
New a_guess 7.67283731258
New b_guess 3.889649956
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -6.74141869e-02  -1.13330429e+00   2.03446216e-04   1.12511781e-01]
dy/da - mean: -0.272000811308
dy/db: [-0.16165619 -1.57332466  1.77876757  0.37214575]
dy/db - mean: 0.103983119923
Current a_guess 7.67283731258
Current b_guess 3.889649956
New a_guess 7.83603779937
New b_guess 3.82726008405
-----------------------------------------------
Learning rate: 0.6
dy/da: [ -6.26865567e-02  -1.05382737e+00   1.89178826e-04   1.04621476e-01]
dy/da - mean: -0.252925819173
dy/db: [-0.15031954 -1.46298978  1.65402516  0.34604765]
dy/db - mean: 0.0966908716105
Current a_guess 7.83603779937
Current b_guess 3.82726008405
New a_guess 7.98779329087
New b_guess 3.76924556108
Out[233]:
0.52077385521080499

In [234]:
plt.plot(X, y, 'ro')
line, = plt.plot(X, lin(a_guess, b_guess, X))