In [20]:
import numpy as np
import matplotlib.pyplot as plt
def f(x):
    return x[0] * x[0] + 50 * x[1] * x[1]
def g(x):
    return np.array([2 * x[0], 100 * x[1]])
xi = np.linspace(-200,200,1000)
yi = np.linspace(-100,100,1000)
X,Y = np.meshgrid(xi, yi)
Z = X * X + 50 * Y * Y

%matplotlib inline
def contour(X,Y,Z, arr = None):
    plt.figure(figsize=(15,7))
    xx = X.flatten()
    yy = Y.flatten()
    zz = Z.flatten()
    plt.contour(X, Y, Z, colors='black')
    plt.plot(0,0,marker='*')
    if arr is not None:
        arr = np.array(arr)
        for i in range(len(arr) - 1):
            plt.plot(arr[i:i+2,0],arr[i:i+2,1])
        
contour(X,Y,Z)



In [37]:
def gd(x_start, step, g):   # gd代表了Gradient Descent
    x = np.array(x_start, dtype='float64')
    passing_dot = [x.copy()]
    for i in range(50):
        grad = g(x)
        x -= grad * step
        
        passing_dot.append(x.copy())
        print '[ Epoch {0} ] grad = {1}, x = {2}'.format(i, grad, x)
        if abs(sum(grad)) < 1e-6:
            break;
    return x, passing_dot
res, x_arr = gd([150,75], 0.016, g)
contour(X,Y,Z, x_arr)


[ Epoch 0 ] grad = [  300.  7500.], x = [ 145.2  -45. ]
[ Epoch 1 ] grad = [  290.4 -4500. ], x = [ 140.5536   27.    ]
[ Epoch 2 ] grad = [  281.1072  2700.    ], x = [ 136.0558848  -16.2      ]
[ Epoch 3 ] grad = [  272.1117696 -1620.       ], x = [ 131.70209649    9.72      ]
[ Epoch 4 ] grad = [ 263.40419297  972.        ], x = [ 127.4876294   -5.832    ]
[ Epoch 5 ] grad = [ 254.9752588 -583.2      ], x = [ 123.40802526    3.4992    ]
[ Epoch 6 ] grad = [ 246.81605052  349.92      ], x = [ 119.45896845   -2.09952   ]
[ Epoch 7 ] grad = [ 238.9179369 -209.952    ], x = [ 115.63628146    1.259712  ]
[ Epoch 8 ] grad = [ 231.27256292  125.9712    ], x = [ 111.93592045   -0.7558272 ]
[ Epoch 9 ] grad = [ 223.87184091  -75.58272   ], x = [ 108.353971      0.45349632]
[ Epoch 10 ] grad = [ 216.707942   45.349632], x = [ 104.88664393   -0.27209779]
[ Epoch 11 ] grad = [ 209.77328785  -27.2097792 ], x = [ 101.53027132    0.16325868]
[ Epoch 12 ] grad = [ 203.06054264   16.32586752], x = [  9.82813026e+01  -9.79552051e-02]
[ Epoch 13 ] grad = [ 196.56260528   -9.79552051], x = [  9.51363010e+01   5.87731231e-02]
[ Epoch 14 ] grad = [ 190.27260191    5.87731231], x = [  9.20919393e+01  -3.52638738e-02]
[ Epoch 15 ] grad = [ 184.18387865   -3.52638738], x = [  8.91449973e+01   2.11583243e-02]
[ Epoch 16 ] grad = [ 178.28999453    2.11583243], x = [  8.62923574e+01  -1.26949946e-02]
[ Epoch 17 ] grad = [ 172.58471471   -1.26949946], x = [  8.35310019e+01   7.61699675e-03]
[ Epoch 18 ] grad = [ 167.06200383    0.76169968], x = [  8.08580099e+01  -4.57019805e-03]
[ Epoch 19 ] grad = [ 161.71601971   -0.45701981], x = [  7.82705535e+01   2.74211883e-03]
[ Epoch 20 ] grad = [ 156.54110708    0.27421188], x = [  7.57658958e+01  -1.64527130e-03]
[ Epoch 21 ] grad = [ 151.53179165   -0.16452713], x = [  7.33413872e+01   9.87162779e-04]
[ Epoch 22 ] grad = [  1.46682774e+02   9.87162779e-02], x = [  7.09944628e+01  -5.92297667e-04]
[ Epoch 23 ] grad = [  1.41988926e+02  -5.92297667e-02], x = [  6.87226400e+01   3.55378600e-04]
[ Epoch 24 ] grad = [  1.37445280e+02   3.55378600e-02], x = [  6.65235155e+01  -2.13227160e-04]
[ Epoch 25 ] grad = [  1.33047031e+02  -2.13227160e-02], x = [  6.43947630e+01   1.27936296e-04]
[ Epoch 26 ] grad = [  1.28789526e+02   1.27936296e-02], x = [  6.23341306e+01  -7.67617777e-05]
[ Epoch 27 ] grad = [  1.24668261e+02  -7.67617777e-03], x = [  6.03394384e+01   4.60570666e-05]
[ Epoch 28 ] grad = [  1.20678877e+02   4.60570666e-03], x = [  5.84085764e+01  -2.76342400e-05]
[ Epoch 29 ] grad = [  1.16817153e+02  -2.76342400e-03], x = [  5.65395019e+01   1.65805440e-05]
[ Epoch 30 ] grad = [  1.13079004e+02   1.65805440e-03], x = [  5.47302379e+01  -9.94832639e-06]
[ Epoch 31 ] grad = [  1.09460476e+02  -9.94832639e-04], x = [  5.29788702e+01   5.96899583e-06]
[ Epoch 32 ] grad = [  1.05957740e+02   5.96899583e-04], x = [  5.12835464e+01  -3.58139750e-06]
[ Epoch 33 ] grad = [  1.02567093e+02  -3.58139750e-04], x = [  4.96424729e+01   2.14883850e-06]
[ Epoch 34 ] grad = [  9.92849458e+01   2.14883850e-04], x = [  4.80539138e+01  -1.28930310e-06]
[ Epoch 35 ] grad = [  9.61078276e+01  -1.28930310e-04], x = [  4.65161885e+01   7.73581860e-07]
[ Epoch 36 ] grad = [  9.30323771e+01   7.73581860e-05], x = [  4.50276705e+01  -4.64149116e-07]
[ Epoch 37 ] grad = [  9.00553410e+01  -4.64149116e-05], x = [  4.35867851e+01   2.78489470e-07]
[ Epoch 38 ] grad = [  8.71735701e+01   2.78489470e-05], x = [  4.21920079e+01  -1.67093682e-07]
[ Epoch 39 ] grad = [  8.43840159e+01  -1.67093682e-05], x = [  4.08418637e+01   1.00256209e-07]
[ Epoch 40 ] grad = [  8.16837274e+01   1.00256209e-05], x = [  3.95349240e+01  -6.01537254e-08]
[ Epoch 41 ] grad = [  7.90698481e+01  -6.01537254e-06], x = [  3.82698065e+01   3.60922353e-08]
[ Epoch 42 ] grad = [  7.65396129e+01   3.60922353e-06], x = [  3.70451727e+01  -2.16553412e-08]
[ Epoch 43 ] grad = [  7.40903453e+01  -2.16553412e-06], x = [  3.58597271e+01   1.29932047e-08]
[ Epoch 44 ] grad = [  7.17194543e+01   1.29932047e-06], x = [  3.47122159e+01  -7.79592282e-09]
[ Epoch 45 ] grad = [  6.94244317e+01  -7.79592282e-07], x = [  3.36014250e+01   4.67755369e-09]
[ Epoch 46 ] grad = [  6.72028499e+01   4.67755369e-07], x = [  3.25261794e+01  -2.80653221e-09]
[ Epoch 47 ] grad = [  6.50523587e+01  -2.80653221e-07], x = [  3.14853416e+01   1.68391933e-09]
[ Epoch 48 ] grad = [  6.29706832e+01   1.68391933e-07], x = [  3.04778107e+01  -1.01035160e-09]
[ Epoch 49 ] grad = [  6.09556214e+01  -1.01035160e-07], x = [  2.95025207e+01   6.06210958e-10]

In [41]:
res, x_arr = gd([150,75], 0.019, g)
contour(X,Y,Z, x_arr)


[ Epoch 0 ] grad = [  300.  7500.], x = [ 144.3  -67.5]
[ Epoch 1 ] grad = [  288.6 -6750. ], x = [ 138.8166   60.75  ]
[ Epoch 2 ] grad = [  277.6332  6075.    ], x = [ 133.5415692  -54.675    ]
[ Epoch 3 ] grad = [  267.0831384 -5467.5      ], x = [ 128.46698957   49.2075    ]
[ Epoch 4 ] grad = [  256.93397914  4920.75      ], x = [ 123.58524397  -44.28675   ]
[ Epoch 5 ] grad = [  247.17048793 -4428.675     ], x = [ 118.8890047   39.858075 ]
[ Epoch 6 ] grad = [  237.77800939  3985.8075    ], x = [ 114.37122252  -35.8722675 ]
[ Epoch 7 ] grad = [  228.74244504 -3587.22675   ], x = [ 110.02511606   32.28504075]
[ Epoch 8 ] grad = [  220.05023212  3228.504075  ], x = [ 105.84416165  -29.05653667]
[ Epoch 9 ] grad = [  211.6883233 -2905.6536675], x = [ 101.82208351   26.15088301]
[ Epoch 10 ] grad = [  203.64416702  2615.08830075], x = [ 97.95284434 -23.53579471]
[ Epoch 11 ] grad = [  195.90568867 -2353.57947067], x = [ 94.23063625  21.18221524]
[ Epoch 12 ] grad = [  188.4612725   2118.22152361], x = [ 90.64987207 -19.06399371]
[ Epoch 13 ] grad = [  181.29974415 -1906.39937125], x = [ 87.20517693  17.15759434]
[ Epoch 14 ] grad = [  174.41035387  1715.75943412], x = [ 83.89138021 -15.44183491]
[ Epoch 15 ] grad = [  167.78276042 -1544.18349071], x = [ 80.70350776  13.89765142]
[ Epoch 16 ] grad = [  161.40701553  1389.76514164], x = [ 77.63677447 -12.50788627]
[ Epoch 17 ] grad = [  155.27354894 -1250.78862747], x = [ 74.68657704  11.25709765]
[ Epoch 18 ] grad = [  149.37315408  1125.70976473], x = [ 71.84848711 -10.13138788]
[ Epoch 19 ] grad = [  143.69697422 -1013.13878825], x = [ 69.1182446    9.11824909]
[ Epoch 20 ] grad = [ 138.2364892   911.82490943], x = [ 66.49175131  -8.20642418]
[ Epoch 21 ] grad = [ 132.98350261 -820.64241849], x = [ 63.96506476   7.38578177]
[ Epoch 22 ] grad = [ 127.93012951  738.57817664], x = [ 61.5343923   -6.64720359]
[ Epoch 23 ] grad = [ 123.06878459 -664.72035897], x = [ 59.19608539   5.98248323]
[ Epoch 24 ] grad = [ 118.39217078  598.24832308], x = [ 56.94663414  -5.38423491]
[ Epoch 25 ] grad = [ 113.89326829 -538.42349077], x = [ 54.78266205   4.84581142]
[ Epoch 26 ] grad = [ 109.56532409  484.58114169], x = [ 52.70092089  -4.36123028]
[ Epoch 27 ] grad = [ 105.40184178 -436.12302752], x = [ 50.69828589   3.92510725]
[ Epoch 28 ] grad = [ 101.39657179  392.51072477], x = [ 48.77175103  -3.53259652]
[ Epoch 29 ] grad = [  97.54350206 -353.25965229], x = [ 46.91842449   3.17933687]
[ Epoch 30 ] grad = [  93.83684898  317.93368706], x = [ 45.13552436  -2.86140318]
[ Epoch 31 ] grad = [  90.27104872 -286.14031836], x = [ 43.42037443   2.57526287]
[ Epoch 32 ] grad = [  86.84074887  257.52628652], x = [ 41.77040021  -2.31773658]
[ Epoch 33 ] grad = [  83.54080041 -231.77365787], x = [ 40.183125     2.08596292]
[ Epoch 34 ] grad = [  80.36625     208.59629208], x = [ 38.65616625  -1.87736663]
[ Epoch 35 ] grad = [  77.3123325  -187.73666287], x = [ 37.18723193   1.68962997]
[ Epoch 36 ] grad = [  74.37446386  168.96299659], x = [ 35.77411712  -1.52066697]
[ Epoch 37 ] grad = [  71.54823424 -152.06669693], x = [ 34.41470067   1.36860027]
[ Epoch 38 ] grad = [  68.82940133  136.86002724], x = [ 33.10694204  -1.23174025]
[ Epoch 39 ] grad = [  66.21388408 -123.17402451], x = [ 31.84887824   1.10856622]
[ Epoch 40 ] grad = [  63.69775649  110.85662206], x = [ 30.63862087  -0.9977096 ]
[ Epoch 41 ] grad = [ 61.27724174 -99.77095985], x = [ 29.47435328   0.89793864]
[ Epoch 42 ] grad = [ 58.94870656  89.79386387], x = [ 28.35432785  -0.80814477]
[ Epoch 43 ] grad = [ 56.70865571 -80.81447748], x = [ 27.27686339   0.7273303 ]
[ Epoch 44 ] grad = [ 54.55372679  72.73302973], x = [ 26.24034259  -0.65459727]
[ Epoch 45 ] grad = [ 52.48068517 -65.45972676], x = [ 25.24320957   0.58913754]
[ Epoch 46 ] grad = [ 50.48641914  58.91375408], x = [ 24.2839676   -0.53022379]
[ Epoch 47 ] grad = [ 48.56793521 -53.02237868], x = [ 23.36117684   0.47720141]
[ Epoch 48 ] grad = [ 46.72235367  47.72014081], x = [ 22.47345212  -0.42948127]
[ Epoch 49 ] grad = [ 44.94690423 -42.94812673], x = [ 21.61946094   0.38653314]

In [44]:
res, x_arr = gd([150,75], 0.02, g)
contour(X,Y,Z, x_arr)


[ Epoch 0 ] grad = [  300.  7500.], x = [ 144.  -75.]
[ Epoch 1 ] grad = [  288. -7500.], x = [ 138.24   75.  ]
[ Epoch 2 ] grad = [  276.48  7500.  ], x = [ 132.7104  -75.    ]
[ Epoch 3 ] grad = [  265.4208 -7500.    ], x = [ 127.401984   75.      ]
[ Epoch 4 ] grad = [  254.803968  7500.      ], x = [ 122.30590464  -75.        ]
[ Epoch 5 ] grad = [  244.61180928 -7500.        ], x = [ 117.41366845   75.        ]
[ Epoch 6 ] grad = [  234.82733691  7500.        ], x = [ 112.71712172  -75.        ]
[ Epoch 7 ] grad = [  225.43424343 -7500.        ], x = [ 108.20843685   75.        ]
[ Epoch 8 ] grad = [  216.4168737  7500.       ], x = [ 103.88009937  -75.        ]
[ Epoch 9 ] grad = [  207.76019875 -7500.        ], x = [ 99.7248954  75.       ]
[ Epoch 10 ] grad = [  199.4497908  7500.       ], x = [ 95.73589958 -75.        ]
[ Epoch 11 ] grad = [  191.47179917 -7500.        ], x = [ 91.9064636  75.       ]
[ Epoch 12 ] grad = [  183.8129272  7500.       ], x = [ 88.23020506 -75.        ]
[ Epoch 13 ] grad = [  176.46041011 -7500.        ], x = [ 84.70099685  75.        ]
[ Epoch 14 ] grad = [  169.40199371  7500.        ], x = [ 81.31295698 -75.        ]
[ Epoch 15 ] grad = [  162.62591396 -7500.        ], x = [ 78.0604387  75.       ]
[ Epoch 16 ] grad = [  156.1208774  7500.       ], x = [ 74.93802115 -75.        ]
[ Epoch 17 ] grad = [  149.8760423 -7500.       ], x = [ 71.94050031  75.        ]
[ Epoch 18 ] grad = [  143.88100061  7500.        ], x = [ 69.06288029 -75.        ]
[ Epoch 19 ] grad = [  138.12576059 -7500.        ], x = [ 66.30036508  75.        ]
[ Epoch 20 ] grad = [  132.60073016  7500.        ], x = [ 63.64835048 -75.        ]
[ Epoch 21 ] grad = [  127.29670096 -7500.        ], x = [ 61.10241646  75.        ]
[ Epoch 22 ] grad = [  122.20483292  7500.        ], x = [ 58.6583198 -75.       ]
[ Epoch 23 ] grad = [  117.3166396 -7500.       ], x = [ 56.31198701  75.        ]
[ Epoch 24 ] grad = [  112.62397402  7500.        ], x = [ 54.05950753 -75.        ]
[ Epoch 25 ] grad = [  108.11901506 -7500.        ], x = [ 51.89712723  75.        ]
[ Epoch 26 ] grad = [  103.79425446  7500.        ], x = [ 49.82124214 -75.        ]
[ Epoch 27 ] grad = [   99.64248428 -7500.        ], x = [ 47.82839245  75.        ]
[ Epoch 28 ] grad = [   95.65678491  7500.        ], x = [ 45.91525675 -75.        ]
[ Epoch 29 ] grad = [   91.83051351 -7500.        ], x = [ 44.07864648  75.        ]
[ Epoch 30 ] grad = [   88.15729297  7500.        ], x = [ 42.31550063 -75.        ]
[ Epoch 31 ] grad = [   84.63100125 -7500.        ], x = [ 40.6228806  75.       ]
[ Epoch 32 ] grad = [   81.2457612  7500.       ], x = [ 38.99796538 -75.        ]
[ Epoch 33 ] grad = [   77.99593075 -7500.        ], x = [ 37.43804676  75.        ]
[ Epoch 34 ] grad = [   74.87609352  7500.        ], x = [ 35.94052489 -75.        ]
[ Epoch 35 ] grad = [   71.88104978 -7500.        ], x = [ 34.5029039  75.       ]
[ Epoch 36 ] grad = [   69.00580779  7500.        ], x = [ 33.12278774 -75.        ]
[ Epoch 37 ] grad = [   66.24557548 -7500.        ], x = [ 31.79787623  75.        ]
[ Epoch 38 ] grad = [   63.59575246  7500.        ], x = [ 30.52596118 -75.        ]
[ Epoch 39 ] grad = [   61.05192236 -7500.        ], x = [ 29.30492273  75.        ]
[ Epoch 40 ] grad = [   58.60984547  7500.        ], x = [ 28.13272582 -75.        ]
[ Epoch 41 ] grad = [   56.26545165 -7500.        ], x = [ 27.00741679  75.        ]
[ Epoch 42 ] grad = [   54.01483358  7500.        ], x = [ 25.92712012 -75.        ]
[ Epoch 43 ] grad = [   51.85424024 -7500.        ], x = [ 24.89003531  75.        ]
[ Epoch 44 ] grad = [   49.78007063  7500.        ], x = [ 23.8944339 -75.       ]
[ Epoch 45 ] grad = [   47.7888678 -7500.       ], x = [ 22.93865655  75.        ]
[ Epoch 46 ] grad = [   45.87731309  7500.        ], x = [ 22.02111028 -75.        ]
[ Epoch 47 ] grad = [   44.04222057 -7500.        ], x = [ 21.14026587  75.        ]
[ Epoch 48 ] grad = [   42.28053175  7500.        ], x = [ 20.29465524 -75.        ]
[ Epoch 49 ] grad = [   40.58931048 -7500.        ], x = [ 19.48286903  75.        ]

In [36]:
def momentum(x_start, step, g, discount = 0.7):   # gd代表了Gradient Descent
    x = np.array(x_start, dtype='float64')
    passing_dot = [x.copy()]
    pre_grad = np.zeros_like(x)
    for i in range(50):
        grad = g(x)
        pre_grad = pre_grad * discount + grad
        x -= pre_grad * step
        
        passing_dot.append(x.copy())
        print '[ Epoch {0} ] grad = {1}, x = {2}'.format(i, grad, x)
        if abs(sum(grad)) < 1e-6:
            break;
    return x, passing_dot
res, x_arr = momentum([150,75], 0.016, g)
contour(X,Y,Z, x_arr)


[ Epoch 0 ] grad = [  300.  7500.], x = [ 145.2  -45. ]
[ Epoch 1 ] grad = [  290.4 -4500. ], x = [ 137.1936  -57.    ]
[ Epoch 2 ] grad = [  274.3872 -5700.    ], x = [ 127.1989248   25.8      ]
[ Epoch 3 ] grad = [  254.3978496  2580.       ], x = [ 116.13228657   42.48      ]
[ Epoch 4 ] grad = [  232.26457313  4248.        ], x = [ 104.66940663  -13.812     ]
[ Epoch 5 ] grad = [  209.33881327 -1381.2       ], x = [ 93.29596967 -31.1172    ]
[ Epoch 6 ] grad = [  186.59193933 -3111.72      ], x = [ 82.34909276   6.55668   ]
[ Epoch 7 ] grad = [ 164.69818552  655.668     ], x = [ 72.05110796  22.437708  ]
[ Epoch 8 ] grad = [  144.10221592  2243.7708    ], x = [ 62.53688314  -2.3459052 ]
[ Epoch 9 ] grad = [ 125.07376629 -234.59052   ], x = [ 53.87574551 -15.94098612]
[ Epoch 10 ] grad = [  107.75149102 -1594.098612  ], x = [ 46.08892531   0.04803503]
[ Epoch 11 ] grad = [ 92.17785063   4.8035028 ], x = [ 39.16330556  11.16349379]
[ Epoch 12 ] grad = [   78.32661113  1116.34937868], x = [ 33.06214596   1.08272486]
[ Epoch 13 ] grad = [  66.12429192  108.27248591], x = [ 27.73334557  -7.70617316]
[ Epoch 14 ] grad = [  55.46669114 -770.61731649], x = [ 23.11571824  -1.52852472]
[ Epoch 15 ] grad = [  46.23143647 -152.85247178], x = [ 19.14367612   5.24146874]
[ Epoch 16 ] grad = [  38.28735224  524.14687436], x = [ 15.750649     1.59411418]
[ Epoch 17 ] grad = [  31.501298    159.41141769], x = [ 12.87150925  -3.5096167 ]
[ Epoch 18 ] grad = [  25.7430185  -350.96167028], x = [ 10.44422313  -1.46684159]
[ Epoch 19 ] grad = [  20.88844626 -146.68415941], x = [ 8.41090771  2.31004753]
[ Epoch 20 ] grad = [  16.82181541  231.00475326], x = [ 6.71843786  1.25779387]
[ Epoch 21 ] grad = [  13.43687572  125.77938691], x = [ 5.31871896 -1.49125389]
[ Epoch 22 ] grad = [  10.63743792 -149.12538859], x = [ 4.16871672 -1.0295811 ]
[ Epoch 23 ] grad = [   8.33743344 -102.9581097 ], x = [ 3.23031622  0.94091961]
[ Epoch 24 ] grad = [  6.46063244  94.09196104], x = [ 2.47006575  0.81479873]
[ Epoch 25 ] grad = [  4.9401315   81.47987289], x = [ 1.85884831 -0.57716385]
[ Epoch 26 ] grad = [  3.71769663 -57.71638544], x = [ 1.37151297 -0.6280755 ]
[ Epoch 27 ] grad = [  2.74302593 -62.80754957], x = [ 0.98648981  0.34120715]
[ Epoch 28 ] grad = [  1.97297961  34.12071485], x = [ 0.68540592  0.47377356]
[ Epoch 29 ] grad = [  1.37081184  47.37735618], x = [ 0.45271421 -0.19146765]
[ Epoch 30 ] grad = [  0.90542842 -19.14676478], x = [ 0.27534316 -0.35078826]
[ Epoch 31 ] grad = [  0.55068632 -35.07882581], x = [ 0.14237244  0.09894853]
[ Epoch 32 ] grad = [ 0.28474489  9.89485276], x = [ 0.04473702  0.25544663]
[ Epoch 33 ] grad = [  0.08947405  25.54466334], x = [-0.02503936 -0.04371931]
[ Epoch 34 ] grad = [-0.05007871 -4.3719306 ], x = [-0.07308156 -0.18318457]
[ Epoch 35 ] grad = [ -0.14616312 -18.3184574 ], x = [-0.1043725   0.01228506]
[ Epoch 36 ] grad = [-0.20874499  1.22850568], x = [-0.12293623  0.12945771]
[ Epoch 37 ] grad = [ -0.24587246  12.94577075], x = [-0.13199688  0.00434623]
[ Epoch 38 ] grad = [-0.26399377  0.4346231 ], x = [-0.13411544 -0.09018577]
[ Epoch 39 ] grad = [-0.26823088 -9.01857721], x = [-0.13130674 -0.01206094]
[ Epoch 40 ] grad = [-0.26261348 -1.20609389], x = [-0.12513883  0.06192395]
[ Epoch 41 ] grad = [-0.25027766  6.19239466], x = [-0.11681685  0.01463505]
[ Epoch 42 ] grad = [-0.2336337   1.46350519], x = [-0.10725333 -0.04188326]
[ Epoch 43 ] grad = [-0.21450666 -4.18832574], x = [-0.09712675 -0.01443286]
[ Epoch 44 ] grad = [-0.19425351 -1.44328621], x = [-0.0869301   0.02787499]
[ Epoch 45 ] grad = [-0.17386019  2.7874994 ], x = [-0.07701067  0.0128905 ]
[ Epoch 46 ] grad = [-0.15402135  1.28905028], x = [-0.06760274 -0.01822345]
[ Epoch 47 ] grad = [-0.13520547 -1.82234455], x = [-0.05885389 -0.0108457 ]
[ Epoch 48 ] grad = [-0.11770778 -1.08456965], x = [-0.05084638  0.01167184]
[ Epoch 49 ] grad = [-0.10169275  1.16718422], x = [-0.04361403  0.00875917]

In [39]:
def nesterov(x_start, step, g, discount = 0.7):   # gd代表了Gradient Descent
    x = np.array(x_start, dtype='float64')
    passing_dot = [x.copy()]
    pre_grad = np.zeros_like(x)
    for i in range(50):
        x_future = x - step * discount * pre_grad
        grad = g(x_future)
        pre_grad = pre_grad * 0.7 + grad 
        x -= pre_grad * step
        
        passing_dot.append(x.copy())
        print '[ Epoch {0} ] grad = {1}, x = {2}'.format(i, grad, x)
        if abs(sum(grad)) < 1e-6:
            break;
    return x, passing_dot
res, x_arr = nesterov([150,75], 0.012, g)
contour(X,Y,Z, x_arr)


[ Epoch 0 ] grad = [  300.  7500.], x = [ 146.4  -15. ]
[ Epoch 1 ] grad = [  287.76 -7800.  ], x = [ 140.42688   15.6    ]
[ Epoch 2 ] grad = [  272.491392  3702.      ], x = [ 132.9757993   -7.404    ]
[ Epoch 3 ] grad = [  255.52008561 -2350.68      ], x = [ 124.69380178    4.70136   ]
[ Epoch 4 ] grad = [  237.79280702  1317.5112    ], x = [ 116.04288983   -2.6350224 ]
[ Epoch 5 ] grad = [ 219.97450293 -777.049008  ], x = [ 107.34755743    1.55409802]
[ Epoch 6 ] grad = [ 202.5216495   448.64823072], x = [ 98.83056496  -0.89729646]
[ Epoch 7 ] grad = [ 185.73734045 -261.32725956], x = [ 90.63982214   0.52265452]
[ Epoch 8 ] grad = [ 169.81260433  151.66202055], x = [ 82.86855092  -0.30332404]
[ Epoch 9 ] grad = [ 154.85732212  -88.15090333], x = [ 75.57037319   0.17630181]
[ Epoch 10 ] grad = [ 140.92329758   51.20399001], x = [ 68.77056922  -0.10240798]
[ Epoch 11 ] grad = [ 128.02141287  -29.75048307], x = [  6.24744495e+01   5.95009661e-02]
[ Epoch 12 ] grad = [ 116.13433133   17.28372284], x = [  5.66735537e+01  -3.45674457e-02]
[ Epoch 13 ] grad = [ 105.22585326  -10.0415334 ], x = [  5.13502164e+01   2.00830668e-02]
[ Epoch 14 ] grad = [ 95.24776057   5.83384255], x = [  4.64809072e+01  -1.16676851e-02]
[ Epoch 15 ] grad = [ 86.14478139  -3.38932114], x = [  4.20386533e+01   6.77864229e-03]
[ Epoch 16 ] grad = [ 77.85815127   1.96910715], x = [  3.79947778e+01  -3.93821429e-03]
[ Epoch 17 ] grad = [ 70.32812993  -1.14400139], x = [  3.43201274e+01   2.28800278e-03]
[ Epoch 18 ] grad = [ 63.49574424   0.66463547], x = [  3.09859232e+01  -1.32927095e-03]
[ Epoch 19 ] grad = [ 57.30396047  -0.38613626], x = [  2.79643327e+01   7.72272511e-04]
[ Epoch 20 ] grad = [ 51.69843875   0.22433529], x = [  2.52288381e+01  -4.48670586e-04]
[ Epoch 21 ] grad = [ 46.62798378  -0.13033308], x = [  2.27544561e+01   2.60666151e-04]
[ Epoch 22 ] grad = [ 42.04477733   0.07572019], x = [  2.05178513e+01  -1.51440373e-04]
[ Epoch 23 ] grad = [ 37.90445603  -0.04399149], x = [  1.84973745e+01   8.79829880e-05]
[ Epoch 24 ] grad = [  3.41660816e+01   2.55579341e-02], x = [  1.66730478e+01  -5.11158682e-05]
[ Epoch 25 ] grad = [  3.07920382e+01  -1.48485068e-02], x = [  1.50265146e+01   2.96970135e-05]
[ Epoch 26 ] grad = [  2.77478828e+01   8.62660307e-03], x = [  1.35409668e+01  -1.72532061e-05]
[ Epoch 27 ] grad = [  2.50021667e+01  -5.01183599e-03], x = [  1.22010574e+01   1.00236720e-05]
[ Epoch 28 ] grad = [  2.25262414e+01   2.91174867e-03], x = [  1.09928058e+01  -5.82349733e-06]
[ Epoch 29 ] grad = [  2.02940595e+01  -1.69165159e-03], x = [  9.90350104e+00   3.38330317e-06]
[ Epoch 30 ] grad = [  1.82819754e+01   9.82806352e-04], x = [  8.92160399e+00  -1.96561270e-06]
[ Epoch 31 ] grad = [  1.64685521e+01  -5.70985382e-04], x = [  8.03665343e+00   1.14197076e-06]
[ Epoch 32 ] grad = [  1.48343761e+01   3.31727919e-04], x = [  7.23917552e+00  -6.63455838e-07]
[ Epoch 33 ] grad = [  1.33618820e+01  -1.92725446e-04], x = [  6.52059840e+00   3.85450892e-07]
[ Epoch 34 ] grad = [  1.20351888e+01   1.11968560e-04], x = [  5.87317215e+00  -2.23937121e-07]
[ Epoch 35 ] grad = [  1.08399476e+01  -6.50508729e-05], x = [  5.28989441e+00   1.30101746e-07]
[ Epoch 36 ] grad = [  9.76319997e+00   3.77928952e-05], x = [  4.76444159e+00  -7.55857905e-08]
[ Epoch 37 ] grad = [  8.79324922e+00  -2.19567066e-05], x = [  4.29110562e+00   4.39134132e-08]
[ Epoch 38 ] grad = [  7.91954089e+00   1.27562856e-05], x = [  3.86473596e+00  -2.55125711e-08]
[ Epoch 39 ] grad = [  7.13255438e+00  -7.41107602e-06], x = [  3.48068654e+00   1.48221520e-08]
[ Epoch 40 ] grad = [  6.42370389e+00   4.30564583e-06], x = [  3.13476750e+00  -8.61129165e-09]
[ Epoch 41 ] grad = [  5.78524834e+00  -2.50147022e-06], x = [  2.82320119e+00   5.00294045e-09]
[ Epoch 42 ] grad = [  5.21020954e+00   1.45329029e-06], x = [  2.54258226e+00  -2.90658058e-09]
[ Epoch 43 ] grad = [  4.69229801e+00  -8.44324531e-07], x = [  2.28984143e+00   1.68864906e-09]
[ Epoch 44 ] grad = [  4.22584570e+00   4.90530981e-07], x = [  2.06221270e+00  -9.81061962e-10]
[ Epoch 45 ] grad = [  3.80574519e+00  -2.84985968e-07], x = [  1.85720365e+00   5.69971936e-10]
[ Epoch 46 ] grad = [  3.42739463e+00   1.65569566e-07], x = [  1.67256858e+00  -3.31139133e-10]
[ Epoch 47 ] grad = [  3.08664806e+00  -9.61916881e-08], x = [  1.50628425e+00   1.92383376e-10]
[ Epoch 48 ] grad = [  2.77977045e+00   5.58849133e-08], x = [  1.35652798e+00  -1.11769827e-10]
[ Epoch 49 ] grad = [  2.50339717e+00  -3.24677068e-08], x = [  1.22165782e+00   6.49354137e-11]