In [1]:
import dpp_nets.my_torch as my_torch
import torch
import numpy as np
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
network_params = {'emb_in': 300, 'emb_h': 200, 'emb_out': 100,
'pred_in': 150, 'pred_h': 100, 'pred_out': 1,
'set_size': 40}
dtype = torch.DoubleTensor
train_iter = 10000
batch_size = 25
sample_iter = 1
alpha_iter = 0
lr = 1e-5
weight_decay = 0
In [2]:
# No Regularization
reg_exp = 0
reg_var = 0
set_seed(10)
reg_0_0 = my_torch.DPPRegressor(network_params, dtype)
reg_0_0.sample()
set_seed(12)
reg_0_0.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_0_0.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 388.73456237145854
100 430.3996593668075
150 1750.837049719433
200 6.10491815201671
250 322.6662017357267
300 20.007869595589817
350 218.06577695617315
400 1194.9962413182288
450 14.006328963061044
500 75.43431850209146
550 31.787690454198945
600 158.01203616947708
650 483.35307914462254
700 73.21369345398983
750 139.76829138094928
800 6.908023874251161
850 0.5882122319361353
900 8.704261592809038
950 18.12740335019029
1000 66.1063138576361
1050 13.358546526794429
1100 18.007321952608
1150 25.41846037807691
1200 7.523323543145854
1250 80.93730914510313
1300 2.3337738743364858
1350 79.31272866650852
1400 11.013485197983824
1450 2.630220730152504e-09
1500 59.76310454353841
1550 120.37850265387846
1600 0.034775976172626316
1650 0.35624721755236916
1700 54.822296272432034
1750 49.87641550106349
1800 1.8895636293225366
1850 79.66709271881474
1900 5.581674727647507
1950 6.856537088532053
2000 3.1777749440226315
2050 58.80697722227189
2100 0.010244792323562273
2150 1.8524439955181025
2200 0.3710798887181803
2250 91.78314485278875
2300 8.358906446296867
2350 0.0012181634314495048
2400 3.2658582377153285
2450 49.988062737869825
2500 1.189549131580153
2550 0.16389163621106315
2600 58.248426055877005
2650 9.643326556179344
2700 37.32162272454854
2750 1.8262661364894344
2800 50.282954530367945
2850 6.0554089792241665
2900 0.04764173231051279
2950 1.8368621388367794
3000 0.024624014146254716
3050 70.75253426776494
3100 0.003319745966894654
3150 31.231870754680838
3200 43.6372388454646
3250 25.029749107426085
3300 0.702918217458523
3350 1.7707354201814083
3400 17.13248386882832
3450 171.63110559304144
3500 63.10639180400167
3550 12.662435005793546
3600 11.578912066953169
3650 1.3655855235428243
3700 14.88050962854508
3750 0.49598978574765046
3800 4.891736369427909
3850 3.0997228293544588
3900 0.41217169113223423
3950 25.704027127041222
4000 19.125464350172408
4050 10.107348319185165
4100 7.027329632968116
4150 3.2457720609584926
4200 12.795314851426776
4250 0.549415438873903
4300 21.64296636217015
4350 88.17066785189749
4400 116.41029899185546
4450 88.16320513060192
4500 26.039267546619264
4550 12.536180039214134
4600 0.5665459362537045
4650 67.04823422232977
4700 9.600399657775952
4750 11.093125208103487
4800 39.15306397419294
4850 41.44870121122264
4900 26.61662634524199
4950 2.1234762541393355
5000 2.800639096048677
5050 0.16387862901824277
5100 27.51968891965798
5150 10.866936020029538
5200 1.4042957237051856
5250 11.422764352075754
5300 21.30036573419564
5350 0.9109598186727531
5400 29.97942644355644
5450 6.414538200883061
5500 7.171586726953501
5550 96.52630675438195
5600 2.8334331196715987
5650 24.192329606176042
5700 14.264118833479136
5750 3.446500600076612
5800 0.9092070517411155
5850 1.1599218194943137
5900 4.920669108965675
5950 0.223635675102293
6000 2.411977554745421
6050 101.62591797351776
6100 13.440260453177874
6150 7.957121448367832
6200 70.24668260108966
6250 3.2714388729447488
6300 3.2897082203534542
6350 46.45675034521769
6400 0.14164599678606724
6450 11.033662134348676
6500 4.208473402731354
6550 17.92581143594433
6600 4.013249629652372
6650 16.899675185655685
6700 38.42749407204761
6750 2.4684812957415185
6800 1.7938102883903277
6850 0.4298919181832078
6900 16.588037077684984
6950 1.2090937527735592
7000 8.88856003223881
7050 2.1902391921850595
7100 0.005295199441248823
7150 4.780376675434424
7200 0.5602146539935362
7250 6.094299738061991
7300 70.93414923049308
7350 24.151101413718333
7400 16.743611454036362
7450 8.941447715391456
7500 3.823030991520533
7550 85.52494059404037
7600 3.3067358239563864
7650 0.8394933496171239
7700 5.358220059344829
7750 1.9433164890055707
7800 0.020153987882039134
7850 4.587080548310689
7900 12.299970625407605
7950 0.3798192617313704
8000 6.104982862427622
8050 0.3422334988731932
8100 0.15612890655547867
8150 12.952290301116484
8200 1.468342115087491
8250 0.4759828268381283
8300 30.280086865070658
8350 4.171657880554901
8400 10.311718276590165
8450 3.3008611324087074
8500 0.944931402443006
8550 8.943009713500714
8600 63.33575186468175
8650 23.58069535592024
8700 1.9092732535235208
8750 23.89033805938883
8800 22.668594086175876
8850 3.125580101843071
8900 14.41993599712968
8950 1.1081113726293998
9000 188.30111581429838
9050 3.9731764098774525
9100 0.3766881478595933
9150 0.00249235968881271
9200 3.6682833786254596
9250 0.1612913512246544
9300 1.1972457400324554
9350 42.503554792162625
9400 5.5063366401022815
9450 19.958493057332035
9500 35.233918288438154
9550 2.4420336739366446
9600 7.498665187210403
9650 3.355060005551066
9700 13.60171004736425
9750 0.12774927842498296
9800 0.2986201135206607
9850 25.050113428121435
9900 4.259642178115099
9950 18.967295456761576
10000 6.10870984933535
In [3]:
# No regularization
reg_exp = 0
reg_var = -10
set_seed(10)
reg_0_1 = my_torch.DPPRegressor(network_params, dtype)
reg_0_1.sample()
set_seed(12)
reg_0_1.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_0_1.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 388.73456237145854
100 88.69788789042978
150 437.65747115638777
200 360.9698446092441
250 14.0200013428039
300 54.09663524617052
350 1068.6559823085408
400 136.36966810480206
450 11.691932010777707
500 18.468431865797314
550 18.050949981244845
600 1.3835718966533415
650 134.46238056566696
700 576.3509479880349
750 7.161322295516758
800 24.581630877222246
850 113.74300090274163
900 28.29298684925505
950 84.99130702186808
1000 14.435492679106497
1050 40.430257773031414
1100 2.5353330707348722
1150 28.957278581972172
1200 118.0166461816586
1250 1.6768401851932808
1300 3.0104326304860463
1350 31.623523623268547
1400 71.77710400791258
1450 4.753419834774547
1500 304.64988554497137
1550 9.51290247628384
1600 17.94922672589043
1650 164.70512548328477
1700 274.9112434016648
1750 281.54404939716744
1800 25.56108465303493
1850 2.043172672875921
1900 59.44085012369049
1950 3.2713840729340773
2000 2.0829644000534198
2050 3.119144115017727
2100 7.591842801359006
2150 4.354268341843744
2200 48.55727368494048
2250 23.371069900725583
2300 74.79783681428091
2350 35.2133360820646
2400 1.0541608253737385
2450 65.28157751202643
2500 18.137909117503508
2550 117.90522746671968
2600 2.6985220421212834
2650 2.3634098833464456
2700 49.65136374747366
2750 0.28846315869987
2800 2.0518009118333915
2850 20.577521584785856
2900 1.9844978206443213
2950 10.260457192947376
3000 8.055728754651602
3050 1.5077594179757678
3100 0.006828526520704137
3150 2.953815923113699
3200 1.8232166181922622
3250 0.2308171919447523
3300 15.734246559003752
3350 1.5534466050543534
3400 3.039773606645758
3450 13.737899349276798
3500 28.700602790287903
3550 34.32360620478795
3600 59.83638086680549
3650 169.88244833651714
3700 10.150890755474665
3750 30.413393052696218
3800 3.5855871013000304
3850 235.4658328562918
3900 2.7043047434827088
3950 31.97303799738628
4000 8.218729966466015
4050 27.888608804151634
4100 1.4894517300239507
4150 26.14561954336841
4200 16.313539128820917
4250 11.746717655829684
4300 5.766154537490384
4350 2.5984731283610945
4400 5.162714711829972
4450 0.018498756194910974
4500 1.696656522844047
4550 19.18172059729333
4600 52.98159996728522
4650 3.2664471190228053
4700 92.58095078366794
4750 5.614515792092439
4800 115.34698518422215
4850 4.552357193501702
4900 44.299040525622644
4950 3.874225445818281
5000 2.605655319997733
5050 3.7761266137993226
5100 26.389654878223677
5150 0.12471237052529467
5200 1.495630564314029
5250 14.256951547214042
5300 0.006384049242251365
5350 8.554109810916655
5400 12.814618262193788
5450 7.699312963499961
5500 39.528310257970354
5550 12.038606015166739
5600 3.0082001933307487
5650 69.68414242637422
5700 2.288206646522855
5750 3.6354541727759773
5800 0.14370226170825226
5850 4.384481999117534
5900 28.056603763836147
5950 18.69105729291988
6000 3.3109676286009653
6050 14.401930686826097
6100 0.7223203921975111
6150 15.938269202970297
6200 5.854672403065575
6250 7.999879249959579
6300 52.26219210380483
6350 0.8096884125528422
6400 9.449904950997075
6450 3.8675331977707392
6500 1.839125074714977
6550 0.2618961651024706
6600 9.992224091293318
6650 2.9746265006945283
6700 0.9311672336651962
6750 0.3999433401866484
6800 1.4794759815365819
6850 1.9139332781011835
6900 19.867148681775916
6950 0.6123083127780966
7000 37.86485202511639
7050 1.9436552240066443
7100 37.19628019831543
7150 5.318472856372965
7200 4.164878890410498
7250 3.760927385161404
7300 3.269730879576889
7350 6.489528721148644
7400 57.585798607349666
7450 2.6800069251550473
7500 14.73976287643562
7550 9.985056170253067
7600 2.765262419083943
7650 3.292846015300967
7700 6.385155481232444
7750 3.981169770270253
7800 2.1778368884413695
7850 0.22119956340173072
7900 4.489186157175981
7950 9.893935014459618
8000 0.2062983086121154
8050 11.395096244044712
8100 1.7379424089353979
8150 25.873273753419944
8200 21.90918536771983
8250 10.603377736151039
8300 12.1816065423373
8350 13.09192350571782
8400 7.76066460031859
8450 13.73585919906925
8500 0.4981798472063088
8550 20.9420415212343
8600 12.821664806095168
8650 2.5316330936703046
8700 7.439871531185817
8750 0.7899739721651804
8800 0.7895693113418667
8850 0.5560888840789397
8900 0.7649791011901302
8950 29.022162518265144
9000 11.276828213662705
9050 24.265093575352342
9100 9.661138616361608
9150 3.5278645749053594
9200 5.503902367643832
9250 2.8123126610758007
9300 0.38650179502627074
9350 0.5336641198272873
9400 6.381888158684965
9450 0.6363607117379992
9500 15.555780933914306
9550 6.582802352654547
9600 1.2639924231266446
9650 1.3157689058991087
9700 0.007009283690062844
9750 3.8791161043321734
9800 0.7489173644621192
9850 8.903735465632744
9900 2.0131478332335226
9950 17.47794764486703
10000 21.61125329314114
In [4]:
# No regularization
reg_exp = 0
reg_var = -100
set_seed(10)
reg_0_10 = my_torch.DPPRegressor(network_params, dtype)
reg_0_10.sample()
set_seed(12)
reg_0_10.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_0_10.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 388.73456237145854
100 4.0781963827723535
150 1815.4398724721232
200 5.190246204985884
250 44.19407300545149
300 104.72411326004682
350 0.22628701437233184
400 0.4872889101087149
450 24.044048712290056
500 14.415617357780816
550 17.528575462519637
600 164.4008184938153
650 59.71343854497914
700 3.0322264800555425
750 0.05018239175051941
800 29.885879450553105
850 23.51402999522097
900 0.4077278225047504
950 2.9740322300774795
1000 18.13397471636944
1050 0.023312135472646794
1100 105.30149754512759
1150 67.80051103430998
1200 43.20578675661129
1250 29.024704694072785
1300 61.31764429123104
1350 0.6830261035961921
1400 5.898346582498724
1450 0.6939019989605806
1500 215.8662055154824
1550 0.00713824617241487
1600 0.333047107575485
1650 116.15771851404963
1700 33.508512534412844
1750 2.4076097547011903
1800 24.546772925001637
1850 481.16863632474644
1900 20.398653401283656
1950 18.97127531567921
2000 33.99589424421268
2050 52.71309914773321
2100 0.287316844480957
2150 0.5538526463144504
2200 34.399259938920636
2250 0.006410110147056698
2300 18.203606452873885
2350 67.96123792476013
2400 4.918992156047424
2450 17.383331718630064
2500 7.245327018266852
2550 7.2918299144329115
2600 32.05625156260708
2650 30.792857179093406
2700 148.0130118065442
2750 0.7306002335115205
2800 8.637274175446517
2850 5.4120789228451525
2900 3.282498188484279
2950 0.735676283317594
3000 10.746607480493676
3050 0.6656777997290542
3100 0.1935368969315461
3150 26.499673677447923
3200 2.079564297139012
3250 0.8333552850630598
3300 1.8741154801460853
3350 15.212867244329846
3400 15.978563804299696
3450 42.6327902040561
3500 2.1135205689683896
3550 0.7669235238860859
3600 0.13754806723104396
3650 5.425072980651146
3700 10.794806839555143
3750 52.72341539674129
3800 0.05210869543723802
3850 9.987020767769714
3900 9.29715425996916
3950 0.29680293648298456
4000 0.008206757116436869
4050 0.953328394660253
4100 4.778109858859063
4150 6.319508302608678
4200 1.701280766829864
4250 0.7553570420367798
4300 0.6570156088017401
4350 97.70607441362768
4400 2.6058958358503763
4450 1.1786477902999042
4500 11.485176132022186
4550 18.323241561623192
4600 0.29184953696935223
4650 1.0308873448522222
4700 3.4708730873979308
4750 66.94652306794339
4800 7.1862380808373025
4850 12.485853878260093
4900 59.78940717575511
4950 4.7136845809004075
5000 18.67333946429655
5050 0.556012614923698
5100 17.449456671545427
5150 2.351468232362017
5200 34.3165385778945
5250 14.303247474205255
5300 0.08531867929896217
5350 15.111682223205689
5400 15.67809824498812
5450 30.30087010979841
5500 0.44058047987685384
5550 6.83442931576695
5600 26.3331419254336
5650 0.040904632405690595
5700 56.709942934300734
5750 2.303070160749634
5800 9.209687421823924
5850 6.168480120035913
5900 9.857602124312978
5950 12.271664245444935
6000 11.038926344449294
6050 45.95755006561012
6100 0.6643999240301756
6150 17.32421363325642
6200 13.009542018548094
6250 3.028082429107965
6300 28.625702938782876
6350 19.472644768869614
6400 1.4760691157774704
6450 3.884529836128307
6500 224.2939692364423
6550 102.93650462292582
6600 0.30214148576662087
6650 4.601133955484274
6700 8.924604257126182
6750 1.3030728368631568
6800 1.620845102409875
6850 20.56963643401612
6900 2.8252649367023284
6950 0.20254932270662196
7000 25.40991100578531
7050 2.9102119639087842
7100 5.657041550018775
7150 0.021884813846754426
7200 2.824101386332552
7250 0.621291737539716
7300 0.48828318945096677
7350 2.63399816545501
7400 5.1892955983725475
7450 11.281327543287745
7500 18.870082211024407
7550 5.028300002866229
7600 1.948712088803546
7650 16.213835439894492
7700 2.5467556553056196
7750 44.05921063584961
7800 18.008080776039844
7850 1.299839074295664
7900 12.287053253758952
7950 4.2356219640478985
8000 47.85599283392636
8050 3.1570616856518443
8100 1.2688811626027212e-06
8150 20.18513247966961
8200 12.949970268155228
8250 10.22175962722399
8300 32.14584161136257
8350 88.27136415673523
8400 1.5875856543420783
8450 10.886345113145682
8500 2.7470094870059416
8550 2.591131403824855
8600 3.143257677400674
8650 5.151658174095014
8700 3.60661931438789e-05
8750 1.183491590539113
8800 7.845798010372355
8850 2.654195483459489
8900 6.599800576186292
8950 4.739390135854805
9000 32.239554096586296
9050 0.5576460853386024
9100 2.9847115590147215
9150 0.017883539042669063
9200 0.30566298544070225
9250 143.39627311904925
9300 11.168097916131341
9350 0.6511810960347533
9400 0.008469234160860567
9450 13.306777810191102
9500 14.044522555303283
9550 18.778235132079
9600 0.11832967215219248
9650 11.934851557824448
9700 21.906212201946392
9750 6.161506878557295
9800 7.274452390618791
9850 17.818368591386196
9900 11.210522111737964
9950 11.088354008348594
10000 0.22824949883168028
In [5]:
# Regularization 1, 0
reg_exp = 1
reg_var = 0
set_seed(10)
reg_1_0 = my_torch.DPPRegressor(network_params, dtype)
reg_1_0.sample()
set_seed(12)
reg_1_0.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_1_0.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 141.22378110307216
100 425.7911257878702
150 56.56009661495951
200 649.0451737682885
250 2.072446018519317
300 161.56377165174445
350 12.445317733063407
400 24.530123576631766
450 102.6371681675969
500 84.99062536837816
550 65.98807330465377
600 21.418925700338807
650 28.369660376017382
700 25.80949010133908
750 12.827423091800432
800 233.82153867296708
850 64.15404546925112
900 101.86053750629455
950 1.939700023196348
1000 2.930206964181695
1050 0.07815479996682916
1100 190.67183103344126
1150 14.119153722249315
1200 59.21122146861177
1250 292.3125503217937
1300 393.6414311752284
1350 34.2420433429814
1400 19.301165359725893
1450 0.0033171245804758487
1500 33.065125457367714
1550 21.721337027379708
1600 9.723226891990157
1650 110.90227063275395
1700 0.3098454656591668
1750 11.073229507638057
1800 0.6523345983306797
1850 86.40054820802354
1900 9.094032519187614
1950 9.78447438382279
2000 2.0221965458533697
2050 1.069097356727613
2100 18.42378589625512
2150 2.676387780185051
2200 1.4325036908998252
2250 62.59229674739126
2300 6.773823018397137
2350 69.05252754146643
2400 18.154918605401598
2450 28.614068901984695
2500 0.7708618245856063
2550 41.946619990256366
2600 3.7357593668399507
2650 0.13423344524031638
2700 20.951219623681595
2750 28.731347981193
2800 3.3233635713732426
2850 0.056414342459254314
2900 0.005024278665617226
2950 29.533538198545102
3000 3.3267061455409896
3050 6.2003936456623086
3100 30.016052420445288
3150 16.698239327461053
3200 21.027159438431788
3250 14.83151922376005
3300 0.20881052408387427
3350 2.038523272796095
3400 43.94134663560922
3450 0.028624189612286747
3500 77.17429140283656
3550 39.18929704187107
3600 5.994115124423626
3650 133.56698319654666
3700 1.2620580895383542
3750 6.885482415587924
3800 3.1865930824211532
3850 2.900719834208407
3900 137.14419579544867
3950 51.936333775718865
4000 4.438920891215195
4050 0.45193988628572146
4100 9.434899674506868
4150 44.19012490998031
4200 22.8597766514522
4250 53.96372917114128
4300 0.09463034182049301
4350 1.8861812501611301
4400 27.85226833227982
4450 26.714340049541043
4500 18.44493890906253
4550 84.5691796934707
4600 0.06595616761079941
4650 8.44753080234643
4700 1.0321391789505918
4750 45.63084725102212
4800 2.2326723608293073
4850 4.117274507932873
4900 4.366819352535835
4950 9.167160310391335
5000 40.97322208488896
5050 10.402410096525065
5100 25.47815264924965
5150 8.192201803358788
5200 2.189136556560567
5250 46.80532233663846
5300 17.34074536177768
5350 7.810468547179065
5400 6.886749174393992
5450 1.2911337976312993
5500 10.229537257781027
5550 1.8747176198841833
5600 0.029631182761385613
5650 0.03833808871763795
5700 45.9441221913894
5750 1.5667368588105766
5800 8.044237488719316
5850 11.918467799846347
5900 36.75964045002472
5950 0.9153243399289835
6000 4.335721882549969
6050 5.95523347226158
6100 10.02233363628195
6150 42.4697428728032
6200 0.21638734538061935
6250 51.88986648605847
6300 0.9873515750315464
6350 0.699309723342145
6400 10.887027820575577
6450 4.153995914201112
6500 47.29263810402609
6550 1.575473613512781
6600 4.61614309012691
6650 0.5198779968232684
6700 15.010833058892704
6750 3.9237607652518953
6800 0.932337143953391
6850 12.304051623364797
6900 3.156924878600574
6950 27.075498747112125
7000 2.3935837983664254
7050 2.5804641648894173
7100 33.00062980367397
7150 21.784338150035804
7200 2.158623646828378
7250 1.5947883782278205
7300 0.5438686925394964
7350 0.008984597601528312
7400 28.49990204978141
7450 14.3966463779274
7500 12.150918700640071
7550 1.425860161726306
7600 21.91119711125728
7650 0.9458118060305006
7700 14.560351059902594
7750 0.4119638724113012
7800 9.770188630930651
7850 6.0400174579741135
7900 0.9024394335925237
7950 2.915579696276714
8000 0.00010413187775790238
8050 3.7526746361076104
8100 33.062110107721665
8150 2.3756030185117605
8200 0.8929599632002837
8250 2.1172602030544745
8300 95.87270433277605
8350 5.014782635378605
8400 11.392913376592439
8450 12.181455831262905
8500 3.254679561865991
8550 56.69239595054882
8600 18.428967675980996
8650 2.72449310281818
8700 0.2804544634242059
8750 1.3196970045655196
8800 52.940902333671005
8850 2.0081656839991253
8900 2.0674529820793546
8950 0.08263422780348074
9000 18.91227761592621
9050 1.2449380240574344
9100 27.472259833876205
9150 3.21210004659651
9200 42.06072247283314
9250 19.17700240745822
9300 0.0014247529154863673
9350 107.5329896778122
9400 3.2266654513905038
9450 2.0497512816847085
9500 2.685227308059002
9550 7.218896083886444
9600 1.0942184862737474
9650 13.7307311127427
9700 1.1406280306879002
9750 20.236903495646256
9800 0.46806724574911807
9850 0.5604748528590373
9900 16.71538669419072
9950 7.858823929952217
10000 3.1947789700378175
In [6]:
# Regularization 1, 1
reg_exp = 1
reg_var = -10
set_seed(10)
reg_1_1 = my_torch.DPPRegressor(network_params, dtype)
reg_1_1.sample()
set_seed(12)
reg_1_1.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_1_1.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 132.35586551969052
100 5186.057473762119
150 124.37522538912843
200 10.758258623037278
250 580.0064399681333
300 72.16524156282829
350 23.310218066947957
400 190.46504681909272
450 108.61250159955767
500 311.9263096496337
550 83.92178830510491
600 58.111194972999094
650 14.08276675272828
700 390.55639475069904
750 16.140297493627642
800 153.3584906904498
850 98.2038382624005
900 213.99118212006576
950 4.099034319746833
1000 174.36476242960788
1050 0.1006396145862461
1100 4.1293126982324
1150 21.131145784859
1200 17.85508806346375
1250 42.07441805172912
1300 73.58879617213528
1350 12.238835650771913
1400 75.78456684437315
1450 213.25890092835505
1500 106.40126697384412
1550 171.84560707410537
1600 28.246345906177662
1650 13.414876400410765
1700 31.133666329758984
1750 9.307795737071238
1800 17.70233139827492
1850 18.110219554691014
1900 0.08526743601640922
1950 13.969063010473509
2000 40.75122428393998
2050 10.842702698088656
2100 8.259494683940144
2150 18.170457331023453
2200 69.06737804237197
2250 17.23481758866424
2300 25.51937943358117
2350 2.652091760677946
2400 0.38437674016511003
2450 2.765499194431219
2500 0.4928455717609939
2550 1.831450109792627
2600 10.192391701300426
2650 16.932711788571115
2700 66.40637941015245
2750 50.69406221863105
2800 27.917550095692366
2850 21.14495803421042
2900 4.749068443675875
2950 0.07510969375169216
3000 66.81007582726177
3050 16.17038308509187
3100 0.6453563757875297
3150 30.118923768789205
3200 38.589136810092036
3250 6.881108637465219
3300 1.8392463221380173
3350 0.12033720045854387
3400 0.5703947065404855
3450 2.4687378938963356
3500 43.56853316270833
3550 42.00261351121467
3600 122.25110983622248
3650 2.02020686541798
3700 67.41021964648638
3750 3.673611005716088
3800 0.31391059810371014
3850 6.7013918666206695
3900 1.518935174725331
3950 10.203342313608134
4000 6.641208487545551
4050 0.01177043787508801
4100 0.9338083081113271
4150 63.35039924207549
4200 0.2279123116000464
4250 100.97600651563064
4300 18.25668863181579
4350 1.034678661335196
4400 4.200476946681983
4450 46.47525916233312
4500 16.801173821935283
4550 11.261846405983542
4600 9.824585241622511
4650 35.76847093506614
4700 2.1693160632545085
4750 1.9361468583830506
4800 19.377368394301783
4850 9.337580148467307
4900 20.09763808047304
4950 19.2600063162403
5000 105.75979852220422
5050 33.01422955352659
5100 1.159225326127392
5150 0.03736455678604073
5200 41.95164904053842
5250 22.937857849448758
5300 47.22646537056012
5350 0.056996512185015544
5400 0.007110838042449463
5450 37.44411247879331
5500 45.105490364776635
5550 2.144158967005824
5600 6.993866839806211
5650 0.3728198933500132
5700 0.7398061415044059
5750 9.53399635396769
5800 1.7746239520312264
5850 14.538256483417491
5900 151.57933330993959
5950 0.44546654855990103
6000 12.005094903359119
6050 0.19105603152738027
6100 28.870705904461204
6150 10.700357204335333
6200 8.390795645542216
6250 1.7402297678235712
6300 10.277336536219593
6350 7.660383389094305
6400 8.14132072794865
6450 8.547481624153372
6500 9.191615308239577
6550 0.08390810321688633
6600 2.625240965738307
6650 12.662476034773523
6700 1.4603462612097458
6750 11.105706073064603
6800 6.144162978764121
6850 0.78829305200533
6900 13.533080825702445
6950 0.07834328980244391
7000 35.15694498106099
7050 0.041382019694722584
7100 60.91089642131608
7150 13.855912772802881
7200 26.636146571534674
7250 0.8664443593628876
7300 7.066897230476251
7350 12.943252987878479
7400 14.76144015538606
7450 1.66391173941462
7500 5.387960329589076
7550 3.1347726175316866
7600 0.3764344516560382
7650 7.1411802631253
7700 34.83510891470291
7750 1.5839609230023088
7800 1.416611399845053
7850 16.526278781324848
7900 1.6905899722468056
7950 9.904853730573432
8000 0.8658379245204687
8050 12.139487894014206
8100 8.885897809083806
8150 2.983750262751317
8200 1.6020291621471627
8250 37.31841354105005
8300 1.2905624567270406
8350 0.04157170507433911
8400 1.6806403122405682
8450 1.9080770839967978
8500 32.307100556794126
8550 19.758180176355665
8600 0.004030880224088075
8650 7.010692456966239
8700 0.0442600871555926
8750 57.02336246072232
8800 7.346119731032733
8850 5.87336183690786
8900 1.275266808497035
8950 5.674220409020913
9000 28.63655439714711
9050 0.10247638889929377
9100 18.783897473373187
9150 2.554413947523064
9200 76.9479596117236
9250 0.033985544993541776
9300 1.5795041568081107
9350 33.16244494983989
9400 7.971833494134474
9450 12.797217921011717
9500 2.23321577362524
9550 6.431566033923113
9600 0.00851239734090671
9650 2.091947736316
9700 0.24017458742719935
9750 116.70578017251044
9800 4.699388834908185
9850 1.6905718361681548
9900 8.90644740043125
9950 0.48214781667359735
10000 15.349682666785277
In [7]:
# Regularization 1, 10
reg_exp = 1
reg_var = -100
set_seed(10)
reg_1_10 = my_torch.DPPRegressor(network_params, dtype)
reg_1_10.sample()
set_seed(12)
reg_1_10.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_1_10.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 132.35586551969052
100 37.08085070754004
150 90.58900207425732
200 845.477894522802
250 193.9288650815498
300 10.869553608382326
350 143.02825224069142
400 62.69213879220012
450 164.6402229859791
500 5.046126546540745
550 14.277271349678823
600 82.9547011579803
650 130.15203046527913
700 36.350880586370266
750 65.54206673539701
800 38.248163400990165
850 187.9095422069286
900 108.84183521899172
950 70.53711079021474
1000 0.34399448704978475
1050 6.185084090387771
1100 43.88251477408995
1150 48.80783819805051
1200 1.8779421273690535
1250 24.258798945859194
1300 54.50352857815559
1350 104.0730286930632
1400 7.712852338530799
1450 7.563675727560954
1500 0.07332912971069354
1550 30.128172352481165
1600 0.12156236593666832
1650 60.30617828300895
1700 16.232645917355104
1750 148.94443295207924
1800 1.1178226157197142
1850 1.9935556520155049
1900 51.41616490174391
1950 3.3135982197274934
2000 1.2654654810124037
2050 5.961935699386583
2100 6.812468113865848
2150 106.85387209782766
2200 36.13471414700073
2250 6.864738086252017
2300 17.394290720515155
2350 169.32542575952445
2400 0.014270040310070287
2450 12.673703218938298
2500 47.515532784947595
2550 0.3361519902718633
2600 12.699712576649151
2650 15.998340410073007
2700 8.131709522717138
2750 6.568511038853531
2800 8.388051859907634
2850 0.2241921349323603
2900 1.651736098527985
2950 77.17468548849945
3000 0.18907365051720298
3050 4.728604521818295
3100 8.272170551374803
3150 1.1907623152793485
3200 1.191272282473177
3250 209.29670590144443
3300 21.586976616020742
3350 22.472443081533857
3400 2.9370578645014898
3450 0.6462999588088081
3500 0.3538400414819438
3550 14.387994254782791
3600 34.15514112587252
3650 3.761787914862649
3700 31.416912867699
3750 3.4428125032441574
3800 14.940947986208224
3850 8.970143559848196
3900 6.072935738488203
3950 0.4628837699530225
4000 15.38428145081123
4050 4.233965950833399
4100 0.01908327771878729
4150 0.028975403012824887
4200 6.697200785287631
4250 6.596511520988253
4300 33.71331856730811
4350 0.8238107587399343
4400 20.210334127025632
4450 0.41595185279950503
4500 181.23000356498622
4550 9.718151407520065
4600 0.053432598486900365
4650 97.00449574982963
4700 2.9260535639565006
4750 0.006464349230711553
4800 1.9798240079798484
4850 51.28394445598927
4900 4.350254096578613
4950 2.2796001424051346
5000 10.792784193523657
5050 18.08004025786801
5100 0.48783320854311935
5150 0.3431449386489798
5200 2.9524232525239705e-05
5250 17.2861618531792
5300 47.15594417897287
5350 13.440716041682338
5400 0.17041574049362007
5450 2.0160942203532266
5500 3.540928072230905
5550 10.716489861396653
5600 6.992211561660098
5650 49.54297364682434
5700 0.3434337807248124
5750 5.294197506179475
5800 81.862320096142
5850 9.736166288284952
5900 6.002160070086634
5950 0.28250113676861033
6000 0.028008644678621023
6050 0.004901249378467376
6100 33.804849305141
6150 1.4654790080223712
6200 1.6003350794825626
6250 1.8058819175356087
6300 16.998457713802704
6350 10.481367113274016
6400 2.515728066428868
6450 36.44273532314279
6500 31.4538056484149
6550 14.007391379152214
6600 1.3181112789118525
6650 16.813721505317034
6700 58.95211651428061
6750 8.340107041019783
6800 2.3828428098229884
6850 1.0660227707070493
6900 0.3603778520218612
6950 10.978771228656473
7000 8.252180710465646
7050 4.657535927053562
7100 13.748358442693776
7150 10.609259473730841
7200 15.28478910572677
7250 6.25716502565551
7300 56.20067851564278
7350 0.45029831901219197
7400 3.722002301643642
7450 8.308093496653195
7500 1.8141476235627267
7550 4.658298725648609
7600 14.022058498804256
7650 73.16596309289153
7700 3.230415918858868
7750 5.362376289539357
7800 29.59560263618087
7850 1.4947137365311625
7900 8.565421581686023
7950 0.2853337340433594
8000 3.566528087481372
8050 53.83068590398822
8100 8.382079877029362
8150 0.8967028050324596
8200 4.195418306047242
8250 34.375337387353326
8300 0.7803114654050716
8350 10.36911534370263
8400 2.013600721236716
8450 38.755751752013666
8500 0.6963819238495687
8550 21.31614670676022
8600 1.9888050990592683
8650 0.3782488294964278
8700 19.1684713867693
8750 0.29073242629093565
8800 0.3798001314619371
8850 1.0242841415163981
8900 15.190758225631988
8950 0.9869171573079404
9000 29.067178792409486
9050 0.22009468096384874
9100 12.216313969207246
9150 12.554957561498346
9200 2.511011004324742
9250 38.51286086011568
9300 19.877302222171885
9350 1.3145733561284167
9400 17.98779981782076
9450 8.200296830422086
9500 0.005795212805069642
9550 5.497220440688717
9600 4.027262312310422
9650 22.25342219940882
9700 10.88427625674906
9750 3.490599476514143
9800 7.649858969252133
9850 16.268161855673423
9900 10.012621191894302
9950 1.2344103665031532
10000 1.8246865434086594
In [8]:
# Regularization 10, 0
reg_exp = 10
reg_var = 0
set_seed(10)
reg_10_0 = my_torch.DPPRegressor(network_params, dtype)
reg_10_0.sample()
set_seed(12)
reg_10_0.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_10_0.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 11.43464879666602
100 1295.2877817074148
150 13.084385141599443
200 112.97019970708307
250 10.206436801109371
300 102.45269879860041
350 1433.921832313504
400 43.56765507899718
450 477.70959927785395
500 188.8948174636128
550 0.10059584003108915
600 2.583661309587377
650 53.6236994391023
700 1.6641409887199565
750 39.559774396322595
800 22.59708653204231
850 87.81563874197886
900 0.02781939072535186
950 9.758304071016882
1000 79.26322692944825
1050 33.06107193602928
1100 14.041612406426376
1150 11.846349288216842
1200 68.74022096382406
1250 0.08238681804080808
1300 3.667765476639298
1350 28.522394786345288
1400 7.640599955278729
1450 47.36694097141782
1500 163.14280767663195
1550 14.447910104722743
1600 8.38398026727882
1650 5.832454546554348
1700 75.98606234673703
1750 74.79201694348204
1800 44.39254917537419
1850 1.502314691462455
1900 52.09977925660072
1950 81.19113477305346
2000 133.83325895356893
2050 6.655160070369408
2100 1.659935557399294
2150 52.8691011232204
2200 57.40966342148262
2250 4.957981856273439
2300 74.66103661855013
2350 6.514511391783545
2400 30.989062347391155
2450 45.45006799111398
2500 27.1320994505124
2550 6.133620504550558
2600 2.077506860280859
2650 89.23930031428998
2700 540.4738693896741
2750 71.9954340913235
2800 92.60969415267768
2850 44.0396219421527
2900 7.135859550498364
2950 0.0767111960504273
3000 16.050600315340912
3050 0.3223509143585856
3100 69.54391265118267
3150 12.394796523339338
3200 2.890554179441281
3250 83.96070759556805
3300 193.00606372845587
3350 0.0011515227128929399
3400 0.6963076050457269
3450 0.8748794738018836
3500 4.444867207403769
3550 3.516579412514459
3600 141.4163895822221
3650 40.3813892590075
3700 154.1552666859842
3750 44.792184227569145
3800 0.7152108047485516
3850 28.124424831220164
3900 13.066647815276548
3950 29.914766876009217
4000 8.738338551548537
4050 160.41624472007172
4100 98.25630838560136
4150 13.616996812258774
4200 8.555899700481453
4250 20.75391604728363
4300 23.957949096735064
4350 2.950621933131214
4400 55.363891323857615
4450 47.47668182943522
4500 7.365366383713
4550 9.807826200772679
4600 33.038448787781014
4650 73.35381829955553
4700 7.568887207832474
4750 0.12135640931242578
4800 21.553019926048254
4850 12.741156973054213
4900 17.19705234840791
4950 0.26398129888397315
5000 2.7428251690456706
5050 54.75821399937353
5100 102.95079499817984
5150 8.875236958452247
5200 0.010649089748796477
5250 1.998847474421633
5300 0.5068665916752253
5350 19.42520827723369
5400 33.37312314021613
5450 1.415347681119349
5500 18.35288708624735
5550 21.27858824498716
5600 3.3081669396166493
5650 9.339085294619531
5700 27.06680335540096
5750 76.6232444094836
5800 0.2078045807600018
5850 22.884418287201616
5900 37.212753360267605
5950 33.60076112120865
6000 5.708948278681832
6050 0.2504022965780917
6100 7.1172275109775605
6150 0.09013211928463491
6200 0.6046563320888954
6250 60.21681821615264
6300 0.03804032450892464
6350 93.48693888742007
6400 2.7990208474990883
6450 11.819051814053214
6500 3.2847738942960247
6550 49.72162929841164
6600 8.282663259330407
6650 17.95486591072181
6700 6.992922571359187
6750 0.03386094320766895
6800 0.4333409615337293
6850 4.674139583105524
6900 36.836870884572306
6950 2.940981245941111
7000 1.5451955198262848
7050 0.21835190060500162
7100 6.125168502566154
7150 1.1851209450543323
7200 21.33901751463689
7250 1.4591988109565661
7300 0.4168537209010781
7350 2.6250157378411694
7400 13.564765337570421
7450 36.246494525267224
7500 7.783959780129592
7550 70.13943448705079
7600 1.2334736049970911
7650 13.30578684140066
7700 144.6196525419249
7750 26.452603882910555
7800 1.7407680806647718
7850 4.80641817783951
7900 0.6448288933747176
7950 10.000861329239177
8000 0.44072485461879485
8050 12.618744739209061
8100 0.7284357580572244
8150 5.6283978769607135
8200 5.737132428790793
8250 11.202007737569673
8300 11.21860809224912
8350 9.329121489385319
8400 31.995717437712315
8450 1.880565723033486
8500 12.59175728712972
8550 8.811162104553452
8600 24.14276591011262
8650 5.0761078070486265
8700 35.51622668977828
8750 13.793801496811719
8800 18.20459511801426
8850 39.4188591941588
8900 0.3214493810631771
8950 25.602494417711952
9000 3.9559695229162846
9050 41.52433010530219
9100 9.511472955010117
9150 11.12552317291169
9200 0.9048151786387147
9250 96.53551116508945
9300 15.729509411241922
9350 8.760289743872526
9400 0.004393152508245462
9450 8.900663645605654
9500 0.14222505698147875
9550 0.29999837362472753
9600 4.247747505690754
9650 7.1879241356612065
9700 8.841493033498837
9750 2.3602691650981904
9800 5.054944684702307
9850 0.003021858473681899
9900 3.3824352109691
9950 2.4335511901100846
10000 33.58140680491616
In [9]:
# Regularization 10, 1
reg_exp = 10
reg_var = -10
set_seed(10)
reg_10_1 = my_torch.DPPRegressor(network_params, dtype)
reg_10_1.sample()
set_seed(12)
reg_10_1.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_10_1.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 11.43464879666602
100 786.8133105107227
150 84.82925724444529
200 19.75972404527528
250 373.3171752019576
300 1018.3407348807655
350 41.80054001602217
400 24.060077363219925
450 45.86298256614749
500 26.087790079511546
550 39.4122966850386
600 0.1131882173750225
650 191.20223933966005
700 51.50694306704027
750 40.813125987923236
800 79.19365506470596
850 16.334153563879646
900 9.808118228330944
950 7.954308111644877
1000 21.72478903230681
1050 2.3118159659807582
1100 186.9315426932835
1150 0.08333347184946481
1200 90.73491694093231
1250 36.89351266691971
1300 65.58530409850955
1350 139.98642320061538
1400 201.31924837542044
1450 0.271113252742676
1500 7.170859573604009
1550 47.8348237488909
1600 3.827655763830544
1650 84.3497279760445
1700 1.8707353377939835
1750 22.83653149566692
1800 0.06570849240476401
1850 9.881034065075205
1900 51.413543579362845
1950 14.941814185766358
2000 179.340078971245
2050 55.33725822889266
2100 16.334720662778377
2150 34.825446491972066
2200 0.22189011667803835
2250 118.54320720218061
2300 2.4783851645053385
2350 2.2977127394548407
2400 47.81463543174418
2450 0.0003869053511259894
2500 52.84504177247345
2550 0.45323620773905565
2600 0.9526598652487198
2650 45.11473839624406
2700 61.67429358998815
2750 26.40296509808917
2800 35.805819679918166
2850 170.4305784236446
2900 141.3268376742747
2950 27.149614786777224
3000 62.10003705934523
3050 0.14464995069082856
3100 92.28770401282654
3150 18.18653664952367
3200 29.506953876271957
3250 28.01192449112028
3300 16.737427092430703
3350 5.731916216776175
3400 1.8466392664026572
3450 67.30101000296722
3500 0.1164381266196151
3550 0.8387472466430476
3600 8.357093065098242
3650 22.51496431892845
3700 24.06691512657127
3750 0.25619844762837596
3800 21.25238150468909
3850 5.6753762102636385
3900 23.87522309020518
3950 0.7983277030608805
4000 9.411902361569698
4050 0.3536914485103066
4100 32.74643709027458
4150 179.7610760992555
4200 0.7890482795542685
4250 0.17770795696293223
4300 0.012805414698246675
4350 107.93059463073436
4400 4.778083616528902
4450 0.05028734341484654
4500 0.8964452773716093
4550 18.405874517475436
4600 1.3380146309727217
4650 8.351199704605898
4700 16.890178701378087
4750 0.7279553058862366
4800 1.2233883287226088
4850 5.499953601643867
4900 22.606217017022317
4950 0.9368849516158251
5000 28.657948621181763
5050 4.976170271100599
5100 8.818655123543248
5150 0.5977625412425862
5200 1.8885798427832907
5250 0.9931571984776109
5300 55.99151771515214
5350 63.43738315581274
5400 4.4850166412800805
5450 2.48012230088133
5500 2.0781930589528197
5550 53.674127670249916
5600 5.8331665348856045
5650 0.1920950856536024
5700 0.9766924267932433
5750 14.062183578188744
5800 40.541013298452775
5850 5.6801580880125195
5900 16.699191854853478
5950 152.6880694560334
6000 30.136847719094963
6050 0.5562804982078123
6100 71.62438048318094
6150 8.486559194821636
6200 1.2953085812740541
6250 21.483812558833073
6300 6.78730851096621
6350 31.612180461903847
6400 41.718130688160414
6450 0.13320870785186825
6500 27.449344077989572
6550 2.1809759483458935
6600 3.903047401966809
6650 6.497956410928505
6700 0.616741337811022
6750 1.215047779112971
6800 54.94280063963398
6850 2.1059908494119592
6900 77.83360562980799
6950 2.3768141896829826
7000 36.392486749494715
7050 38.33700156123033
7100 0.4681677894491763
7150 2.4455303889076667
7200 11.929616537238008
7250 21.457638273079997
7300 8.955938565293069
7350 1.2138879259854436
7400 1.8263192748849528
7450 95.1577431913559
7500 69.99143778681754
7550 2.8797986344858644
7600 1.7559373781682135
7650 53.104393224521765
7700 0.0026991061630235163
7750 10.00208617170669
7800 2.035163022689251
7850 4.071309603672318
7900 1.9354581216585294
7950 18.754474804226657
8000 90.88511735412953
8050 14.615260110407338
8100 5.426408247169063
8150 0.9726651329900936
8200 0.24339817806337805
8250 4.977706451425262
8300 6.623180740476319
8350 116.9993580796589
8400 22.774438739185285
8450 2.3529891715371427
8500 8.047540465371494
8550 0.037100895674451614
8600 0.5495299047838085
8650 2.2885900210520322
8700 21.90442933355348
8750 2.27047840687921
8800 7.257022199434456
8850 0.953637951575009
8900 0.1264399939203189
8950 7.072374508407159
9000 61.74759830485075
9050 0.3919793740537647
9100 65.20394906372042
9150 0.6550326584548051
9200 102.4586903040896
9250 4.444844736060039
9300 12.956709857573284
9350 41.39397272359388
9400 42.32372889567859
9450 0.22358253314420956
9500 23.154129095522624
9550 4.854207358699654
9600 0.2931081361726374
9650 1.1390096411583508
9700 0.032391218937572075
9750 51.82898839137723
9800 5.404464638198187
9850 44.47040308196219
9900 9.677626823835613
9950 19.491422583447193
10000 10.183192306896176
In [10]:
# Regularization 10, 10
reg_exp = 10
reg_var = -100
set_seed(10)
reg_10_10 = my_torch.DPPRegressor(network_params, dtype)
reg_10_10.sample()
set_seed(12)
reg_10_10.train_with_baseline(train_iter, batch_size, sample_iter, alpha_iter, lr, weight_decay, reg_exp, reg_var)
#reg_10_10.evaluate(100)
Target sampled was: 16.0
Network predicted:
-5.9872
[torch.DoubleTensor of size 1]
Resulting loss is: Variable containing:
483.4362
[torch.DoubleTensor of size 1]
Prediction was based on 29.0 observations.
50 11.43464879666602
100 411.58965063996953
150 0.2629803570422931
200 0.5811942104475872
250 746.4042701896462
300 79.52790522363676
350 7.394051298085309
400 60.001036529852605
450 1.5218678283068072
500 1.3708879469634754
550 13.752385451861294
600 8.807673992057438
650 13.113129600364923
700 37.26817322623325
750 1.2432747126995913
800 8.513075042842031
850 63.80326125411385
900 4.264394504078151
950 11.734300397099295
1000 0.26200001733339184
1050 10.646319166852985
1100 8.787934463310158
1150 51.63246036899557
1200 104.64511317646003
1250 0.5147253418696132
1300 117.6267730739011
1350 165.8129614781377
1400 4.460933195823218
1450 21.622097593834468
1500 20.534642735850678
1550 24.45387878842538
1600 0.39873010341798704
1650 0.1569058347648615
1700 17.07620922026448
1750 192.155779474208
1800 103.26232152268484
1850 8.816106793324458
1900 0.32378786165498574
1950 52.460778358882365
2000 4.299231461009928
2050 1.9700751848437406
2100 2.1780966559972437
2150 3.808392404377046
2200 98.22894875900717
2250 13.878762924804615
2300 6.0899547511068075
2350 123.82873419225864
2400 122.76789768745874
2450 6.261575569404033
2500 2.7684748441103926
2550 4.666686818728363
2600 196.9723556227107
2650 4.558270222432233
2700 121.77678538517586
2750 41.571300416877065
2800 7.3914283023257115
2850 0.00037861939789147803
2900 24.43330378748546
2950 42.25640101960961
3000 1.5604617553266606
3050 0.44597444766854966
3100 0.5822133171786578
3150 1.333549640655532
3200 38.050497558094875
3250 0.08494828505393026
3300 1.578698337005922
3350 0.872904093438178
3400 33.46284774527947
3450 121.6176386999591
3500 0.13205535986795677
3550 14.804315699011688
3600 47.39126173880645
3650 8.774904368533528
3700 1.2653986929025078
3750 44.207807933941744
3800 20.77378445904714
3850 5.226143455590943
3900 2.904768116116188
3950 0.2554929405886383
4000 98.7942615632094
4050 89.68620786756965
4100 0.17744815347849793
4150 70.84332390703761
4200 0.41412213716468355
4250 0.548532472516763
4300 0.6010453900323
4350 81.70808253771345
4400 11.073131251671613
4450 13.363101272599092
4500 53.29724567398966
4550 105.48463107084079
4600 17.15538888914136
4650 21.666201498313786
4700 0.0759424888712754
4750 23.235200450451032
4800 3.956260911927576
4850 0.49653937177376506
4900 20.27830344859575
4950 0.4221564213884099
5000 29.708142693880813
5050 0.8601902328488162
5100 35.1503941104847
5150 0.9599197073706294
5200 3.567506753165917
5250 7.381862715844313
5300 1.2563578954758507
5350 65.65660889166632
5400 4.454397996348354
5450 15.447322864888068
5500 106.58370467851249
5550 23.73169021525043
5600 0.002150621476265798
5650 1.3866610875922365
5700 0.623777186771403
5750 8.963788007781
5800 0.0006286989830582527
5850 11.919962769310617
5900 1.9765484578897816
5950 16.99715838139861
6000 1.852740923406535
6050 2.866677449636154
6100 8.2715351629558
6150 65.44449960030563
6200 5.950273299267193
6250 0.10738838329086334
6300 138.50988014877353
6350 3.7746078776196494
6400 43.90255896799173
6450 14.299826023226283
6500 14.77949516090363
6550 8.929584253513028
6600 1.768769548296194
6650 4.185496885369724
6700 69.12704268576246
6750 5.8601181498223065
6800 32.019300417047916
6850 20.787557881879728
6900 161.65795033778124
6950 2.8096968147058132
7000 17.701278270960454
7050 0.07513645857807891
7100 0.9254995654733366
7150 2.212853338910751
7200 12.999905241834718
7250 4.630529182078269
7300 4.005039209370843
7350 4.521121391057321
7400 20.878662091103035
7450 1.91089153613516
7500 18.115614431689853
7550 42.88480513709994
7600 7.549568882589087
7650 10.071226938526408
7700 61.79977541138927
7750 28.114678347252028
7800 81.84464889715233
7850 61.15590520286986
7900 22.349525625306327
7950 4.414275661409873
8000 45.01464446499983
8050 7.902596490724889
8100 9.642730088153913
8150 12.520829196770284
8200 3.641546576570235
8250 0.09488732336682566
8300 0.019735513899983313
8350 4.928675155074315
8400 178.8650365774628
8450 10.236085227627518
8500 9.320345007878585
8550 0.02632397685553373
8600 1.7983864739533932
8650 1.4145988529923836
8700 6.699247888541447
8750 57.22487469087637
8800 0.42728281116400285
8850 18.77923783164948
8900 0.56052575239063
8950 17.89244595626197
9000 0.31795739654933636
9050 15.755841093679372
9100 19.97130700386897
9150 26.13739741727079
9200 37.77971827245802
9250 4.042360124353679
9300 0.32770064528095183
9350 4.258578149362357
9400 13.772423272165778
9450 3.388557283406395
9500 12.27610782605213
9550 0.05403856757139928
9600 5.761911764779006
9650 29.175682840996974
9700 5.046043156451223
9750 1.5296110477562759
9800 79.20567441508572
9850 66.84980033131438
9900 0.48109125608521025
9950 2.6308844047960207
10000 11.932626784449141
In [11]:
loss_av_0_0, subset_mean_0_0, subset_var_0_0 = reg_0_0.evaluate(1000)
loss_av_0_1, subset_mean_0_1, subset_var_0_1 = reg_0_1.evaluate(1000)
loss_av_0_10, subset_mean_0_10, subset_var_0_10 = reg_0_10.evaluate(1000)
loss_av_1_0, subset_mean_1_0, subset_var_1_0 = reg_1_0.evaluate(1000)
loss_av_1_1, subset_mean_1_1, subset_var_1_1 = reg_1_1.evaluate(1000)
loss_av_1_10, subset_mean_1_10, subset_var_1_10 = reg_1_10.evaluate(1000)
loss_av_10_0, subset_mean_10_0, subset_var_10_0 = reg_10_0.evaluate(1000)
loss_av_10_1, subset_mean_10_1, subset_var_10_1 = reg_10_1.evaluate(1000)
loss_av_10_10, subset_mean_10_10, subset_var_10_10 = reg_10_10.evaluate(1000)
Average Loss is: Variable containing:
9.0506
[torch.DoubleTensor of size 1]
Average Subset Size is: 26.307999999999982
Subset Variance is: 0.8631006092691704
Average Loss is: Variable containing:
9.8390
[torch.DoubleTensor of size 1]
Average Subset Size is: 26.644000000000013
Subset Variance is: 1.1042632793432174
Average Loss is: Variable containing:
9.0834
[torch.DoubleTensor of size 1]
Average Subset Size is: 26.499000000000017
Subset Variance is: 0.6209283881312595
Average Loss is: Variable containing:
12.3702
[torch.DoubleTensor of size 1]
Average Subset Size is: 16.57
Subset Variance is: 0.40195272442405
Average Loss is: Variable containing:
11.4883
[torch.DoubleTensor of size 1]
Average Subset Size is: 17.521999999999966
Subset Variance is: 0.3219727068161745
Average Loss is: Variable containing:
9.3844
[torch.DoubleTensor of size 1]
Average Subset Size is: 22.543999999999976
Subset Variance is: 0.7408448877399381
Average Loss is: Variable containing:
14.4270
[torch.DoubleTensor of size 1]
Average Subset Size is: 10.905000000000005
Subset Variance is: 0.14935398222443708
Average Loss is: Variable containing:
15.6790
[torch.DoubleTensor of size 1]
Average Subset Size is: 10.934000000000017
Subset Variance is: 0.14135734933208752
Average Loss is: Variable containing:
20.1459
[torch.DoubleTensor of size 1]
Average Subset Size is: 12.115999999999994
Subset Variance is: 0.15328101222324667
In [13]:
import matplotlib.pyplot as plt
# Loss Table
var0 = list([loss_av_0_0, loss_av_1_0, loss_av_10_0])
var0 = [i.data[0] for i in var0]
var1 = list([loss_av_0_1, loss_av_1_1, loss_av_10_1])
var1 = [i.data[0] for i in var1]
var10 = list([loss_av_0_10, loss_av_1_10, loss_av_10_10])
var10 = [i.data[0] for i in var10]
plot0, = plt.plot(np.arange(3), var0, 'red', label='No Var Reg')
plot1, = plt.plot(np.arange(3), var1, 'blue', label='Var Reg = 100')
plot10, = plt.plot(np.arange(3), var10, 'green', label='Var Reg = 1000')
labels = ['0', '1', '10']
plt.legend(handles=[plot0, plot1, plot10])
plt.xticks(np.arange(3), labels)
plt.xlabel("Mean Regularization")
plt.title("MSE Loss")
plt.savefig('reg_loss.pdf',format='pdf')
plt.show()
In [14]:
# Subset Mean Table
var0 = list([subset_mean_0_0, subset_mean_1_0, subset_mean_10_0])
var1 = list([subset_mean_0_1, subset_mean_1_1, subset_mean_10_1])
var10 = list([subset_mean_0_10, subset_mean_1_10, subset_mean_10_10])
plot0, = plt.plot(np.arange(3), var0, 'red', label='No Var Reg')
plot1, = plt.plot(np.arange(3), var1, 'blue', label='Var Reg = 100')
plot10, = plt.plot(np.arange(3), var10, 'green', label='Var Reg = 1000')
labels = ['0', '1', '10']
plt.legend(handles=[plot0, plot1, plot10])
plt.xticks(np.arange(3), labels)
plt.xlabel("Mean Regularization")
plt.title("Subset Mean")
plt.savefig('reg_subset_mean.pdf',format='pdf')
plt.show()
In [16]:
# Subset Variance Table
var0 = list([subset_var_0_0, subset_var_1_0, subset_var_10_0])
var1 = list([subset_var_0_1, subset_var_1_1, subset_var_10_1])
var10 = list([subset_var_0_10, subset_var_1_10, subset_var_10_10])
plot0, = plt.plot(np.arange(3), var0, 'red', label='No Var Reg')
plot1, = plt.plot(np.arange(3), var1, 'blue', label='Var Reg = 100')
plot10, = plt.plot(np.arange(3), var10, 'green', label='Var Reg = 1000')
labels = ['0', '1', '10']
plt.legend(handles=[plot0, plot1, plot10])
plt.xticks(np.arange(3), labels)
plt.xlabel("Mean Regularization")
plt.title("Subset Variance")
plt.savefig('reg_subset_var.pdf',format='pdf')
plt.show()
Content source: mbp28/dpp_nets
Similar notebooks: