In [2]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sc
import pandas as pd

# import seaborn as sns
# sns.set(color_codes=True)

# plt.figure(figsize=(5,5))
# df = pd.read_csv(u'data/wind_tribune.csv')
# sns.jointplot(x='wind_speed', y='production', data=df);
# plt.show()

In [4]:
df_iris = pd.read_csv(u'notes/data/iris.txt',sep=' ')

In [10]:
df_iris


Out[10]:
sl sw pl pw c
0 5.1 3.5 1.4 0.2 1
1 4.9 3.0 1.4 0.2 1
2 4.7 3.2 1.3 0.2 1
3 4.6 3.1 1.5 0.2 1
4 5.0 3.6 1.4 0.2 1
5 5.4 3.9 1.7 0.4 1
6 4.6 3.4 1.4 0.3 1
7 5.0 3.4 1.5 0.2 1
8 4.4 2.9 1.4 0.2 1
9 4.9 3.1 1.5 0.1 1
10 5.4 3.7 1.5 0.2 1
11 4.8 3.4 1.6 0.2 1
12 4.8 3.0 1.4 0.1 1
13 4.3 3.0 1.1 0.1 1
14 5.8 4.0 1.2 0.2 1
15 5.7 4.4 1.5 0.4 1
16 5.4 3.9 1.3 0.4 1
17 5.1 3.5 1.4 0.3 1
18 5.7 3.8 1.7 0.3 1
19 5.1 3.8 1.5 0.3 1
20 5.4 3.4 1.7 0.2 1
21 5.1 3.7 1.5 0.4 1
22 4.6 3.6 1.0 0.2 1
23 5.1 3.3 1.7 0.5 1
24 4.8 3.4 1.9 0.2 1
25 5.0 3.0 1.6 0.2 1
26 5.0 3.4 1.6 0.4 1
27 5.2 3.5 1.5 0.2 1
28 5.2 3.4 1.4 0.2 1
29 4.7 3.2 1.6 0.2 1
... ... ... ... ... ...
120 6.9 3.2 5.7 2.3 3
121 5.6 2.8 4.9 2.0 3
122 7.7 2.8 6.7 2.0 3
123 6.3 2.7 4.9 1.8 3
124 6.7 3.3 5.7 2.1 3
125 7.2 3.2 6.0 1.8 3
126 6.2 2.8 4.8 1.8 3
127 6.1 3.0 4.9 1.8 3
128 6.4 2.8 5.6 2.1 3
129 7.2 3.0 5.8 1.6 3
130 7.4 2.8 6.1 1.9 3
131 7.9 3.8 6.4 2.0 3
132 6.4 2.8 5.6 2.2 3
133 6.3 2.8 5.1 1.5 3
134 6.1 2.6 5.6 1.4 3
135 7.7 3.0 6.1 2.3 3
136 6.3 3.4 5.6 2.4 3
137 6.4 3.1 5.5 1.8 3
138 6.0 3.0 4.8 1.8 3
139 6.9 3.1 5.4 2.1 3
140 6.7 3.1 5.6 2.4 3
141 6.9 3.1 5.1 2.3 3
142 5.8 2.7 5.1 1.9 3
143 6.8 3.2 5.9 2.3 3
144 6.7 3.3 5.7 2.5 3
145 6.7 3.0 5.2 2.3 3
146 6.3 2.5 5.0 1.9 3
147 6.5 3.0 5.2 2.0 3
148 6.2 3.4 5.4 2.3 3
149 5.9 3.0 5.1 1.8 3

150 rows × 5 columns


In [16]:
target = np.array(df_iris['c'])
features = np.array(df_iris[['sl','sw','pl','pw']])

In [17]:
features


Out[17]:
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2],
       [ 5.4,  3.9,  1.7,  0.4],
       [ 4.6,  3.4,  1.4,  0.3],
       [ 5. ,  3.4,  1.5,  0.2],
       [ 4.4,  2.9,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5.4,  3.7,  1.5,  0.2],
       [ 4.8,  3.4,  1.6,  0.2],
       [ 4.8,  3. ,  1.4,  0.1],
       [ 4.3,  3. ,  1.1,  0.1],
       [ 5.8,  4. ,  1.2,  0.2],
       [ 5.7,  4.4,  1.5,  0.4],
       [ 5.4,  3.9,  1.3,  0.4],
       [ 5.1,  3.5,  1.4,  0.3],
       [ 5.7,  3.8,  1.7,  0.3],
       [ 5.1,  3.8,  1.5,  0.3],
       [ 5.4,  3.4,  1.7,  0.2],
       [ 5.1,  3.7,  1.5,  0.4],
       [ 4.6,  3.6,  1. ,  0.2],
       [ 5.1,  3.3,  1.7,  0.5],
       [ 4.8,  3.4,  1.9,  0.2],
       [ 5. ,  3. ,  1.6,  0.2],
       [ 5. ,  3.4,  1.6,  0.4],
       [ 5.2,  3.5,  1.5,  0.2],
       [ 5.2,  3.4,  1.4,  0.2],
       [ 4.7,  3.2,  1.6,  0.2],
       [ 4.8,  3.1,  1.6,  0.2],
       [ 5.4,  3.4,  1.5,  0.4],
       [ 5.2,  4.1,  1.5,  0.1],
       [ 5.5,  4.2,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5. ,  3.2,  1.2,  0.2],
       [ 5.5,  3.5,  1.3,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 4.4,  3. ,  1.3,  0.2],
       [ 5.1,  3.4,  1.5,  0.2],
       [ 5. ,  3.5,  1.3,  0.3],
       [ 4.5,  2.3,  1.3,  0.3],
       [ 4.4,  3.2,  1.3,  0.2],
       [ 5. ,  3.5,  1.6,  0.6],
       [ 5.1,  3.8,  1.9,  0.4],
       [ 4.8,  3. ,  1.4,  0.3],
       [ 5.1,  3.8,  1.6,  0.2],
       [ 4.6,  3.2,  1.4,  0.2],
       [ 5.3,  3.7,  1.5,  0.2],
       [ 5. ,  3.3,  1.4,  0.2],
       [ 7. ,  3.2,  4.7,  1.4],
       [ 6.4,  3.2,  4.5,  1.5],
       [ 6.9,  3.1,  4.9,  1.5],
       [ 5.5,  2.3,  4. ,  1.3],
       [ 6.5,  2.8,  4.6,  1.5],
       [ 5.7,  2.8,  4.5,  1.3],
       [ 6.3,  3.3,  4.7,  1.6],
       [ 4.9,  2.4,  3.3,  1. ],
       [ 6.6,  2.9,  4.6,  1.3],
       [ 5.2,  2.7,  3.9,  1.4],
       [ 5. ,  2. ,  3.5,  1. ],
       [ 5.9,  3. ,  4.2,  1.5],
       [ 6. ,  2.2,  4. ,  1. ],
       [ 6.1,  2.9,  4.7,  1.4],
       [ 5.6,  2.9,  3.6,  1.3],
       [ 6.7,  3.1,  4.4,  1.4],
       [ 5.6,  3. ,  4.5,  1.5],
       [ 5.8,  2.7,  4.1,  1. ],
       [ 6.2,  2.2,  4.5,  1.5],
       [ 5.6,  2.5,  3.9,  1.1],
       [ 5.9,  3.2,  4.8,  1.8],
       [ 6.1,  2.8,  4. ,  1.3],
       [ 6.3,  2.5,  4.9,  1.5],
       [ 6.1,  2.8,  4.7,  1.2],
       [ 6.4,  2.9,  4.3,  1.3],
       [ 6.6,  3. ,  4.4,  1.4],
       [ 6.8,  2.8,  4.8,  1.4],
       [ 6.7,  3. ,  5. ,  1.7],
       [ 6. ,  2.9,  4.5,  1.5],
       [ 5.7,  2.6,  3.5,  1. ],
       [ 5.5,  2.4,  3.8,  1.1],
       [ 5.5,  2.4,  3.7,  1. ],
       [ 5.8,  2.7,  3.9,  1.2],
       [ 6. ,  2.7,  5.1,  1.6],
       [ 5.4,  3. ,  4.5,  1.5],
       [ 6. ,  3.4,  4.5,  1.6],
       [ 6.7,  3.1,  4.7,  1.5],
       [ 6.3,  2.3,  4.4,  1.3],
       [ 5.6,  3. ,  4.1,  1.3],
       [ 5.5,  2.5,  4. ,  1.3],
       [ 5.5,  2.6,  4.4,  1.2],
       [ 6.1,  3. ,  4.6,  1.4],
       [ 5.8,  2.6,  4. ,  1.2],
       [ 5. ,  2.3,  3.3,  1. ],
       [ 5.6,  2.7,  4.2,  1.3],
       [ 5.7,  3. ,  4.2,  1.2],
       [ 5.7,  2.9,  4.2,  1.3],
       [ 6.2,  2.9,  4.3,  1.3],
       [ 5.1,  2.5,  3. ,  1.1],
       [ 5.7,  2.8,  4.1,  1.3],
       [ 6.3,  3.3,  6. ,  2.5],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 7.1,  3. ,  5.9,  2.1],
       [ 6.3,  2.9,  5.6,  1.8],
       [ 6.5,  3. ,  5.8,  2.2],
       [ 7.6,  3. ,  6.6,  2.1],
       [ 4.9,  2.5,  4.5,  1.7],
       [ 7.3,  2.9,  6.3,  1.8],
       [ 6.7,  2.5,  5.8,  1.8],
       [ 7.2,  3.6,  6.1,  2.5],
       [ 6.5,  3.2,  5.1,  2. ],
       [ 6.4,  2.7,  5.3,  1.9],
       [ 6.8,  3. ,  5.5,  2.1],
       [ 5.7,  2.5,  5. ,  2. ],
       [ 5.8,  2.8,  5.1,  2.4],
       [ 6.4,  3.2,  5.3,  2.3],
       [ 6.5,  3. ,  5.5,  1.8],
       [ 7.7,  3.8,  6.7,  2.2],
       [ 7.7,  2.6,  6.9,  2.3],
       [ 6. ,  2.2,  5. ,  1.5],
       [ 6.9,  3.2,  5.7,  2.3],
       [ 5.6,  2.8,  4.9,  2. ],
       [ 7.7,  2.8,  6.7,  2. ],
       [ 6.3,  2.7,  4.9,  1.8],
       [ 6.7,  3.3,  5.7,  2.1],
       [ 7.2,  3.2,  6. ,  1.8],
       [ 6.2,  2.8,  4.8,  1.8],
       [ 6.1,  3. ,  4.9,  1.8],
       [ 6.4,  2.8,  5.6,  2.1],
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.4,  2.8,  6.1,  1.9],
       [ 7.9,  3.8,  6.4,  2. ],
       [ 6.4,  2.8,  5.6,  2.2],
       [ 6.3,  2.8,  5.1,  1.5],
       [ 6.1,  2.6,  5.6,  1.4],
       [ 7.7,  3. ,  6.1,  2.3],
       [ 6.3,  3.4,  5.6,  2.4],
       [ 6.4,  3.1,  5.5,  1.8],
       [ 6. ,  3. ,  4.8,  1.8],
       [ 6.9,  3.1,  5.4,  2.1],
       [ 6.7,  3.1,  5.6,  2.4],
       [ 6.9,  3.1,  5.1,  2.3],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 6.8,  3.2,  5.9,  2.3],
       [ 6.7,  3.3,  5.7,  2.5],
       [ 6.7,  3. ,  5.2,  2.3],
       [ 6.3,  2.5,  5. ,  1.9],
       [ 6.5,  3. ,  5.2,  2. ],
       [ 6.2,  3.4,  5.4,  2.3],
       [ 5.9,  3. ,  5.1,  1.8]])

In [18]:
w = np.array([1,1,1,1])

def f(x, w):
    return x.dot(w)

def E(y, f_est):
    return np.sum((y-f_est)**2)

In [19]:
f_est = f(features, w)
print(E(target, f_est))


21937.18

In [27]:
w = np.array([1,1,1,1])
eta = 0.0001
for epoch in range(10000):
    f_est = f(features, w)
    e = target - f_est
    dE = -features.T.dot(e)
    
    w = w - eta*dE
    if epoch%100==0:
        print(E(target, f_est))
        print(w)


21937.18
[-0.067463  0.459053  0.272187  0.761825]
10.5525666505
[-0.14851803  0.34806374  0.24380087  0.7329741 ]
10.1309151298
[-0.12584939  0.31553845  0.24217589  0.71009234]
9.80084327011
[-0.10518428  0.28768226  0.23936129  0.68924878]
9.54137138746
[-0.08680701  0.26299369  0.23675974  0.67082486]
9.33738879754
[-0.0704802   0.24107451  0.23440092  0.65456852]
9.17702137093
[-0.05597323  0.22161004  0.23226118  0.64023216]
9.05093643136
[-0.04308078  0.20432281  0.23031742  0.62759536]
8.95179876351
[-0.03162076  0.18896686  0.22854903  0.61646274]
8.87384306656
[-0.0214317   0.1753241   0.22693768  0.6066612 ]
8.81253792401
[-0.0123704   0.16320118  0.22546707  0.59803735]
8.76432170079
[-0.0043099   0.15242661  0.22412269  0.59045529]
8.72639497132
[ 0.00286243  0.14284835  0.22289162  0.58379462]
8.69655737803
[ 0.00924643  0.13433158  0.22176238  0.57794868]
8.67307940996
[ 0.01493067  0.12675672  0.22072471  0.57282298]
8.65460162677
[ 0.0199937   0.12001776  0.21976951  0.56833382]
8.64005545344
[ 0.02450517  0.11402067  0.21888862  0.56440707]
8.62860092836
[ 0.02852687  0.10868206  0.21807481  0.56097708]
8.61957777615
[ 0.0321136   0.10392799  0.21732161  0.5579857 ]
8.61246695279
[ 0.03531398  0.09969286  0.21662323  0.55538146]
8.60686042168
[ 0.03817113  0.0959185   0.21597453  0.55311878]
8.60243739869
[ 0.0407233   0.09255332  0.21537088  0.5511573 ]
8.59894568164
[ 0.04300443  0.08955155  0.21480817  0.5494613 ]
8.59618697576
[ 0.04504463  0.08687261  0.21428271  0.54799917]
8.59400535995
[ 0.04687062  0.08448047  0.2137912   0.54674294]
8.59227822148
[ 0.0485061   0.08234318  0.21333067  0.54566786]
8.59090913081
[ 0.04997211  0.08043242  0.21289847  0.54475203]
8.58982224125
[ 0.05128732  0.07872303  0.21249221  0.5439761 ]
8.58895788715
[ 0.0524683   0.07719269  0.21210976  0.54332294]
8.58826912397
[ 0.05352975  0.07582161  0.21174918  0.54277741]
8.58771900871
[ 0.05448475  0.07459222  0.21140874  0.54232615]
8.58727846234
[ 0.05534488  0.07348892  0.21108687  0.54195733]
8.58692458941
[ 0.05612045  0.07249787  0.21078217  0.54166051]
8.58663935725
[ 0.05682058  0.07160678  0.21049335  0.54142648]
8.58640855768
[ 0.05745341  0.07080475  0.21021927  0.54124711]
8.58622099077
[ 0.05802615  0.07008208  0.20995888  0.54111521]
8.5860678232
[ 0.05854521  0.06943018  0.20971122  0.54102447]
8.58594208373
[ 0.0590163   0.06884141  0.20947544  0.54096929]
8.58583826663
[ 0.05944447  0.06830898  0.20925075  0.54094478]
8.58575201975
[ 0.05983423  0.06782685  0.20903643  0.5409466 ]
8.58567989934
[ 0.06018958  0.06738968  0.20883184  0.54097095]
8.58561917715
[ 0.06051409  0.0669927   0.20863638  0.54101448]
8.58556768878
[ 0.06081093  0.06663167  0.2084495   0.54107426]
8.58552371435
[ 0.06108293  0.06630284  0.20827069  0.54114771]
8.5854858846
[ 0.06133258  0.06600285  0.2080995   0.54123257]
8.5854531071
[ 0.06156214  0.06572873  0.2079355   0.54132685]
8.58542450802
[ 0.06177359  0.06547782  0.2077783   0.54142883]
8.58539938646
[ 0.06196872  0.06524776  0.20762754  0.541537  ]
8.58537717832
[ 0.06214912  0.06503646  0.20748289  0.54165003]
8.58535742798
[ 0.06231619  0.06484203  0.20734403  0.54176679]
8.58533976588
[ 0.0624712   0.06466281  0.20721067  0.54188627]
8.58532389089
[ 0.06261528  0.06449732  0.20708254  0.54200761]
8.58530955641
[ 0.06274944  0.06434422  0.2069594   0.54213007]
8.58529655934
[ 0.06287458  0.06420233  0.20684101  0.542253  ]
8.58528473142
[ 0.06299151  0.06407059  0.20672716  0.54237585]
8.58527393236
[ 0.06310096  0.06394806  0.20661762  0.54249815]
8.58526404441
[ 0.06320357  0.06383388  0.20651222  0.54261949]
8.58525496804
[ 0.06329992  0.06372732  0.20641078  0.54273954]
8.58524661849
[ 0.06339054  0.06362767  0.20631311  0.54285801]
8.58523892309
[ 0.0634759   0.06353435  0.20621907  0.54297467]
8.58523181909
[ 0.06355642  0.06344681  0.20612849  0.5430893 ]
8.58522525189
[ 0.06363248  0.06336455  0.20604124  0.54320177]
8.58521917365
[ 0.06370442  0.06328714  0.20595717  0.54331192]
8.58521354222
[ 0.06377256  0.06321419  0.20587617  0.54341967]
8.58520832022
[ 0.06383718  0.06314534  0.2057981   0.54352494]
8.58520347426
[ 0.06389852  0.06308026  0.20572286  0.54362767]
8.58519897444
[ 0.06395683  0.06301868  0.20565032  0.54372782]
8.58519479378
[ 0.0640123   0.06296032  0.20558039  0.54382537]
8.58519090789
[ 0.06406514  0.06290496  0.20551297  0.54392032]
8.58518729457
[ 0.0641155   0.06285237  0.20544795  0.54401265]
8.58518393362
[ 0.06416355  0.06280237  0.20538525  0.54410239]
8.58518080655
[ 0.06420944  0.06275479  0.20532479  0.54418956]
8.58517789639
[ 0.0642533   0.06270945  0.20526646  0.54427418]
8.58517518757
[ 0.06429524  0.06266621  0.20521021  0.5443563 ]
8.58517266573
[ 0.06433537  0.06262494  0.20515594  0.54443594]
8.58517031764
[ 0.06437381  0.06258553  0.20510359  0.54451315]
8.58516813106
[ 0.06441064  0.06254785  0.20505309  0.54458797]
8.58516609468
[ 0.06444594  0.0625118   0.20500436  0.54466047]
8.58516419802
[ 0.06447981  0.0624773   0.20495735  0.54473068]
8.58516243137
[ 0.06451232  0.06244425  0.20491199  0.54479866]
8.58516078571
[ 0.06454352  0.06241258  0.20486822  0.54486446]
8.58515925268
[ 0.0645735   0.0623822   0.20482599  0.54492815]
8.58515782451
[ 0.0646023   0.06235306  0.20478524  0.54498977]
8.58515649398
[ 0.06462998  0.06232509  0.20474591  0.54504938]
8.58515525437
[ 0.0646566   0.06229823  0.20470796  0.54510703]
8.58515409945
[ 0.06468221  0.06227243  0.20467134  0.54516278]
8.5851530234
[ 0.06470684  0.06224764  0.204636    0.54521669]
8.58515202081
[ 0.06473054  0.0622238   0.20460189  0.54526881]
8.58515108667
[ 0.06475336  0.06220088  0.20456897  0.54531918]
8.58515021628
[ 0.06477533  0.06217882  0.2045372   0.54536788]
8.58514940528
[ 0.06479649  0.06215761  0.20450654  0.54541493]
8.58514864962
[ 0.06481687  0.06213719  0.20447694  0.54546041]
8.58514794551
[ 0.0648365   0.06211753  0.20444838  0.54550434]
8.58514728943
[ 0.06485542  0.0620986   0.20442081  0.5455468 ]
8.58514667811
[ 0.06487365  0.06208037  0.2043942   0.54558781]
8.58514610848
[ 0.06489121  0.06206281  0.20436852  0.54562743]
8.5851455777
[ 0.06490815  0.06204589  0.20434373  0.5456657 ]
8.58514508312
[ 0.06492447  0.06202959  0.2043198   0.54570267]
8.58514462227
[ 0.06494021  0.06201388  0.20429671  0.54573838]
8.58514419284
[ 0.06495539  0.06199874  0.20427442  0.54577287]

In [29]:
plt.plot(features.dot(w), target, 'o')


Out[29]:
[<matplotlib.lines.Line2D at 0x1186bd518>]

In [107]:
def f(x):
    return x**2 - 2*x


def df(x):
    return 2*x - 2


x = 0
eta = 0.1
for i in range(200):
    x = x - eta*df(x)
    
    print(x, df(x))


0.2 -1.6
0.36000000000000004 -1.2799999999999998
0.488 -1.024
0.5904 -0.8191999999999999
0.67232 -0.6553599999999999
0.7378560000000001 -0.5242879999999999
0.7902848 -0.4194304
0.83222784 -0.3355443199999999
0.865782272 -0.26843545599999996
0.8926258176 -0.21474836479999992
0.9141006540800001 -0.17179869183999985
0.931280523264 -0.13743895347199997
0.9450244186112 -0.10995116277759998
0.95601953488896 -0.08796093022208007
0.9648156279111679 -0.07036874417766414
0.9718525023289344 -0.05629499534213123
0.9774820018631475 -0.045035996273705026
0.981985601490518 -0.03602879701896411
0.9855884811924144 -0.028823037615171243
0.9884707849539315 -0.02305843009213704
0.9907766279631451 -0.01844674407370972
0.9926213023705162 -0.014757395258967687
0.9940970418964129 -0.011805916207174194
0.9952776335171303 -0.009444732965739444
0.9962221068137043 -0.0075557863725914665
0.9969776854509634 -0.006044629098073129
0.9975821483607707 -0.004835703278458503
0.9980657186886166 -0.0038685626227668024
0.9984525749508932 -0.0030948500982135307
0.9987620599607145 -0.0024758800785709134
0.9990096479685716 -0.001980704062856775
0.9992077183748573 -0.0015845632502853313
0.9993661746998859 -0.001267650600228265
0.9994929397599087 -0.0010141204801825676
0.999594351807927 -0.0008112963841460097
0.9996754814463416 -0.0006490371073168966
0.9997403851570732 -0.0005192296858536061
0.9997923081256586 -0.00041538374868288486
0.9998338465005269 -0.0003323069989462635
0.9998670772004215 -0.0002658455991570996
0.9998936617603371 -0.0002126764793257685
0.9999149294082696 -0.00017014118346070362
0.9999319435266157 -0.00013611294676851848
0.9999455548212925 -0.0001088903574149036
0.9999564438570341 -8.711228593183407e-05
0.9999651550856272 -6.968982874555607e-05
0.9999721240685018 -5.5751862996444856e-05
0.9999776992548014 -4.4601490397200294e-05
0.9999821594038412 -3.568119231767142e-05
0.9999857275230729 -2.8544953854181543e-05
0.9999885820184583 -2.2835963083389643e-05
0.9999908656147667 -1.8268770466622897e-05
0.9999926924918133 -1.4615016373342726e-05
0.9999941539934507 -1.1692013098585363e-05
0.9999953231947606 -9.353610478823882e-06
0.9999962585558084 -7.482888383147923e-06
0.9999970068446468 -5.98631070647393e-06
0.9999976054757174 -4.7890485652679615e-06
0.9999980843805739 -3.83123885216996e-06
0.9999984675044591 -3.064991081824786e-06
0.9999987740035673 -2.45199286541542e-06
0.9999990192028538 -1.961594292332336e-06
0.9999992153622831 -1.5692754338214598e-06
0.9999993722898265 -1.255420347012759e-06
0.9999994978318612 -1.004336277699025e-06
0.9999995982654889 -8.034690222036289e-07
0.9999996786123911 -6.427752177184942e-07
0.9999997428899129 -5.142201742192043e-07
0.9999997943119303 -4.1137613937536344e-07
0.9999998354495443 -3.291009114114729e-07
0.9999998683596354 -2.6328072921799617e-07
0.9999998946877083 -2.1062458332998801e-07
0.9999999157501667 -1.684996666639904e-07
0.9999999326001333 -1.3479973337560125e-07
0.9999999460801067 -1.0783978665607208e-07
0.9999999568640854 -8.627182923603982e-08
0.9999999654912683 -6.901746330001401e-08
0.9999999723930146 -5.5213970728829054e-08
0.9999999779144118 -4.41711764942454e-08
0.9999999823315294 -3.53369411509874e-08
0.9999999858652235 -2.826955292078992e-08
0.9999999886921789 -2.2615642247814094e-08
0.9999999909537431 -1.8092513709433433e-08
0.9999999927629946 -1.4474010878728905e-08
0.9999999942103956 -1.1579208702983124e-08
0.9999999953683165 -9.263366962386499e-09
0.9999999962946532 -7.41069361431812e-09
0.9999999970357225 -5.928554980272338e-09
0.999999997628578 -4.7428438954000285e-09
0.9999999981028624 -3.794275116320023e-09
0.9999999984822899 -3.0354201374649392e-09
0.9999999987858319 -2.4283361987897933e-09
0.9999999990286655 -1.9426689146229137e-09
0.9999999992229325 -1.554135042880489e-09
0.9999999993783459 -1.2433081231222332e-09
0.9999999995026767 -9.946465873156285e-10
0.9999999996021414 -7.957172698525028e-10
0.999999999681713 -6.365739046998442e-10
0.9999999997453705 -5.092590793509544e-10
0.9999999997962964 -4.074072190718425e-10
0.9999999998370371 -3.2592573084855303e-10
0.9999999998696297 -2.6074054026992144e-10
0.9999999998957038 -2.0859247662485814e-10
0.999999999916563 -1.6687407011772848e-10
0.9999999999332504 -1.334992116852618e-10
0.9999999999466003 -1.0679945816605141e-10
0.9999999999572802 -8.54396553506831e-11
0.9999999999658241 -6.835176868946746e-11
0.9999999999726593 -5.4681370542652985e-11
0.9999999999781275 -4.374500761628042e-11
0.999999999982502 -3.4996006093024334e-11
0.9999999999860016 -2.7996716056577498e-11
0.9999999999888013 -2.2397417254182983e-11
0.999999999991041 -1.79178893944254e-11
0.9999999999928328 -1.4334311515540321e-11
0.9999999999942663 -1.1467493621353242e-11
0.999999999995413 -9.173994897082594e-12
0.9999999999963304 -7.33924032658706e-12
0.9999999999970643 -5.871303443427678e-12
0.9999999999976514 -4.697131572584112e-12
0.9999999999981212 -3.757660849146305e-12
0.999999999998497 -3.006039861475074e-12
0.9999999999987976 -2.404743071338089e-12
0.9999999999990381 -1.9237944570704713e-12
0.9999999999992305 -1.538991156735392e-12
0.9999999999993844 -1.2312373343092986e-12
0.9999999999995075 -9.849898674474389e-13
0.999999999999606 -7.880363028789361e-13
0.9999999999996848 -6.303846333821639e-13
0.9999999999997479 -5.042632977847461e-13
0.9999999999997983 -4.034550471487819e-13
0.9999999999998386 -3.228528555609955e-13
0.9999999999998709 -2.582378755278114e-13
0.9999999999998967 -2.0650148258027912e-13
0.9999999999999174 -1.652011860642233e-13
0.9999999999999339 -1.3211653993039363e-13
0.9999999999999472 -1.056932319443149e-13
0.9999999999999577 -8.459899447643693e-14
0.9999999999999661 -6.772360450213455e-14
0.9999999999999729 -5.417888360170764e-14
0.9999999999999784 -4.3298697960381105e-14
0.9999999999999827 -3.4638958368304884e-14
0.9999999999999861 -2.7755575615628914e-14
0.9999999999999889 -2.220446049250313e-14
0.9999999999999911 -1.7763568394002505e-14
0.9999999999999929 -1.4210854715202004e-14
0.9999999999999943 -1.1324274851176597e-14
0.9999999999999954 -9.103828801926284e-15
0.9999999999999963 -7.327471962526033e-15
0.9999999999999971 -5.773159728050814e-15
0.9999999999999977 -4.6629367034256575e-15
0.9999999999999981 -3.774758283725532e-15
0.9999999999999984 -3.1086244689504383e-15
0.9999999999999988 -2.4424906541753444e-15
0.999999999999999 -1.9984014443252818e-15
0.9999999999999992 -1.5543122344752192e-15
0.9999999999999993 -1.3322676295501878e-15
0.9999999999999994 -1.1102230246251565e-15
0.9999999999999996 -8.881784197001252e-16
0.9999999999999997 -6.661338147750939e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16
0.9999999999999998 -4.440892098500626e-16

In [108]:
def f(x):
    return 3*x[0]**2 + 2*x[1]**2 - 2*x[0]*x[1] + x[0] - x[1]

def df(x):
    return np.array([6*x[0]  - 2*x[1] + 1 ,  4*x[1] - 2*x[0]  - 1])

x = np.array([0, 0])
eta = 0.1
for i in range(200):
    x = x - eta*df(x)
    
    print(x, df(x))


[-0.1  0.1] [ 0.2 -0.4]
[-0.12  0.14] [ 0.  -0.2]
[-0.12  0.16] [-0.04 -0.12]
[-0.116  0.172] [-0.04 -0.08]
[-0.112  0.18 ] [-0.032 -0.056]
[-0.1088  0.1856] [-0.024 -0.04 ]
[-0.1064  0.1896] [-0.0176 -0.0288]
[-0.10464  0.19248] [-0.0128 -0.0208]
[-0.10336  0.19456] [-0.00928 -0.01504]
[-0.102432  0.196064] [-0.00672 -0.01088]
[-0.10176   0.197152] [-0.004864 -0.007872]
[-0.1012736  0.1979392] [-0.00352  -0.005696]
[-0.1009216  0.1985088] [-0.0025472 -0.0041216]
[-0.10066688  0.19892096] [-0.0018432 -0.0029824]
[-0.10048256  0.1992192 ] [-0.00133376 -0.00215808]
[-0.10034918  0.19943501] [-0.00096512 -0.0015616 ]
[-0.10025267  0.19959117] [-0.00069837 -0.00112998]
[-0.10018284  0.19970417] [-0.00050534 -0.00081766]
[-0.1001323   0.19978593] [-0.00036567 -0.00059167]
[-0.10009573  0.1998451 ] [-0.0002646  -0.00042813]
[-0.10006927  0.19988791] [-0.00019147 -0.0003098 ]
[-0.10005013  0.19991889] [-0.00013855 -0.00022417]
[-0.10003627  0.19994131] [-0.00010025 -0.00016221]
[-0.10002625  0.19995753] [ -7.25442560e-05  -1.17379072e-04]
[-0.10001899  0.19996927] [ -5.24935168e-05  -8.49362944e-05]
[-0.10001374  0.19997776] [ -3.79846656e-05  -6.14604800e-05]
[-0.10000994  0.19998391] [ -2.74859622e-05  -4.44732211e-05]
[-0.1000072   0.19998836] [ -1.98890291e-05  -3.21811251e-05]
[-0.10000521  0.19999157] [ -1.43918367e-05  -2.32864809e-05]
[-0.10000377  0.1999939 ] [ -1.04140308e-05  -1.68502559e-05]
[-0.10000273  0.19999559] [ -7.53566351e-06  -1.21929597e-05]
[-0.10000197  0.19999681] [ -5.45285734e-06  -8.82290852e-06]
[-0.10000143  0.19999769] [ -3.94572464e-06  -6.38431658e-06]
[-0.10000103  0.19999833] [ -2.85515317e-06  -4.61973488e-06]
[-0.10000075  0.19999879] [ -2.06600824e-06  -3.34287156e-06]
[-0.10000054  0.19999912] [ -1.49497761e-06  -2.41892458e-06]
[-0.10000039  0.19999937] [ -1.08177596e-06  -1.75035027e-06]
[-0.10000028  0.19999954] [ -7.82780439e-07  -1.26656536e-06]
[-0.1000002   0.19999967] [ -5.66425247e-07  -9.16495301e-07]
[-0.10000015  0.19999976] [ -4.09869159e-07  -6.63182230e-07]
[-0.10000011  0.19999983] [ -2.96584109e-07  -4.79883170e-07]
[-0.10000008  0.19999987] [ -2.14610278e-07  -3.47246724e-07]
[-0.10000006  0.19999991] [ -1.55293456e-07  -2.51270090e-07]
[-0.10000004  0.19999993] [ -1.12371400e-07  -1.81820745e-07]
[-0.10000003  0.19999995] [ -8.13127092e-08  -1.31566727e-07]
[-0.10000002  0.19999997] [ -5.88384292e-08  -9.52025780e-08]
[-0.10000002  0.19999998] [ -4.25758873e-08  -6.88892328e-08]
[-0.10000001  0.19999998] [ -3.08082013e-08  -4.98487172e-08]
[-0.10000001  0.19999999] [ -2.22930239e-08  -3.60708705e-08]
[-0.10000001  0.19999999] [ -1.61313838e-08  -2.61011271e-08]
[-0.1         0.19999999] [ -1.16727790e-08  -1.88869530e-08]
[-0.1  0.2] [ -8.44650216e-09  -1.36667275e-08]
[-0.1  0.2] [ -6.11194650e-09  -9.88933702e-09]
[-0.1  0.2] [ -4.42264603e-09  -7.15599147e-09]
[-0.1  0.2] [ -3.20025673e-09  -5.17812415e-09]
[-0.1  0.2] [ -2.31572739e-09  -3.74692588e-09]
[-0.1  0.2] [ -1.67567604e-09  -2.71130096e-09]
[-0.1  0.2] [ -1.21253074e-09  -1.96191585e-09]
[-0.1  0.2] [ -8.77395490e-10  -1.41965562e-09]
[-0.1  0.2] [ -6.34889252e-10  -1.02727249e-09]
[-0.1  0.2] [ -4.59410288e-10  -7.43341388e-10]
[-0.1  0.2] [ -3.32432304e-10  -5.37886846e-10]
[-0.1  0.2] [ -2.40550468e-10  -3.89218546e-10]
[-0.1  0.2] [ -1.74063874e-10  -2.81641155e-10]
[-0.1  0.2] [ -1.25953692e-10  -2.03797423e-10]
[-0.1  0.2] [ -9.11408726e-11  -1.47469148e-10]
[-0.1  0.2] [ -6.59503563e-11  -1.06709641e-10]
[-0.1  0.2] [ -4.77222706e-11  -7.72159003e-11]
[-0.1  0.2] [ -3.45319329e-11  -5.58739721e-11]
[-0.1  0.2] [ -2.49875676e-11  -4.04307698e-11]
[-0.1  0.2] [ -1.80810922e-11  -2.92559310e-11]
[-0.1  0.2] [ -1.30837563e-11  -2.11697326e-11]
[-0.1  0.2] [ -9.46753786e-12  -1.53186352e-11]
[-0.1  0.2] [ -6.85074220e-12  -1.10846887e-11]
[-0.1  0.2] [ -4.95714580e-12  -8.02091726e-12]
[-0.1  0.2] [ -3.58690855e-12  -5.80402393e-12]
[-0.1  0.2] [ -2.59570143e-12  -4.19986268e-12]
[-0.1  0.2] [ -1.87827531e-12  -3.03901349e-12]
[-0.1  0.2] [ -1.35913503e-12  -2.19901874e-12]
[-0.1  0.2] [ -9.83435555e-13  -1.59128266e-12]
[-0.1  0.2] [ -7.11652959e-13  -1.15141230e-12]
[-0.1  0.2] [ -5.14921439e-13  -8.33222380e-13]
[-0.1  0.2] [ -3.72590847e-13  -6.02962125e-13]
[-0.1  0.2] [ -2.69562150e-13  -4.36317649e-13]
[-0.1  0.2] [ -1.95399252e-13  -3.15747428e-13]
[-0.1  0.2] [ -1.41220369e-13  -2.28483898e-13]
[-0.1  0.2] [ -1.02140518e-13  -1.65312208e-13]
[-0.1  0.2] [ -7.39408534e-14  -1.19571020e-13]
[-0.1  0.2] [ -5.32907052e-14  -8.64863736e-14]
[-0.1  0.2] [ -3.86357613e-14  -6.25055563e-14]
[-0.1  0.2] [ -2.81996648e-14  -4.52970994e-14]
[-0.1  0.2] [ -2.04281037e-14  -3.28626015e-14]
[-0.1  0.2] [ -1.46549439e-14  -2.37587727e-14]
[-0.1  0.2] [ -1.06581410e-14  -1.72084569e-14]
[-0.1  0.2] [ -7.54951657e-15  -1.24344979e-14]
[-0.1  0.2] [ -5.55111512e-15  -8.88178420e-15]
[-0.1  0.2] [ -3.99680289e-15  -6.43929354e-15]
[-0.1  0.2] [ -2.88657986e-15  -4.77395901e-15]
[-0.1  0.2] [ -1.99840144e-15  -3.44169138e-15]
[-0.1  0.2] [ -1.55431223e-15  -2.44249065e-15]
[-0.1  0.2] [ -1.11022302e-15  -1.77635684e-15]
[-0.1  0.2] [ -8.88178420e-16  -1.33226763e-15]
[-0.1  0.2] [ -6.66133815e-16  -9.99200722e-16]
[-0.1  0.2] [ -4.44089210e-16  -6.66133815e-16]
[-0.1  0.2] [ -2.22044605e-16  -5.55111512e-16]
[-0.1  0.2] [ -2.22044605e-16  -3.33066907e-16]
[-0.1  0.2] [  0.00000000e+00  -3.33066907e-16]
[-0.1  0.2] [ -2.22044605e-16  -2.22044605e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
[-0.1  0.2] [  0.00000000e+00  -1.11022302e-16]
$$ E_i(w) \equiv -l_i(w) = - y_i x_i^\top w + \text{logsumexp}(0, x_i^\top w) = - y_i z_i + \text{logsumexp}(0, z_i) $$

In [12]:
x = np.array([2,-1,5,-3,1])
y =  np.array([1, 0, 1, 0, 1])

w = 0.1

E = np.sum(-y*x*w + np.log(1+np.exp(x*w)))


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-c491118d9ae9> in <module>()
      4 w = 0.1
      5 
----> 6 E = np.sum(-y*x*w + np.log(1+np.exp(x*w)))
      7 
      8 

TypeError: bad operand type for unary -: 'list'

In [15]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sc
import pandas as pd

df_iris = pd.read_csv(u'/Users/cemgil/src/ipynb/notes/data/iris.txt',sep=' ')

In [24]:
y = np.array(df_iris['c'])
y[y>1] = 0

X = np.array(df_iris[['sl','sw','pl','pw']])

In [27]:
w = np.ones(4)

E = 0
for i in range(len(y)):
    f = X[i,:].dot(w)
    E += -y[i]*f + np.log(1+np.exp(f))

In [39]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np

import torch
import torch.autograd
from torch.autograd import Variable

y = torch.Tensor(np.array(df_iris['c'])).double()
y[y>1] = 0

X = Variable(torch.Tensor(np.array(df_iris[['sl','sw','pl','pw']])).double(), requires_grad=False)
w = Variable(20*torch.randn(4).double(), requires_grad=True)
# learning rate
eta = 0.0005

for epoch in range(1000):
    for i in range(X.shape[0]):
        f = torch.matmul(X[i,:], w)
        E = -y[i]*f + torch.log(1+torch.exp(f))
        
        # Compute the gradients by automated differentiation
        E.backward()
    
    #print(E.data)
    
    # For each adjustable parameter 
    # Move along the negative gradient direction
    w.data.add_(-eta * w.grad.data)
    
    
    # Reset the gradients, as otherwise they are accumulated in param.grad
    w.grad.zero_()
    
ws = w.data
print(ws)


 -5.8943
-22.2856
 20.1330
-28.4639
[torch.DoubleTensor of size 4]


In [ ]:
x = [4.9, 3.7, 1.4, 0.3 ]

In [42]:
X.shape[0]


Out[42]:
150

In [138]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np

import torch
import torch.autograd
from torch.autograd import Variable


x = Variable(20*torch.randn(1).double(), requires_grad=True)
# learning rate
eta = 0.0005

for epoch in range(1000):
    ## Compute the forward pass
    #f = torch.matmul(A, w)
    #f = 3*x[0]**2 + 2*x[1]**2 - 2*x[0]*x[1] + x[0] - x[1]
    #f = x**2 - 2*x
    f = torch.sin(x)*x**2
    
    # Compute the gradients by automated differentiation
    f.backward()
    
    # For each adjustable parameter 
    # Move along the negative gradient direction
    x.data.add_(-eta * x.grad.data)
    
    
    # Reset the gradients, as otherwise they are accumulated in param.grad
    x.grad.zero_()
    
print(epoch,':',f.data[0])
xs = x.data
fs = f.data[0]
print(x.data)


999 : -63.634981951554465

-8.0962
[torch.DoubleTensor of size 1]


In [139]:
x = np.linspace(-30,30,100)
plt.plot(x, np.sin(x)*x**2)
plt.plot(xs[0], fs, 'ro')
plt.show()



In [140]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np

import torch
import torch.autograd
from torch.autograd import Variable

y = [1,2,3,-1]
x = [-2,-1,0, 1]

w = Variable(20*torch.randn(4).double(), requires_grad=True)
# learning rate
eta = 0.0005

for epoch in range(1000):
    ## Compute the forward pass
    #f = torch.matmul(A, w)
    #f = 3*x[0]**2 + 2*x[1]**2 - 2*x[0]*x[1] + x[0] - x[1]
    #f = x**2 - 2*x
    #f = torch.sin(x)*x**2
    for i in range(len(x)):
        f = w[0]*torch.sin(w[1]*x[i]) + w[2]*torch.cos(w[3]*x[i])
        E = (y[i]-f)**2
        
        # Compute the gradients by automated differentiation
        E.backward()
    
    # For each adjustable parameter 
    # Move along the negative gradient direction
    w.data.add_(-eta * w.grad.data)
    
    
    # Reset the gradients, as otherwise they are accumulated in param.grad
    w.grad.zero_()
    
ws = w.data
print(ws)


-13.6220
-24.9247
  6.1738
-17.4263
[torch.DoubleTensor of size 4]

Stochastic Gradient Descent


In [ ]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np

import torch
import torch.autograd
from torch.autograd import Variable

y = [1,2,3,-1]
x = [-2,-1,0, 1]

w = Variable(20*torch.randn(3).double(), requires_grad=True)
# learning rate
eta = 0.00005

for epoch in range(1000):
    ## Compute the forward pass
    #f = torch.matmul(A, w)
    #f = 3*x[0]**2 + 2*x[1]**2 - 2*x[0]*x[1] + x[0] - x[1]
    #f = x**2 - 2*x
    #f = torch.sin(x)*x**2
    for i in range(len(x)):
        f = w[0] + w[1]*x[i] + w[2]*x[i]**2 
        E = (y[i]-f)**2
        
        # Compute the gradients by automated differentiation
        E.backward()
    
        # For each adjustable parameter 
        # Move along the negative gradient direction
        w.data.add_(-eta * w.grad.data)
    
        # Reset the gradients, as otherwise they are accumulated in param.grad
        w.grad.zero_()
    
ws = w.data
print(ws)

In [8]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np

import torch
import torch.autograd
from torch.autograd import Variable

y = [32,29,49,13, 51, 28, 35]

w = Variable(200*torch.randn(1).double(), requires_grad=True)
# learning rate
eta = 0.005

for epoch in range(1000):
    for i in range(len(y)):
        f = w[0] 
        E = (y[i]-f)**2
        
        # Compute the gradients by automated differentiation
        E.backward()
    
        # For each adjustable parameter 
        # Move along the negative gradient direction
        w.data.add_(-eta * w.grad.data)
    
        # Reset the gradients, as otherwise they are accumulated in param.grad
        w.grad.zero_()
    
    print(float(w.data))
    
ws = w.data
print(ws)


-324.304261518025
-299.9718216623381
-277.29239764282136
-256.15369240374037
-236.45103775077305
-218.0868760869638
-200.9702773567652
-185.0164888063183
-170.14651533061152
-156.2867283296091
-143.36850113660302
-131.32786921361225
-120.10521343128988
-109.64496486509917
-99.89532964605904
-90.80803250365815
-82.33807773109206
-74.44352638924377
-67.08528864623436
-60.22693022431394
-53.834491995715865
-47.87632183420366
-42.32291788972473
-37.146782510145925
-32.322286086765125
-27.825540149430683
-23.63427908289971
-19.727749878754498
-16.08660937698407
-12.692828488423185
-9.52960292380659
-6.581269987413938
-3.8332310233094837
-1.2718791301695012
1.1154682132222442
3.3406319454155
5.414629953612223
7.3477316286805054
9.149508713992608
10.828882699864879
12.394168998272864
13.853118116572546
15.212954034099136
16.480409971664976
17.66176173106906
18.762858769698564
19.789153164088134
20.74572660584984
21.63731556364405
22.468334735780548
23.24289890957532
23.96484333569964
24.637742718404716
25.264928915652124
25.849507436791985
26.394372819477162
26.902222961952102
27.375572481682617
27.81676516647185
28.227985579713945
28.611269877248894
28.968515889378097
29.30149251796168
29.61184849512734
29.901120546959266
30.17074100258975
30.42204488636983
30.656276528235683
30.874595725002205
31.078083483091167
31.26774737112917
31.4445265089187
31.6092962174852
31.76287235322476
31.906015347613074
32.039433972478065
32.16378884948012
32.27969572117698
32.387728499869915
32.48842210932771
32.582275133459
32.66975228504806
32.75128670677782
32.82728211593376
32.898114803408035
32.964135496901925
33.025671097552376
33.083026298581295
33.13648509398259
33.186312184716996
33.232754289377546
33.276041365815516
33.31638774977555
33.35399321617805
33.389043968303696
33.42171355977809
33.45216375392165
33.480545324719905
33.506998803380114
33.53165517417089
33.55463652299013
33.576056641872704
33.59602159243121
33.61463023101947
33.6319746982193
33.648140875074176
33.66320880832875
33.67725310677991
33.6903433107019
33.70254423617462
33.71391629602014
33.72451579893647
33.73439522830982
33.74360350208581
33.75218621498646
33.76018586427219
33.76764206016682
33.77459172198741
33.78106926095007
33.78710675055688
33.79273408540774
33.79797912922328
33.802867852812
33.80742446266455
33.811671520812034
33.815630056541856
33.819319670524074
33.82275863186405
33.82596396856183
33.82895155182621
33.831736174660925
33.834331625112156
33.83675075453996
33.83900554125172
33.84110714981267
33.84306598632719
33.84489174996459
33.846593480984474
33.84817960549957
33.84965797719755
33.85103591622858
33.85232024545093
33.85351732421438
33.85463307984851
33.85567303701181
33.856642345047035
33.857545803478104
33.858387885774995
33.859172761504006
33.85990431697343
33.860586174476545
33.86122171022741
33.86181407107815
33.86236619010059
33.86288080110932
33.863360452198215
33.863807518357255
33.86422421323232
33.86461260008601
33.86497460201392
33.865312011466806
33.865626499125895
33.86591962217529
33.866192832012295
33.86644748143407
33.866684831335974
33.86690605695486
33.8671122536883
33.867304442518396
33.867483575067176
33.867650538308574
33.86780615896027
33.86795120757712
33.86808640236666
33.8682124127452
33.8683298626525
33.86843933364121
33.8685413677564
33.86863647021948
33.86872511192981
33.86880773179639
33.86888473891108
33.86895651457423
33.869023414182664
33.86908576898948
33.86914388774418
33.8691980582215
33.86924854864628
33.869295609021634
33.86933947236676
33.869380355870796
33.86941846196821
33.86945397934115
33.86948708385373
33.86951793942276
33.869546698829424
33.86957350447581
33.86959848908993
33.86962177638299
33.86964348166189
33.869663712400225
33.8696825687704
33.86970014413962
33.86971652553224
33.869731794060655
33.86974602532691
33.869759289797045
33.86977165315
33.869783176602894
33.86979391721402
33.86980392816545
33.869813259026394
33.86982195599853
33.8698300621449
33.86983761760303
33.869844659783745
33.86985122355636
33.869857341421366
33.86986304367133
33.869868358540934
33.86987331234672
33.86987792961745
33.86988223321549
33.86988624445009
33.86988998318287
33.86989346792615
33.86989671593459
33.869899743290716
33.86990256498445
33.86990519498741
33.869907646322034
33.86990993112609
33.869912060712764
33.86991404562672
33.869915895696245
33.869917620081935
33.86991922732208
33.86992072537493
33.86992212165808
33.86992342308522
33.86992463610036
33.869925766709734
33.869926820511566
33.869927802723716
33.869928718209636
33.86992957150234
33.86993036682689
33.86993110812135
33.869931799056225
33.86993244305268
33.86993304329947
33.8699336027687
33.869934124230575
33.86993461026712
33.86993506328495
33.869935485527165
33.8699358790845
33.86993624590565
33.86993658780694
33.869936906481286
33.8699372035066
33.869937480353606
33.869937738393105
33.869937978902776
33.86993820307351
33.869938412015294
33.86993860676268
33.86993878827997
33.86993895746595
33.86993911515834
33.86993926213794
33.869939399132534
33.869939526820446
33.86993964583394
33.869939756762285
33.869939860154744
33.86993995652328
33.86994004634504
33.869940130064805
33.8699402080971
33.86994028082829
33.86994034861851
33.86994041180343
33.86994047069591
33.869940525587545
33.86994057675014
33.86994062443701
33.86994066888429
33.86994071031206
33.86994074892545
33.869940784915656
33.869940818460876
33.86994084972722
33.8699408788695
33.86994090603201
33.86994093134924
33.86994095494655
33.86994097694078
33.86994099744084
33.86994101654825
33.8699410343576
33.86994105095706
33.86994106642886
33.86994108084958
33.86994109429064
33.86994110681857
33.86994111849544
33.869941129379036
33.86994113952326
33.86994114897835
33.8699411577911
33.869941166005155
33.86994117366119
33.86994118079713
33.86994118744828
33.86994119364759
33.86994119942576
33.86994120481139
33.86994120983113
33.869941214509865
33.869941218870764
33.869941222935395
33.8699412267239
33.86994123025504
33.86994123354628
33.869941236613926
33.869941239473185
33.869941242138204
33.869941244622176
33.8699412469374
33.86994124909534
33.86994125110668
33.86994125298138
33.869941254728715
33.86994125635735
33.86994125787535
33.869941259290215
33.86994126060897
33.869941261838136
33.8699412629838
33.869941264051626
33.869941265046904
33.869941265974575
33.86994126683923
33.869941267645146
33.86994126839631
33.86994126909644
33.86994126974901
33.86994127035725
33.869941270924166
33.869941271452575
33.869941271945095
33.86994127240414
33.869941272832
33.8699412732308
33.86994127360251
33.86994127394895
33.86994127427188
33.86994127457286
33.869941274853396
33.869941275114876
33.869941275358585
33.86994127558574
33.86994127579745
33.86994127599479
33.869941276178714
33.86994127635017
33.86994127650996
33.8699412766589
33.86994127679772
33.8699412769271
33.8699412770477
33.86994127716011
33.86994127726488
33.86994127736252
33.869941277453535
33.86994127753837
33.869941277617436
33.86994127769114
33.869941277759835
33.86994127782387
33.86994127788354
33.86994127793916
33.86994127799101
33.869941278039335
33.86994127808437
33.869941278126355
33.869941278165484
33.869941278201956
33.86994127823595
33.869941278267625
33.869941278297155
33.86994127832468
33.869941278350325
33.869941278374235
33.86994127839653
33.8699412784173
33.86994127843666
33.86994127845471
33.869941278471536
33.869941278487225
33.86994127850184
33.869941278515455
33.869941278528145
33.869941278539976
33.869941278551
33.86994127856127
33.86994127857085
33.86994127857979
33.869941278588115
33.86994127859587
33.8699412786031
33.86994127860983
33.86994127861611
33.86994127862196
33.869941278627415
33.8699412786325
33.86994127863724
33.869941278641654
33.86994127864577
33.869941278649605
33.86994127865319
33.86994127865652
33.86994127865963
33.86994127866253
33.86994127866524
33.86994127866775
33.86994127867011
33.86994127867231
33.86994127867435
33.86994127867625
33.86994127867803
33.86994127867968
33.86994127868122
33.86994127868265
33.86994127868399
33.86994127868524
33.86994127868639
33.86994127868747
33.86994127868848
33.86994127868942
33.86994127869029
33.86994127869111
33.869941278691876
33.86994127869259
33.86994127869326
33.86994127869388
33.86994127869445
33.86994127869498
33.869941278695485
33.86994127869595
33.86994127869637
33.869941278696786
33.869941278697155
33.86994127869752
33.86994127869784
33.86994127869815
33.86994127869842
33.86994127869868
33.86994127869894
33.86994127869916
33.86994127869937
33.869941278699585
33.86994127869976
33.869941278699926
33.86994127870009
33.86994127870025
33.869941278700395
33.869941278700516
33.86994127870063
33.86994127870074
33.86994127870086
33.86994127870097
33.86994127870106
33.86994127870115
33.86994127870121
33.869941278701276
33.86994127870134
33.869941278701404
33.86994127870147
33.86994127870153
33.869941278701596
33.86994127870166
33.86994127870171
33.86994127870175
33.86994127870179
33.86994127870182
33.86994127870186
33.86994127870188
33.869941278701894
33.86994127870191
33.86994127870192
33.86994127870194
33.86994127870195
33.869941278701965
33.86994127870198
33.869941278701994
33.86994127870201
33.86994127870202
33.869941278702036
33.86994127870205
33.869941278702065
33.86994127870208
33.86994127870209
33.86994127870211
33.86994127870212
33.869941278702136
33.86994127870215
33.869941278702164
33.86994127870218
33.86994127870219
33.86994127870221
33.86994127870222
33.869941278702235
33.86994127870225
33.869941278702264
33.86994127870228
33.86994127870229
33.869941278702306
33.86994127870232
33.869941278702335
33.86994127870235
33.869941278702356
33.86994127870236
33.86994127870237
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238
33.86994127870238

 33.8699
[torch.DoubleTensor of size 1]


In [11]:
x = np.linspace(-30,30,100)

plt.plot(x, np.log(1/(np.exp(-x)+1)))
plt.show()



In [7]:
np.mean(y)


Out[7]:
33.857142857142854

In [43]:
f_est = f(features, w)

c = np.round(f_est)

conf_mat = np.zeros((3,3))
for i in range(len(c)):
    conf_mat[target[i]-1, int(c[i])-1] += 1

In [44]:
conf_mat


Out[44]:
array([[ 50.,   0.,   0.],
       [  0.,  48.,   2.],
       [  0.,   4.,  46.]])

In [46]:
acc = np.sum(np.diag(conf_mat))/np.sum(conf_mat)
print(acc)


0.96

Mnist example


In [152]:
from matplotlib.pylab import plt
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')

In [153]:
mnist


Out[153]:
{'COL_NAMES': ['label', 'data'],
 'DESCR': 'mldata.org dataset: mnist-original',
 'data': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ..., 
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
 'target': array([ 0.,  0.,  0., ...,  9.,  9.,  9.])}

In [161]:
idx = 50000

x = mnist['data'][idx]

#x[x>0] = 1
plt.imshow(x.reshape(28,28), cmap='gray_r')
plt.show()



In [41]:
features = mnist['data'][0:-1:50].copy()
target = mnist['target'][0:-1:50].copy()
#target[target>0] = 1

In [39]:
N = features.shape[1]
M= features.shape[0]

In [97]:
w = np.ones(N)
eta = 0.0001
for epoch in range(500):
    f_est = f(features, w)
    e = (target - f_est)/(M*N)
    dE = -features.T.dot(e)
    
    w = w - eta*dE
    if epoch%10==0:
        print(E(target, f_est))
        #print(w)


5.3216370447e+13
461149906829.0
297311795147.0
217905092168.0
169736867488.0
138634722304.0
117306675851.0
101881946109.0
90225645816.1
81105698935.5
73774709890.5
67755760545.3
62729713060.4
58473914811.4
54827528510.9
51671054434.3
48913683731.6
46485146542.0
44330250760.4
42405100378.2
40674404087.8
39109516494.0
37686986229.4
36387463393.4
35194866797.7
34095742158.4
33078762562.0
32134336194.5
31254295779.2
30431650843.1
29660388703.8
28935313541.4
28251915460.0
27606263331.1
26994916618.3
26414852448.7
25863404999.6
25338214890.3
24837186744.4
24358453455.2
23900345978.7
23461367703.5
23040172625.7
22635546701.1
22246391859.4
21871712256.7
21510602417.6
21162236975.1
20825861769.9
20500786105.3

In [98]:
f_est = f(features, w)

c = np.round(f_est)

conf_mat = np.zeros((2,2))
for i in range(len(c)):
    ii = min(int(c[i]), 1)
    ii = max(0, ii)
    conf_mat[int(target[i])-1, ii] += 1

In [99]:
plt.imshow(conf_mat)
plt.show()



In [101]:
conf_mat


Out[101]:
array([[ 32606.,  30491.],
       [  4582.,   2321.]])

In [100]:
plt.plot(features.dot(w), target, 'o')


Out[100]:
[<matplotlib.lines.Line2D at 0x1a1dc84a20>]

In [33]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np

import torch
import torch.autograd
from torch.autograd import Variable



def sigmoid(x):
    return 1./(1+np.exp(-x))


sizes = [1,40,1]

x = 3

W1 = np.random.randn(sizes[1], sizes[0])
b1 = np.random.randn(sizes[1],1)
W2 = np.random.randn(sizes[2], sizes[1])
b2 = np.random.randn(sizes[2],1)


def nnet(x, W1, b1, W2, b2):
    y1 = W1.dot(x) + b1
    f1 = sigmoid(y1)
    y2 = W2.dot(f1) + b2
    f2 = sigmoid(y2)
    return f2

nnet(-3.1, W1, b1, W2, b2)

X = np.linspace(-50,50,100)
F = np.zeros_like(X)

for i in range(len(X)):
    F[i] = nnet(X[i], W1, b1, W2, b2)

plt.plot(X, F)
plt.show()



In [43]:
plt.plot(target)


Out[43]:
[<matplotlib.lines.Line2D at 0x1a1701fe48>]

In [135]:
sz = (3,5)
th = np.random.randn(*sz)
c = np.random.choice(range(sz[1]),size=sz[0])

In [136]:
inp = Variable(torch.FloatTensor(th), requires_grad=True)
target = Variable(torch.LongTensor(c), requires_grad=False)

In [151]:
#cross_entropy = torch.nn.functional.cross_entropy
CE_Loss = torch.nn.CrossEntropyLoss(reduce=False)
E = CE_Loss(inp, target)
print(E)
#E.backward()


Variable containing:
 2.8778
 2.0622
 0.7726
[torch.FloatTensor of size 3]


In [148]:
from functools import reduce 

for i,j in enumerate(c):
    res = -th[i,j] + reduce(np.logaddexp, th[i,:])
    print(res)


2.87783090587
2.06221449644
0.77263602782

In [145]:



Out[145]:
-9.6867383109749543
$$ \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) $$$$ = -x[class] + \log\left(\sum_j \exp(x[j])\right) $$

In [129]:
torch.nn.BCELoss


Out[129]:
array([[ 0.14261972,  0.52072644,  0.25314187,  0.82356005,  0.79203528],
       [ 0.11038007,  0.27908824,  0.76113961,  0.29513401,  0.7812479 ],
       [ 0.71271026,  0.90600794,  0.62578133,  0.06017429,  0.39552562]])

In [113]:
p = [0.2,0.3,0.5]
torch.multinomial(torch.Tensor(p), 100, replacement=True)


Out[113]:
 1
 0
 0
 2
 1
 2
 2
 0
 2
 1
 0
 2
 2
 1
 1
 2
 1
 2
 2
 2
 2
 1
 0
 2
 2
 1
 2
 2
 2
 0
 1
 0
 1
 2
 2
 2
 2
 2
 2
 2
 0
 1
 2
 1
 2
 0
 0
 1
 1
 1
 2
 2
 2
 1
 2
 2
 0
 0
 1
 2
 1
 1
 2
 0
 0
 1
 2
 2
 2
 0
 0
 0
 2
 2
 1
 0
 2
 2
 2
 0
 2
 2
 1
 2
 1
 1
 0
 0
 2
 2
 1
 2
 2
 2
 2
 0
 2
 0
 0
 2
[torch.LongTensor of size 100]