In [20]:
import numpy as np
import matplotlib.pyplot as plt
def f(x):
return x[0] * x[0] + 50 * x[1] * x[1]
def g(x):
return np.array([2 * x[0], 100 * x[1]])
xi = np.linspace(-200,200,1000)
yi = np.linspace(-100,100,1000)
X,Y = np.meshgrid(xi, yi)
Z = X * X + 50 * Y * Y
%matplotlib inline
def contour(X,Y,Z, arr = None):
plt.figure(figsize=(15,7))
xx = X.flatten()
yy = Y.flatten()
zz = Z.flatten()
plt.contour(X, Y, Z, colors='black')
plt.plot(0,0,marker='*')
if arr is not None:
arr = np.array(arr)
for i in range(len(arr) - 1):
plt.plot(arr[i:i+2,0],arr[i:i+2,1])
contour(X,Y,Z)
In [37]:
def gd(x_start, step, g): # gd代表了Gradient Descent
x = np.array(x_start, dtype='float64')
passing_dot = [x.copy()]
for i in range(50):
grad = g(x)
x -= grad * step
passing_dot.append(x.copy())
print '[ Epoch {0} ] grad = {1}, x = {2}'.format(i, grad, x)
if abs(sum(grad)) < 1e-6:
break;
return x, passing_dot
res, x_arr = gd([150,75], 0.016, g)
contour(X,Y,Z, x_arr)
[ Epoch 0 ] grad = [ 300. 7500.], x = [ 145.2 -45. ]
[ Epoch 1 ] grad = [ 290.4 -4500. ], x = [ 140.5536 27. ]
[ Epoch 2 ] grad = [ 281.1072 2700. ], x = [ 136.0558848 -16.2 ]
[ Epoch 3 ] grad = [ 272.1117696 -1620. ], x = [ 131.70209649 9.72 ]
[ Epoch 4 ] grad = [ 263.40419297 972. ], x = [ 127.4876294 -5.832 ]
[ Epoch 5 ] grad = [ 254.9752588 -583.2 ], x = [ 123.40802526 3.4992 ]
[ Epoch 6 ] grad = [ 246.81605052 349.92 ], x = [ 119.45896845 -2.09952 ]
[ Epoch 7 ] grad = [ 238.9179369 -209.952 ], x = [ 115.63628146 1.259712 ]
[ Epoch 8 ] grad = [ 231.27256292 125.9712 ], x = [ 111.93592045 -0.7558272 ]
[ Epoch 9 ] grad = [ 223.87184091 -75.58272 ], x = [ 108.353971 0.45349632]
[ Epoch 10 ] grad = [ 216.707942 45.349632], x = [ 104.88664393 -0.27209779]
[ Epoch 11 ] grad = [ 209.77328785 -27.2097792 ], x = [ 101.53027132 0.16325868]
[ Epoch 12 ] grad = [ 203.06054264 16.32586752], x = [ 9.82813026e+01 -9.79552051e-02]
[ Epoch 13 ] grad = [ 196.56260528 -9.79552051], x = [ 9.51363010e+01 5.87731231e-02]
[ Epoch 14 ] grad = [ 190.27260191 5.87731231], x = [ 9.20919393e+01 -3.52638738e-02]
[ Epoch 15 ] grad = [ 184.18387865 -3.52638738], x = [ 8.91449973e+01 2.11583243e-02]
[ Epoch 16 ] grad = [ 178.28999453 2.11583243], x = [ 8.62923574e+01 -1.26949946e-02]
[ Epoch 17 ] grad = [ 172.58471471 -1.26949946], x = [ 8.35310019e+01 7.61699675e-03]
[ Epoch 18 ] grad = [ 167.06200383 0.76169968], x = [ 8.08580099e+01 -4.57019805e-03]
[ Epoch 19 ] grad = [ 161.71601971 -0.45701981], x = [ 7.82705535e+01 2.74211883e-03]
[ Epoch 20 ] grad = [ 156.54110708 0.27421188], x = [ 7.57658958e+01 -1.64527130e-03]
[ Epoch 21 ] grad = [ 151.53179165 -0.16452713], x = [ 7.33413872e+01 9.87162779e-04]
[ Epoch 22 ] grad = [ 1.46682774e+02 9.87162779e-02], x = [ 7.09944628e+01 -5.92297667e-04]
[ Epoch 23 ] grad = [ 1.41988926e+02 -5.92297667e-02], x = [ 6.87226400e+01 3.55378600e-04]
[ Epoch 24 ] grad = [ 1.37445280e+02 3.55378600e-02], x = [ 6.65235155e+01 -2.13227160e-04]
[ Epoch 25 ] grad = [ 1.33047031e+02 -2.13227160e-02], x = [ 6.43947630e+01 1.27936296e-04]
[ Epoch 26 ] grad = [ 1.28789526e+02 1.27936296e-02], x = [ 6.23341306e+01 -7.67617777e-05]
[ Epoch 27 ] grad = [ 1.24668261e+02 -7.67617777e-03], x = [ 6.03394384e+01 4.60570666e-05]
[ Epoch 28 ] grad = [ 1.20678877e+02 4.60570666e-03], x = [ 5.84085764e+01 -2.76342400e-05]
[ Epoch 29 ] grad = [ 1.16817153e+02 -2.76342400e-03], x = [ 5.65395019e+01 1.65805440e-05]
[ Epoch 30 ] grad = [ 1.13079004e+02 1.65805440e-03], x = [ 5.47302379e+01 -9.94832639e-06]
[ Epoch 31 ] grad = [ 1.09460476e+02 -9.94832639e-04], x = [ 5.29788702e+01 5.96899583e-06]
[ Epoch 32 ] grad = [ 1.05957740e+02 5.96899583e-04], x = [ 5.12835464e+01 -3.58139750e-06]
[ Epoch 33 ] grad = [ 1.02567093e+02 -3.58139750e-04], x = [ 4.96424729e+01 2.14883850e-06]
[ Epoch 34 ] grad = [ 9.92849458e+01 2.14883850e-04], x = [ 4.80539138e+01 -1.28930310e-06]
[ Epoch 35 ] grad = [ 9.61078276e+01 -1.28930310e-04], x = [ 4.65161885e+01 7.73581860e-07]
[ Epoch 36 ] grad = [ 9.30323771e+01 7.73581860e-05], x = [ 4.50276705e+01 -4.64149116e-07]
[ Epoch 37 ] grad = [ 9.00553410e+01 -4.64149116e-05], x = [ 4.35867851e+01 2.78489470e-07]
[ Epoch 38 ] grad = [ 8.71735701e+01 2.78489470e-05], x = [ 4.21920079e+01 -1.67093682e-07]
[ Epoch 39 ] grad = [ 8.43840159e+01 -1.67093682e-05], x = [ 4.08418637e+01 1.00256209e-07]
[ Epoch 40 ] grad = [ 8.16837274e+01 1.00256209e-05], x = [ 3.95349240e+01 -6.01537254e-08]
[ Epoch 41 ] grad = [ 7.90698481e+01 -6.01537254e-06], x = [ 3.82698065e+01 3.60922353e-08]
[ Epoch 42 ] grad = [ 7.65396129e+01 3.60922353e-06], x = [ 3.70451727e+01 -2.16553412e-08]
[ Epoch 43 ] grad = [ 7.40903453e+01 -2.16553412e-06], x = [ 3.58597271e+01 1.29932047e-08]
[ Epoch 44 ] grad = [ 7.17194543e+01 1.29932047e-06], x = [ 3.47122159e+01 -7.79592282e-09]
[ Epoch 45 ] grad = [ 6.94244317e+01 -7.79592282e-07], x = [ 3.36014250e+01 4.67755369e-09]
[ Epoch 46 ] grad = [ 6.72028499e+01 4.67755369e-07], x = [ 3.25261794e+01 -2.80653221e-09]
[ Epoch 47 ] grad = [ 6.50523587e+01 -2.80653221e-07], x = [ 3.14853416e+01 1.68391933e-09]
[ Epoch 48 ] grad = [ 6.29706832e+01 1.68391933e-07], x = [ 3.04778107e+01 -1.01035160e-09]
[ Epoch 49 ] grad = [ 6.09556214e+01 -1.01035160e-07], x = [ 2.95025207e+01 6.06210958e-10]
In [41]:
res, x_arr = gd([150,75], 0.019, g)
contour(X,Y,Z, x_arr)
[ Epoch 0 ] grad = [ 300. 7500.], x = [ 144.3 -67.5]
[ Epoch 1 ] grad = [ 288.6 -6750. ], x = [ 138.8166 60.75 ]
[ Epoch 2 ] grad = [ 277.6332 6075. ], x = [ 133.5415692 -54.675 ]
[ Epoch 3 ] grad = [ 267.0831384 -5467.5 ], x = [ 128.46698957 49.2075 ]
[ Epoch 4 ] grad = [ 256.93397914 4920.75 ], x = [ 123.58524397 -44.28675 ]
[ Epoch 5 ] grad = [ 247.17048793 -4428.675 ], x = [ 118.8890047 39.858075 ]
[ Epoch 6 ] grad = [ 237.77800939 3985.8075 ], x = [ 114.37122252 -35.8722675 ]
[ Epoch 7 ] grad = [ 228.74244504 -3587.22675 ], x = [ 110.02511606 32.28504075]
[ Epoch 8 ] grad = [ 220.05023212 3228.504075 ], x = [ 105.84416165 -29.05653667]
[ Epoch 9 ] grad = [ 211.6883233 -2905.6536675], x = [ 101.82208351 26.15088301]
[ Epoch 10 ] grad = [ 203.64416702 2615.08830075], x = [ 97.95284434 -23.53579471]
[ Epoch 11 ] grad = [ 195.90568867 -2353.57947067], x = [ 94.23063625 21.18221524]
[ Epoch 12 ] grad = [ 188.4612725 2118.22152361], x = [ 90.64987207 -19.06399371]
[ Epoch 13 ] grad = [ 181.29974415 -1906.39937125], x = [ 87.20517693 17.15759434]
[ Epoch 14 ] grad = [ 174.41035387 1715.75943412], x = [ 83.89138021 -15.44183491]
[ Epoch 15 ] grad = [ 167.78276042 -1544.18349071], x = [ 80.70350776 13.89765142]
[ Epoch 16 ] grad = [ 161.40701553 1389.76514164], x = [ 77.63677447 -12.50788627]
[ Epoch 17 ] grad = [ 155.27354894 -1250.78862747], x = [ 74.68657704 11.25709765]
[ Epoch 18 ] grad = [ 149.37315408 1125.70976473], x = [ 71.84848711 -10.13138788]
[ Epoch 19 ] grad = [ 143.69697422 -1013.13878825], x = [ 69.1182446 9.11824909]
[ Epoch 20 ] grad = [ 138.2364892 911.82490943], x = [ 66.49175131 -8.20642418]
[ Epoch 21 ] grad = [ 132.98350261 -820.64241849], x = [ 63.96506476 7.38578177]
[ Epoch 22 ] grad = [ 127.93012951 738.57817664], x = [ 61.5343923 -6.64720359]
[ Epoch 23 ] grad = [ 123.06878459 -664.72035897], x = [ 59.19608539 5.98248323]
[ Epoch 24 ] grad = [ 118.39217078 598.24832308], x = [ 56.94663414 -5.38423491]
[ Epoch 25 ] grad = [ 113.89326829 -538.42349077], x = [ 54.78266205 4.84581142]
[ Epoch 26 ] grad = [ 109.56532409 484.58114169], x = [ 52.70092089 -4.36123028]
[ Epoch 27 ] grad = [ 105.40184178 -436.12302752], x = [ 50.69828589 3.92510725]
[ Epoch 28 ] grad = [ 101.39657179 392.51072477], x = [ 48.77175103 -3.53259652]
[ Epoch 29 ] grad = [ 97.54350206 -353.25965229], x = [ 46.91842449 3.17933687]
[ Epoch 30 ] grad = [ 93.83684898 317.93368706], x = [ 45.13552436 -2.86140318]
[ Epoch 31 ] grad = [ 90.27104872 -286.14031836], x = [ 43.42037443 2.57526287]
[ Epoch 32 ] grad = [ 86.84074887 257.52628652], x = [ 41.77040021 -2.31773658]
[ Epoch 33 ] grad = [ 83.54080041 -231.77365787], x = [ 40.183125 2.08596292]
[ Epoch 34 ] grad = [ 80.36625 208.59629208], x = [ 38.65616625 -1.87736663]
[ Epoch 35 ] grad = [ 77.3123325 -187.73666287], x = [ 37.18723193 1.68962997]
[ Epoch 36 ] grad = [ 74.37446386 168.96299659], x = [ 35.77411712 -1.52066697]
[ Epoch 37 ] grad = [ 71.54823424 -152.06669693], x = [ 34.41470067 1.36860027]
[ Epoch 38 ] grad = [ 68.82940133 136.86002724], x = [ 33.10694204 -1.23174025]
[ Epoch 39 ] grad = [ 66.21388408 -123.17402451], x = [ 31.84887824 1.10856622]
[ Epoch 40 ] grad = [ 63.69775649 110.85662206], x = [ 30.63862087 -0.9977096 ]
[ Epoch 41 ] grad = [ 61.27724174 -99.77095985], x = [ 29.47435328 0.89793864]
[ Epoch 42 ] grad = [ 58.94870656 89.79386387], x = [ 28.35432785 -0.80814477]
[ Epoch 43 ] grad = [ 56.70865571 -80.81447748], x = [ 27.27686339 0.7273303 ]
[ Epoch 44 ] grad = [ 54.55372679 72.73302973], x = [ 26.24034259 -0.65459727]
[ Epoch 45 ] grad = [ 52.48068517 -65.45972676], x = [ 25.24320957 0.58913754]
[ Epoch 46 ] grad = [ 50.48641914 58.91375408], x = [ 24.2839676 -0.53022379]
[ Epoch 47 ] grad = [ 48.56793521 -53.02237868], x = [ 23.36117684 0.47720141]
[ Epoch 48 ] grad = [ 46.72235367 47.72014081], x = [ 22.47345212 -0.42948127]
[ Epoch 49 ] grad = [ 44.94690423 -42.94812673], x = [ 21.61946094 0.38653314]
In [44]:
res, x_arr = gd([150,75], 0.02, g)
contour(X,Y,Z, x_arr)
[ Epoch 0 ] grad = [ 300. 7500.], x = [ 144. -75.]
[ Epoch 1 ] grad = [ 288. -7500.], x = [ 138.24 75. ]
[ Epoch 2 ] grad = [ 276.48 7500. ], x = [ 132.7104 -75. ]
[ Epoch 3 ] grad = [ 265.4208 -7500. ], x = [ 127.401984 75. ]
[ Epoch 4 ] grad = [ 254.803968 7500. ], x = [ 122.30590464 -75. ]
[ Epoch 5 ] grad = [ 244.61180928 -7500. ], x = [ 117.41366845 75. ]
[ Epoch 6 ] grad = [ 234.82733691 7500. ], x = [ 112.71712172 -75. ]
[ Epoch 7 ] grad = [ 225.43424343 -7500. ], x = [ 108.20843685 75. ]
[ Epoch 8 ] grad = [ 216.4168737 7500. ], x = [ 103.88009937 -75. ]
[ Epoch 9 ] grad = [ 207.76019875 -7500. ], x = [ 99.7248954 75. ]
[ Epoch 10 ] grad = [ 199.4497908 7500. ], x = [ 95.73589958 -75. ]
[ Epoch 11 ] grad = [ 191.47179917 -7500. ], x = [ 91.9064636 75. ]
[ Epoch 12 ] grad = [ 183.8129272 7500. ], x = [ 88.23020506 -75. ]
[ Epoch 13 ] grad = [ 176.46041011 -7500. ], x = [ 84.70099685 75. ]
[ Epoch 14 ] grad = [ 169.40199371 7500. ], x = [ 81.31295698 -75. ]
[ Epoch 15 ] grad = [ 162.62591396 -7500. ], x = [ 78.0604387 75. ]
[ Epoch 16 ] grad = [ 156.1208774 7500. ], x = [ 74.93802115 -75. ]
[ Epoch 17 ] grad = [ 149.8760423 -7500. ], x = [ 71.94050031 75. ]
[ Epoch 18 ] grad = [ 143.88100061 7500. ], x = [ 69.06288029 -75. ]
[ Epoch 19 ] grad = [ 138.12576059 -7500. ], x = [ 66.30036508 75. ]
[ Epoch 20 ] grad = [ 132.60073016 7500. ], x = [ 63.64835048 -75. ]
[ Epoch 21 ] grad = [ 127.29670096 -7500. ], x = [ 61.10241646 75. ]
[ Epoch 22 ] grad = [ 122.20483292 7500. ], x = [ 58.6583198 -75. ]
[ Epoch 23 ] grad = [ 117.3166396 -7500. ], x = [ 56.31198701 75. ]
[ Epoch 24 ] grad = [ 112.62397402 7500. ], x = [ 54.05950753 -75. ]
[ Epoch 25 ] grad = [ 108.11901506 -7500. ], x = [ 51.89712723 75. ]
[ Epoch 26 ] grad = [ 103.79425446 7500. ], x = [ 49.82124214 -75. ]
[ Epoch 27 ] grad = [ 99.64248428 -7500. ], x = [ 47.82839245 75. ]
[ Epoch 28 ] grad = [ 95.65678491 7500. ], x = [ 45.91525675 -75. ]
[ Epoch 29 ] grad = [ 91.83051351 -7500. ], x = [ 44.07864648 75. ]
[ Epoch 30 ] grad = [ 88.15729297 7500. ], x = [ 42.31550063 -75. ]
[ Epoch 31 ] grad = [ 84.63100125 -7500. ], x = [ 40.6228806 75. ]
[ Epoch 32 ] grad = [ 81.2457612 7500. ], x = [ 38.99796538 -75. ]
[ Epoch 33 ] grad = [ 77.99593075 -7500. ], x = [ 37.43804676 75. ]
[ Epoch 34 ] grad = [ 74.87609352 7500. ], x = [ 35.94052489 -75. ]
[ Epoch 35 ] grad = [ 71.88104978 -7500. ], x = [ 34.5029039 75. ]
[ Epoch 36 ] grad = [ 69.00580779 7500. ], x = [ 33.12278774 -75. ]
[ Epoch 37 ] grad = [ 66.24557548 -7500. ], x = [ 31.79787623 75. ]
[ Epoch 38 ] grad = [ 63.59575246 7500. ], x = [ 30.52596118 -75. ]
[ Epoch 39 ] grad = [ 61.05192236 -7500. ], x = [ 29.30492273 75. ]
[ Epoch 40 ] grad = [ 58.60984547 7500. ], x = [ 28.13272582 -75. ]
[ Epoch 41 ] grad = [ 56.26545165 -7500. ], x = [ 27.00741679 75. ]
[ Epoch 42 ] grad = [ 54.01483358 7500. ], x = [ 25.92712012 -75. ]
[ Epoch 43 ] grad = [ 51.85424024 -7500. ], x = [ 24.89003531 75. ]
[ Epoch 44 ] grad = [ 49.78007063 7500. ], x = [ 23.8944339 -75. ]
[ Epoch 45 ] grad = [ 47.7888678 -7500. ], x = [ 22.93865655 75. ]
[ Epoch 46 ] grad = [ 45.87731309 7500. ], x = [ 22.02111028 -75. ]
[ Epoch 47 ] grad = [ 44.04222057 -7500. ], x = [ 21.14026587 75. ]
[ Epoch 48 ] grad = [ 42.28053175 7500. ], x = [ 20.29465524 -75. ]
[ Epoch 49 ] grad = [ 40.58931048 -7500. ], x = [ 19.48286903 75. ]
In [36]:
def momentum(x_start, step, g, discount = 0.7): # gd代表了Gradient Descent
x = np.array(x_start, dtype='float64')
passing_dot = [x.copy()]
pre_grad = np.zeros_like(x)
for i in range(50):
grad = g(x)
pre_grad = pre_grad * discount + grad
x -= pre_grad * step
passing_dot.append(x.copy())
print '[ Epoch {0} ] grad = {1}, x = {2}'.format(i, grad, x)
if abs(sum(grad)) < 1e-6:
break;
return x, passing_dot
res, x_arr = momentum([150,75], 0.016, g)
contour(X,Y,Z, x_arr)
[ Epoch 0 ] grad = [ 300. 7500.], x = [ 145.2 -45. ]
[ Epoch 1 ] grad = [ 290.4 -4500. ], x = [ 137.1936 -57. ]
[ Epoch 2 ] grad = [ 274.3872 -5700. ], x = [ 127.1989248 25.8 ]
[ Epoch 3 ] grad = [ 254.3978496 2580. ], x = [ 116.13228657 42.48 ]
[ Epoch 4 ] grad = [ 232.26457313 4248. ], x = [ 104.66940663 -13.812 ]
[ Epoch 5 ] grad = [ 209.33881327 -1381.2 ], x = [ 93.29596967 -31.1172 ]
[ Epoch 6 ] grad = [ 186.59193933 -3111.72 ], x = [ 82.34909276 6.55668 ]
[ Epoch 7 ] grad = [ 164.69818552 655.668 ], x = [ 72.05110796 22.437708 ]
[ Epoch 8 ] grad = [ 144.10221592 2243.7708 ], x = [ 62.53688314 -2.3459052 ]
[ Epoch 9 ] grad = [ 125.07376629 -234.59052 ], x = [ 53.87574551 -15.94098612]
[ Epoch 10 ] grad = [ 107.75149102 -1594.098612 ], x = [ 46.08892531 0.04803503]
[ Epoch 11 ] grad = [ 92.17785063 4.8035028 ], x = [ 39.16330556 11.16349379]
[ Epoch 12 ] grad = [ 78.32661113 1116.34937868], x = [ 33.06214596 1.08272486]
[ Epoch 13 ] grad = [ 66.12429192 108.27248591], x = [ 27.73334557 -7.70617316]
[ Epoch 14 ] grad = [ 55.46669114 -770.61731649], x = [ 23.11571824 -1.52852472]
[ Epoch 15 ] grad = [ 46.23143647 -152.85247178], x = [ 19.14367612 5.24146874]
[ Epoch 16 ] grad = [ 38.28735224 524.14687436], x = [ 15.750649 1.59411418]
[ Epoch 17 ] grad = [ 31.501298 159.41141769], x = [ 12.87150925 -3.5096167 ]
[ Epoch 18 ] grad = [ 25.7430185 -350.96167028], x = [ 10.44422313 -1.46684159]
[ Epoch 19 ] grad = [ 20.88844626 -146.68415941], x = [ 8.41090771 2.31004753]
[ Epoch 20 ] grad = [ 16.82181541 231.00475326], x = [ 6.71843786 1.25779387]
[ Epoch 21 ] grad = [ 13.43687572 125.77938691], x = [ 5.31871896 -1.49125389]
[ Epoch 22 ] grad = [ 10.63743792 -149.12538859], x = [ 4.16871672 -1.0295811 ]
[ Epoch 23 ] grad = [ 8.33743344 -102.9581097 ], x = [ 3.23031622 0.94091961]
[ Epoch 24 ] grad = [ 6.46063244 94.09196104], x = [ 2.47006575 0.81479873]
[ Epoch 25 ] grad = [ 4.9401315 81.47987289], x = [ 1.85884831 -0.57716385]
[ Epoch 26 ] grad = [ 3.71769663 -57.71638544], x = [ 1.37151297 -0.6280755 ]
[ Epoch 27 ] grad = [ 2.74302593 -62.80754957], x = [ 0.98648981 0.34120715]
[ Epoch 28 ] grad = [ 1.97297961 34.12071485], x = [ 0.68540592 0.47377356]
[ Epoch 29 ] grad = [ 1.37081184 47.37735618], x = [ 0.45271421 -0.19146765]
[ Epoch 30 ] grad = [ 0.90542842 -19.14676478], x = [ 0.27534316 -0.35078826]
[ Epoch 31 ] grad = [ 0.55068632 -35.07882581], x = [ 0.14237244 0.09894853]
[ Epoch 32 ] grad = [ 0.28474489 9.89485276], x = [ 0.04473702 0.25544663]
[ Epoch 33 ] grad = [ 0.08947405 25.54466334], x = [-0.02503936 -0.04371931]
[ Epoch 34 ] grad = [-0.05007871 -4.3719306 ], x = [-0.07308156 -0.18318457]
[ Epoch 35 ] grad = [ -0.14616312 -18.3184574 ], x = [-0.1043725 0.01228506]
[ Epoch 36 ] grad = [-0.20874499 1.22850568], x = [-0.12293623 0.12945771]
[ Epoch 37 ] grad = [ -0.24587246 12.94577075], x = [-0.13199688 0.00434623]
[ Epoch 38 ] grad = [-0.26399377 0.4346231 ], x = [-0.13411544 -0.09018577]
[ Epoch 39 ] grad = [-0.26823088 -9.01857721], x = [-0.13130674 -0.01206094]
[ Epoch 40 ] grad = [-0.26261348 -1.20609389], x = [-0.12513883 0.06192395]
[ Epoch 41 ] grad = [-0.25027766 6.19239466], x = [-0.11681685 0.01463505]
[ Epoch 42 ] grad = [-0.2336337 1.46350519], x = [-0.10725333 -0.04188326]
[ Epoch 43 ] grad = [-0.21450666 -4.18832574], x = [-0.09712675 -0.01443286]
[ Epoch 44 ] grad = [-0.19425351 -1.44328621], x = [-0.0869301 0.02787499]
[ Epoch 45 ] grad = [-0.17386019 2.7874994 ], x = [-0.07701067 0.0128905 ]
[ Epoch 46 ] grad = [-0.15402135 1.28905028], x = [-0.06760274 -0.01822345]
[ Epoch 47 ] grad = [-0.13520547 -1.82234455], x = [-0.05885389 -0.0108457 ]
[ Epoch 48 ] grad = [-0.11770778 -1.08456965], x = [-0.05084638 0.01167184]
[ Epoch 49 ] grad = [-0.10169275 1.16718422], x = [-0.04361403 0.00875917]
In [39]:
def nesterov(x_start, step, g, discount = 0.7): # gd代表了Gradient Descent
x = np.array(x_start, dtype='float64')
passing_dot = [x.copy()]
pre_grad = np.zeros_like(x)
for i in range(50):
x_future = x - step * discount * pre_grad
grad = g(x_future)
pre_grad = pre_grad * 0.7 + grad
x -= pre_grad * step
passing_dot.append(x.copy())
print '[ Epoch {0} ] grad = {1}, x = {2}'.format(i, grad, x)
if abs(sum(grad)) < 1e-6:
break;
return x, passing_dot
res, x_arr = nesterov([150,75], 0.012, g)
contour(X,Y,Z, x_arr)
[ Epoch 0 ] grad = [ 300. 7500.], x = [ 146.4 -15. ]
[ Epoch 1 ] grad = [ 287.76 -7800. ], x = [ 140.42688 15.6 ]
[ Epoch 2 ] grad = [ 272.491392 3702. ], x = [ 132.9757993 -7.404 ]
[ Epoch 3 ] grad = [ 255.52008561 -2350.68 ], x = [ 124.69380178 4.70136 ]
[ Epoch 4 ] grad = [ 237.79280702 1317.5112 ], x = [ 116.04288983 -2.6350224 ]
[ Epoch 5 ] grad = [ 219.97450293 -777.049008 ], x = [ 107.34755743 1.55409802]
[ Epoch 6 ] grad = [ 202.5216495 448.64823072], x = [ 98.83056496 -0.89729646]
[ Epoch 7 ] grad = [ 185.73734045 -261.32725956], x = [ 90.63982214 0.52265452]
[ Epoch 8 ] grad = [ 169.81260433 151.66202055], x = [ 82.86855092 -0.30332404]
[ Epoch 9 ] grad = [ 154.85732212 -88.15090333], x = [ 75.57037319 0.17630181]
[ Epoch 10 ] grad = [ 140.92329758 51.20399001], x = [ 68.77056922 -0.10240798]
[ Epoch 11 ] grad = [ 128.02141287 -29.75048307], x = [ 6.24744495e+01 5.95009661e-02]
[ Epoch 12 ] grad = [ 116.13433133 17.28372284], x = [ 5.66735537e+01 -3.45674457e-02]
[ Epoch 13 ] grad = [ 105.22585326 -10.0415334 ], x = [ 5.13502164e+01 2.00830668e-02]
[ Epoch 14 ] grad = [ 95.24776057 5.83384255], x = [ 4.64809072e+01 -1.16676851e-02]
[ Epoch 15 ] grad = [ 86.14478139 -3.38932114], x = [ 4.20386533e+01 6.77864229e-03]
[ Epoch 16 ] grad = [ 77.85815127 1.96910715], x = [ 3.79947778e+01 -3.93821429e-03]
[ Epoch 17 ] grad = [ 70.32812993 -1.14400139], x = [ 3.43201274e+01 2.28800278e-03]
[ Epoch 18 ] grad = [ 63.49574424 0.66463547], x = [ 3.09859232e+01 -1.32927095e-03]
[ Epoch 19 ] grad = [ 57.30396047 -0.38613626], x = [ 2.79643327e+01 7.72272511e-04]
[ Epoch 20 ] grad = [ 51.69843875 0.22433529], x = [ 2.52288381e+01 -4.48670586e-04]
[ Epoch 21 ] grad = [ 46.62798378 -0.13033308], x = [ 2.27544561e+01 2.60666151e-04]
[ Epoch 22 ] grad = [ 42.04477733 0.07572019], x = [ 2.05178513e+01 -1.51440373e-04]
[ Epoch 23 ] grad = [ 37.90445603 -0.04399149], x = [ 1.84973745e+01 8.79829880e-05]
[ Epoch 24 ] grad = [ 3.41660816e+01 2.55579341e-02], x = [ 1.66730478e+01 -5.11158682e-05]
[ Epoch 25 ] grad = [ 3.07920382e+01 -1.48485068e-02], x = [ 1.50265146e+01 2.96970135e-05]
[ Epoch 26 ] grad = [ 2.77478828e+01 8.62660307e-03], x = [ 1.35409668e+01 -1.72532061e-05]
[ Epoch 27 ] grad = [ 2.50021667e+01 -5.01183599e-03], x = [ 1.22010574e+01 1.00236720e-05]
[ Epoch 28 ] grad = [ 2.25262414e+01 2.91174867e-03], x = [ 1.09928058e+01 -5.82349733e-06]
[ Epoch 29 ] grad = [ 2.02940595e+01 -1.69165159e-03], x = [ 9.90350104e+00 3.38330317e-06]
[ Epoch 30 ] grad = [ 1.82819754e+01 9.82806352e-04], x = [ 8.92160399e+00 -1.96561270e-06]
[ Epoch 31 ] grad = [ 1.64685521e+01 -5.70985382e-04], x = [ 8.03665343e+00 1.14197076e-06]
[ Epoch 32 ] grad = [ 1.48343761e+01 3.31727919e-04], x = [ 7.23917552e+00 -6.63455838e-07]
[ Epoch 33 ] grad = [ 1.33618820e+01 -1.92725446e-04], x = [ 6.52059840e+00 3.85450892e-07]
[ Epoch 34 ] grad = [ 1.20351888e+01 1.11968560e-04], x = [ 5.87317215e+00 -2.23937121e-07]
[ Epoch 35 ] grad = [ 1.08399476e+01 -6.50508729e-05], x = [ 5.28989441e+00 1.30101746e-07]
[ Epoch 36 ] grad = [ 9.76319997e+00 3.77928952e-05], x = [ 4.76444159e+00 -7.55857905e-08]
[ Epoch 37 ] grad = [ 8.79324922e+00 -2.19567066e-05], x = [ 4.29110562e+00 4.39134132e-08]
[ Epoch 38 ] grad = [ 7.91954089e+00 1.27562856e-05], x = [ 3.86473596e+00 -2.55125711e-08]
[ Epoch 39 ] grad = [ 7.13255438e+00 -7.41107602e-06], x = [ 3.48068654e+00 1.48221520e-08]
[ Epoch 40 ] grad = [ 6.42370389e+00 4.30564583e-06], x = [ 3.13476750e+00 -8.61129165e-09]
[ Epoch 41 ] grad = [ 5.78524834e+00 -2.50147022e-06], x = [ 2.82320119e+00 5.00294045e-09]
[ Epoch 42 ] grad = [ 5.21020954e+00 1.45329029e-06], x = [ 2.54258226e+00 -2.90658058e-09]
[ Epoch 43 ] grad = [ 4.69229801e+00 -8.44324531e-07], x = [ 2.28984143e+00 1.68864906e-09]
[ Epoch 44 ] grad = [ 4.22584570e+00 4.90530981e-07], x = [ 2.06221270e+00 -9.81061962e-10]
[ Epoch 45 ] grad = [ 3.80574519e+00 -2.84985968e-07], x = [ 1.85720365e+00 5.69971936e-10]
[ Epoch 46 ] grad = [ 3.42739463e+00 1.65569566e-07], x = [ 1.67256858e+00 -3.31139133e-10]
[ Epoch 47 ] grad = [ 3.08664806e+00 -9.61916881e-08], x = [ 1.50628425e+00 1.92383376e-10]
[ Epoch 48 ] grad = [ 2.77977045e+00 5.58849133e-08], x = [ 1.35652798e+00 -1.11769827e-10]
[ Epoch 49 ] grad = [ 2.50339717e+00 -3.24677068e-08], x = [ 1.22165782e+00 6.49354137e-11]
Content source: hsmyy/zhihuzhuanlan
Similar notebooks: