Gradient Descend

We will construct a simple linear model and see how Gradient Descend algorithm helps us to find the best parameters.


In [77]:
import torch

torch.manual_seed(42) # fix the seed; easy to debug


Out[77]:
<torch._C.Generator at 0x10d376c30>

Vectorizes Operation


In [78]:
u = torch.randint(10, (5, 1))
v = torch.randint(10, (5, 1))
print(u)
print(v)


tensor([[ 2.],
        [ 7.],
        [ 6.],
        [ 4.],
        [ 6.]])
tensor([[ 5.],
        [ 0.],
        [ 4.],
        [ 0.],
        [ 3.]])

In [79]:
u.size()


Out[79]:
torch.Size([5, 1])

In [80]:
u * v


Out[80]:
tensor([[ 10.],
        [  0.],
        [ 24.],
        [  0.],
        [ 18.]])

In [81]:
u.pow(v)


Out[81]:
tensor([[   32.],
        [    1.],
        [ 1296.],
        [    1.],
        [  216.]])

Ground Truth Model


In [82]:
n = 5
a = 2.0
b = 3.0
epsilon = 0.02

# Specify our ground truth model
X = torch.randint(10, (n, 1))
ps = torch.ones(n)/2.0
exponents = torch.bernoulli(ps) * epsilon + (torch.ones(n) - epsilon/2)
e = exponents.reshape(5, 1)
T = a * (X.pow(e)) + b

In [83]:
print(X)
print(e)
print(T)


tensor([[ 8.],
        [ 4.],
        [ 0.],
        [ 4.],
        [ 1.]])
tensor([[ 1.0100],
        [ 0.9900],
        [ 1.0100],
        [ 0.9900],
        [ 0.9900]])
tensor([[ 19.3362],
        [ 10.8899],
        [  3.0000],
        [ 10.8899],
        [  5.0000]])

In [84]:
# verify that our implementation is correct. 
print( a * X[0][0]**e[0][0] + b)


tensor(19.3362)

The above output shows that we can do vectorized operations on tensors easily.

You should be able to write down the ground truth model mathematically.

Fitting a Linear Model

$$\hat{\mathbf{y}} = \mathbf{x}\mathbf{w} + \mathbf{b}$$

In [85]:
# note that all operations here are vectorized operation. 

def forward(x): # dangerous: assuming w and b are global variables. 
    return x.mm(w) + b

# Loss function
def SSEloss(y_pred, y): 
    return (y_pred - y) * (y_pred - y) /2.0

def loss(x, t):
    y_pred = forward(x)
    return (y_pred - t).pow(2).sum() /2.0

def gradient_w(x, t):  
    return  torch.t(forward(x) - t).mm(x)

def gradient_b(x, t):  
    return  (forward(x) - t).sum() # need to add sum(); otherwise, each b_i will be different due to different gradients it receives.

In [103]:
# Training loop
MAX_EPOCHES = 500
learning_rate = 0.01 # it is tricky to find a good learning rate. 


w = torch.randn(1, 1)
b = torch.ones(n, 1)

for epoch in range(MAX_EPOCHES):
    grad_w = gradient_w(X, T)
    w = w - learning_rate * grad_w
    grad_b = gradient_b(X, T)
    b = b - learning_rate * grad_b
    
    print(".. grad: ", w.item(), b[0].item())
    l = loss(X, T)
    print("epoch = ", epoch, ", loss=", l.item(), "learning rate = ", learning_rate)
    
    # Adaptive learning rate on: 500 epoches => loss = 0.04446
    #                       off: 500 epoches => loss = 0.04179                                 
    #learning_rate = 1.0 / (100 + epoch)  
    
# After training
print(forward(X))
print(T)


.. grad:  2.3410696983337402 1.0431773662567139
epoch =  0 , loss= 3.5474693775177 learning rate =  0.01
.. grad:  2.360976457595825 1.080811619758606
epoch =  1 , loss= 3.3889691829681396 learning rate =  0.01
.. grad:  2.355175733566284 1.1175503730773926
epoch =  2 , loss= 3.255636215209961 learning rate =  0.01
.. grad:  2.3487560749053955 1.153543472290039
epoch =  3 , loss= 3.1272027492523193 learning rate =  0.01
.. grad:  2.342444658279419 1.188809871673584
epoch =  4 , loss= 3.003889322280884 learning rate =  0.01
.. grad:  2.3362600803375244 1.2233643531799316
epoch =  5 , loss= 2.8855020999908447 learning rate =  0.01
.. grad:  2.3302001953125 1.2572213411331177
epoch =  6 , loss= 2.771848201751709 learning rate =  0.01
.. grad:  2.324262857437134 1.2903947830200195
epoch =  7 , loss= 2.662736654281616 learning rate =  0.01
.. grad:  2.3184452056884766 1.322898507118225
epoch =  8 , loss= 2.5579850673675537 learning rate =  0.01
.. grad:  2.3127448558807373 1.3547461032867432
epoch =  9 , loss= 2.4574201107025146 learning rate =  0.01
.. grad:  2.307159900665283 1.3859508037567139
epoch =  10 , loss= 2.3608744144439697 learning rate =  0.01
.. grad:  2.301687717437744 1.4165254831314087
epoch =  11 , loss= 2.268187999725342 learning rate =  0.01
.. grad:  2.29632568359375 1.4464830160140991
epoch =  12 , loss= 2.179206371307373 learning rate =  0.01
.. grad:  2.291072130203247 1.4758358001708984
epoch =  13 , loss= 2.093780040740967 learning rate =  0.01
.. grad:  2.2859244346618652 1.5045959949493408
epoch =  14 , loss= 2.0117688179016113 learning rate =  0.01
.. grad:  2.2808806896209717 1.532775640487671
epoch =  15 , loss= 1.9330343008041382 learning rate =  0.01
.. grad:  2.2759389877319336 1.5603864192962646
epoch =  16 , loss= 1.8574470281600952 learning rate =  0.01
.. grad:  2.27109694480896 1.587439775466919
epoch =  17 , loss= 1.7848812341690063 learning rate =  0.01
.. grad:  2.266352415084839 1.6139470338821411
epoch =  18 , loss= 1.715214490890503 learning rate =  0.01
.. grad:  2.2617039680480957 1.6399191617965698
epoch =  19 , loss= 1.648332953453064 learning rate =  0.01
.. grad:  2.2571492195129395 1.6653670072555542
epoch =  20 , loss= 1.584124207496643 learning rate =  0.01
.. grad:  2.2526865005493164 1.6903010606765747
epoch =  21 , loss= 1.5224826335906982 learning rate =  0.01
.. grad:  2.2483136653900146 1.7147318124771118
epoch =  22 , loss= 1.4633028507232666 learning rate =  0.01
.. grad:  2.2440295219421387 1.7386693954467773
epoch =  23 , loss= 1.4064900875091553 learning rate =  0.01
.. grad:  2.2398314476013184 1.762123703956604
epoch =  24 , loss= 1.3519465923309326 learning rate =  0.01
.. grad:  2.235718250274658 1.7851046323776245
epoch =  25 , loss= 1.2995835542678833 learning rate =  0.01
.. grad:  2.2316882610321045 1.8076215982437134
epoch =  26 , loss= 1.2493131160736084 learning rate =  0.01
.. grad:  2.2277393341064453 1.8296840190887451
epoch =  27 , loss= 1.2010517120361328 learning rate =  0.01
.. grad:  2.223870277404785 1.8513010740280151
epoch =  28 , loss= 1.1547198295593262 learning rate =  0.01
.. grad:  2.2200794219970703 1.8724817037582397
epoch =  29 , loss= 1.1102396249771118 learning rate =  0.01
.. grad:  2.216364860534668 1.8932347297668457
epoch =  30 , loss= 1.06753671169281 learning rate =  0.01
.. grad:  2.2127256393432617 1.9135688543319702
epoch =  31 , loss= 1.0265406370162964 learning rate =  0.01
.. grad:  2.2091596126556396 1.9334924221038818
epoch =  32 , loss= 0.9871832728385925 learning rate =  0.01
.. grad:  2.205665349960327 1.9530138969421387
epoch =  33 , loss= 0.949398934841156 learning rate =  0.01
.. grad:  2.202241897583008 1.9721412658691406
epoch =  34 , loss= 0.9131239652633667 learning rate =  0.01
.. grad:  2.198887586593628 1.9908825159072876
epoch =  35 , loss= 0.8782991766929626 learning rate =  0.01
.. grad:  2.195600986480713 2.0092453956604004
epoch =  36 , loss= 0.8448664546012878 learning rate =  0.01
.. grad:  2.192380905151367 2.0272376537323
epoch =  37 , loss= 0.8127694129943848 learning rate =  0.01
.. grad:  2.189225435256958 2.0448665618896484
epoch =  38 , loss= 0.7819561958312988 learning rate =  0.01
.. grad:  2.186133861541748 2.0621397495269775
epoch =  39 , loss= 0.7523733377456665 learning rate =  0.01
.. grad:  2.1831047534942627 2.079064130783081
epoch =  40 , loss= 0.7239733934402466 learning rate =  0.01
.. grad:  2.1801366806030273 2.095646858215332
epoch =  41 , loss= 0.6967084407806396 learning rate =  0.01
.. grad:  2.1772286891937256 2.1118948459625244
epoch =  42 , loss= 0.6705334186553955 learning rate =  0.01
.. grad:  2.1743791103363037 2.127814769744873
epoch =  43 , loss= 0.6454047560691833 learning rate =  0.01
.. grad:  2.1715872287750244 2.1434133052825928
epoch =  44 , loss= 0.621279776096344 learning rate =  0.01
.. grad:  2.168851613998413 2.1586971282958984
epoch =  45 , loss= 0.5981188416481018 learning rate =  0.01
.. grad:  2.1661715507507324 2.1736721992492676
epoch =  46 , loss= 0.5758842825889587 learning rate =  0.01
.. grad:  2.1635453701019287 2.188344955444336
epoch =  47 , loss= 0.5545380711555481 learning rate =  0.01
.. grad:  2.1609721183776855 2.20272159576416
epoch =  48 , loss= 0.5340449213981628 learning rate =  0.01
.. grad:  2.1584510803222656 2.2168080806732178
epoch =  49 , loss= 0.514370858669281 learning rate =  0.01
.. grad:  2.1559805870056152 2.2306101322174072
epoch =  50 , loss= 0.49548307061195374 learning rate =  0.01
.. grad:  2.153560161590576 2.244133472442627
epoch =  51 , loss= 0.47735047340393066 learning rate =  0.01
.. grad:  2.151188611984253 2.2573838233947754
epoch =  52 , loss= 0.4599425494670868 learning rate =  0.01
.. grad:  2.14886474609375 2.270366907119751
epoch =  53 , loss= 0.4432299733161926 learning rate =  0.01
.. grad:  2.146588087081909 2.283087730407715
epoch =  54 , loss= 0.4271852672100067 learning rate =  0.01
.. grad:  2.144357204437256 2.2955517768859863
epoch =  55 , loss= 0.41178226470947266 learning rate =  0.01
.. grad:  2.1421713829040527 2.3077642917633057
epoch =  56 , loss= 0.3969944715499878 learning rate =  0.01
.. grad:  2.1400296688079834 2.319730281829834
epoch =  57 , loss= 0.38279813528060913 learning rate =  0.01
.. grad:  2.1379311084747314 2.3314547538757324
epoch =  58 , loss= 0.36916854977607727 learning rate =  0.01
.. grad:  2.1358749866485596 2.342942476272583
epoch =  59 , loss= 0.3560841381549835 learning rate =  0.01
.. grad:  2.1338603496551514 2.3541982173919678
epoch =  60 , loss= 0.3435225784778595 learning rate =  0.01
.. grad:  2.1318864822387695 2.3652267456054688
epoch =  61 , loss= 0.33146312832832336 learning rate =  0.01
.. grad:  2.1299524307250977 2.376032590866089
epoch =  62 , loss= 0.3198855221271515 learning rate =  0.01
.. grad:  2.1280574798583984 2.386620283126831
epoch =  63 , loss= 0.3087708353996277 learning rate =  0.01
.. grad:  2.1262006759643555 2.3969943523406982
epoch =  64 , loss= 0.29810023307800293 learning rate =  0.01
.. grad:  2.1243815422058105 2.407158851623535
epoch =  65 , loss= 0.2878563702106476 learning rate =  0.01
.. grad:  2.122598886489868 2.4171183109283447
epoch =  66 , loss= 0.2780216634273529 learning rate =  0.01
.. grad:  2.12085223197937 2.4268767833709717
epoch =  67 , loss= 0.26857995986938477 learning rate =  0.01
.. grad:  2.119140863418579 2.4364380836486816
epoch =  68 , loss= 0.2595156729221344 learning rate =  0.01
.. grad:  2.117464065551758 2.4458065032958984
epoch =  69 , loss= 0.25081372261047363 learning rate =  0.01
.. grad:  2.115821123123169 2.4549858570098877
epoch =  70 , loss= 0.24245941638946533 learning rate =  0.01
.. grad:  2.114211320877075 2.463979721069336
epoch =  71 , loss= 0.23443925380706787 learning rate =  0.01
.. grad:  2.1126341819763184 2.472792148590088
epoch =  72 , loss= 0.2267393171787262 learning rate =  0.01
.. grad:  2.111088752746582 2.48142671585083
epoch =  73 , loss= 0.21934719383716583 learning rate =  0.01
.. grad:  2.109574556350708 2.48988676071167
epoch =  74 , loss= 0.21225064992904663 learning rate =  0.01
.. grad:  2.108090877532959 2.498176097869873
epoch =  75 , loss= 0.20543789863586426 learning rate =  0.01
.. grad:  2.1066372394561768 2.506298065185547
epoch =  76 , loss= 0.1988973319530487 learning rate =  0.01
.. grad:  2.105212926864624 2.514256238937378
epoch =  77 , loss= 0.19261790812015533 learning rate =  0.01
.. grad:  2.1038172245025635 2.5220537185668945
epoch =  78 , loss= 0.18658967316150665 learning rate =  0.01
.. grad:  2.102449655532837 2.529693841934204
epoch =  79 , loss= 0.18080207705497742 learning rate =  0.01
.. grad:  2.1011099815368652 2.537179708480835
epoch =  80 , loss= 0.17524611949920654 learning rate =  0.01
.. grad:  2.099797010421753 2.5445144176483154
epoch =  81 , loss= 0.1699119657278061 learning rate =  0.01
.. grad:  2.098510980606079 2.551701068878174
epoch =  82 , loss= 0.16479110717773438 learning rate =  0.01
.. grad:  2.097250461578369 2.5587425231933594
epoch =  83 , loss= 0.15987500548362732 learning rate =  0.01
.. grad:  2.096015691757202 2.5656418800354004
epoch =  84 , loss= 0.1551554948091507 learning rate =  0.01
.. grad:  2.0948057174682617 2.572402000427246
epoch =  85 , loss= 0.15062430500984192 learning rate =  0.01
.. grad:  2.0936200618743896 2.5790257453918457
epoch =  86 , loss= 0.1462744176387787 learning rate =  0.01
.. grad:  2.092458724975586 2.5855157375335693
epoch =  87 , loss= 0.14209811389446259 learning rate =  0.01
.. grad:  2.091320514678955 2.591874599456787
epoch =  88 , loss= 0.13808904588222504 learning rate =  0.01
.. grad:  2.090205192565918 2.5981051921844482
epoch =  89 , loss= 0.13423995673656464 learning rate =  0.01
.. grad:  2.0891127586364746 2.6042098999023438
epoch =  90 , loss= 0.13054491579532623 learning rate =  0.01
.. grad:  2.0880420207977295 2.6101913452148438
epoch =  91 , loss= 0.1269974410533905 learning rate =  0.01
.. grad:  2.0869932174682617 2.6160521507263184
epoch =  92 , loss= 0.12359190732240677 learning rate =  0.01
.. grad:  2.085965156555176 2.6217947006225586
epoch =  93 , loss= 0.12032219767570496 learning rate =  0.01
.. grad:  2.084958076477051 2.6274211406707764
epoch =  94 , loss= 0.1171833947300911 learning rate =  0.01
.. grad:  2.0839715003967285 2.632934093475342
epoch =  95 , loss= 0.11417005956172943 learning rate =  0.01
.. grad:  2.0830047130584717 2.638335704803467
epoch =  96 , loss= 0.1112770140171051 learning rate =  0.01
.. grad:  2.082057476043701 2.6436283588409424
epoch =  97 , loss= 0.10849953442811966 learning rate =  0.01
.. grad:  2.081129312515259 2.6488142013549805
epoch =  98 , loss= 0.10583337396383286 learning rate =  0.01
.. grad:  2.0802199840545654 2.653895378112793
epoch =  99 , loss= 0.1032734215259552 learning rate =  0.01
.. grad:  2.079328775405884 2.6588737964630127
epoch =  100 , loss= 0.10081592202186584 learning rate =  0.01
.. grad:  2.078455686569214 2.6637518405914307
epoch =  101 , loss= 0.09845667332410812 learning rate =  0.01
.. grad:  2.0776002407073975 2.6685314178466797
epoch =  102 , loss= 0.0961916521191597 learning rate =  0.01
.. grad:  2.0767619609832764 2.6732144355773926
epoch =  103 , loss= 0.09401744604110718 learning rate =  0.01
.. grad:  2.0759408473968506 2.6778030395507812
epoch =  104 , loss= 0.09192964434623718 learning rate =  0.01
.. grad:  2.075136184692383 2.6822988986968994
epoch =  105 , loss= 0.08992563188076019 learning rate =  0.01
.. grad:  2.074347734451294 2.68670392036438
epoch =  106 , loss= 0.088001549243927 learning rate =  0.01
.. grad:  2.073575019836426 2.6910202503204346
epoch =  107 , loss= 0.08615442365407944 learning rate =  0.01
.. grad:  2.0728182792663574 2.695249319076538
epoch =  108 , loss= 0.08438125997781754 learning rate =  0.01
.. grad:  2.0720765590667725 2.6993930339813232
epoch =  109 , loss= 0.08267882466316223 learning rate =  0.01
.. grad:  2.071349859237671 2.7034530639648438
epoch =  110 , loss= 0.08104430884122849 learning rate =  0.01
.. grad:  2.0706379413604736 2.7074310779571533
epoch =  111 , loss= 0.07947524636983871 learning rate =  0.01
.. grad:  2.0699403285980225 2.7113287448883057
epoch =  112 , loss= 0.07796899974346161 learning rate =  0.01
.. grad:  2.0692567825317383 2.7151477336883545
epoch =  113 , loss= 0.07652297616004944 learning rate =  0.01
.. grad:  2.068587064743042 2.7188897132873535
epoch =  114 , loss= 0.0751347690820694 learning rate =  0.01
.. grad:  2.0679306983947754 2.7225561141967773
epoch =  115 , loss= 0.0738016664981842 learning rate =  0.01
.. grad:  2.0672876834869385 2.7261486053466797
epoch =  116 , loss= 0.07252217829227448 learning rate =  0.01
.. grad:  2.066657781600952 2.729668617248535
epoch =  117 , loss= 0.07129353284835815 learning rate =  0.01
.. grad:  2.066040515899658 2.7331175804138184
epoch =  118 , loss= 0.07011433690786362 learning rate =  0.01
.. grad:  2.0654356479644775 2.736496925354004
epoch =  119 , loss= 0.06898199021816254 learning rate =  0.01
.. grad:  2.064842939376831 2.7398078441619873
epoch =  120 , loss= 0.06789518892765045 learning rate =  0.01
.. grad:  2.0642623901367188 2.7430520057678223
epoch =  121 , loss= 0.06685150414705276 learning rate =  0.01
.. grad:  2.0636935234069824 2.7462306022644043
epoch =  122 , loss= 0.06584985554218292 learning rate =  0.01
.. grad:  2.063135862350464 2.749345064163208
epoch =  123 , loss= 0.06488790363073349 learning rate =  0.01
.. grad:  2.0625898838043213 2.752396583557129
epoch =  124 , loss= 0.06396469473838806 learning rate =  0.01
.. grad:  2.0620546340942383 2.7553865909576416
epoch =  125 , loss= 0.06307830661535263 learning rate =  0.01
.. grad:  2.061530351638794 2.7583162784576416
epoch =  126 , loss= 0.06222725659608841 learning rate =  0.01
.. grad:  2.06101655960083 2.7611868381500244
epoch =  127 , loss= 0.06141021475195885 learning rate =  0.01
.. grad:  2.0605132579803467 2.7639994621276855
epoch =  128 , loss= 0.0606258325278759 learning rate =  0.01
.. grad:  2.0600199699401855 2.7667553424835205
epoch =  129 , loss= 0.05987276881933212 learning rate =  0.01
.. grad:  2.0595364570617676 2.7694554328918457
epoch =  130 , loss= 0.05914997309446335 learning rate =  0.01
.. grad:  2.059063196182251 2.7721011638641357
epoch =  131 , loss= 0.05845615640282631 learning rate =  0.01
.. grad:  2.0585992336273193 2.774693489074707
epoch =  132 , loss= 0.05778962001204491 learning rate =  0.01
.. grad:  2.0581443309783936 2.777233362197876
epoch =  133 , loss= 0.05715007334947586 learning rate =  0.01
.. grad:  2.057699203491211 2.779721975326538
epoch =  134 , loss= 0.056536052376031876 learning rate =  0.01
.. grad:  2.057262659072876 2.7821602821350098
epoch =  135 , loss= 0.05594666302204132 learning rate =  0.01
.. grad:  2.0568349361419678 2.7845494747161865
epoch =  136 , loss= 0.055380597710609436 learning rate =  0.01
.. grad:  2.0564160346984863 2.7868905067443848
epoch =  137 , loss= 0.05483713001012802 learning rate =  0.01
.. grad:  2.0560054779052734 2.789184331893921
epoch =  138 , loss= 0.054315533488988876 learning rate =  0.01
.. grad:  2.055603265762329 2.7914316654205322
epoch =  139 , loss= 0.053814731538295746 learning rate =  0.01
.. grad:  2.055209159851074 2.7936336994171143
epoch =  140 , loss= 0.05333389714360237 learning rate =  0.01
.. grad:  2.0548229217529297 2.7957913875579834
epoch =  141 , loss= 0.05287240818142891 learning rate =  0.01
.. grad:  2.0544445514678955 2.797905445098877
epoch =  142 , loss= 0.052429426461458206 learning rate =  0.01
.. grad:  2.0540738105773926 2.7999768257141113
epoch =  143 , loss= 0.05200381577014923 learning rate =  0.01
.. grad:  2.053710699081421 2.802006244659424
epoch =  144 , loss= 0.051595717668533325 learning rate =  0.01
.. grad:  2.0533547401428223 2.80399489402771
epoch =  145 , loss= 0.05120348930358887 learning rate =  0.01
.. grad:  2.0530059337615967 2.805943250656128
epoch =  146 , loss= 0.050827156752347946 learning rate =  0.01
.. grad:  2.052664279937744 2.807852268218994
epoch =  147 , loss= 0.050465598702430725 learning rate =  0.01
.. grad:  2.0523295402526855 2.809722900390625
epoch =  148 , loss= 0.050118736922740936 learning rate =  0.01
.. grad:  2.052001476287842 2.8115556240081787
epoch =  149 , loss= 0.049785759299993515 learning rate =  0.01
.. grad:  2.051679849624634 2.8133513927459717
epoch =  150 , loss= 0.0494660846889019 learning rate =  0.01
.. grad:  2.0513651371002197 2.815110921859741
epoch =  151 , loss= 0.04915902391076088 learning rate =  0.01
.. grad:  2.051056385040283 2.8168349266052246
epoch =  152 , loss= 0.04886435717344284 learning rate =  0.01
.. grad:  2.0507540702819824 2.818524122238159
epoch =  153 , loss= 0.04858154430985451 learning rate =  0.01
.. grad:  2.0504579544067383 2.8201792240142822
epoch =  154 , loss= 0.04830976203083992 learning rate =  0.01
.. grad:  2.0501677989959717 2.821800947189331
epoch =  155 , loss= 0.04804915189743042 learning rate =  0.01
.. grad:  2.0498833656311035 2.823390007019043
epoch =  156 , loss= 0.04779855161905289 learning rate =  0.01
.. grad:  2.0496044158935547 2.824946880340576
epoch =  157 , loss= 0.04755839705467224 learning rate =  0.01
.. grad:  2.0493314266204834 2.826472282409668
epoch =  158 , loss= 0.047327689826488495 learning rate =  0.01
.. grad:  2.0490641593933105 2.8279669284820557
epoch =  159 , loss= 0.047106049954891205 learning rate =  0.01
.. grad:  2.048801898956299 2.8294315338134766
epoch =  160 , loss= 0.046893443912267685 learning rate =  0.01
.. grad:  2.0485451221466064 2.830866575241089
epoch =  161 , loss= 0.04668942838907242 learning rate =  0.01
.. grad:  2.0482935905456543 2.832272529602051
epoch =  162 , loss= 0.04649340361356735 learning rate =  0.01
.. grad:  2.0480470657348633 2.8336501121520996
epoch =  163 , loss= 0.04630520939826965 learning rate =  0.01
.. grad:  2.0478053092956543 2.8349997997283936
epoch =  164 , loss= 0.04612467810511589 learning rate =  0.01
.. grad:  2.0475685596466064 2.83632230758667
epoch =  165 , loss= 0.04595086723566055 learning rate =  0.01
.. grad:  2.0473365783691406 2.837618112564087
epoch =  166 , loss= 0.045784588903188705 learning rate =  0.01
.. grad:  2.047109365463257 2.8388876914978027
epoch =  167 , loss= 0.045624613761901855 learning rate =  0.01
.. grad:  2.046886682510376 2.8401317596435547
epoch =  168 , loss= 0.045471351593732834 learning rate =  0.01
.. grad:  2.046668529510498 2.841350793838501
epoch =  169 , loss= 0.0453239344060421 learning rate =  0.01
.. grad:  2.046454906463623 2.8425450325012207
epoch =  170 , loss= 0.04518251121044159 learning rate =  0.01
.. grad:  2.046245574951172 2.843715190887451
epoch =  171 , loss= 0.04504689201712608 learning rate =  0.01
.. grad:  2.0460400581359863 2.8448617458343506
epoch =  172 , loss= 0.04491652920842171 learning rate =  0.01
.. grad:  2.0458390712738037 2.845985174179077
epoch =  173 , loss= 0.04479137063026428 learning rate =  0.01
.. grad:  2.045642137527466 2.847085952758789
epoch =  174 , loss= 0.04467122256755829 learning rate =  0.01
.. grad:  2.0454490184783936 2.8481645584106445
epoch =  175 , loss= 0.04455570876598358 learning rate =  0.01
.. grad:  2.045259714126587 2.8492212295532227
epoch =  176 , loss= 0.04444504156708717 learning rate =  0.01
.. grad:  2.045074462890625 2.8502566814422607
epoch =  177 , loss= 0.04433896392583847 learning rate =  0.01
.. grad:  2.0448930263519287 2.851271152496338
epoch =  178 , loss= 0.044237010180950165 learning rate =  0.01
.. grad:  2.044715166091919 2.8522651195526123
epoch =  179 , loss= 0.04413898289203644 learning rate =  0.01
.. grad:  2.0445408821105957 2.853239059448242
epoch =  180 , loss= 0.044044770300388336 learning rate =  0.01
.. grad:  2.04436993598938 2.8541934490203857
epoch =  181 , loss= 0.043954625725746155 learning rate =  0.01
.. grad:  2.0442025661468506 2.855128526687622
epoch =  182 , loss= 0.04386761412024498 learning rate =  0.01
.. grad:  2.0440385341644287 2.8560447692871094
epoch =  183 , loss= 0.04378451779484749 learning rate =  0.01
.. grad:  2.0438778400421143 2.8569424152374268
epoch =  184 , loss= 0.04370448738336563 learning rate =  0.01
.. grad:  2.0437204837799072 2.8578219413757324
epoch =  185 , loss= 0.043627724051475525 learning rate =  0.01
.. grad:  2.0435662269592285 2.8586838245391846
epoch =  186 , loss= 0.043554116040468216 learning rate =  0.01
.. grad:  2.043415069580078 2.8595283031463623
epoch =  187 , loss= 0.04348373785614967 learning rate =  0.01
.. grad:  2.043267011642456 2.8603556156158447
epoch =  188 , loss= 0.04341569542884827 learning rate =  0.01
.. grad:  2.0431220531463623 2.86116623878479
epoch =  189 , loss= 0.04335054010152817 learning rate =  0.01
.. grad:  2.0429799556732178 2.8619604110717773
epoch =  190 , loss= 0.04328813776373863 learning rate =  0.01
.. grad:  2.0428407192230225 2.862738609313965
epoch =  191 , loss= 0.043228089809417725 learning rate =  0.01
.. grad:  2.0427041053771973 2.8635010719299316
epoch =  192 , loss= 0.0431702621281147 learning rate =  0.01
.. grad:  2.0425703525543213 2.864248275756836
epoch =  193 , loss= 0.04311500862240791 learning rate =  0.01
.. grad:  2.0424392223358154 2.864980459213257
epoch =  194 , loss= 0.04306187480688095 learning rate =  0.01
.. grad:  2.042310953140259 2.8656976222991943
epoch =  195 , loss= 0.04301087185740471 learning rate =  0.01
.. grad:  2.042185068130493 2.8664004802703857
epoch =  196 , loss= 0.04296201840043068 learning rate =  0.01
.. grad:  2.0420618057250977 2.867089033126831
epoch =  197 , loss= 0.0429149866104126 learning rate =  0.01
.. grad:  2.0419411659240723 2.8677637577056885
epoch =  198 , loss= 0.04286964610219002 learning rate =  0.01
.. grad:  2.041822910308838 2.868424892425537
epoch =  199 , loss= 0.04282644763588905 learning rate =  0.01
.. grad:  2.0417068004608154 2.869072675704956
epoch =  200 , loss= 0.042784929275512695 learning rate =  0.01
.. grad:  2.041593313217163 2.8697073459625244
epoch =  201 , loss= 0.04274478927254677 learning rate =  0.01
.. grad:  2.0414819717407227 2.8703291416168213
epoch =  202 , loss= 0.04270647093653679 learning rate =  0.01
.. grad:  2.0413730144500732 2.870938539505005
epoch =  203 , loss= 0.04266950488090515 learning rate =  0.01
.. grad:  2.0412659645080566 2.871535539627075
epoch =  204 , loss= 0.04263414442539215 learning rate =  0.01
.. grad:  2.041161298751831 2.8721206188201904
epoch =  205 , loss= 0.042600564658641815 learning rate =  0.01
.. grad:  2.0410587787628174 2.8726937770843506
epoch =  206 , loss= 0.042567864060401917 learning rate =  0.01
.. grad:  2.0409581661224365 2.8732552528381348
epoch =  207 , loss= 0.04253678768873215 learning rate =  0.01
.. grad:  2.0408596992492676 2.873805522918701
epoch =  208 , loss= 0.042506616562604904 learning rate =  0.01
.. grad:  2.0407631397247314 2.87434458732605
epoch =  209 , loss= 0.042477890849113464 learning rate =  0.01
.. grad:  2.0406687259674072 2.874872922897339
epoch =  210 , loss= 0.04245021939277649 learning rate =  0.01
.. grad:  2.040576219558716 2.8753905296325684
epoch =  211 , loss= 0.04242326691746712 learning rate =  0.01
.. grad:  2.040485382080078 2.8758976459503174
epoch =  212 , loss= 0.042397841811180115 learning rate =  0.01
.. grad:  2.0403964519500732 2.876394510269165
epoch =  213 , loss= 0.042373474687337875 learning rate =  0.01
.. grad:  2.040309190750122 2.8768813610076904
epoch =  214 , loss= 0.042350176721811295 learning rate =  0.01
.. grad:  2.0402238368988037 2.8773584365844727
epoch =  215 , loss= 0.04232729226350784 learning rate =  0.01
.. grad:  2.040140151977539 2.877825975418091
epoch =  216 , loss= 0.04230562597513199 learning rate =  0.01
.. grad:  2.040058135986328 2.878283977508545
epoch =  217 , loss= 0.04228505492210388 learning rate =  0.01
.. grad:  2.039977788925171 2.878732681274414
epoch =  218 , loss= 0.042264994233846664 learning rate =  0.01
.. grad:  2.0398991107940674 2.8791723251342773
epoch =  219 , loss= 0.04224599897861481 learning rate =  0.01
.. grad:  2.0398221015930176 2.879603147506714
epoch =  220 , loss= 0.04222763329744339 learning rate =  0.01
.. grad:  2.0397465229034424 2.8800251483917236
epoch =  221 , loss= 0.04220961779356003 learning rate =  0.01
.. grad:  2.039672374725342 2.880438804626465
epoch =  222 , loss= 0.04219269007444382 learning rate =  0.01
.. grad:  2.039599895477295 2.8808441162109375
epoch =  223 , loss= 0.042176537215709686 learning rate =  0.01
.. grad:  2.0395288467407227 2.8812410831451416
epoch =  224 , loss= 0.042160890996456146 learning rate =  0.01
.. grad:  2.039459228515625 2.8816301822662354
epoch =  225 , loss= 0.04214588180184364 learning rate =  0.01
.. grad:  2.039391040802002 2.8820114135742188
epoch =  226 , loss= 0.04213149473071098 learning rate =  0.01
.. grad:  2.0393242835998535 2.882384777069092
epoch =  227 , loss= 0.04211774468421936 learning rate =  0.01
.. grad:  2.0392587184906006 2.8827507495880127
epoch =  228 , loss= 0.042104218155145645 learning rate =  0.01
.. grad:  2.0391945838928223 2.8831093311309814
epoch =  229 , loss= 0.042091451585292816 learning rate =  0.01
.. grad:  2.0391316413879395 2.883460760116577
epoch =  230 , loss= 0.04207945615053177 learning rate =  0.01
.. grad:  2.039069890975952 2.8838050365448
epoch =  231 , loss= 0.042067818343639374 learning rate =  0.01
.. grad:  2.0390095710754395 2.8841423988342285
epoch =  232 , loss= 0.04205629229545593 learning rate =  0.01
.. grad:  2.0389506816864014 2.8844728469848633
epoch =  233 , loss= 0.042045388370752335 learning rate =  0.01
.. grad:  2.0388925075531006 2.884796619415283
epoch =  234 , loss= 0.04203537479043007 learning rate =  0.01
.. grad:  2.0388357639312744 2.8851139545440674
epoch =  235 , loss= 0.04202498868107796 learning rate =  0.01
.. grad:  2.0387799739837646 2.885424852371216
epoch =  236 , loss= 0.04201540723443031 learning rate =  0.01
.. grad:  2.0387256145477295 2.8857293128967285
epoch =  237 , loss= 0.04200661554932594 learning rate =  0.01
.. grad:  2.0386722087860107 2.8860278129577637
epoch =  238 , loss= 0.041997361928224564 learning rate =  0.01
.. grad:  2.0386197566986084 2.8863203525543213
epoch =  239 , loss= 0.0419890321791172 learning rate =  0.01
.. grad:  2.0385684967041016 2.8866069316864014
epoch =  240 , loss= 0.041981086134910583 learning rate =  0.01
.. grad:  2.0385184288024902 2.886887550354004
epoch =  241 , loss= 0.04197325557470322 learning rate =  0.01
.. grad:  2.038469076156616 2.887162685394287
epoch =  242 , loss= 0.04196542501449585 learning rate =  0.01
.. grad:  2.0384209156036377 2.887432098388672
epoch =  243 , loss= 0.041958436369895935 learning rate =  0.01
.. grad:  2.0383734703063965 2.8876962661743164
epoch =  244 , loss= 0.04195163771510124 learning rate =  0.01
.. grad:  2.038327217102051 2.8879549503326416
epoch =  245 , loss= 0.041944846510887146 learning rate =  0.01
.. grad:  2.0382819175720215 2.8882083892822266
epoch =  246 , loss= 0.04193849861621857 learning rate =  0.01
.. grad:  2.0382373332977295 2.8884568214416504
epoch =  247 , loss= 0.0419326089322567 learning rate =  0.01
.. grad:  2.038193941116333 2.888700246810913
epoch =  248 , loss= 0.0419265441596508 learning rate =  0.01
.. grad:  2.038151264190674 2.8889386653900146
epoch =  249 , loss= 0.041920922696590424 learning rate =  0.01
.. grad:  2.038109302520752 2.889172315597534
epoch =  250 , loss= 0.0419156476855278 learning rate =  0.01
.. grad:  2.0380685329437256 2.8894011974334717
epoch =  251 , loss= 0.04191043600440025 learning rate =  0.01
.. grad:  2.0380282402038574 2.8896255493164062
epoch =  252 , loss= 0.04190526902675629 learning rate =  0.01
.. grad:  2.0379889011383057 2.889845371246338
epoch =  253 , loss= 0.04190067574381828 learning rate =  0.01
.. grad:  2.037950277328491 2.8900606632232666
epoch =  254 , loss= 0.04189584404230118 learning rate =  0.01
.. grad:  2.037912607192993 2.8902716636657715
epoch =  255 , loss= 0.04189123958349228 learning rate =  0.01
.. grad:  2.0378756523132324 2.8904783725738525
epoch =  256 , loss= 0.04188717529177666 learning rate =  0.01
.. grad:  2.037839412689209 2.890681028366089
epoch =  257 , loss= 0.041883353143930435 learning rate =  0.01
.. grad:  2.037803888320923 2.8908793926239014
epoch =  258 , loss= 0.04187943413853645 learning rate =  0.01
.. grad:  2.037769079208374 2.8910739421844482
epoch =  259 , loss= 0.0418756827712059 learning rate =  0.01
.. grad:  2.0377349853515625 2.8912644386291504
epoch =  260 , loss= 0.04187183082103729 learning rate =  0.01
.. grad:  2.0377016067504883 2.891451120376587
epoch =  261 , loss= 0.0418684296309948 learning rate =  0.01
.. grad:  2.0376687049865723 2.891633987426758
epoch =  262 , loss= 0.041865136474370956 learning rate =  0.01
.. grad:  2.0376367568969727 2.891813278198242
epoch =  263 , loss= 0.04186195135116577 learning rate =  0.01
.. grad:  2.0376052856445312 2.89198899269104
epoch =  264 , loss= 0.041858866810798645 learning rate =  0.01
.. grad:  2.037574529647827 2.8921611309051514
epoch =  265 , loss= 0.041856132447719574 learning rate =  0.01
.. grad:  2.0375442504882812 2.892329692840576
epoch =  266 , loss= 0.04185321182012558 learning rate =  0.01
.. grad:  2.0375146865844727 2.8924949169158936
epoch =  267 , loss= 0.04185058921575546 learning rate =  0.01
.. grad:  2.0374858379364014 2.8926568031311035
epoch =  268 , loss= 0.041848063468933105 learning rate =  0.01
.. grad:  2.037457227706909 2.892815351486206
epoch =  269 , loss= 0.04184551537036896 learning rate =  0.01
.. grad:  2.0374295711517334 2.8929708003997803
epoch =  270 , loss= 0.0418429858982563 learning rate =  0.01
.. grad:  2.0374021530151367 2.893123149871826
epoch =  271 , loss= 0.04184072092175484 learning rate =  0.01
.. grad:  2.0373754501342773 2.8932723999023438
epoch =  272 , loss= 0.04183843731880188 learning rate =  0.01
.. grad:  2.0373494625091553 2.893418550491333
epoch =  273 , loss= 0.04183614253997803 learning rate =  0.01
.. grad:  2.0373237133026123 2.893561840057373
epoch =  274 , loss= 0.041834086179733276 learning rate =  0.01
.. grad:  2.0372986793518066 2.8937020301818848
epoch =  275 , loss= 0.04183225706219673 learning rate =  0.01
.. grad:  2.037274122238159 2.8938395977020264
epoch =  276 , loss= 0.0418306365609169 learning rate =  0.01
.. grad:  2.03725004196167 2.8939743041992188
epoch =  277 , loss= 0.04182865843176842 learning rate =  0.01
.. grad:  2.0372262001037598 2.894106388092041
epoch =  278 , loss= 0.04182685539126396 learning rate =  0.01
.. grad:  2.037203073501587 2.894235610961914
epoch =  279 , loss= 0.04182529076933861 learning rate =  0.01
.. grad:  2.0371804237365723 2.894362449645996
epoch =  280 , loss= 0.04182354360818863 learning rate =  0.01
.. grad:  2.0371580123901367 2.894486665725708
epoch =  281 , loss= 0.04182199016213417 learning rate =  0.01
.. grad:  2.0371363162994385 2.89460825920105
epoch =  282 , loss= 0.04182086139917374 learning rate =  0.01
.. grad:  2.0371150970458984 2.8947274684906006
epoch =  283 , loss= 0.04181929677724838 learning rate =  0.01
.. grad:  2.0370941162109375 2.8948442935943604
epoch =  284 , loss= 0.04181789979338646 learning rate =  0.01
.. grad:  2.037073850631714 2.894958734512329
epoch =  285 , loss= 0.0418165884912014 learning rate =  0.01
.. grad:  2.0370535850524902 2.895070791244507
epoch =  286 , loss= 0.04181543365120888 learning rate =  0.01
.. grad:  2.037034034729004 2.8951807022094727
epoch =  287 , loss= 0.04181431978940964 learning rate =  0.01
.. grad:  2.0370147228240967 2.8952882289886475
epoch =  288 , loss= 0.04181307554244995 learning rate =  0.01
.. grad:  2.0369958877563477 2.8953936100006104
epoch =  289 , loss= 0.04181190952658653 learning rate =  0.01
.. grad:  2.036977529525757 2.8954968452453613
epoch =  290 , loss= 0.041811104863882065 learning rate =  0.01
.. grad:  2.036959171295166 2.8955981731414795
epoch =  291 , loss= 0.0418100506067276 learning rate =  0.01
.. grad:  2.0369415283203125 2.8956973552703857
epoch =  292 , loss= 0.04180904105305672 learning rate =  0.01
.. grad:  2.036924123764038 2.895794630050659
epoch =  293 , loss= 0.041808120906353 learning rate =  0.01
.. grad:  2.036907196044922 2.8958897590637207
epoch =  294 , loss= 0.041806966066360474 learning rate =  0.01
.. grad:  2.0368905067443848 2.8959829807281494
epoch =  295 , loss= 0.04180612787604332 learning rate =  0.01
.. grad:  2.0368740558624268 2.8960745334625244
epoch =  296 , loss= 0.041805315762758255 learning rate =  0.01
.. grad:  2.036858081817627 2.8961641788482666
epoch =  297 , loss= 0.04180450364947319 learning rate =  0.01
.. grad:  2.0368423461914062 2.896251916885376
epoch =  298 , loss= 0.04180376976728439 learning rate =  0.01
.. grad:  2.0368268489837646 2.8963379859924316
epoch =  299 , loss= 0.04180328920483589 learning rate =  0.01
.. grad:  2.0368118286132812 2.8964221477508545
epoch =  300 , loss= 0.04180251806974411 learning rate =  0.01
.. grad:  2.036797046661377 2.8965046405792236
epoch =  301 , loss= 0.041801489889621735 learning rate =  0.01
.. grad:  2.0367825031280518 2.896585464477539
epoch =  302 , loss= 0.041801050305366516 learning rate =  0.01
.. grad:  2.0367684364318848 2.896664619445801
epoch =  303 , loss= 0.04180055111646652 learning rate =  0.01
.. grad:  2.036754608154297 2.896742343902588
epoch =  304 , loss= 0.04179973527789116 learning rate =  0.01
.. grad:  2.036741018295288 2.8968183994293213
epoch =  305 , loss= 0.04179919883608818 learning rate =  0.01
.. grad:  2.0367276668548584 2.89689302444458
epoch =  306 , loss= 0.041798919439315796 learning rate =  0.01
.. grad:  2.0367143154144287 2.8969662189483643
epoch =  307 , loss= 0.041797976940870285 learning rate =  0.01
.. grad:  2.0367014408111572 2.8970377445220947
epoch =  308 , loss= 0.04179760813713074 learning rate =  0.01
.. grad:  2.036689043045044 2.8971078395843506
epoch =  309 , loss= 0.041797198355197906 learning rate =  0.01
.. grad:  2.0366766452789307 2.897176504135132
epoch =  310 , loss= 0.0417967326939106 learning rate =  0.01
.. grad:  2.0366647243499756 2.8972437381744385
epoch =  311 , loss= 0.04179620370268822 learning rate =  0.01
.. grad:  2.0366530418395996 2.8973097801208496
epoch =  312 , loss= 0.041795890778303146 learning rate =  0.01
.. grad:  2.0366413593292236 2.897374391555786
epoch =  313 , loss= 0.04179522022604942 learning rate =  0.01
.. grad:  2.0366299152374268 2.897437810897827
epoch =  314 , loss= 0.0417950376868248 learning rate =  0.01
.. grad:  2.036618947982788 2.8974997997283936
epoch =  315 , loss= 0.0417947955429554 learning rate =  0.01
.. grad:  2.0366079807281494 2.8975605964660645
epoch =  316 , loss= 0.04179445281624794 learning rate =  0.01
.. grad:  2.03659725189209 2.89762020111084
epoch =  317 , loss= 0.041793759912252426 learning rate =  0.01
.. grad:  2.0365867614746094 2.8976786136627197
epoch =  318 , loss= 0.04179354012012482 learning rate =  0.01
.. grad:  2.036576747894287 2.897735834121704
epoch =  319 , loss= 0.041793182492256165 learning rate =  0.01
.. grad:  2.036566734313965 2.897791862487793
epoch =  320 , loss= 0.04179278388619423 learning rate =  0.01
.. grad:  2.0365567207336426 2.8978466987609863
epoch =  321 , loss= 0.041792549192905426 learning rate =  0.01
.. grad:  2.0365471839904785 2.8979005813598633
epoch =  322 , loss= 0.041792478412389755 learning rate =  0.01
.. grad:  2.0365378856658936 2.8979532718658447
epoch =  323 , loss= 0.04179202392697334 learning rate =  0.01
.. grad:  2.0365285873413086 2.8980050086975098
epoch =  324 , loss= 0.04179172217845917 learning rate =  0.01
.. grad:  2.0365195274353027 2.8980555534362793
epoch =  325 , loss= 0.04179162159562111 learning rate =  0.01
.. grad:  2.036510467529297 2.8981051445007324
epoch =  326 , loss= 0.04179137945175171 learning rate =  0.01
.. grad:  2.036501884460449 2.898153781890869
epoch =  327 , loss= 0.04179098829627037 learning rate =  0.01
.. grad:  2.0364933013916016 2.8982014656066895
epoch =  328 , loss= 0.041790761053562164 learning rate =  0.01
.. grad:  2.036484956741333 2.8982481956481934
epoch =  329 , loss= 0.04179039224982262 learning rate =  0.01
.. grad:  2.0364768505096436 2.898293972015381
epoch =  330 , loss= 0.04179021716117859 learning rate =  0.01
.. grad:  2.036468744277954 2.898338794708252
epoch =  331 , loss= 0.041790418326854706 learning rate =  0.01
.. grad:  2.0364608764648438 2.8983826637268066
epoch =  332 , loss= 0.0417899489402771 learning rate =  0.01
.. grad:  2.0364532470703125 2.898425579071045
epoch =  333 , loss= 0.04178991913795471 learning rate =  0.01
.. grad:  2.0364456176757812 2.898467779159546
epoch =  334 , loss= 0.04178968816995621 learning rate =  0.01
.. grad:  2.036438226699829 2.8985090255737305
epoch =  335 , loss= 0.04178965464234352 learning rate =  0.01
.. grad:  2.036431074142456 2.8985495567321777
epoch =  336 , loss= 0.041789133101701736 learning rate =  0.01
.. grad:  2.036423921585083 2.8985891342163086
epoch =  337 , loss= 0.04178933799266815 learning rate =  0.01
.. grad:  2.036417007446289 2.898627996444702
epoch =  338 , loss= 0.04178905487060547 learning rate =  0.01
.. grad:  2.036410093307495 2.8986661434173584
epoch =  339 , loss= 0.04178892821073532 learning rate =  0.01
.. grad:  2.0364034175872803 2.8987033367156982
epoch =  340 , loss= 0.04178892821073532 learning rate =  0.01
.. grad:  2.0363969802856445 2.898739814758301
epoch =  341 , loss= 0.04178875684738159 learning rate =  0.01
.. grad:  2.036390542984009 2.898775577545166
epoch =  342 , loss= 0.04178869351744652 learning rate =  0.01
.. grad:  2.036384344100952 2.898810625076294
epoch =  343 , loss= 0.041788190603256226 learning rate =  0.01
.. grad:  2.0363781452178955 2.8988449573516846
epoch =  344 , loss= 0.04178807511925697 learning rate =  0.01
.. grad:  2.036372184753418 2.898878574371338
epoch =  345 , loss= 0.04178806394338608 learning rate =  0.01
.. grad:  2.0363662242889404 2.898911476135254
epoch =  346 , loss= 0.041787877678871155 learning rate =  0.01
.. grad:  2.036360263824463 2.8989439010620117
epoch =  347 , loss= 0.04178778827190399 learning rate =  0.01
.. grad:  2.0363547801971436 2.8989756107330322
epoch =  348 , loss= 0.04178779199719429 learning rate =  0.01
.. grad:  2.036349296569824 2.8990066051483154
epoch =  349 , loss= 0.04178789630532265 learning rate =  0.01
.. grad:  2.036343812942505 2.8990368843078613
epoch =  350 , loss= 0.04178784415125847 learning rate =  0.01
.. grad:  2.0363385677337646 2.899066686630249
epoch =  351 , loss= 0.04178759455680847 learning rate =  0.01
.. grad:  2.0363330841064453 2.8990957736968994
epoch =  352 , loss= 0.04178744554519653 learning rate =  0.01
.. grad:  2.036328077316284 2.8991243839263916
epoch =  353 , loss= 0.04178734868764877 learning rate =  0.01
.. grad:  2.036323070526123 2.8991525173187256
epoch =  354 , loss= 0.04178736358880997 learning rate =  0.01
.. grad:  2.036318063735962 2.8991799354553223
epoch =  355 , loss= 0.041787438094615936 learning rate =  0.01
.. grad:  2.03631329536438 2.8992068767547607
epoch =  356 , loss= 0.041787344962358475 learning rate =  0.01
.. grad:  2.036308765411377 2.899233102798462
epoch =  357 , loss= 0.0417870432138443 learning rate =  0.01
.. grad:  2.036303997039795 2.899258852005005
epoch =  358 , loss= 0.041786834597587585 learning rate =  0.01
.. grad:  2.036299467086792 2.8992841243743896
epoch =  359 , loss= 0.041787005960941315 learning rate =  0.01
.. grad:  2.036294937133789 2.899308919906616
epoch =  360 , loss= 0.04178692772984505 learning rate =  0.01
.. grad:  2.0362906455993652 2.8993332386016846
epoch =  361 , loss= 0.04178665205836296 learning rate =  0.01
.. grad:  2.0362863540649414 2.8993570804595947
epoch =  362 , loss= 0.04178676754236221 learning rate =  0.01
.. grad:  2.0362823009490967 2.8993804454803467
epoch =  363 , loss= 0.04178691655397415 learning rate =  0.01
.. grad:  2.036278247833252 2.8994033336639404
epoch =  364 , loss= 0.04178687930107117 learning rate =  0.01
.. grad:  2.0362741947174072 2.899425745010376
epoch =  365 , loss= 0.041786596179008484 learning rate =  0.01
.. grad:  2.0362703800201416 2.8994476795196533
epoch =  366 , loss= 0.04178640991449356 learning rate =  0.01
.. grad:  2.036266326904297 2.8994691371917725
epoch =  367 , loss= 0.041786566376686096 learning rate =  0.01
.. grad:  2.0362625122070312 2.8994901180267334
epoch =  368 , loss= 0.041786521673202515 learning rate =  0.01
.. grad:  2.0362589359283447 2.8995108604431152
epoch =  369 , loss= 0.04178622364997864 learning rate =  0.01
.. grad:  2.036255359649658 2.899531126022339
epoch =  370 , loss= 0.04178660735487938 learning rate =  0.01
.. grad:  2.0362517833709717 2.8995509147644043
epoch =  371 , loss= 0.04178617149591446 learning rate =  0.01
.. grad:  2.036248207092285 2.8995702266693115
epoch =  372 , loss= 0.04178636148571968 learning rate =  0.01
.. grad:  2.0362448692321777 2.8995893001556396
epoch =  373 , loss= 0.04178635776042938 learning rate =  0.01
.. grad:  2.0362415313720703 2.8996078968048096
epoch =  374 , loss= 0.04178610444068909 learning rate =  0.01
.. grad:  2.036238193511963 2.8996262550354004
epoch =  375 , loss= 0.041786156594753265 learning rate =  0.01
.. grad:  2.0362350940704346 2.899644136428833
epoch =  376 , loss= 0.04178602248430252 learning rate =  0.01
.. grad:  2.0362319946289062 2.8996617794036865
epoch =  377 , loss= 0.04178619757294655 learning rate =  0.01
.. grad:  2.036228895187378 2.899678945541382
epoch =  378 , loss= 0.041786182671785355 learning rate =  0.01
.. grad:  2.0362257957458496 2.899695873260498
epoch =  379 , loss= 0.04178619384765625 learning rate =  0.01
.. grad:  2.0362226963043213 2.899712324142456
epoch =  380 , loss= 0.04178626835346222 learning rate =  0.01
.. grad:  2.036220073699951 2.899728536605835
epoch =  381 , loss= 0.04178611561655998 learning rate =  0.01
.. grad:  2.036217212677002 2.8997442722320557
epoch =  382 , loss= 0.041785985231399536 learning rate =  0.01
.. grad:  2.0362143516540527 2.8997597694396973
epoch =  383 , loss= 0.041785936802625656 learning rate =  0.01
.. grad:  2.0362117290496826 2.8997750282287598
epoch =  384 , loss= 0.04178622364997864 learning rate =  0.01
.. grad:  2.0362091064453125 2.899789810180664
epoch =  385 , loss= 0.04178597033023834 learning rate =  0.01
.. grad:  2.0362064838409424 2.8998043537139893
epoch =  386 , loss= 0.041785769164562225 learning rate =  0.01
.. grad:  2.0362038612365723 2.8998186588287354
epoch =  387 , loss= 0.04178587347269058 learning rate =  0.01
.. grad:  2.0362014770507812 2.8998327255249023
epoch =  388 , loss= 0.041785769164562225 learning rate =  0.01
.. grad:  2.036198854446411 2.8998465538024902
epoch =  389 , loss= 0.04178565740585327 learning rate =  0.01
.. grad:  2.03619647026062 2.89985990524292
epoch =  390 , loss= 0.04178566858172417 learning rate =  0.01
.. grad:  2.036194086074829 2.8998730182647705
epoch =  391 , loss= 0.041785676032304764 learning rate =  0.01
.. grad:  2.036191701889038 2.899885892868042
epoch =  392 , loss= 0.041786015033721924 learning rate =  0.01
.. grad:  2.036189556121826 2.8998985290527344
epoch =  393 , loss= 0.041785821318626404 learning rate =  0.01
.. grad:  2.0361874103546143 2.8999109268188477
epoch =  394 , loss= 0.04178568348288536 learning rate =  0.01
.. grad:  2.0361850261688232 2.899923086166382
epoch =  395 , loss= 0.0417858324944973 learning rate =  0.01
.. grad:  2.0361828804016113 2.899935007095337
epoch =  396 , loss= 0.041785724461078644 learning rate =  0.01
.. grad:  2.0361809730529785 2.899946689605713
epoch =  397 , loss= 0.04178568720817566 learning rate =  0.01
.. grad:  2.0361790657043457 2.8999581336975098
epoch =  398 , loss= 0.041785672307014465 learning rate =  0.01
.. grad:  2.036176919937134 2.8999693393707275
epoch =  399 , loss= 0.04178570210933685 learning rate =  0.01
.. grad:  2.036175012588501 2.899980306625366
epoch =  400 , loss= 0.04178576543927193 learning rate =  0.01
.. grad:  2.036172866821289 2.899991035461426
epoch =  401 , loss= 0.04178585484623909 learning rate =  0.01
.. grad:  2.0361711978912354 2.9000015258789062
epoch =  402 , loss= 0.041785720735788345 learning rate =  0.01
.. grad:  2.0361692905426025 2.9000117778778076
epoch =  403 , loss= 0.04178587347269058 learning rate =  0.01
.. grad:  2.036167621612549 2.90002179145813
epoch =  404 , loss= 0.04178548976778984 learning rate =  0.01
.. grad:  2.036165714263916 2.900031805038452
epoch =  405 , loss= 0.04178572818636894 learning rate =  0.01
.. grad:  2.0361640453338623 2.9000415802001953
epoch =  406 , loss= 0.041785649955272675 learning rate =  0.01
.. grad:  2.0361623764038086 2.9000511169433594
epoch =  407 , loss= 0.04178565740585327 learning rate =  0.01
.. grad:  2.036160707473755 2.9000604152679443
epoch =  408 , loss= 0.04178568720817566 learning rate =  0.01
.. grad:  2.036159038543701 2.90006947517395
epoch =  409 , loss= 0.04178544506430626 learning rate =  0.01
.. grad:  2.0361573696136475 2.900078535079956
epoch =  410 , loss= 0.04178581014275551 learning rate =  0.01
.. grad:  2.0361557006835938 2.900087356567383
epoch =  411 , loss= 0.041785623878240585 learning rate =  0.01
.. grad:  2.036154270172119 2.9000959396362305
epoch =  412 , loss= 0.041785478591918945 learning rate =  0.01
.. grad:  2.0361526012420654 2.900104284286499
epoch =  413 , loss= 0.041785649955272675 learning rate =  0.01
.. grad:  2.036151170730591 2.9001126289367676
epoch =  414 , loss= 0.04178580641746521 learning rate =  0.01
.. grad:  2.036149740219116 2.900120735168457
epoch =  415 , loss= 0.04178544133901596 learning rate =  0.01
.. grad:  2.0361483097076416 2.9001286029815674
epoch =  416 , loss= 0.04178541153669357 learning rate =  0.01
.. grad:  2.036146879196167 2.9001364707946777
epoch =  417 , loss= 0.041785381734371185 learning rate =  0.01
.. grad:  2.0361456871032715 2.900144100189209
epoch =  418 , loss= 0.04178565368056297 learning rate =  0.01
.. grad:  2.036144256591797 2.900151491165161
epoch =  419 , loss= 0.041785407811403275 learning rate =  0.01
.. grad:  2.0361430644989014 2.9001588821411133
epoch =  420 , loss= 0.04178541153669357 learning rate =  0.01
.. grad:  2.0361416339874268 2.9001660346984863
epoch =  421 , loss= 0.04178576543927193 learning rate =  0.01
.. grad:  2.0361404418945312 2.9001729488372803
epoch =  422 , loss= 0.041785579174757004 learning rate =  0.01
.. grad:  2.0361392498016357 2.900179862976074
epoch =  423 , loss= 0.041785676032304764 learning rate =  0.01
.. grad:  2.0361380577087402 2.900186538696289
epoch =  424 , loss= 0.041785549372434616 learning rate =  0.01
.. grad:  2.0361368656158447 2.900193214416504
epoch =  425 , loss= 0.04178541898727417 learning rate =  0.01
.. grad:  2.036135673522949 2.9001996517181396
epoch =  426 , loss= 0.04178556427359581 learning rate =  0.01
.. grad:  2.0361344814300537 2.9002060890197754
epoch =  427 , loss= 0.041785769164562225 learning rate =  0.01
.. grad:  2.0361335277557373 2.900212287902832
epoch =  428 , loss= 0.04178537428379059 learning rate =  0.01
.. grad:  2.036132335662842 2.9002182483673096
epoch =  429 , loss= 0.041785333305597305 learning rate =  0.01
.. grad:  2.0361311435699463 2.900224208831787
epoch =  430 , loss= 0.04178529232740402 learning rate =  0.01
.. grad:  2.03613018989563 2.9002299308776855
epoch =  431 , loss= 0.041785288602113724 learning rate =  0.01
.. grad:  2.0361292362213135 2.900235652923584
epoch =  432 , loss= 0.041785284876823425 learning rate =  0.01
.. grad:  2.036128282546997 2.9002411365509033
epoch =  433 , loss= 0.041785337030887604 learning rate =  0.01
.. grad:  2.0361273288726807 2.9002466201782227
epoch =  434 , loss= 0.041785385459661484 learning rate =  0.01
.. grad:  2.036126136779785 2.900252103805542
epoch =  435 , loss= 0.041785672307014465 learning rate =  0.01
.. grad:  2.036125421524048 2.9002573490142822
epoch =  436 , loss= 0.04178546741604805 learning rate =  0.01
.. grad:  2.0361244678497314 2.9002625942230225
epoch =  437 , loss= 0.04178554192185402 learning rate =  0.01
.. grad:  2.036123514175415 2.9002676010131836
epoch =  438 , loss= 0.04178565740585327 learning rate =  0.01
.. grad:  2.0361227989196777 2.9002726078033447
epoch =  439 , loss= 0.04178544506430626 learning rate =  0.01
.. grad:  2.0361218452453613 2.9002773761749268
epoch =  440 , loss= 0.04178530350327492 learning rate =  0.01
.. grad:  2.036120891571045 2.900282144546509
epoch =  441 , loss= 0.04178544133901596 learning rate =  0.01
.. grad:  2.0361201763153076 2.9002866744995117
epoch =  442 , loss= 0.04178561642765999 learning rate =  0.01
.. grad:  2.0361194610595703 2.9002912044525146
epoch =  443 , loss= 0.04178551211953163 learning rate =  0.01
.. grad:  2.036118507385254 2.9002957344055176
epoch =  444 , loss= 0.04178541153669357 learning rate =  0.01
.. grad:  2.0361177921295166 2.9003000259399414
epoch =  445 , loss= 0.041785359382629395 learning rate =  0.01
.. grad:  2.0361168384552 2.9003043174743652
epoch =  446 , loss= 0.04178525134921074 learning rate =  0.01
.. grad:  2.036116123199463 2.900308609008789
epoch =  447 , loss= 0.041785482317209244 learning rate =  0.01
.. grad:  2.0361154079437256 2.900312662124634
epoch =  448 , loss= 0.04178544133901596 learning rate =  0.01
.. grad:  2.0361146926879883 2.9003167152404785
epoch =  449 , loss= 0.041785404086112976 learning rate =  0.01
.. grad:  2.036113977432251 2.9003207683563232
epoch =  450 , loss= 0.04178565740585327 learning rate =  0.01
.. grad:  2.0361132621765137 2.900324583053589
epoch =  451 , loss= 0.041785672307014465 learning rate =  0.01
.. grad:  2.0361127853393555 2.9003283977508545
epoch =  452 , loss= 0.04178563877940178 learning rate =  0.01
.. grad:  2.036112070083618 2.90033221244812
epoch =  453 , loss= 0.04178565368056297 learning rate =  0.01
.. grad:  2.036111354827881 2.9003357887268066
epoch =  454 , loss= 0.04178539663553238 learning rate =  0.01
.. grad:  2.0361106395721436 2.900339365005493
epoch =  455 , loss= 0.041785433888435364 learning rate =  0.01
.. grad:  2.0361101627349854 2.9003429412841797
epoch =  456 , loss= 0.041785482317209244 learning rate =  0.01
.. grad:  2.036109447479248 2.900346279144287
epoch =  457 , loss= 0.04178554564714432 learning rate =  0.01
.. grad:  2.0361087322235107 2.9003496170043945
epoch =  458 , loss= 0.041785310953855515 learning rate =  0.01
.. grad:  2.0361082553863525 2.900352954864502
epoch =  459 , loss= 0.04178566858172417 learning rate =  0.01
.. grad:  2.0361075401306152 2.9003562927246094
epoch =  460 , loss= 0.04178546369075775 learning rate =  0.01
.. grad:  2.036107063293457 2.9003593921661377
epoch =  461 , loss= 0.04178529977798462 learning rate =  0.01
.. grad:  2.036106586456299 2.900362491607666
epoch =  462 , loss= 0.041785385459661484 learning rate =  0.01
.. grad:  2.0361061096191406 2.9003655910491943
epoch =  463 , loss= 0.04178548604249954 learning rate =  0.01
.. grad:  2.0361053943634033 2.9003684520721436
epoch =  464 , loss= 0.04178564250469208 learning rate =  0.01
.. grad:  2.036104917526245 2.9003713130950928
epoch =  465 , loss= 0.041785500943660736 learning rate =  0.01
.. grad:  2.036104440689087 2.900374174118042
epoch =  466 , loss= 0.04178563877940178 learning rate =  0.01
.. grad:  2.0361039638519287 2.900377035140991
epoch =  467 , loss= 0.04178549721837044 learning rate =  0.01
.. grad:  2.0361034870147705 2.9003796577453613
epoch =  468 , loss= 0.04178566113114357 learning rate =  0.01
.. grad:  2.0361030101776123 2.9003822803497314
epoch =  469 , loss= 0.04178528115153313 learning rate =  0.01
.. grad:  2.036102533340454 2.9003849029541016
epoch =  470 , loss= 0.04178547114133835 learning rate =  0.01
.. grad:  2.036102056503296 2.9003875255584717
epoch =  471 , loss= 0.04178563505411148 learning rate =  0.01
.. grad:  2.0361015796661377 2.900390148162842
epoch =  472 , loss= 0.04178522154688835 learning rate =  0.01
.. grad:  2.0361011028289795 2.900392532348633
epoch =  473 , loss= 0.04178544133901596 learning rate =  0.01
.. grad:  2.0361006259918213 2.900394916534424
epoch =  474 , loss= 0.04178565740585327 learning rate =  0.01
.. grad:  2.036100387573242 2.900397300720215
epoch =  475 , loss= 0.04178531467914581 learning rate =  0.01
.. grad:  2.036099910736084 2.900399684906006
epoch =  476 , loss= 0.041785530745983124 learning rate =  0.01
.. grad:  2.036099433898926 2.900402069091797
epoch =  477 , loss= 0.041785452514886856 learning rate =  0.01
.. grad:  2.0360989570617676 2.900404214859009
epoch =  478 , loss= 0.04178539663553238 learning rate =  0.01
.. grad:  2.0360987186431885 2.9004063606262207
epoch =  479 , loss= 0.04178565368056297 learning rate =  0.01
.. grad:  2.0360982418060303 2.9004085063934326
epoch =  480 , loss= 0.04178563132882118 learning rate =  0.01
.. grad:  2.036098003387451 2.9004106521606445
epoch =  481 , loss= 0.0417855903506279 learning rate =  0.01
.. grad:  2.036097526550293 2.9004127979278564
epoch =  482 , loss= 0.041785284876823425 learning rate =  0.01
.. grad:  2.0360970497131348 2.9004149436950684
epoch =  483 , loss= 0.041785527020692825 learning rate =  0.01
.. grad:  2.0360968112945557 2.900416851043701
epoch =  484 , loss= 0.04178521782159805 learning rate =  0.01
.. grad:  2.0360963344573975 2.900418758392334
epoch =  485 , loss= 0.04178524762392044 learning rate =  0.01
.. grad:  2.0360960960388184 2.900420665740967
epoch =  486 , loss= 0.041785527020692825 learning rate =  0.01
.. grad:  2.03609561920166 2.9004225730895996
epoch =  487 , loss= 0.041785527020692825 learning rate =  0.01
.. grad:  2.036095380783081 2.9004244804382324
epoch =  488 , loss= 0.04178524389863014 learning rate =  0.01
.. grad:  2.036095142364502 2.900426149368286
epoch =  489 , loss= 0.04178525507450104 learning rate =  0.01
.. grad:  2.0360946655273438 2.90042781829834
epoch =  490 , loss= 0.041785307228565216 learning rate =  0.01
.. grad:  2.0360944271087646 2.9004294872283936
epoch =  491 , loss= 0.041785627603530884 learning rate =  0.01
.. grad:  2.0360941886901855 2.9004311561584473
epoch =  492 , loss= 0.04178537055850029 learning rate =  0.01
.. grad:  2.0360939502716064 2.900432825088501
epoch =  493 , loss= 0.041785407811403275 learning rate =  0.01
.. grad:  2.0360937118530273 2.9004344940185547
epoch =  494 , loss= 0.04178544878959656 learning rate =  0.01
.. grad:  2.0360934734344482 2.9004361629486084
epoch =  495 , loss= 0.04178548604249954 learning rate =  0.01
.. grad:  2.03609299659729 2.900437593460083
epoch =  496 , loss= 0.04178526625037193 learning rate =  0.01
.. grad:  2.036092758178711 2.9004390239715576
epoch =  497 , loss= 0.04178561270236969 learning rate =  0.01
.. grad:  2.036092519760132 2.9004404544830322
epoch =  498 , loss= 0.041785407811403275 learning rate =  0.01
.. grad:  2.0360922813415527 2.900441884994507
epoch =  499 , loss= 0.041785500943660736 learning rate =  0.01
tensor([[ 19.1892],
        [ 11.0448],
        [  2.9004],
        [ 11.0448],
        [  4.9365]])
tensor([[ 19.3362],
        [ 10.8899],
        [  3.0000],
        [ 10.8899],
        [  5.0000]])

In [87]:
print(w)
print(b)


tensor([[ 2.0361]])
tensor([[ 2.9004],
        [ 2.9004],
        [ 2.9004],
        [ 2.9004],
        [ 2.9004]])

This is fairly close to the ground truth: $w = 2.0, b = 3.0$


In [ ]:


In [ ]:


In [ ]:

Ground Truth Model

The ground truth model is:

$$ y = a x^{1\pm \epsilon/2} + b, $$

where $\epsilon$ is a random variable that takes the value of epsilon with probability 0.5, and takes the value of -epsilon with probability 0.5.


In [ ]:


In [ ]: