numpy版のNN


In [6]:
import numpy as np

# N: バッチサイズ
N, D_in, H, D_out = 64, 1000, 100, 10

x = np.random.randn(N, D_in)   # (64, 1000)
y = np.random.randn(N, D_out)  # (64, 10)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # forward pass
    h = x.dot(w1)              # (64, 1000) * (1000, 100) = (64, 100)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)    # (64, 100) * (100, 10) = (64, 10)
    
    # compute loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    # backward pass
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    # update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2


0 32757066.0214
1 31592742.8799
2 35457112.0985
3 37779330.1688
4 33351651.0168
5 22702054.3509
6 12124512.7044
7 5724138.93169
8 2827098.35565
9 1637826.53972
10 1123090.87116
11 862480.624867
12 703663.47134
13 591248.819214
14 504331.184857
15 434005.94821
16 375807.802794
17 326875.918559
18 285497.953608
19 250246.140994
20 220067.870583
21 194112.992556
22 171676.117389
23 152240.021369
24 135363.982324
25 120628.157731
26 107723.493533
27 96385.6866235
28 86400.3615529
29 77572.777417
30 69761.9243962
31 62826.9461829
32 56657.9767638
33 51161.7501471
34 46252.8366224
35 41863.9361727
36 37934.3624612
37 34410.2253135
38 31247.6314374
39 28402.9882934
40 25840.6186223
41 23530.4546728
42 21448.2540177
43 19566.1061876
44 17862.5104198
45 16319.7807051
46 14920.7084017
47 13650.9504462
48 12497.9545937
49 11449.5371161
50 10495.551557
51 9627.02646568
52 8835.49424324
53 8113.57177192
54 7454.72820115
55 6853.08242569
56 6303.23374228
57 5800.51633071
58 5340.52125251
59 4919.28634451
60 4533.46524697
61 4180.04897878
62 3856.40393074
63 3559.29192821
64 3286.61403872
65 3036.12475547
66 2805.85212003
67 2594.19117864
68 2399.42848341
69 2220.09263494
70 2054.98097478
71 1902.85674902
72 1762.73038135
73 1633.76190838
74 1514.81020546
75 1405.28310378
76 1304.13741537
77 1210.70833305
78 1124.34907913
79 1044.53217809
80 970.71980313
81 902.403433033
82 839.156391328
83 780.591577073
84 726.341766052
85 676.071523278
86 629.473390998
87 586.253013496
88 546.172822973
89 508.967447616
90 474.44375357
91 442.379628967
92 412.59130529
93 384.911741672
94 359.182661554
95 335.26276712
96 313.023944654
97 292.330643879
98 273.070751594
99 255.143627725
100 238.454406518
101 222.914495139
102 208.431974684
103 194.936843581
104 182.359155624
105 170.632195244
106 159.694608943
107 149.493635102
108 139.976021865
109 131.092005496
110 122.800829702
111 115.058382563
112 107.827477045
113 101.072526438
114 94.7622884138
115 88.8631548174
116 83.3476423874
117 78.1897493425
118 73.366072365
119 68.8533651779
120 64.630708973
121 60.6789764047
122 56.9790289988
123 53.5150790739
124 50.2711524204
125 47.2326368395
126 44.3860773982
127 41.7182246061
128 39.2183552115
129 36.8739916962
130 34.6761499447
131 32.6151164361
132 30.6814059951
133 28.8673280758
134 27.1652282517
135 25.5677078035
136 24.0681183488
137 22.6600694004
138 21.3379008467
139 20.096101334
140 18.9296096335
141 17.8335001063
142 16.8035782428
143 15.8354244697
144 14.9252020701
145 14.069550103
146 13.2646928107
147 12.5077488782
148 11.7959108533
149 11.1260980871
150 10.4957184341
151 9.90243700063
152 9.34401678436
153 8.81825983658
154 8.32321622734
155 7.8570220154
156 7.41804721785
157 7.00440047144
158 6.6147859367
159 6.24758340037
160 5.90149734704
161 5.57522903336
162 5.26766192414
163 4.97773178831
164 4.70425197101
165 4.44635724754
166 4.20311312016
167 3.97371726586
168 3.75716760925
169 3.55290814016
170 3.36009008778
171 3.17808672777
172 3.00626156285
173 2.84402768851
174 2.69087872149
175 2.54624641707
176 2.40966295201
177 2.28078600593
178 2.15903661692
179 2.0439721567
180 1.9352431807
181 1.83249119778
182 1.73536960455
183 1.64357096865
184 1.55675741186
185 1.47467058125
186 1.39704487156
187 1.32362269418
188 1.25417036829
189 1.1884845462
190 1.12633083149
191 1.0675226936
192 1.01187413972
193 0.959215705248
194 0.909367264041
195 0.862193635364
196 0.817532702778
197 0.775251365839
198 0.735207153593
199 0.697281903274
200 0.661366954226
201 0.627357969593
202 0.595138787696
203 0.564620845592
204 0.535703149756
205 0.508309630917
206 0.482348075268
207 0.457742966231
208 0.434425138631
209 0.412328252406
210 0.391383555785
211 0.371533746777
212 0.352705003779
213 0.334857427311
214 0.317928612828
215 0.301875741477
216 0.286652887342
217 0.272216425914
218 0.258521328563
219 0.245529890969
220 0.233204875154
221 0.221516218365
222 0.210422916726
223 0.199898552393
224 0.189910133027
225 0.180436139442
226 0.171439152385
227 0.162900787644
228 0.15479523837
229 0.147102580439
230 0.139799449487
231 0.132864656271
232 0.126280438926
233 0.12003035021
234 0.114095983517
235 0.108458656769
236 0.103106412616
237 0.098022500284
238 0.0931944831013
239 0.0886077484904
240 0.084250403025
241 0.0801109558076
242 0.0761796645331
243 0.0724437739865
244 0.0688939314626
245 0.0655211856047
246 0.0623167246446
247 0.0592710824546
248 0.0563766576692
249 0.0536264970513
250 0.0510120575834
251 0.0485273122547
252 0.0461650501961
253 0.0439196814473
254 0.0417850822354
255 0.0397565257702
256 0.0378275067875
257 0.0359930782151
258 0.0342489639559
259 0.0325914317693
260 0.0310151452461
261 0.0295155986595
262 0.0280896469188
263 0.0267333131812
264 0.0254437020064
265 0.0242167791276
266 0.0230497565333
267 0.0219397218499
268 0.0208841297223
269 0.0198798193515
270 0.0189243480075
271 0.0180154142901
272 0.0171508547839
273 0.0163282346083
274 0.0155454921764
275 0.0148007009429
276 0.0140919260299
277 0.0134176038214
278 0.0127759220247
279 0.0121651680457
280 0.0115839585084
281 0.0110308436429
282 0.0105045438191
283 0.0100034822518
284 0.00952658046992
285 0.00907268595169
286 0.00864069269444
287 0.0082295726873
288 0.0078381533324
289 0.00746554409506
290 0.0071107393746
291 0.0067730613193
292 0.00645150501556
293 0.00614534629912
294 0.00585386430738
295 0.00557637335062
296 0.00531218318455
297 0.00506056473878
298 0.00482098871225
299 0.00459289072768
300 0.00437572986565
301 0.00416886260047
302 0.00397184470664
303 0.00378424059867
304 0.00360555513372
305 0.00343547777143
306 0.00327345070054
307 0.00311909609691
308 0.00297206861544
309 0.00283203602425
310 0.00269867781702
311 0.00257163169561
312 0.00245062565736
313 0.00233533661463
314 0.00222553655857
315 0.00212093084074
316 0.00202126741121
317 0.00192631906685
318 0.0018358704349
319 0.00174972191445
320 0.00166762459732
321 0.00158940301443
322 0.00151489079603
323 0.00144389569323
324 0.00137627614934
325 0.00131181634741
326 0.00125039309738
327 0.00119186453028
328 0.0011360973396
329 0.00108297331152
330 0.00103233734281
331 0.000984077582536
332 0.00093809144922
333 0.000894275556244
334 0.000852520606344
335 0.000812719563647
336 0.000774791181474
337 0.000738641595788
338 0.000704193985757
339 0.000671361627262
340 0.000640064951547
341 0.000610235792792
342 0.000581805133948
343 0.000554716237999
344 0.00052888994282
345 0.000504269128531
346 0.00048080235133
347 0.000458433542485
348 0.000437122106325
349 0.000416807225222
350 0.000397431677904
351 0.000378961272103
352 0.000361353688142
353 0.000344575191657
354 0.000328575245914
355 0.000313321309193
356 0.000298778474298
357 0.000284914090786
358 0.000271701510529
359 0.00025910296586
360 0.000247089984852
361 0.000235636707202
362 0.000224716966116
363 0.000214308766602
364 0.000204381437131
365 0.000194915939561
366 0.000185891177845
367 0.000177286127318
368 0.000169083910908
369 0.000161260755822
370 0.000153800742443
371 0.000146687919754
372 0.000139906509627
373 0.000133441335841
374 0.000127274112522
375 0.000121392913871
376 0.000115784833977
377 0.000110436730984
378 0.000105338286376
379 0.000100475462183
380 9.58379321057e-05
381 9.14152544304e-05
382 8.71972322042e-05
383 8.31756665319e-05
384 7.93402817792e-05
385 7.56817987874e-05
386 7.21925609813e-05
387 6.88648368801e-05
388 6.56918696946e-05
389 6.26666576031e-05
390 5.97802996675e-05
391 5.70267770162e-05
392 5.44005653797e-05
393 5.18962061308e-05
394 4.95079057844e-05
395 4.72293707135e-05
396 4.5056326461e-05
397 4.29837489949e-05
398 4.10066940911e-05
399 3.91214064342e-05
400 3.73224232155e-05
401 3.5606502901e-05
402 3.39697578985e-05
403 3.24085145658e-05
404 3.0919796442e-05
405 2.94993352199e-05
406 2.81443415183e-05
407 2.6851775103e-05
408 2.56188786686e-05
409 2.444294437e-05
410 2.33210587667e-05
411 2.22506451069e-05
412 2.1229532848e-05
413 2.02554994547e-05
414 1.93264718211e-05
415 1.84402849965e-05
416 1.75946003621e-05
417 1.67878139646e-05
418 1.60181609196e-05
419 1.5283989092e-05
420 1.45837435404e-05
421 1.39154829133e-05
422 1.32779588941e-05
423 1.26696878451e-05
424 1.20893613306e-05
425 1.15358017269e-05
426 1.10076734031e-05
427 1.05038629325e-05
428 1.00230849635e-05
429 9.56434807346e-06
430 9.12670243943e-06
431 8.7092867912e-06
432 8.31083825663e-06
433 7.93067796947e-06
434 7.56794388344e-06
435 7.22184562167e-06
436 6.89168551638e-06
437 6.57673343726e-06
438 6.2761068527e-06
439 5.98924342104e-06
440 5.71553120965e-06
441 5.45437909521e-06
442 5.20530226066e-06
443 4.96752581341e-06
444 4.74063856056e-06
445 4.52415022611e-06
446 4.31756013473e-06
447 4.12045700705e-06
448 3.93240470161e-06
449 3.75291453271e-06
450 3.5816315656e-06
451 3.41818407211e-06
452 3.26223767318e-06
453 3.1134718234e-06
454 2.97146218721e-06
455 2.83592692223e-06
456 2.70659268981e-06
457 2.58316950483e-06
458 2.46539517913e-06
459 2.35305924336e-06
460 2.24580076062e-06
461 2.14343998115e-06
462 2.04577286099e-06
463 1.95257621172e-06
464 1.86363585864e-06
465 1.77876228936e-06
466 1.69774615841e-06
467 1.62042526664e-06
468 1.54663100705e-06
469 1.4762075483e-06
470 1.40900624908e-06
471 1.3448767356e-06
472 1.28365755784e-06
473 1.22523631687e-06
474 1.16948168594e-06
475 1.11627306739e-06
476 1.06549394037e-06
477 1.01702932095e-06
478 9.7076538249e-07
479 9.26610876884e-07
480 8.844738298e-07
481 8.44256270696e-07
482 8.05879330839e-07
483 7.69247422512e-07
484 7.34277068695e-07
485 7.00902197569e-07
486 6.6905019182e-07
487 6.38648190592e-07
488 6.09636435866e-07
489 5.81940245714e-07
490 5.5550258556e-07
491 5.30268355473e-07
492 5.06182229429e-07
493 4.83192603693e-07
494 4.61258357701e-07
495 4.40316305015e-07
496 4.20329092403e-07
497 4.0125249689e-07
498 3.83037674867e-07
499 3.65651668822e-07

PyTorch版のNN

  • dot()はmm()
  • Tはt()

In [30]:
import torch

dtype = torch.FloatTensor
#dtype = torch.cuda.FloatTensor  # GPUを使う場合

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)

w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)

learning_rate = 1e-6
for t in range(500):
    # forward pass
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    # compute loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss)
    
    # backward pass
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2


0 33519802.33498156
1 29872658.72228745
2 28232201.35337338
3 24703782.486548424
4 18743259.388961434
5 12301520.46972239
6 7339225.924230867
7 4289060.262792304
8 2623948.781733956
9 1738355.4788811116
10 1250364.4825149775
11 960000.0030733151
12 770809.2800763324
13 637018.3387828094
14 536106.0812650071
15 456549.16201951914
16 392037.97327633994
17 338774.421579702
18 294280.9317381763
19 256822.27168787987
20 225008.64031713532
21 197854.1979252888
22 174563.00484699442
23 154477.65188670252
24 137110.3964615895
25 122028.55731431302
26 108908.6724899185
27 97435.66606277274
28 87368.50717904617
29 78514.50379502052
30 70700.82946323616
31 63785.81109264992
32 57654.10851306887
33 52201.84758821846
34 47346.202351426065
35 43011.421975365025
36 39135.1979157662
37 35668.22037300737
38 32557.61342888174
39 29760.179726280614
40 27238.576571210047
41 24961.97025868032
42 22901.39659521862
43 21034.524166789844
44 19341.120538000076
45 17803.695989330685
46 16405.221539200844
47 15130.877037832786
48 13968.637909965037
49 12908.214642094323
50 11938.225804394555
51 11050.179460635196
52 10236.167405875029
53 9489.699038609278
54 8804.027246562982
55 8173.568189415775
56 7593.495377849729
57 7059.015239429675
58 6566.352209980032
59 6111.722953456567
60 5691.828351729249
61 5303.516790770943
62 4944.552688133783
63 4612.42488731187
64 4305.014616950379
65 4020.517186165558
66 3756.5913072718176
67 3511.6253212844895
68 3284.001973283487
69 3072.491143470772
70 2875.764008138751
71 2692.710372787027
72 2522.3483963156095
73 2363.512948321696
74 2215.5062951483374
75 2077.468431940022
76 1948.6996744666044
77 1828.5315082894908
78 1716.3475498296975
79 1611.5619967153334
80 1513.5804146571572
81 1421.9668709724501
82 1336.2881122715717
83 1256.1029595788636
84 1181.1155914484286
85 1110.8827252898031
86 1045.1109380788444
87 983.4857359643902
88 925.7063886522792
89 871.5368997392743
90 820.7067957720695
91 773.0138523681353
92 728.2629342635973
93 686.2484187627883
94 646.7777273730248
95 609.6910839690959
96 574.8625995534114
97 542.1427133062471
98 511.37564006008233
99 482.45592585270225
100 455.23835666932496
101 429.6280589082219
102 405.52115269667047
103 382.83764203270437
104 361.4711824249994
105 341.34671845841575
106 322.3998385770832
107 304.5397011648304
108 287.70950190458746
109 271.85801053352225
110 256.91919793808665
111 242.81978517964853
112 229.52795349198647
113 217.0013162608135
114 205.19536080411348
115 194.05134244606614
116 183.53424445041907
117 173.61571591694474
118 164.24937177533522
119 155.4054598228123
120 147.0518068214376
121 139.1662687408965
122 131.72013496333818
123 124.67867331828745
124 118.02743528906606
125 111.74034815928825
126 105.7997225062199
127 100.18519109710223
128 94.87613610923313
129 89.85589081046828
130 85.10986856720892
131 80.62336308388143
132 76.37980418004281
133 72.36363553493554
134 68.56376213835077
135 64.97132825067274
136 61.57130783043726
137 58.35190060984206
138 55.305587690606046
139 52.42261044526819
140 49.69513661929439
141 47.110524626463516
142 44.665524624032685
143 42.349981503875085
144 40.15628883028944
145 38.07747291623849
146 36.110199472350544
147 34.24680050047711
148 32.480536078759016
149 30.80863319654153
150 29.224562626544653
151 27.722061008230483
152 26.298250884249327
153 24.9509557124058
154 23.672997361038014
155 22.461437938071988
156 21.313381332603186
157 20.225119276838683
158 19.193492473376992
159 18.215034778688448
160 17.287105979140463
161 16.40767214145066
162 15.574121821305496
163 14.78291046816236
164 14.033709710601716
165 13.321914954874686
166 12.646939385085673
167 12.00763038726511
168 11.400059535014307
169 10.82401817070344
170 10.277222061465451
171 9.758835081144756
172 9.26689070972401
173 8.800124226351898
174 8.357480283997731
175 7.937170848890063
176 7.538396725467962
177 7.159980754513274
178 6.800445242195629
179 6.459281726066148
180 6.135551793489757
181 5.828257093256116
182 5.536657122151608
183 5.2597046798038
184 4.996922543141668
185 4.747115358530909
186 4.510253233628763
187 4.285423822176373
188 4.071712151497795
189 3.8685265143526557
190 3.6759494609904273
191 3.4933234111516462
192 3.3194657057064667
193 3.1545607800465696
194 2.997738677142255
195 2.8488001415179305
196 2.707485322087198
197 2.5731482750972567
198 2.4456551395446695
199 2.3244612465027217
200 2.209416854197091
201 2.1000828414629815
202 1.9960707427361264
203 1.897405586940895
204 1.8036137432074426
205 1.7144088469777978
206 1.6300393965344782
207 1.5496374360588518
208 1.4731530616555446
209 1.400467126149259
210 1.3314230912168519
211 1.2658992371496005
212 1.2035014934292967
213 1.1443823625389449
214 1.088008828798202
215 1.0344559697030604
216 0.9837045991331799
217 0.9353523808115192
218 0.8894426482482709
219 0.8456990422910815
220 0.8042526332623021
221 0.7648493124167546
222 0.7273545535303905
223 0.6917245944842012
224 0.6578412107760538
225 0.6255846956087385
226 0.594976937131122
227 0.5658803213163432
228 0.5382802557708111
229 0.511855870936671
230 0.48687567799981046
231 0.4630910910095649
232 0.440481989930511
233 0.4189334309637971
234 0.3984860903696372
235 0.37904597668418916
236 0.3605872538678412
237 0.342979231188572
238 0.32623096405635543
239 0.3103292441913985
240 0.2952522868434979
241 0.2808761467575245
242 0.2672395551943829
243 0.2542329302825195
244 0.24185235603490796
245 0.23010958502612144
246 0.21893314030409616
247 0.2082691333532738
248 0.19811513600342412
249 0.18851705508817407
250 0.1793663447614664
251 0.17066952613346675
252 0.16238682692936557
253 0.15455473496330185
254 0.14704199981200627
255 0.13987552676271076
256 0.1331164321113818
257 0.12664805577326188
258 0.12051277105793123
259 0.11471389881193095
260 0.10915678832475995
261 0.10387376503102819
262 0.09885721101357259
263 0.0940406063147805
264 0.0894931676048174
265 0.08519269009475328
266 0.08107479506894966
267 0.07715255610776905
268 0.07343482980597149
269 0.06990078217594164
270 0.06651666056487526
271 0.06331097042918499
272 0.060274372802818776
273 0.05737231944926746
274 0.05461181003964133
275 0.0519748240208453
276 0.049477235008537734
277 0.04710807024965913
278 0.044822643068431756
279 0.04266542287941566
280 0.04061355934779609
281 0.03866774391509442
282 0.03680159653670789
283 0.035036291637655426
284 0.03334682429629088
285 0.03175325688441011
286 0.030217937863354694
287 0.028780689050950103
288 0.027396455469019365
289 0.026077049012601883
290 0.024833266367955575
291 0.023642804044988752
292 0.02250997650086295
293 0.02144935095992606
294 0.02041746972021763
295 0.019434793266880926
296 0.018523770977934784
297 0.017645228126187087
298 0.016804942966873848
299 0.01600757041011658
300 0.015250588415205338
301 0.01453388649481803
302 0.013842117765789363
303 0.013180763267992016
304 0.012559404281203501
305 0.011969626436850414
306 0.011407594157821888
307 0.010868728288580876
308 0.010361678678714847
309 0.009878460563185643
310 0.009413484962257357
311 0.008978939933071861
312 0.0085578097151916
313 0.008162058362414903
314 0.007781866826276129
315 0.007424791676298392
316 0.0070832515459768874
317 0.006755848838269296
318 0.006444993213156502
319 0.0061544923813399155
320 0.005869158886097958
321 0.005605744890589515
322 0.005350296656660314
323 0.005109207326088994
324 0.004880517271764306
325 0.004659021912903016
326 0.004454611044912404
327 0.004253656196197386
328 0.004066557601510867
329 0.0038853800510941783
330 0.0037187784422974546
331 0.0035569630829315746
332 0.0034016790123304053
333 0.0032583654981401255
334 0.003117920004445829
335 0.002984025565382853
336 0.002854016989381869
337 0.0027348234696771834
338 0.0026192599941011196
339 0.0025095443389809846
340 0.002405041536743313
341 0.00230687359272036
342 0.002214928249156145
343 0.0021234736929978126
344 0.002039180638685245
345 0.0019552143643697395
346 0.0018745736055206241
347 0.0018035952704212321
348 0.00173087127624183
349 0.0016639477228803656
350 0.0016006842994862325
351 0.00153838967733666
352 0.0014794865637011156
353 0.0014226339676941535
354 0.0013703713614470758
355 0.001318847620898289
356 0.0012695277298178653
357 0.0012232220821966622
358 0.0011793379575229523
359 0.0011369445119511767
360 0.0010960899135473157
361 0.0010557352139236587
362 0.0010190521209586523
363 0.0009816684287030086
364 0.0009483489301430614
365 0.0009155840117395231
366 0.0008831587410558783
367 0.0008529282894703671
368 0.0008241091835046732
369 0.0007965066485314254
370 0.0007701554599624516
371 0.0007455710941129401
372 0.000721726058673211
373 0.0006984609953770704
374 0.0006756728690309188
375 0.0006543395272482488
376 0.0006334997625743743
377 0.0006133135986232663
378 0.0005940285408930213
379 0.0005768163984323227
380 0.0005582955879497309
381 0.0005403396776233471
382 0.000524907132466057
383 0.0005081582513181693
384 0.0004932083544134658
385 0.0004789402307995849
386 0.0004657454269794692
387 0.0004515671884509076
388 0.0004391324009427322
389 0.0004264239467155484
390 0.0004139616038950633
391 0.000402449667937288
392 0.000391100533594424
393 0.0003802354751193532
394 0.0003695609222785362
395 0.0003594383931329742
396 0.000348877747375978
397 0.00033964784308526674
398 0.0003298401872761647
399 0.00032147704635587804
400 0.000312450517303664
401 0.00030449965913548205
402 0.0002967216174721843
403 0.0002893919426896763
404 0.00028203448829108857
405 0.0002743978663064528
406 0.0002676972807276745
407 0.0002609849706238232
408 0.0002548175835291322
409 0.00024811153767811955
410 0.00024250765672600982
411 0.00023650274327942367
412 0.00023047716687268904
413 0.00022471379447019935
414 0.00021914975305849238
415 0.00021423687253542545
416 0.00020897214132226116
417 0.00020414794454927387
418 0.0001993863547718605
419 0.00019475557957809864
420 0.00019035105116484152
421 0.0001864379290024698
422 0.0001822087874031597
423 0.00017828980146461504
424 0.0001748117836740254
425 0.00017067258914045536
426 0.00016659315227303406
427 0.00016304911062160754
428 0.00015953894128478696
429 0.00015640512905601422
430 0.00015333295964406468
431 0.00015016095880160396
432 0.00014714279263838836
433 0.0001440810641674256
434 0.0001410526051056904
435 0.0001380402029223654
436 0.0001352770360503741
437 0.00013224261271289894
438 0.00013011410101099186
439 0.0001271161862473648
440 0.00012488698017590338
441 0.0001217863333489444
442 0.00012007163922342357
443 0.00011781583684736252
444 0.00011559607739315692
445 0.00011358314188421315
446 0.0001114155523467969
447 0.00010889654849662034
448 0.00010646967275358687
449 0.00010482432479164139
450 0.00010286320594761478
451 0.00010125045926358267
452 9.960929570414223e-05
453 9.799235906490789e-05
454 9.620844305581466e-05
455 9.459530710949349e-05
456 9.273197955957102e-05
457 9.127427589890325e-05
458 9.007831064523908e-05
459 8.817761231945387e-05
460 8.684583397740309e-05
461 8.5335020382174e-05
462 8.371255083484963e-05
463 8.22974811159477e-05
464 8.094163969969703e-05
465 7.960938639257897e-05
466 7.842215220583981e-05
467 7.728297942995177e-05
468 7.593238006604885e-05
469 7.479653909998885e-05
470 7.359414504477801e-05
471 7.261324151659754e-05
472 7.13220807885967e-05
473 7.025350290558799e-05
474 6.922546995237477e-05
475 6.810574995476382e-05
476 6.680429225564e-05
477 6.610008566809711e-05
478 6.499050056003874e-05
479 6.392327166034051e-05
480 6.244427063628599e-05
481 6.160321193622587e-05
482 6.0822358790735276e-05
483 5.985063854610506e-05
484 5.907645988752208e-05
485 5.816978297835951e-05
486 5.725530716667315e-05
487 5.6669281581392394e-05
488 5.599471629270936e-05
489 5.516636369974626e-05
490 5.4241925234579935e-05
491 5.353131823862545e-05
492 5.277215221846654e-05
493 5.206747732634798e-05
494 5.139430257125599e-05
495 5.090189738672646e-05
496 5.012338018581253e-05
497 4.9430048894433254e-05
498 4.884355437909105e-05
499 4.8229601279398127e-05

In [20]:
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
print(torch.mm(mat1, mat2))
print(mat1.mm(mat2))


 1.3635  3.2792 -2.0203
 0.7865  1.0616 -0.5123
[torch.FloatTensor of size 2x3]


 1.3635  3.2792 -2.0203
 0.7865  1.0616 -0.5123
[torch.FloatTensor of size 2x3]


In [23]:
a = torch.randn(4)
print(a)
print(torch.clamp(a, min=-0.5, max=0.5))


 1.3481
 0.5239
-0.4739
 0.1594
[torch.FloatTensor of size 4]


 0.5000
 0.5000
-0.4739
 0.1594
[torch.FloatTensor of size 4]


In [25]:
a = torch.randn(4)
print(a)
print(torch.pow(a, 2))


-1.2419
-1.1250
-0.3517
-1.1164
[torch.FloatTensor of size 4]


 1.5423
 1.2656
 0.1237
 1.2462
[torch.FloatTensor of size 4]


In [29]:
a = torch.randn(2, 3)
print(a)
print(torch.t(a))
print(a.t())


-1.8005  0.2771 -0.2920
-0.2676 -0.0230 -1.3677
[torch.FloatTensor of size 2x3]


-1.8005 -0.2676
 0.2771 -0.0230
-0.2920 -1.3677
[torch.FloatTensor of size 3x2]


-1.8005 -0.2676
 0.2771 -0.0230
-0.2920 -1.3677
[torch.FloatTensor of size 3x2]

Autograd

  • Tensorはnumpy.arraに当たる
  • TensorをくるんだVariableは計算グラフのノードに当たる
  • xがVariableだとするとx.dataがTensor、x.gradが勾配になる
  • TensorのメソッドはほとんどVariableでも使える
  • Tensorとの違いは勾配を自動計算できる点

In [66]:
import torch
from torch.autograd import Variable

dtype = torch.FloatTensor
#dtype = torch.cuda.FloatTensor  # GPU使う場合

N, D_in, H, D_out = 64, 1000, 100, 10

# 入力は勾配計算不要
# CNNの可視化をするときはTrueにする?
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # forward pass
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    # compute loss
    loss = (y_pred - y).pow(2).sum()
    # loss: (1,) のVariable
    # loss.data: (1,)のTensor
    # loss.data[0]: Scalar
    print(t, loss.data[0])
    
    # backward pass
    # lossに関する各Variableの勾配が求まる
    # requires_grad=TrueのVariableが勾配計算の対象
    # w1.gradとw2.gradに値が入る
    loss.backward()
    
    # update weights
    # w1.gradはVariableなのでTensor計算はdataを使う
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data

    # 重み更新したあとは蓄積した勾配をクリアする
    w1.grad.data.zero_()
    w2.grad.data.zero_()


0 37870032.0
1 34478608.0
2 32695228.0
3 27592806.0
4 19437664.0
5 11586074.0
6 6358191.5
7 3563654.25
8 2198189.25
9 1516179.0
10 1143097.875
11 912395.375
12 752971.0
13 633663.0
14 539798.625
15 463610.125
16 400742.5
17 348235.65625
18 304007.5
19 266465.09375
20 234415.390625
21 206922.6875
22 183230.328125
23 162741.234375
24 144937.5
25 129392.6328125
26 115781.6328125
27 103833.2265625
28 93326.7421875
29 84068.0234375
30 75869.8671875
31 68592.0234375
32 62121.3671875
33 56352.0234375
34 51195.7109375
35 46581.4765625
36 42440.4453125
37 38719.12890625
38 35367.265625
39 32345.12890625
40 29614.98828125
41 27145.8984375
42 24910.3203125
43 22882.93359375
44 21040.947265625
45 19366.619140625
46 17843.998046875
47 16456.990234375
48 15192.2353515625
49 14039.095703125
50 12987.9814453125
51 12026.560546875
52 11146.494140625
53 10339.19140625
54 9597.78515625
55 8916.2001953125
56 8289.037109375
57 7711.7705078125
58 7179.63671875
59 6689.1181640625
60 6235.869140625
61 5817.21044921875
62 5430.04296875
63 5071.8046875
64 4740.1455078125
65 4432.67724609375
66 4147.5458984375
67 3882.796142578125
68 3636.97509765625
69 3408.57861328125
70 3196.184326171875
71 2998.408935546875
72 2814.2763671875
73 2642.755126953125
74 2482.88037109375
75 2333.728515625
76 2194.524658203125
77 2064.54541015625
78 1943.135986328125
79 1829.6522216796875
80 1723.5206298828125
81 1624.172607421875
82 1531.16357421875
83 1444.034912109375
84 1362.3807373046875
85 1285.844482421875
86 1214.0477294921875
87 1146.6597900390625
88 1083.3560791015625
89 1023.8919677734375
90 968.0101928710938
91 915.4775390625
92 866.0459594726562
93 819.55908203125
94 775.8433837890625
95 734.6747436523438
96 695.8858642578125
97 659.318115234375
98 624.8422241210938
99 592.3289794921875
100 561.6595458984375
101 532.7123413085938
102 505.3792419433594
103 479.5636901855469
104 455.18402099609375
105 432.142578125
106 410.3601989746094
107 389.7490539550781
108 370.25726318359375
109 351.8101806640625
110 334.35235595703125
111 317.8269958496094
112 302.1764221191406
113 287.35003662109375
114 273.3050231933594
115 259.99664306640625
116 247.38290405273438
117 235.42591857910156
118 224.0830841064453
119 213.3343505859375
120 203.12445068359375
121 193.43618774414062
122 184.2409210205078
123 175.5115203857422
124 167.2244110107422
125 159.35760498046875
126 151.87757873535156
127 144.77052307128906
128 138.0179443359375
129 131.59751892089844
130 125.49559783935547
131 119.69213104248047
132 114.17386627197266
133 108.93020629882812
134 103.93760681152344
135 99.1849136352539
136 94.66242980957031
137 90.35808563232422
138 86.26091003417969
139 82.35966491699219
140 78.64615631103516
141 75.10803985595703
142 71.7395248413086
143 68.52716064453125
144 65.4678726196289
145 62.553401947021484
146 59.77412033081055
147 57.12360763549805
148 54.59821701049805
149 52.18914794921875
150 49.8921012878418
151 47.700260162353516
152 45.61074447631836
153 43.616188049316406
154 41.71348571777344
155 39.89878845214844
156 38.166934967041016
157 36.513816833496094
158 34.93598556518555
159 33.428306579589844
160 31.988161087036133
161 30.613733291625977
162 29.301597595214844
163 28.047916412353516
164 26.849842071533203
165 25.705211639404297
166 24.612102508544922
167 23.56743621826172
168 22.57015037536621
169 21.61608123779297
170 20.703645706176758
171 19.832048416137695
172 18.99795150756836
173 18.201446533203125
174 17.4395694732666
175 16.710905075073242
176 16.01413917541504
177 15.34815788269043
178 14.710489273071289
179 14.100008010864258
180 13.516216278076172
181 12.957781791687012
182 12.423283576965332
183 11.911429405212402
184 11.42110538482666
185 10.95229434967041
186 10.503878593444824
187 10.074260711669922
188 9.66293716430664
189 9.268757820129395
190 8.891319274902344
191 8.529993057250977
192 8.1835298538208
193 7.851962089538574
194 7.53405237197876
195 7.230029582977295
196 6.93812894821167
197 6.658735752105713
198 6.391169548034668
199 6.134161472320557
200 5.888348579406738
201 5.652589321136475
202 5.426687717437744
203 5.210164546966553
204 5.002322673797607
205 4.80317497253418
206 4.612399578094482
207 4.429353713989258
208 4.253411293029785
209 4.084941387176514
210 3.923436403274536
211 3.768491744995117
212 3.6200287342071533
213 3.47733473777771
214 3.340663433074951
215 3.2094614505767822
216 3.08345890045166
217 2.9627106189727783
218 2.8466126918792725
219 2.7355270385742188
220 2.62874174118042
221 2.526240825653076
222 2.4278273582458496
223 2.3333935737609863
224 2.2427544593811035
225 2.1557223796844482
226 2.0722615718841553
227 1.9920085668563843
228 1.9149503707885742
229 1.8412480354309082
230 1.770180344581604
231 1.7019550800323486
232 1.6365866661071777
233 1.5737887620925903
234 1.513175129890442
235 1.4552029371261597
236 1.3994910717010498
237 1.3458523750305176
238 1.2944509983062744
239 1.2450075149536133
240 1.1975328922271729
241 1.1518539190292358
242 1.108028531074524
243 1.0658938884735107
244 1.0253745317459106
245 0.98646479845047
246 0.9490171074867249
247 0.9131610989570618
248 0.8786035776138306
249 0.845293402671814
250 0.8133932948112488
251 0.782757043838501
252 0.7531759142875671
253 0.7248224020004272
254 0.6975113153457642
255 0.671303927898407
256 0.6460397243499756
257 0.6218372583389282
258 0.59846431016922
259 0.576130747795105
260 0.5544269680976868
261 0.5337149500846863
262 0.5138075351715088
263 0.49458929896354675
264 0.4761113226413727
265 0.45834872126579285
266 0.44127872586250305
267 0.4248211085796356
268 0.4090152382850647
269 0.3938395082950592
270 0.3792349696159363
271 0.3651185929775238
272 0.3515734076499939
273 0.33853527903556824
274 0.3259619176387787
275 0.3138852119445801
276 0.3022891581058502
277 0.29116347432136536
278 0.2803885340690613
279 0.2700197398662567
280 0.260076642036438
281 0.25046107172966003
282 0.24128766357898712
283 0.23238122463226318
284 0.2238440364599228
285 0.21561096608638763
286 0.20767469704151154
287 0.20003192126750946
288 0.1927458643913269
289 0.18564538657665253
290 0.1788666993379593
291 0.17232288420200348
292 0.16597211360931396
293 0.15994659066200256
294 0.15411631762981415
295 0.14846326410770416
296 0.14304590225219727
297 0.1378142237663269
298 0.1327887922525406
299 0.12794089317321777
300 0.12328280508518219
301 0.11880230158567429
302 0.11447075754404068
303 0.11031948775053024
304 0.10629594326019287
305 0.1024278923869133
306 0.0986919030547142
307 0.09514228999614716
308 0.09169638901948929
309 0.0883614718914032
310 0.08514798432588577
311 0.08207543939352036
312 0.07913155108690262
313 0.07627607882022858
314 0.07351433485746384
315 0.07084961980581284
316 0.06830120831727982
317 0.06585732847452164
318 0.06346788257360458
319 0.06118668243288994
320 0.05896943807601929
321 0.0568489134311676
322 0.0548076257109642
323 0.052835144102573395
324 0.05094197764992714
325 0.049116749316453934
326 0.04735731706023216
327 0.04566659405827522
328 0.04404222220182419
329 0.042475927621126175
330 0.0409347303211689
331 0.039473727345466614
332 0.03807145357131958
333 0.036708444356918335
334 0.03541264683008194
335 0.03414784371852875
336 0.03293178230524063
337 0.031742971390485764
338 0.03061896190047264
339 0.029534315690398216
340 0.028492679819464684
341 0.02747279591858387
342 0.026508735492825508
343 0.025570206344127655
344 0.024662354961037636
345 0.023791035637259483
346 0.022958965972065926
347 0.022154852747917175
348 0.02137317694723606
349 0.020614001899957657
350 0.019883684813976288
351 0.019185859709978104
352 0.018517162650823593
353 0.017862223088741302
354 0.017233485355973244
355 0.016630860045552254
356 0.01604647748172283
357 0.015494116581976414
358 0.014952479861676693
359 0.014432663097977638
360 0.013926276005804539
361 0.013451041653752327
362 0.012983896769583225
363 0.01253658439964056
364 0.012102684937417507
365 0.011680260300636292
366 0.011279590427875519
367 0.010890264995396137
368 0.010516932234168053
369 0.010156218893826008
370 0.009808727540075779
371 0.009476130828261375
372 0.009147753939032555
373 0.008837738074362278
374 0.008536809124052525
375 0.008246975019574165
376 0.007975606247782707
377 0.007706505246460438
378 0.007443699520081282
379 0.0071933688595891
380 0.00695499312132597
381 0.006723369006067514
382 0.006494562607258558
383 0.0062788971699774265
384 0.006075172685086727
385 0.005868594162166119
386 0.0056759207509458065
387 0.0054891896434128284
388 0.005308847408741713
389 0.005139398854225874
390 0.004966994747519493
391 0.004804356023669243
392 0.0046522971242666245
393 0.004500924609601498
394 0.004354889504611492
395 0.0042158872820436954
396 0.004080577287822962
397 0.003948947414755821
398 0.003824489889666438
399 0.003701734123751521
400 0.003587280632928014
401 0.0034729016479104757
402 0.0033653422724455595
403 0.0032620830461382866
404 0.0031601497903466225
405 0.0030606903601437807
406 0.002965899184346199
407 0.002873847261071205
408 0.0027869725599884987
409 0.002700116718187928
410 0.0026208411436527967
411 0.002541628200560808
412 0.0024645712692290545
413 0.002393615897744894
414 0.002318289829418063
415 0.002248723292723298
416 0.0021844441071152687
417 0.0021217181347310543
418 0.0020596361719071865
419 0.0019973155576735735
420 0.0019406459759920835
421 0.0018820848781615496
422 0.0018288587452843785
423 0.0017751905834302306
424 0.001726602902635932
425 0.0016775239491835237
426 0.0016285416204482317
427 0.001584334415383637
428 0.001538667711429298
429 0.0014962786808609962
430 0.0014547640457749367
431 0.0014151422074064612
432 0.0013780660228803754
433 0.0013405780773609877
434 0.0013046213425695896
435 0.0012677927734330297
436 0.0012345182476565242
437 0.0012011844664812088
438 0.001167984795756638
439 0.0011366186663508415
440 0.001108652912080288
441 0.001080093439668417
442 0.0010519508505240083
443 0.001024694531224668
444 0.000999605399556458
445 0.0009735235362313688
446 0.0009507114300504327
447 0.0009254908072762191
448 0.0009019618155434728
449 0.0008802818483673036
450 0.0008585642208345234
451 0.0008375356555916369
452 0.0008156609837897122
453 0.0007956983172334731
454 0.000776375294663012
455 0.000758312875404954
456 0.0007407242665067315
457 0.000723997363820672
458 0.0007060692296363413
459 0.0006889899959787726
460 0.0006733941263519228
461 0.0006581827183254063
462 0.0006437848205678165
463 0.0006282631657086313
464 0.0006135209114290774
465 0.0005995312239974737
466 0.0005868279840797186
467 0.0005738042527809739
468 0.0005611792439594865
469 0.000549312389921397
470 0.0005370269063860178
471 0.0005258559831418097
472 0.0005145852919667959
473 0.0005032621556892991
474 0.0004927353584207594
475 0.00048276627785526216
476 0.0004718914278782904
477 0.00046243672841228545
478 0.0004521625814959407
479 0.00044225770398043096
480 0.0004333268734626472
481 0.0004239983973093331
482 0.0004159581148996949
483 0.0004076466429978609
484 0.0003996106097474694
485 0.0003921102616004646
486 0.0003834785311482847
487 0.0003759614482987672
488 0.00036885563167743385
489 0.00036155630368739367
490 0.0003546611696947366
491 0.00034819930442608893
492 0.00034078393946401775
493 0.00033466453896835446
494 0.00032749975798651576
495 0.00032224113238044083
496 0.0003160535707138479
497 0.0003099766909144819
498 0.0003042948665097356
499 0.00029854648164473474

In [68]:
w2.grad


Out[68]:
Variable containing:
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
       ...          ⋱          ...       
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
[torch.FloatTensor of size 100x10]

New Function


In [70]:
import torch
from torch.autograd import Variable

class MyReLU(torch.autograd.Function):
    def forward(self, input):
        # Tensorを受け取ってTensorを返す
        self.save_for_backward(input)
        return input.clamp(min=0)

    def backward(self, grad_output):
        # lossのこのユニットの出力に対する勾配を受け取る
        # lossのこのユニットの入力に対する購買を返す
        input, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

dtype = torch.FloatTensor

N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    relu = MyReLU()
    y_pred = relu(x.mm(w1)).mm(w2)    
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])
    loss.backward()
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data
    w1.grad.data.zero_()
    w2.grad.data.zero_()


0 28923664.0
1 21498286.0
2 18640954.0
3 17113786.0
4 15503434.0
5 13195715.0
6 10424069.0
7 7645163.0
8 5322102.0
9 3596325.25
10 2423709.0
11 1659939.375
12 1172385.25
13 858418.6875
14 652287.375
15 512329.5625
16 413720.1875
17 341474.09375
18 286561.5625
19 243556.640625
20 209006.109375
21 180666.84375
22 157067.765625
23 137187.0625
24 120301.1015625
25 105850.234375
26 93404.109375
27 82638.6953125
28 73295.3984375
29 65162.32421875
30 58063.671875
31 51841.375
32 46374.22265625
33 41566.15625
34 37321.76171875
35 33570.828125
36 30258.173828125
37 27318.052734375
38 24701.7265625
39 22367.705078125
40 20283.693359375
41 18418.294921875
42 16745.623046875
43 15242.7880859375
44 13891.72265625
45 12675.8466796875
46 11579.302734375
47 10588.8564453125
48 9693.3740234375
49 8882.716796875
50 8147.6982421875
51 7480.39501953125
52 6874.16259765625
53 6323.63720703125
54 5822.29150390625
55 5365.216796875
56 4947.8427734375
57 4566.49072265625
58 4217.5107421875
59 3898.070068359375
60 3605.19580078125
61 3336.48046875
62 3089.747314453125
63 2863.11279296875
64 2655.2666015625
65 2463.986083984375
66 2287.767333984375
67 2125.36083984375
68 1975.6036376953125
69 1837.2276611328125
70 1709.358154296875
71 1591.2010498046875
72 1481.877197265625
73 1380.703125
74 1286.96728515625
75 1200.1004638671875
76 1119.5401611328125
77 1044.79345703125
78 975.436279296875
79 911.0248413085938
80 851.1502075195312
81 795.57275390625
82 743.9124145507812
83 695.8291015625
84 651.05029296875
85 609.3773803710938
86 570.5382080078125
87 534.3441772460938
88 500.5921936035156
89 469.1074523925781
90 439.7309265136719
91 412.3121032714844
92 386.7171325683594
93 362.801513671875
94 340.4529113769531
95 319.55841064453125
96 300.01861572265625
97 281.7463073730469
98 264.64361572265625
99 248.6382293701172
100 233.6525421142578
101 219.6236572265625
102 206.48904418945312
103 194.17471313476562
104 182.6298828125
105 171.8096466064453
106 161.65875244140625
107 152.13641357421875
108 143.20468139648438
109 134.82394409179688
110 126.95314025878906
111 119.56464385986328
112 112.6275405883789
113 106.11439514160156
114 99.98987579345703
115 94.23602294921875
116 88.8243408203125
117 83.73809814453125
118 78.95480346679688
119 74.45486450195312
120 70.22343444824219
121 66.24131774902344
122 62.494178771972656
123 58.96847152709961
124 55.647743225097656
125 52.520992279052734
126 49.57637405395508
127 46.80282974243164
128 44.1906852722168
129 41.72941970825195
130 39.40969467163086
131 37.223087310791016
132 35.16261291503906
133 33.21999740600586
134 31.38882064819336
135 29.66090202331543
136 28.031343460083008
137 26.49390983581543
138 25.0439510345459
139 23.675399780273438
140 22.384441375732422
141 21.165136337280273
142 20.01498031616211
143 18.928407669067383
144 17.90375328063965
145 16.935808181762695
146 16.020719528198242
147 15.157450675964355
148 14.341598510742188
149 13.570789337158203
150 12.84309196472168
151 12.154644966125488
152 11.504240036010742
153 10.889629364013672
154 10.309209823608398
155 9.760289192199707
156 9.241244316101074
157 8.750333786010742
158 8.28641414642334
159 7.847211837768555
160 7.432059288024902
161 7.039107799530029
162 6.668087959289551
163 6.316605567932129
164 5.984142303466797
165 5.669780731201172
166 5.372228622436523
167 5.090503692626953
168 4.823985576629639
169 4.571763515472412
170 4.33288049697876
171 4.106768608093262
172 3.892662763595581
173 3.6901235580444336
174 3.498110055923462
175 3.3165032863616943
176 3.144544839859009
177 2.981487989425659
178 2.8270070552825928
179 2.680805206298828
180 2.5423574447631836
181 2.4111075401306152
182 2.2868545055389404
183 2.1690704822540283
184 2.057396650314331
185 1.9516382217407227
186 1.851377248764038
187 1.7564656734466553
188 1.6665834188461304
189 1.5812724828720093
190 1.5002679824829102
191 1.4235821962356567
192 1.3509317636489868
193 1.282096266746521
194 1.2166204452514648
195 1.1546019315719604
196 1.0959266424179077
197 1.0402212142944336
198 0.987381637096405
199 0.9372941851615906
200 0.8898499011993408
201 0.8447314500808716
202 0.8019359111785889
203 0.7614567875862122
204 0.7228724956512451
205 0.686340868473053
206 0.6517183780670166
207 0.618864119052887
208 0.5876443982124329
209 0.5581151843070984
210 0.5300303101539612
211 0.5034107565879822
212 0.47816023230552673
213 0.45404255390167236
214 0.43130481243133545
215 0.40969619154930115
216 0.38912779092788696
217 0.36962446570396423
218 0.35121211409568787
219 0.3336177170276642
220 0.31695878505706787
221 0.3011242151260376
222 0.28612715005874634
223 0.27187058329582214
224 0.2583119869232178
225 0.24545729160308838
226 0.23322661221027374
227 0.22162245213985443
228 0.2106011062860489
229 0.20013532042503357
230 0.1902325451374054
231 0.18080182373523712
232 0.1718176007270813
233 0.16333018243312836
234 0.15522423386573792
235 0.14756996929645538
236 0.14025260508060455
237 0.13333715498447418
238 0.12674397230148315
239 0.12050984799861908
240 0.11453907191753387
241 0.10885920375585556
242 0.10352634638547897
243 0.09842785447835922
244 0.09359283745288849
245 0.08896742016077042
246 0.08459114283323288
247 0.08043776452541351
248 0.07649579644203186
249 0.07275176793336868
250 0.06917980313301086
251 0.06577007472515106
252 0.0625392273068428
253 0.05948027968406677
254 0.05657285451889038
255 0.05381157249212265
256 0.05117188021540642
257 0.048677459359169006
258 0.04629391059279442
259 0.04403146356344223
260 0.0418638214468956
261 0.03982250392436981
262 0.037890806794166565
263 0.03605346009135246
264 0.034287575632333755
265 0.03263368457555771
266 0.03105229325592518
267 0.029539020732045174
268 0.028111014515161514
269 0.026739779859781265
270 0.025443635880947113
271 0.024208400398492813
272 0.02303745411336422
273 0.021934960037469864
274 0.020885055884718895
275 0.019873928278684616
276 0.018905743956565857
277 0.0180071908980608
278 0.017139796167612076
279 0.01632714457809925
280 0.015535750426352024
281 0.014795918948948383
282 0.014085466042160988
283 0.01341425720602274
284 0.012768176384270191
285 0.012165223248302937
286 0.01158919744193554
287 0.011045937426388264
288 0.010523861274123192
289 0.010019428096711636
290 0.009553903713822365
291 0.009097878821194172
292 0.00867906492203474
293 0.008272059261798859
294 0.007886571809649467
295 0.00751827098429203
296 0.0071678063832223415
297 0.006832011044025421
298 0.006516290828585625
299 0.006213999353349209
300 0.005932728294283152
301 0.005661071743816137
302 0.0054058353416621685
303 0.005157188978046179
304 0.0049183908849954605
305 0.004695625975728035
306 0.004484082106500864
307 0.004279893822968006
308 0.004091349896043539
309 0.003913091029971838
310 0.0037346675526350737
311 0.003569413209334016
312 0.0034109530970454216
313 0.0032603638246655464
314 0.003118695691227913
315 0.0029870946891605854
316 0.002855949802324176
317 0.0027328492142260075
318 0.002614900702610612
319 0.0025024819187819958
320 0.0023970911279320717
321 0.0022979453206062317
322 0.002199960872530937
323 0.0021069732028990984
324 0.002022555796429515
325 0.001938356552273035
326 0.001857837662100792
327 0.0017827516421675682
328 0.0017107422463595867
329 0.0016423908527940512
330 0.0015781833790242672
331 0.0015150345861911774
332 0.0014520447002723813
333 0.0013963355449959636
334 0.0013424447970464826
335 0.001291140099056065
336 0.0012414485681802034
337 0.00119434529915452
338 0.0011498663807287812
339 0.0011097491951659322
340 0.0010663957800716162
341 0.001028326223604381
342 0.0009905204642564058
343 0.0009535758872516453
344 0.0009200025233440101
345 0.0008878379594534636
346 0.0008560301503166556
347 0.000826462113764137
348 0.0007977659115567803
349 0.000769733393099159
350 0.0007435690495185554
351 0.000718369148671627
352 0.0006941225728951395
353 0.0006694131880067289
354 0.0006492062238976359
355 0.0006273124599829316
356 0.0006078107398934662
357 0.0005864134873263538
358 0.000568610557820648
359 0.0005507656605914235
360 0.000533539627213031
361 0.0005167921190150082
362 0.0005024982965551317
363 0.00048616030835546553
364 0.00047021987847983837
365 0.00045697440509684384
366 0.0004431252309586853
367 0.0004304220783524215
368 0.000418556242948398
369 0.0004073760355822742
370 0.0003948859521187842
371 0.0003838999255094677
372 0.00037256604991853237
373 0.0003621844807639718
374 0.0003524279163684696
375 0.00034316873643547297
376 0.00033260128111578524
377 0.0003246633568778634
378 0.0003165012749377638
379 0.0003075235290452838
380 0.00029933860059827566
381 0.0002911289921030402
382 0.0002833910984918475
383 0.00027626942028291523
384 0.0002688698295969516
385 0.00026204081950709224
386 0.00025500249466858804
387 0.00024849624605849385
388 0.0002434398338664323
389 0.00023773561406414956
390 0.00023197421978693455
391 0.00022621991229243577
392 0.00022067167446948588
393 0.0002156850678147748
394 0.00021074930555187166
395 0.00020565732847899199
396 0.0002011816977756098
397 0.0001959902874659747
398 0.00019220240938011557
399 0.00018781483231578022
400 0.00018341114628128707
401 0.0001796314027160406
402 0.00017604885215405375
403 0.00017195241525769234
404 0.00016855249123182148
405 0.00016483916260767728
406 0.00016163161490112543
407 0.00015797774540260434
408 0.0001546357525512576
409 0.00015133628039620817
410 0.00014865194680169225
411 0.00014545678277499974
412 0.00014237905270420015
413 0.00013943710655439645
414 0.00013693858636543155
415 0.00013418371963780373
416 0.00013155430497135967
417 0.00012901709123980254
418 0.00012668329873122275
419 0.00012408624752424657
420 0.00012158006575191393
421 0.00011936425289604813
422 0.00011732535494957119
423 0.00011523292778292671
424 0.000113444693852216
425 0.00011123613512609154
426 0.00010967015259666368
427 0.00010774750990094617
428 0.00010583215771475807
429 0.0001042165094986558
430 0.00010212985216639936
431 9.998519817600027e-05
432 9.803407010622323e-05
433 9.686112025519833e-05
434 9.532038529869169e-05
435 9.377708920510486e-05
436 9.235588368028402e-05
437 9.107199730351567e-05
438 8.94090480869636e-05
439 8.810464350972325e-05
440 8.64808025653474e-05
441 8.50931101012975e-05
442 8.400014485232532e-05
443 8.266579970950261e-05
444 8.142495789797977e-05
445 8.007594442460686e-05
446 7.901286880951375e-05
447 7.761896995361894e-05
448 7.623060810146853e-05
449 7.520755025325343e-05
450 7.400189497275278e-05
451 7.291082147276029e-05
452 7.184280548244715e-05
453 7.060560892568901e-05
454 6.964625208638608e-05
455 6.865674367872998e-05
456 6.775430665584281e-05
457 6.665982800768688e-05
458 6.565610237885267e-05
459 6.474139081547037e-05
460 6.378690159181133e-05
461 6.284342816798016e-05
462 6.215638859430328e-05
463 6.121573824202642e-05
464 6.0722824855474755e-05
465 5.9917143516940996e-05
466 5.913156928727403e-05
467 5.805127148050815e-05
468 5.735135346185416e-05
469 5.673127816407941e-05
470 5.5778677051421255e-05
471 5.5280055676121265e-05
472 5.4465144785353914e-05
473 5.382458766689524e-05
474 5.3100760851521045e-05
475 5.2223942475393414e-05
476 5.160309956409037e-05
477 5.1030045142397285e-05
478 5.0649734475882724e-05
479 5.0006408855551854e-05
480 4.934615935781039e-05
481 4.8492078349227086e-05
482 4.77825706184376e-05
483 4.726068436866626e-05
484 4.676151002058759e-05
485 4.611333497450687e-05
486 4.571661338559352e-05
487 4.5202796172816306e-05
488 4.466272366698831e-05
489 4.393340350361541e-05
490 4.362184336059727e-05
491 4.312535747885704e-05
492 4.247122706146911e-05
493 4.18940071540419e-05
494 4.144340709899552e-05
495 4.0874820115277544e-05
496 4.0433264075545594e-05
497 4.0167695260606706e-05
498 3.976369043812156e-05
499 3.912360989488661e-05

TensorFlow版


In [77]:
import tensorflow as tf
import numpy as np

N, D_in, H, D_out = 64, 1000, 100, 10

# placeholderは計算グラフの実行時に実際のデータで満たされる
x = tf.placeholder(tf.float32, shape=(None, D_in))
y = tf.placeholder(tf.float32, shape=(None, D_out))

# TensorFlowのVariableもPyTorchと同じ(計算グラフのノード)
w1 = tf.Variable(tf.random_normal((D_in, H)))
w2 = tf.Variable(tf.random_normal((H, D_out)))

# forward pass
# この段階では計算グラフを構築するだけで実際の数値計算はしない
h = tf.matmul(x, w1)
h_relu = tf.maximum(h, tf.zeros(1))
y_pred = tf.matmul(h_relu, w2)

# compute loss
loss = tf.reduce_sum((y - y_pred) ** 2.0)

# 損失のw1とw2に関する勾配を計算
grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])

# TensorFlowではパラメータ更新も計算グラフ内で行われる
learning_rate = 1e-6
new_w1 = w1.assign(w1 - learning_rate * grad_w1)
new_w2 = w2.assign(w2 - learning_rate * grad_w2)

# 計算グラフができたので実際のデータを入れて計算
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    x_value = np.random.randn(N, D_in)
    y_value = np.random.randn(N, D_out)
    
    for _ in range(500):
        # 計算グラフのloss, new_w1, new_w2を実行
        loss_value, _, _ = sess.run([loss, new_w1, new_w2],
                                    # placeholderにデータを入力
                                    feed_dict={x: x_value, y: y_value})
        print(loss_value)


3.12259e+07
2.45127e+07
2.12229e+07
1.85246e+07
1.5593e+07
1.21715e+07
8.81838e+06
6.06698e+06
4.07357e+06
2.75875e+06
1.91932e+06
1.38465e+06
1.03923e+06
810277.0
652668.0
539671.0
455276.0
389974.0
337956.0
295558.0
260221.0
230367.0
204837.0
182813.0
163670.0
146935.0
132244.0
119298.0
107845.0
97687.2
88647.5
80577.3
73363.4
66892.4
61076.4
55837.9
51115.2
46850.1
42996.8
39503.2
36334.0
33452.2
30829.5
28439.1
26259.9
24267.5
22446.4
20777.2
19247.7
17847.8
16561.9
15379.7
14292.3
13290.8
12368.0
11517.2
10731.7
10005.9
9334.87
8713.51
8137.94
7604.45
7109.44
6649.92
6223.16
5826.59
5457.7
5114.57
4795.11
4497.48
4219.99
3961.19
3719.85
3494.45
3283.91
3087.24
2903.42
2731.38
2570.43
2419.7
2278.57
2146.35
2022.38
1906.18
1797.21
1694.92
1598.91
1508.76
1424.05
1344.43
1269.61
1199.28
1133.1
1070.82
1012.21
957.044
905.087
856.149
810.006
766.531
725.558
686.908
650.451
616.076
583.614
552.974
524.041
496.714
470.906
446.53
423.478
401.689
381.113
361.664
343.269
325.868
309.398
293.804
279.04
265.062
251.822
239.286
227.404
216.146
205.473
195.353
185.766
176.665
168.035
159.848
152.086
144.716
137.722
131.081
124.775
118.788
113.101
107.7
102.571
97.6951
93.0629
88.6594
84.4754
80.4949
76.7113
73.1137
69.6932
66.4384
63.3427
60.3976
57.5932
54.9269
52.3889
49.9726
47.6719
45.481
43.3958
41.41
39.518
37.7165
35.9996
34.3641
32.8056
31.3208
29.9045
28.5556
27.27
26.0436
24.8748
23.76
22.6974
21.6832
20.7161
19.7937
18.9134
18.0745
17.2727
16.5087
15.7791
15.0826
14.4185
13.7841
13.1784
12.5999
12.0481
11.5212
11.0178
10.5368
10.0776
9.63935
9.22012
8.81994
8.43806
8.07331
7.725
7.39202
7.07368
6.76924
6.47831
6.20044
5.93466
5.68013
5.43744
5.20527
4.98308
4.77094
4.56768
4.37346
4.18748
4.01012
3.83996
3.67733
3.52161
3.37297
3.2305
3.09434
2.96377
2.83897
2.71963
2.6053
2.49609
2.39144
2.29138
2.19549
2.10359
2.01566
1.93154
1.85101
1.77377
1.69997
1.62915
1.56137
1.49655
1.43454
1.37501
1.31806
1.26349
1.21115
1.16104
1.1131
1.06714
1.02319
0.980885
0.940506
0.90178
0.864706
0.8291
0.795017
0.762388
0.731129
0.701157
0.672433
0.644829
0.618365
0.593099
0.56887
0.545696
0.523329
0.502022
0.481575
0.461933
0.443075
0.425056
0.407747
0.391173
0.375273
0.35996
0.345376
0.33129
0.317888
0.305014
0.292673
0.280822
0.269394
0.258522
0.248063
0.238008
0.228405
0.219202
0.210368
0.201864
0.193725
0.185903
0.178434
0.17125
0.16434
0.157711
0.151382
0.145277
0.139451
0.13385
0.128453
0.123314
0.118364
0.113607
0.109067
0.104685
0.10049
0.0964625
0.0925924
0.0889038
0.0853352
0.0819322
0.0786523
0.0755121
0.0724907
0.0696019
0.0668173
0.0641467
0.0615877
0.0591339
0.0567862
0.0545261
0.0523596
0.0503012
0.048281
0.0463696
0.0445134
0.0427629
0.0410576
0.0394348
0.0378668
0.0363769
0.0349312
0.0335377
0.0322034
0.030943
0.029726
0.0285517
0.027428
0.0263517
0.0253066
0.0243123
0.0233662
0.0224423
0.021563
0.02071
0.0199027
0.019123
0.018371
0.0176562
0.0169766
0.0163167
0.0156742
0.0150733
0.0144807
0.0139172
0.0133728
0.0128574
0.0123556
0.0118898
0.0114325
0.0109845
0.0105625
0.0101555
0.00976228
0.00939761
0.00904117
0.0086931
0.00836674
0.00804485
0.00774056
0.00745105
0.00716341
0.00689666
0.00663837
0.00638786
0.0061505
0.00592202
0.00570274
0.00549705
0.00528993
0.00509349
0.00490535
0.00473019
0.00455528
0.00438924
0.00422735
0.00407817
0.00392684
0.0037833
0.00364893
0.00351683
0.00339254
0.00327194
0.00315613
0.00304356
0.00293407
0.00283251
0.00273307
0.00263919
0.00254875
0.00246127
0.0023765
0.00229603
0.00221815
0.00214335
0.00206983
0.00200048
0.00193458
0.00186904
0.00180901
0.00174723
0.00169226
0.001637
0.00158179
0.00153255
0.00148276
0.00143713
0.0013915
0.00134637
0.00130486
0.0012637
0.00122757
0.00118874
0.0011504
0.00111735
0.00108154
0.00104862
0.00101935
0.000988923
0.000960252
0.00093141
0.000903866
0.000878356
0.000853453
0.000827215
0.000803581
0.000782643
0.000760103
0.000739048
0.000718424
0.000698606
0.000678461
0.000660542
0.000641875
0.000624269
0.00060734
0.000592116
0.000577136
0.000561861
0.000546885
0.00053333
0.000519737
0.000506077
0.000492634
0.000479982
0.000467659
0.000455963
0.000442856
0.000432194
0.000421502
0.000410793
0.000400616
0.000391827
0.000382203
0.000372312
0.000363635
0.000354807
0.000346683
0.000338851
0.000331158
0.000322919
0.000315322
0.000308096
0.000300936
0.000294161
0.000287886
0.00028121
0.000274682
0.00026844
0.00026263
0.000256791
0.000250888
0.000245954
0.000240712
0.000235374
0.000230329
0.000225437
0.000220439
0.000216414
0.000211611
0.000207383
0.000203174
0.000199157
0.000194643
0.000190399
0.000187064
0.000183586
0.00017909
0.000176096
0.000172516
0.000169191
0.000165981
0.00016296
0.000159766
0.000156856
0.000153806
0.000151083
0.00014881

nnパッケージ


In [82]:
import torch
from torch.autograd import Variable

N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    # forward pass
    y_pred = model(x)
    
    # compute loss
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])
    
    # backwardする前に勾配をクリア
    model.zero_grad()
    
    # backward pass
    loss.backward()
    
    # update the weights
    # paramはVariableなのでdataでTensorを取り出して更新する
    for param in model.parameters():
        param.data -= learning_rate * param.grad.data


0 675.65478515625
1 625.3120727539062
2 581.5741577148438
3 543.2457275390625
4 509.56689453125
5 479.6416931152344
6 452.5718994140625
7 427.7896728515625
8 404.82525634765625
9 383.4107971191406
10 363.2756652832031
11 344.36639404296875
12 326.5947265625
13 309.7552795410156
14 293.7693786621094
15 278.57525634765625
16 264.117919921875
17 250.3040008544922
18 237.11444091796875
19 224.51162719726562
20 212.4835205078125
21 200.97837829589844
22 190.0261993408203
23 179.53216552734375
24 169.5443115234375
25 159.99258422851562
26 150.90789794921875
27 142.3028564453125
28 134.12548828125
29 126.3597183227539
30 118.99972534179688
31 112.01719665527344
32 105.41255187988281
33 99.15910339355469
34 93.26506042480469
35 87.70309448242188
36 82.46214294433594
37 77.52316284179688
38 72.86669921875
39 68.46904754638672
40 64.32349395751953
41 60.43075180053711
42 56.78474044799805
43 53.35789489746094
44 50.141658782958984
45 47.12574005126953
46 44.28652572631836
47 41.62451171875
48 39.132266998291016
49 36.79706954956055
50 34.6125602722168
51 32.56613540649414
52 30.64849853515625
53 28.854799270629883
54 27.171537399291992
55 25.593002319335938
56 24.115636825561523
57 22.732067108154297
58 21.4384765625
59 20.219188690185547
60 19.07802391052246
61 18.009052276611328
62 17.00579071044922
63 16.06410789489746
64 15.18037223815918
65 14.349960327148438
66 13.570079803466797
67 12.837751388549805
68 12.148946762084961
69 11.50184154510498
70 10.893783569335938
71 10.322190284729004
72 9.783084869384766
73 9.274497985839844
74 8.79552173614502
75 8.34421443939209
76 7.918089866638184
77 7.51570987701416
78 7.135971546173096
79 6.777321815490723
80 6.438596725463867
81 6.1185407638549805
82 5.815964221954346
83 5.529472827911377
84 5.2586541175842285
85 5.002779960632324
86 4.761168003082275
87 4.5323076248168945
88 4.315114498138428
89 4.109409809112549
90 3.914238929748535
91 3.729203939437866
92 3.55362868309021
93 3.387044668197632
94 3.2288804054260254
95 3.078555107116699
96 2.935783624649048
97 2.799914836883545
98 2.670793294906616
99 2.5480520725250244
100 2.4313580989837646
101 2.3205208778381348
102 2.2152884006500244
103 2.1150104999542236
104 2.0196075439453125
105 1.9287279844284058
106 1.842155933380127
107 1.7597295045852661
108 1.6812456846237183
109 1.6064707040786743
110 1.5348340272903442
111 1.4666717052459717
112 1.4016685485839844
113 1.3397374153137207
114 1.280611276626587
115 1.2242895364761353
116 1.1705615520477295
117 1.1192907094955444
118 1.0703858137130737
119 1.0237667560577393
120 0.9792618751525879
121 0.9368405342102051
122 0.8962995409965515
123 0.8576265573501587
124 0.8207319974899292
125 0.785493791103363
126 0.7518104910850525
127 0.7196504473686218
128 0.6889349818229675
129 0.6595973968505859
130 0.6315714120864868
131 0.6048287749290466
132 0.5792746543884277
133 0.5548452138900757
134 0.5314984917640686
135 0.5092183947563171
136 0.48789167404174805
137 0.46749186515808105
138 0.44798561930656433
139 0.42932364344596863
140 0.41148555278778076
141 0.3944176733493805
142 0.378089964389801
143 0.3624555468559265
144 0.34751299023628235
145 0.33320820331573486
146 0.3195238709449768
147 0.30642589926719666
148 0.2938777804374695
149 0.28187739849090576
150 0.2703791856765747
151 0.2593702971935272
152 0.24883301556110382
153 0.23874923586845398
154 0.22908471524715424
155 0.2198312133550644
156 0.2109670639038086
157 0.20247633755207062
158 0.1943383663892746
159 0.18655134737491608
160 0.17907896637916565
161 0.17192113399505615
162 0.16506071388721466
163 0.15848278999328613
164 0.15218143165111542
165 0.14614365994930267
166 0.14035852253437042
167 0.13480634987354279
168 0.1294853240251541
169 0.12438629567623138
170 0.11949548125267029
171 0.1148068755865097
172 0.11031004041433334
173 0.10599800199270248
174 0.10186056792736053
175 0.09788781404495239
176 0.0940728411078453
177 0.09041343629360199
178 0.08690321445465088
179 0.08353481441736221
180 0.080302894115448
181 0.07720248401165009
182 0.07422599196434021
183 0.07136943191289902
184 0.06862975656986237
185 0.06599953770637512
186 0.06347211450338364
187 0.061046384274959564
188 0.05871780961751938
189 0.05648224055767059
190 0.05433623120188713
191 0.05227389931678772
192 0.05029395595192909
193 0.048391636461019516
194 0.0465649738907814
195 0.044810257852077484
196 0.04312582314014435
197 0.041506603360176086
198 0.039953526109457016
199 0.03846193477511406
200 0.037026483565568924
201 0.035646989941596985
202 0.034322481602430344
203 0.03304952383041382
204 0.03182518854737282
205 0.03064839169383049
206 0.02951742336153984
207 0.028429796919226646
208 0.027383876964449883
209 0.026378002017736435
210 0.025411011651158333
211 0.024480892345309258
212 0.02358631044626236
213 0.02272668667137623
214 0.021898888051509857
215 0.02110256813466549
216 0.020336996763944626
217 0.019600175321102142
218 0.018891675397753716
219 0.018209509551525116
220 0.01755305379629135
221 0.01692158356308937
222 0.01631336845457554
223 0.015728304162621498
224 0.015165157616138458
225 0.014623591676354408
226 0.014101977460086346
227 0.013599781319499016
228 0.013116545043885708
229 0.012651651166379452
230 0.012203389778733253
231 0.011771902441978455
232 0.011356188915669918
233 0.010955728590488434
234 0.010570071637630463
235 0.010198958218097687
236 0.009841236285865307
237 0.00949665904045105
238 0.009165083989501
239 0.008845184929668903
240 0.008536982350051403
241 0.008240079507231712
242 0.007953978143632412
243 0.0076781571842730045
244 0.007412464823573828
245 0.007156394887715578
246 0.006909938063472509
247 0.0066719455644488335
248 0.006442675832659006
249 0.006221471354365349
250 0.006008249707520008
251 0.005802683066576719
252 0.005604545585811138
253 0.00541332783177495
254 0.005229064263403416
255 0.005051286891102791
256 0.00487998453900218
257 0.004714705049991608
258 0.0045552426017820835
259 0.004401328973472118
260 0.00425296276807785
261 0.004109869245439768
262 0.0039718556217849255
263 0.0038384953513741493
264 0.0037100135814398527
265 0.003585967468097806
266 0.0034662478137761354
267 0.0033507714979350567
268 0.0032392472494393587
269 0.003131563775241375
270 0.003027670318260789
271 0.0029273247346282005
272 0.002830467652529478
273 0.002736972877755761
274 0.0026466655544936657
275 0.002559497021138668
276 0.0024753850884735584
277 0.0023941511753946543
278 0.0023157005198299885
279 0.0022398876026272774
280 0.0021667135879397392
281 0.0020959926769137383
282 0.002027706243097782
283 0.0019617509096860886
284 0.0018980937311425805
285 0.001836535520851612
286 0.0017771474085748196
287 0.0017198980785906315
288 0.0016645110445097089
289 0.0016109741991385818
290 0.0015592684503644705
291 0.0015092750545591116
292 0.0014609573408961296
293 0.0014143000589683652
294 0.0013692221837118268
295 0.0013255940284579992
296 0.0012833960354328156
297 0.001242607831954956
298 0.0012032317463308573
299 0.001165146240964532
300 0.0011282827472314239
301 0.0010926255490630865
302 0.0010581769747659564
303 0.0010248497128486633
304 0.0009926458587870002
305 0.0009614650625735521
306 0.000931316229980439
307 0.0009021806181408465
308 0.0008739638142287731
309 0.0008466633153147995
310 0.0008202516473829746
311 0.0007947049452923238
312 0.0007699744310230017
313 0.0007460600463673472
314 0.0007229230250231922
315 0.0007005200604908168
316 0.0006788333994336426
317 0.0006578530301339924
318 0.0006375610246323049
319 0.0006179025513119996
320 0.0005988873890601099
321 0.0005804764223285019
322 0.0005626647616736591
323 0.0005454242927953601
324 0.0005287190433591604
325 0.0005125394091010094
326 0.0004968796856701374
327 0.00048173224786296487
328 0.0004670596099458635
329 0.00045284797670319676
330 0.00043908556108362973
331 0.0004257683758623898
332 0.00041286146733909845
333 0.0004003603244200349
334 0.0003882618620991707
335 0.0003765389265026897
336 0.0003651783335953951
337 0.00035416873288340867
338 0.00034350602072663605
339 0.00033317855559289455
340 0.00032317329896613955
341 0.00031347450567409396
342 0.00030409282771870494
343 0.0002949955814983696
344 0.00028616978670470417
345 0.0002776322071440518
346 0.0002693466085474938
347 0.00026132422499358654
348 0.0002535436942707747
349 0.0002460115938447416
350 0.00023870430595707148
351 0.00023162650177255273
352 0.00022476111189462245
353 0.00021810887847095728
354 0.00021166047372389585
355 0.00020541061530821025
356 0.0001993484765989706
357 0.0001934729953063652
358 0.00018778124649543315
359 0.00018225608801003546
360 0.00017690236563794315
361 0.00017170653154607862
362 0.00016667389718350023
363 0.00016179443628061563
364 0.00015706251724623144
365 0.00015247023839037865
366 0.00014801543147768825
367 0.00014369473501574248
368 0.00013950809079688042
369 0.00013544103421736509
370 0.00013150025915820152
371 0.0001276823750231415
372 0.0001239723205799237
373 0.0001203784195240587
374 0.00011689134407788515
375 0.00011350496060913429
376 0.00011022111721104011
377 0.00010703295265557244
378 0.00010394264973001555
379 0.00010094643221236765
380 9.803617285797372e-05
381 9.521000902168453e-05
382 9.247224807040766e-05
383 8.98118523764424e-05
384 8.723392966203392e-05
385 8.472863555653021e-05
386 8.230158709920943e-05
387 7.994394400157034e-05
388 7.765409100102261e-05
389 7.543452375102788e-05
390 7.327692583203316e-05
391 7.118735084077343e-05
392 6.915871927049011e-05
393 6.718366785207763e-05
394 6.52735834592022e-05
395 6.341632251860574e-05
396 6.161308556329459e-05
397 5.986208998365328e-05
398 5.816398697788827e-05
399 5.651264655170962e-05
400 5.491240881383419e-05
401 5.3354891861090437e-05
402 5.1843919209204614e-05
403 5.037944356445223e-05
404 4.895559686701745e-05
405 4.75741908303462e-05
406 4.623055428965017e-05
407 4.492874722927809e-05
408 4.366083157947287e-05
409 4.243276998749934e-05
410 4.123699545743875e-05
411 4.007991083199158e-05
412 3.895299232681282e-05
413 3.786019442486577e-05
414 3.679931614897214e-05
415 3.576635936042294e-05
416 3.476571146165952e-05
417 3.378982000867836e-05
418 3.2846386602614075e-05
419 3.192783333361149e-05
420 3.1035531719680876e-05
421 3.0167444492690265e-05
422 2.932778443209827e-05
423 2.8510778065538034e-05
424 2.7715730539057404e-05
425 2.694242175493855e-05
426 2.619397491798736e-05
427 2.546514770074282e-05
428 2.475718611094635e-05
429 2.4069044229690917e-05
430 2.3402219085255638e-05
431 2.2752899894840084e-05
432 2.212225263065193e-05
433 2.1510073565877974e-05
434 2.0914008928230032e-05
435 2.0334855435066856e-05
436 1.9771716324612498e-05
437 1.922602314152755e-05
438 1.8695414837566204e-05
439 1.8179316612076946e-05
440 1.7677652067504823e-05
441 1.718914063530974e-05
442 1.6715950550860725e-05
443 1.625478580535855e-05
444 1.5807947420398705e-05
445 1.537232856207993e-05
446 1.4950637705624104e-05
447 1.4539710718963761e-05
448 1.414019880030537e-05
449 1.375093142996775e-05
450 1.3374229638429824e-05
451 1.3008473615627736e-05
452 1.2651163160626311e-05
453 1.2304378287808504e-05
454 1.1967954378633294e-05
455 1.163997694675345e-05
456 1.1321482816128992e-05
457 1.1011932656401768e-05
458 1.0711302820709534e-05
459 1.0418658348498866e-05
460 1.0134505828318652e-05
461 9.85783390206052e-06
462 9.589181900082622e-06
463 9.328544365416747e-06
464 9.074116860574577e-06
465 8.8266624516109e-06
466 8.586634976381902e-06
467 8.352999429916963e-06
468 8.126248758344445e-06
469 7.905037819000427e-06
470 7.69027610658668e-06
471 7.4816721280512866e-06
472 7.278579232661286e-06
473 7.080734121700516e-06
474 6.888229563628556e-06
475 6.702490736643085e-06
476 6.520349870697828e-06
477 6.344431767502101e-06
478 6.172551366034895e-06
479 6.004639544698875e-06
480 5.8424197959539015e-06
481 5.684456482413225e-06
482 5.53123391000554e-06
483 5.381974006013479e-06
484 5.236488050286425e-06
485 5.094996595289558e-06
486 4.958324552717386e-06
487 4.824224561161827e-06
488 4.694158633355983e-06
489 4.567688392853597e-06
490 4.4444996092352085e-06
491 4.32559954788303e-06
492 4.208764948998578e-06
493 4.0954282667371444e-06
494 3.985463990829885e-06
495 3.878044481098186e-06
496 3.774114020416164e-06
497 3.6722524328069994e-06
498 3.574280299289967e-06
499 3.4780355235852767e-06

Optimizer


In [83]:
import torch
from torch.autograd import Variable

N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # forward pass
    y_pred = model(x)
    
    # compute loss
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])
    
    # backwardする前に勾配をクリア
    optimizer.zero_grad()
    
    # backward pass
    loss.backward()
    
    # update the weights
    optimizer.step()


0 630.1945190429688
1 613.5613403320312
2 597.468505859375
3 581.8590698242188
4 566.6857299804688
5 551.9625244140625
6 537.625244140625
7 523.65478515625
8 510.0739440917969
9 496.94525146484375
10 484.1810302734375
11 471.8095397949219
12 459.8274230957031
13 448.1696472167969
14 436.78521728515625
15 425.6701354980469
16 414.8320617675781
17 404.3284606933594
18 394.0894470214844
19 384.1022644042969
20 374.3945007324219
21 364.9898986816406
22 355.8160400390625
23 346.88946533203125
24 338.2266845703125
25 329.8339538574219
26 321.614501953125
27 313.6141357421875
28 305.8409423828125
29 298.2760925292969
30 290.8974304199219
31 283.6495666503906
32 276.5545349121094
33 269.619873046875
34 262.8145446777344
35 256.151123046875
36 249.61766052246094
37 243.227783203125
38 236.96031188964844
39 230.82102966308594
40 224.8285369873047
41 218.98483276367188
42 213.2899627685547
43 207.694580078125
44 202.20387268066406
45 196.83082580566406
46 191.58071899414062
47 186.45713806152344
48 181.4352264404297
49 176.52072143554688
50 171.6978302001953
51 166.96670532226562
52 162.33815002441406
53 157.80386352539062
54 153.37136840820312
55 149.03167724609375
56 144.779052734375
57 140.61338806152344
58 136.5396728515625
59 132.5596466064453
60 128.67791748046875
61 124.87992858886719
62 121.16690063476562
63 117.5396499633789
64 113.9938735961914
65 110.52212524414062
66 107.1170654296875
67 103.80551147460938
68 100.58183288574219
69 97.43888854980469
70 94.37068176269531
71 91.3759994506836
72 88.45703125
73 85.61814880371094
74 82.8473892211914
75 80.14981079101562
76 77.52640533447266
77 74.9667739868164
78 72.46195983886719
79 70.02333068847656
80 67.65304565429688
81 65.34642791748047
82 63.107177734375
83 60.9239501953125
84 58.80186080932617
85 56.73649597167969
86 54.72855758666992
87 52.77726364135742
88 50.882015228271484
89 49.04623031616211
90 47.26160430908203
91 45.525779724121094
92 43.84543991088867
93 42.2119255065918
94 40.62570571899414
95 39.09185028076172
96 37.601219177246094
97 36.160648345947266
98 34.767574310302734
99 33.4160270690918
100 32.11177062988281
101 30.846912384033203
102 29.621932983398438
103 28.43855094909668
104 27.294870376586914
105 26.18939971923828
106 25.12156105041504
107 24.089462280273438
108 23.092212677001953
109 22.130855560302734
110 21.204469680786133
111 20.309350967407227
112 19.447277069091797
113 18.614978790283203
114 17.814571380615234
115 17.043516159057617
116 16.302490234375
117 15.5911226272583
118 14.906179428100586
119 14.246766090393066
120 13.612051010131836
121 13.001005172729492
122 12.41418743133545
123 11.849424362182617
124 11.306818008422852
125 10.786190032958984
126 10.285964012145996
127 9.806589126586914
128 9.345783233642578
129 8.904109954833984
130 8.481026649475098
131 8.075965881347656
132 7.688319683074951
133 7.316683769226074
134 6.9603495597839355
135 6.619009017944336
136 6.292989253997803
137 5.981649398803711
138 5.68381929397583
139 5.399041175842285
140 5.1266961097717285
141 4.866694927215576
142 4.618686676025391
143 4.381949424743652
144 4.156022548675537
145 3.940596342086792
146 3.735361099243164
147 3.5398104190826416
148 3.3535878658294678
149 3.1763265132904053
150 3.0077478885650635
151 2.8473355770111084
152 2.6944327354431152
153 2.5493741035461426
154 2.4113717079162598
155 2.280428409576416
156 2.1560885906219482
157 2.037994146347046
158 1.9258633852005005
159 1.8192625045776367
160 1.7181320190429688
161 1.6223467588424683
162 1.5313549041748047
163 1.4451098442077637
164 1.3634165525436401
165 1.2859675884246826
166 1.2126035690307617
167 1.1432349681854248
168 1.077603816986084
169 1.0154258012771606
170 0.9566268920898438
171 0.9009003639221191
172 0.8483271598815918
173 0.7986063361167908
174 0.7516224384307861
175 0.707212507724762
176 0.6652708649635315
177 0.625709056854248
178 0.5884259343147278
179 0.5531680583953857
180 0.5198379755020142
181 0.4884292185306549
182 0.4588363766670227
183 0.43094804883003235
184 0.4046817123889923
185 0.3799101710319519
186 0.3565402030944824
187 0.3345586359500885
188 0.31385868787765503
189 0.2944129407405853
190 0.2760780453681946
191 0.2588368058204651
192 0.24260365962982178
193 0.2273435890674591
194 0.21299250423908234
195 0.19950471818447113
196 0.18682949244976044
197 0.17492569983005524
198 0.1637284755706787
199 0.15323537588119507
200 0.14338171482086182
201 0.13414250314235687
202 0.12545029819011688
203 0.11731269210577011
204 0.10968709737062454
205 0.10252837091684341
206 0.09581895917654037
207 0.08953533321619034
208 0.08363942801952362
209 0.07812152057886124
210 0.0729486420750618
211 0.06810378283262253
212 0.06357648968696594
213 0.059324365109205246
214 0.05535179004073143
215 0.05163753405213356
216 0.04816288873553276
217 0.04491092264652252
218 0.04187195375561714
219 0.03905327618122101
220 0.03642173856496811
221 0.03396441042423248
222 0.03166750818490982
223 0.029525145888328552
224 0.02752656862139702
225 0.025663506239652634
226 0.0239203330129385
227 0.02229551412165165
228 0.020781464874744415
229 0.019368035718798637
230 0.018050743266940117
231 0.01682242751121521
232 0.015676243230700493
233 0.014608747325837612
234 0.013611276634037495
235 0.01268316712230444
236 0.011817984282970428
237 0.011011829599738121
238 0.010260993614792824
239 0.009562554769217968
240 0.008913544937968254
241 0.00830860249698162
242 0.00774555467069149
243 0.007221052423119545
244 0.006733111571520567
245 0.006278149783611298
246 0.005854339804500341
247 0.0054598841816186905
248 0.005092403385788202
249 0.004750038031488657
250 0.004431240726262331
251 0.00413420470431447
252 0.0038575211074203253
253 0.0035998765379190445
254 0.0033595950808376074
255 0.0031358206178992987
256 0.002927411813288927
257 0.0027332925237715244
258 0.002552251098677516
259 0.002383608603850007
260 0.0022267410531640053
261 0.002080458216369152
262 0.0019437019946053624
263 0.0018165220972150564
264 0.0016979664796963334
265 0.0015874411910772324
266 0.0014843952376395464
267 0.0013882778584957123
268 0.0012986381771042943
269 0.0012150410329923034
270 0.001137156505137682
271 0.0010644937865436077
272 0.0009966132929548621
273 0.0009332672343589365
274 0.0008741738274693489
275 0.000819024455267936
276 0.0007675326196476817
277 0.0007194558274932206
278 0.0006745619466528296
279 0.0006326131406240165
280 0.0005934522487223148
281 0.0005568322376348078
282 0.0005226214416325092
283 0.0004906183457933366
284 0.00046070240205153823
285 0.0004327188653405756
286 0.0004065471584908664
287 0.000382053607609123
288 0.00035913227475248277
289 0.00033765280386433005
290 0.0003175501769874245
291 0.0002987212792504579
292 0.00028106942772865295
293 0.0002645330096129328
294 0.000249034957960248
295 0.00023449136642739177
296 0.00022085418459028006
297 0.00020806473912671208
298 0.00019604936824180186
299 0.00018477090634405613
300 0.00017418162315152586
301 0.0001642337447265163
302 0.0001548907021060586
303 0.00014609952631872147
304 0.00013784288603346795
305 0.0001300732110394165
306 0.00012276857160031796
307 0.00011588576307985932
308 0.00010941663640551269
309 0.0001033254447975196
310 9.758208761923015e-05
311 9.217702609021217e-05
312 8.70811563800089e-05
313 8.227964281104505e-05
314 7.774641562718898e-05
315 7.347796781687066e-05
316 6.945259519852698e-05
317 6.565266085090116e-05
318 6.20685750618577e-05
319 5.86816095164977e-05
320 5.548636181629263e-05
321 5.2469280490186065e-05
322 4.9619251512922347e-05
323 4.6928202209528536e-05
324 4.438465475686826e-05
325 4.197944508632645e-05
326 3.9707803807687014e-05
327 3.755931174964644e-05
328 3.5530545574147254e-05
329 3.3610824175411835e-05
330 3.179489431204274e-05
331 3.0078113923082128e-05
332 2.8454483981477097e-05
333 2.6918640287476592e-05
334 2.546560972405132e-05
335 2.4090297301881947e-05
336 2.278858664794825e-05
337 2.155650145141408e-05
338 2.0392402802826837e-05
339 1.928773235704284e-05
340 1.8244307284476236e-05
341 1.725783477013465e-05
342 1.6322137526003644e-05
343 1.5435520253959112e-05
344 1.4598536836274434e-05
345 1.3805557500745635e-05
346 1.3054035662207752e-05
347 1.2343106391199399e-05
348 1.1671763786580414e-05
349 1.103409977076808e-05
350 1.0432370800117496e-05
351 9.860838872555178e-06
352 9.32103557715891e-06
353 8.809937753539998e-06
354 8.32551177154528e-06
355 7.867448402976152e-06
356 7.434103736159159e-06
357 7.024054866633378e-06
358 6.635828412981937e-06
359 6.26808287051972e-06
360 5.9207195590715855e-06
361 5.5918467296578456e-06
362 5.2803411563218106e-06
363 4.9863592721521854e-06
364 4.7079411160666496e-06
365 4.4439311750466e-06
366 4.19533080275869e-06
367 3.9599281080882065e-06
368 3.736906819540309e-06
369 3.5268828924017726e-06
370 3.3277940474363277e-06
371 3.138963165838504e-06
372 2.9612897378683556e-06
373 2.7932592274737544e-06
374 2.634071506690816e-06
375 2.4840735477482667e-06
376 2.3421430341841187e-06
377 2.208300429629162e-06
378 2.081590764646535e-06
379 1.961908765224507e-06
380 1.8485781083654729e-06
381 1.7420876474716351e-06
382 1.641558469600568e-06
383 1.5461259863513988e-06
384 1.4563238437403925e-06
385 1.3714890201299568e-06
386 1.2913695854877005e-06
387 1.2158698154962622e-06
388 1.1448250916146208e-06
389 1.077567048923811e-06
390 1.0144253792532254e-06
391 9.543064152239822e-07
392 8.977622769634763e-07
393 8.445140906587767e-07
394 7.946380833345756e-07
395 7.472024776689068e-07
396 7.025252557468775e-07
397 6.60717432765523e-07
398 6.211953404999804e-07
399 5.839162326992664e-07
400 5.488516876539506e-07
401 5.15746933160699e-07
402 4.843615215577302e-07
403 4.550025209937303e-07
404 4.2749755380100396e-07
405 4.014368073512742e-07
406 3.7699780364164326e-07
407 3.5385005503485445e-07
408 3.3234056218134356e-07
409 3.118458380413358e-07
410 2.9269233436934883e-07
411 2.7469002361613093e-07
412 2.5763148414625903e-07
413 2.4161826672752795e-07
414 2.2663171250769665e-07
415 2.1254840021356358e-07
416 1.9919065152862458e-07
417 1.8681191704672528e-07
418 1.7505854543742316e-07
419 1.6412582226621453e-07
420 1.5381657192392595e-07
421 1.4410596804736997e-07
422 1.3501914963853778e-07
423 1.264213409513104e-07
424 1.1836489477445866e-07
425 1.1088091866895411e-07
426 1.0382520088114688e-07
427 9.718741011965903e-08
428 9.096545738884743e-08
429 8.514714266993906e-08
430 7.968642989908403e-08
431 7.449838790307695e-08
432 6.973886002015206e-08
433 6.523558937487905e-08
434 6.09584986932532e-08
435 5.7028831434990934e-08
436 5.32774855344087e-08
437 4.981649581736747e-08
438 4.662164698743254e-08
439 4.3566657836890954e-08
440 4.071048920195608e-08
441 3.804375126037485e-08
442 3.550315241795943e-08
443 3.314581320523757e-08
444 3.098860901218359e-08
445 2.895527906332518e-08
446 2.7060053753302782e-08
447 2.5250475488292068e-08
448 2.358770068155991e-08
449 2.2054702952800653e-08
450 2.0597271443989484e-08
451 1.9234288828329227e-08
452 1.7958790010652592e-08
453 1.6795670632063775e-08
454 1.568242424809796e-08
455 1.4611400978026268e-08
456 1.3645490071212407e-08
457 1.2746747657388369e-08
458 1.1898390717135499e-08
459 1.1114631881525838e-08
460 1.0367790181931014e-08
461 9.673892797934514e-09
462 9.037096404540534e-09
463 8.420978581114014e-09
464 7.864539242063984e-09
465 7.342435104362721e-09
466 6.871385238582661e-09
467 6.400473040457655e-09
468 5.988429307990373e-09
469 5.585047979650426e-09
470 5.225257559970942e-09
471 4.872999781468934e-09
472 4.544762788327716e-09
473 4.242266093967828e-09
474 3.969378603585483e-09
475 3.699218042996222e-09
476 3.447852225946235e-09
477 3.221829691923972e-09
478 3.0183739951894495e-09
479 2.8199333979017638e-09
480 2.628056661180267e-09
481 2.455221803643326e-09
482 2.3040063190649107e-09
483 2.1539248162838476e-09
484 2.019747036285935e-09
485 1.890100964629937e-09
486 1.7747944225376955e-09
487 1.6616578113470837e-09
488 1.5563869082413362e-09
489 1.4517328450480704e-09
490 1.3655054864614158e-09
491 1.2798022641646867e-09
492 1.2051144526736834e-09
493 1.1355625328945962e-09
494 1.0623663060371769e-09
495 1.003696459278558e-09
496 9.517264754066446e-10
497 8.859022404550387e-10
498 8.406968454721664e-10
499 7.958352865600204e-10

Custom nn Modeules


In [84]:
import torch
from torch.autograd import Variable

# 独自モデルの定義
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
    
    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

model = TwoLayerNet(D_in, H, D_out)

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # forward pass
    y_pred = model(x)
    
    # compute loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


0 628.8768920898438
1 582.905029296875
2 543.0150146484375
3 507.8055114746094
4 476.51544189453125
5 448.16473388671875
6 422.3822326660156
7 398.69195556640625
8 376.92254638671875
9 356.5376281738281
10 337.4390563964844
11 319.57373046875
12 302.7062072753906
13 286.7491760253906
14 271.6932373046875
15 257.373779296875
16 243.7882080078125
17 230.88551330566406
18 218.60121154785156
19 206.88059997558594
20 195.67245483398438
21 184.96205139160156
22 174.74896240234375
23 165.0063934326172
24 155.72891235351562
25 146.8719024658203
26 138.43077087402344
27 130.381103515625
28 122.75699615478516
29 115.50647735595703
30 108.64201354980469
31 102.15310668945312
32 96.0034408569336
33 90.20628356933594
34 84.72351837158203
35 79.5626220703125
36 74.6899185180664
37 70.1153564453125
38 65.8133773803711
39 61.767574310302734
40 57.97158432006836
41 54.41753005981445
42 51.0764274597168
43 47.945804595947266
44 45.01655197143555
45 42.27145004272461
46 39.703224182128906
47 37.3044319152832
48 35.059078216552734
49 32.95682907104492
50 30.98513412475586
51 29.141529083251953
52 27.414379119873047
53 25.799152374267578
54 24.28508186340332
55 22.86610221862793
56 21.535982131958008
57 20.28965950012207
58 19.121274948120117
59 18.025775909423828
60 16.997135162353516
61 16.032447814941406
62 15.128386497497559
63 14.279285430908203
64 13.483600616455078
65 12.740087509155273
66 12.039639472961426
67 11.382251739501953
68 10.764351844787598
69 10.184585571289062
70 9.639841079711914
71 9.127924919128418
72 8.646318435668945
73 8.192237854003906
74 7.764461994171143
75 7.361912727355957
76 6.982723236083984
77 6.624800205230713
78 6.2867536544799805
79 5.967916011810303
80 5.667055606842041
81 5.382901668548584
82 5.114840507507324
83 4.861160755157471
84 4.620994567871094
85 4.39376974105835
86 4.1792120933532715
87 3.975693702697754
88 3.7831358909606934
89 3.6005656719207764
90 3.427530527114868
91 3.2639000415802
92 3.108649253845215
93 2.961369752883911
94 2.8217713832855225
95 2.6892292499542236
96 2.563187837600708
97 2.443455696105957
98 2.329859733581543
99 2.2220115661621094
100 2.119577407836914
101 2.0220882892608643
102 1.9289782047271729
103 1.840386152267456
104 1.7561156749725342
105 1.6759458780288696
106 1.599685788154602
107 1.5273051261901855
108 1.4584053754806519
109 1.3927888870239258
110 1.3302958011627197
111 1.2709208726882935
112 1.2143330574035645
113 1.1604400873184204
114 1.1090588569641113
115 1.0600712299346924
116 1.013375163078308
117 0.9688601493835449
118 0.9264160990715027
119 0.8859167098999023
120 0.8473114371299744
121 0.8104737997055054
122 0.7752606272697449
123 0.7416868209838867
124 0.7096778154373169
125 0.6791211366653442
126 0.6499359607696533
127 0.6220793724060059
128 0.5954534411430359
129 0.5700303316116333
130 0.545718252658844
131 0.5225005745887756
132 0.500317394733429
133 0.47912320494651794
134 0.4588737189769745
135 0.4395332634449005
136 0.4210309088230133
137 0.40334346890449524
138 0.38645002245903015
139 0.3703002631664276
140 0.3548395335674286
141 0.3400569260120392
142 0.3259228467941284
143 0.31241491436958313
144 0.29947420954704285
145 0.2870945334434509
146 0.2752539813518524
147 0.26391682028770447
148 0.2531098425388336
149 0.24276626110076904
150 0.23286984860897064
151 0.223390594124794
152 0.21430960297584534
153 0.20561061799526215
154 0.19728820025920868
155 0.1893206536769867
156 0.18168362975120544
157 0.17436788976192474
158 0.16736282408237457
159 0.1606549322605133
160 0.15422949194908142
161 0.14806926250457764
162 0.14216244220733643
163 0.13650457561016083
164 0.13108444213867188
165 0.12589198350906372
166 0.12091667950153351
167 0.11614266782999039
168 0.11157122254371643
169 0.1071794331073761
170 0.10296905785799026
171 0.09892625361680984
172 0.09505238384008408
173 0.09133795648813248
174 0.0877734124660492
175 0.08435333520174026
176 0.08107079565525055
177 0.07792165130376816
178 0.07490617036819458
179 0.07201878726482391
180 0.069246806204319
181 0.06658772379159927
182 0.0640348494052887
183 0.061581145972013474
184 0.059228185564279556
185 0.05696742609143257
186 0.05479587987065315
187 0.052710700780153275
188 0.05070723965764046
189 0.04878400266170502
190 0.046937260776758194
191 0.04516225308179855
192 0.04345610737800598
193 0.04181729257106781
194 0.04024216905236244
195 0.038728151470422745
196 0.037273891270160675
197 0.0358765535056591
198 0.03453294187784195
199 0.033241983503103256
200 0.03200026601552963
201 0.030807895585894585
202 0.029660504311323166
203 0.028557587414979935
204 0.02749626524746418
205 0.026476016268134117
206 0.02549436129629612
207 0.02455090545117855
208 0.023643119260668755
209 0.022770114243030548
210 0.021930204704403877
211 0.021123021841049194
212 0.02034611999988556
213 0.01959959976375103
214 0.018880873918533325
215 0.01818910799920559
216 0.01752306893467903
217 0.016882412135601044
218 0.016265835613012314
219 0.015672389417886734
220 0.015101303346455097
221 0.01455171313136816
222 0.01402268186211586
223 0.01351366750895977
224 0.013023781590163708
225 0.012551754713058472
226 0.012097032740712166
227 0.011659912765026093
228 0.01123831421136856
229 0.010832572355866432
230 0.010441661812365055
231 0.010066078044474125
232 0.009704281575977802
233 0.009355916641652584
234 0.009020312689244747
235 0.00869691837579012
236 0.008385451510548592
237 0.008085131645202637
238 0.007795886602252722
239 0.007517447229474783
240 0.007249071728438139
241 0.006990686524659395
242 0.006741910241544247
243 0.006501904688775539
244 0.00627076206728816
245 0.006047961302101612
246 0.0058331661857664585
247 0.005626145284622908
248 0.005426604766398668
249 0.005234295502305031
250 0.005048990715295076
251 0.004870458971709013
252 0.004698398523032665
253 0.004532509949058294
254 0.004372508265078068
255 0.004218226298689842
256 0.00406953739002347
257 0.003926203120499849
258 0.003788108704611659
259 0.0036548622883856297
260 0.0035264871548861265
261 0.003402676433324814
262 0.0032832662109285593
263 0.0031681114342063665
264 0.0030570016242563725
265 0.002949872985482216
266 0.0028465944342315197
267 0.0027469894848763943
268 0.002650908660143614
269 0.0025582527741789818
270 0.0024689065758138895
271 0.0023827922996133566
272 0.002299732994288206
273 0.0022195796482264996
274 0.0021422388963401318
275 0.002067727269604802
276 0.001995770027860999
277 0.0019263536669313908
278 0.0018593703862279654
279 0.001794783165678382
280 0.0017324825748801231
281 0.0016724382294341922
282 0.001614489359781146
283 0.0015585143119096756
284 0.0015045426553115249
285 0.0014524575090035796
286 0.0014022354735061526
287 0.0013537146151065826
288 0.0013069347478449345
289 0.001261790399439633
290 0.0012182136997580528
291 0.001176220248453319
292 0.0011356562608852983
293 0.001096562948077917
294 0.0010587762808427215
295 0.001022301148623228
296 0.000987111241556704
297 0.0009531514369882643
298 0.0009203641093336046
299 0.0008887449512258172
300 0.0008582447189837694
301 0.0008287805831059813
302 0.0008003398543223739
303 0.0007728799246251583
304 0.0007463740184903145
305 0.0007207827875390649
306 0.000696078292094171
307 0.000672238296829164
308 0.0006492208922281861
309 0.0006270170561037958
310 0.0006055858102627099
311 0.000584896479267627
312 0.000564930378459394
313 0.0005456244689412415
314 0.0005269797402434051
315 0.0005089989863336086
316 0.0004916353500448167
317 0.0004748603387270123
318 0.0004586807044688612
319 0.0004430467961356044
320 0.0004279575659893453
321 0.00041338204755447805
322 0.000399310898501426
323 0.000385710911359638
324 0.0003726017312146723
325 0.00035992389894090593
326 0.000347695080563426
327 0.0003358841349836439
328 0.0003244730760343373
329 0.0003134573344141245
330 0.00030282189254648983
331 0.00029255275148898363
332 0.00028263233252801
333 0.00027305822004564106
334 0.00026380151393823326
335 0.00025486521190032363
336 0.0002462302800267935
337 0.0002378957433393225
338 0.00022984531824477017
339 0.0002220700989710167
340 0.00021455455862451345
341 0.00020730386313516647
342 0.00020029734878335148
343 0.00019353485549800098
344 0.0001869970583356917
345 0.00018068747885990888
346 0.00017458935326430947
347 0.0001686942996457219
348 0.00016300381685141474
349 0.00015750866441521794
350 0.00015220875502564013
351 0.00014707553782500327
352 0.00014211917005013674
353 0.00013734071399085224
354 0.00013271371426526457
355 0.00012824674195144325
356 0.00012393361248541623
357 0.00011976953828707337
358 0.00011573927622521296
359 0.00011184767936356366
360 0.00010808730439748615
361 0.00010445326188346371
362 0.00010094729077536613
363 9.756282815942541e-05
364 9.428915655007586e-05
365 9.112346015172079e-05
366 8.80725565366447e-05
367 8.511535997968167e-05
368 8.226191130233929e-05
369 7.950788858579472e-05
370 7.684101728955284e-05
371 7.42672273190692e-05
372 7.1782196755521e-05
373 6.938153819646686e-05
374 6.705996202072129e-05
375 6.481441960204393e-05
376 6.26509718131274e-05
377 6.055557605577633e-05
378 5.853157927049324e-05
379 5.657784276991151e-05
380 5.468646122608334e-05
381 5.285903898766264e-05
382 5.109456833451986e-05
383 4.939021891914308e-05
384 4.7744040784891695e-05
385 4.615002035279758e-05
386 4.461118805920705e-05
387 4.312056989874691e-05
388 4.168609666521661e-05
389 4.029708725283854e-05
390 3.895556801580824e-05
391 3.76572206732817e-05
392 3.6403271224116907e-05
393 3.519393067108467e-05
394 3.402162838028744e-05
395 3.289017331553623e-05
396 3.1794152164366096e-05
397 3.073714469792321e-05
398 2.9716409699176438e-05
399 2.87269340333296e-05
400 2.7773477995651774e-05
401 2.6851466827793047e-05
402 2.5959598133340478e-05
403 2.509721707610879e-05
404 2.4263814339064993e-05
405 2.3459113435819745e-05
406 2.2681389964418486e-05
407 2.192888678109739e-05
408 2.1201736672082916e-05
409 2.049901377176866e-05
410 1.9820581655949354e-05
411 1.916450128192082e-05
412 1.8530485249357298e-05
413 1.7916505385073833e-05
414 1.7322237908956595e-05
415 1.6748395864851773e-05
416 1.619467366253957e-05
417 1.56591777340509e-05
418 1.5141456060518976e-05
419 1.4640903827967122e-05
420 1.4156329598336015e-05
421 1.368854191241553e-05
422 1.323557626164984e-05
423 1.2799095202353783e-05
424 1.2375211554171983e-05
425 1.1967484169872478e-05
426 1.157119550043717e-05
427 1.1190740224265028e-05
428 1.0821258911164477e-05
429 1.046349098032806e-05
430 1.0118698810401838e-05
431 9.786027476366144e-06
432 9.462528396397829e-06
433 9.151079211733304e-06
434 8.849759069562424e-06
435 8.557890396332368e-06
436 8.276587323052809e-06
437 8.004155461094342e-06
438 7.740901310171466e-06
439 7.486093181796605e-06
440 7.2399116106680594e-06
441 7.001933227002155e-06
442 6.77080288369325e-06
443 6.549024874402676e-06
444 6.333350938803051e-06
445 6.125902928033611e-06
446 5.924895503994776e-06
447 5.729342774429824e-06
448 5.54071402802947e-06
449 5.359770057111746e-06
450 5.183663688512752e-06
451 5.012848760088673e-06
452 4.848670869250782e-06
453 4.69002270619967e-06
454 4.536191227089148e-06
455 4.3875520532310475e-06
456 4.2432284317328595e-06
457 4.104598701815121e-06
458 3.9698511500319e-06
459 3.840378667518962e-06
460 3.7145957776374416e-06
461 3.5930158901464893e-06
462 3.475669245744939e-06
463 3.36207585860393e-06
464 3.252138412790373e-06
465 3.145035861962242e-06
466 3.042135176656302e-06
467 2.9426203127513872e-06
468 2.846916231646901e-06
469 2.7537041660252726e-06
470 2.664094836291042e-06
471 2.577009581727907e-06
472 2.4926885089371353e-06
473 2.411085915809963e-06
474 2.332766825929866e-06
475 2.256970901726163e-06
476 2.1832374841324054e-06
477 2.1121295503689907e-06
478 2.042705773419584e-06
479 1.9760439045057865e-06
480 1.912071184051456e-06
481 1.8495394442652469e-06
482 1.789323505363427e-06
483 1.731188035591913e-06
484 1.6750985878388747e-06
485 1.6204744497372303e-06
486 1.5675828990424634e-06
487 1.5165126114879968e-06
488 1.4670133623440051e-06
489 1.419428031113057e-06
490 1.3729368220083416e-06
491 1.3284346778164036e-06
492 1.285278813156765e-06
493 1.2436605629773112e-06
494 1.2030917559968657e-06
495 1.1638243222478195e-06
496 1.1262238786002854e-06
497 1.0899775588768534e-06
498 1.054150061463588e-06
499 1.0203552847087849e-06

Dynamic Model


In [85]:
import random
import torch
from torch.autograd import Variable

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
    
    def forward(self, x):
        h_relu = self.input_linear(x).clamp(min=0)
        # forward passのたびに0から2個のlinear層をランダムに追加する
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

model = DynamicNet(D_in, H, D_out)

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # forward pass
    y_pred = model(x)
    
    # compute loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


0 640.6339111328125
1 631.3048706054688
2 659.3577880859375
3 629.0328369140625
4 591.5602416992188
5 625.6828002929688
6 630.2764892578125
7 494.9000549316406
8 450.3318786621094
9 628.6624755859375
10 546.13623046875
11 627.501708984375
12 614.8070068359375
13 611.4189453125
14 515.6226196289062
15 226.6064453125
16 593.8165283203125
17 617.955322265625
18 614.0869140625
19 446.6548156738281
20 425.9581604003906
21 542.76025390625
22 525.5250854492188
23 573.9893798828125
24 480.4125671386719
25 534.5718994140625
26 507.98529052734375
27 408.9234313964844
28 390.77008056640625
29 435.8751525878906
30 421.808349609375
31 355.2283935546875
32 291.4407043457031
33 370.36181640625
34 224.03392028808594
35 202.64659118652344
36 179.03271484375
37 322.5909118652344
38 143.73648071289062
39 202.34881591796875
40 102.24514770507812
41 179.5446319580078
42 88.59590911865234
43 160.90188598632812
44 141.61976623535156
45 94.38774108886719
46 114.52020263671875
47 158.8015594482422
48 79.39371490478516
49 116.01934051513672
50 64.37411499023438
51 80.15482330322266
52 64.56410217285156
53 62.099761962890625
54 79.6227798461914
55 57.843017578125
56 74.33861541748047
57 186.13548278808594
58 133.14683532714844
59 88.22381591796875
60 57.4803352355957
61 110.00944519042969
62 81.33229064941406
63 77.33230590820312
64 72.38409423828125
65 115.05047607421875
66 52.679988861083984
67 41.9664192199707
68 29.800870895385742
69 108.56969451904297
70 40.39303970336914
71 83.87448120117188
72 63.08845138549805
73 45.89917755126953
74 39.225955963134766
75 51.602638244628906
76 40.79216003417969
77 17.08511734008789
78 28.232406616210938
79 36.58096694946289
80 23.18515968322754
81 22.41709327697754
82 19.26115608215332
83 52.43354797363281
84 21.380966186523438
85 39.96542739868164
86 80.7813949584961
87 17.51972198486328
88 39.39006805419922
89 19.12723731994629
90 39.15578842163086
91 15.568731307983398
92 19.20397186279297
93 31.091533660888672
94 14.23631763458252
95 38.91694259643555
96 27.52091407775879
97 18.257177352905273
98 11.38196086883545
99 23.20766258239746
100 18.017391204833984
101 14.954666137695312
102 11.621520042419434
103 8.696244239807129
104 13.579984664916992
105 8.875887870788574
106 15.578822135925293
107 12.324014663696289
108 11.421392440795898
109 6.599695682525635
110 9.798518180847168
111 16.59755516052246
112 11.493302345275879
113 8.739280700683594
114 34.38998031616211
115 7.232539653778076
116 13.05335807800293
117 15.574180603027344
118 12.311744689941406
119 8.564322471618652
120 11.430635452270508
121 8.69427490234375
122 24.75551986694336
123 26.471179962158203
124 16.746294021606445
125 21.24774932861328
126 10.437763214111328
127 5.116059303283691
128 6.753158092498779
129 5.536902904510498
130 19.40483856201172
131 6.1616926193237305
132 7.150239944458008
133 19.387067794799805
134 5.739974021911621
135 4.099265098571777
136 4.87880802154541
137 6.963666915893555
138 11.725775718688965
139 9.694158554077148
140 8.33269214630127
141 5.964237689971924
142 2.8801589012145996
143 16.128971099853516
144 3.6393864154815674
145 6.348008632659912
146 12.419780731201172
147 3.9682765007019043
148 7.3023858070373535
149 5.606322765350342
150 6.056889533996582
151 6.548543453216553
152 4.109741687774658
153 2.3753857612609863
154 3.042941093444824
155 4.616113662719727
156 4.168756484985352
157 14.282095909118652
158 4.939173698425293
159 1.4284929037094116
160 2.7167749404907227
161 11.325643539428711
162 49.708961486816406
163 3.0505549907684326
164 10.226749420166016
165 45.13346862792969
166 142.62899780273438
167 22.035104751586914
168 51.4554557800293
169 101.67341613769531
170 50.65621566772461
171 58.01139450073242
172 17.46959686279297
173 12.579951286315918
174 54.26023864746094
175 102.39468383789062
176 69.96268463134766
177 20.378849029541016
178 19.615020751953125
179 39.74927520751953
180 35.0086784362793
181 37.457515716552734
182 34.72481155395508
183 13.040182113647461
184 19.37213706970215
185 28.625749588012695
186 29.15053367614746
187 29.23573875427246
188 22.76650047302246
189 5.943062782287598
190 6.702073574066162
191 21.133747100830078
192 12.978866577148438
193 5.981543064117432
194 20.615291595458984
195 9.014150619506836
196 7.7108330726623535
197 15.92280101776123
198 11.963845252990723
199 6.203895568847656
200 3.7301888465881348
201 4.728057384490967
202 5.314579486846924
203 5.240014553070068
204 5.453920364379883
205 8.696304321289062
206 7.900552272796631
207 6.678144931793213
208 3.3714938163757324
209 1.5386085510253906
210 6.032501220703125
211 5.8018364906311035
212 3.881443977355957
213 2.6174204349517822
214 2.976761817932129
215 3.5310964584350586
216 3.7623395919799805
217 2.6155009269714355
218 1.396619439125061
219 3.204648017883301
220 1.8713853359222412
221 6.81362771987915
222 2.8310372829437256
223 1.7322633266448975
224 6.833327293395996
225 2.695887565612793
226 2.619550943374634
227 2.378783941268921
228 0.6700885891914368
229 10.298734664916992
230 3.470989465713501
231 4.275885581970215
232 1.9519766569137573
233 1.6557674407958984
234 18.67424201965332
235 3.171677350997925
236 3.4258246421813965
237 13.1660795211792
238 13.750452041625977
239 20.715951919555664
240 1.1175780296325684
241 4.620871543884277
242 10.80490493774414
243 8.206889152526855
244 7.947866916656494
245 2.2120327949523926
246 2.562795400619507
247 30.538103103637695
248 13.901779174804688
249 2.9181559085845947
250 23.1232852935791
251 4.2170820236206055
252 9.517720222473145
253 2.089613914489746
254 9.528303146362305
255 9.1177339553833
256 2.7187657356262207
257 3.678499698638916
258 3.7776966094970703
259 13.38383960723877
260 3.195574998855591
261 6.17686653137207
262 8.391398429870605
263 4.773057460784912
264 2.824456214904785
265 2.305602788925171
266 2.30399227142334
267 3.1287713050842285
268 3.660961151123047
269 2.533358097076416
270 9.300485610961914
271 4.682646751403809
272 1.623048186302185
273 2.4515528678894043
274 2.693984031677246
275 7.850718975067139
276 1.579115390777588
277 10.981849670410156
278 1.2006516456604004
279 5.2704691886901855
280 8.482172012329102
281 7.135919094085693
282 4.059561729431152
283 1.4069573879241943
284 2.6515307426452637
285 24.300792694091797
286 7.018038749694824
287 0.964491605758667
288 7.909121990203857
289 29.25485610961914
290 5.144656181335449
291 6.143868923187256
292 4.477543354034424
293 17.83785057067871
294 13.807804107666016
295 7.488834381103516
296 5.293219566345215
297 1.7783221006393433
298 1.465052604675293
299 2.9993293285369873
300 5.037281036376953
301 28.263219833374023
302 3.9570372104644775
303 9.329717636108398
304 9.696162223815918
305 39.5653190612793
306 2.9339380264282227
307 3.9535694122314453
308 7.880911350250244
309 10.706714630126953
310 7.586513042449951
311 4.169835090637207
312 6.35995626449585
313 5.3431854248046875
314 1.9723360538482666
315 3.598680257797241
316 1.4837133884429932
317 10.792834281921387
318 1.8036093711853027
319 1.5567809343338013
320 2.695099115371704
321 3.627091884613037
322 3.8082361221313477
323 2.468785524368286
324 1.3327337503433228
325 0.9699833989143372
326 10.822328567504883
327 1.4124585390090942
328 2.5382823944091797
329 2.320375680923462
330 2.4636099338531494
331 3.492678642272949
332 1.3891918659210205
333 1.8921149969100952
334 2.805755853652954
335 1.8069634437561035
336 1.3733097314834595
337 1.4393953084945679
338 3.465122699737549
339 1.3971080780029297
340 1.5291532278060913
341 0.4533478915691376
342 1.7823694944381714
343 3.6125588417053223
344 1.3653022050857544
345 0.816382646560669
346 3.55966854095459
347 1.836150050163269
348 1.280544400215149
349 0.8359392881393433
350 0.5177276730537415
351 1.4263097047805786
352 0.37720954418182373
353 6.211544036865234
354 0.8503292202949524
355 1.1587474346160889
356 7.564732074737549
357 3.6345913410186768
358 1.1681846380233765
359 6.441969394683838
360 1.022794485092163
361 2.310795783996582
362 1.417323350906372
363 1.743171215057373
364 3.065700054168701
365 0.46113091707229614
366 0.839004635810852
367 0.671021580696106
368 0.7674111127853394
369 0.9029204845428467
370 0.7262328267097473
371 0.8153609037399292
372 0.4804431200027466
373 1.5623407363891602
374 0.7265201807022095
375 0.5946952700614929
376 0.4926525950431824
377 0.5163369178771973
378 0.65239018201828
379 2.30501389503479
380 0.7336974740028381
381 1.3022087812423706
382 1.006253957748413
383 0.7253178358078003
384 0.4226285517215729
385 0.3634716272354126
386 0.6258991360664368
387 2.116492986679077
388 0.5364167094230652
389 1.0257192850112915
390 0.3153134882450104
391 0.5183588862419128
392 0.4579337537288666
393 0.49842023849487305
394 0.1868744194507599
395 0.17780829966068268
396 3.463909149169922
397 0.5562441349029541
398 1.5820661783218384
399 1.4303451776504517
400 1.3730618953704834
401 0.47221848368644714
402 0.38177555799484253
403 0.15416491031646729
404 6.4630513191223145
405 1.2200937271118164
406 1.5048688650131226
407 0.30329757928848267
408 10.58842658996582
409 2.374497652053833
410 2.310439109802246
411 10.687407493591309
412 1.0352970361709595
413 0.9835796356201172
414 0.8828991651535034
415 3.653787851333618
416 1.3530324697494507
417 0.5175939798355103
418 0.20744101703166962
419 1.0667517185211182
420 6.705753803253174
421 0.24381057918071747
422 1.0012201070785522
423 2.779245138168335
424 0.9200983047485352
425 0.3046930134296417
426 0.6605075597763062
427 0.4993615746498108
428 1.297109842300415
429 1.6436946392059326
430 0.47011610865592957
431 0.967345654964447
432 0.6892916560173035
433 1.2586971521377563
434 0.3433978855609894
435 1.2471317052841187
436 0.5714337825775146
437 0.9455128312110901
438 1.0444202423095703
439 0.6850330829620361
440 0.32870280742645264
441 0.14840413630008698
442 0.17985893785953522
443 1.167070746421814
444 2.55379056930542
445 0.38805091381073
446 0.18350791931152344
447 0.3380293846130371
448 0.4455258548259735
449 2.868293285369873
450 1.6683859825134277
451 0.26343196630477905
452 0.5781919360160828
453 3.98403263092041
454 0.787031352519989
455 1.1261801719665527
456 0.4379221796989441
457 1.0284756422042847
458 0.4967060089111328
459 0.37963205575942993
460 1.996079444885254
461 0.5005324482917786
462 1.0561158657073975
463 1.8731203079223633
464 1.3041950464248657
465 0.329582542181015
466 0.7239723801612854
467 1.2895078659057617
468 0.538284182548523
469 0.357747882604599
470 0.6347585916519165
471 0.5843656063079834
472 0.25580793619155884
473 0.24190807342529297
474 1.1879912614822388
475 0.3029036521911621
476 0.9910099506378174
477 0.21286898851394653
478 0.2229626625776291
479 1.0583611726760864
480 0.506009042263031
481 0.42843618988990784
482 0.6080997586250305
483 0.781005322933197
484 0.5436868071556091
485 0.2701403498649597
486 0.3564458191394806
487 1.142154335975647
488 0.5860298871994019
489 0.28172093629837036
490 1.0404759645462036
491 0.5730339884757996
492 0.42362573742866516
493 0.42784374952316284
494 0.4694884717464447
495 0.38278234004974365
496 0.23577500879764557
497 0.8770691156387329
498 0.9416893720626831
499 0.2107960432767868