Licensed under the Apache License, Version 2.0.


In [0]:
import tensorflow as tf
import os

from experimental.attentive_uncertainty import attention
  # local file importfrom experimental.attentive_uncertainty.contextual_bandits.pretrain import train  # local file import

In [0]:
savedir = '/tmp/wheel_bandit/models/multitask'

num_target = 50
num_context = 512
data_hparams = tf.contrib.training.HParams(context_dim=2,
                                           num_actions=5,
                                           num_target=num_target,
                                           num_context=num_context)
X_HIDDEN_SIZE = 100
x_encoder_sizes = [X_HIDDEN_SIZE]*2
HIDDEN_SIZE = 64
latent_units = 32
global_latent_net_sizes = [HIDDEN_SIZE]*2 + [2*latent_units]
local_latent_net_sizes = [HIDDEN_SIZE]*3 + [2]
x_y_encoder_sizes = [HIDDEN_SIZE]*3
heteroskedastic_net_sizes = None
mean_att_type = attention.laplace_attention
scale_att_type_1 = attention.laplace_attention
scale_att_type_2 = attention.laplace_attention
att_type = 'multihead'
att_heads = 8
data_uncertainty = False

Prior Predictive + Freeform


In [0]:
uncertainty_type = 'attentive_freeform'
local_variational = False
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_prior_freeform_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)


it: 0, train nll: 175.937515259, mse: 326.51953125, local kl: 0.0 global kl: 7.74599175202e-05 valid nll: 201.171005249, mse: 393.021575928, local kl: 0.0 global kl: 0.000304171262542
Saving best model with MSE 393.02158
it: 50, train nll: 22.2970085144, mse: 111.704582214, local kl: 0.0 global kl: 2.50619686994e-05 valid nll: 21.1540718079, mse: 92.6402206421, local kl: 0.0 global kl: 2.43014455918e-05
Saving best model with MSE 92.64022
it: 100, train nll: 26.0193710327, mse: 79.1269226074, local kl: 0.0 global kl: 0.133565917611 valid nll: 27.6986160278, mse: 102.077888489, local kl: 0.0 global kl: 2.97966962535e-05
it: 150, train nll: 25.1414089203, mse: 65.0398178101, local kl: 0.0 global kl: 0.0110144298524 valid nll: 30.1620178223, mse: 98.4341812134, local kl: 0.0 global kl: 5.3404179198e-05
it: 200, train nll: 28.0258808136, mse: 81.0215988159, local kl: 0.0 global kl: 0.00028004439082 valid nll: 26.6753025055, mse: 81.8777999878, local kl: 0.0 global kl: 0.00149854901247
Saving best model with MSE 81.8778
it: 250, train nll: 23.7632274628, mse: 58.2351493835, local kl: 0.0 global kl: 0.000783208000939 valid nll: 38.8488388062, mse: 83.2493591309, local kl: 0.0 global kl: 0.000727659557015
it: 300, train nll: 19.205083847, mse: 41.1444244385, local kl: 0.0 global kl: 0.000405898375902 valid nll: 43.1957397461, mse: 77.4813232422, local kl: 0.0 global kl: 0.000416289956775
Saving best model with MSE 77.48132
it: 350, train nll: 39.9142684937, mse: 68.071723938, local kl: 0.0 global kl: 0.000708944164217 valid nll: 54.0199546814, mse: 86.8032989502, local kl: 0.0 global kl: 0.000189295475138
it: 400, train nll: 57.8735427856, mse: 82.9639053345, local kl: 0.0 global kl: 0.00455050496385 valid nll: 31.676984787, mse: 59.6367530823, local kl: 0.0 global kl: 0.000369597342797
Saving best model with MSE 59.636753
it: 450, train nll: 23.1233024597, mse: 34.5200767517, local kl: 0.0 global kl: 0.00112121377606 valid nll: 21.4771327972, mse: 37.7946510315, local kl: 0.0 global kl: 0.000349988229573
Saving best model with MSE 37.79465
it: 500, train nll: 31.0630187988, mse: 42.0417366028, local kl: 0.0 global kl: 0.000311297393637 valid nll: 64.5235290527, mse: 74.870803833, local kl: 0.0 global kl: 0.00184763874859
it: 550, train nll: 49.3603515625, mse: 43.594455719, local kl: 0.0 global kl: 0.00463156308979 valid nll: 67.2754440308, mse: 55.7209091187, local kl: 0.0 global kl: 0.000934383482672
it: 600, train nll: 91.9542236328, mse: 71.0717468262, local kl: 0.0 global kl: 0.00425210827962 valid nll: 65.0645828247, mse: 48.9892082214, local kl: 0.0 global kl: 0.000277478247881
it: 650, train nll: 67.4039077759, mse: 34.7135696411, local kl: 0.0 global kl: 0.00161040143576 valid nll: 164.669448853, mse: 41.9747467041, local kl: 0.0 global kl: 0.00184986332897
it: 700, train nll: 53.8370170593, mse: 42.5842475891, local kl: 0.0 global kl: 0.00239921198227 valid nll: 48.6750946045, mse: 35.552154541, local kl: 0.0 global kl: 0.000183649855899
Saving best model with MSE 35.552155
it: 750, train nll: 30.3498401642, mse: 22.5199356079, local kl: 0.0 global kl: 0.00145898701157 valid nll: 66.5486450195, mse: 53.7017974854, local kl: 0.0 global kl: 0.000259432912571
it: 800, train nll: 39.1559181213, mse: 32.0955352783, local kl: 0.0 global kl: 0.00161583290901 valid nll: 53.4604263306, mse: 36.5113716125, local kl: 0.0 global kl: 0.00176369841211
it: 850, train nll: 93.4288787842, mse: 76.8169021606, local kl: 0.0 global kl: 0.00163547787815 valid nll: 62.0754814148, mse: 39.76379776, local kl: 0.0 global kl: 0.00331573560834
it: 900, train nll: 374.378570557, mse: 43.763633728, local kl: 0.0 global kl: 0.0222362969071 valid nll: 82.7130889893, mse: 48.3118286133, local kl: 0.0 global kl: 0.000515884777997
it: 950, train nll: 108.048110962, mse: 52.1763725281, local kl: 0.0 global kl: 0.00107746920548 valid nll: 167.689193726, mse: 54.8975868225, local kl: 0.0 global kl: 0.000573885336053
it: 1000, train nll: 107.441390991, mse: 46.8828811646, local kl: 0.0 global kl: 0.00650559831411 valid nll: 122.758407593, mse: 51.6467514038, local kl: 0.0 global kl: 0.00058289890876
it: 1050, train nll: 64.0786437988, mse: 23.6010456085, local kl: 0.0 global kl: 0.000454349385109 valid nll: 193.08114624, mse: 48.9903793335, local kl: 0.0 global kl: 0.000291681528324
it: 1100, train nll: 104.451019287, mse: 25.7002296448, local kl: 0.0 global kl: 0.000133388501126 valid nll: 203.481552124, mse: 42.4076766968, local kl: 0.0 global kl: 0.00016043188225
it: 1150, train nll: 144.373535156, mse: 25.8542308807, local kl: 0.0 global kl: 0.000649867695756 valid nll: 107.315376282, mse: 27.9535694122, local kl: 0.0 global kl: 0.00114819279406
Saving best model with MSE 27.95357
it: 1200, train nll: 173.85244751, mse: 40.015045166, local kl: 0.0 global kl: 0.000858941755723 valid nll: 146.709823608, mse: 49.6794128418, local kl: 0.0 global kl: 0.00156588945538
it: 1250, train nll: 214.810455322, mse: 36.4093971252, local kl: 0.0 global kl: 0.00105423945934 valid nll: 179.104324341, mse: 66.3396072388, local kl: 0.0 global kl: 0.00155196967535
it: 1300, train nll: 158.274002075, mse: 25.240480423, local kl: 0.0 global kl: 0.00102779874578 valid nll: 122.023406982, mse: 35.894695282, local kl: 0.0 global kl: 0.000609370996244
it: 1350, train nll: 82.1165390015, mse: 14.3614940643, local kl: 0.0 global kl: 0.00199447828345 valid nll: 143.962493896, mse: 29.3612594604, local kl: 0.0 global kl: 0.00232246774249
it: 1400, train nll: 15729341.0, mse: 42.9917640686, local kl: 0.0 global kl: 0.00628226343542 valid nll: 155.054290771, mse: 44.7137947083, local kl: 0.0 global kl: 0.00454936083406
it: 1450, train nll: 159.817367554, mse: 35.5269241333, local kl: 0.0 global kl: 0.00690871616825 valid nll: 208.922241211, mse: 41.3580665588, local kl: 0.0 global kl: 0.000236136445892
it: 1500, train nll: 48.072353363, mse: 8.7271490097, local kl: 0.0 global kl: 0.00427036499605 valid nll: 122.146614075, mse: 28.1235237122, local kl: 0.0 global kl: 0.000438428542111
it: 1550, train nll: 55.7161636353, mse: 21.9884262085, local kl: 0.0 global kl: 0.0025177304633 valid nll: 144.678482056, mse: 30.3698978424, local kl: 0.0 global kl: 0.00039319723146
it: 1600, train nll: 50.1324539185, mse: 19.5146865845, local kl: 0.0 global kl: 0.00671525811777 valid nll: 57.4658584595, mse: 17.0358753204, local kl: 0.0 global kl: 0.00238785753027
Saving best model with MSE 17.035875
it: 1650, train nll: 94.8223571777, mse: 19.5802383423, local kl: 0.0 global kl: 0.00256366981193 valid nll: 87.6506118774, mse: 31.6490821838, local kl: 0.0 global kl: 0.00365137448534
it: 1700, train nll: 58.8786392212, mse: 17.052986145, local kl: 0.0 global kl: 0.00181298074313 valid nll: 181.882461548, mse: 31.7792205811, local kl: 0.0 global kl: 0.0148546798155
it: 1750, train nll: 108.737731934, mse: 20.699678421, local kl: 0.0 global kl: 0.00290898606181 valid nll: 133.079498291, mse: 40.5060157776, local kl: 0.0 global kl: 0.00213967030868
it: 1800, train nll: 248.159515381, mse: 36.8738822937, local kl: 0.0 global kl: 0.00533189903945 valid nll: 249.039886475, mse: 32.6633644104, local kl: 0.0 global kl: 0.00171226006933
it: 1850, train nll: 112.717048645, mse: 14.9048213959, local kl: 0.0 global kl: 0.00423793215305 valid nll: 134.260559082, mse: 25.1026859283, local kl: 0.0 global kl: 0.000477991299704
it: 1900, train nll: 179.266815186, mse: 24.4300422668, local kl: 0.0 global kl: 0.0187510959804 valid nll: 228.461196899, mse: 35.6555671692, local kl: 0.0 global kl: 0.000614507880528
it: 1950, train nll: 181.938430786, mse: 30.425819397, local kl: 0.0 global kl: 0.000419783464167 valid nll: 156.213027954, mse: 22.6093521118, local kl: 0.0 global kl: 0.00150703755207
it: 2000, train nll: 129.426025391, mse: 16.8515701294, local kl: 0.0 global kl: 0.0109331822023 valid nll: 147.028121948, mse: 28.5352725983, local kl: 0.0 global kl: 0.000748521531932
it: 2050, train nll: 117.733482361, mse: 20.9740638733, local kl: 0.0 global kl: 0.00100230320822 valid nll: 162.604949951, mse: 22.3817062378, local kl: 0.0 global kl: 0.000512873695698
it: 2100, train nll: 179.282546997, mse: 25.4756660461, local kl: 0.0 global kl: 0.0189961418509 valid nll: 128.173416138, mse: 33.7380027771, local kl: 0.0 global kl: 0.00189703493379
it: 2150, train nll: 668.248168945, mse: 23.7144184113, local kl: 0.0 global kl: 0.000674539944157 valid nll: 219.587356567, mse: 23.6968250275, local kl: 0.0 global kl: 0.00166850909591
it: 2200, train nll: 83.2928466797, mse: 21.2647094727, local kl: 0.0 global kl: 0.00170300656464 valid nll: 76.0718078613, mse: 11.3744277954, local kl: 0.0 global kl: 0.00302705448121
Saving best model with MSE 11.374428
it: 2250, train nll: 384.847320557, mse: 17.7979106903, local kl: 0.0 global kl: 0.000613991054706 valid nll: 344.137329102, mse: 21.456905365, local kl: 0.0 global kl: 0.00177285447717
it: 2300, train nll: 502.164154053, mse: 23.3764705658, local kl: 0.0 global kl: 0.0101530542597 valid nll: 148.600341797, mse: 19.0851745605, local kl: 0.0 global kl: 0.00195488496684
it: 2350, train nll: 212.490432739, mse: 18.0885105133, local kl: 0.0 global kl: 0.0144316311926 valid nll: 427.730682373, mse: 17.95977211, local kl: 0.0 global kl: 0.000540162553079
it: 2400, train nll: 368.462738037, mse: 13.811709404, local kl: 0.0 global kl: 0.00224884646013 valid nll: 183.857635498, mse: 13.765209198, local kl: 0.0 global kl: 0.00254346942529
it: 2450, train nll: 353.172485352, mse: 23.302154541, local kl: 0.0 global kl: 0.00576397730038 valid nll: 111.62789917, mse: 28.5533866882, local kl: 0.0 global kl: 0.000604407687206
it: 2500, train nll: 304.672119141, mse: 24.8727359772, local kl: 0.0 global kl: 0.0140555705875 valid nll: 243.122375488, mse: 24.7560214996, local kl: 0.0 global kl: 0.00183136144187
it: 2550, train nll: 897.093078613, mse: 24.0963401794, local kl: 0.0 global kl: 0.0113578857854 valid nll: 413.810852051, mse: 50.1311187744, local kl: 0.0 global kl: 0.000211963008041
it: 2600, train nll: 394.649169922, mse: 15.9926195145, local kl: 0.0 global kl: 0.000780132075306 valid nll: 208.862838745, mse: 16.1162910461, local kl: 0.0 global kl: 0.00038759369636
it: 2650, train nll: 460.61831665, mse: 36.3305625916, local kl: 0.0 global kl: 0.00234078429639 valid nll: 158.262863159, mse: 19.0801105499, local kl: 0.0 global kl: 0.00161935202777
it: 2700, train nll: 579.330688477, mse: 37.0733718872, local kl: 0.0 global kl: 0.000393015565351 valid nll: 319.370391846, mse: 22.8380203247, local kl: 0.0 global kl: 0.000571494572796
it: 2750, train nll: 109.791069031, mse: 13.8254823685, local kl: 0.0 global kl: 0.000727368984371 valid nll: 47907.0390625, mse: 28.8177471161, local kl: 0.0 global kl: 0.00122091802768
it: 2800, train nll: 298.697631836, mse: 25.802942276, local kl: 0.0 global kl: 0.00204421230592 valid nll: 209.830490112, mse: 44.2017250061, local kl: 0.0 global kl: 0.00166806881316
it: 2850, train nll: 350.934417725, mse: 21.6358623505, local kl: 0.0 global kl: 0.000892611395102 valid nll: 223.445556641, mse: 26.770778656, local kl: 0.0 global kl: 0.000714973895811
it: 2900, train nll: 219.150482178, mse: 11.5131816864, local kl: 0.0 global kl: 0.000296456128126 valid nll: 172.302017212, mse: 24.3013629913, local kl: 0.0 global kl: 0.000933325965889
it: 2950, train nll: 180.33744812, mse: 24.4229640961, local kl: 0.0 global kl: 0.00494115334004 valid nll: 157.082366943, mse: 17.5903587341, local kl: 0.0 global kl: 0.00214592181146
it: 3000, train nll: 745.557678223, mse: 37.7700843811, local kl: 0.0 global kl: 0.00280232634395 valid nll: 203.621566772, mse: 24.5339221954, local kl: 0.0 global kl: 0.000508516968694
it: 3050, train nll: 291.633605957, mse: 22.2962532043, local kl: 0.0 global kl: 0.0140312537551 valid nll: 400.011688232, mse: 23.4796485901, local kl: 0.0 global kl: 0.00149891432375
it: 3100, train nll: 210.67074585, mse: 21.4663448334, local kl: 0.0 global kl: 0.00792724452913 valid nll: 356.457366943, mse: 31.593870163, local kl: 0.0 global kl: 0.00426168181002
it: 3150, train nll: 256.41192627, mse: 20.3176212311, local kl: 0.0 global kl: 0.00452661700547 valid nll: 183.033248901, mse: 22.9176368713, local kl: 0.0 global kl: 0.00279369344935
it: 3200, train nll: 349.191894531, mse: 23.3969039917, local kl: 0.0 global kl: 0.00405071349815 valid nll: 152.4634552, mse: 12.9167032242, local kl: 0.0 global kl: 0.0028993752785
it: 3250, train nll: 193.05569458, mse: 15.411485672, local kl: 0.0 global kl: 0.0318362861872 valid nll: 220.579650879, mse: 15.6433515549, local kl: 0.0 global kl: 0.00136447360273
it: 3300, train nll: 639.169799805, mse: 39.9528846741, local kl: 0.0 global kl: 0.00879805162549 valid nll: 318.141784668, mse: 21.775598526, local kl: 0.0 global kl: 0.00322509673424
it: 3350, train nll: 282.56829834, mse: 17.1449317932, local kl: 0.0 global kl: 0.00656696688384 valid nll: 187.910125732, mse: 19.2117881775, local kl: 0.0 global kl: 0.00217941263691
it: 3400, train nll: 288.361206055, mse: 14.0278396606, local kl: 0.0 global kl: 0.0324049815536 valid nll: 696.762329102, mse: 23.0473575592, local kl: 0.0 global kl: 0.00265448261052
it: 3450, train nll: 459.971466064, mse: 16.0699253082, local kl: 0.0 global kl: 0.0102316224948 valid nll: 255.927062988, mse: 21.6609916687, local kl: 0.0 global kl: 0.00198068562895
it: 3500, train nll: 1348.44970703, mse: 25.4045715332, local kl: 0.0 global kl: 0.00131788302679 valid nll: 579.089660645, mse: 28.993812561, local kl: 0.0 global kl: 0.00300658331253
it: 3550, train nll: 390.452545166, mse: 16.2007007599, local kl: 0.0 global kl: 0.00712493527681 valid nll: 564.828430176, mse: 18.5121822357, local kl: 0.0 global kl: 0.00248844875023
it: 3600, train nll: 491.96685791, mse: 16.2878494263, local kl: 0.0 global kl: 0.000587065413129 valid nll: 3130.90698242, mse: 27.8105297089, local kl: 0.0 global kl: 0.00340137956664
it: 3650, train nll: 915.187011719, mse: 27.5548210144, local kl: 0.0 global kl: 0.00768698379397 valid nll: 676.643920898, mse: 34.7733421326, local kl: 0.0 global kl: 0.00377276958898
it: 3700, train nll: 866.558959961, mse: 12.0348129272, local kl: 0.0 global kl: 0.00223161862232 valid nll: 665.394592285, mse: 26.5769996643, local kl: 0.0 global kl: 0.00163154769689
it: 3750, train nll: 773.940368652, mse: 29.3226394653, local kl: 0.0 global kl: 0.00473322626203 valid nll: 996.989868164, mse: 19.5097084045, local kl: 0.0 global kl: 0.00433530518785
it: 3800, train nll: 304.49987793, mse: 9.96506309509, local kl: 0.0 global kl: 0.00437247939408 valid nll: 809.164733887, mse: 17.6647014618, local kl: 0.0 global kl: 0.00617053406313
it: 3850, train nll: 236.793655396, mse: 15.2822589874, local kl: 0.0 global kl: 0.0122405011207 valid nll: 1264.58581543, mse: 25.7225379944, local kl: 0.0 global kl: 0.00366962142289
it: 3900, train nll: 1105.01599121, mse: 22.9985427856, local kl: 0.0 global kl: 0.00373829016462 valid nll: 505.254058838, mse: 27.4022064209, local kl: 0.0 global kl: 0.00556020019576
it: 3950, train nll: 963.687561035, mse: 17.4834289551, local kl: 0.0 global kl: 0.00390942348167 valid nll: 2211.31469727, mse: 27.2531547546, local kl: 0.0 global kl: 0.00794936995953
it: 4000, train nll: 939.034851074, mse: 10.2903175354, local kl: 0.0 global kl: 0.00516132544726 valid nll: 1118.68884277, mse: 25.1669883728, local kl: 0.0 global kl: 0.00656151399016
it: 4050, train nll: 1383.69934082, mse: 16.8525943756, local kl: 0.0 global kl: 0.00226892181672 valid nll: 1750.10632324, mse: 23.0165367126, local kl: 0.0 global kl: 0.0129666104913
it: 4100, train nll: 1273.30371094, mse: 19.7939796448, local kl: 0.0 global kl: 0.00571082578972 valid nll: 1192.47961426, mse: 16.88697052, local kl: 0.0 global kl: 0.01808960177
it: 4150, train nll: 530.357971191, mse: 18.0563030243, local kl: 0.0 global kl: 0.00442619249225 valid nll: 2310.63037109, mse: 43.6963806152, local kl: 0.0 global kl: 0.0214758906513
it: 4200, train nll: 409.84286499, mse: 19.7053508759, local kl: 0.0 global kl: 0.00539143104106 valid nll: 4183.92626953, mse: 37.7939758301, local kl: 0.0 global kl: 0.0263043008745
it: 4250, train nll: 1909.58984375, mse: 23.6546192169, local kl: 0.0 global kl: 0.00825089588761 valid nll: 833.581176758, mse: 21.9529800415, local kl: 0.0 global kl: 0.0101474383846
it: 4300, train nll: 1435.78271484, mse: 20.6416816711, local kl: 0.0 global kl: 0.0104697616771 valid nll: 1467.21289062, mse: 20.6475467682, local kl: 0.0 global kl: 0.0202958025038
it: 4350, train nll: 742.608032227, mse: 23.0429058075, local kl: 0.0 global kl: 0.209233954549 valid nll: 1965.51428223, mse: 25.4598789215, local kl: 0.0 global kl: 0.0108548477292
it: 4400, train nll: 789.785705566, mse: 24.1470298767, local kl: 0.0 global kl: 0.0271407403052 valid nll: 492.891845703, mse: 10.454457283, local kl: 0.0 global kl: 0.0109418835491
Saving best model with MSE 10.454457
it: 4450, train nll: 2398.18603516, mse: 33.0084762573, local kl: 0.0 global kl: 0.0136565864086 valid nll: 1343.68261719, mse: 27.429725647, local kl: 0.0 global kl: 0.0271213110536
it: 4500, train nll: 720.459228516, mse: 11.3155126572, local kl: 0.0 global kl: 0.0200979337096 valid nll: 469.962646484, mse: 17.0670051575, local kl: 0.0 global kl: 0.0264908783138
it: 4550, train nll: 442.988586426, mse: 7.57100200653, local kl: 0.0 global kl: 0.00413058744743 valid nll: 420.335906982, mse: 12.2616853714, local kl: 0.0 global kl: 0.0129356533289
it: 4600, train nll: 396.816986084, mse: 9.20459938049, local kl: 0.0 global kl: 0.00245746574365 valid nll: 383.3956604, mse: 21.0967845917, local kl: 0.0 global kl: 0.0196737013757
it: 4650, train nll: 1364.63525391, mse: 21.7371406555, local kl: 0.0 global kl: 0.0064976871945 valid nll: 1174.58215332, mse: 32.5401763916, local kl: 0.0 global kl: 0.016362125054
it: 4700, train nll: 787.429382324, mse: 16.3875904083, local kl: 0.0 global kl: 0.00654621561989 valid nll: 2175.51269531, mse: 22.4966239929, local kl: 0.0 global kl: 0.0184947550297
it: 4750, train nll: 1736.61755371, mse: 72.6729125977, local kl: 0.0 global kl: 0.0285313632339 valid nll: 902.406738281, mse: 59.8774147034, local kl: 0.0 global kl: 0.0779447183013
it: 4800, train nll: 253.949966431, mse: 28.0697479248, local kl: 0.0 global kl: 0.0730386599898 valid nll: 557.613586426, mse: 38.7616424561, local kl: 0.0 global kl: 0.112974189222
it: 4850, train nll: 1250.76843262, mse: 13.3250865936, local kl: 0.0 global kl: 0.0130481701344 valid nll: 497.204376221, mse: 15.599738121, local kl: 0.0 global kl: 0.0576331987977
it: 4900, train nll: 279.938598633, mse: 9.04510402679, local kl: 0.0 global kl: 0.0104266107082 valid nll: 354.964172363, mse: 14.9960184097, local kl: 0.0 global kl: 0.0473644658923
it: 4950, train nll: 637.328308105, mse: 26.9567127228, local kl: 0.0 global kl: 0.0122087355703 valid nll: 701.312255859, mse: 11.278427124, local kl: 0.0 global kl: 0.0849697217345
it: 5000, train nll: 2514.86523438, mse: 16.1035671234, local kl: 0.0 global kl: 0.0102535840124 valid nll: 767.283752441, mse: 19.190990448, local kl: 0.0 global kl: 0.180517062545
it: 5050, train nll: 4031.28662109, mse: 40.5675239563, local kl: 0.0 global kl: 0.0714102834463 valid nll: 697.469726562, mse: 53.8911094666, local kl: 0.0 global kl: 0.0611666850746
it: 5100, train nll: 1156.02978516, mse: 24.2204265594, local kl: 0.0 global kl: 0.054726742208 valid nll: 747521.8125, mse: 17.7136669159, local kl: 0.0 global kl: 0.131854251027
it: 5150, train nll: 607.781860352, mse: 16.5714511871, local kl: 0.0 global kl: 0.0366778187454 valid nll: 1093.25695801, mse: 26.9516906738, local kl: 0.0 global kl: 0.214570209384
it: 5200, train nll: 1584.96826172, mse: 25.55818367, local kl: 0.0 global kl: 0.0841591060162 valid nll: 4162.12695312, mse: 13.9037389755, local kl: 0.0 global kl: 0.116646483541
it: 5250, train nll: 885.013122559, mse: 21.5114917755, local kl: 0.0 global kl: 0.0413937978446 valid nll: 1040.90307617, mse: 30.6172904968, local kl: 0.0 global kl: 0.0765329450369
it: 5300, train nll: 1176.51477051, mse: 15.9496307373, local kl: 0.0 global kl: 0.0328361392021 valid nll: 506.936767578, mse: 50.511844635, local kl: 0.0 global kl: 0.116803824902
it: 5350, train nll: 529.439025879, mse: 13.3948945999, local kl: 0.0 global kl: 0.0379370190203 valid nll: 1103.96728516, mse: 17.8897838593, local kl: 0.0 global kl: 0.152429401875
it: 5400, train nll: 957.305053711, mse: 24.9576663971, local kl: 0.0 global kl: 0.0501258596778 valid nll: 1753.42163086, mse: 23.9927387238, local kl: 0.0 global kl: 0.18190073967
it: 5450, train nll: 450.831176758, mse: 8.40579128265, local kl: 0.0 global kl: 0.0524239055812 valid nll: 504.920196533, mse: 27.4973125458, local kl: 0.0 global kl: 0.178640574217
it: 5500, train nll: 507.469848633, mse: 10.2284183502, local kl: 0.0 global kl: 0.0856596529484 valid nll: 399.988769531, mse: 13.2090263367, local kl: 0.0 global kl: 0.169554233551
it: 5550, train nll: 527.580688477, mse: 16.9135456085, local kl: 0.0 global kl: 0.115009047091 valid nll: 1030.33666992, mse: 22.7659378052, local kl: 0.0 global kl: 0.264447808266
it: 5600, train nll: 733.916259766, mse: 20.7030086517, local kl: 0.0 global kl: 0.0421675853431 valid nll: 866.53302002, mse: 26.2867298126, local kl: 0.0 global kl: 0.336214959621
it: 5650, train nll: 852.827026367, mse: 10.8238630295, local kl: 0.0 global kl: 0.0395294353366 valid nll: 3452.15991211, mse: 21.1967811584, local kl: 0.0 global kl: 0.414111942053
it: 5700, train nll: 437.791687012, mse: 13.1226129532, local kl: 0.0 global kl: 0.0834678262472 valid nll: 778.353027344, mse: 15.4664936066, local kl: 0.0 global kl: 0.273341953754
it: 5750, train nll: 236.656509399, mse: 11.4700937271, local kl: 0.0 global kl: 0.0371824428439 valid nll: 16834.6230469, mse: 25.5470867157, local kl: 0.0 global kl: 0.234191209078
it: 5800, train nll: 1607.98364258, mse: 17.9083747864, local kl: 0.0 global kl: 0.139628410339 valid nll: 7342.60791016, mse: 16.3807659149, local kl: 0.0 global kl: 0.219621926546
it: 5850, train nll: 118.701690674, mse: 4.92975473404, local kl: 0.0 global kl: 0.0390865691006 valid nll: 1337343.625, mse: 17.8901824951, local kl: 0.0 global kl: 0.388295978308
it: 5900, train nll: 686.525390625, mse: 18.1351547241, local kl: 0.0 global kl: 0.102634295821 valid nll: 507.656463623, mse: 18.0318241119, local kl: 0.0 global kl: 0.268114119768
it: 5950, train nll: 192.665786743, mse: 14.3864469528, local kl: 0.0 global kl: 0.0120946289971 valid nll: 27813.1835938, mse: 23.5041923523, local kl: 0.0 global kl: 0.146941512823
it: 6000, train nll: 166.421569824, mse: 7.31975030899, local kl: 0.0 global kl: 0.0962375700474 valid nll: 26171.8671875, mse: 37.2981147766, local kl: 0.0 global kl: 0.42278534174
it: 6050, train nll: 34915.1679688, mse: 23.8268928528, local kl: 0.0 global kl: 0.0786458030343 valid nll: 11545.5273438, mse: 21.260017395, local kl: 0.0 global kl: 0.237184613943
it: 6100, train nll: 5421.18359375, mse: 12.1578121185, local kl: 0.0 global kl: 0.0312141180038 valid nll: 37670.0429688, mse: 22.4740848541, local kl: 0.0 global kl: 0.205271080136
it: 6150, train nll: 710.78704834, mse: 19.4936714172, local kl: 0.0 global kl: 0.137616887689 valid nll: 3605.58496094, mse: 17.2633781433, local kl: 0.0 global kl: 0.21603910625
it: 6200, train nll: 467.395629883, mse: 8.67010974884, local kl: 0.0 global kl: 0.170949220657 valid nll: 1533.81762695, mse: 16.6785640717, local kl: 0.0 global kl: 0.224763989449
it: 6250, train nll: 9587.65820312, mse: 13.6195716858, local kl: 0.0 global kl: 0.0363052077591 valid nll: 1997.39038086, mse: 22.0066642761, local kl: 0.0 global kl: 0.14998999238
it: 6300, train nll: 13359.2529297, mse: 20.4335842133, local kl: 0.0 global kl: 0.0344087593257 valid nll: 135699.78125, mse: 18.0063819885, local kl: 0.0 global kl: 0.116489648819
it: 6350, train nll: 1412.41137695, mse: 6.48348379135, local kl: 0.0 global kl: 0.0528158061206 valid nll: 6316.9453125, mse: 19.0712928772, local kl: 0.0 global kl: 0.231977343559
it: 6400, train nll: 9013.58691406, mse: 19.903339386, local kl: 0.0 global kl: 0.0354501716793 valid nll: 109683.90625, mse: 14.2014970779, local kl: 0.0 global kl: 0.0879608616233
it: 6450, train nll: 631114.625, mse: 20.3567256927, local kl: 0.0 global kl: 0.0418804846704 valid nll: 2184390.25, mse: 23.6429634094, local kl: 0.0 global kl: 0.0727515369654
it: 6500, train nll: 29851.0917969, mse: 15.7776947021, local kl: 0.0 global kl: 0.0302359405905 valid nll: 951439.0, mse: 25.1213722229, local kl: 0.0 global kl: 0.152031511068
it: 6550, train nll: 658192.25, mse: 24.1463508606, local kl: 0.0 global kl: 0.103060148656 valid nll: 3702127.0, mse: 26.5307846069, local kl: 0.0 global kl: 0.309918701649
it: 6600, train nll: 4995.40966797, mse: 5.39249801636, local kl: 0.0 global kl: 0.0615633241832 valid nll: 31078.328125, mse: 22.570224762, local kl: 0.0 global kl: 0.377029746771
it: 6650, train nll: 1802969.5, mse: 18.2160701752, local kl: 0.0 global kl: 0.0325171686709 valid nll: 105127.960938, mse: 23.1363334656, local kl: 0.0 global kl: 0.122711338103
it: 6700, train nll: 42211.6367188, mse: 18.2319927216, local kl: 0.0 global kl: 0.0325404889882 valid nll: 37298.0117188, mse: 23.7340126038, local kl: 0.0 global kl: 0.105163455009
it: 6750, train nll: 34341.1875, mse: 8.81659603119, local kl: 0.0 global kl: 0.053913615644 valid nll: 20985.6523438, mse: 18.1665897369, local kl: 0.0 global kl: 0.203858226538
it: 6800, train nll: 934.164978027, mse: 30.7336883545, local kl: 0.0 global kl: 0.273960500956 valid nll: 2077049.5, mse: 27.9662456512, local kl: 0.0 global kl: 0.254088282585
it: 6850, train nll: 1261.98596191, mse: 23.8556308746, local kl: 0.0 global kl: 0.105684414506 valid nll: 404.570617676, mse: 30.0735626221, local kl: 0.0 global kl: 0.190547257662
it: 6900, train nll: 902654.0, mse: 6.36376094818, local kl: 0.0 global kl: 0.090479157865 valid nll: 839.419616699, mse: 13.9310379028, local kl: 0.0 global kl: 0.136617556214
it: 6950, train nll: 160.469329834, mse: 27.5663318634, local kl: 0.0 global kl: 0.157789543271 valid nll: 429.332763672, mse: 10.5130434036, local kl: 0.0 global kl: 0.0630095005035
it: 7000, train nll: 6962.73388672, mse: 22.4663734436, local kl: 0.0 global kl: 0.06052390486 valid nll: 3055.46875, mse: 32.1362838745, local kl: 0.0 global kl: 0.391781330109
it: 7050, train nll: 313.399230957, mse: 6.2939491272, local kl: 0.0 global kl: 0.0625967979431 valid nll: 6399.10839844, mse: 12.3227376938, local kl: 0.0 global kl: 0.368358135223
it: 7100, train nll: 13255.0878906, mse: 17.916841507, local kl: 0.0 global kl: 0.0278756432235 valid nll: 16269.0380859, mse: 18.6936206818, local kl: 0.0 global kl: 0.278957098722
it: 7150, train nll: 16713.2324219, mse: 17.6528625488, local kl: 0.0 global kl: 0.0207240097225 valid nll: 265804.53125, mse: 19.86236763, local kl: 0.0 global kl: 0.246056348085
it: 7200, train nll: 78139.3359375, mse: 22.9784946442, local kl: 0.0 global kl: 0.0539334788918 valid nll: 979674.1875, mse: 26.7743644714, local kl: 0.0 global kl: 0.108745232224
it: 7250, train nll: 364.159667969, mse: 10.0845994949, local kl: 0.0 global kl: 0.0277416165918 valid nll: 23120.7480469, mse: 38.0313224792, local kl: 0.0 global kl: 0.0436664298177
it: 7300, train nll: 2024125.625, mse: 14.9559764862, local kl: 0.0 global kl: 0.0802547931671 valid nll: 1842662.75, mse: 21.4315776825, local kl: 0.0 global kl: 0.187949299812
it: 7350, train nll: 6849.51806641, mse: 15.1358242035, local kl: 0.0 global kl: 0.0336205251515 valid nll: 2389500.25, mse: 15.5526351929, local kl: 0.0 global kl: 0.103199258447
it: 7400, train nll: 21196.5195312, mse: 31.058298111, local kl: 0.0 global kl: 0.252956718206 valid nll: 2891203.25, mse: 18.7614650726, local kl: 0.0 global kl: 0.0900831818581
it: 7450, train nll: 2503.78100586, mse: 23.7704219818, local kl: 0.0 global kl: 0.056113421917 valid nll: 1725508.625, mse: 24.8692893982, local kl: 0.0 global kl: 0.105130836368
it: 7500, train nll: 3792.39233398, mse: 9.73682785034, local kl: 0.0 global kl: 0.0403264313936 valid nll: 10239.8339844, mse: 21.7587776184, local kl: 0.0 global kl: 0.141787618399
it: 7550, train nll: 13075214.0, mse: 17.9838352203, local kl: 0.0 global kl: 0.0163481645286 valid nll: 2300823.75, mse: 22.0383739471, local kl: 0.0 global kl: 0.0764746516943
it: 7600, train nll: 4004424.75, mse: 18.9393634796, local kl: 0.0 global kl: 0.0154679547995 valid nll: 46573880.0, mse: 56.4321556091, local kl: 0.0 global kl: 0.0125160012394
it: 7650, train nll: 2491779.75, mse: 16.4662475586, local kl: 0.0 global kl: 0.115334652364 valid nll: 14804643.0, mse: 30.0032558441, local kl: 0.0 global kl: 0.304028689861
it: 7700, train nll: 2277959.25, mse: 15.6593313217, local kl: 0.0 global kl: 0.0339917317033 valid nll: 14977404.0, mse: 21.4159011841, local kl: 0.0 global kl: 0.105074599385
it: 7750, train nll: 2381949.0, mse: 24.3228626251, local kl: 0.0 global kl: 0.0662150382996 valid nll: 2283072.0, mse: 36.4692573547, local kl: 0.0 global kl: 0.340666085482
it: 7800, train nll: 2380547.25, mse: 8.82314968109, local kl: 0.0 global kl: 0.0630853697658 valid nll: 45824288.0, mse: 35.2816352844, local kl: 0.0 global kl: 0.366210728884
it: 7850, train nll: 2319417.5, mse: 14.9442224503, local kl: 0.0 global kl: 0.0348273292184 valid nll: 2488253.5, mse: 25.013053894, local kl: 0.0 global kl: 0.097087867558
it: 7900, train nll: 3009620.5, mse: 7.58460378647, local kl: 0.0 global kl: 0.0179281011224 valid nll: 21408058.0, mse: 18.5392742157, local kl: 0.0 global kl: 0.208560273051
it: 7950, train nll: 2351123.75, mse: 9.97539520264, local kl: 0.0 global kl: 0.0527716204524 valid nll: 2927881.25, mse: 18.4111061096, local kl: 0.0 global kl: 0.450086116791
it: 8000, train nll: 2547210.25, mse: 11.0298452377, local kl: 0.0 global kl: 0.101192951202 valid nll: 5691472.0, mse: 27.6529903412, local kl: 0.0 global kl: 0.234378427267
it: 8050, train nll: 2380733.25, mse: 12.5153503418, local kl: 0.0 global kl: 0.0387340895832 valid nll: 28529758.0, mse: 22.0554637909, local kl: 0.0 global kl: 0.185115545988
it: 8100, train nll: 2695642.0, mse: 8.55829429626, local kl: 0.0 global kl: 0.0294670220464 valid nll: 14749848.0, mse: 20.5022830963, local kl: 0.0 global kl: 0.103909030557
it: 8150, train nll: 15132498.0, mse: 19.0595321655, local kl: 0.0 global kl: 0.103580772877 valid nll: 14839877.0, mse: 25.1956768036, local kl: 0.0 global kl: 0.228136271238
it: 8200, train nll: 2531444.0, mse: 20.5781116486, local kl: 0.0 global kl: 0.0274085700512 valid nll: 29189558.0, mse: 33.052066803, local kl: 0.0 global kl: 0.198031663895
it: 8250, train nll: 12719537.0, mse: 13.9667606354, local kl: 0.0 global kl: 0.0219631977379 valid nll: 200132560.0, mse: 95.8951263428, local kl: 0.0 global kl: 0.105412684381
it: 8300, train nll: 150275.265625, mse: 17.0385799408, local kl: 0.0 global kl: 0.0549528114498 valid nll: 14650820.0, mse: 16.6947498322, local kl: 0.0 global kl: 0.0623693354428
it: 8350, train nll: 1662.29248047, mse: 12.6486377716, local kl: 0.0 global kl: 0.0655858963728 valid nll: 36643.125, mse: 11.8394346237, local kl: 0.0 global kl: 0.130972385406
it: 8400, train nll: 25279.0683594, mse: 30.1570663452, local kl: 0.0 global kl: 0.140379920602 valid nll: 474818.25, mse: 17.5504875183, local kl: 0.0 global kl: 0.0992715880275
it: 8450, train nll: 11050292.0, mse: 26.9032516479, local kl: 0.0 global kl: 0.100783765316 valid nll: 19091024.0, mse: 32.7924385071, local kl: 0.0 global kl: 0.10952784121
it: 8500, train nll: 1746.33300781, mse: 6.70943546295, local kl: 0.0 global kl: 0.0306603908539 valid nll: 3529693.75, mse: 15.3397140503, local kl: 0.0 global kl: 0.101229906082
it: 8550, train nll: 34221.5742188, mse: 13.9174642563, local kl: 0.0 global kl: 0.0484689883888 valid nll: 4248662.5, mse: 14.5163021088, local kl: 0.0 global kl: 0.119409367442
it: 8600, train nll: 1505493.5, mse: 10.7800483704, local kl: 0.0 global kl: 0.0648574978113 valid nll: 16592002.0, mse: 32.5623474121, local kl: 0.0 global kl: 0.108749531209
it: 8650, train nll: 881163.6875, mse: 8.80975437164, local kl: 0.0 global kl: 0.0635388866067 valid nll: 3446437.5, mse: 20.5146808624, local kl: 0.0 global kl: 0.0847040563822
it: 8700, train nll: 13344.2412109, mse: 12.7824668884, local kl: 0.0 global kl: 0.158439666033 valid nll: 3423254.75, mse: 17.2902793884, local kl: 0.0 global kl: 0.106349825859
it: 8750, train nll: 95136.7421875, mse: 14.4196329117, local kl: 0.0 global kl: 0.0910966545343 valid nll: 478026.71875, mse: 13.798283577, local kl: 0.0 global kl: 0.186292380095
it: 8800, train nll: 2008942.125, mse: 13.9394798279, local kl: 0.0 global kl: 0.0495833940804 valid nll: 6592732.5, mse: 21.7443161011, local kl: 0.0 global kl: 0.0708025321364
it: 8850, train nll: 566091.8125, mse: 13.6820402145, local kl: 0.0 global kl: 0.0607424005866 valid nll: 18693254.0, mse: 25.565114975, local kl: 0.0 global kl: 0.0985813215375
it: 8900, train nll: 34005.5234375, mse: 13.6273097992, local kl: 0.0 global kl: 0.0267247259617 valid nll: 14804383.0, mse: 14.6500120163, local kl: 0.0 global kl: 0.0746844187379
it: 8950, train nll: 14233465.0, mse: 26.3350143433, local kl: 0.0 global kl: 0.0385752730072 valid nll: 5060774.0, mse: 10.2200651169, local kl: 0.0 global kl: 0.0908330529928
Saving best model with MSE 10.220065
it: 9000, train nll: 3042041.0, mse: 10.2526855469, local kl: 0.0 global kl: 0.0621616765857 valid nll: 4486588.5, mse: 25.5089664459, local kl: 0.0 global kl: 0.111813798547
it: 9050, train nll: 127954.976562, mse: 22.5594215393, local kl: 0.0 global kl: 0.0401459932327 valid nll: 611434.3125, mse: 35.1765785217, local kl: 0.0 global kl: 0.142476916313
it: 9100, train nll: 402855.78125, mse: 8.87015914917, local kl: 0.0 global kl: 0.0297452453524 valid nll: 2837077.5, mse: 15.6829547882, local kl: 0.0 global kl: 0.183349862695
it: 9150, train nll: 7185562.5, mse: 27.6995429993, local kl: 0.0 global kl: 0.21633644402 valid nll: 8061281.5, mse: 15.9937362671, local kl: 0.0 global kl: 0.217425197363
it: 9200, train nll: 4235051.0, mse: 11.3300352097, local kl: 0.0 global kl: 0.012836557813 valid nll: 10758486.0, mse: 24.1349105835, local kl: 0.0 global kl: 0.213553518057
it: 9250, train nll: 13699389.0, mse: 10.9829435349, local kl: 0.0 global kl: 0.041811350733 valid nll: 8244169.5, mse: 19.4873523712, local kl: 0.0 global kl: 0.203503966331
it: 9300, train nll: 3307581.5, mse: 20.1069374084, local kl: 0.0 global kl: 0.0994302779436 valid nll: 6307099.0, mse: 15.265381813, local kl: 0.0 global kl: 0.248556375504
it: 9350, train nll: 1934851.75, mse: 10.6064062119, local kl: 0.0 global kl: 0.109821215272 valid nll: 12540639.0, mse: 29.4972019196, local kl: 0.0 global kl: 0.275861918926
it: 9400, train nll: 1376417.375, mse: 7.36630964279, local kl: 0.0 global kl: 0.0360636897385 valid nll: 4919950.5, mse: 15.4620580673, local kl: 0.0 global kl: 0.270707070827
it: 9450, train nll: 10949536.0, mse: 26.1859512329, local kl: 0.0 global kl: 0.0479173026979 valid nll: 5909456.5, mse: 17.1117801666, local kl: 0.0 global kl: 0.222762912512
it: 9500, train nll: 12109905.0, mse: 19.2543907166, local kl: 0.0 global kl: 0.144899740815 valid nll: 22467766.0, mse: 11.1016550064, local kl: 0.0 global kl: 0.232563972473
it: 9550, train nll: 652189.4375, mse: 16.4946784973, local kl: 0.0 global kl: 0.0468709841371 valid nll: 25719274.0, mse: 19.0577411652, local kl: 0.0 global kl: 0.183100923896
it: 9600, train nll: 296760.96875, mse: 13.1121397018, local kl: 0.0 global kl: 0.0223705861717 valid nll: 16396853.0, mse: 12.7314510345, local kl: 0.0 global kl: 0.150465458632
it: 9650, train nll: 114294.320312, mse: 23.3783073425, local kl: 0.0 global kl: 0.113870181143 valid nll: 15564239.0, mse: 29.4932956696, local kl: 0.0 global kl: 0.169445574284
it: 9700, train nll: 1028978.75, mse: 16.7677192688, local kl: 0.0 global kl: 0.0435428693891 valid nll: 2579786.0, mse: 19.1349220276, local kl: 0.0 global kl: 0.224936932325
it: 9750, train nll: 2387513.25, mse: 10.4229059219, local kl: 0.0 global kl: 0.0554197058082 valid nll: 1001522.1875, mse: 28.9705696106, local kl: 0.0 global kl: 0.218313485384
it: 9800, train nll: 7283546.0, mse: 12.2872781754, local kl: 0.0 global kl: 0.0440026000142 valid nll: 14838749.0, mse: 24.9554977417, local kl: 0.0 global kl: 0.204919531941
it: 9850, train nll: 18010100.0, mse: 16.4089355469, local kl: 0.0 global kl: 0.0536355562508 valid nll: 8611419.0, mse: 21.3106327057, local kl: 0.0 global kl: 0.302130401134
it: 9900, train nll: 937147.5, mse: 15.6648712158, local kl: 0.0 global kl: 0.0998610258102 valid nll: 3389211.0, mse: 20.5598011017, local kl: 0.0 global kl: 0.331852227449
it: 9950, train nll: 455042.875, mse: 10.5701332092, local kl: 0.0 global kl: 0.0215912815183 valid nll: 5805827.0, mse: 8.51126861572, local kl: 0.0 global kl: 0.0895229354501
Saving best model with MSE 8.511269

Posterior predictive + freeform


In [0]:
uncertainty_type = 'attentive_freeform'
local_variational = True
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_posterior_freeform_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)


it: 0, train nll: 291.345306396, mse: 327.003387451, local kl: 0.0251190047711 global kl: 8.42864028527e-05 valid nll: 446.904907227, mse: 410.268127441, local kl: 0.0530770793557 global kl: 0.00013029173715
Saving best model with MSE 410.26813
it: 50, train nll: 11.5967645645, mse: 131.161407471, local kl: 0.0747972652316 global kl: 1.34051947498e-06 valid nll: 8.17937660217, mse: 128.501358032, local kl: 0.0416157543659 global kl: 1.41697219078e-05
Saving best model with MSE 128.50136
it: 100, train nll: 39.7724533081, mse: 252.00189209, local kl: 0.0145239755511 global kl: 34902859776.0 valid nll: 20.704328537, mse: 202.928283691, local kl: 0.0488733649254 global kl: 9.995731034e-07
it: 150, train nll: 3.64748597145, mse: 105.372077942, local kl: 0.0107508469373 global kl: 0.0003227260313 valid nll: 3.00736522675, mse: 94.4865570068, local kl: 0.0274358782917 global kl: 3.69362642232e-05
Saving best model with MSE 94.48656
it: 200, train nll: 2.67741012573, mse: 80.642074585, local kl: 0.00961234234273 global kl: 3.98564661737e-05 valid nll: 3.64317774773, mse: 105.332954407, local kl: 0.00791214127094 global kl: 0.000248979602475
it: 250, train nll: 2.60094356537, mse: 65.8145065308, local kl: 0.00571423815563 global kl: 0.000385096704122 valid nll: 2.92429161072, mse: 69.0826416016, local kl: 0.00229561864398 global kl: 0.000332172261551
Saving best model with MSE 69.08264
it: 300, train nll: 2.51189446449, mse: 50.0561943054, local kl: 0.00446181977168 global kl: 0.000299677660223 valid nll: 3.19988584518, mse: 81.19165802, local kl: 0.00478346832097 global kl: 0.000881981686689
it: 350, train nll: 2.6372487545, mse: 60.708782196, local kl: 0.0043863854371 global kl: 0.00047045154497 valid nll: 3.12556695938, mse: 96.4865264893, local kl: 0.0021287270356 global kl: 0.000280496227788
it: 400, train nll: 3.4186372757, mse: 61.1954650879, local kl: 0.00151869817637 global kl: 0.00273065734655 valid nll: 2.95010519028, mse: 54.8949890137, local kl: 0.00258051604033 global kl: 0.000149737330503
Saving best model with MSE 54.89499
it: 450, train nll: 2.65680146217, mse: 31.7283992767, local kl: 0.004473506473 global kl: 0.00141486572102 valid nll: 2.76946759224, mse: 34.9076194763, local kl: 0.00779176922515 global kl: 0.000532339152414
Saving best model with MSE 34.90762
it: 500, train nll: 2.84436249733, mse: 28.8289546967, local kl: 0.000912165676709 global kl: 0.000209362522583 valid nll: 3.04574298859, mse: 55.667881012, local kl: 0.00377391534857 global kl: 0.000918705947697
it: 550, train nll: 2.94405555725, mse: 35.8628044128, local kl: 0.00296138087288 global kl: 0.000571158132516 valid nll: 2.9929857254, mse: 30.1273708344, local kl: 0.00378059525974 global kl: 0.000174187356606
Saving best model with MSE 30.12737
it: 600, train nll: 3.05358862877, mse: 81.316734314, local kl: 0.00119542051107 global kl: 0.00357438321225 valid nll: 2.97930264473, mse: 64.6176452637, local kl: 0.00134507776238 global kl: 5.94359116803e-05
it: 650, train nll: 2.8011841774, mse: 36.6520881653, local kl: 0.00110211607534 global kl: 0.000202487921342 valid nll: 2.92088246346, mse: 33.9958267212, local kl: 0.0014550111955 global kl: 0.000242142705247
it: 700, train nll: 2.82405519485, mse: 25.8108768463, local kl: 0.000960217905231 global kl: 0.00104307197034 valid nll: 3.23678398132, mse: 62.3176803589, local kl: 0.0132199507207 global kl: 0.000191384664504
it: 750, train nll: 2.81659126282, mse: 53.8953056335, local kl: 0.00138731813058 global kl: 0.00221583852544 valid nll: 3.48421144485, mse: 99.7099609375, local kl: 0.0137205021456 global kl: 0.00045604351908
it: 800, train nll: 2.61663913727, mse: 16.0111789703, local kl: 0.000823241716716 global kl: 0.000843234243803 valid nll: 2.76911878586, mse: 32.9360237122, local kl: 0.000820125162136 global kl: 0.000425020552939
it: 850, train nll: 2.96369409561, mse: 55.8621253967, local kl: 0.0025464752689 global kl: 0.00290138483979 valid nll: 2.85267829895, mse: 38.6920814514, local kl: 0.00134498532861 global kl: 0.000878678984009
it: 900, train nll: 2.5163731575, mse: 27.7878189087, local kl: 0.000996504211798 global kl: 0.00891763158143 valid nll: 2.6389811039, mse: 32.0222091675, local kl: 0.00559860887006 global kl: 0.000343302439433
it: 950, train nll: 2.62014341354, mse: 40.3008537292, local kl: 0.00287264352664 global kl: 0.000437197915744 valid nll: 2.76213955879, mse: 51.5242500305, local kl: 0.00417012581602 global kl: 0.000486800214276
it: 1000, train nll: 2.74395418167, mse: 21.0181674957, local kl: 0.000865022942889 global kl: 0.00128460000269 valid nll: 3.10299420357, mse: 38.8220939636, local kl: 0.00538825197145 global kl: 0.000470564060379
it: 1050, train nll: 2.88940191269, mse: 62.7669372559, local kl: 0.00168646953534 global kl: 0.000315257406328 valid nll: 3.00718450546, mse: 36.8663139343, local kl: 0.00224413280375 global kl: 0.000285701826215
it: 1100, train nll: 2.35056734085, mse: 22.4295978546, local kl: 0.00169741315767 global kl: 0.000670440786052 valid nll: 2.37664556503, mse: 30.3299312592, local kl: 0.0066316020675 global kl: 0.000337898294674
it: 1150, train nll: 2.96894478798, mse: 26.6173038483, local kl: 0.00773007096723 global kl: 0.000673171249218 valid nll: 2.37721705437, mse: 29.3483123779, local kl: 0.00354988547042 global kl: 0.000122638215544
Saving best model with MSE 29.348312
it: 1200, train nll: 2.3450191021, mse: 20.9891891479, local kl: 0.00198826775886 global kl: 0.000850608572364 valid nll: 2.39182734489, mse: 27.5819568634, local kl: 0.0361358746886 global kl: 0.00025385862682
Saving best model with MSE 27.581957
it: 1250, train nll: 2.40282917023, mse: 52.2122077942, local kl: 0.000284214562271 global kl: 0.00031949422555 valid nll: 2.39294886589, mse: 53.7004623413, local kl: 0.000701526878402 global kl: 9.25885542529e-05
it: 1300, train nll: 2.35534143448, mse: 39.7444000244, local kl: 0.000547896779608 global kl: 0.00150490598753 valid nll: 2.75899744034, mse: 43.2087936401, local kl: 0.00223369593732 global kl: 0.000620151229668
it: 1350, train nll: 2.79547691345, mse: 30.8212203979, local kl: 0.000955872121267 global kl: 0.00765049597248 valid nll: 2.54061055183, mse: 23.6779155731, local kl: 0.00848384015262 global kl: 0.000654832809232
Saving best model with MSE 23.677916
it: 1400, train nll: 2.60293459892, mse: 28.7342815399, local kl: 0.00385494320653 global kl: 0.011420735158 valid nll: 2.39216923714, mse: 22.2636127472, local kl: 0.00889316480607 global kl: 0.0022947229445
Saving best model with MSE 22.263613
it: 1450, train nll: 2.59702944756, mse: 43.5488433838, local kl: 0.0114409960806 global kl: 0.00297050341032 valid nll: 2.40495800972, mse: 17.5143451691, local kl: 0.00588196842 global kl: 0.00103218609001
Saving best model with MSE 17.514345
it: 1500, train nll: 2.30496740341, mse: 12.8915700912, local kl: 0.0008739453624 global kl: 0.00278417696245 valid nll: 2.40698480606, mse: 24.785900116, local kl: 0.00214071036316 global kl: 0.00175092217978
it: 1550, train nll: 2.33063817024, mse: 14.4761781693, local kl: 0.00266512017697 global kl: 0.000863075081725 valid nll: 2.44622206688, mse: 23.9642734528, local kl: 0.00636985432357 global kl: 0.00151938886847
it: 1600, train nll: 2.59053373337, mse: 19.1136360168, local kl: 0.000674323178828 global kl: 0.0026238122955 valid nll: 2.70743441582, mse: 25.2774925232, local kl: 0.0525820143521 global kl: 0.00129474943969
it: 1650, train nll: 2.34118509293, mse: 11.5763092041, local kl: 0.00152996671386 global kl: 0.00159648060799 valid nll: 2.99649572372, mse: 29.9945144653, local kl: 0.00108032592107 global kl: 0.000111742898298
it: 1700, train nll: 2.4707403183, mse: 13.9622888565, local kl: 0.0169810391963 global kl: 0.000169004837517 valid nll: 2.42957234383, mse: 15.1448793411, local kl: 0.00190440402366 global kl: 0.000486245931825
Saving best model with MSE 15.144879
it: 1750, train nll: 2.43301296234, mse: 24.5653247833, local kl: 0.00143469648901 global kl: 0.00172938371543 valid nll: 2.44495081902, mse: 29.7146816254, local kl: 0.00166551442817 global kl: 0.000439729134087
it: 1800, train nll: 2.78773188591, mse: 35.8248939514, local kl: 0.00254252669401 global kl: 0.00749176833779 valid nll: 2.68622016907, mse: 33.3149185181, local kl: 0.0027628459502 global kl: 0.00763872917742
it: 1850, train nll: 2.3675699234, mse: 16.2261333466, local kl: 0.00240637501702 global kl: 0.00158184673637 valid nll: 2.45278859138, mse: 18.8933639526, local kl: 0.0118119902909 global kl: 0.00473066512495
it: 1900, train nll: 2.31531095505, mse: 16.8307342529, local kl: 0.00104976352304 global kl: 0.00435461103916 valid nll: 2.49889492989, mse: 28.0936717987, local kl: 0.00582839362323 global kl: 0.00142832437996
it: 1950, train nll: 2.47364163399, mse: 27.7578258514, local kl: 0.00517945829779 global kl: 0.00231808098033 valid nll: 2.50075960159, mse: 23.1265583038, local kl: 0.0177239347249 global kl: 0.00392420683056
it: 2000, train nll: 2.57717990875, mse: 14.7230195999, local kl: 0.00375575036742 global kl: 0.00791614037007 valid nll: 2.60651350021, mse: 34.9138221741, local kl: 0.00656816270202 global kl: 0.00438077747822
it: 2050, train nll: 2.42046284676, mse: 8.56861400604, local kl: 0.00159290246665 global kl: 0.00143575575203 valid nll: 2.49306559563, mse: 15.4081106186, local kl: 0.0700053572655 global kl: 0.00125039182603
it: 2100, train nll: 2.46702647209, mse: 20.6496620178, local kl: 0.0027049286291 global kl: 0.0150745706633 valid nll: 2.50151491165, mse: 15.0660715103, local kl: 0.0111338123679 global kl: 0.0173133518547
Saving best model with MSE 15.0660715
it: 2150, train nll: 2.52846431732, mse: 29.6249370575, local kl: 0.00588985346258 global kl: 0.00825631152838 valid nll: 2.56753087044, mse: 30.0506134033, local kl: 0.0105274859816 global kl: 0.0112245483324
it: 2200, train nll: 2.47795724869, mse: 19.3266563416, local kl: 0.0196130704135 global kl: 0.00146511010826 valid nll: 2.54358649254, mse: 23.6669464111, local kl: 0.00208494416438 global kl: 0.0134705081582
it: 2250, train nll: 2.58073735237, mse: 14.650935173, local kl: 0.00149685877841 global kl: 0.00284803682007 valid nll: 2.53010392189, mse: 22.2828521729, local kl: 0.0054753809236 global kl: 0.0114807123318
it: 2300, train nll: 2.40911078453, mse: 14.2861509323, local kl: 0.00363506516442 global kl: 0.00813251826912 valid nll: 2.52348399162, mse: 24.511390686, local kl: 0.0111244414002 global kl: 0.00341435451992
it: 2350, train nll: 2.94738411903, mse: 23.3691196442, local kl: 0.00793772190809 global kl: 0.0136693445966 valid nll: 2.62140727043, mse: 29.2572479248, local kl: 0.00369190936908 global kl: 0.0744479000568
it: 2400, train nll: 2.46762490273, mse: 11.3510951996, local kl: 0.00488542951643 global kl: 0.00289469724521 valid nll: 2.54398775101, mse: 17.9919242859, local kl: 0.00761729991063 global kl: 0.00805459823459
it: 2450, train nll: 2.37919569016, mse: 13.9371347427, local kl: 0.0086047174409 global kl: 0.00405557872728 valid nll: 2.55813241005, mse: 22.0309772491, local kl: 0.0241238307208 global kl: 0.00446869153529
it: 2500, train nll: 2.47219848633, mse: 20.5306930542, local kl: 0.00448359781876 global kl: 0.00466383853927 valid nll: 2.53869009018, mse: 17.0420646667, local kl: 0.00533671444282 global kl: 0.00649101566523
it: 2550, train nll: 2.4729847908, mse: 5.40275812149, local kl: 0.00194074434694 global kl: 0.0030366149731 valid nll: 2.51541209221, mse: 13.4940099716, local kl: 0.0107161104679 global kl: 0.00823755934834
Saving best model with MSE 13.49401
it: 2600, train nll: 2.49932336807, mse: 18.2961807251, local kl: 0.00779443187639 global kl: 0.00272206170484 valid nll: 2.560577631, mse: 27.795879364, local kl: 0.00875684339553 global kl: 0.00407565245405
it: 2650, train nll: 2.50320482254, mse: 27.2781505585, local kl: 0.143328353763 global kl: 0.0102568631992 valid nll: 2.57410216331, mse: 31.1093769073, local kl: 0.00362401106395 global kl: 0.0182110182941
it: 2700, train nll: 2.62189912796, mse: 21.9468975067, local kl: 0.0132144102827 global kl: 0.00678028538823 valid nll: 2.57898664474, mse: 21.0495166779, local kl: 0.00798303354532 global kl: 0.00652960920706
it: 2750, train nll: 2.46993160248, mse: 9.72295665741, local kl: 0.00478815101087 global kl: 0.0341749116778 valid nll: 2.57926893234, mse: 19.4066047668, local kl: 0.0211440417916 global kl: 0.0513148903847
it: 2800, train nll: 2.54162931442, mse: 16.9489421844, local kl: 0.00751974526793 global kl: 0.189293652773 valid nll: 2.61188697815, mse: 30.8014354706, local kl: 0.0322404280305 global kl: 0.0645402818918
it: 2850, train nll: 2.574644804, mse: 16.7780799866, local kl: 0.00417951447889 global kl: 0.00206483714283 valid nll: 2.60175657272, mse: 22.1192073822, local kl: 0.0167703982443 global kl: 0.00830320548266
it: 2900, train nll: 2.51973199844, mse: 4.31895780563, local kl: 0.0125242741778 global kl: 0.0254504680634 valid nll: 2.73805332184, mse: 18.1748771667, local kl: 0.0668108686805 global kl: 0.0245495848358
it: 2950, train nll: 2.54754734039, mse: 12.0817289352, local kl: 0.0168859474361 global kl: 0.0155866695568 valid nll: 2.65614819527, mse: 25.892572403, local kl: 0.0264678299427 global kl: 0.0130820367485
it: 3000, train nll: 2.57695245743, mse: 25.3093605042, local kl: 0.0429888926446 global kl: 0.0136408079416 valid nll: 2.63911366463, mse: 29.0673465729, local kl: 0.0342177040875 global kl: 0.00690666213632
it: 3050, train nll: 2.65172028542, mse: 12.5806436539, local kl: 0.0496012978256 global kl: 0.0123634198681 valid nll: 2.77756881714, mse: 12.1680335999, local kl: 0.0192226003855 global kl: 0.00943814124912
Saving best model with MSE 12.168034
it: 3100, train nll: 2.46443414688, mse: 12.4936132431, local kl: 0.00270610069856 global kl: 0.0156515687704 valid nll: 2.62835788727, mse: 26.0794258118, local kl: 0.0152306333184 global kl: 0.0217686649412
it: 3150, train nll: 2.53963637352, mse: 15.9298334122, local kl: 0.00328033906408 global kl: 0.02127783373 valid nll: 2.61653757095, mse: 17.5342082977, local kl: 0.0075105256401 global kl: 0.0634019598365
it: 3200, train nll: 2.48110628128, mse: 15.9902973175, local kl: 0.00433429935947 global kl: 0.0100740455091 valid nll: 2.62114572525, mse: 17.5896530151, local kl: 0.0810562372208 global kl: 0.01577960141
it: 3250, train nll: 2.70260667801, mse: 18.7855091095, local kl: 0.000937557430007 global kl: 0.0304933041334 valid nll: 2.69343113899, mse: 25.3397579193, local kl: 0.079218506813 global kl: 0.108063593507
it: 3300, train nll: 2.65449976921, mse: 19.7170696259, local kl: 0.0586947090924 global kl: 0.0258237831295 valid nll: 2.62198424339, mse: 21.4073581696, local kl: 0.0250332225114 global kl: 0.0276337657124
it: 3350, train nll: 2.59044218063, mse: 14.7564220428, local kl: 0.0146230384707 global kl: 0.00763814523816 valid nll: 2.65125584602, mse: 25.5037879944, local kl: 0.0299689359963 global kl: 0.0384705811739
it: 3400, train nll: 2.61518025398, mse: 12.7418813705, local kl: 0.0197584014386 global kl: 0.0278146024793 valid nll: 2.82124686241, mse: 23.7298202515, local kl: 0.0734874010086 global kl: 0.00140420696698
it: 3450, train nll: 2.63219952583, mse: 12.0720157623, local kl: 0.00188001466449 global kl: 0.00670815864578 valid nll: 2.71564888954, mse: 36.9443588257, local kl: 0.0170368943363 global kl: 0.00821665115654
it: 3500, train nll: 2.61087560654, mse: 13.4469766617, local kl: 0.00954456999898 global kl: 0.011909856461 valid nll: 2.92645263672, mse: 25.2701854706, local kl: 0.0174440871924 global kl: 0.0591816417873
it: 3550, train nll: 2.59580254555, mse: 13.0923061371, local kl: 0.0103477723897 global kl: 0.0296295490116 valid nll: 2.68735551834, mse: 22.5811958313, local kl: 0.00378346187063 global kl: 0.0531838536263
it: 3600, train nll: 2.6458799839, mse: 22.251247406, local kl: 0.0462309904397 global kl: 0.0140952123329 valid nll: 2.68787002563, mse: 26.8065338135, local kl: 0.0452044196427 global kl: 0.00725852465257
it: 3650, train nll: 2.64064836502, mse: 18.7274074554, local kl: 0.072366528213 global kl: 0.0541830174625 valid nll: 2.86003780365, mse: 21.279499054, local kl: 0.0133959287778 global kl: 0.0717800408602
it: 3700, train nll: 2.61168026924, mse: 18.4725074768, local kl: 0.000758105015848 global kl: 0.00756677612662 valid nll: 2.6791908741, mse: 26.2814769745, local kl: 0.0114509621635 global kl: 0.0179509855807
it: 3750, train nll: 2.57178211212, mse: 16.2769241333, local kl: 0.00577462790534 global kl: 0.00526669342071 valid nll: 2.67859745026, mse: 25.5186595917, local kl: 0.0254149157554 global kl: 0.00856514368206
it: 3800, train nll: 2.61618065834, mse: 23.1563129425, local kl: 0.00766077218577 global kl: 0.0107753071934 valid nll: 2.76654052734, mse: 28.288854599, local kl: 0.00818346720189 global kl: 0.00210567237809
it: 3850, train nll: 2.59256291389, mse: 7.12050676346, local kl: 0.0211983267218 global kl: 0.0426607690752 valid nll: 2.71593308449, mse: 23.5755729675, local kl: 0.00840928498656 global kl: 0.0208580102772
it: 3900, train nll: 2.61349058151, mse: 11.0286893845, local kl: 0.000659076147713 global kl: 0.0216381195933 valid nll: 2.83840370178, mse: 27.2306270599, local kl: 0.00851466413587 global kl: 0.0359230712056
it: 3950, train nll: 2.59235310555, mse: 15.2357749939, local kl: 0.0024426386226 global kl: 0.0112474858761 valid nll: 2.88779830933, mse: 15.7800855637, local kl: 0.0892772674561 global kl: 0.0119901504368
it: 4000, train nll: 2.69325947762, mse: 20.8851833344, local kl: 0.0128421653062 global kl: 0.025415873155 valid nll: 2.70082259178, mse: 31.9218444824, local kl: 0.0403598546982 global kl: 0.0208986215293
it: 4050, train nll: 2.56171250343, mse: 11.4233503342, local kl: 0.00766167556867 global kl: 0.0170599147677 valid nll: 2.69381785393, mse: 18.6808757782, local kl: 0.00342918536626 global kl: 0.0694663077593
it: 4100, train nll: 2.53484678268, mse: 32.0483589172, local kl: 0.00715421559289 global kl: 0.00815795361996 valid nll: 2.73874211311, mse: 24.4477729797, local kl: 0.008594321087 global kl: 0.0535900071263
it: 4150, train nll: 2.64502334595, mse: 24.0049266815, local kl: 0.00367682450451 global kl: 0.0456834062934 valid nll: 2.8367061615, mse: 17.9727802277, local kl: 0.173340752721 global kl: 0.0542820617557
it: 4200, train nll: 2.68278574944, mse: 16.0135116577, local kl: 0.0223084315658 global kl: 0.0347361080348 valid nll: 2.70238375664, mse: 19.4282550812, local kl: 0.0157237593085 global kl: 0.285714805126
it: 4250, train nll: 2.63750195503, mse: 23.6704368591, local kl: 0.00259907217696 global kl: 0.0124472947791 valid nll: 2.70150661469, mse: 16.0474205017, local kl: 0.0135667705908 global kl: 0.025761032477
it: 4300, train nll: 2.72995519638, mse: 25.6469650269, local kl: 0.058481130749 global kl: 0.0465701557696 valid nll: 2.72795081139, mse: 25.6099300385, local kl: 0.0171713531017 global kl: 0.0747309476137
it: 4350, train nll: 2.746986866, mse: 21.0114746094, local kl: 0.00244362419471 global kl: 0.1742438972 valid nll: 2.73611521721, mse: 22.0561294556, local kl: 0.0191158466041 global kl: 0.280290514231
it: 4400, train nll: 2.79092431068, mse: 33.2612838745, local kl: 0.00476312404498 global kl: 0.0630312412977 valid nll: 2.74933028221, mse: 26.0236244202, local kl: 0.00216077384539 global kl: 0.0102644730359
it: 4450, train nll: 2.69214200974, mse: 30.176279068, local kl: 0.005426004529 global kl: 0.0280549116433 valid nll: 2.73987293243, mse: 24.0776252747, local kl: 0.0104513829574 global kl: 0.0932472199202
it: 4500, train nll: 2.67781496048, mse: 12.7429704666, local kl: 0.00304553983733 global kl: 0.0153461126611 valid nll: 2.71658563614, mse: 21.7702960968, local kl: 0.0443647392094 global kl: 0.055806118995
it: 4550, train nll: 2.71798658371, mse: 5.54773664474, local kl: 0.00674819294363 global kl: 0.00910129304975 valid nll: 2.77380347252, mse: 16.0537090302, local kl: 0.0277444496751 global kl: 0.0261196196079
it: 4600, train nll: 2.6912176609, mse: 4.06627702713, local kl: 0.000626313732937 global kl: 0.0251348502934 valid nll: 2.72169518471, mse: 18.2289142609, local kl: 0.00606571557 global kl: 0.0364216342568
it: 4650, train nll: 2.66604542732, mse: 15.0681858063, local kl: 0.00897775962949 global kl: 0.0104592721909 valid nll: 2.75544667244, mse: 26.320022583, local kl: 0.00279357819818 global kl: 0.00768982479349
it: 4700, train nll: 4.07238531113, mse: 28.4954223633, local kl: 0.000287349888822 global kl: 0.0505554489791 valid nll: 3.62909221649, mse: 17.9724884033, local kl: 0.0027002862189 global kl: 0.0512250959873
it: 4750, train nll: 2.75852489471, mse: 25.1100902557, local kl: 0.0023006182164 global kl: 0.0407937057316 valid nll: 2.7667632103, mse: 27.6326313019, local kl: 0.00238920422271 global kl: 0.03876645118
it: 4800, train nll: 2.78014540672, mse: 32.0520553589, local kl: 0.0054907463491 global kl: 0.198168128729 valid nll: 2.77545285225, mse: 25.7045001984, local kl: 0.0289721321315 global kl: 0.0174572281539
it: 4850, train nll: 2.71429181099, mse: 10.8439855576, local kl: 0.00114455469884 global kl: 0.0152880875394 valid nll: 2.76907587051, mse: 16.6321163177, local kl: 0.00407478632405 global kl: 0.0176534838974
it: 4900, train nll: 2.64708852768, mse: 14.3052053452, local kl: 0.00162832648493 global kl: 0.0145740928128 valid nll: 2.75026035309, mse: 13.0830821991, local kl: 0.00247530406341 global kl: 0.0283730719239
it: 4950, train nll: 2.70045924187, mse: 17.7586460114, local kl: 0.000640687707346 global kl: 0.0109068322927 valid nll: 2.76387119293, mse: 22.346698761, local kl: 0.0119039574638 global kl: 0.0177107136697
it: 5000, train nll: 2.58997011185, mse: 10.3800430298, local kl: 0.000905578490347 global kl: 0.0118724424392 valid nll: 2.75841593742, mse: 12.9968070984, local kl: 0.00532778073102 global kl: 0.0265851374716
it: 5050, train nll: 2.72375369072, mse: 31.5546092987, local kl: 0.00284968060441 global kl: 0.0539113990963 valid nll: 2.83723139763, mse: 24.3424625397, local kl: 0.00429217657074 global kl: 0.0210260506719
it: 5100, train nll: 2.78734922409, mse: 42.1940956116, local kl: 0.00143723306246 global kl: 0.010385915637 valid nll: 2.79947113991, mse: 33.0544548035, local kl: 0.000617799407337 global kl: 0.032013989985
it: 5150, train nll: 2.72789716721, mse: 9.40992164612, local kl: 0.00205707037821 global kl: 0.0456276461482 valid nll: 2.89206767082, mse: 16.6297264099, local kl: 0.00203389744274 global kl: 0.0356936082244
it: 5200, train nll: 2.68910503387, mse: 10.2929973602, local kl: 0.000989065272734 global kl: 0.0738404542208 valid nll: 2.76638484001, mse: 18.541683197, local kl: 0.00663538090885 global kl: 0.0668793469667
it: 5250, train nll: 2.73348331451, mse: 12.2809724808, local kl: 0.103905893862 global kl: 0.0336748324335 valid nll: 2.76973700523, mse: 14.3584432602, local kl: 0.0332577452064 global kl: 0.03463697806
it: 5300, train nll: 2.72455596924, mse: 9.00101280212, local kl: 0.00798766501248 global kl: 0.0284857209772 valid nll: 2.76737713814, mse: 15.305565834, local kl: 0.00515746790916 global kl: 0.0423748344183
it: 5350, train nll: 2.66517901421, mse: 7.81434249878, local kl: 0.00339583377354 global kl: 0.0191937610507 valid nll: 2.76371669769, mse: 11.6187601089, local kl: 0.016769753769 global kl: 0.0356229059398
Saving best model with MSE 11.61876
it: 5400, train nll: 2.88328528404, mse: 25.6117229462, local kl: 0.00197748653591 global kl: 0.0419338867068 valid nll: 2.78491640091, mse: 25.5235309601, local kl: 0.0145403370261 global kl: 0.0309488866478
it: 5450, train nll: 2.66002964973, mse: 14.3259925842, local kl: 0.00149500509724 global kl: 0.0835936665535 valid nll: 2.79185032845, mse: 19.9911403656, local kl: 0.00155710987747 global kl: 0.0413441546261
it: 5500, train nll: 2.80235052109, mse: 12.0371780396, local kl: 0.00974051561207 global kl: 0.0230333674699 valid nll: 2.77603936195, mse: 17.7604923248, local kl: 0.00573800876737 global kl: 0.023296887055
it: 5550, train nll: 4.95069646835, mse: 14.5355367661, local kl: 0.00408285437152 global kl: 0.0424922592938 valid nll: 4.96784973145, mse: 13.5038881302, local kl: 0.0283757653087 global kl: 0.0503282658756
it: 5600, train nll: 2.68896842003, mse: 21.2719688416, local kl: 0.00504843564704 global kl: 0.0415825843811 valid nll: 2.7749402523, mse: 7.16472434998, local kl: 0.0181875247508 global kl: 0.132401227951
Saving best model with MSE 7.1647243
it: 5650, train nll: 2.74232363701, mse: 12.4239711761, local kl: 0.00779293524101 global kl: 0.0376534946263 valid nll: 2.90320515633, mse: 27.5276107788, local kl: 0.00616130931303 global kl: 0.0454192832112
it: 5700, train nll: 2.78346133232, mse: 6.62168264389, local kl: 0.00237763533369 global kl: 0.0463013127446 valid nll: 2.90390443802, mse: 15.9099245071, local kl: 0.0289078671485 global kl: 0.0977253764868
it: 5750, train nll: 2.70260167122, mse: 1.9406119585, local kl: 0.00111153081525 global kl: 0.0139396395534 valid nll: 2.78104496002, mse: 8.29404449463, local kl: 0.0156631842256 global kl: 0.0411910638213
it: 5800, train nll: 2.68649148941, mse: 11.0330524445, local kl: 0.00305675319396 global kl: 0.0747916325927 valid nll: 2.80885601044, mse: 12.3182353973, local kl: 0.000357895245543 global kl: 0.0740087032318
it: 5850, train nll: 2.74264788628, mse: 11.2998094559, local kl: 0.00332026230171 global kl: 0.0753999650478 valid nll: 3.00888371468, mse: 13.4880819321, local kl: 0.00936244428158 global kl: 0.0255372859538
it: 5900, train nll: 2.76030564308, mse: 11.568312645, local kl: 0.00188742321916 global kl: 0.0523453354836 valid nll: 2.80580472946, mse: 21.5583133698, local kl: 0.00897171441466 global kl: 0.0425422862172
it: 5950, train nll: 2.80222535133, mse: 21.3470191956, local kl: 0.00119900971185 global kl: 0.0529309287667 valid nll: 2.94664859772, mse: 34.7335319519, local kl: 0.000142284479807 global kl: 0.0778629481792
it: 6000, train nll: 2.81260490417, mse: 18.2238388062, local kl: 0.00230176234618 global kl: 0.0761987417936 valid nll: 2.82145547867, mse: 32.4690513611, local kl: 0.00399597035721 global kl: 0.0761086195707
it: 6050, train nll: 2.97938728333, mse: 18.3174057007, local kl: 0.00126880849712 global kl: 0.0505222156644 valid nll: 2.80367350578, mse: 13.7049684525, local kl: 0.0768940299749 global kl: 0.0392194464803
it: 6100, train nll: 2.79560136795, mse: 12.4480056763, local kl: 0.00197729910724 global kl: 0.0159472022206 valid nll: 2.81419324875, mse: 12.0392131805, local kl: 0.00740273064002 global kl: 0.0573878996074
it: 6150, train nll: 2.78810739517, mse: 25.1553421021, local kl: 0.00664205243811 global kl: 0.0210897661746 valid nll: 2.84878611565, mse: 23.2594337463, local kl: 0.000262305809883 global kl: 0.0923013910651
it: 6200, train nll: 2.87341880798, mse: 16.2186355591, local kl: 0.000470902130473 global kl: 0.120770439506 valid nll: 2.92613077164, mse: 21.1844291687, local kl: 0.00517733860761 global kl: 0.0731789320707
it: 6250, train nll: 2.68941235542, mse: 9.61262321472, local kl: 0.162393420935 global kl: 0.0332526490092 valid nll: 2.80407738686, mse: 6.6322889328, local kl: 0.00351465889253 global kl: 0.058262757957
Saving best model with MSE 6.632289
it: 6300, train nll: 2.69935560226, mse: 14.9480667114, local kl: 0.00164527026936 global kl: 0.0463248565793 valid nll: 2.8085873127, mse: 7.52628755569, local kl: 0.00325667741708 global kl: 0.0613498277962
it: 6350, train nll: 2.78359031677, mse: 8.19171619415, local kl: 0.000960687873885 global kl: 0.0327511020005 valid nll: 2.8371925354, mse: 32.3574066162, local kl: 0.000394859118387 global kl: 0.06416798383
it: 6400, train nll: 2.75625300407, mse: 15.2724151611, local kl: 0.000220929257921 global kl: 0.0446313731372 valid nll: 2.81974959373, mse: 11.3385381699, local kl: 0.00599813647568 global kl: 0.0827974081039
it: 6450, train nll: 2.76294088364, mse: 14.9837970734, local kl: 0.00211084983312 global kl: 0.0434785559773 valid nll: 2.83452272415, mse: 21.8817481995, local kl: 0.0052539203316 global kl: 0.0387694761157
it: 6500, train nll: 2.73363113403, mse: 11.0036411285, local kl: 0.000988200539723 global kl: 0.0291097071022 valid nll: 2.91926240921, mse: 17.5549488068, local kl: 0.0258030630648 global kl: 0.0753561705351
it: 6550, train nll: 2.76539182663, mse: 21.7285060883, local kl: 0.00103681802284 global kl: 0.0494267642498 valid nll: 2.82010364532, mse: 20.0483283997, local kl: 0.00154732377268 global kl: 0.0895797535777
it: 6600, train nll: 2.72827625275, mse: 9.06008911133, local kl: 0.000460368668428 global kl: 0.0477240793407 valid nll: 2.82965016365, mse: 15.336233139, local kl: 0.0536775738001 global kl: 0.0725664198399
it: 6650, train nll: 2.75487160683, mse: 6.32866191864, local kl: 0.00655464455485 global kl: 0.0426371172071 valid nll: 2.81798410416, mse: 17.0837230682, local kl: 0.0349081233144 global kl: 0.0618298761547
it: 6700, train nll: 2.79217720032, mse: 19.7873249054, local kl: 0.00307120522484 global kl: 0.0110433567315 valid nll: 2.81849122047, mse: 17.558927536, local kl: 0.0808154866099 global kl: 0.0337835028768
it: 6750, train nll: 2.81867814064, mse: 11.6734495163, local kl: 0.00269082491286 global kl: 0.046463996172 valid nll: 2.8327190876, mse: 23.2228927612, local kl: 0.00539172952995 global kl: 0.114917695522
it: 6800, train nll: 2.84422683716, mse: 18.952091217, local kl: 0.00313552189618 global kl: 0.167936310172 valid nll: 2.92452430725, mse: 25.261177063, local kl: 0.029863147065 global kl: 0.114605210721
it: 6850, train nll: 2.86677098274, mse: 12.2882537842, local kl: 0.00871941912919 global kl: 0.0276928581297 valid nll: 2.8352189064, mse: 28.1357555389, local kl: 0.00159381108824 global kl: 0.085940875113
it: 6900, train nll: 2.79276371002, mse: 5.70539236069, local kl: 0.00155736913439 global kl: 0.0743656903505 valid nll: 2.93475627899, mse: 19.6905078888, local kl: 0.00176511087921 global kl: 0.0989913195372
it: 6950, train nll: 2.81609869003, mse: 16.2112159729, local kl: 0.000757159665227 global kl: 0.201698705554 valid nll: 2.93222427368, mse: 14.7997951508, local kl: 0.00146736612078 global kl: 0.0876756161451
it: 7000, train nll: 2.91675543785, mse: 17.1415367126, local kl: 0.00108833669219 global kl: 0.0184418484569 valid nll: 2.84083008766, mse: 10.9305763245, local kl: 0.00131658802275 global kl: 0.126639515162
it: 7050, train nll: 2.77847552299, mse: 8.0937871933, local kl: 0.00193248910364 global kl: 0.0234671495855 valid nll: 2.84192061424, mse: 13.4707365036, local kl: 0.00599278882146 global kl: 0.124354325235
it: 7100, train nll: 2.87545704842, mse: 20.7058601379, local kl: 0.00199253461324 global kl: 0.0278400070965 valid nll: 2.94457387924, mse: 20.906671524, local kl: 0.00309689855203 global kl: 0.0681609213352
it: 7150, train nll: 2.81856560707, mse: 14.9233455658, local kl: 0.00218980899081 global kl: 0.0300269387662 valid nll: 2.85391688347, mse: 23.4420032501, local kl: 0.00116756616626 global kl: 0.10280059278
it: 7200, train nll: 2.84785199165, mse: 18.5454540253, local kl: 0.000913111609407 global kl: 0.0607531256974 valid nll: 2.86027550697, mse: 15.0301923752, local kl: 0.000476488785353 global kl: 0.0477636083961
it: 7250, train nll: 2.85916113853, mse: 21.5832271576, local kl: 0.00104763440322 global kl: 0.0526106767356 valid nll: 2.98452830315, mse: 34.9607658386, local kl: 0.00267794332467 global kl: 0.0513248331845
it: 7300, train nll: 2.8571562767, mse: 15.5919733047, local kl: 0.00102413573768 global kl: 0.0715577676892 valid nll: 4.03435230255, mse: 18.379863739, local kl: 0.00560913840309 global kl: 0.120574615896
it: 7350, train nll: 2.77013325691, mse: 4.92543458939, local kl: 0.00166896625888 global kl: 0.0471291206777 valid nll: 2.86306500435, mse: 12.3865261078, local kl: 0.00111133081373 global kl: 0.122145570815
it: 7400, train nll: 2.83342957497, mse: 25.1543750763, local kl: 0.00198926497251 global kl: 0.174277812243 valid nll: 2.92411613464, mse: 15.4928398132, local kl: 0.00691838329658 global kl: 0.0672508701682
it: 7450, train nll: 2.90838527679, mse: 17.301404953, local kl: 0.000778499408625 global kl: 0.087894923985 valid nll: 3.1418607235, mse: 43.9711837769, local kl: 0.00104145240039 global kl: 0.182911619544
it: 7500, train nll: 2.85583353043, mse: 6.86529827118, local kl: 0.000357381475624 global kl: 0.0300987660885 valid nll: 2.8756570816, mse: 19.1919460297, local kl: 0.000367692176951 global kl: 0.0465393587947
it: 7550, train nll: 2.88809561729, mse: 10.6118993759, local kl: 0.000186982448213 global kl: 0.0243381094187 valid nll: 2.88185763359, mse: 17.9987220764, local kl: 0.000266393763013 global kl: 0.0433799177408
it: 7600, train nll: 2.84260106087, mse: 12.9933996201, local kl: 0.0102587658912 global kl: 0.129697471857 valid nll: 2.9127869606, mse: 17.0832576752, local kl: 0.0231297239661 global kl: 0.0823982805014
it: 7650, train nll: 2.93743872643, mse: 11.2691679001, local kl: 0.000697998213582 global kl: 0.130347043276 valid nll: 2.98154425621, mse: 21.2627944946, local kl: 0.105365283787 global kl: 0.199651330709
it: 7700, train nll: 2.85644698143, mse: 14.7865085602, local kl: 0.000714791123755 global kl: 0.0587358362973 valid nll: 2.96379995346, mse: 15.11739254, local kl: 0.000638213241473 global kl: 0.0751625224948
it: 7750, train nll: 2.90949869156, mse: 14.5116691589, local kl: 0.00181035208516 global kl: 0.0177369918674 valid nll: 2.88449883461, mse: 18.4165592194, local kl: 0.00683477800339 global kl: 0.1361143291
it: 7800, train nll: 2.92489433289, mse: 8.76372337341, local kl: 0.00129598996136 global kl: 0.0721649378538 valid nll: 2.97298192978, mse: 12.718996048, local kl: 0.00285088270903 global kl: 0.0707115978003
it: 7850, train nll: 2.87219691277, mse: 10.8134775162, local kl: 0.000154888461111 global kl: 0.0596166849136 valid nll: 2.90465307236, mse: 18.6365814209, local kl: 0.00177571596578 global kl: 0.0757354944944
it: 7900, train nll: 2.8716366291, mse: 6.84365320206, local kl: 0.000294502067845 global kl: 0.0298784915358 valid nll: 2.89459133148, mse: 10.3385686874, local kl: 0.0043379580602 global kl: 0.0891936272383
it: 7950, train nll: 2.82290673256, mse: 9.61950969696, local kl: 0.00110657827463 global kl: 0.0273420121521 valid nll: 2.89581465721, mse: 12.9724750519, local kl: 0.00409218808636 global kl: 0.0248970743269
it: 8000, train nll: 2.92507719994, mse: 11.2989711761, local kl: 0.00254253181629 global kl: 0.0582532510161 valid nll: 2.89576625824, mse: 18.9427070618, local kl: 0.00055822514696 global kl: 0.100237466395
it: 8050, train nll: 2.93101167679, mse: 13.6367502213, local kl: 0.00155271973927 global kl: 0.0352201499045 valid nll: 2.98416113853, mse: 12.5439558029, local kl: 0.00113043212332 global kl: 0.0753442496061
it: 8100, train nll: 2.98620891571, mse: 24.3721218109, local kl: 0.00487171392888 global kl: 0.0257049053907 valid nll: 2.99545192719, mse: 14.0225200653, local kl: 0.000657471711747 global kl: 0.0552011355758
it: 8150, train nll: 2.93265628815, mse: 12.9956703186, local kl: 0.000440688279923 global kl: 0.0226471759379 valid nll: 2.92086553574, mse: 19.0875740051, local kl: 0.000167424208485 global kl: 0.129665151238
it: 8200, train nll: 2.89876770973, mse: 8.86827945709, local kl: 0.00542398169637 global kl: 0.041543006897 valid nll: 2.90234613419, mse: 10.6870002747, local kl: 0.000818692205939 global kl: 0.169620245695
it: 8250, train nll: 2.95095944405, mse: 11.6267967224, local kl: 0.0021087185014 global kl: 0.106647633016 valid nll: 2.99757504463, mse: 19.2928085327, local kl: 0.0132223470137 global kl: 0.118721053004
it: 8300, train nll: 2.82975411415, mse: 10.243024826, local kl: 0.0037503647618 global kl: 0.0406527705491 valid nll: 2.90341520309, mse: 10.21575737, local kl: 0.00224472559057 global kl: 0.158806622028
it: 8350, train nll: 2.86359667778, mse: 25.6806144714, local kl: 0.000667435175274 global kl: 0.0551279857755 valid nll: 2.99583768845, mse: 14.5531997681, local kl: 0.00107875559479 global kl: 0.169570639729
it: 8400, train nll: 3.14110660553, mse: 36.3177566528, local kl: 0.00101124146022 global kl: 0.0499974861741 valid nll: 2.91709041595, mse: 17.9494609833, local kl: 0.0051752794534 global kl: 0.242678195238
it: 8450, train nll: 2.86045384407, mse: 19.3798007965, local kl: 0.00375230843201 global kl: 0.12807957828 valid nll: 3.01350712776, mse: 21.5379962921, local kl: 0.00302310870029 global kl: 0.10569883883
it: 8500, train nll: 2.82951855659, mse: 14.7812681198, local kl: 0.00461621861905 global kl: 0.0375213250518 valid nll: 2.93657708168, mse: 21.5433311462, local kl: 0.11234036833 global kl: 0.116898551583
it: 8550, train nll: 2.90910553932, mse: 13.419921875, local kl: 0.00176299456507 global kl: 0.102993890643 valid nll: 2.92288136482, mse: 10.0578899384, local kl: 0.000593163596932 global kl: 0.0806280225515
it: 8600, train nll: 2.92384791374, mse: 4.95807933807, local kl: 0.00059067318216 global kl: 0.0563807599247 valid nll: 2.94455957413, mse: 25.3780136108, local kl: 0.0016224947758 global kl: 0.109943248332
it: 8650, train nll: 2.8820374012, mse: 9.34777641296, local kl: 0.000756928697228 global kl: 0.0244181193411 valid nll: 2.95444536209, mse: 26.6486759186, local kl: 0.00344894593582 global kl: 0.125179752707
it: 8700, train nll: 2.90816569328, mse: 3.23829007149, local kl: 0.00308323954232 global kl: 0.2223726511 valid nll: 3.00672101974, mse: 16.5672225952, local kl: 0.0725807920098 global kl: 0.283537685871
it: 8750, train nll: 3.150359869, mse: 25.2703323364, local kl: 0.000872881384566 global kl: 0.0534724406898 valid nll: 3.021961689, mse: 21.4738807678, local kl: 0.00138256093487 global kl: 0.190621927381
it: 8800, train nll: 2.89375567436, mse: 8.62604522705, local kl: 0.00407453719527 global kl: 0.0794796571136 valid nll: 2.93699264526, mse: 15.5390892029, local kl: 0.0100700473413 global kl: 0.249614477158
it: 8850, train nll: 2.92344880104, mse: 8.80011558533, local kl: 0.00189246062655 global kl: 0.0536990240216 valid nll: 2.95056629181, mse: 13.9232521057, local kl: 0.00835805758834 global kl: 0.127717211843
it: 8900, train nll: 2.8635969162, mse: 7.17995977402, local kl: 0.00165800307877 global kl: 0.0344936177135 valid nll: 3.02617573738, mse: 23.3184013367, local kl: 0.0112868947908 global kl: 0.150423839688
it: 8950, train nll: 4.94181346893, mse: 27.2403736115, local kl: 0.00404550367966 global kl: 0.0691023916006 valid nll: 4.99633455276, mse: 19.1184959412, local kl: 0.00191490468569 global kl: 0.144712999463
it: 9000, train nll: 3.01739144325, mse: 24.9424533844, local kl: 0.00246736058034 global kl: 0.0800837799907 valid nll: 3.03729104996, mse: 26.4122409821, local kl: 0.00576774589717 global kl: 0.307545900345
it: 9050, train nll: 2.91635227203, mse: 18.5760860443, local kl: 0.00199292786419 global kl: 0.0415155589581 valid nll: 2.9881913662, mse: 29.5068283081, local kl: 0.00161291321274 global kl: 0.350832045078
it: 9100, train nll: 2.84134578705, mse: 5.55648374557, local kl: 0.000197938497877 global kl: 0.0332720838487 valid nll: 2.96108818054, mse: 19.5304641724, local kl: 0.000757570087444 global kl: 0.142863035202
it: 9150, train nll: 3.15748977661, mse: 27.6981163025, local kl: 0.00502405920997 global kl: 0.197786480188 valid nll: 2.99482941628, mse: 17.2269153595, local kl: 0.00394222326577 global kl: 0.191272467375
it: 9200, train nll: 3.20471191406, mse: 33.7951393127, local kl: 0.00140073022339 global kl: 0.0753173679113 valid nll: 3.10072278976, mse: 29.0391235352, local kl: 0.00133845896926 global kl: 0.0854525342584
it: 9250, train nll: 2.94998979568, mse: 7.67171382904, local kl: 0.000233292928897 global kl: 0.0400648899376 valid nll: 3.04610800743, mse: 13.2476959229, local kl: 0.00392587017268 global kl: 0.104500077665
it: 9300, train nll: 2.97900152206, mse: 28.4338111877, local kl: 0.000433157547377 global kl: 0.0903421193361 valid nll: 2.959815979, mse: 17.6085147858, local kl: 0.00806088000536 global kl: 0.0616888590157
it: 9350, train nll: 2.95973396301, mse: 22.5097465515, local kl: 0.00386127410457 global kl: 0.0515857636929 valid nll: 3.03640031815, mse: 19.0836334229, local kl: 0.00158660090528 global kl: 0.0794937238097
it: 9400, train nll: 2.92537140846, mse: 7.00175476074, local kl: 0.00321875279769 global kl: 0.0566972978413 valid nll: 4.81221103668, mse: 18.2943477631, local kl: 0.00123093882576 global kl: 0.0507823117077
it: 9450, train nll: 2.97328567505, mse: 9.4620923996, local kl: 0.000263881025603 global kl: 0.0297016445547 valid nll: 2.99260878563, mse: 20.7616386414, local kl: 0.00147782545537 global kl: 0.0709002241492
it: 9500, train nll: 2.95713663101, mse: 17.5633831024, local kl: 0.000339032470947 global kl: 0.0867501571774 valid nll: 2.96598720551, mse: 9.01369762421, local kl: 0.000389911117963 global kl: 0.046994227916
it: 9550, train nll: 2.91861486435, mse: 17.0716819763, local kl: 0.000385452265618 global kl: 0.0395494475961 valid nll: 3.18196034431, mse: 23.6904468536, local kl: 0.0070534539409 global kl: 0.118608869612
it: 9600, train nll: 2.87248396873, mse: 18.2736206055, local kl: 0.00144983804785 global kl: 0.0358879491687 valid nll: 3.02214741707, mse: 23.4814624786, local kl: 0.00294861826114 global kl: 0.0702672153711
it: 9650, train nll: 2.92430734634, mse: 14.1031942368, local kl: 0.000758342735935 global kl: 0.028947468847 valid nll: 3.06124281883, mse: 28.1839179993, local kl: 0.0156695153564 global kl: 0.136490106583
it: 9700, train nll: 2.97900342941, mse: 30.5376052856, local kl: 0.00218865298666 global kl: 0.0445156991482 valid nll: 3.00968980789, mse: 27.0379657745, local kl: 0.00361331249587 global kl: 0.0799587219954
it: 9750, train nll: 2.92439031601, mse: 9.23939800262, local kl: 0.00151974055916 global kl: 0.0580890290439 valid nll: 3.04930734634, mse: 18.5313644409, local kl: 0.00188106391579 global kl: 0.185203149915
it: 9800, train nll: 2.8667371273, mse: 5.03585386276, local kl: 5.51202720089e-05 global kl: 0.034987129271 valid nll: 3.0304453373, mse: 19.5715179443, local kl: 0.00131210708059 global kl: 0.161949664354
it: 9850, train nll: 2.99008321762, mse: 21.0021495819, local kl: 0.00186575925909 global kl: 0.0468719825149 valid nll: 3.01696920395, mse: 21.1073112488, local kl: 0.00342134828679 global kl: 0.0498446971178
it: 9900, train nll: 2.98803544044, mse: 26.2247009277, local kl: 0.000654980132822 global kl: 0.081779293716 valid nll: 3.02636265755, mse: 13.8944711685, local kl: 0.000963420956396 global kl: 0.0334536768496
it: 9950, train nll: 2.89250612259, mse: 9.36975002289, local kl: 0.000381960911909 global kl: 0.0176936089993 valid nll: 3.00395154953, mse: 15.7712621689, local kl: 0.0016575162299 global kl: 0.0290727950633

Prior Predictive + GP


In [0]:
uncertainty_type = 'attentive_gp'
local_variational = False
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_prior_gp_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)


it: 0, train nll: 133.707901001, mse: 265.018371582, local kl: 0.0 global kl: 4.86915232614e-05 valid nll: 155.182296753, mse: 307.859466553, local kl: 0.0 global kl: 6.83785183355e-05
Saving best model with MSE 307.85947
it: 50, train nll: 30.6259536743, mse: 47.5792503357, local kl: 0.0 global kl: 0.00189643469639 valid nll: 44.8390197754, mse: 69.9601364136, local kl: 0.0 global kl: 0.00633073132485
Saving best model with MSE 69.96014
it: 100, train nll: 30.8255844116, mse: 46.4106445312, local kl: 0.0 global kl: 0.000299640523735 valid nll: 50.1438789368, mse: 65.9298400879, local kl: 0.0 global kl: 0.00054213067051
Saving best model with MSE 65.92984
it: 150, train nll: 20.5862350464, mse: 29.4456310272, local kl: 0.0 global kl: 2.68571693596e-05 valid nll: 45.0638008118, mse: 60.0532035828, local kl: 0.0 global kl: 8.25333918328e-05
Saving best model with MSE 60.053204
it: 200, train nll: 23.5716304779, mse: 32.0456619263, local kl: 0.0 global kl: 3.63523827218e-06 valid nll: 45.6291809082, mse: 58.6537971497, local kl: 0.0 global kl: 5.65321306567e-06
Saving best model with MSE 58.653797
it: 250, train nll: 18.016166687, mse: 25.7392730713, local kl: 0.0 global kl: 4.58404574601e-05 valid nll: 40.7628974915, mse: 55.1388969421, local kl: 0.0 global kl: 3.49385882146e-05
Saving best model with MSE 55.138897
it: 300, train nll: 41.2081756592, mse: 38.6637573242, local kl: 0.0 global kl: 8.45341310196e-06 valid nll: 35.4655838013, mse: 49.5577850342, local kl: 0.0 global kl: 5.52751225769e-06
Saving best model with MSE 49.557785
it: 350, train nll: 8.72609710693, mse: 12.6746664047, local kl: 0.0 global kl: 2.20449237531e-06 valid nll: 38.7598762512, mse: 47.9833145142, local kl: 0.0 global kl: 1.8187099613e-06
Saving best model with MSE 47.983315
it: 400, train nll: 24.7499294281, mse: 29.2195663452, local kl: 0.0 global kl: 5.32537853815e-07 valid nll: 35.0146522522, mse: 47.5217514038, local kl: 0.0 global kl: 8.35251000808e-07
Saving best model with MSE 47.52175
it: 450, train nll: 12.9923734665, mse: 20.5920257568, local kl: 0.0 global kl: 4.39723862655e-07 valid nll: 33.1168632507, mse: 46.9113998413, local kl: 0.0 global kl: 4.01970282837e-07
Saving best model with MSE 46.9114
it: 500, train nll: 15.672208786, mse: 23.3050708771, local kl: 0.0 global kl: 2.01237824626e-07 valid nll: 31.0984916687, mse: 43.5356864929, local kl: 0.0 global kl: 1.99262402134e-07
Saving best model with MSE 43.535686
it: 550, train nll: 13.4539604187, mse: 19.9612503052, local kl: 0.0 global kl: 1.89163685604e-07 valid nll: 30.3498954773, mse: 41.2062225342, local kl: 0.0 global kl: 2.0106945442e-07
Saving best model with MSE 41.206223
it: 600, train nll: 13.4120025635, mse: 19.9294090271, local kl: 0.0 global kl: 1.51686066374e-07 valid nll: 32.4266433716, mse: 45.8623199463, local kl: 0.0 global kl: 1.76191491619e-07
it: 650, train nll: 19.6834888458, mse: 25.2140617371, local kl: 0.0 global kl: 1.03306831534e-07 valid nll: 34.755355835, mse: 41.76039505, local kl: 0.0 global kl: 1.3436209656e-07
it: 700, train nll: 18.2550258636, mse: 21.0077037811, local kl: 0.0 global kl: 7.57425624442e-08 valid nll: 31.4194564819, mse: 45.1046066284, local kl: 0.0 global kl: 1.56485640446e-07
it: 750, train nll: 14.6508827209, mse: 22.1003475189, local kl: 0.0 global kl: 7.55245395112e-08 valid nll: 28.1495838165, mse: 39.8295249939, local kl: 0.0 global kl: 9.28334102923e-08
Saving best model with MSE 39.829525
it: 800, train nll: 19.4836196899, mse: 16.5695991516, local kl: 0.0 global kl: 8.55725730275e-08 valid nll: 31.8805885315, mse: 38.5065536499, local kl: 0.0 global kl: 8.18531802338e-08
Saving best model with MSE 38.506554
it: 850, train nll: 15.0241937637, mse: 22.7491073608, local kl: 0.0 global kl: 9.41051112591e-08 valid nll: 31.928817749, mse: 38.8029556274, local kl: 0.0 global kl: 1.04106206322e-07
it: 900, train nll: 30.8498516083, mse: 23.4114131927, local kl: 0.0 global kl: 8.65372840053e-08 valid nll: 37.7762451172, mse: 47.0648994446, local kl: 0.0 global kl: 8.30436874821e-08
it: 950, train nll: 9.85110855103, mse: 12.2681093216, local kl: 0.0 global kl: 7.14673191737e-08 valid nll: 34.3194580078, mse: 37.842086792, local kl: 0.0 global kl: 7.06960676666e-08
Saving best model with MSE 37.842087
it: 1000, train nll: 10.2736816406, mse: 14.7614088058, local kl: 0.0 global kl: 6.40388151396e-08 valid nll: 34.2499084473, mse: 43.8121871948, local kl: 0.0 global kl: 7.91856749061e-08
it: 1050, train nll: 18.1204376221, mse: 25.7743835449, local kl: 0.0 global kl: 5.29122345938e-08 valid nll: 33.0834236145, mse: 41.0989227295, local kl: 0.0 global kl: 6.06423355975e-08
it: 1100, train nll: 10.3615226746, mse: 16.1424732208, local kl: 0.0 global kl: 4.07174489681e-08 valid nll: 39.7120857239, mse: 44.8182411194, local kl: 0.0 global kl: 6.7565501638e-08
it: 1150, train nll: 10.7096090317, mse: 13.5021820068, local kl: 0.0 global kl: 3.05241236731e-08 valid nll: 27.2186107635, mse: 32.1571083069, local kl: 0.0 global kl: 5.27610843903e-08
Saving best model with MSE 32.15711
it: 1200, train nll: 8.67172622681, mse: 13.2727470398, local kl: 0.0 global kl: 2.72822262559e-08 valid nll: 30.629070282, mse: 36.4547157288, local kl: 0.0 global kl: 4.23661390414e-08
it: 1250, train nll: 13.5819349289, mse: 17.9439239502, local kl: 0.0 global kl: 4.16720418173e-08 valid nll: 31.1222343445, mse: 34.1750144958, local kl: 0.0 global kl: 4.17115373352e-08
it: 1300, train nll: 9.83561134338, mse: 14.1067743301, local kl: 0.0 global kl: 2.93339788016e-08 valid nll: 35.0054702759, mse: 44.1275367737, local kl: 0.0 global kl: 4.13156122647e-08
it: 1350, train nll: 11.993812561, mse: 14.6567287445, local kl: 0.0 global kl: 3.5957054223e-08 valid nll: 30.3739891052, mse: 35.7063522339, local kl: 0.0 global kl: 3.71942689981e-08
it: 1400, train nll: 33.9864387512, mse: 29.0754814148, local kl: 0.0 global kl: 6.0547122871e-08 valid nll: 32.0140151978, mse: 42.7676620483, local kl: 0.0 global kl: 5.0724416667e-08
it: 1450, train nll: 9.88674736023, mse: 13.3375654221, local kl: 0.0 global kl: 4.76728807541e-08 valid nll: 31.813659668, mse: 34.6614074707, local kl: 0.0 global kl: 5.11119822022e-08
it: 1500, train nll: 4.05601167679, mse: 5.15707540512, local kl: 0.0 global kl: 2.05161381217e-08 valid nll: 28.7104663849, mse: 34.9150886536, local kl: 0.0 global kl: 2.74185119054e-08
it: 1550, train nll: 15.9245119095, mse: 16.4678268433, local kl: 0.0 global kl: 2.96152791179e-08 valid nll: 26.6316242218, mse: 33.5362434387, local kl: 0.0 global kl: 3.84528675568e-08
it: 1600, train nll: 27.7523365021, mse: 23.1754779816, local kl: 0.0 global kl: 3.10772705348e-08 valid nll: 36.3902282715, mse: 33.663646698, local kl: 0.0 global kl: 4.29101056909e-08
it: 1650, train nll: 11.1874952316, mse: 13.476102829, local kl: 0.0 global kl: 2.75456351062e-08 valid nll: 34.1864776611, mse: 39.4754638672, local kl: 0.0 global kl: 2.88325825437e-08
it: 1700, train nll: 14.4262914658, mse: 19.3412570953, local kl: 0.0 global kl: 2.93791888595e-08 valid nll: 31.0226325989, mse: 35.7445831299, local kl: 0.0 global kl: 2.98971265522e-08
it: 1750, train nll: 8.05744743347, mse: 11.5430498123, local kl: 0.0 global kl: 2.68915556489e-08 valid nll: 41.5743408203, mse: 41.5426445007, local kl: 0.0 global kl: 3.15435997322e-08
it: 1800, train nll: 13.4556875229, mse: 17.5921764374, local kl: 0.0 global kl: 1.77207049035e-08 valid nll: 34.6112747192, mse: 42.4671058655, local kl: 0.0 global kl: 2.58084043026e-08
it: 1850, train nll: 24.8260707855, mse: 15.7062101364, local kl: 0.0 global kl: 3.25479092567e-08 valid nll: 32.8421173096, mse: 32.5204925537, local kl: 0.0 global kl: 3.27724976046e-08
it: 1900, train nll: 10.4088487625, mse: 10.2746162415, local kl: 0.0 global kl: 2.22410072581e-08 valid nll: 34.2458877563, mse: 33.6699180603, local kl: 0.0 global kl: 2.4964521117e-08
it: 1950, train nll: 8.21937274933, mse: 11.3279161453, local kl: 0.0 global kl: 2.34721344583e-08 valid nll: 31.6118659973, mse: 34.645690918, local kl: 0.0 global kl: 2.20800213668e-08
it: 2000, train nll: 9.8427734375, mse: 13.6427879333, local kl: 0.0 global kl: 1.62503930312e-08 valid nll: 35.9252357483, mse: 34.856842041, local kl: 0.0 global kl: 2.40339268487e-08
it: 2050, train nll: 5.18579530716, mse: 7.26883029938, local kl: 0.0 global kl: 1.81266592847e-08 valid nll: 32.2024765015, mse: 33.3473014832, local kl: 0.0 global kl: 2.1174399123e-08
it: 2100, train nll: 6.55301904678, mse: 8.11044979095, local kl: 0.0 global kl: 3.91561556512e-08 valid nll: 32.6022491455, mse: 41.010925293, local kl: 0.0 global kl: 2.43687203749e-08
it: 2150, train nll: 40.2286109924, mse: 22.7188873291, local kl: 0.0 global kl: 1.72996585945e-08 valid nll: 34.0046310425, mse: 32.7598266602, local kl: 0.0 global kl: 2.3685595707e-08
it: 2200, train nll: 16.4540958405, mse: 13.8568353653, local kl: 0.0 global kl: 3.17386437132e-08 valid nll: 30.2243709564, mse: 28.8910865784, local kl: 0.0 global kl: 2.49574156896e-08
Saving best model with MSE 28.891087
it: 2250, train nll: 11.0381546021, mse: 14.3690872192, local kl: 0.0 global kl: 2.22500204927e-08 valid nll: 25.8503894806, mse: 28.8446540833, local kl: 0.0 global kl: 2.52671021883e-08
Saving best model with MSE 28.844654
it: 2300, train nll: 12.9101238251, mse: 13.5522356033, local kl: 0.0 global kl: 1.46129508494e-08 valid nll: 24.8513870239, mse: 25.3401966095, local kl: 0.0 global kl: 2.28482157638e-08
Saving best model with MSE 25.340197
it: 2350, train nll: 8.23562145233, mse: 11.2613811493, local kl: 0.0 global kl: 1.49314711706e-08 valid nll: 39.0460166931, mse: 41.2587623596, local kl: 0.0 global kl: 2.24169944829e-08
it: 2400, train nll: 14.6814393997, mse: 10.1427602768, local kl: 0.0 global kl: 1.23632828419e-08 valid nll: 27.7793254852, mse: 29.1750526428, local kl: 0.0 global kl: 1.42595295571e-08
it: 2450, train nll: 9.72875499725, mse: 10.1328077316, local kl: 0.0 global kl: 1.66424349857e-08 valid nll: 30.6333179474, mse: 38.6847801208, local kl: 0.0 global kl: 1.7072379066e-08
it: 2500, train nll: 23.6245937347, mse: 17.7804374695, local kl: 0.0 global kl: 1.62883075916e-08 valid nll: 36.429901123, mse: 34.0731582642, local kl: 0.0 global kl: 2.12359250185e-08
it: 2550, train nll: 7.50979757309, mse: 8.72747516632, local kl: 0.0 global kl: 1.41141782706e-08 valid nll: 34.7818374634, mse: 37.7508735657, local kl: 0.0 global kl: 1.63045754675e-08
it: 2600, train nll: 23.8032817841, mse: 23.3341827393, local kl: 0.0 global kl: 1.79453785165e-08 valid nll: 24.590171814, mse: 26.9764976501, local kl: 0.0 global kl: 1.78196071232e-08
it: 2650, train nll: 81.9334259033, mse: 25.0416793823, local kl: 0.0 global kl: 1.6530764313e-08 valid nll: 47.4432525635, mse: 35.3413391113, local kl: 0.0 global kl: 1.89324627087e-08
it: 2700, train nll: 6.93671607971, mse: 8.64537143707, local kl: 0.0 global kl: 1.4619141453e-08 valid nll: 26.3637466431, mse: 31.0046691895, local kl: 0.0 global kl: 1.61967861345e-08
it: 2750, train nll: 5.06015396118, mse: 6.19331216812, local kl: 0.0 global kl: 2.0591073735e-08 valid nll: 41.1184692383, mse: 33.3299827576, local kl: 0.0 global kl: 1.76310273048e-08
it: 2800, train nll: 16.4461135864, mse: 20.3944664001, local kl: 0.0 global kl: 1.0692820851e-08 valid nll: 33.5689544678, mse: 32.5938491821, local kl: 0.0 global kl: 1.75019660986e-08
it: 2850, train nll: 16.8449745178, mse: 21.153137207, local kl: 0.0 global kl: 1.96463787461e-08 valid nll: 51.9497413635, mse: 42.3756904602, local kl: 0.0 global kl: 1.79487749108e-08
it: 2900, train nll: 18.4105472565, mse: 20.242893219, local kl: 0.0 global kl: 1.32798332331e-08 valid nll: 33.3428039551, mse: 39.6739959717, local kl: 0.0 global kl: 1.27731230037e-08
it: 2950, train nll: 30.2812347412, mse: 34.3992958069, local kl: 0.0 global kl: 1.70049645476e-08 valid nll: 60.046875, mse: 36.2716064453, local kl: 0.0 global kl: 1.7689913534e-08
it: 3000, train nll: 14.4753379822, mse: 17.461774826, local kl: 0.0 global kl: 1.86726119011e-08 valid nll: 40.4033355713, mse: 39.6373710632, local kl: 0.0 global kl: 1.54041011058e-08
it: 3050, train nll: 9.53898715973, mse: 12.4021253586, local kl: 0.0 global kl: 1.09327107367e-08 valid nll: 43.3763198853, mse: 34.9481315613, local kl: 0.0 global kl: 8.60110738188e-09
it: 3100, train nll: 14.9289073944, mse: 19.2128448486, local kl: 0.0 global kl: 2.32856969262e-08 valid nll: 39.5546875, mse: 39.9126243591, local kl: 0.0 global kl: 1.97213303466e-08
it: 3150, train nll: 10.0171003342, mse: 13.4681997299, local kl: 0.0 global kl: 1.75346226428e-08 valid nll: 35.0317001343, mse: 38.147064209, local kl: 0.0 global kl: 1.73075864751e-08
it: 3200, train nll: 4.8725938797, mse: 6.32702445984, local kl: 0.0 global kl: 1.02044195316e-08 valid nll: 36.123966217, mse: 32.2299957275, local kl: 0.0 global kl: 1.84580759566e-08
it: 3250, train nll: 14.2127599716, mse: 9.33010959625, local kl: 0.0 global kl: 1.40318832109e-08 valid nll: 35.5183868408, mse: 29.5917453766, local kl: 0.0 global kl: 1.18230163437e-08
it: 3300, train nll: 22.9382133484, mse: 22.6170501709, local kl: 0.0 global kl: 1.0328006006e-08 valid nll: 34.2001419067, mse: 32.8183937073, local kl: 0.0 global kl: 8.1155535625e-09
it: 3350, train nll: 10.3433094025, mse: 13.1893854141, local kl: 0.0 global kl: 1.86387367762e-08 valid nll: 129.866592407, mse: 38.0105133057, local kl: 0.0 global kl: 2.35881749688e-08
it: 3400, train nll: 9.33993625641, mse: 10.7642765045, local kl: 0.0 global kl: 1.25302506149e-08 valid nll: 56.1065788269, mse: 37.2219276428, local kl: 0.0 global kl: 1.18553087347e-08
it: 3450, train nll: 9.54908180237, mse: 12.5324211121, local kl: 0.0 global kl: 7.48710959897e-09 valid nll: 61.3345413208, mse: 39.4972114563, local kl: 0.0 global kl: 1.30268178467e-08
it: 3500, train nll: 10.4155483246, mse: 15.522482872, local kl: 0.0 global kl: 1.01748138803e-08 valid nll: 48.3018531799, mse: 31.3606777191, local kl: 0.0 global kl: 9.56178247691e-09
it: 3550, train nll: 9.03656959534, mse: 12.5879077911, local kl: 0.0 global kl: 7.77755371217e-09 valid nll: 40.2242698669, mse: 34.8332328796, local kl: 0.0 global kl: 1.00232666611e-08
it: 3600, train nll: 12.4592647552, mse: 19.3305835724, local kl: 0.0 global kl: 1.37288029833e-08 valid nll: 51.1787757874, mse: 38.2089729309, local kl: 0.0 global kl: 1.40781200031e-08
it: 3650, train nll: 27.0539169312, mse: 19.4224128723, local kl: 0.0 global kl: 1.21599912362e-08 valid nll: 41.6930274963, mse: 37.0232849121, local kl: 0.0 global kl: 1.0540498252e-08
it: 3700, train nll: 11.1929960251, mse: 12.1724071503, local kl: 0.0 global kl: 1.2998768284e-08 valid nll: 44.6670684814, mse: 34.6347579956, local kl: 0.0 global kl: 1.1739611061e-08
it: 3750, train nll: 36.3348197937, mse: 9.48810768127, local kl: 0.0 global kl: 7.16623649311e-09 valid nll: 28.2574577332, mse: 25.6822376251, local kl: 0.0 global kl: 6.96501434305e-09
it: 3800, train nll: 29.4124622345, mse: 22.093328476, local kl: 0.0 global kl: 1.07217648093e-08 valid nll: 49.7998046875, mse: 31.5421123505, local kl: 0.0 global kl: 1.13708491511e-08
it: 3850, train nll: 5.64269590378, mse: 6.97565507889, local kl: 0.0 global kl: 1.27954491447e-08 valid nll: 53.1980056763, mse: 31.5298519135, local kl: 0.0 global kl: 1.06314477222e-08
it: 3900, train nll: 15.177564621, mse: 17.2344284058, local kl: 0.0 global kl: 1.46205163531e-08 valid nll: 34.8319015503, mse: 28.9723968506, local kl: 0.0 global kl: 1.85472810443e-08
it: 3950, train nll: 8.01558589935, mse: 9.83223819733, local kl: 0.0 global kl: 1.18909939673e-08 valid nll: 51.327381134, mse: 28.7799358368, local kl: 0.0 global kl: 1.20876135767e-08
it: 4000, train nll: 4.11389780045, mse: 5.46667289734, local kl: 0.0 global kl: 6.46786935121e-09 valid nll: 42.5624275208, mse: 31.1987247467, local kl: 0.0 global kl: 8.50684767073e-09
it: 4050, train nll: 18.2897491455, mse: 23.2408809662, local kl: 0.0 global kl: 8.15043676994e-09 valid nll: 63.0100708008, mse: 28.6555404663, local kl: 0.0 global kl: 1.07173745434e-08
it: 4100, train nll: 34.6209373474, mse: 25.6527175903, local kl: 0.0 global kl: 8.59945803455e-09 valid nll: 46.1460227966, mse: 32.5124206543, local kl: 0.0 global kl: 6.96260382682e-09
it: 4150, train nll: 3.36250209808, mse: 3.54414272308, local kl: 0.0 global kl: 7.32369676015e-09 valid nll: 39.6586074829, mse: 29.6904850006, local kl: 0.0 global kl: 7.01900315647e-09
it: 4200, train nll: 20.4037494659, mse: 17.3960952759, local kl: 0.0 global kl: 8.23753776302e-09 valid nll: 52.4722442627, mse: 30.8331451416, local kl: 0.0 global kl: 9.11973163653e-09
it: 4250, train nll: 14.3954744339, mse: 16.8423309326, local kl: 0.0 global kl: 6.68639810186e-09 valid nll: 36.7682647705, mse: 30.1677036285, local kl: 0.0 global kl: 7.52791784464e-09
it: 4300, train nll: 8.25340080261, mse: 12.1669483185, local kl: 0.0 global kl: 8.63356408587e-09 valid nll: 58.0368461609, mse: 36.1484336853, local kl: 0.0 global kl: 7.38528305178e-09
it: 4350, train nll: 31.411108017, mse: 15.5078353882, local kl: 0.0 global kl: 7.23326110119e-09 valid nll: 43.2302894592, mse: 33.3290557861, local kl: 0.0 global kl: 6.07496897231e-09
it: 4400, train nll: 64.1333465576, mse: 18.8841323853, local kl: 0.0 global kl: 1.04560937686e-08 valid nll: 44.1144981384, mse: 39.9702606201, local kl: 0.0 global kl: 8.98513441427e-09
it: 4450, train nll: 28.7255001068, mse: 31.3911113739, local kl: 0.0 global kl: 5.55726575868e-09 valid nll: 26.9469909668, mse: 26.9031238556, local kl: 0.0 global kl: 8.22283219293e-09
it: 4500, train nll: 8.67902946472, mse: 9.86683273315, local kl: 0.0 global kl: 1.19347038918e-08 valid nll: 32.4253082275, mse: 27.1181983948, local kl: 0.0 global kl: 6.55394449822e-09
it: 4550, train nll: 12.7058134079, mse: 15.8891563416, local kl: 0.0 global kl: 5.71682168271e-09 valid nll: 33.503288269, mse: 29.7857723236, local kl: 0.0 global kl: 6.41490771613e-09
it: 4600, train nll: 4.4546456337, mse: 5.76465272903, local kl: 0.0 global kl: 7.51235518237e-09 valid nll: 40.6656036377, mse: 31.5246372223, local kl: 0.0 global kl: 7.11635816941e-09
it: 4650, train nll: 12.6505184174, mse: 17.4095115662, local kl: 0.0 global kl: 8.76108607883e-09 valid nll: 44.3150787354, mse: 38.3963279724, local kl: 0.0 global kl: 7.90025822539e-09
it: 4700, train nll: 31.1618442535, mse: 20.6423873901, local kl: 0.0 global kl: 7.60402762978e-09 valid nll: 41.5319480896, mse: 27.2414188385, local kl: 0.0 global kl: 6.30034913129e-09
it: 4750, train nll: 9.35857963562, mse: 12.9515218735, local kl: 0.0 global kl: 5.68303715198e-09 valid nll: 35.4557609558, mse: 27.3513221741, local kl: 0.0 global kl: 5.96785421081e-09
it: 4800, train nll: 12.265255928, mse: 11.4960985184, local kl: 0.0 global kl: 6.91733692548e-09 valid nll: 34.386806488, mse: 25.3666229248, local kl: 0.0 global kl: 5.10902831152e-09
it: 4850, train nll: 14.539106369, mse: 16.1018981934, local kl: 0.0 global kl: 4.39793179652e-09 valid nll: 27.822052002, mse: 30.8235778809, local kl: 0.0 global kl: 4.92143659159e-09
it: 4900, train nll: 16.9510726929, mse: 17.9136772156, local kl: 0.0 global kl: 5.88937343338e-09 valid nll: 42.0698509216, mse: 30.6344928741, local kl: 0.0 global kl: 5.56633716897e-09
it: 4950, train nll: 18.5551738739, mse: 23.1467914581, local kl: 0.0 global kl: 3.4627611889e-09 valid nll: 41.4296264648, mse: 39.1106491089, local kl: 0.0 global kl: 6.10888406527e-09
it: 5000, train nll: 19.6543197632, mse: 10.6738901138, local kl: 0.0 global kl: 4.63834703979e-09 valid nll: 47.5431556702, mse: 28.4380931854, local kl: 0.0 global kl: 4.61287941178e-09
it: 5050, train nll: 10.2509527206, mse: 14.5198411942, local kl: 0.0 global kl: 6.97928381754e-09 valid nll: 46.8182106018, mse: 35.7438659668, local kl: 0.0 global kl: 6.21306650572e-09
it: 5100, train nll: 8.76458072662, mse: 10.9291152954, local kl: 0.0 global kl: 4.06342204329e-09 valid nll: 49.8320732117, mse: 31.2230415344, local kl: 0.0 global kl: 5.01875696557e-09
it: 5150, train nll: 44.1456031799, mse: 27.8630332947, local kl: 0.0 global kl: 4.68351135652e-09 valid nll: 43.1842193604, mse: 33.4552078247, local kl: 0.0 global kl: 7.22676585241e-09
it: 5200, train nll: 9.29118347168, mse: 8.88558292389, local kl: 0.0 global kl: 4.10131306694e-09 valid nll: 28.697637558, mse: 25.6173381805, local kl: 0.0 global kl: 4.39854108691e-09
it: 5250, train nll: 9.84863185883, mse: 12.7890701294, local kl: 0.0 global kl: 4.09704714599e-09 valid nll: 39.6927490234, mse: 28.960319519, local kl: 0.0 global kl: 4.28176649692e-09
it: 5300, train nll: 15.531826973, mse: 13.9144983292, local kl: 0.0 global kl: 2.72348898989e-09 valid nll: 30.2736225128, mse: 26.4474411011, local kl: 0.0 global kl: 3.82968146084e-09
it: 5350, train nll: 6.6112909317, mse: 9.05863571167, local kl: 0.0 global kl: 3.53318374557e-09 valid nll: 48.0161018372, mse: 33.5932044983, local kl: 0.0 global kl: 4.21728652e-09
it: 5400, train nll: 23.5656223297, mse: 22.2717380524, local kl: 0.0 global kl: 3.93853394343e-09 valid nll: 27.7980384827, mse: 24.2045440674, local kl: 0.0 global kl: 3.74656838886e-09
Saving best model with MSE 24.204544
it: 5450, train nll: 2.92077732086, mse: 3.31508779526, local kl: 0.0 global kl: 3.03236413757e-09 valid nll: 60.0353279114, mse: 30.205499649, local kl: 0.0 global kl: 4.11445233439e-09
it: 5500, train nll: 3.14115977287, mse: 3.88418364525, local kl: 0.0 global kl: 5.43142864018e-09 valid nll: 39.5337333679, mse: 28.999546051, local kl: 0.0 global kl: 5.98093086168e-09
it: 5550, train nll: 5.85880184174, mse: 7.69276714325, local kl: 0.0 global kl: 4.8399249053e-09 valid nll: 37.5405082703, mse: 37.4496307373, local kl: 0.0 global kl: 4.36988667474e-09
it: 5600, train nll: 13.3457012177, mse: 17.2879314423, local kl: 0.0 global kl: 6.19115425593e-09 valid nll: 42.5615653992, mse: 32.9732017517, local kl: 0.0 global kl: 5.61886626116e-09
it: 5650, train nll: 8.92033481598, mse: 9.97661590576, local kl: 0.0 global kl: 4.63074334434e-09 valid nll: 60.8525390625, mse: 33.0636444092, local kl: 0.0 global kl: 3.9653405004e-09
it: 5700, train nll: 9.55978298187, mse: 11.7281951904, local kl: 0.0 global kl: 3.80559139757e-09 valid nll: 38.8418807983, mse: 29.0648040771, local kl: 0.0 global kl: 5.67888491787e-09
it: 5750, train nll: 2.3836042881, mse: 2.16212582588, local kl: 0.0 global kl: 3.00318125923e-09 valid nll: 32.0130500793, mse: 31.638835907, local kl: 0.0 global kl: 4.52866943945e-09
it: 5800, train nll: 7.03562736511, mse: 9.11522674561, local kl: 0.0 global kl: 5.11145836768e-09 valid nll: 51.2593307495, mse: 28.1582946777, local kl: 0.0 global kl: 5.03497377125e-09
it: 5850, train nll: 3.62049889565, mse: 4.93066692352, local kl: 0.0 global kl: 3.53942652964e-09 valid nll: 59.869934082, mse: 33.5695648193, local kl: 0.0 global kl: 5.08688424716e-09
it: 5900, train nll: 5.61607694626, mse: 7.85008192062, local kl: 0.0 global kl: 2.7096969113e-09 valid nll: 41.8254127502, mse: 30.2642879486, local kl: 0.0 global kl: 3.65802899083e-09
it: 5950, train nll: 13.2283582687, mse: 17.2664108276, local kl: 0.0 global kl: 2.95721802601e-09 valid nll: 43.2171974182, mse: 35.8172073364, local kl: 0.0 global kl: 2.96344060402e-09
it: 6000, train nll: 3.2725212574, mse: 3.27885007858, local kl: 0.0 global kl: 3.22226645366e-09 valid nll: 91.7629699707, mse: 40.0873985291, local kl: 0.0 global kl: 3.54973361816e-09
it: 6050, train nll: 14.4691324234, mse: 19.821472168, local kl: 0.0 global kl: 2.33169838992e-09 valid nll: 32.4550704956, mse: 23.4914932251, local kl: 0.0 global kl: 2.42887510105e-09
Saving best model with MSE 23.491493
it: 6100, train nll: 33.1699790955, mse: 25.0458984375, local kl: 0.0 global kl: 4.22138901612e-09 valid nll: 46.321056366, mse: 32.0545120239, local kl: 0.0 global kl: 4.23658796933e-09
it: 6150, train nll: 9.59853458405, mse: 12.7495450974, local kl: 0.0 global kl: 3.13603409907e-09 valid nll: 51.4027748108, mse: 29.6591339111, local kl: 0.0 global kl: 3.32236416156e-09
it: 6200, train nll: 16.6133785248, mse: 20.218963623, local kl: 0.0 global kl: 2.71721734002e-09 valid nll: 43.4887428284, mse: 36.3592224121, local kl: 0.0 global kl: 2.62725663447e-09
it: 6250, train nll: 18.6435203552, mse: 11.8024196625, local kl: 0.0 global kl: 3.30990879149e-09 valid nll: 40.5587348938, mse: 30.0720081329, local kl: 0.0 global kl: 2.81646039824e-09
it: 6300, train nll: 7.64785242081, mse: 9.49889469147, local kl: 0.0 global kl: 2.54689469514e-09 valid nll: 37.3204193115, mse: 27.6396579742, local kl: 0.0 global kl: 2.39803954472e-09
it: 6350, train nll: 2.32910609245, mse: 2.32053899765, local kl: 0.0 global kl: 3.06046321619e-09 valid nll: 49.1421775818, mse: 33.9996948242, local kl: 0.0 global kl: 3.23048610085e-09
it: 6400, train nll: 7.0587387085, mse: 8.22596549988, local kl: 0.0 global kl: 2.69011923848e-09 valid nll: 31.6308517456, mse: 24.975933075, local kl: 0.0 global kl: 3.18190984672e-09
it: 6450, train nll: 2.81026697159, mse: 2.92264461517, local kl: 0.0 global kl: 2.52412535318e-09 valid nll: 59.3208694458, mse: 34.0675849915, local kl: 0.0 global kl: 3.50044548902e-09
it: 6500, train nll: 49.1738471985, mse: 22.9584884644, local kl: 0.0 global kl: 4.74144723484e-09 valid nll: 52.3899040222, mse: 36.1823196411, local kl: 0.0 global kl: 2.96374347286e-09
it: 6550, train nll: 14.8164224625, mse: 19.3031082153, local kl: 0.0 global kl: 2.5661002212e-09 valid nll: 39.0296173096, mse: 34.0272979736, local kl: 0.0 global kl: 2.3525845716e-09
it: 6600, train nll: 2.56479144096, mse: 2.70333504677, local kl: 0.0 global kl: 1.21031451528e-09 valid nll: 42.9418945312, mse: 29.9500160217, local kl: 0.0 global kl: 1.69864122768e-09
it: 6650, train nll: 9.68043518066, mse: 10.9781064987, local kl: 0.0 global kl: 4.05190681008e-09 valid nll: 52.2582702637, mse: 41.7855911255, local kl: 0.0 global kl: 3.42360206851e-09
it: 6700, train nll: 13.0077943802, mse: 14.1411457062, local kl: 0.0 global kl: 1.78155801223e-09 valid nll: 34.6217842102, mse: 25.0645046234, local kl: 0.0 global kl: 2.1345250012e-09
it: 6750, train nll: 5.11444759369, mse: 6.65851974487, local kl: 0.0 global kl: 1.81184456327e-09 valid nll: 65.9827575684, mse: 35.958770752, local kl: 0.0 global kl: 2.36286479272e-09
it: 6800, train nll: 34.7099838257, mse: 23.2275676727, local kl: 0.0 global kl: 1.58198287803e-09 valid nll: 44.5312194824, mse: 27.2303123474, local kl: 0.0 global kl: 2.56268672949e-09
it: 6850, train nll: 10.9391431808, mse: 11.6632099152, local kl: 0.0 global kl: 1.99459604389e-09 valid nll: 36.0390739441, mse: 29.6321754456, local kl: 0.0 global kl: 1.71114322711e-09
it: 6900, train nll: 4.14309453964, mse: 4.36735391617, local kl: 0.0 global kl: 2.29581487154e-09 valid nll: 43.0072860718, mse: 37.9940490723, local kl: 0.0 global kl: 1.71276115513e-09
it: 6950, train nll: 18.7013988495, mse: 19.4181957245, local kl: 0.0 global kl: 1.84082293853e-09 valid nll: 40.1774520874, mse: 26.2285118103, local kl: 0.0 global kl: 2.09851758193e-09
it: 7000, train nll: 50.3948631287, mse: 19.9143753052, local kl: 0.0 global kl: 1.62573832174e-09 valid nll: 51.3471374512, mse: 33.0167503357, local kl: 0.0 global kl: 1.99114813526e-09
it: 7050, train nll: 3.60745620728, mse: 4.73484039307, local kl: 0.0 global kl: 1.61639324148e-09 valid nll: 42.7377662659, mse: 37.1534233093, local kl: 0.0 global kl: 2.17467999164e-09
it: 7100, train nll: 3.22598338127, mse: 4.18896627426, local kl: 0.0 global kl: 1.69943814576e-09 valid nll: 81.1204605103, mse: 36.8131484985, local kl: 0.0 global kl: 1.27823251983e-09
it: 7150, train nll: 106.313858032, mse: 15.8776931763, local kl: 0.0 global kl: 2.03762562379e-09 valid nll: 62.649394989, mse: 29.1287555695, local kl: 0.0 global kl: 1.93775262503e-09
it: 7200, train nll: 25.4333534241, mse: 16.168970108, local kl: 0.0 global kl: 2.42906761372e-09 valid nll: 41.5430717468, mse: 27.7566642761, local kl: 0.0 global kl: 1.71493697021e-09
it: 7250, train nll: 9.99769878387, mse: 14.0965328217, local kl: 0.0 global kl: 1.98805771845e-09 valid nll: 46.6924324036, mse: 29.2385940552, local kl: 0.0 global kl: 1.82880399713e-09
it: 7300, train nll: 15.5268478394, mse: 20.9963111877, local kl: 0.0 global kl: 2.44019648932e-09 valid nll: 105.245574951, mse: 35.0302124023, local kl: 0.0 global kl: 2.77284306627e-09
it: 7350, train nll: 2.19816327095, mse: 2.27156376839, local kl: 0.0 global kl: 1.58404578343e-09 valid nll: 47.544052124, mse: 29.2437934875, local kl: 0.0 global kl: 1.59883450923e-09
it: 7400, train nll: 94.6355056763, mse: 33.3068695068, local kl: 0.0 global kl: 2.19317319861e-09 valid nll: 64.2739486694, mse: 40.0029258728, local kl: 0.0 global kl: 1.85415394149e-09
it: 7450, train nll: 17.2624073029, mse: 14.650434494, local kl: 0.0 global kl: 2.12865702842e-09 valid nll: 84.1756134033, mse: 41.9098434448, local kl: 0.0 global kl: 3.04904124171e-09
it: 7500, train nll: 12.9596805573, mse: 10.2071619034, local kl: 0.0 global kl: 2.54305598801e-09 valid nll: 57.241355896, mse: 27.9785423279, local kl: 0.0 global kl: 2.25512120089e-09
it: 7550, train nll: 12.8947324753, mse: 17.8392086029, local kl: 0.0 global kl: 1.73273606574e-09 valid nll: 54.1040077209, mse: 29.748632431, local kl: 0.0 global kl: 1.55370916133e-09
it: 7600, train nll: 67.1555557251, mse: 17.9020328522, local kl: 0.0 global kl: 2.00138350337e-09 valid nll: 81.820892334, mse: 36.0740013123, local kl: 0.0 global kl: 2.38776154404e-09
it: 7650, train nll: 3.81499314308, mse: 4.59018087387, local kl: 0.0 global kl: 1.92934734855e-09 valid nll: 52.5395317078, mse: 34.8464851379, local kl: 0.0 global kl: 1.69364233749e-09
it: 7700, train nll: 29.7394008636, mse: 15.9520177841, local kl: 0.0 global kl: 3.31138849674e-09 valid nll: 79.6802597046, mse: 34.4355049133, local kl: 0.0 global kl: 2.81987633244e-09
it: 7750, train nll: 10.5509872437, mse: 12.0710287094, local kl: 0.0 global kl: 2.02007299777e-09 valid nll: 62.9580612183, mse: 32.8720397949, local kl: 0.0 global kl: 2.37242581136e-09
it: 7800, train nll: 6.90877437592, mse: 8.80881214142, local kl: 0.0 global kl: 2.74373057607e-09 valid nll: 80.1445007324, mse: 39.233505249, local kl: 0.0 global kl: 2.6274240561e-09
it: 7850, train nll: 7.97787714005, mse: 8.11868953705, local kl: 0.0 global kl: 2.51738163648e-09 valid nll: 58.6545753479, mse: 32.6371040344, local kl: 0.0 global kl: 2.44602871291e-09
it: 7900, train nll: 3.48235797882, mse: 4.86089897156, local kl: 0.0 global kl: 1.65647340289e-09 valid nll: 55.3811378479, mse: 27.7708492279, local kl: 0.0 global kl: 2.27476948389e-09
it: 7950, train nll: 13.0779371262, mse: 13.983133316, local kl: 0.0 global kl: 2.22685092588e-09 valid nll: 51.0279426575, mse: 31.6236248016, local kl: 0.0 global kl: 1.33814659353e-09
it: 8000, train nll: 6.02258300781, mse: 8.60881328583, local kl: 0.0 global kl: 2.20353180147e-09 valid nll: 49.4096641541, mse: 28.6799373627, local kl: 0.0 global kl: 1.96979699219e-09
it: 8050, train nll: 3.82658243179, mse: 4.89877128601, local kl: 0.0 global kl: 1.55132018342e-09 valid nll: 32.3883743286, mse: 23.9492263794, local kl: 0.0 global kl: 1.30405153342e-09
it: 8100, train nll: 45.5364646912, mse: 16.9558753967, local kl: 0.0 global kl: 1.04009345581e-09 valid nll: 44.2011795044, mse: 27.1222057343, local kl: 0.0 global kl: 9.19307407976e-10
it: 8150, train nll: 9.21020889282, mse: 5.9818520546, local kl: 0.0 global kl: 2.5036961393e-09 valid nll: 101.782028198, mse: 37.3456192017, local kl: 0.0 global kl: 1.41703859757e-09
it: 8200, train nll: 6.45548725128, mse: 9.44457530975, local kl: 0.0 global kl: 1.72075687033e-09 valid nll: 49.6312103271, mse: 23.0856952667, local kl: 0.0 global kl: 1.95198923691e-09
Saving best model with MSE 23.085695
it: 8250, train nll: 10.1837282181, mse: 10.8819494247, local kl: 0.0 global kl: 1.79168391234e-09 valid nll: 62.2912368774, mse: 32.4922294617, local kl: 0.0 global kl: 1.162026253e-09
it: 8300, train nll: 4.578166008, mse: 5.88134002686, local kl: 0.0 global kl: 1.68339742146e-09 valid nll: 127.134651184, mse: 34.1096763611, local kl: 0.0 global kl: 1.62182778318e-09
it: 8350, train nll: 4.08218669891, mse: 5.35207414627, local kl: 0.0 global kl: 1.13473119789e-09 valid nll: 52.2611808777, mse: 31.5287342072, local kl: 0.0 global kl: 1.02930930446e-09
it: 8400, train nll: 39.3859291077, mse: 29.0614185333, local kl: 0.0 global kl: 2.13769046908e-09 valid nll: 67.2971954346, mse: 31.3014850616, local kl: 0.0 global kl: 1.52417434229e-09
it: 8450, train nll: 10.4989881516, mse: 11.7206058502, local kl: 0.0 global kl: 2.36092945194e-09 valid nll: 50.9021759033, mse: 26.3825187683, local kl: 0.0 global kl: 1.90103865982e-09
it: 8500, train nll: 6.17145586014, mse: 7.8491101265, local kl: 0.0 global kl: 1.76753034431e-09 valid nll: 45.2976531982, mse: 23.7768421173, local kl: 0.0 global kl: 8.7582663344e-10
it: 8550, train nll: 45.2926712036, mse: 15.7119531631, local kl: 0.0 global kl: 9.9967456535e-10 valid nll: 48.234249115, mse: 26.8103637695, local kl: 0.0 global kl: 7.3955652713e-10
it: 8600, train nll: 1.93641376495, mse: 2.35026574135, local kl: 0.0 global kl: 2.16824203036e-09 valid nll: 47.0045127869, mse: 32.1393013, local kl: 0.0 global kl: 2.11075845691e-09
it: 8650, train nll: 5.9277381897, mse: 7.3838429451, local kl: 0.0 global kl: 8.70681415854e-10 valid nll: 147.046676636, mse: 29.2151603699, local kl: 0.0 global kl: 1.06112041376e-09
it: 8700, train nll: 6.48049068451, mse: 6.0764708519, local kl: 0.0 global kl: 5.92621896001e-10 valid nll: 39.9483184814, mse: 23.8938026428, local kl: 0.0 global kl: 6.00524852068e-10
it: 8750, train nll: 16.2592449188, mse: 17.5030593872, local kl: 0.0 global kl: 1.28975818914e-09 valid nll: 145.006454468, mse: 31.0372753143, local kl: 0.0 global kl: 1.13056963791e-09
it: 8800, train nll: 14.6296262741, mse: 18.3042545319, local kl: 0.0 global kl: 9.34008648201e-10 valid nll: 178.034683228, mse: 27.574382782, local kl: 0.0 global kl: 1.03057140599e-09
it: 8850, train nll: 6.07125997543, mse: 6.99719333649, local kl: 0.0 global kl: 1.35256172928e-09 valid nll: 45.2764358521, mse: 30.1669406891, local kl: 0.0 global kl: 1.46873901929e-09
it: 8900, train nll: 61.0513458252, mse: 18.8518772125, local kl: 0.0 global kl: 1.52387880092e-09 valid nll: 33.0860290527, mse: 22.7692394257, local kl: 0.0 global kl: 1.19740894977e-09
Saving best model with MSE 22.76924
it: 8950, train nll: 9.40464496613, mse: 12.5724477768, local kl: 0.0 global kl: 9.59024637481e-10 valid nll: 74.3321838379, mse: 34.6384887695, local kl: 0.0 global kl: 1.02289277049e-09
it: 9000, train nll: 3.53249597549, mse: 4.87058067322, local kl: 0.0 global kl: 1.00603148034e-09 valid nll: 53.8079490662, mse: 30.2422561646, local kl: 0.0 global kl: 6.60832610855e-10
it: 9050, train nll: 13.7281923294, mse: 16.8059616089, local kl: 0.0 global kl: 1.38351596846e-09 valid nll: 38.7359848022, mse: 27.5046195984, local kl: 0.0 global kl: 1.17418874623e-09
it: 9100, train nll: 15.0895013809, mse: 15.6250629425, local kl: 0.0 global kl: 1.05375752568e-09 valid nll: 32.2164459229, mse: 22.2594680786, local kl: 0.0 global kl: 9.87938286734e-10
Saving best model with MSE 22.259468
it: 9150, train nll: 205.803878784, mse: 16.4065437317, local kl: 0.0 global kl: 1.08967745938e-09 valid nll: 79.0534744263, mse: 35.9300918579, local kl: 0.0 global kl: 9.65879598525e-10
it: 9200, train nll: 6.09910011292, mse: 9.15531730652, local kl: 0.0 global kl: 5.05384567528e-10 valid nll: 60.5390625, mse: 38.0648040771, local kl: 0.0 global kl: 7.62290119916e-10
it: 9250, train nll: 2.92119073868, mse: 3.46736764908, local kl: 0.0 global kl: 7.99951493935e-10 valid nll: 49.0031013489, mse: 29.6317882538, local kl: 0.0 global kl: 8.76741956812e-10
it: 9300, train nll: 18.257894516, mse: 16.3212375641, local kl: 0.0 global kl: 9.64372914858e-10 valid nll: 32.5050430298, mse: 26.4424705505, local kl: 0.0 global kl: 1.00612540521e-09
it: 9350, train nll: 10.728430748, mse: 12.4898700714, local kl: 0.0 global kl: 9.90242776666e-10 valid nll: 76.154800415, mse: 29.1075344086, local kl: 0.0 global kl: 1.1762858465e-09
it: 9400, train nll: 8.06761646271, mse: 9.68877506256, local kl: 0.0 global kl: 8.1445356015e-10 valid nll: 65.0922470093, mse: 27.0622253418, local kl: 0.0 global kl: 7.21168624818e-10
it: 9450, train nll: 4.24491167068, mse: 5.99360227585, local kl: 0.0 global kl: 1.11788589496e-09 valid nll: 121.555717468, mse: 36.2668838501, local kl: 0.0 global kl: 1.2197126642e-09
it: 9500, train nll: 19.7063274384, mse: 25.0707416534, local kl: 0.0 global kl: 1.23526167073e-09 valid nll: 55.862121582, mse: 36.7813224792, local kl: 0.0 global kl: 1.15945042456e-09
it: 9550, train nll: 6.45845317841, mse: 8.77624893188, local kl: 0.0 global kl: 7.45667194657e-10 valid nll: 151.555160522, mse: 28.3232269287, local kl: 0.0 global kl: 9.78505831917e-10
it: 9600, train nll: 7.91549015045, mse: 9.93492794037, local kl: 0.0 global kl: 1.22584942197e-09 valid nll: 51.0618400574, mse: 30.656785965, local kl: 0.0 global kl: 1.4765545453e-09
it: 9650, train nll: 9.88807582855, mse: 13.0828619003, local kl: 0.0 global kl: 7.18609116657e-10 valid nll: 45.3902816772, mse: 28.2014923096, local kl: 0.0 global kl: 7.77503395e-10
it: 9700, train nll: 41.6026344299, mse: 27.2033367157, local kl: 0.0 global kl: 1.41012157506e-09 valid nll: 76.4292984009, mse: 41.4961929321, local kl: 0.0 global kl: 8.99194607662e-10
it: 9750, train nll: 21.4572067261, mse: 16.1757354736, local kl: 0.0 global kl: 8.58671689308e-10 valid nll: 70.1112136841, mse: 28.8667469025, local kl: 0.0 global kl: 9.61837054447e-10
it: 9800, train nll: 9.56504058838, mse: 12.3788528442, local kl: 0.0 global kl: 7.16008197177e-10 valid nll: 63.1881370544, mse: 31.8765563965, local kl: 0.0 global kl: 9.83052972359e-10
it: 9850, train nll: 13.479475975, mse: 14.9182987213, local kl: 0.0 global kl: 7.73792419029e-10 valid nll: 35.4259147644, mse: 27.638999939, local kl: 0.0 global kl: 1.0837736264e-09
it: 9900, train nll: 12.6211690903, mse: 8.07109832764, local kl: 0.0 global kl: 8.03363431334e-10 valid nll: 49.6588134766, mse: 25.058139801, local kl: 0.0 global kl: 1.53978274575e-09
it: 9950, train nll: 19.7585315704, mse: 13.3992853165, local kl: 0.0 global kl: 7.65344954079e-10 valid nll: 53.0265426636, mse: 32.7918739319, local kl: 0.0 global kl: 1.08405617816e-09

Posterior predictive + GP


In [0]:
uncertainty_type = 'attentive_gp'
local_variational = True
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_posterior_gp_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)


it: 0, train nll: 131.736602783, mse: 261.122589111, local kl: 0.0353629663587 global kl: 8.2984319306e-05 valid nll: 152.57359314, mse: 302.690948486, local kl: 0.0492436401546 global kl: 0.000122616358567
Saving best model with MSE 302.69095
it: 50, train nll: 8.20581817627, mse: 9.99204444885, local kl: 12.4604902267 global kl: 0.000217912674998 valid nll: 12.7324056625, mse: 17.2057094574, local kl: 14.8639354706 global kl: 0.000617719313595
Saving best model with MSE 17.20571
it: 100, train nll: 9.82409477234, mse: 13.3075666428, local kl: 7.85672092438 global kl: 0.000187600715435 valid nll: 10.4615631104, mse: 13.1943149567, local kl: 16.4872779846 global kl: 6.16395409452e-05
Saving best model with MSE 13.194315
it: 150, train nll: 5.35595321655, mse: 6.58239555359, local kl: 7.75857448578 global kl: 1.5606310626e-05 valid nll: 10.0951595306, mse: 12.7748575211, local kl: 15.3519239426 global kl: 2.25215171668e-05
Saving best model with MSE 12.7748575
it: 200, train nll: 7.05510044098, mse: 9.10587215424, local kl: 7.37109661102 global kl: 2.78741354123e-05 valid nll: 9.97383117676, mse: 12.6873283386, local kl: 14.1055603027 global kl: 1.10009423224e-05
Saving best model with MSE 12.687328
it: 250, train nll: 5.61041021347, mse: 6.92930650711, local kl: 6.33448934555 global kl: 7.73314332037e-06 valid nll: 8.70890426636, mse: 10.5828504562, local kl: 15.1027021408 global kl: 4.67329755338e-06
Saving best model with MSE 10.58285
it: 300, train nll: 6.29046487808, mse: 7.78862714767, local kl: 7.83928871155 global kl: 1.39571034197e-06 valid nll: 9.77124023438, mse: 12.8659296036, local kl: 12.0133943558 global kl: 1.24156633774e-06
it: 350, train nll: 2.69561052322, mse: 2.67313218117, local kl: 2.64422035217 global kl: 4.99532688991e-07 valid nll: 7.08707618713, mse: 8.11425209045, local kl: 15.3343276978 global kl: 4.49836392136e-07
Saving best model with MSE 8.114252
it: 400, train nll: 5.71114349365, mse: 7.38367176056, local kl: 4.8832821846 global kl: 1.68673139456e-07 valid nll: 7.82870721817, mse: 9.35809707642, local kl: 14.013133049 global kl: 1.53688759497e-07
it: 450, train nll: 3.07934904099, mse: 3.34903359413, local kl: 4.20472669601 global kl: 5.36260138517e-08 valid nll: 6.95073413849, mse: 7.90281152725, local kl: 14.580078125 global kl: 7.91208378814e-08
Saving best model with MSE 7.9028115
it: 500, train nll: 4.68724155426, mse: 5.69263219833, local kl: 3.51335024834 global kl: 5.96213496351e-08 valid nll: 6.60188627243, mse: 7.72890901566, local kl: 11.5125684738 global kl: 4.88946518828e-08
Saving best model with MSE 7.728909
it: 550, train nll: 4.32224082947, mse: 5.31283283234, local kl: 3.58725428581 global kl: 4.57838460477e-08 valid nll: 6.34783554077, mse: 7.16313505173, local kl: 13.6487398148 global kl: 4.99773946672e-08
Saving best model with MSE 7.163135
it: 600, train nll: 3.03542399406, mse: 3.03466439247, local kl: 3.43493700027 global kl: 3.55587346235e-08 valid nll: 6.48825311661, mse: 7.25276136398, local kl: 12.7321271896 global kl: 3.53664866282e-08
it: 650, train nll: 3.39887905121, mse: 3.32939887047, local kl: 5.24714565277 global kl: 2.63459583039e-08 valid nll: 6.35402297974, mse: 7.18432712555, local kl: 12.4536495209 global kl: 2.88066814846e-08
it: 700, train nll: 3.63143110275, mse: 3.73080849648, local kl: 5.2684264183 global kl: 1.86710202854e-08 valid nll: 6.81429386139, mse: 7.57431554794, local kl: 13.2048768997 global kl: 3.35218715009e-08
it: 750, train nll: 4.12898015976, mse: 4.76116037369, local kl: 4.80679893494 global kl: 1.37299291936e-08 valid nll: 5.90150785446, mse: 6.44697475433, local kl: 12.0433597565 global kl: 2.13998170295e-08
Saving best model with MSE 6.4469748
it: 800, train nll: 2.47894620895, mse: 2.19765329361, local kl: 3.70741391182 global kl: 9.5571692782e-09 valid nll: 5.87560987473, mse: 6.30558204651, local kl: 12.7674093246 global kl: 1.9050323985e-08
Saving best model with MSE 6.305582
it: 850, train nll: 3.44647264481, mse: 3.29909729958, local kl: 5.937479496 global kl: 1.30612063387e-08 valid nll: 4.96073484421, mse: 5.22900390625, local kl: 11.7237052917 global kl: 1.99432808046e-08
Saving best model with MSE 5.229004
it: 900, train nll: 3.88086867332, mse: 4.13504076004, local kl: 4.57558774948 global kl: 1.1192027749e-08 valid nll: 5.2080078125, mse: 5.6439371109, local kl: 11.5488758087 global kl: 1.14278710939e-08
it: 950, train nll: 1.81255936623, mse: 1.1996628046, local kl: 3.90155434608 global kl: 1.48444030401e-08 valid nll: 4.64729404449, mse: 4.87053537369, local kl: 10.3159770966 global kl: 9.87788695284e-09
Saving best model with MSE 4.8705354
it: 1000, train nll: 3.44300174713, mse: 3.67007303238, local kl: 5.33276748657 global kl: 5.22680032589e-09 valid nll: 4.61440658569, mse: 4.73942899704, local kl: 10.948515892 global kl: 6.88702561646e-09
Saving best model with MSE 4.739429
it: 1050, train nll: 3.75643754005, mse: 3.64799284935, local kl: 6.48091173172 global kl: 4.59252103013e-09 valid nll: 5.58371591568, mse: 5.9868221283, local kl: 10.8117074966 global kl: 6.92004054059e-09
it: 1100, train nll: 2.39800024033, mse: 2.17956280708, local kl: 3.72692131996 global kl: 6.30427621218e-09 valid nll: 5.32682561874, mse: 5.51526069641, local kl: 14.1158752441 global kl: 6.70376021361e-09
it: 1150, train nll: 2.22844338417, mse: 1.93241214752, local kl: 2.30653262138 global kl: 8.31947488678e-09 valid nll: 4.74414157867, mse: 4.99185800552, local kl: 10.3436374664 global kl: 7.44397388175e-09
it: 1200, train nll: 2.15593314171, mse: 1.59348213673, local kl: 4.32652330399 global kl: 7.53664952668e-09 valid nll: 4.84252882004, mse: 4.78090047836, local kl: 12.4369478226 global kl: 5.60825341722e-09
it: 1250, train nll: 2.52617883682, mse: 2.07575154305, local kl: 6.26399755478 global kl: 5.09393638382e-09 valid nll: 4.94428825378, mse: 5.13966274261, local kl: 13.1241579056 global kl: 6.33104590975e-09
it: 1300, train nll: 2.63909554482, mse: 2.43624615669, local kl: 3.64323496819 global kl: 3.06589109655e-09 valid nll: 5.18234300613, mse: 5.25909900665, local kl: 13.744395256 global kl: 3.56996077144e-09
it: 1350, train nll: 2.60648846626, mse: 2.42937922478, local kl: 3.21365237236 global kl: 4.60620963594e-09 valid nll: 5.66870260239, mse: 6.10067939758, local kl: 11.0710744858 global kl: 5.11858910812e-09
it: 1400, train nll: 4.08694839478, mse: 3.70918440819, local kl: 22.0923976898 global kl: 3.11167625e-09 valid nll: 8.74178028107, mse: 11.9377288818, local kl: 9.66919803619 global kl: 3.17818438234e-09
it: 1450, train nll: 2.71613907814, mse: 2.47394752502, local kl: 4.04584598541 global kl: 2.19492757303e-09 valid nll: 5.53877067566, mse: 5.76447486877, local kl: 12.5028343201 global kl: 4.09253653189e-09
it: 1500, train nll: 1.88387703896, mse: 1.46845114231, local kl: 2.32921910286 global kl: 1.8475914132e-09 valid nll: 5.36864566803, mse: 5.73973941803, local kl: 11.7066612244 global kl: 3.82797571419e-09
it: 1550, train nll: 3.23831319809, mse: 3.33292913437, local kl: 4.4075512886 global kl: 2.22368079505e-09 valid nll: 4.58206939697, mse: 4.67618513107, local kl: 10.2097740173 global kl: 3.00324165536e-09
Saving best model with MSE 4.676185
it: 1600, train nll: 3.91090321541, mse: 3.94123721123, local kl: 6.7273850441 global kl: 4.27727808727e-09 valid nll: 4.0688624382, mse: 3.91679739952, local kl: 10.1703996658 global kl: 2.81810952352e-09
Saving best model with MSE 3.9167974
it: 1650, train nll: 2.45914721489, mse: 2.17145681381, local kl: 3.62189602852 global kl: 2.32385022336e-09 valid nll: 4.69405889511, mse: 4.69360733032, local kl: 11.4580440521 global kl: 2.52239695797e-09
it: 1700, train nll: 3.09966754913, mse: 2.88982224464, local kl: 4.82898044586 global kl: 2.23864193849e-09 valid nll: 4.45939922333, mse: 4.46049451828, local kl: 9.9505109787 global kl: 3.02424618681e-09
it: 1750, train nll: 2.70715522766, mse: 2.55384230614, local kl: 3.38270306587 global kl: 3.6466585307e-09 valid nll: 4.65081262589, mse: 4.56149721146, local kl: 15.2960805893 global kl: 6.12311357173e-09
it: 1800, train nll: 3.37514972687, mse: 3.28492593765, local kl: 5.19713592529 global kl: 2.3500184021e-09 valid nll: 5.28727436066, mse: 5.7941570282, local kl: 11.9488477707 global kl: 3.38123240518e-09
it: 1850, train nll: 2.07458424568, mse: 1.61195647717, local kl: 4.51951265335 global kl: 3.12202219632e-09 valid nll: 4.16334199905, mse: 4.0090675354, local kl: 10.5434007645 global kl: 3.38382699638e-09
it: 1900, train nll: 1.50460803509, mse: 0.896158516407, local kl: 2.4068903923 global kl: 1.75653358525e-09 valid nll: 4.6939997673, mse: 4.7660984993, local kl: 12.409453392 global kl: 2.71221023418e-09
it: 1950, train nll: 3.06692957878, mse: 2.75973892212, local kl: 6.07640266418 global kl: 3.48463968791e-09 valid nll: 3.96153950691, mse: 3.82208251953, local kl: 10.3825998306 global kl: 2.05602246339e-09
Saving best model with MSE 3.8220825
it: 2000, train nll: 1.92526817322, mse: 1.42849099636, local kl: 4.28264188766 global kl: 3.81170517372e-09 valid nll: 4.84716510773, mse: 4.7479133606, local kl: 15.9660692215 global kl: 2.60658272744e-09
it: 2050, train nll: 2.33061671257, mse: 1.9631100893, local kl: 3.43156266212 global kl: 1.44008371894e-09 valid nll: 4.5454454422, mse: 4.63463497162, local kl: 11.4467077255 global kl: 2.84091461467e-09
it: 2100, train nll: 2.11916589737, mse: 1.76117956638, local kl: 3.19627976418 global kl: 2.73698952391e-09 valid nll: 4.62873506546, mse: 4.56786155701, local kl: 12.669424057 global kl: 2.05535721776e-09
it: 2150, train nll: 3.31860089302, mse: 3.04737305641, local kl: 6.66125249863 global kl: 1.42470002462e-09 valid nll: 4.56219291687, mse: 4.53662109375, local kl: 10.7046880722 global kl: 2.21968687875e-09
it: 2200, train nll: 2.49544906616, mse: 1.90037190914, local kl: 5.26071023941 global kl: 3.40241390617e-09 valid nll: 4.15907335281, mse: 3.99009680748, local kl: 9.9608001709 global kl: 2.5197512965e-09
it: 2250, train nll: 2.36067771912, mse: 1.90362083912, local kl: 4.46517562866 global kl: 1.90973148406e-09 valid nll: 4.51600646973, mse: 4.6380200386, local kl: 10.9523563385 global kl: 2.66830246787e-09
it: 2300, train nll: 3.03853535652, mse: 3.04530405998, local kl: 4.08507823944 global kl: 1.69609981615e-09 valid nll: 4.06745910645, mse: 3.99640226364, local kl: 8.78305721283 global kl: 2.52904119868e-09
it: 2350, train nll: 3.25078821182, mse: 3.5885052681, local kl: 5.44250631332 global kl: 3.72856145958e-09 valid nll: 4.5141825676, mse: 4.72780036926, local kl: 10.1647844315 global kl: 3.77553366349e-09
it: 2400, train nll: 2.53925585747, mse: 2.24067544937, local kl: 7.01959276199 global kl: 1.57862634076e-09 valid nll: 4.25796413422, mse: 4.42850923538, local kl: 8.25619220734 global kl: 1.69101777026e-09
it: 2450, train nll: 2.36239552498, mse: 1.94361543655, local kl: 6.2202539444 global kl: 2.14852602376e-09 valid nll: 4.82958602905, mse: 4.88020515442, local kl: 11.2568445206 global kl: 2.70231059751e-09
it: 2500, train nll: 2.4736366272, mse: 2.05748987198, local kl: 6.84222698212 global kl: 1.41197920023e-09 valid nll: 4.64536857605, mse: 4.62501525879, local kl: 12.4686231613 global kl: 2.78720380109e-09
it: 2550, train nll: 1.45271039009, mse: 0.762513875961, local kl: 3.01276516914 global kl: 1.16781029291e-09 valid nll: 3.83596682549, mse: 3.69909405708, local kl: 9.97217273712 global kl: 2.05337058468e-09
Saving best model with MSE 3.699094
it: 2600, train nll: 3.87288379669, mse: 4.18238115311, local kl: 6.00857496262 global kl: 1.2527832105e-09 valid nll: 4.02640914917, mse: 4.00827455521, local kl: 11.278175354 global kl: 1.63085933647e-09
it: 2650, train nll: 4.97355604172, mse: 5.83392715454, local kl: 7.97796010971 global kl: 3.17921644566e-09 valid nll: 4.75985813141, mse: 4.97953367233, local kl: 10.4845180511 global kl: 2.2712067782e-09
it: 2700, train nll: 2.38679337502, mse: 2.13036179543, local kl: 3.94111680984 global kl: 1.15034537451e-09 valid nll: 4.35486698151, mse: 4.3182926178, local kl: 10.1790390015 global kl: 1.18598986187e-09
it: 2750, train nll: 1.29641997814, mse: 0.568443953991, local kl: 3.22796225548 global kl: 2.2129966748e-09 valid nll: 3.77965784073, mse: 3.43875837326, local kl: 14.2688264847 global kl: 1.41857336988e-09
Saving best model with MSE 3.4387584
it: 2800, train nll: 2.45617961884, mse: 2.07347774506, local kl: 5.74192333221 global kl: 1.46461898165e-09 valid nll: 6.07408475876, mse: 7.02613782883, local kl: 11.5405445099 global kl: 1.23114873851e-09
it: 2850, train nll: 3.8526494503, mse: 3.98510122299, local kl: 6.68023490906 global kl: 1.57789581401e-09 valid nll: 4.02369403839, mse: 3.8042447567, local kl: 13.028506279 global kl: 3.2139468864e-09
it: 2900, train nll: 1.38851416111, mse: 0.686342298985, local kl: 3.40847539902 global kl: 1.27726518251e-09 valid nll: 5.1445145607, mse: 4.88982439041, local kl: 16.3605823517 global kl: 1.31814725801e-09
it: 2950, train nll: 3.26388692856, mse: 3.13573050499, local kl: 11.8873052597 global kl: 2.55124033011e-09 valid nll: 6.37606334686, mse: 7.64383506775, local kl: 8.86728000641 global kl: 1.64404989622e-09
it: 3000, train nll: 2.3651227951, mse: 2.18109989166, local kl: 5.63298034668 global kl: 2.07157158094e-09 valid nll: 4.21606636047, mse: 4.16706895828, local kl: 11.7830104828 global kl: 2.56387688857e-09
it: 3050, train nll: 2.95615267754, mse: 2.97031807899, local kl: 2.90378332138 global kl: 1.98797467377e-09 valid nll: 4.1741399765, mse: 4.22867441177, local kl: 10.2920446396 global kl: 1.73651937274e-09
it: 3100, train nll: 3.00106978416, mse: 2.39401340485, local kl: 8.85858249664 global kl: 1.38080336054e-09 valid nll: 4.83169746399, mse: 4.74328660965, local kl: 13.9995536804 global kl: 1.51164558648e-09
it: 3150, train nll: 3.59941935539, mse: 3.92273139954, local kl: 3.48936486244 global kl: 1.40500855395e-09 valid nll: 5.64675664902, mse: 5.7505197525, local kl: 14.1613645554 global kl: 1.01120423146e-09
it: 3200, train nll: 1.44001233578, mse: 0.721992135048, local kl: 2.55765080452 global kl: 1.01428010435e-09 valid nll: 3.87623381615, mse: 3.76794028282, local kl: 10.0409297943 global kl: 1.47124901151e-09
it: 3250, train nll: 2.27241706848, mse: 2.12126517296, local kl: 1.90355849266 global kl: 1.37131395128e-09 valid nll: 4.27187347412, mse: 4.27880096436, local kl: 10.42374897 global kl: 1.15040388327e-09
it: 3300, train nll: 3.50218272209, mse: 3.11656475067, local kl: 10.0049734116 global kl: 1.07003450545e-09 valid nll: 4.0939707756, mse: 3.96336555481, local kl: 11.159406662 global kl: 8.48050074609e-10
it: 3350, train nll: 2.17135190964, mse: 1.69401919842, local kl: 5.02556562424 global kl: 8.39721958634e-10 valid nll: 4.27283334732, mse: 4.27044343948, local kl: 10.5741243362 global kl: 1.69646319215e-09
it: 3400, train nll: 2.69792723656, mse: 2.32995319366, local kl: 6.8985877037 global kl: 1.52672274822e-09 valid nll: 4.87890338898, mse: 4.99009037018, local kl: 10.6829557419 global kl: 1.49339030031e-09
it: 3450, train nll: 3.15796685219, mse: 3.43522548676, local kl: 3.11684894562 global kl: 1.15738818529e-09 valid nll: 3.907361269, mse: 3.80206894875, local kl: 11.7211751938 global kl: 2.30204699747e-09
it: 3500, train nll: 1.74512338638, mse: 1.21209800243, local kl: 3.33957505226 global kl: 1.15508813625e-09 valid nll: 3.99217104912, mse: 3.62590837479, local kl: 12.9483385086 global kl: 1.88072446505e-09
it: 3550, train nll: 3.65675616264, mse: 3.76567196846, local kl: 4.9027967453 global kl: 1.92768712104e-09 valid nll: 4.33539485931, mse: 4.12127494812, local kl: 11.0982255936 global kl: 9.13530251445e-10
it: 3600, train nll: 2.17300105095, mse: 1.83893632889, local kl: 2.75658369064 global kl: 1.25121280004e-09 valid nll: 4.09754514694, mse: 3.81708097458, local kl: 11.961066246 global kl: 1.65225910731e-09
it: 3650, train nll: 3.62460970879, mse: 3.31024861336, local kl: 9.17100429535 global kl: 1.05051134458e-09 valid nll: 3.84129023552, mse: 3.58301949501, local kl: 12.7405080795 global kl: 9.36003163865e-10
it: 3700, train nll: 3.86430835724, mse: 4.44609022141, local kl: 2.4672806263 global kl: 1.98375826876e-09 valid nll: 3.29777121544, mse: 3.03102564812, local kl: 8.69805526733 global kl: 2.04539629678e-09
Saving best model with MSE 3.0310256
it: 3750, train nll: 3.48757576942, mse: 3.60281324387, local kl: 8.34064102173 global kl: 1.08671860399e-09 valid nll: 5.44010543823, mse: 6.58800458908, local kl: 11.8530778885 global kl: 1.45247391892e-09
it: 3800, train nll: 3.84451770782, mse: 4.06337451935, local kl: 4.88033008575 global kl: 9.26123178147e-10 valid nll: 3.86793398857, mse: 3.79121828079, local kl: 10.7968063354 global kl: 2.12004325206e-09
it: 3850, train nll: 2.23476338387, mse: 1.88667857647, local kl: 8.00851631165 global kl: 1.15202913875e-09 valid nll: 3.60002684593, mse: 3.37333321571, local kl: 11.1810235977 global kl: 7.2646944016e-10
it: 3900, train nll: 2.51075148582, mse: 2.34521484375, local kl: 3.63939547539 global kl: 1.02911257294e-09 valid nll: 4.44952630997, mse: 4.638215065, local kl: 8.42218399048 global kl: 1.09111097935e-09
it: 3950, train nll: 3.46069574356, mse: 3.70876312256, local kl: 2.9957151413 global kl: 1.1107694764e-09 valid nll: 3.66536259651, mse: 3.48088002205, local kl: 8.3839263916 global kl: 1.28762656093e-09
it: 4000, train nll: 2.16980862617, mse: 1.59991180897, local kl: 3.93858671188 global kl: 9.18636999803e-10 valid nll: 4.11113595963, mse: 4.14938879013, local kl: 9.26665496826 global kl: 1.47409595641e-09
it: 4050, train nll: 3.14738607407, mse: 3.10217547417, local kl: 5.48441934586 global kl: 1.16821730067e-09 valid nll: 2.8574256897, mse: 2.4355597496, local kl: 8.2306470871 global kl: 1.35383049216e-09
Saving best model with MSE 2.4355597
it: 4100, train nll: 4.06129312515, mse: 4.15263080597, local kl: 7.53188228607 global kl: 6.83138101643e-10 valid nll: 4.50344991684, mse: 4.6411190033, local kl: 9.50436019897 global kl: 1.08361841722e-09
it: 4150, train nll: 1.84961175919, mse: 1.5723528862, local kl: 1.17811596394 global kl: 5.70221814211e-10 valid nll: 3.22940206528, mse: 2.61783409119, local kl: 13.4688472748 global kl: 9.44369915601e-10
it: 4200, train nll: 2.02102565765, mse: 1.53292715549, local kl: 2.63537144661 global kl: 1.01911290518e-09 valid nll: 4.64539384842, mse: 4.43478536606, local kl: 13.5652656555 global kl: 7.2358352643e-10
it: 4250, train nll: 2.77392482758, mse: 2.48984932899, local kl: 4.88301134109 global kl: 4.36077340904e-10 valid nll: 3.73970603943, mse: 3.56205248833, local kl: 8.2781419754 global kl: 7.89348419961e-10
it: 4300, train nll: 3.16291069984, mse: 3.09675145149, local kl: 4.7958984375 global kl: 8.53363602005e-10 valid nll: 4.02314138412, mse: 3.83296585083, local kl: 12.5991849899 global kl: 1.20769771961e-09
it: 4350, train nll: 3.28175544739, mse: 3.21865081787, local kl: 5.19243717194 global kl: 5.34953525921e-10 valid nll: 3.75128245354, mse: 3.5544552803, local kl: 9.96827983856 global kl: 7.86233744776e-10
it: 4400, train nll: 3.67316222191, mse: 3.39814734459, local kl: 8.51923274994 global kl: 6.98559710077e-10 valid nll: 4.4811668396, mse: 4.1530880928, local kl: 11.3059682846 global kl: 8.98656260517e-10
it: 4450, train nll: 4.04071855545, mse: 3.89607810974, local kl: 9.43712520599 global kl: 3.83380494018e-10 valid nll: 4.88265895844, mse: 5.26137399673, local kl: 9.45113754272 global kl: 1.04026676162e-09
it: 4500, train nll: 1.57711660862, mse: 0.893361628056, local kl: 3.32351660728 global kl: 9.29916477155e-10 valid nll: 3.52919340134, mse: 3.16569137573, local kl: 9.95921897888 global kl: 1.26774701847e-09
it: 4550, train nll: 2.07659292221, mse: 1.41264474392, local kl: 5.46616601944 global kl: 5.24814580682e-10 valid nll: 3.42093038559, mse: 3.02983808517, local kl: 10.0055885315 global kl: 8.1521350781e-10
it: 4600, train nll: 1.58298528194, mse: 1.15580105782, local kl: 3.00464510918 global kl: 1.91061166888e-09 valid nll: 4.22465896606, mse: 3.92071723938, local kl: 12.2244796753 global kl: 9.58660262285e-10
it: 4650, train nll: 1.99643528461, mse: 1.48756051064, local kl: 3.90720677376 global kl: 7.04461877721e-10 valid nll: 3.59495639801, mse: 3.01713418961, local kl: 13.7252759933 global kl: 9.36654975803e-10
it: 4700, train nll: 3.65860128403, mse: 3.67136526108, local kl: 5.80149745941 global kl: 8.98253582626e-10 valid nll: 3.32724738121, mse: 2.76537561417, local kl: 9.85548305511 global kl: 5.5090332296e-10
it: 4750, train nll: 3.62339663506, mse: 3.68590426445, local kl: 7.65880918503 global kl: 1.50031165269e-09 valid nll: 3.68296027184, mse: 3.24758911133, local kl: 10.6526374817 global kl: 9.3491703268e-10
it: 4800, train nll: 2.60529780388, mse: 2.38498044014, local kl: 6.05395841599 global kl: 7.43781147783e-10 valid nll: 3.6269595623, mse: 3.33858633041, local kl: 10.3293275833 global kl: 8.93910390154e-10
it: 4850, train nll: 3.57655262947, mse: 3.36385250092, local kl: 6.25039958954 global kl: 3.23791382595e-10 valid nll: 4.27270650864, mse: 4.27758932114, local kl: 8.85856533051 global kl: 8.38790037427e-10
it: 4900, train nll: 3.48903226852, mse: 3.60102534294, local kl: 3.82938241959 global kl: 7.88999254819e-10 valid nll: 3.86649703979, mse: 3.58935976028, local kl: 10.8177137375 global kl: 1.12377995798e-09
it: 4950, train nll: 2.96677899361, mse: 2.79052567482, local kl: 3.99836468697 global kl: 1.22039023331e-09 valid nll: 5.25368309021, mse: 5.28339529037, local kl: 13.1573915482 global kl: 9.00291230455e-10
it: 5000, train nll: 2.09687042236, mse: 1.78596901894, local kl: 2.02269911766 global kl: 1.15850951055e-09 valid nll: 3.76012849808, mse: 3.4222111702, local kl: 10.2368011475 global kl: 9.98184868095e-10
it: 5050, train nll: 2.63499093056, mse: 2.60017108917, local kl: 4.14310455322 global kl: 8.16200274034e-10 valid nll: 3.27463006973, mse: 2.54759955406, local kl: 12.9336175919 global kl: 5.31851174213e-10
it: 5100, train nll: 2.33630228043, mse: 1.77985966206, local kl: 5.16076421738 global kl: 1.02855823858e-09 valid nll: 2.87910199165, mse: 2.34421873093, local kl: 11.5488233566 global kl: 1.14160325637e-09
Saving best model with MSE 2.3442187
it: 5150, train nll: 2.71345376968, mse: 2.41480588913, local kl: 5.35407972336 global kl: 9.15508557853e-10 valid nll: 3.69021916389, mse: 3.47985267639, local kl: 9.50863933563 global kl: 8.33973612391e-10
it: 5200, train nll: 2.63120365143, mse: 2.69170713425, local kl: 2.0253136158 global kl: 7.29000748656e-10 valid nll: 3.41695165634, mse: 3.01484894753, local kl: 9.58011627197 global kl: 1.45724354805e-09
it: 5250, train nll: 2.41797065735, mse: 1.78132343292, local kl: 9.38363647461 global kl: 8.39948444131e-10 valid nll: 3.48725748062, mse: 3.00888180733, local kl: 10.8286066055 global kl: 6.53076481782e-10
it: 5300, train nll: 2.27741503716, mse: 1.66141927242, local kl: 8.53205680847 global kl: 1.27137256278e-09 valid nll: 5.12731790543, mse: 5.43461990356, local kl: 13.1771306992 global kl: 7.91947230017e-10
it: 5350, train nll: 3.22304105759, mse: 3.39647126198, local kl: 2.95677852631 global kl: 7.45297157323e-10 valid nll: 3.37583017349, mse: 2.9475569725, local kl: 11.1631746292 global kl: 3.86156795429e-09
it: 5400, train nll: 2.94102263451, mse: 2.80761027336, local kl: 5.96521282196 global kl: 1.70865299687e-09 valid nll: 3.19899153709, mse: 2.62388086319, local kl: 10.1855869293 global kl: 6.1793004047e-10
it: 5450, train nll: 1.27941286564, mse: 0.464536935091, local kl: 3.96495723724 global kl: 1.10902542705e-09 valid nll: 2.78656220436, mse: 2.26448082924, local kl: 10.6311416626 global kl: 9.07576569453e-10
Saving best model with MSE 2.2644808
it: 5500, train nll: 1.77996051311, mse: 1.0541690588, local kl: 3.23787403107 global kl: 8.40814640135e-10 valid nll: 2.99714922905, mse: 2.55126595497, local kl: 9.62326335907 global kl: 7.79036724019e-10
it: 5550, train nll: 1.75706481934, mse: 1.17818915844, local kl: 3.53139305115 global kl: 3.09713088509e-10 valid nll: 3.32949662209, mse: 2.62991666794, local kl: 11.5204277039 global kl: 6.37721320196e-10
it: 5600, train nll: 3.6792576313, mse: 4.0980758667, local kl: 3.61453914642 global kl: 2.79553213911e-10 valid nll: 4.11863231659, mse: 3.90645694733, local kl: 10.1377811432 global kl: 5.15001707946e-10
it: 5650, train nll: 3.44946956635, mse: 3.86488819122, local kl: 3.39962148666 global kl: 5.95452076535e-10 valid nll: 3.34094929695, mse: 2.83319568634, local kl: 11.4251928329 global kl: 1.73456660146e-09
it: 5700, train nll: 3.41205120087, mse: 3.19113230705, local kl: 5.50469493866 global kl: 5.70535896305e-10 valid nll: 3.15475869179, mse: 2.61545228958, local kl: 11.1830701828 global kl: 9.50915013398e-10
it: 5750, train nll: 1.27761030197, mse: 0.653386294842, local kl: 1.16925275326 global kl: 3.11155823329e-10 valid nll: 3.08083939552, mse: 2.60474038124, local kl: 7.91691732407 global kl: 7.07821357082e-10
it: 5800, train nll: 2.75886368752, mse: 2.45148825645, local kl: 6.18365478516 global kl: 3.0567028908e-09 valid nll: 4.14396762848, mse: 4.06536626816, local kl: 8.59971809387 global kl: 1.57818624835e-09
it: 5850, train nll: 2.23724222183, mse: 1.91285228729, local kl: 2.72879958153 global kl: 7.03921920753e-10 valid nll: 3.28871631622, mse: 2.87024927139, local kl: 11.0769147873 global kl: 1.20286558492e-09
it: 5900, train nll: 2.79527664185, mse: 2.70623779297, local kl: 3.60476374626 global kl: 4.62736821083e-10 valid nll: 3.21903729439, mse: 2.8094522953, local kl: 9.5217666626 global kl: 9.84415438054e-10
it: 5950, train nll: 3.26594376564, mse: 3.638890028, local kl: 3.84650802612 global kl: 1.47481094004e-09 valid nll: 3.06488347054, mse: 2.59947609901, local kl: 10.2916946411 global kl: 1.33450439588e-09
it: 6000, train nll: 2.47231340408, mse: 2.27864694595, local kl: 2.9454703331 global kl: 2.47925280306e-09 valid nll: 4.1739692688, mse: 3.63035607338, local kl: 18.6142864227 global kl: 1.97242755462e-09
it: 6050, train nll: 4.21246957779, mse: 4.85241270065, local kl: 5.37868738174 global kl: 2.11205986034e-09 valid nll: 3.50339508057, mse: 3.19772648811, local kl: 11.2180833817 global kl: 9.22851128848e-10
it: 6100, train nll: 3.59024763107, mse: 3.83889484406, local kl: 3.68613958359 global kl: 9.53437107043e-10 valid nll: 4.56752347946, mse: 4.65503168106, local kl: 11.221654892 global kl: 1.81013015688e-09
it: 6150, train nll: 2.80382657051, mse: 2.64702010155, local kl: 5.14874267578 global kl: 8.48043413271e-10 valid nll: 2.99646186829, mse: 2.5961868763, local kl: 9.70889091492 global kl: 1.10473519221e-09
it: 6200, train nll: 3.14185667038, mse: 3.22244024277, local kl: 3.77373218536 global kl: 1.25503762938e-09 valid nll: 3.40765571594, mse: 3.05175971985, local kl: 9.83208942413 global kl: 8.55182702431e-10
it: 6250, train nll: 2.4152879715, mse: 1.82285368443, local kl: 8.29429912567 global kl: 8.24733503713e-10 valid nll: 3.67969965935, mse: 3.29828643799, local kl: 12.5241231918 global kl: 1.12012688014e-09
it: 6300, train nll: 3.13824033737, mse: 3.09655380249, local kl: 4.90710687637 global kl: 1.76400083429e-09 valid nll: 3.41036009789, mse: 2.97448992729, local kl: 9.36726570129 global kl: 8.8021218092e-10
it: 6350, train nll: 1.80746638775, mse: 1.35459530354, local kl: 2.635617733 global kl: 7.90489451674e-10 valid nll: 3.68803858757, mse: 3.33361959457, local kl: 10.4489002228 global kl: 1.11221898358e-09
it: 6400, train nll: 3.30414223671, mse: 3.52240133286, local kl: 3.7443087101 global kl: 7.48396011829e-10 valid nll: 3.36716532707, mse: 2.84590816498, local kl: 11.3029088974 global kl: 9.33802590808e-10
it: 6450, train nll: 2.75754952431, mse: 2.88611888885, local kl: 2.41000103951 global kl: 6.82826073461e-10 valid nll: 3.68022203445, mse: 3.29160451889, local kl: 12.0995149612 global kl: 1.06191200278e-09
it: 6500, train nll: 3.05412626266, mse: 3.01057672501, local kl: 6.80566215515 global kl: 1.22313348339e-09 valid nll: 4.33622646332, mse: 4.22319793701, local kl: 12.6758928299 global kl: 2.36087527306e-09
it: 6550, train nll: 2.90661931038, mse: 2.76278877258, local kl: 3.60973143578 global kl: 4.96118646165e-10 valid nll: 3.03750896454, mse: 2.49851870537, local kl: 10.0339641571 global kl: 8.6038837166e-10
it: 6600, train nll: 1.32821929455, mse: 0.676372945309, local kl: 1.57670152187 global kl: 8.11810008106e-10 valid nll: 3.51284408569, mse: 3.22599363327, local kl: 10.2219867706 global kl: 7.07553904356e-10
it: 6650, train nll: 2.19439268112, mse: 1.75317156315, local kl: 3.40716481209 global kl: 2.93682034247e-09 valid nll: 4.16032886505, mse: 3.8424448967, local kl: 14.8509273529 global kl: 1.83662529629e-09
it: 6700, train nll: 3.66118073463, mse: 3.79794621468, local kl: 6.1193728447 global kl: 8.07686695303e-10 valid nll: 3.89958882332, mse: 3.6256582737, local kl: 9.59320449829 global kl: 7.08370362368e-10
it: 6750, train nll: 1.84297895432, mse: 1.26939415932, local kl: 3.31234622002 global kl: 6.89141632648e-10 valid nll: 3.53631496429, mse: 3.04869771004, local kl: 13.1155157089 global kl: 7.52676476701e-10
it: 6800, train nll: 3.17748308182, mse: 3.17454195023, local kl: 7.68555212021 global kl: 1.64382141232e-09 valid nll: 3.48997974396, mse: 3.01353859901, local kl: 11.7409067154 global kl: 1.09084008493e-09
it: 6850, train nll: 2.9109313488, mse: 3.18432283401, local kl: 2.02010035515 global kl: 9.88539250457e-10 valid nll: 3.71339678764, mse: 3.46361041069, local kl: 9.63112926483 global kl: 5.47368927961e-10
it: 6900, train nll: 1.7099083662, mse: 1.27985405922, local kl: 2.01248550415 global kl: 6.56846577129e-10 valid nll: 3.50092720985, mse: 3.15035533905, local kl: 9.52360725403 global kl: 7.05914271482e-10
it: 6950, train nll: 3.11444354057, mse: 3.01907229424, local kl: 4.91388320923 global kl: 6.65370203379e-10 valid nll: 3.32830119133, mse: 2.78073954582, local kl: 11.4146881104 global kl: 6.41842579086e-10
it: 7000, train nll: 2.86613512039, mse: 2.66341876984, local kl: 4.35570001602 global kl: 6.81051104401e-10 valid nll: 3.60964250565, mse: 3.37243008614, local kl: 8.72378063202 global kl: 5.20324616726e-10
it: 7050, train nll: 1.7147629261, mse: 1.35680270195, local kl: 1.91201245785 global kl: 3.8931635693e-10 valid nll: 4.16426086426, mse: 4.31368494034, local kl: 9.76200771332 global kl: 4.32632013547e-10
it: 7100, train nll: 2.16448402405, mse: 2.0668027401, local kl: 1.65962016582 global kl: 4.45441378227e-10 valid nll: 4.03715801239, mse: 3.95956349373, local kl: 10.968328476 global kl: 4.21087886782e-10
it: 7150, train nll: 2.23464345932, mse: 1.81334233284, local kl: 3.95186400414 global kl: 4.61141930197e-10 valid nll: 4.55120515823, mse: 4.8364481926, local kl: 9.72647476196 global kl: 4.22883589257e-10
it: 7200, train nll: 2.69217324257, mse: 2.80716848373, local kl: 3.65474224091 global kl: 2.91539847819e-10 valid nll: 3.98034882545, mse: 3.92009735107, local kl: 10.1876106262 global kl: 5.72478453531e-10
it: 7250, train nll: 1.92393779755, mse: 1.29791140556, local kl: 8.14931297302 global kl: 6.04577388152e-10 valid nll: 6.37118339539, mse: 7.68345165253, local kl: 8.5719833374 global kl: 5.56899748538e-10
it: 7300, train nll: 4.00156450272, mse: 4.37254571915, local kl: 6.25479030609 global kl: 5.38819211471e-10 valid nll: 3.49928689003, mse: 3.19143033028, local kl: 10.8591661453 global kl: 7.74793451619e-10
it: 7350, train nll: 2.17011380196, mse: 2.04306578636, local kl: 2.01232242584 global kl: 8.20498169407e-10 valid nll: 3.6085395813, mse: 3.33162927628, local kl: 11.1947336197 global kl: 4.87157369999e-10
it: 7400, train nll: 3.06307029724, mse: 2.80940699577, local kl: 7.91753911972 global kl: 6.82501721805e-10 valid nll: 3.66544318199, mse: 3.49821186066, local kl: 9.53761863708 global kl: 8.50073456071e-10
it: 7450, train nll: 1.86030912399, mse: 1.00541591644, local kl: 5.59879255295 global kl: 5.76161451882e-10 valid nll: 4.27589082718, mse: 4.08790588379, local kl: 12.9692001343 global kl: 5.03665220641e-10
it: 7500, train nll: 1.68099558353, mse: 0.902087211609, local kl: 4.63493585587 global kl: 3.40847516878e-10 valid nll: 3.96178412437, mse: 3.96817564964, local kl: 8.64488983154 global kl: 3.53264889563e-10
it: 7550, train nll: 2.49657464027, mse: 2.00993919373, local kl: 7.28549289703 global kl: 6.07513095385e-10 valid nll: 3.114828825, mse: 2.74341249466, local kl: 7.02352380753 global kl: 2.71971417609e-10
it: 7600, train nll: 2.21653699875, mse: 1.74573731422, local kl: 4.99923229218 global kl: 4.54108528558e-10 valid nll: 4.21462678909, mse: 4.19893074036, local kl: 9.3118724823 global kl: 5.69950309171e-10
it: 7650, train nll: 1.26014924049, mse: 0.419284313917, local kl: 4.68749523163 global kl: 7.75581987522e-10 valid nll: 3.4387409687, mse: 2.92281413078, local kl: 12.5109233856 global kl: 8.71984817685e-10
it: 7700, train nll: 2.6668150425, mse: 2.34547781944, local kl: 5.34000778198 global kl: 8.42933001177e-10 valid nll: 3.24580144882, mse: 2.73441076279, local kl: 10.5591936111 global kl: 1.23176591149e-09
it: 7750, train nll: 2.45311570168, mse: 2.56059169769, local kl: 2.41297745705 global kl: 2.29370122895e-09 valid nll: 3.28629183769, mse: 2.76605439186, local kl: 10.6810626984 global kl: 2.0676302892e-09
it: 7800, train nll: 2.15078425407, mse: 1.79948997498, local kl: 6.1606388092 global kl: 1.30212685079e-09 valid nll: 3.91551017761, mse: 3.5036842823, local kl: 15.4577951431 global kl: 2.71787325978e-09
it: 7850, train nll: 2.24000525475, mse: 1.86896848679, local kl: 2.75127530098 global kl: 5.11664544067e-10 valid nll: 3.7874109745, mse: 3.8189432621, local kl: 7.67694997787 global kl: 2.20620521851e-09
it: 7900, train nll: 1.30840229988, mse: 0.626443326473, local kl: 1.7983263731 global kl: 3.22837728772e-10 valid nll: 3.27418875694, mse: 2.86347413063, local kl: 10.3221073151 global kl: 8.67687366402e-10
it: 7950, train nll: 2.93783640862, mse: 2.6931579113, local kl: 3.97721290588 global kl: 1.06348874152e-09 valid nll: 3.50981283188, mse: 3.36922693253, local kl: 7.94185352325 global kl: 2.36549713151e-09
it: 8000, train nll: 1.92638850212, mse: 1.42388212681, local kl: 2.92993855476 global kl: 6.69638233752e-10 valid nll: 3.20309352875, mse: 2.69366312027, local kl: 10.4154586792 global kl: 5.83050940861e-10
it: 8050, train nll: 1.60881221294, mse: 0.926987707615, local kl: 4.67454242706 global kl: 9.66012603243e-10 valid nll: 3.47798657417, mse: 3.22163414955, local kl: 10.9405193329 global kl: 9.03650931861e-10
it: 8100, train nll: 3.2312669754, mse: 2.91602134705, local kl: 6.51536798477 global kl: 7.58475338092e-10 valid nll: 3.33488416672, mse: 2.74648976326, local kl: 12.4726514816 global kl: 2.05745198656e-09
it: 8150, train nll: 1.76873385906, mse: 1.17514061928, local kl: 2.72181987762 global kl: 2.32879870943e-09 valid nll: 3.75602388382, mse: 3.35943579674, local kl: 12.1314134598 global kl: 1.71271063998e-09
it: 8200, train nll: 2.38133811951, mse: 2.10014271736, local kl: 5.1867017746 global kl: 3.49639828201e-09 valid nll: 3.65261745453, mse: 3.43627882004, local kl: 9.97092437744 global kl: 1.77567893722e-09
it: 8250, train nll: 3.17837142944, mse: 3.64318561554, local kl: 2.44040560722 global kl: 9.41009381528e-10 valid nll: 3.44740796089, mse: 2.99872350693, local kl: 10.9735841751 global kl: 1.07514197545e-09
it: 8300, train nll: 1.93181455135, mse: 1.44064569473, local kl: 3.13438510895 global kl: 6.14467143834e-10 valid nll: 3.19580245018, mse: 2.89920735359, local kl: 11.4221925735 global kl: 1.07731579213e-09
it: 8350, train nll: 2.4936041832, mse: 2.0073390007, local kl: 5.66533946991 global kl: 6.66242117031e-10 valid nll: 3.38683128357, mse: 3.19096493721, local kl: 9.02793598175 global kl: 3.69932695587e-10
it: 8400, train nll: 4.69994020462, mse: 5.53297281265, local kl: 4.41665410995 global kl: 5.17171250269e-10 valid nll: 3.29082345963, mse: 3.05549526215, local kl: 8.11459827423 global kl: 3.95354249338e-10
it: 8450, train nll: 3.08195066452, mse: 3.09625482559, local kl: 4.97462415695 global kl: 4.3128442484e-10 valid nll: 3.59739589691, mse: 3.17198300362, local kl: 10.932097435 global kl: 3.1951652435e-10
it: 8500, train nll: 2.23256492615, mse: 2.13962864876, local kl: 2.76070737839 global kl: 1.57645257959e-10 valid nll: 3.01295089722, mse: 2.74962830544, local kl: 7.74421024323 global kl: 2.20968063291e-10
it: 8550, train nll: 3.59132623672, mse: 3.96561717987, local kl: 3.72806191444 global kl: 3.78999859274e-10 valid nll: 3.23285746574, mse: 2.8614885807, local kl: 9.04355430603 global kl: 4.66694127788e-10
it: 8600, train nll: 1.53695297241, mse: 0.906897902489, local kl: 2.21706414223 global kl: 3.95416061005e-10 valid nll: 3.57534480095, mse: 3.21763253212, local kl: 12.0911874771 global kl: 4.64525445887e-10
it: 8650, train nll: 2.31271386147, mse: 1.86365818977, local kl: 4.78085279465 global kl: 2.5039131879e-10 valid nll: 3.37086987495, mse: 3.04189252853, local kl: 10.9612874985 global kl: 2.51803827789e-10
it: 8700, train nll: 2.63099861145, mse: 2.35971426964, local kl: 3.84178566933 global kl: 4.15595641234e-10 valid nll: 3.25890207291, mse: 2.86081361771, local kl: 12.039513588 global kl: 4.831152145e-10
it: 8750, train nll: 2.01875925064, mse: 1.51553535461, local kl: 4.49067354202 global kl: 5.06404418399e-10 valid nll: 3.01314067841, mse: 2.74367594719, local kl: 11.804028511 global kl: 4.71780059463e-10
it: 8800, train nll: 3.50148963928, mse: 3.75416207314, local kl: 6.60580396652 global kl: 5.97623006637e-10 valid nll: 3.44474172592, mse: 3.16294741631, local kl: 10.3213701248 global kl: 2.44437081687e-10
it: 8850, train nll: 1.94041836262, mse: 1.43501615524, local kl: 3.35340642929 global kl: 7.89636522835e-10 valid nll: 3.23341393471, mse: 3.02596068382, local kl: 9.18964576721 global kl: 5.49949030759e-10
it: 8900, train nll: 2.83556413651, mse: 2.52674245834, local kl: 5.83864545822 global kl: 4.80505024658e-10 valid nll: 3.25868701935, mse: 2.95699882507, local kl: 7.86099100113 global kl: 3.50668105664e-10
it: 8950, train nll: 4.57670783997, mse: 4.87769079208, local kl: 5.97867012024 global kl: 1.17638221386e-09 valid nll: 3.56153297424, mse: 3.26567173004, local kl: 10.8853282928 global kl: 1.88591897654e-09
it: 9000, train nll: 1.69375145435, mse: 1.00462305546, local kl: 3.03306603432 global kl: 1.11207354436e-09 valid nll: 3.10650205612, mse: 2.82980084419, local kl: 10.421043396 global kl: 1.0669354289e-09
it: 9050, train nll: 3.14947795868, mse: 2.7417178154, local kl: 7.78482532501 global kl: 3.17525145066e-10 valid nll: 3.84056806564, mse: 3.65118837357, local kl: 9.745262146 global kl: 6.83807954704e-10
it: 9100, train nll: 2.2898747921, mse: 1.97034311295, local kl: 2.50639677048 global kl: 3.42398526199e-10 valid nll: 3.48942756653, mse: 3.25228500366, local kl: 9.23949623108 global kl: 1.99343180851e-10
it: 9150, train nll: 2.56123971939, mse: 1.92089962959, local kl: 12.9385480881 global kl: 4.14951684125e-10 valid nll: 4.53203582764, mse: 4.37469577789, local kl: 15.1609487534 global kl: 6.2807725687e-10
it: 9200, train nll: 4.43377113342, mse: 5.24020195007, local kl: 3.86767268181 global kl: 1.18932308446e-09 valid nll: 3.0411863327, mse: 2.54791712761, local kl: 9.67111492157 global kl: 1.7148508169e-09
it: 9250, train nll: 1.68690180779, mse: 1.05513799191, local kl: 3.06415772438 global kl: 1.06631947716e-09 valid nll: 3.39461731911, mse: 3.06781339645, local kl: 9.52678585052 global kl: 9.81118852827e-10
it: 9300, train nll: 4.08004236221, mse: 4.48691749573, local kl: 5.09087276459 global kl: 1.75644765399e-09 valid nll: 3.26706027985, mse: 2.81430840492, local kl: 10.3545608521 global kl: 3.93643423413e-10
it: 9350, train nll: 3.29975581169, mse: 3.59968137741, local kl: 2.4876203537 global kl: 2.09821895969e-10 valid nll: 3.56029391289, mse: 3.27871489525, local kl: 11.216591835 global kl: 4.67188565612e-10
it: 9400, train nll: 2.68590521812, mse: 2.30397748947, local kl: 6.74857950211 global kl: 8.27199642117e-10 valid nll: 3.68748998642, mse: 3.5043926239, local kl: 11.4553442001 global kl: 7.112284095e-10
it: 9450, train nll: 1.55681180954, mse: 0.934122800827, local kl: 3.06697893143 global kl: 1.89660820382e-09 valid nll: 3.17162251472, mse: 2.77382135391, local kl: 12.1841964722 global kl: 2.59755617016e-09
it: 9500, train nll: 3.02712106705, mse: 2.89475655556, local kl: 6.68263721466 global kl: 4.8412107656e-10 valid nll: 3.1791074276, mse: 2.83290672302, local kl: 9.07976818085 global kl: 5.43251554852e-10
it: 9550, train nll: 1.98640024662, mse: 1.48763298988, local kl: 3.1218650341 global kl: 2.82469075907e-10 valid nll: 3.39599728584, mse: 3.01965498924, local kl: 9.10259342194 global kl: 4.54391330118e-10
it: 9600, train nll: 2.33390665054, mse: 2.01977968216, local kl: 4.34323501587 global kl: 8.46715406122e-11 valid nll: 4.24430418015, mse: 4.23021793365, local kl: 9.05621242523 global kl: 1.96367852534e-10
it: 9650, train nll: 2.8458340168, mse: 2.71290802956, local kl: 4.48560380936 global kl: 1.34284461328e-09 valid nll: 3.39825439453, mse: 3.006118536, local kl: 10.5922927856 global kl: 1.20444032525e-09
it: 9700, train nll: 2.31331920624, mse: 1.6164958477, local kl: 9.09985542297 global kl: 4.98435959173e-10 valid nll: 3.68887352943, mse: 3.33310270309, local kl: 11.8291854858 global kl: 5.9717375489e-10
it: 9750, train nll: 2.67627763748, mse: 2.36939024925, local kl: 5.13096857071 global kl: 4.39755704074e-10 valid nll: 3.2809817791, mse: 2.75127840042, local kl: 10.8973016739 global kl: 5.2708892806e-10
it: 9800, train nll: 1.75025296211, mse: 1.37081134319, local kl: 3.51549839973 global kl: 3.99357658054e-09 valid nll: 3.53291463852, mse: 3.01734256744, local kl: 12.8550701141 global kl: 4.0333274498e-09
it: 9850, train nll: 2.50188207626, mse: 2.2229282856, local kl: 4.31409263611 global kl: 4.08008682395e-10 valid nll: 3.1952559948, mse: 2.75312590599, local kl: 9.40912055969 global kl: 3.70123542925e-10
it: 9900, train nll: 3.98516988754, mse: 3.92397904396, local kl: 5.57134914398 global kl: 2.36119179764e-10 valid nll: 3.08349847794, mse: 2.5702047348, local kl: 8.65666294098 global kl: 2.34656849507e-10
it: 9950, train nll: 1.89738988876, mse: 1.42257785797, local kl: 3.59115695953 global kl: 9.47623868264e-10 valid nll: 3.21281242371, mse: 2.63018918037, local kl: 11.5001916885 global kl: 1.07683684192e-09

Prior predictive + freeform


In [0]:
num_target = 50
num_context = 512
data_hparams = tf.contrib.training.HParams(context_dim=2,
                                           num_actions=5,
                                           num_target=num_target,
                                           num_context=num_context)
X_HIDDEN_SIZE = 100
x_encoder_sizes = [X_HIDDEN_SIZE]*3

HIDDEN_SIZE = 64
latent_units = 32
freeform_decoder_sizes = [HIDDEN_SIZE]*3 + [2]
global_decoder_sizes = [HIDDEN_SIZE]*2 + [2*latent_units]
global2local_decoder_sizes = None
x_y_encoder_sizes = [HIDDEN_SIZE]*3
heteroskedastic_sizes = None
uncertainty_type = None
mean_att_type = attention.laplace_attention
scale_att_type_1 = attention.laplace_attention
scale_att_type_2 = attention.laplace_attention
att_type = 'multihead'
att_heads = 8
data_uncertainty = False

model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            freeform_decoder_sizes=freeform_decoder_sizes,
                                            global_decoder_sizes=global_decoder_sizes,
                                            global2local_decoder_sizes=global2local_decoder_sizes,
                                            heteroskedastic_sizes=heteroskedastic_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            meta_learn=False,
                                            data_uncertainty=data_uncertainty)
save_path = os.path.join(savedir, 'best_prior_freeform_mse_unclipped.ckpt')
pred_type = 'prior_predictive'
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               pred_type=pred_type,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
         model_hparams,
         training_hparams)


it: 0, train nll: 233.244369507, mse: 330.893981934, local kl: 0.0 global kl: 7.03447440173e-05 valid nll: 299.086914062, mse: 424.603118896, local kl: 7.28687268747e-07 global kl: 9.62082631304e-05
Saving best model with MSE 424.60312
it: 50, train nll: 37.6227493286, mse: 136.156738281, local kl: 0.0 global kl: 0.000122057666886 valid nll: 119.781051636, mse: 853.618164062, local kl: 0.141291186213 global kl: 0.000135921232868
it: 100, train nll: 35.6790542603, mse: 121.717926025, local kl: 0.0 global kl: 0.000265129288891 valid nll: 26.3708591461, mse: 106.190765381, local kl: 0.0135101545602 global kl: 3.95455208491e-05
Saving best model with MSE 106.190765
it: 150, train nll: 18.2585353851, mse: 67.5693893433, local kl: 0.0 global kl: 6.13677766523e-05 valid nll: 14.4377069473, mse: 92.0325088501, local kl: 0.0030499540735 global kl: 2.05152064154e-05
Saving best model with MSE 92.03251
it: 200, train nll: 12.2136974335, mse: 107.714637756, local kl: 0.0 global kl: 1.22998130792e-06 valid nll: 73.5984954834, mse: 191.983596802, local kl: 0.00251626269892 global kl: 1.44754676512e-06
it: 250, train nll: 15.0469512939, mse: 90.2373657227, local kl: 0.0 global kl: 4.51964922377e-07 valid nll: 34.2468452454, mse: 141.247741699, local kl: 0.00095150800189 global kl: 1.32043737722e-06
it: 300, train nll: 13.8157997131, mse: 133.684585571, local kl: 0.0 global kl: 2.78156221611e-07 valid nll: 54.7261161804, mse: 156.676239014, local kl: 0.0116384141147 global kl: 7.69219582253e-07
it: 350, train nll: 5.65085029602, mse: 37.2897872925, local kl: 0.0 global kl: 1.42243408163e-07 valid nll: 22.2030105591, mse: 76.5795974731, local kl: 0.0114105734974 global kl: 1.70319438553e-07
Saving best model with MSE 76.5796
it: 400, train nll: 47.2676391602, mse: 108.256637573, local kl: 0.0 global kl: 5.31903765477e-08 valid nll: 23.890504837, mse: 92.3499908447, local kl: 0.00311789382249 global kl: 3.61468482879e-06
it: 450, train nll: 7.60969114304, mse: 54.4280052185, local kl: 0.0 global kl: 7.97343062686e-08 valid nll: 17.6112346649, mse: 63.3818092346, local kl: 0.0375552587211 global kl: 1.4156117345e-07
Saving best model with MSE 63.38181
it: 500, train nll: 16.7301864624, mse: 43.0992469788, local kl: 0.0 global kl: 8.8171718815e-09 valid nll: 86.5731887817, mse: 48.0000648499, local kl: 0.433847844601 global kl: 1.74158056865e-08
Saving best model with MSE 48.000065
it: 550, train nll: 14.477180481, mse: 58.5050964355, local kl: 0.0 global kl: 8.9769347511e-09 valid nll: 31.4165534973, mse: 43.1019287109, local kl: 0.0635034814477 global kl: 6.50302300897e-09
Saving best model with MSE 43.10193
it: 600, train nll: 4.97070837021, mse: 34.4896888733, local kl: 0.0 global kl: 1.77623832087e-07 valid nll: 10.7650489807, mse: 37.2305641174, local kl: 0.140524327755 global kl: 3.39440830999e-08
Saving best model with MSE 37.230564
it: 650, train nll: 6.05850839615, mse: 49.3097305298, local kl: 0.0 global kl: 1.7858971546e-07 valid nll: 20.3483524323, mse: 68.3444976807, local kl: 0.0310782399029 global kl: 1.16322258492e-08
it: 700, train nll: 6.21667003632, mse: 24.9758834839, local kl: 0.0 global kl: 2.93956929909e-07 valid nll: 12.6280088425, mse: 37.2315216064, local kl: 0.0213788244873 global kl: 0.0
it: 750, train nll: 5.12508249283, mse: 24.9651031494, local kl: 0.0 global kl: 2.12253081777e-09 valid nll: 71.3837051392, mse: 29.0237007141, local kl: 0.062432192266 global kl: 2.1343951051e-09
Saving best model with MSE 29.0237
it: 800, train nll: 6.33889627457, mse: 34.5256195068, local kl: 0.0 global kl: 0.0 valid nll: 18.2455215454, mse: 40.4446411133, local kl: 0.0808407366276 global kl: 4.72290651032e-10
it: 850, train nll: 42.2264022827, mse: 79.4913787842, local kl: 0.0 global kl: 4.51870150187e-08 valid nll: 21.8098049164, mse: 65.4720153809, local kl: 0.0715529695153 global kl: 4.09951361746e-09
it: 900, train nll: 65862.71875, mse: 31.3406715393, local kl: 0.0 global kl: 1.19751007333e-07 valid nll: 1120.91357422, mse: 20.4919166565, local kl: 5.77571630478 global kl: 3.69552957125e-08
Saving best model with MSE 20.491917
it: 950, train nll: 8.33448219299, mse: 21.6629009247, local kl: 0.0 global kl: 6.05603645454e-07 valid nll: 3.59411096573, mse: 15.8790016174, local kl: 0.10721115768 global kl: 2.30949672186e-06
Saving best model with MSE 15.879002
it: 1000, train nll: 8.67390632629, mse: 36.3382072449, local kl: 0.0 global kl: 1.28364916918e-07 valid nll: 19.0403366089, mse: 40.6859588623, local kl: 0.0848646759987 global kl: 2.82965658016e-07
it: 1050, train nll: 13.6320466995, mse: 49.5122413635, local kl: 0.0 global kl: 5.02664825319e-08 valid nll: 7.22431564331, mse: 30.4677391052, local kl: 0.0265022739768 global kl: 7.097031407e-09
it: 1100, train nll: 1.82934117317, mse: 13.2768888474, local kl: 0.0 global kl: 0.0 valid nll: 5.7890586853, mse: 14.9057855606, local kl: 0.408183276653 global kl: 3.48051198973e-10
Saving best model with MSE 14.905786
it: 1150, train nll: 139.888198853, mse: 23.2411136627, local kl: 0.0 global kl: 1.8252226397e-08 valid nll: 125.353233337, mse: 26.5057582855, local kl: 0.0285704117268 global kl: 1.907752889e-08
it: 1200, train nll: 2.30484962463, mse: 18.480632782, local kl: 0.0 global kl: 0.0 valid nll: 7.67459869385, mse: 37.3096199036, local kl: 0.00622555520386 global kl: 0.0
it: 1250, train nll: 5.04485750198, mse: 22.0676231384, local kl: 0.0 global kl: 0.0 valid nll: 8.83702087402, mse: 29.6349334717, local kl: 0.0492613054812 global kl: 0.0
it: 1300, train nll: 34.6150817871, mse: 35.4246253967, local kl: 0.0 global kl: 0.0 valid nll: 31.6258678436, mse: 35.0287017822, local kl: 0.0551678910851 global kl: 0.0
it: 1350, train nll: 5.63910531998, mse: 18.3532619476, local kl: 0.0 global kl: 0.0 valid nll: 10.0290184021, mse: 35.2382736206, local kl: 0.178735792637 global kl: 0.0
it: 1400, train nll: 2102.71826172, mse: 51.9947280884, local kl: 0.0 global kl: 4.54630298918e-06 valid nll: 215.222915649, mse: 201.954071045, local kl: 0.0869696810842 global kl: 0.0
it: 1450, train nll: 2.88330554962, mse: 21.9685840607, local kl: 0.0 global kl: 0.0 valid nll: 8.56879615784, mse: 26.1177330017, local kl: 0.129461139441 global kl: 0.0
it: 1500, train nll: 14.2461690903, mse: 13.8321342468, local kl: 0.0 global kl: 0.0 valid nll: 5.82502698898, mse: 22.742225647, local kl: 0.0260955542326 global kl: 0.0
it: 1550, train nll: 2.21354985237, mse: 16.0892410278, local kl: 0.0 global kl: 0.0 valid nll: 283.46307373, mse: 24.7471942902, local kl: 0.031999617815 global kl: 0.0
it: 1600, train nll: 8106.48095703, mse: 26.4370613098, local kl: 0.0 global kl: 0.0 valid nll: 8.59003067017, mse: 21.8368473053, local kl: 0.0250834524632 global kl: 0.0
it: 1650, train nll: 59.5075874329, mse: 25.5304508209, local kl: 0.0 global kl: 0.0 valid nll: 11.52277565, mse: 34.344783783, local kl: 0.0906481519341 global kl: 0.0
it: 1700, train nll: 9.31398868561, mse: 28.8361988068, local kl: 0.0 global kl: 3.83456599806e-09 valid nll: 11.5344133377, mse: 81.0085678101, local kl: 0.0441405288875 global kl: 3.07385164433e-07
it: 1750, train nll: 7.47542047501, mse: 29.7755565643, local kl: 0.0 global kl: 0.0 valid nll: 72.5062942505, mse: 44.8535842896, local kl: 0.502634167671 global kl: 0.0
it: 1800, train nll: 6.17585372925, mse: 36.287311554, local kl: 0.0 global kl: 1.01079287163e-07 valid nll: 2.09490394592, mse: 23.1277103424, local kl: 0.0241507180035 global kl: 1.15984374882e-10
it: 1850, train nll: 27.0159187317, mse: 29.6779003143, local kl: 0.0 global kl: 0.0 valid nll: 12.5993089676, mse: 28.2772960663, local kl: 0.0111934691668 global kl: 1.37193922001e-06
it: 1900, train nll: 6.47953796387, mse: 12.7410697937, local kl: 0.0 global kl: 0.0 valid nll: 5.23857545853, mse: 22.3287906647, local kl: 0.0540931001306 global kl: 0.0
it: 1950, train nll: 6.01980113983, mse: 24.2641506195, local kl: 0.0 global kl: 0.0 valid nll: 2.26339268684, mse: 26.0249080658, local kl: 0.259679853916 global kl: 0.0
it: 2000, train nll: 154.521331787, mse: 28.2832374573, local kl: 0.0 global kl: 0.0 valid nll: 33.4602546692, mse: 55.1908454895, local kl: 0.0467975065112 global kl: 0.0
it: 2050, train nll: 1.62547147274, mse: 7.95186424255, local kl: 0.0 global kl: 0.0 valid nll: 7.65046262741, mse: 30.8763084412, local kl: 0.0359988920391 global kl: 0.0
it: 2100, train nll: 2.36467862129, mse: 34.3113594055, local kl: 0.0 global kl: 2.92662960533e-09 valid nll: 7.60727071762, mse: 40.5357933044, local kl: 0.151594340801 global kl: 0.0
it: 2150, train nll: 4.75442886353, mse: 15.9797801971, local kl: 0.0 global kl: 0.0 valid nll: 11.1762266159, mse: 29.3393135071, local kl: 0.0361705645919 global kl: 0.0
it: 2200, train nll: 10434.7597656, mse: 28.1880664825, local kl: 0.0 global kl: 0.0 valid nll: 378.204467773, mse: 36.708732605, local kl: 0.0673946663737 global kl: 0.0
it: 2250, train nll: 1.34213602543, mse: 15.1614208221, local kl: 0.0 global kl: 0.0 valid nll: 1.93026733398, mse: 17.7843856812, local kl: 0.056662324816 global kl: 0.0
it: 2300, train nll: 454.350036621, mse: 27.0793361664, local kl: 0.0 global kl: 0.0 valid nll: 28419.8300781, mse: 20.9405708313, local kl: 0.0364913344383 global kl: 0.0
it: 2350, train nll: 2.65757989883, mse: 23.8089313507, local kl: 0.0 global kl: 0.0 valid nll: 2.7785551548, mse: 21.7100582123, local kl: 0.105783767998 global kl: 0.0
it: 2400, train nll: 1.19875276089, mse: 16.6403083801, local kl: 0.0 global kl: 0.0 valid nll: 6.45111942291, mse: 31.0643501282, local kl: 0.0425965040922 global kl: 0.0
it: 2450, train nll: 1.02945649624, mse: 16.8554706573, local kl: 0.0 global kl: 0.0 valid nll: 23.032245636, mse: 40.3064308167, local kl: 0.0240885205567 global kl: 0.0
it: 2500, train nll: 1.89354610443, mse: 20.0032405853, local kl: 0.0 global kl: 1.41903377937e-09 valid nll: 5.4370174408, mse: 25.2991790771, local kl: 0.0584361441433 global kl: 9.10161168655e-10
it: 2550, train nll: 3.34073996544, mse: 17.6382579803, local kl: 0.0 global kl: 8.9074852383e-08 valid nll: 11.8425931931, mse: 32.056854248, local kl: 0.0047827004455 global kl: 0.0
it: 2600, train nll: 1.5958044529, mse: 23.3365287781, local kl: 0.0 global kl: 0.0 valid nll: 143.515106201, mse: 55.9755554199, local kl: 0.0296175032854 global kl: 0.0
it: 2650, train nll: 3.66012883186, mse: 24.9185752869, local kl: 0.0 global kl: 0.0 valid nll: 6.75148296356, mse: 29.7686500549, local kl: 0.0253840927035 global kl: 0.0
it: 2700, train nll: 65.9508132935, mse: 27.8647937775, local kl: 0.0 global kl: 2.18623714687e-08 valid nll: 12.2623081207, mse: 39.3610610962, local kl: 0.0184395834804 global kl: 2.82183833633e-07
it: 2750, train nll: 2.2846801281, mse: 25.8481559753, local kl: 0.0 global kl: 2.52035370352e-08 valid nll: 8.2478427887, mse: 25.2132987976, local kl: 0.0173885133117 global kl: 9.01063259562e-08
it: 2800, train nll: 9.50282096863, mse: 16.4344825745, local kl: 0.0 global kl: 5.07467170507e-08 valid nll: 50.2045326233, mse: 54.9338722229, local kl: 0.0269495733082 global kl: 2.05401207154e-08
it: 2850, train nll: 447.619873047, mse: 47.7191200256, local kl: 0.0 global kl: 7.73448505242e-09 valid nll: 18.9033794403, mse: 33.5833892822, local kl: 22.9659976959 global kl: 6.40827280129e-10
it: 2900, train nll: 1.57312238216, mse: 12.4582700729, local kl: 0.0 global kl: 0.0 valid nll: 15.7066745758, mse: 35.3842353821, local kl: 0.137478470802 global kl: 4.46071357629e-09
it: 2950, train nll: 5.25469923019, mse: 26.604013443, local kl: 0.0 global kl: 1.17462950477e-09 valid nll: 2.57432889938, mse: 19.0374355316, local kl: 0.0637611448765 global kl: 3.28164162511e-10
it: 3000, train nll: 5.11238622665, mse: 18.933391571, local kl: 0.0 global kl: 0.0 valid nll: 28.4895191193, mse: 41.7391166687, local kl: 0.0134165110067 global kl: 0.0
it: 3050, train nll: 3.17509031296, mse: 23.9957523346, local kl: 0.0 global kl: 0.0 valid nll: 3.55092597008, mse: 22.3400230408, local kl: 0.0609841383994 global kl: 3.32352374788e-11
it: 3100, train nll: 4.52778148651, mse: 32.8353729248, local kl: 0.0 global kl: 3.87477168406e-11 valid nll: 4.7680439949, mse: 27.2863864899, local kl: 0.0277507584542 global kl: 0.0
it: 3150, train nll: 1.07570338249, mse: 15.0637454987, local kl: 0.0 global kl: 0.0 valid nll: 12.7037687302, mse: 30.9927406311, local kl: 0.0433233045042 global kl: 0.0
it: 3200, train nll: 1.0216524601, mse: 13.6344003677, local kl: 0.0 global kl: 8.73557013392e-08 valid nll: 13.7893733978, mse: 27.8015632629, local kl: 0.0147842587903 global kl: 0.0
it: 3250, train nll: 1.59636616707, mse: 22.7622241974, local kl: 0.0 global kl: 0.0 valid nll: 4.99854850769, mse: 29.1298103333, local kl: 0.0317456461489 global kl: 0.0
it: 3300, train nll: 9.23516178131, mse: 25.5169296265, local kl: 0.0 global kl: 0.0 valid nll: 818.128173828, mse: 30.6358318329, local kl: 0.148217320442 global kl: 0.0
it: 3350, train nll: 5.12079191208, mse: 22.1641998291, local kl: 0.0 global kl: 0.0 valid nll: 1.5925770998, mse: 12.7292966843, local kl: 0.0511931627989 global kl: 0.0
Saving best model with MSE 12.729297
it: 3400, train nll: 5.72056388855, mse: 27.058057785, local kl: 0.0 global kl: 2.51893368386e-07 valid nll: 17.7485809326, mse: 22.5653648376, local kl: 0.0381637513638 global kl: 1.51494464262e-06
it: 3450, train nll: 5.81204271317, mse: 15.8281927109, local kl: 0.0 global kl: 1.9033219445e-08 valid nll: 4.7759847641, mse: 12.6486101151, local kl: 0.0276202782989 global kl: 4.18645420552e-08
Saving best model with MSE 12.64861
it: 3500, train nll: 32.0869369507, mse: 26.6989498138, local kl: 0.0 global kl: 1.96495392402e-07 valid nll: 22.0628433228, mse: 32.0675849915, local kl: 0.035602979362 global kl: 6.92762000654e-07
it: 3550, train nll: 10.5251369476, mse: 25.9774055481, local kl: 0.0 global kl: 1.24541458035e-07 valid nll: 4.29034471512, mse: 21.2529201508, local kl: 0.0343338064849 global kl: 1.64732455232e-07
it: 3600, train nll: 2.03373265266, mse: 13.6510000229, local kl: 0.0 global kl: 2.05331884828e-09 valid nll: 3.1159954071, mse: 24.118396759, local kl: 0.111681528389 global kl: 1.77659664757e-08
it: 3650, train nll: 220.870498657, mse: 44.6408996582, local kl: 0.0 global kl: 3.65125139012e-08 valid nll: 133.523406982, mse: 30.4759731293, local kl: 19.3832206726 global kl: 0.0
it: 3700, train nll: 0.954360544682, mse: 19.3967342377, local kl: 0.0 global kl: 0.0 valid nll: 8.82184505463, mse: 23.5649623871, local kl: 0.0148523729295 global kl: 0.0
it: 3750, train nll: 3.30299448967, mse: 22.9248123169, local kl: 0.0 global kl: 5.89055071387e-09 valid nll: 10.9020814896, mse: 33.1045951843, local kl: 0.0173363089561 global kl: 0.0
it: 3800, train nll: 147.286697388, mse: 28.329000473, local kl: 0.0 global kl: 0.0 valid nll: 3.10918998718, mse: 28.1637458801, local kl: 0.0460172444582 global kl: 0.0
it: 3850, train nll: 10.9477243423, mse: 26.6470832825, local kl: 0.0 global kl: 0.0 valid nll: 24.5908794403, mse: 45.2291069031, local kl: 0.0351930297911 global kl: 0.0
it: 3900, train nll: 26.0254840851, mse: 43.2875785828, local kl: 0.0 global kl: 0.0 valid nll: 18.7724170685, mse: 36.1152420044, local kl: 0.00379215553403 global kl: 0.0
it: 3950, train nll: 2.65315818787, mse: 18.6509971619, local kl: 0.0 global kl: 0.0 valid nll: 3.59364056587, mse: 13.6977443695, local kl: 0.104000195861 global kl: 0.0
it: 4000, train nll: 7.15398359299, mse: 15.3281106949, local kl: 0.0 global kl: 0.0 valid nll: 19.6527900696, mse: 33.9775543213, local kl: 0.140761420131 global kl: 0.0
it: 4050, train nll: 1.11931312084, mse: 12.6540546417, local kl: 0.0 global kl: 0.0 valid nll: 8.69804668427, mse: 28.3869304657, local kl: 0.00228888634592 global kl: 0.0
it: 4100, train nll: 25.4342136383, mse: 22.654340744, local kl: 0.0 global kl: 0.0 valid nll: 34.3195877075, mse: 23.9875659943, local kl: 2.2498819828 global kl: 0.0
it: 4150, train nll: 27.3880271912, mse: 23.8297424316, local kl: 0.0 global kl: 0.0 valid nll: 20.1134605408, mse: 27.0867786407, local kl: 0.484777271748 global kl: 0.0
it: 4200, train nll: 12.1199512482, mse: 18.7533836365, local kl: 0.0 global kl: 0.0 valid nll: 7.28889322281, mse: 16.2031974792, local kl: 0.0470508523285 global kl: 0.0
it: 4250, train nll: 80.995880127, mse: 127.906806946, local kl: 0.0 global kl: 0.0 valid nll: 1.59252238274, mse: 8.03712844849, local kl: 0.0560566820204 global kl: 0.0
Saving best model with MSE 8.037128
it: 4300, train nll: 63.1550064087, mse: 35.0959701538, local kl: 0.0 global kl: 0.0 valid nll: 5.36284303665, mse: 18.6983375549, local kl: 0.0389610975981 global kl: 0.0
it: 4350, train nll: 5.58499765396, mse: 24.4351787567, local kl: 0.0 global kl: 1.17196852223e-09 valid nll: 81.3706436157, mse: 19.1508369446, local kl: 0.0974433422089 global kl: 2.4612078868e-12
it: 4400, train nll: 48.9047355652, mse: 27.4930496216, local kl: 0.0 global kl: 0.0 valid nll: 1.89903926849, mse: 18.397901535, local kl: 0.153975114226 global kl: 0.0
it: 4450, train nll: 5.26292467117, mse: 23.1527252197, local kl: 0.0 global kl: 1.45026934906e-09 valid nll: 33.7193908691, mse: 49.003944397, local kl: 0.0259026940912 global kl: 1.18134593663e-07
it: 4500, train nll: 4.45231151581, mse: 20.5475788116, local kl: 0.0 global kl: 0.0 valid nll: 22.1097183228, mse: 39.9495735168, local kl: 0.0674315467477 global kl: 0.0
it: 4550, train nll: 3.90131688118, mse: 7.27128601074, local kl: 0.0 global kl: 0.0 valid nll: 23.1879768372, mse: 15.6217956543, local kl: 0.0386722236872 global kl: 0.0
it: 4600, train nll: 40137716.0, mse: 385.514038086, local kl: 0.0 global kl: 0.0261666662991 valid nll: 283.311889648, mse: 391.733947754, local kl: 0.0 global kl: 43328622592.0
it: 4650, train nll: 201.75378418, mse: 37.2736320496, local kl: 0.0 global kl: 1.96480502979e-08 valid nll: 15.2506036758, mse: 53.3676757812, local kl: 0.0342746861279 global kl: 1.94577978618e-07
it: 4700, train nll: 135.842376709, mse: 29.0305786133, local kl: 0.0 global kl: 6.53070975076e-09 valid nll: 17.7319698334, mse: 26.9241790771, local kl: 0.0129907680675 global kl: 1.02138471902e-07
it: 4750, train nll: 6.01775312424, mse: 26.1000385284, local kl: 0.0 global kl: 8.44301979441e-08 valid nll: 22.5010948181, mse: 27.1566677094, local kl: 0.0630404725671 global kl: 8.09669558066e-08
it: 4800, train nll: 488210.90625, mse: 81.6690673828, local kl: 0.0 global kl: 9.2933569249e-06 valid nll: 219624.078125, mse: 121.532188416, local kl: 2093.92211914 global kl: 2.35766401602e-06
it: 4850, train nll: 416701.21875, mse: 39.8082351685, local kl: 0.0 global kl: 7.54602602626e-09 valid nll: 4590.8671875, mse: 101.945747375, local kl: 0.840952575207 global kl: 1.28237642727e-08
it: 4900, train nll: 44969.1757812, mse: 20.8517684937, local kl: 0.0 global kl: 5.37857136607e-09 valid nll: 15178.7050781, mse: 54.7034339905, local kl: 0.597939133644 global kl: 2.4152704281e-09
it: 4950, train nll: 34.3280563354, mse: 23.1307296753, local kl: 0.0 global kl: 0.0 valid nll: 59.8054847717, mse: 21.4316959381, local kl: 0.212557762861 global kl: 3.31068679416e-11
it: 5000, train nll: 11.3179893494, mse: 22.7790489197, local kl: 0.0 global kl: 0.0 valid nll: 54.3446998596, mse: 34.1228485107, local kl: 2.66187119484 global kl: 0.0
it: 5050, train nll: 1523.23059082, mse: 39.4788475037, local kl: 0.0 global kl: 9.24267773428e-09 valid nll: 42.532459259, mse: 34.277961731, local kl: 0.103112287819 global kl: 1.5033826406e-09
it: 5100, train nll: 1130.02380371, mse: 51.2634811401, local kl: 0.0 global kl: 2.54510069908e-07 valid nll: 9.3821105957, mse: 33.4230995178, local kl: 0.0461427047849 global kl: 1.01458264012e-08
it: 5150, train nll: 42.0236129761, mse: 14.7067756653, local kl: 0.0 global kl: 8.97240397535e-08 valid nll: 22.1814804077, mse: 37.2963218689, local kl: 0.046852145344 global kl: 5.97723470719e-08
it: 5200, train nll: 391.647064209, mse: 14.0031585693, local kl: 0.0 global kl: 2.40809292507e-08 valid nll: 8.89374923706, mse: 19.4171218872, local kl: 0.222296342254 global kl: 7.73943753529e-09
it: 5250, train nll: 9.65467834473, mse: 19.1412887573, local kl: 0.0 global kl: 5.25718757416e-09 valid nll: 409.783782959, mse: 31.9007606506, local kl: 0.0520876161754 global kl: 9.21384035735e-09
it: 5300, train nll: 2.36807823181, mse: 8.20053100586, local kl: 0.0 global kl: 8.22814150148e-10 valid nll: 3.51892733574, mse: 36.0230751038, local kl: 0.03403397277 global kl: 2.9631801457e-09
it: 5350, train nll: 1187.99902344, mse: 11.9266729355, local kl: 0.0 global kl: 1.7623387194e-09 valid nll: 23.3653450012, mse: 32.4268455505, local kl: 0.0213634800166 global kl: 0.0
it: 5400, train nll: 4.01410627365, mse: 21.309841156, local kl: 0.0 global kl: 0.0 valid nll: 53.1333122253, mse: 20.6384449005, local kl: 0.329431563616 global kl: 1.54664198115e-11
it: 5450, train nll: 46.9495735168, mse: 7.02305173874, local kl: 0.0 global kl: 0.0 valid nll: 11.4606666565, mse: 22.4382801056, local kl: 0.00575887970626 global kl: 1.2738536892e-09
it: 5500, train nll: 28118.8925781, mse: 12.9472351074, local kl: 0.0 global kl: 2.8629151827e-10 valid nll: 14.8126907349, mse: 16.0752010345, local kl: 0.631988227367 global kl: 0.0
it: 5550, train nll: 5.66637325287, mse: 17.353515625, local kl: 0.0 global kl: 0.0 valid nll: 185.026504517, mse: 36.241394043, local kl: 0.0542495846748 global kl: 0.0
it: 5600, train nll: 56.7830581665, mse: 23.6226711273, local kl: 0.0 global kl: 0.0 valid nll: 14022.4199219, mse: 25.4456367493, local kl: 0.0444868654013 global kl: 0.0
it: 5650, train nll: 14.7370023727, mse: 8.24208259583, local kl: 0.0 global kl: 0.0 valid nll: 23.8322181702, mse: 21.3396053314, local kl: 0.384601712227 global kl: 0.0
it: 5700, train nll: 5.51596403122, mse: 27.7465553284, local kl: 0.0 global kl: 0.0 valid nll: 11.1960134506, mse: 15.1722898483, local kl: 0.43504050374 global kl: 0.0
it: 5750, train nll: 98.1274185181, mse: 23.5830402374, local kl: 0.0 global kl: 1.82863395537e-08 valid nll: 7.39990663528, mse: 20.6549835205, local kl: 3.12832689285 global kl: 2.35715447161e-10
it: 5800, train nll: 12.8858146667, mse: 14.792840004, local kl: 0.0 global kl: 0.0 valid nll: 14.3315486908, mse: 22.902513504, local kl: 0.0551751665771 global kl: 0.0
it: 5850, train nll: 2.84488010406, mse: 13.9020023346, local kl: 0.0 global kl: 0.0 valid nll: 7434.26123047, mse: 18.8012962341, local kl: 0.400448322296 global kl: 4.56012103078e-11
it: 5900, train nll: 16.0777130127, mse: 29.807220459, local kl: 0.0 global kl: 0.0 valid nll: 21.7782917023, mse: 37.6907043457, local kl: 0.0354166775942 global kl: 0.0
it: 5950, train nll: 41.3239784241, mse: 11.4136037827, local kl: 0.0 global kl: 0.0 valid nll: 96.4115905762, mse: 20.8053703308, local kl: 0.108574025333 global kl: 0.0
it: 6000, train nll: 16.5551929474, mse: 4.83228731155, local kl: 0.0 global kl: 0.0 valid nll: 5.23260354996, mse: 20.037519455, local kl: 0.325040131807 global kl: 1.92460287707e-08
it: 6050, train nll: 9434.42480469, mse: 38.6820068359, local kl: 0.0 global kl: 0.0 valid nll: 6.15696811676, mse: 34.5688209534, local kl: 0.0405591875315 global kl: 0.0
it: 6100, train nll: 2.25782489777, mse: 16.6463527679, local kl: 0.0 global kl: 0.0 valid nll: 13.7718925476, mse: 18.8703556061, local kl: 0.0775294005871 global kl: 0.0
it: 6150, train nll: 70.3609237671, mse: 15.1641225815, local kl: 0.0 global kl: 0.0 valid nll: 3.76423096657, mse: 16.2609558105, local kl: 0.0679179280996 global kl: 0.0
it: 6200, train nll: 8.86699962616, mse: 13.4168624878, local kl: 0.0 global kl: 0.0 valid nll: 37.3514060974, mse: 32.5388450623, local kl: 0.0846313610673 global kl: 0.0
it: 6250, train nll: 715058112.0, mse: 286.023284912, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6300, train nll: 714903104.0, mse: 285.961242676, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 6350, train nll: 1002413248.0, mse: 400.965301514, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 6400, train nll: 827398336.0, mse: 330.959350586, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6450, train nll: 864880256.0, mse: 345.952056885, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6500, train nll: 839909952.0, mse: 335.963989258, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 6550, train nll: 827455872.0, mse: 330.98236084, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 6600, train nll: 689928768.0, mse: 275.971496582, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6650, train nll: 864846528.0, mse: 345.93862915, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 6700, train nll: 952381504.0, mse: 380.952636719, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6750, train nll: 1002357504.0, mse: 400.943023682, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913452148, local kl: 0.0 global kl: 0.0
it: 6800, train nll: 1177336832.0, mse: 470.934783936, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6850, train nll: 1177172608.0, mse: 470.869049072, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6900, train nll: 1002372288.0, mse: 400.948883057, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 6950, train nll: 1002352512.0, mse: 400.940979004, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7000, train nll: 1014799808.0, mse: 405.919952393, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7050, train nll: 902371456.0, mse: 360.948638916, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7100, train nll: 802402048.0, mse: 320.96081543, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7150, train nll: 989877056.0, mse: 395.950805664, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7200, train nll: 1027331712.0, mse: 410.932678223, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 7250, train nll: 789939200.0, mse: 315.975708008, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7300, train nll: 1039840448.0, mse: 415.936096191, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913452148, local kl: 0.0 global kl: 0.0
it: 7350, train nll: 789852928.0, mse: 315.941162109, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7400, train nll: 789923648.0, mse: 315.969512939, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7450, train nll: 902427904.0, mse: 360.971191406, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7500, train nll: 1014844032.0, mse: 405.937591553, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7550, train nll: 1139879808.0, mse: 455.951873779, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7600, train nll: 927456448.0, mse: 370.982543945, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913543701, local kl: 0.0 global kl: 0.0
it: 7650, train nll: 1264712192.0, mse: 505.884918213, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7700, train nll: 927412928.0, mse: 370.965179443, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7750, train nll: 1139783936.0, mse: 455.913543701, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7800, train nll: 902404416.0, mse: 360.961791992, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7850, train nll: 1027395904.0, mse: 410.958374023, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7900, train nll: 1014761664.0, mse: 405.904602051, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 7950, train nll: 802429056.0, mse: 320.971618652, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8000, train nll: 1139760128.0, mse: 455.904022217, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913543701, local kl: 0.0 global kl: 0.0
it: 8050, train nll: 852405440.0, mse: 340.962188721, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8100, train nll: 964841920.0, mse: 385.936737061, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913543701, local kl: 0.0 global kl: 0.0
it: 8150, train nll: 1152203008.0, mse: 460.881195068, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8200, train nll: 1014813440.0, mse: 405.925384521, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8250, train nll: 927362496.0, mse: 370.945007324, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8300, train nll: 777418048.0, mse: 310.967193604, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 8350, train nll: 889806208.0, mse: 355.922485352, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8400, train nll: 977303360.0, mse: 390.921356201, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 8450, train nll: 839914624.0, mse: 335.965820312, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8500, train nll: 777395008.0, mse: 310.958007812, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8550, train nll: 1027311104.0, mse: 410.924468994, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8600, train nll: 1139769088.0, mse: 455.907684326, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8650, train nll: 939898880.0, mse: 375.959564209, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8700, train nll: 1052350656.0, mse: 420.940307617, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8750, train nll: 1039834304.0, mse: 415.933746338, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 8800, train nll: 952426048.0, mse: 380.970397949, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8850, train nll: 1027315776.0, mse: 410.926300049, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8900, train nll: 802416832.0, mse: 320.966705322, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 8950, train nll: 927340608.0, mse: 370.936218262, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9000, train nll: 952332672.0, mse: 380.933105469, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9050, train nll: 902358208.0, mse: 360.94329834, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9100, train nll: 702442496.0, mse: 280.976989746, local kl: 0.0 global kl: 0.0 valid nll: 1064783616.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9150, train nll: 1164681600.0, mse: 465.872619629, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913543701, local kl: 0.0 global kl: 0.0
it: 9200, train nll: 1102337664.0, mse: 440.935028076, local kl: 0.0 global kl: 0.0 valid nll: 1064783616.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9250, train nll: 1027342336.0, mse: 410.936981201, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9300, train nll: 602453760.0, mse: 240.981506348, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9350, train nll: 739933760.0, mse: 295.973480225, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9400, train nll: 877413888.0, mse: 350.965515137, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9450, train nll: 989838912.0, mse: 395.935577393, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9500, train nll: 964896832.0, mse: 385.958648682, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9550, train nll: 640005440.0, mse: 256.002166748, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9600, train nll: 602469248.0, mse: 240.987686157, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9650, train nll: 864814016.0, mse: 345.925628662, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9700, train nll: 914804672.0, mse: 365.921875, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9750, train nll: 839890752.0, mse: 335.956298828, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0
it: 9800, train nll: 689973504.0, mse: 275.989379883, local kl: 0.0 global kl: 0.0 valid nll: 1064783744.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9850, train nll: 827404992.0, mse: 330.962005615, local kl: 0.0 global kl: 0.0 valid nll: 1064783808.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9900, train nll: 802453440.0, mse: 320.981384277, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913482666, local kl: 0.0 global kl: 0.0
it: 9950, train nll: 789961472.0, mse: 315.984588623, local kl: 0.0 global kl: 0.0 valid nll: 1064783680.0, mse: 425.913513184, local kl: 0.0 global kl: 0.0

Prior predictive + gp


In [0]:
num_target = 50
num_context = 512
data_hparams = tf.contrib.training.HParams(context_dim=2,
                                           num_actions=5,
                                           num_target=num_target,
                                           num_context=num_context)
X_HIDDEN_SIZE = 100
x_encoder_sizes = [X_HIDDEN_SIZE]*3

HIDDEN_SIZE = 64
latent_units = 32
freeform_decoder_sizes = None
global_decoder_sizes = [HIDDEN_SIZE]*2 + [2*latent_units]
global2local_decoder_sizes = [HIDDEN_SIZE]*3 + [2]
x_y_encoder_sizes = [HIDDEN_SIZE]*3
heteroskedastic_sizes = None
uncertainty_type = 'attentive_gp'
mean_att_type = attention.laplace_attention
scale_att_type_1 = attention.laplace_attention
scale_att_type_2 = attention.laplace_attention
att_type = 'multihead'
att_heads = 8
data_uncertainty = False

model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            freeform_decoder_sizes=freeform_decoder_sizes,
                                            global_decoder_sizes=global_decoder_sizes,
                                            global2local_decoder_sizes=global2local_decoder_sizes,
                                            heteroskedastic_sizes=heteroskedastic_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            meta_learn=False,
                                            data_uncertainty=data_uncertainty)
save_path = os.path.join(savedir, 'best_prior_gp_mse_unclipped.ckpt')
pred_type = 'prior_predictive'
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               pred_type=pred_type,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)


it: 0, train nll: 135.67288208, mse: 268.965209961, local kl: 0.0 global kl: 2.97951009998e-05 valid nll: 156.140930176, mse: 309.864196777, local kl: 0.0464261583984 global kl: 3.66583008145e-05
Saving best model with MSE 309.8642
it: 50, train nll: 38.713054657, mse: 47.3657722473, local kl: 0.0 global kl: 7.08528241375e-05 valid nll: 7.87959337234, mse: 9.00686168671, local kl: 20.2937602997 global kl: 0.000171415551449
Saving best model with MSE 9.006862
it: 100, train nll: 39.7649269104, mse: 56.1764793396, local kl: 0.0 global kl: 1.59876290127e-05 valid nll: 5.57811546326, mse: 5.43897247314, local kl: 24.2452487946 global kl: 2.80916428892e-05
Saving best model with MSE 5.4389725
it: 150, train nll: 22.0698070526, mse: 27.5763893127, local kl: 0.0 global kl: 1.70928542502e-05 valid nll: 5.52220392227, mse: 5.64932394028, local kl: 25.2032299042 global kl: 4.46284757345e-05
it: 200, train nll: 25.4558467865, mse: 34.0109519958, local kl: 0.0 global kl: 8.34837101138e-06 valid nll: 5.73757171631, mse: 5.74629878998, local kl: 23.9382076263 global kl: 9.3124626801e-06
it: 250, train nll: 17.4858455658, mse: 23.6737976074, local kl: 0.0 global kl: 9.34611875891e-07 valid nll: 4.19640207291, mse: 3.81868958473, local kl: 18.156583786 global kl: 7.27839676529e-07
Saving best model with MSE 3.8186896
it: 300, train nll: 44.40001297, mse: 41.5089111328, local kl: 0.0 global kl: 2.37443188666e-07 valid nll: 4.88419532776, mse: 4.78721570969, local kl: 17.0979824066 global kl: 2.67433392764e-07
it: 350, train nll: 4.98600816727, mse: 6.48660326004, local kl: 0.0 global kl: 1.6774728806e-07 valid nll: 3.63469481468, mse: 3.00728702545, local kl: 31.9590950012 global kl: 1.61089445783e-07
Saving best model with MSE 3.007287
it: 400, train nll: 17.6276931763, mse: 21.9941921234, local kl: 0.0 global kl: 6.57210179611e-08 valid nll: 3.13389086723, mse: 2.19624376297, local kl: 42.2717208862 global kl: 7.32079357135e-08
Saving best model with MSE 2.1962438
it: 450, train nll: 10.5739336014, mse: 15.16157341, local kl: 0.0 global kl: 7.87948835068e-08 valid nll: 3.15881633759, mse: 2.45544719696, local kl: 29.0998897552 global kl: 6.77230005408e-08
it: 500, train nll: 10.1896390915, mse: 13.8419141769, local kl: 0.0 global kl: 6.57685745864e-08 valid nll: 3.08007359505, mse: 2.25175094604, local kl: 27.1248092651 global kl: 9.39631092933e-08
it: 550, train nll: 13.055516243, mse: 18.999879837, local kl: 0.0 global kl: 5.37782227639e-08 valid nll: 2.84575152397, mse: 1.95389282703, local kl: 29.5342178345 global kl: 5.45162137655e-08
Saving best model with MSE 1.9538928
it: 600, train nll: 9.98043251038, mse: 14.123547554, local kl: 0.0 global kl: 3.57013405505e-08 valid nll: 3.1010415554, mse: 1.94351387024, local kl: 25.7098255157 global kl: 2.54009364653e-08
Saving best model with MSE 1.9435139
it: 650, train nll: 12.0002708435, mse: 15.2650356293, local kl: 0.0 global kl: 2.18747171488e-08 valid nll: 2.10267496109, mse: 1.18513834476, local kl: 23.9093055725 global kl: 2.09624602121e-08
Saving best model with MSE 1.1851383
it: 700, train nll: 23.4943408966, mse: 31.2075195312, local kl: 0.0 global kl: 2.9210053043e-08 valid nll: 2.19499707222, mse: 1.23601675034, local kl: 22.3066501617 global kl: 1.98671621376e-08
it: 750, train nll: 20.1542434692, mse: 28.5167713165, local kl: 0.0 global kl: 2.67520530173e-08 valid nll: 2.72223591805, mse: 2.01130390167, local kl: 17.6100902557 global kl: 3.57894194281e-08
it: 800, train nll: 25.0166854858, mse: 16.184381485, local kl: 0.0 global kl: 3.41682131477e-08 valid nll: 2.18673849106, mse: 1.30634498596, local kl: 43.9515380859 global kl: 3.35992851319e-08
it: 850, train nll: 20.4032001495, mse: 25.6418991089, local kl: 0.0 global kl: 2.57557246641e-08 valid nll: 1.61301779747, mse: 0.860883831978, local kl: 26.6890125275 global kl: 3.60803120714e-08
Saving best model with MSE 0.86088383
it: 900, train nll: 22.1447029114, mse: 16.9675254822, local kl: 0.0 global kl: 3.35757945891e-08 valid nll: 1.28108823299, mse: 0.503101944923, local kl: 29.516702652 global kl: 2.90000574665e-08
Saving best model with MSE 0.50310194
it: 950, train nll: 5.3908290863, mse: 6.24831199646, local kl: 0.0 global kl: 2.61158454862e-08 valid nll: 1.79969274998, mse: 1.02008450031, local kl: 29.3878078461 global kl: 3.2601867872e-08
it: 1000, train nll: 12.829744339, mse: 18.1550350189, local kl: 0.0 global kl: 1.90515994092e-08 valid nll: 2.28042387962, mse: 1.600923419, local kl: 42.8468894958 global kl: 2.42496618341e-08
it: 1050, train nll: 24.7326583862, mse: 33.0672531128, local kl: 0.0 global kl: 3.25047082583e-08 valid nll: 1.90831363201, mse: 1.08301317692, local kl: 22.3630561829 global kl: 4.09852844996e-08
it: 1100, train nll: 6.96532773972, mse: 9.69468975067, local kl: 0.0 global kl: 1.61521249709e-08 valid nll: 1.51637113094, mse: 0.682213485241, local kl: 65.3598556519 global kl: 2.35271215843e-08
it: 1150, train nll: 14.2707157135, mse: 15.9548168182, local kl: 0.0 global kl: 1.27205739275e-08 valid nll: 2.27527928352, mse: 1.62292528152, local kl: 35.3907623291 global kl: 1.18414504868e-08
it: 1200, train nll: 9.80038738251, mse: 8.24425983429, local kl: 0.0 global kl: 6.76070266437e-09 valid nll: 1.65432798862, mse: 0.746800005436, local kl: 24.9352436066 global kl: 1.14633644799e-08
it: 1250, train nll: 19.7135601044, mse: 16.7542591095, local kl: 0.0 global kl: 1.90406357348e-08 valid nll: 2.34033346176, mse: 1.78525328636, local kl: 29.6850757599 global kl: 2.25904379647e-08
it: 1300, train nll: 10.8902511597, mse: 14.6638717651, local kl: 0.0 global kl: 9.44318134799e-09 valid nll: 1.97045433521, mse: 1.16465783119, local kl: 29.8939342499 global kl: 1.35387043798e-08
it: 1350, train nll: 10.8603286743, mse: 8.50722885132, local kl: 0.0 global kl: 9.47540179652e-09 valid nll: 1.71049964428, mse: 0.834987342358, local kl: 22.8211326599 global kl: 1.98332923418e-08
it: 1400, train nll: 532.386901855, mse: 24.333814621, local kl: 0.0 global kl: 6.28361007671e-09 valid nll: 2.75197029114, mse: 2.15423130989, local kl: 25.4823284149 global kl: 8.94503049409e-09
it: 1450, train nll: 9.21810150146, mse: 13.0568265915, local kl: 0.0 global kl: 3.84112297525e-08 valid nll: 1.7244540453, mse: 0.968671321869, local kl: 32.3720932007 global kl: 4.64933869182e-08
it: 1500, train nll: 4.80863332748, mse: 5.69072008133, local kl: 0.0 global kl: 7.81681492867e-08 valid nll: 1.82010126114, mse: 1.09318590164, local kl: 25.900800705 global kl: 3.54183136153e-08
it: 1550, train nll: 12.4981193542, mse: 12.9707317352, local kl: 0.0 global kl: 2.38676669539e-08 valid nll: 1.76128959656, mse: 1.11679542065, local kl: 28.2069454193 global kl: 3.11775210093e-08
it: 1600, train nll: 32.9268455505, mse: 18.6304302216, local kl: 0.0 global kl: 1.68519331822e-08 valid nll: 1.82810032368, mse: 1.37209177017, local kl: 42.8485717773 global kl: 5.2424343977e-08
it: 1650, train nll: 11.856546402, mse: 13.7787742615, local kl: 0.0 global kl: 6.79483624921e-09 valid nll: 1.29575288296, mse: 0.630173325539, local kl: 28.1765270233 global kl: 1.17495506657e-08
it: 1700, train nll: 13.3365850449, mse: 15.0206136703, local kl: 0.0 global kl: 1.21398056052e-08 valid nll: 1.69369995594, mse: 0.952570617199, local kl: 17.6260890961 global kl: 1.15756737529e-08
it: 1750, train nll: 7.2779917717, mse: 10.5942277908, local kl: 0.0 global kl: 1.14614762126e-08 valid nll: 1.47669649124, mse: 0.863148570061, local kl: 44.0124092102 global kl: 8.2718925043e-09
it: 1800, train nll: 13.7344255447, mse: 17.8315181732, local kl: 0.0 global kl: 9.4055305766e-09 valid nll: 1.63631618023, mse: 1.09454369545, local kl: 34.8970909119 global kl: 1.68476166351e-08
it: 1850, train nll: 7.24004793167, mse: 9.79350566864, local kl: 0.0 global kl: 6.62853727462e-09 valid nll: 1.1788008213, mse: 0.58881008625, local kl: 33.4731941223 global kl: 3.38677219602e-09
it: 1900, train nll: 8.92973136902, mse: 9.64473819733, local kl: 0.0 global kl: 8.6661797738e-09 valid nll: 1.03229951859, mse: 0.389713287354, local kl: 43.9711074829 global kl: 4.85062612299e-09
Saving best model with MSE 0.3897133
it: 1950, train nll: 9.62401676178, mse: 13.4962034225, local kl: 0.0 global kl: 8.22302137493e-09 valid nll: 1.02113521099, mse: 0.524263679981, local kl: 39.0698890686 global kl: 8.25572676888e-09
it: 2000, train nll: 9.10063648224, mse: 10.6248846054, local kl: 0.0 global kl: 8.24752621753e-09 valid nll: 1.12876737118, mse: 0.592695534229, local kl: 43.0520133972 global kl: 8.83857609324e-09
it: 2050, train nll: 14.3722677231, mse: 11.3133382797, local kl: 0.0 global kl: 8.76962236163e-09 valid nll: 0.965671062469, mse: 0.40548568964, local kl: 29.0190124512 global kl: 1.11020037608e-08
it: 2100, train nll: 3.03452253342, mse: 2.7699110508, local kl: 0.0 global kl: 5.86240167522e-09 valid nll: 1.42459177971, mse: 0.755770266056, local kl: 30.2621421814 global kl: 1.16759961699e-08
it: 2150, train nll: 17.0855751038, mse: 21.267118454, local kl: 0.0 global kl: 7.5580341985e-09 valid nll: 1.31066846848, mse: 0.79665261507, local kl: 27.8705425262 global kl: 5.22449283835e-09
it: 2200, train nll: 25.1353549957, mse: 13.9426088333, local kl: 0.0 global kl: 6.37715036333e-09 valid nll: 1.44227540493, mse: 0.726765930653, local kl: 34.5564994812 global kl: 9.76554481724e-09
it: 2250, train nll: 32.7330131531, mse: 6.37265539169, local kl: 0.0 global kl: 1.26342554196e-08 valid nll: 1.38506007195, mse: 0.686924159527, local kl: 27.2734184265 global kl: 1.04184545435e-08
it: 2300, train nll: 15.2237510681, mse: 13.6877536774, local kl: 0.0 global kl: 1.55455808226e-08 valid nll: 0.77672880888, mse: 0.302642613649, local kl: 37.9379463196 global kl: 1.5875272652e-08
Saving best model with MSE 0.3026426
it: 2350, train nll: 17.242729187, mse: 22.6949958801, local kl: 0.0 global kl: 2.68113042878e-09 valid nll: 1.15183353424, mse: 0.565146267414, local kl: 21.8525905609 global kl: 3.34309069316e-09
it: 2400, train nll: 65.4532623291, mse: 9.06063652039, local kl: 0.0 global kl: 1.15507825527e-08 valid nll: 1.43206417561, mse: 0.874848425388, local kl: 41.0783195496 global kl: 6.49460352165e-09
it: 2450, train nll: 48.4601364136, mse: 16.4113063812, local kl: 0.0 global kl: 1.31480888399e-08 valid nll: 0.794924020767, mse: 0.291959136724, local kl: 35.042678833 global kl: 7.13651493456e-09
Saving best model with MSE 0.29195914
it: 2500, train nll: 22.0635948181, mse: 23.0963478088, local kl: 0.0 global kl: 1.93935832726e-07 valid nll: 1.23714518547, mse: 0.555161952972, local kl: 38.7894096375 global kl: 1.92030398694e-07
it: 2550, train nll: 1.67632687092, mse: 1.0826035738, local kl: 0.0 global kl: 3.14309600569e-09 valid nll: 0.772527813911, mse: 0.305853366852, local kl: 24.2130489349 global kl: 2.94608271112e-09
it: 2600, train nll: 33.3498916626, mse: 18.3238754272, local kl: 0.0 global kl: 5.2210831214e-09 valid nll: 0.401161313057, mse: 0.0887761935592, local kl: 50.0205001831 global kl: 6.07915229267e-09
Saving best model with MSE 0.08877619
it: 2650, train nll: 21.2488689423, mse: 21.0498142242, local kl: 0.0 global kl: 4.62703786397e-09 valid nll: 0.927809894085, mse: 0.46798825264, local kl: 34.4223213196 global kl: 4.91190954577e-09
it: 2700, train nll: 19.2231998444, mse: 12.523273468, local kl: 0.0 global kl: 5.49122258775e-09 valid nll: 0.957778096199, mse: 0.438966006041, local kl: 25.6892948151 global kl: 2.83178569482e-09
it: 2750, train nll: 4.28415203094, mse: 5.94087696075, local kl: 0.0 global kl: 2.57805368165e-07 valid nll: 0.762319028378, mse: 0.265781074762, local kl: 45.2685203552 global kl: 4.19925072492e-08
it: 2800, train nll: 32.2546730042, mse: 14.9935045242, local kl: 0.0 global kl: 0.0 valid nll: 0.871641039848, mse: 0.307507008314, local kl: 36.3893280029 global kl: 0.0
it: 2850, train nll: 13.2314825058, mse: 18.3900947571, local kl: 0.0 global kl: 0.0 valid nll: 1.10180902481, mse: 0.510904252529, local kl: 33.4876823425 global kl: 0.0
it: 2900, train nll: 14.3235998154, mse: 16.7537384033, local kl: 0.0 global kl: 0.0 valid nll: 0.737176537514, mse: 0.18620377779, local kl: 22.2613544464 global kl: 0.0
it: 2950, train nll: 24.4667053223, mse: 29.3658943176, local kl: 0.0 global kl: 0.0 valid nll: 1.57248878479, mse: 0.877755999565, local kl: 25.3010139465 global kl: 0.0
it: 3000, train nll: 19.6643009186, mse: 15.5699148178, local kl: 0.0 global kl: 0.0 valid nll: 0.889896690845, mse: 0.382617145777, local kl: 26.6562538147 global kl: 0.0
it: 3050, train nll: 4.03701019287, mse: 4.88511228561, local kl: 0.0 global kl: 0.0 valid nll: 0.612292468548, mse: 0.162031233311, local kl: 32.7294654846 global kl: 0.0
it: 3100, train nll: 17.8829479218, mse: 22.8830280304, local kl: 0.0 global kl: 0.0 valid nll: 0.819652795792, mse: 0.270338386297, local kl: 31.1775894165 global kl: 0.0
it: 3150, train nll: 10.8478317261, mse: 13.8416604996, local kl: 0.0 global kl: 0.0 valid nll: 0.649574816227, mse: 0.279452562332, local kl: 44.9304504395 global kl: 0.0
it: 3200, train nll: 3.02262115479, mse: 3.66352462769, local kl: 0.0 global kl: 0.0 valid nll: 0.678210735321, mse: 0.143701344728, local kl: 32.3190803528 global kl: 0.0
it: 3250, train nll: 3.14674520493, mse: 3.26892399788, local kl: 0.0 global kl: 0.0 valid nll: 0.660642564297, mse: 0.177934601903, local kl: 92.3591461182 global kl: 0.0
it: 3300, train nll: 44.4920005798, mse: 28.2844085693, local kl: 0.0 global kl: 0.0 valid nll: 0.648121535778, mse: 0.107744172215, local kl: 33.0481033325 global kl: 0.0
it: 3350, train nll: 9.75230884552, mse: 12.6673021317, local kl: 0.0 global kl: 6.26801153203e-08 valid nll: 0.942423760891, mse: 0.40170648694, local kl: 86.8008804321 global kl: 3.34933396573e-08
it: 3400, train nll: 167.604949951, mse: 13.5978555679, local kl: 0.0 global kl: 0.0 valid nll: 0.5482853055, mse: 0.182373180985, local kl: 62.2076835632 global kl: 0.0
it: 3450, train nll: 8.24667072296, mse: 10.6210451126, local kl: 0.0 global kl: 0.0 valid nll: 0.528516113758, mse: 0.143158048391, local kl: 108.999893188 global kl: 0.0
it: 3500, train nll: 12.7469158173, mse: 18.8595867157, local kl: 0.0 global kl: 0.0 valid nll: 0.757270872593, mse: 0.21701541543, local kl: 262.090026855 global kl: 0.0
it: 3550, train nll: 18.8503932953, mse: 15.4892187119, local kl: 0.0 global kl: 0.0 valid nll: 0.678811788559, mse: 0.173564791679, local kl: 100.598960876 global kl: 0.0
it: 3600, train nll: 4.07899045944, mse: 5.76709938049, local kl: 0.0 global kl: 0.0 valid nll: 0.649443745613, mse: 0.228646844625, local kl: 61.3081893921 global kl: 0.0
it: 3650, train nll: 70.5420532227, mse: 21.3098926544, local kl: 0.0 global kl: 0.0 valid nll: 0.748411476612, mse: 0.213907390833, local kl: 85.6216201782 global kl: 0.0
it: 3700, train nll: 11.7111034393, mse: 11.9440422058, local kl: 0.0 global kl: 0.0 valid nll: 0.728141307831, mse: 0.236078694463, local kl: 39.6701087952 global kl: 0.0
it: 3750, train nll: 190.362060547, mse: 19.001581192, local kl: 0.0 global kl: 0.0 valid nll: 0.475658208132, mse: 0.133007869124, local kl: 25.6355705261 global kl: 0.0
it: 3800, train nll: 13.3809175491, mse: 15.2693357468, local kl: 0.0 global kl: 0.0 valid nll: 0.975448668003, mse: 0.527097404003, local kl: 42.2584266663 global kl: 0.0
it: 3850, train nll: 68.8415222168, mse: 20.0454978943, local kl: 0.0 global kl: 0.0 valid nll: 0.498117536306, mse: 0.190657824278, local kl: 28.330827713 global kl: 0.0
it: 3900, train nll: 6.0312461853, mse: 7.49583482742, local kl: 0.0 global kl: 0.0 valid nll: 0.785881698132, mse: 0.351758360863, local kl: 47.4314994812 global kl: 0.0
it: 3950, train nll: 8.02746200562, mse: 8.58711051941, local kl: 0.0 global kl: 0.0 valid nll: 0.910224735737, mse: 0.595754563808, local kl: 38.1561813354 global kl: 0.0
it: 4000, train nll: 17.7047557831, mse: 11.6804714203, local kl: 0.0 global kl: 0.0 valid nll: 0.519778072834, mse: 0.11802738905, local kl: 29.7209243774 global kl: 0.0
it: 4050, train nll: 23.7838249207, mse: 28.5435237885, local kl: 0.0 global kl: 0.0 valid nll: 0.385556697845, mse: 0.0731834322214, local kl: 17.7685337067 global kl: 0.0
Saving best model with MSE 0.07318343
it: 4100, train nll: 69.2422103882, mse: 28.9557304382, local kl: 0.0 global kl: 0.0 valid nll: 0.671221792698, mse: 0.38440734148, local kl: 29.3675842285 global kl: 0.0
it: 4150, train nll: 4.79461479187, mse: 6.29106378555, local kl: 0.0 global kl: 0.0 valid nll: 0.487108975649, mse: 0.194384679198, local kl: 42.8515319824 global kl: 0.0
it: 4200, train nll: 6.46383523941, mse: 7.95132875443, local kl: 0.0 global kl: 0.0 valid nll: 0.960959076881, mse: 0.518565714359, local kl: 24.8157157898 global kl: 0.0
it: 4250, train nll: 30.6752567291, mse: 19.0483913422, local kl: 0.0 global kl: 0.0 valid nll: 0.360718458891, mse: 0.108757920563, local kl: 50.7504425049 global kl: 0.0
it: 4300, train nll: 11.6510868073, mse: 13.9357585907, local kl: 0.0 global kl: 0.0 valid nll: 0.591583013535, mse: 0.231441482902, local kl: 52.3128623962 global kl: 0.0
it: 4350, train nll: 14.5436754227, mse: 16.730463028, local kl: 0.0 global kl: 0.0 valid nll: 0.622538685799, mse: 0.185143217444, local kl: 37.6928062439 global kl: 0.0
it: 4400, train nll: 52.3740921021, mse: 29.4204883575, local kl: 0.0 global kl: 0.0 valid nll: 0.520299732685, mse: 0.195781707764, local kl: 47.7467727661 global kl: 0.0
it: 4450, train nll: 37.4780158997, mse: 21.5979747772, local kl: 0.0 global kl: 0.0 valid nll: 0.267373710871, mse: 0.0733546987176, local kl: 50.4918327332 global kl: 0.0
it: 4500, train nll: 6.01602125168, mse: 6.58507537842, local kl: 0.0 global kl: 0.0 valid nll: 0.675513327122, mse: 0.336170732975, local kl: 38.5327644348 global kl: 0.0
it: 4550, train nll: 7.13419389725, mse: 9.38906097412, local kl: 0.0 global kl: 0.0 valid nll: 0.217198386788, mse: 0.0643740743399, local kl: 110.252449036 global kl: 0.0
Saving best model with MSE 0.064374074
it: 4600, train nll: 4.20773649216, mse: 5.71775817871, local kl: 0.0 global kl: 0.0 valid nll: 0.30002617836, mse: 0.0488403737545, local kl: 66.4875793457 global kl: 0.0
Saving best model with MSE 0.048840374
it: 4650, train nll: 9.04526424408, mse: 13.1717824936, local kl: 0.0 global kl: 0.0 valid nll: 0.309748828411, mse: 0.110865518451, local kl: 64.5202178955 global kl: 0.0
it: 4700, train nll: 8.81709861755, mse: 11.8904752731, local kl: 0.0 global kl: 0.0 valid nll: 0.436405271292, mse: 0.184258937836, local kl: 35.9930458069 global kl: 0.0
it: 4750, train nll: 11.2636318207, mse: 11.188123703, local kl: 0.0 global kl: 0.0 valid nll: 0.354118585587, mse: 0.0905349999666, local kl: 51.6694908142 global kl: 0.0
it: 4800, train nll: 11.8270330429, mse: 17.3811149597, local kl: 0.0 global kl: 0.0 valid nll: 0.138171285391, mse: 0.0464979149401, local kl: 43.3686676025 global kl: 0.0
Saving best model with MSE 0.046497915
it: 4850, train nll: 17.7627677917, mse: 8.49855232239, local kl: 0.0 global kl: 0.0 valid nll: 0.276249885559, mse: 0.0687968581915, local kl: 24.6384944916 global kl: 0.0
it: 4900, train nll: 18.16314888, mse: 20.1385860443, local kl: 0.0 global kl: 0.0 valid nll: 0.361848920584, mse: 0.115614749491, local kl: 38.258972168 global kl: 0.0
it: 4950, train nll: 15.4627981186, mse: 16.7423400879, local kl: 0.0 global kl: 0.0 valid nll: 0.664833962917, mse: 0.280036985874, local kl: 50.9451026917 global kl: 0.0
it: 5000, train nll: 8.14297103882, mse: 9.19327259064, local kl: 0.0 global kl: 0.0 valid nll: 1.01974153519, mse: 0.681376874447, local kl: 36.7171020508 global kl: 0.0
it: 5050, train nll: 12.6986045837, mse: 15.2093696594, local kl: 0.0 global kl: 0.0 valid nll: 0.19777494669, mse: 0.0585678294301, local kl: 27.7733249664 global kl: 0.0
it: 5100, train nll: 7.16447210312, mse: 7.91525363922, local kl: 0.0 global kl: 0.0 valid nll: 0.187136292458, mse: 0.0765735656023, local kl: 24.3987922668 global kl: 0.0
it: 5150, train nll: 10.5720844269, mse: 15.9469490051, local kl: 0.0 global kl: 0.0 valid nll: 0.664290726185, mse: 0.250223219395, local kl: 62.5214195251 global kl: 0.0
it: 5200, train nll: 4.48386192322, mse: 5.68594932556, local kl: 0.0 global kl: 0.0 valid nll: 0.363571375608, mse: 0.0882641524076, local kl: 19.9330730438 global kl: 0.0
it: 5250, train nll: 24.8049182892, mse: 14.2355527878, local kl: 0.0 global kl: 0.0 valid nll: 0.464312195778, mse: 0.251261681318, local kl: 113.42325592 global kl: 0.0
it: 5300, train nll: 24.3277378082, mse: 18.4686088562, local kl: 0.0 global kl: 0.0 valid nll: 0.0420930497348, mse: 0.0230998639017, local kl: 913.272460938 global kl: 0.0
Saving best model with MSE 0.023099864
it: 5350, train nll: 2.56082415581, mse: 2.57475781441, local kl: 0.0 global kl: 0.0 valid nll: 0.0811261981726, mse: 0.0146845858544, local kl: 27.4068717957 global kl: 0.0
Saving best model with MSE 0.014684586
it: 5400, train nll: 12.6974020004, mse: 10.9486036301, local kl: 0.0 global kl: 0.0 valid nll: 0.828587472439, mse: 0.410275399685, local kl: 46.5762367249 global kl: 0.0
it: 5450, train nll: 8.61384773254, mse: 12.6422109604, local kl: 0.0 global kl: 0.0 valid nll: 0.475454121828, mse: 0.158815026283, local kl: 22.3523578644 global kl: 0.0
it: 5500, train nll: 7.72694444656, mse: 6.55122184753, local kl: 0.0 global kl: 0.0 valid nll: 0.279172986746, mse: 0.16195769608, local kl: 112.629753113 global kl: 0.0
it: 5550, train nll: 4.35671615601, mse: 5.37067508698, local kl: 0.0 global kl: 0.0 valid nll: 0.496717125177, mse: 0.187304407358, local kl: 40.9086303711 global kl: 0.0
it: 5600, train nll: 22.2403907776, mse: 18.7036571503, local kl: 0.0 global kl: 0.0 valid nll: 0.728114008904, mse: 0.431292295456, local kl: 29.844619751 global kl: 0.0
it: 5650, train nll: 12.6454248428, mse: 14.9674634933, local kl: 0.0 global kl: 0.0 valid nll: 0.458410918713, mse: 0.141432538629, local kl: 35.5283508301 global kl: 0.0
it: 5700, train nll: 79.5557937622, mse: 8.74200248718, local kl: 0.0 global kl: 0.0 valid nll: 0.0328171551228, mse: 0.0357264280319, local kl: 120.836013794 global kl: 0.0
it: 5750, train nll: 3.34545898438, mse: 3.91176128387, local kl: 0.0 global kl: 0.0 valid nll: 0.114516437054, mse: 0.0650213211775, local kl: 63.8450965881 global kl: 0.0
it: 5800, train nll: 118.347419739, mse: 18.4449100494, local kl: 0.0 global kl: 0.0 valid nll: 0.309643030167, mse: 0.110810354352, local kl: 21.5921802521 global kl: 0.0
it: 5850, train nll: 5.41095638275, mse: 6.22319364548, local kl: 0.0 global kl: 0.0 valid nll: 0.194172441959, mse: 0.088344797492, local kl: 63.188117981 global kl: 0.0
it: 5900, train nll: 21.0755996704, mse: 12.1778059006, local kl: 0.0 global kl: 0.0 valid nll: 0.0930842459202, mse: 0.0212290044874, local kl: 27.9643478394 global kl: 0.0
it: 5950, train nll: 12.3289070129, mse: 17.9484672546, local kl: 0.0 global kl: 0.0 valid nll: 0.313672810793, mse: 0.0630095899105, local kl: 29.3130607605 global kl: 0.0
it: 6000, train nll: 1.89957809448, mse: 1.53467750549, local kl: 0.0 global kl: 0.0 valid nll: 0.739322006702, mse: 0.416579604149, local kl: 32.0859870911 global kl: 0.0
it: 6050, train nll: 82.3761978149, mse: 32.27551651, local kl: 0.0 global kl: 0.0 valid nll: 0.713897228241, mse: 0.530344069004, local kl: 290.030426025 global kl: 0.0
it: 6100, train nll: 17.312379837, mse: 13.5887212753, local kl: 0.0 global kl: 0.0 valid nll: 0.217863470316, mse: 0.0574299246073, local kl: 39.0650291443 global kl: 0.0
it: 6150, train nll: 7.21775150299, mse: 9.88391590118, local kl: 0.0 global kl: 0.0 valid nll: 0.0103831104934, mse: 0.0275967326015, local kl: 49.0590591431 global kl: 0.0
it: 6200, train nll: 3.67263817787, mse: 3.83211684227, local kl: 0.0 global kl: 0.0 valid nll: 0.36027687788, mse: 0.34011888504, local kl: 26.4660015106 global kl: 0.0
it: 6250, train nll: 51.3519668579, mse: 21.1924571991, local kl: 0.0 global kl: 0.0 valid nll: 0.295661866665, mse: 0.205094575882, local kl: 41.4537506104 global kl: 0.0
it: 6300, train nll: 38.6716461182, mse: 19.3161315918, local kl: 0.0 global kl: 2.54136556244e-09 valid nll: 0.00359613983892, mse: 0.0193141121417, local kl: 61.9747810364 global kl: 0.0
it: 6350, train nll: 4.79793357849, mse: 6.40180206299, local kl: 0.0 global kl: 0.0 valid nll: 0.103637449443, mse: 0.0266434047371, local kl: 28.1779747009 global kl: 0.0
it: 6400, train nll: 16.1473426819, mse: 18.4485416412, local kl: 0.0 global kl: 0.0 valid nll: 0.0914581269026, mse: 0.0604567117989, local kl: 38.5506057739 global kl: 0.0
it: 6450, train nll: 1.85471069813, mse: 1.36480474472, local kl: 0.0 global kl: 0.0 valid nll: 0.264219641685, mse: 0.0760953426361, local kl: 60.6284332275 global kl: 0.0
it: 6500, train nll: 94.5267791748, mse: 19.4239997864, local kl: 0.0 global kl: 0.0 valid nll: 0.502601206303, mse: 0.154880002141, local kl: 35.4002037048 global kl: 0.0
it: 6550, train nll: 12.5939769745, mse: 13.6239719391, local kl: 0.0 global kl: 0.0 valid nll: 0.171389222145, mse: 0.0419179238379, local kl: 27.6049346924 global kl: 0.0
it: 6600, train nll: 3.31055903435, mse: 4.17938756943, local kl: 0.0 global kl: 0.0 valid nll: 0.0836106091738, mse: 0.0138732092455, local kl: 94.2180480957 global kl: 0.0
Saving best model with MSE 0.013873209
it: 6650, train nll: 10947768.0, mse: 22.1083049774, local kl: 0.0 global kl: 0.0 valid nll: 0.346638560295, mse: 0.0723776742816, local kl: 43.9190635681 global kl: 0.0
it: 6700, train nll: 16.8882389069, mse: 22.2330570221, local kl: 0.0 global kl: 0.0 valid nll: 0.181410491467, mse: 0.0374855734408, local kl: 18.2445373535 global kl: 0.0
it: 6750, train nll: 2.31386017799, mse: 2.05940914154, local kl: 0.0 global kl: 0.0 valid nll: 0.0876859128475, mse: 0.0293555036187, local kl: 37.2882919312 global kl: 0.0
it: 6800, train nll: 42.9523010254, mse: 27.2865695953, local kl: 0.0 global kl: 0.0 valid nll: 0.222671240568, mse: 0.300532519817, local kl: 62.2840538025 global kl: 0.0
it: 6850, train nll: 11.5527029037, mse: 9.4832239151, local kl: 0.0 global kl: 0.0 valid nll: 0.122034415603, mse: 0.0619608312845, local kl: 27.9019680023 global kl: 0.0
it: 6900, train nll: 2.70521187782, mse: 2.37272834778, local kl: 0.0 global kl: 0.0 valid nll: 0.31624147296, mse: 0.0905210673809, local kl: 28.9618091583 global kl: 0.0
it: 6950, train nll: 11.3715343475, mse: 12.7519884109, local kl: 0.0 global kl: 0.0 valid nll: -0.0582815408707, mse: 0.0256973579526, local kl: 37.5166320801 global kl: 0.0
it: 7000, train nll: 15.4968137741, mse: 20.6603889465, local kl: 0.0 global kl: 0.0 valid nll: 0.176808923483, mse: 0.0693604275584, local kl: 888.353942871 global kl: 0.0
it: 7050, train nll: 0.99615240097, mse: 0.557232379913, local kl: 0.0 global kl: 0.0 valid nll: 0.377445489168, mse: 0.182582944632, local kl: 30.1433162689 global kl: 0.0
it: 7100, train nll: 3.68472123146, mse: 4.44814682007, local kl: 0.0 global kl: 0.0 valid nll: 0.339302748442, mse: 0.212706327438, local kl: 62.6370353699 global kl: 0.0
it: 7150, train nll: 12.5832071304, mse: 13.7182970047, local kl: 0.0 global kl: 0.0 valid nll: -0.0162086375058, mse: 0.0399602837861, local kl: 1832.62988281 global kl: 0.0
it: 7200, train nll: 18.4166069031, mse: 12.3811836243, local kl: 0.0 global kl: 0.0 valid nll: 0.227905422449, mse: 0.11953818053, local kl: 72.2518768311 global kl: 0.0
it: 7250, train nll: 157.305465698, mse: 27.0544719696, local kl: 0.0 global kl: 0.0 valid nll: 0.114396311343, mse: 0.0393659472466, local kl: 41.4475440979 global kl: 0.0
it: 7300, train nll: 21.3018112183, mse: 24.7862358093, local kl: 0.0 global kl: 0.0 valid nll: 0.359025359154, mse: 0.215255871415, local kl: 62.1644287109 global kl: 0.0
it: 7350, train nll: 5.5211482048, mse: 5.93625164032, local kl: 0.0 global kl: 0.0 valid nll: -0.0174660217017, mse: 0.0389247536659, local kl: 36.7948989868 global kl: 0.0
it: 7400, train nll: 44.2389335632, mse: 20.7762050629, local kl: 0.0 global kl: 0.0 valid nll: 0.111174531281, mse: 0.0565474294126, local kl: 39.728187561 global kl: 0.0
it: 7450, train nll: 10.0841407776, mse: 10.2586250305, local kl: 0.0 global kl: 0.0 valid nll: 0.274009108543, mse: 0.168050348759, local kl: 259.318450928 global kl: 0.0
it: 7500, train nll: 534.809265137, mse: 6.6686091423, local kl: 0.0 global kl: 0.0 valid nll: -0.0498784556985, mse: 0.00454327045009, local kl: 37.7799186707 global kl: 0.0
Saving best model with MSE 0.0045432705
it: 7550, train nll: 8.7248506546, mse: 12.7267580032, local kl: 0.0 global kl: 0.0 valid nll: 0.11188955605, mse: 0.0354894250631, local kl: 42.6978416443 global kl: 0.0
it: 7600, train nll: 77.7779922485, mse: 13.0122127533, local kl: 0.0 global kl: 0.0 valid nll: 0.503135442734, mse: 0.351943999529, local kl: 39.7210540771 global kl: 0.0
it: 7650, train nll: 3.74051451683, mse: 5.05398368835, local kl: 0.0 global kl: 0.0 valid nll: 0.429058372974, mse: 0.231047123671, local kl: 105.342140198 global kl: 0.0
it: 7700, train nll: 141.330734253, mse: 12.6501226425, local kl: 0.0 global kl: 0.0 valid nll: -0.0873204171658, mse: 0.0437855236232, local kl: 70.1931381226 global kl: 0.0
it: 7750, train nll: 19.4714660645, mse: 12.6586074829, local kl: 0.0 global kl: 0.0 valid nll: -0.11208973825, mse: 0.0218198560178, local kl: 42.6104431152 global kl: 0.0
it: 7800, train nll: 18.5738124847, mse: 9.88386917114, local kl: 0.0 global kl: 0.0 valid nll: 0.0484111905098, mse: 0.067242577672, local kl: 52.847946167 global kl: 0.0
it: 7850, train nll: 7.57524108887, mse: 5.54613208771, local kl: 0.0 global kl: 0.0 valid nll: 0.203512325883, mse: 0.106117367744, local kl: 59.8207511902 global kl: 0.0
it: 7900, train nll: 5.17933797836, mse: 6.95267868042, local kl: 0.0 global kl: 0.0 valid nll: 0.189949408174, mse: 0.029133355245, local kl: 47.5339698792 global kl: 0.0
it: 7950, train nll: 8.931432724, mse: 9.29685306549, local kl: 0.0 global kl: 0.0 valid nll: 0.390719383955, mse: 0.161461085081, local kl: 46.2626457214 global kl: 0.0
it: 8000, train nll: 1.36241364479, mse: 0.95297563076, local kl: 0.0 global kl: 0.0 valid nll: -0.0714138671756, mse: 0.00862880237401, local kl: 49.7891540527 global kl: 0.0
it: 8050, train nll: 4.1858587265, mse: 5.98797941208, local kl: 0.0 global kl: 0.0 valid nll: 0.587609469891, mse: 0.345277518034, local kl: 44.0082321167 global kl: 0.0
it: 8100, train nll: 2.95833110809, mse: 2.90672898293, local kl: 0.0 global kl: 0.0 valid nll: 0.199976593256, mse: 0.0602283477783, local kl: 55.5089569092 global kl: 0.0
it: 8150, train nll: 46.4507026672, mse: 16.7125759125, local kl: 0.0 global kl: 0.0 valid nll: 0.101908124983, mse: 0.074765637517, local kl: 309.726165771 global kl: 0.0
it: 8200, train nll: 10.6575698853, mse: 13.0728254318, local kl: 0.0 global kl: 0.0 valid nll: 0.182992652059, mse: 0.0847501903772, local kl: 42.4263000488 global kl: 0.0
it: 8250, train nll: 19.2107429504, mse: 9.7255115509, local kl: 0.0 global kl: 0.0 valid nll: 0.396989673376, mse: 0.167534977198, local kl: 61.4095611572 global kl: 0.0
it: 8300, train nll: 6.05707073212, mse: 7.70176649094, local kl: 0.0 global kl: 0.0 valid nll: 0.614295423031, mse: 0.434649705887, local kl: 32.2700576782 global kl: 0.0
it: 8350, train nll: 8.5415391922, mse: 11.719203949, local kl: 0.0 global kl: 0.0 valid nll: 0.252622663975, mse: 0.124038569629, local kl: 54.642791748 global kl: 0.0
it: 8400, train nll: 127.641471863, mse: 18.475151062, local kl: 0.0 global kl: 0.0 valid nll: 0.34201285243, mse: 0.189043179154, local kl: 93.1075515747 global kl: 0.0
it: 8450, train nll: 8.15616226196, mse: 10.0182476044, local kl: 0.0 global kl: 0.0 valid nll: 0.0243486408144, mse: 0.052542321384, local kl: 72.8054885864 global kl: 0.0
it: 8500, train nll: 17.1113204956, mse: 10.5110254288, local kl: 0.0 global kl: 0.0 valid nll: 0.365533471107, mse: 0.165926501155, local kl: 43.1362571716 global kl: 0.0
it: 8550, train nll: 39.6881980896, mse: 11.8220157623, local kl: 0.0 global kl: 0.0 valid nll: 0.0903825238347, mse: 0.0518640652299, local kl: 31.92445755 global kl: 0.0
it: 8600, train nll: 0.612336158752, mse: 0.0477507710457, local kl: 0.0 global kl: 0.0 valid nll: 0.731662869453, mse: 0.435653448105, local kl: 20.0320549011 global kl: 0.0
it: 8650, train nll: 23.8933315277, mse: 11.0483074188, local kl: 0.0 global kl: 0.0 valid nll: 0.240048915148, mse: 0.14531300962, local kl: 55.5776863098 global kl: 0.0
it: 8700, train nll: 2.39568686485, mse: 2.74567580223, local kl: 0.0 global kl: 0.0 valid nll: 0.170260742307, mse: 0.103032119572, local kl: 98.9401245117 global kl: 0.0
it: 8750, train nll: 12.3204622269, mse: 12.3829421997, local kl: 0.0 global kl: 0.0 valid nll: 0.184437200427, mse: 0.0618651397526, local kl: 57.62758255 global kl: 0.0
it: 8800, train nll: 533.718383789, mse: 25.5261135101, local kl: 0.0 global kl: 0.0 valid nll: 0.25421756506, mse: 0.283810406923, local kl: 41.5943984985 global kl: 0.0
it: 8850, train nll: 2.36769771576, mse: 2.22573590279, local kl: 0.0 global kl: 0.0 valid nll: 0.377896219492, mse: 0.153994485736, local kl: 85.1687088013 global kl: 0.0
it: 8900, train nll: 3.97271418571, mse: 4.14866638184, local kl: 0.0 global kl: 0.0 valid nll: 0.060447640717, mse: 0.0657506659627, local kl: 114.911987305 global kl: 0.0
it: 8950, train nll: 9.78967380524, mse: 12.5131454468, local kl: 0.0 global kl: 0.0 valid nll: 0.393507987261, mse: 0.139269739389, local kl: 33.6035308838 global kl: 0.0
it: 9000, train nll: 3.53051996231, mse: 4.51103687286, local kl: 0.0 global kl: 0.0 valid nll: 0.0732915326953, mse: 0.0607601702213, local kl: 242.753875732 global kl: 0.0
it: 9050, train nll: 16.1500587463, mse: 17.2854366302, local kl: 0.0 global kl: 0.0 valid nll: 0.0805616006255, mse: 0.0440275482833, local kl: 41.1641616821 global kl: 0.0
it: 9100, train nll: 9.43925857544, mse: 10.0430774689, local kl: 0.0 global kl: 0.0 valid nll: 0.270087450743, mse: 0.185763478279, local kl: 296.815032959 global kl: 0.0
it: 9150, train nll: 56733.5039062, mse: 18.0277061462, local kl: 0.0 global kl: 0.0 valid nll: 0.163556292653, mse: 0.0571777150035, local kl: 45.8987731934 global kl: 0.0
it: 9200, train nll: 8.06288051605, mse: 11.138461113, local kl: 0.0 global kl: 0.0 valid nll: 0.257373392582, mse: 0.107091404498, local kl: 42.6111717224 global kl: 0.0
it: 9250, train nll: 15.0291070938, mse: 7.35913562775, local kl: 0.0 global kl: 0.0 valid nll: 0.213386446238, mse: 0.071647003293, local kl: 27.0039730072 global kl: 0.0
it: 9300, train nll: 23.7939090729, mse: 20.9052600861, local kl: 0.0 global kl: 0.0 valid nll: 0.187842532992, mse: 0.0787667185068, local kl: 52.4015007019 global kl: 0.0
it: 9350, train nll: 9.3624830246, mse: 10.7744474411, local kl: 0.0 global kl: 0.0 valid nll: 0.282780170441, mse: 0.159105926752, local kl: 68.7294311523 global kl: 0.0
it: 9400, train nll: 3.75494623184, mse: 3.6859805584, local kl: 0.0 global kl: 0.0 valid nll: 0.150326162577, mse: 0.0526873581111, local kl: 51.0460739136 global kl: 0.0
it: 9450, train nll: 4.14857816696, mse: 4.792927742, local kl: 0.0 global kl: 0.0 valid nll: 0.214047923684, mse: 0.0941213443875, local kl: 28.7827205658 global kl: 0.0
it: 9500, train nll: 12.5249900818, mse: 17.1225337982, local kl: 0.0 global kl: 0.0 valid nll: 0.359704107046, mse: 0.223016321659, local kl: 24.1487770081 global kl: 0.0
it: 9550, train nll: 4.94543743134, mse: 6.30193662643, local kl: 0.0 global kl: 0.0 valid nll: 0.155060470104, mse: 0.184480249882, local kl: 45.3946914673 global kl: 0.0
it: 9600, train nll: 8.58651351929, mse: 8.78809165955, local kl: 0.0 global kl: 0.0 valid nll: -0.0236851200461, mse: 0.06224180758, local kl: 33.1709251404 global kl: 0.0
it: 9650, train nll: 2.60676646233, mse: 2.45419859886, local kl: 0.0 global kl: 0.0 valid nll: 0.131714731455, mse: 0.0998166874051, local kl: 49.6479187012 global kl: 0.0
it: 9700, train nll: 16.5164833069, mse: 16.6641559601, local kl: 0.0 global kl: 0.0 valid nll: 0.0966908112168, mse: 0.0557645596564, local kl: 27.2583312988 global kl: 0.0
it: 9750, train nll: 90.7753753662, mse: 18.3127098083, local kl: 0.0 global kl: 0.0 valid nll: 0.0732259899378, mse: 0.0520556867123, local kl: 224.938293457 global kl: 0.0
it: 9800, train nll: 10.6305913925, mse: 8.16889572144, local kl: 0.0 global kl: 0.0 valid nll: 0.733917057514, mse: 0.657127439976, local kl: 51.0800094604 global kl: 0.0
it: 9850, train nll: 4.50726366043, mse: 5.6487827301, local kl: 0.0 global kl: 0.0 valid nll: 0.314571261406, mse: 0.0967169776559, local kl: 30.6570663452 global kl: 0.0
it: 9900, train nll: 12.2864122391, mse: 10.8941879272, local kl: 0.0 global kl: 0.0 valid nll: 0.477919489145, mse: 0.142212763429, local kl: 27.766254425 global kl: 0.0
it: 9950, train nll: 6.44628667831, mse: 8.77133464813, local kl: 0.0 global kl: 0.0 valid nll: 0.468151032925, mse: 0.222285300493, local kl: 61.3763046265 global kl: 0.0

Archive


In [0]:
def sample_training_wheel_bandit_data(num_total_states,
                                      num_actions,
                                      context_dim,
                                      delta,
                                      mean_v,
                                      std_v,
                                      mu_large,
                                      std_large):
  """Samples from Wheel bandit game (see https://arxiv.org/abs/1802.09127).

  Args:
    num_total_states: Number of points to sample, i.e. (context, action rewards).
    num_actions: Number of actions.
    context_dim: Number of dimensions in the context
    delta: Exploration parameter: high reward in one region if norm above delta.
    mean_v: Mean reward for each action if context norm is below delta.
    std_v: Gaussian reward std for each action if context norm is below delta.
    mu_large: Mean reward for optimal action if context norm is above delta.
    std_large: Reward std for optimal action if context norm is above delta.

  Returns:
    dataset: Sampled matrix with n rows: (context, one_hot_actions).
    opt_vals: Vector of expected optimal (reward, action) for each context.
  """


  data = []
  actions = []
  rewards = []

  # sample uniform contexts in unit ball
  while len(data) < num_total_states:
    raw_data = np.random.uniform(-1, 1, (int(num_total_states / 3), context_dim))

    for i in range(raw_data.shape[0]):
      if np.linalg.norm(raw_data[i, :]) <= 1:
        
        data.append(raw_data[i, :])

  states = np.stack(data)[:num_total_states, :]

  # sample rewards and random actions
  
  for i in range(num_total_states):
    r = [np.random.normal(mean_v[j], std_v[j]) for j in range(num_actions)]
    if np.linalg.norm(states[i, :]) >= delta:
      # large reward in the right region for the context
      r_big = np.random.normal(mu_large, std_large)
      if states[i, 0] > 0:
        if states[i, 1] > 0:
          r[0] = r_big
        else:
          r[1] = r_big
      else:
        if states[i, 1] > 0:
          r[2] = r_big
        else:
          r[3] = r_big
    one_hot_vector = np.zeros((5))
    random_action = np.random.randint(num_actions)
    one_hot_vector[random_action]=1
    actions.append(one_hot_vector)
    rewards.append(r[random_action])

  rewards = np.expand_dims(np.array(rewards), -1)
  state_action_pairs = np.hstack([states, actions])
  perm = np.random.permutation(len(rewards))
  return state_action_pairs[perm, :], rewards[perm, :]


def get_training_wheel_data(num_total_states, num_actions, context_dim, delta):

  mean_v = [1.0, 1.0, 1.0, 1.0, 1.2]
  std_v = [0.01, 0.01, 0.01, 0.01, 0.01]
  mu_large = 50
  std_large = 0.01
  state_action_pairs, rewards = sample_training_wheel_bandit_data(num_total_states,
                                              num_actions,
                                              context_dim,
                                              delta,
                                              mean_v,
                                              std_v,
                                              mu_large,
                                              std_large)
  return state_action_pairs, rewards

In [0]:
def procure_dataset(hparams, num_wheels, seed=0):
  np.random.seed(seed)

  data_type = 'wheel_2'

  all_state_action_pairs, all_rewards = [], []
  for _ in range(num_wheels):
    delta = np.random.uniform()
    state_action_pairs, rewards = get_training_wheel_data(
        hparams.num_target + hparams.num_context,
        hparams.num_actions,
        hparams.context_dim,
        delta)
    all_state_action_pairs.append(state_action_pairs)
    all_rewards.append(rewards)

  all_state_action_pairs = np.stack(all_state_action_pairs)
  all_rewards = np.stack(all_rewards)

  return all_state_action_pairs, all_rewards

@tf.function
def step(model, data, optimizer_config, num_context):

    context_x, context_y, target_x, target_y, unseen_targets = data
    with tf.GradientTape() as tape:
      prior_prediction, posterior_prediction = model(
          context_x, 
          context_y, 
          target_x, 
          target_y)
      unseen_targets = target_y[:, num_context:]
      unseen_predictions = posterior_prediction[:, num_context:]
      nll = utils.nll(unseen_targets, unseen_predictions)
      mse = utils.mse(unseen_targets, unseen_predictions)
      local_kl = tf.reduce_mean(model.losses[-1][:, num_context:])
      global_kl = tf.reduce_mean(model.losses[-2])        
      # loss = nll + local_kl + global_kl
      loss = mse + local_kl + global_kl
      # loss = nll + global_kl
      # loss = mse + global_kl
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer_config.apply_gradients(zip(gradients, model.trainable_variables))
    return nll, mse, local_kl, global_kl

def training_loop(train_dataset,
                  valid_dataset,
                  model,
                  hparams):
  
  optimizer_config = hparams.optimizer(hparams.lr)
  num_context = hparams.num_context 
  best_mse = np.inf

  train_target_x, train_target_y = train_dataset

  def _get_splits(dataset, n_context, batch_size, points_perm=True):
    full_x, full_y = dataset
    dataset_perm = np.random.permutation(len(full_x))[:batch_size]
    if points_perm:
      datapoints_perm = np.random.permutation(full_x.shape[1])
    else:
      datapoints_perm = np.arange(full_x.shape[1])

    target_x = tf.to_float(full_x[dataset_perm[:, None], datapoints_perm])
    target_y = tf.to_float(full_y[dataset_perm[:, None], datapoints_perm])
    context_x = target_x[:, :n_context, :]
    context_y = target_y[:, :n_context, :]
    unseen_targets = target_y[:, n_context:]

    return context_x, context_y, target_x, target_y, unseen_targets

  for it in range(hparams.num_iterations):
    batch_train_data = _get_splits(train_dataset, num_context, hparams.batch_size, points_perm=True)
    nll, mse, local_z_kl, global_z_kl = step(
        model,
        batch_train_data, 
        optimizer_config,
        num_context)
    
    if it % hparams.print_every == 0:
      batch_context_x, batch_context_y, batch_target_x, batch_target_y, batch_unseen_targets = _get_splits(valid_dataset, num_context, hparams.batch_size, points_perm=False)
      prior_prediction, posterior_prediction = model(
          batch_context_x, 
          batch_context_y, 
          batch_target_x, 
          batch_target_y)
      
      valid_unseen_predictions = posterior_prediction[:, num_context:]
      # unseen_predictions = prior_prediction[:, num_context:]
      valid_nll = utils.nll(batch_unseen_targets, valid_unseen_predictions)
      valid_mse = utils.mse(batch_unseen_targets, valid_unseen_predictions)
      valid_local_kl = tf.reduce_mean(model.losses[-1][:, num_context:])
      valid_global_kl = tf.reduce_mean(model.losses[-2])        

      print('it: {}, train nll: {}, mse: {}, local kl: {} global kl: {} '
            'valid nll: {}, mse: {}, local kl: {} global kl: {}'
            .format(it, nll, mse, local_z_kl, global_z_kl,
                    valid_nll, valid_mse, valid_local_kl, valid_global_kl))
      if valid_mse.numpy() < best_mse:
        print('Saving best model')
        best_mse = valid_mse.numpy()
        model.save_weights(hparams.save_path)

  print('Best MSE is', best_mse)
      
def pretrain(data_hparams,
             model_hparams,
             training_hparams):

  all_state_action_pairs, all_rewards = procure_dataset(data_hparams,
                                                        num_wheels=100,
                                                        seed=0)
  train_dataset = (all_state_action_pairs, all_rewards)

  all_state_action_pairs, all_rewards = procure_dataset(data_hparams,
                                                        num_wheels=10,
                                                        seed=42)
  valid_dataset = (all_state_action_pairs, all_rewards)

  model = Regressor(
        input_dim=data_hparams.context_dim + data_hparams.num_actions,
        output_dim=1,
        x_encoder_sizes=model_hparams.x_encoder_sizes,
        x_y_encoder_sizes=model_hparams.x_y_encoder_sizes,
        freeform_decoder_sizes=model_hparams.freeform_decoder_sizes,
        global_decoder_sizes=model_hparams.global_decoder_sizes,
        global2local_decoder_sizes=model_hparams.global2local_decoder_sizes,
        heteroskedastic_sizes=model_hparams.heteroskedastic_sizes,
        att_type=model_hparams.att_type,
        att_heads=model_hparams.att_heads,
        uncertainty_type=model_hparams.uncertainty_type,
        mean_att_type=model_hparams.mean_att_type,
        scale_att_type_1=model_hparams.scale_att_type_1,
        scale_att_type_2=model_hparams.scale_att_type_2,
        activation=model_hparams.activation,
        output_activation=model_hparams.output_activation,
        meta_learn=model_hparams.meta_learn)

  
  training_loop(train_dataset,
                             valid_dataset,
                             model,
                             training_hparams)
  
  # # check if weights are saved correctly
  # valid_context_x, valid_context_y, valid_target_x, valid_target_y, valid_unseen_targets = valid_data
  # model.load_weights(training_hparams.save_path)

  # prior_prediction, posterior_prediction = model(
  #         valid_context_x, 
  #         valid_context_y, 
  #         valid_target_x, 
  #         valid_target_y)
  
  # valid_unseen_predictions = posterior_prediction[:, num_context:]
  # valid_nll = utils.nll(valid_unseen_targets, valid_unseen_predictions)
  # valid_mse = utils.mse(valid_unseen_targets, valid_unseen_predictions)
  # print('Verified best MSE is', valid_mse.numpy())

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)

prior predictive + mse


In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)


it: 0, train nll: 80.64142608642578, mse: 264.84783935546875, local kl: 2.2282488346099854 global kl: 0.012998933903872967valid nll: 73.07066345214844, mse: 239.69493103027344, local kl: 3.808199167251587 global kl: 0.020804043859243393
it: 50, train nll: 80.03545379638672, mse: 263.78887939453125, local kl: 3.2589476108551025 global kl: 0.0017950760666280985valid nll: 77.01897430419922, mse: 253.8538055419922, local kl: 1.3728536367416382 global kl: 0.0006710884626954794
it: 100, train nll: 76.54899597167969, mse: 252.11932373046875, local kl: 3.2033069133758545 global kl: 0.00012088955554645509valid nll: 70.98995208740234, mse: 233.72201538085938, local kl: 3.41615891456604 global kl: 0.000149666826473549
it: 150, train nll: 86.5409927368164, mse: 284.36871337890625, local kl: 6.026893138885498 global kl: 0.00034067913657054305valid nll: 64.36710357666016, mse: 210.2588653564453, local kl: 2.018472194671631 global kl: 0.00029383436776697636
it: 200, train nll: 83.56436157226562, mse: 276.3795166015625, local kl: 2.478466033935547 global kl: 6.839125603619323e-07valid nll: 67.9030990600586, mse: 223.70187377929688, local kl: 6.548462390899658 global kl: 7.911902457635733e-07
it: 250, train nll: 75.58776092529297, mse: 250.43666076660156, local kl: 1.7319374084472656 global kl: 1.5221696969547338e-07valid nll: 79.78518676757812, mse: 264.43365478515625, local kl: 2.751905679702759 global kl: 1.308293917645642e-07
it: 300, train nll: 76.14912414550781, mse: 253.78453063964844, local kl: 1.473401427268982 global kl: 3.913009294365111e-08valid nll: 58.048980712890625, mse: 193.0487823486328, local kl: 0.9498961567878723 global kl: 5.4925571646435856e-08
it: 350, train nll: 82.3572006225586, mse: 272.67333984375, local kl: 4.358160972595215 global kl: 4.587663582356072e-08valid nll: 74.87318420410156, mse: 248.21527099609375, local kl: 1.5207240581512451 global kl: 6.403710983704514e-08
it: 400, train nll: 91.03268432617188, mse: 301.0337829589844, local kl: 3.0707430839538574 global kl: 4.292673594363805e-08valid nll: 68.14114379882812, mse: 223.31600952148438, local kl: 5.099541187286377 global kl: 3.649966018315354e-08
it: 450, train nll: 75.67906188964844, mse: 250.17552185058594, local kl: 2.144524574279785 global kl: 2.1869551503073126e-08valid nll: 84.20165252685547, mse: 278.6016845703125, local kl: 2.7449562549591064 global kl: 3.175713203518171e-08
it: 500, train nll: 83.32353973388672, mse: 274.6894836425781, local kl: 3.476597309112549 global kl: 7.114813627140393e-08valid nll: 94.11646270751953, mse: 310.83953857421875, local kl: 3.4066193103790283 global kl: 4.7398525993003204e-08
it: 550, train nll: 84.16699981689453, mse: 279.2153625488281, local kl: 1.779749870300293 global kl: 1.857047138287271e-08valid nll: 61.619712829589844, mse: 203.1038360595703, local kl: 3.0211451053619385 global kl: 3.604542442303682e-08
it: 600, train nll: 96.12476348876953, mse: 317.0376281738281, local kl: 2.8844528198242188 global kl: 3.161008166330248e-08valid nll: 65.885498046875, mse: 216.1623992919922, local kl: 2.752321720123291 global kl: 5.722105811400979e-08
it: 650, train nll: 72.33695220947266, mse: 239.97520446777344, local kl: 2.402458906173706 global kl: 4.251641172459131e-08valid nll: 81.32144927978516, mse: 270.96014404296875, local kl: 2.6584548950195312 global kl: 2.2020312684389864e-08
it: 700, train nll: 73.26473236083984, mse: 243.494140625, local kl: 2.592867612838745 global kl: 6.969043564453159e-09valid nll: 89.51174926757812, mse: 298.3213806152344, local kl: 2.3439621925354004 global kl: 8.706960485937998e-09
it: 750, train nll: 89.19873046875, mse: 296.8776550292969, local kl: 2.609757661819458 global kl: 1.7568432042480708e-08valid nll: 90.67761993408203, mse: 302.61328125, local kl: 2.717984676361084 global kl: 8.707740306590495e-09
it: 800, train nll: 92.50891876220703, mse: 311.3219299316406, local kl: 1.1518354415893555 global kl: 6.968345012126065e-09valid nll: 79.88585662841797, mse: 268.59503173828125, local kl: 1.4761244058609009 global kl: 8.498346915075672e-09
it: 850, train nll: 66.6796646118164, mse: 223.8875274658203, local kl: 2.828697681427002 global kl: 1.2602218824042666e-08valid nll: 83.82502746582031, mse: 281.8943786621094, local kl: 1.7822562456130981 global kl: 1.1295233193209242e-08
it: 900, train nll: 89.82132720947266, mse: 302.96832275390625, local kl: 5.1185173988342285 global kl: 2.4734566395068214e-08valid nll: 91.43539428710938, mse: 309.444580078125, local kl: 1.8595776557922363 global kl: 1.1394554633170628e-08
it: 950, train nll: 71.63065338134766, mse: 242.0308837890625, local kl: 2.200707197189331 global kl: 4.9147486080869385e-09valid nll: 64.91653442382812, mse: 218.77122497558594, local kl: 2.184328317642212 global kl: 9.00603858156046e-09
it: 1000, train nll: 76.28456115722656, mse: 258.980224609375, local kl: 3.0768754482269287 global kl: 7.047096683976406e-09valid nll: 62.93938446044922, mse: 213.03277587890625, local kl: 0.9447728991508484 global kl: 8.388167493933452e-09
it: 1050, train nll: 75.24250030517578, mse: 257.0177001953125, local kl: 3.523057222366333 global kl: 2.9439317650314933e-09valid nll: 86.8646011352539, mse: 297.564208984375, local kl: 3.637453317642212 global kl: 4.320206414831773e-09
it: 1100, train nll: 74.27992248535156, mse: 254.22525024414062, local kl: 2.4260270595550537 global kl: 2.019649114615163e-09valid nll: 55.66709518432617, mse: 189.8374786376953, local kl: 1.605698585510254 global kl: 2.6771420635895993e-09
it: 1150, train nll: 67.89299774169922, mse: 230.09605407714844, local kl: 1.3890490531921387 global kl: 9.598144501410388e-09valid nll: 83.02552795410156, mse: 282.309326171875, local kl: 2.3077287673950195 global kl: 7.523248690688433e-09
it: 1200, train nll: 75.0975570678711, mse: 256.5108337402344, local kl: 1.4370931386947632 global kl: 3.1186579985131857e-09valid nll: 85.58000946044922, mse: 292.895751953125, local kl: 2.953248977661133 global kl: 5.895404164846241e-09
it: 1250, train nll: 81.43197631835938, mse: 278.9182434082031, local kl: 2.0099353790283203 global kl: 5.032402050630935e-09valid nll: 79.94624328613281, mse: 273.7984313964844, local kl: 9.319388389587402 global kl: 9.265770373190207e-09
it: 1300, train nll: 93.51103210449219, mse: 320.107177734375, local kl: 1.5832518339157104 global kl: 2.5319315533067766e-09valid nll: 93.16075134277344, mse: 318.7287902832031, local kl: 2.6946945190429688 global kl: 1.647670999638251e-09
it: 1350, train nll: 58.33557891845703, mse: 197.8896942138672, local kl: 1.9116010665893555 global kl: 1.7595962464866943e-09valid nll: 88.18049621582031, mse: 301.27545166015625, local kl: 2.0494961738586426 global kl: 1.947336070173833e-09
it: 1400, train nll: 81.88078308105469, mse: 280.81048583984375, local kl: 2.060009479522705 global kl: 2.8211339930805934e-09valid nll: 50.84994888305664, mse: 172.82359313964844, local kl: 1.3129647970199585 global kl: 2.5467534747747322e-09
it: 1450, train nll: 82.51393127441406, mse: 282.51495361328125, local kl: 1.8248894214630127 global kl: 2.7163331584034722e-09valid nll: 81.47759246826172, mse: 279.76165771484375, local kl: 1.116531252861023 global kl: 1.3680135912963465e-09
it: 1500, train nll: 79.3572006225586, mse: 273.8508605957031, local kl: 3.106325626373291 global kl: 3.8699625726223985e-09valid nll: 75.25865173339844, mse: 259.172119140625, local kl: 1.8615806102752686 global kl: 3.319569952253687e-09
it: 1550, train nll: 77.62882995605469, mse: 268.33197021484375, local kl: 4.153712272644043 global kl: 6.104295291464723e-09valid nll: 74.6070327758789, mse: 257.7219543457031, local kl: 1.7608585357666016 global kl: 3.8320422390825115e-09
it: 1600, train nll: 61.538211822509766, mse: 212.19032287597656, local kl: 2.307532787322998 global kl: 1.5075489745441928e-09valid nll: 71.24417877197266, mse: 246.42567443847656, local kl: 2.3156440258026123 global kl: 1.3135088572369114e-09
it: 1650, train nll: 89.43278503417969, mse: 310.0826110839844, local kl: 2.499373197555542 global kl: 2.6448536694090308e-09valid nll: 77.13996124267578, mse: 267.09320068359375, local kl: 2.5657002925872803 global kl: 2.339651139493526e-09
it: 1700, train nll: 70.66138458251953, mse: 244.83082580566406, local kl: 1.888595700263977 global kl: 1.2137414628909937e-08valid nll: 54.04180908203125, mse: 186.0751190185547, local kl: 1.357717514038086 global kl: 5.237378974953799e-09
it: 1750, train nll: 85.68579864501953, mse: 296.1614990234375, local kl: 2.8248543739318848 global kl: 6.498547033828572e-09valid nll: 80.56611633300781, mse: 278.2190856933594, local kl: 0.6889492869377136 global kl: 4.627778160681828e-09
it: 1800, train nll: 63.71576690673828, mse: 219.190185546875, local kl: 3.778057336807251 global kl: 1.4419785365760163e-08valid nll: 73.79653930664062, mse: 255.356689453125, local kl: 2.7846057415008545 global kl: 6.284953002477778e-09
it: 1850, train nll: 59.452178955078125, mse: 204.8419647216797, local kl: 3.453061580657959 global kl: 1.851964981369747e-08valid nll: 66.94640350341797, mse: 231.25009155273438, local kl: 1.3058580160140991 global kl: 1.7358361859010074e-08
it: 1900, train nll: 86.95545196533203, mse: 299.2074279785156, local kl: 1.0110441446304321 global kl: 1.965948293047859e-09valid nll: 94.6871566772461, mse: 326.21136474609375, local kl: 3.820974588394165 global kl: 2.0737160877359884e-09
it: 1950, train nll: 69.46497344970703, mse: 239.2873077392578, local kl: 5.587080955505371 global kl: 9.2384260241829e-09valid nll: 70.7953872680664, mse: 244.19085693359375, local kl: 4.670518398284912 global kl: 6.561188037323973e-09

prior predictive + nll


In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)


it: 0, train nll: 73.67438507080078, mse: 247.44496154785156, local kl: 0.8368964195251465 global kl: 0.002964801387861371valid nll: 61.58158493041992, mse: 223.068359375, local kl: 3.387556314468384 global kl: 0.011656454764306545
it: 50, train nll: 4.328117847442627, mse: 263.22772216796875, local kl: 3.845857620239258 global kl: 0.0013694826047867537valid nll: 4.294598579406738, mse: 259.070556640625, local kl: 3.2342352867126465 global kl: 0.0010290180798619986
it: 100, train nll: 3.747567892074585, mse: 292.353271484375, local kl: 5.592900276184082 global kl: 0.00027055011014454067valid nll: 3.7165119647979736, mse: 283.5585021972656, local kl: 2.8959457874298096 global kl: 0.00016625048010610044
it: 150, train nll: 3.6381709575653076, mse: 282.8884582519531, local kl: 3.68673038482666 global kl: 4.3551444832701236e-05valid nll: 3.6170289516448975, mse: 273.54779052734375, local kl: 3.87436842918396 global kl: 3.7984536902513355e-05
it: 200, train nll: 3.6732494831085205, mse: 267.37176513671875, local kl: 10615.046875 global kl: 8.492462802678347e-05valid nll: 3.5100505352020264, mse: 275.1871643066406, local kl: 9075.75390625 global kl: 3.428794298088178e-05
it: 250, train nll: 3.7324070930480957, mse: 314.98101806640625, local kl: 9.35401725769043 global kl: 2.0329113681327726e-07valid nll: 3.5243709087371826, mse: 240.9940643310547, local kl: 150.04251098632812 global kl: 1.4697643280214834e-07
it: 300, train nll: 3.4207077026367188, mse: 270.11419677734375, local kl: 249.7897186279297 global kl: 0.0valid nll: 3.560389280319214, mse: 301.7614440917969, local kl: 1227.628662109375 global kl: 0.0
it: 350, train nll: 5.344493389129639, mse: 243.17274475097656, local kl: 12940.1484375 global kl: 0.0valid nll: 3.5747857093811035, mse: 283.2994689941406, local kl: 799.0734252929688 global kl: 0.0
it: 400, train nll: 3.4565186500549316, mse: 308.4736022949219, local kl: 7066.10205078125 global kl: 0.0valid nll: 3.457401752471924, mse: 308.86749267578125, local kl: 4931.3916015625 global kl: 0.0
it: 450, train nll: 3.3996012210845947, mse: 302.0494689941406, local kl: 13149.12890625 global kl: 0.0valid nll: 3.343344211578369, mse: 235.35800170898438, local kl: 8093.013671875 global kl: 0.0
it: 500, train nll: 3.4189577102661133, mse: 278.50439453125, local kl: 815.7020874023438 global kl: 0.0valid nll: 3.414430618286133, mse: 283.69879150390625, local kl: 3044.797607421875 global kl: 0.0
it: 550, train nll: 4.058550834655762, mse: 293.2491455078125, local kl: 12455.0224609375 global kl: 0.0valid nll: 3.6099648475646973, mse: 239.6857147216797, local kl: 7805.18115234375 global kl: 0.0
it: 600, train nll: 3.382012367248535, mse: 253.41275024414062, local kl: 6556.66845703125 global kl: 0.0valid nll: 3.819643259048462, mse: 263.34906005859375, local kl: 34487.796875 global kl: 0.0
it: 650, train nll: 3.54530930519104, mse: 304.23541259765625, local kl: 29233.61328125 global kl: 0.0valid nll: 3.3641436100006104, mse: 209.0580291748047, local kl: 23959.650390625 global kl: 0.0
it: 700, train nll: 3.6240041255950928, mse: 315.43768310546875, local kl: 8626.8681640625 global kl: 0.0valid nll: 3.2750163078308105, mse: 181.36688232421875, local kl: 27081.173828125 global kl: 0.0
it: 750, train nll: 3.282343864440918, mse: 174.92100524902344, local kl: 26899.990234375 global kl: 0.0valid nll: 3.501723051071167, mse: 328.9114074707031, local kl: 25049.23046875 global kl: 0.0
it: 800, train nll: 3.3277060985565186, mse: 292.45318603515625, local kl: 26054.833984375 global kl: 0.0valid nll: 3.285163164138794, mse: 277.833251953125, local kl: 89896.90625 global kl: 0.0
it: 850, train nll: 3.658747434616089, mse: 304.57098388671875, local kl: 19510.169921875 global kl: 0.0valid nll: 3.5376126766204834, mse: 272.6236267089844, local kl: 20868.80078125 global kl: 0.0
it: 900, train nll: 3.422032117843628, mse: 261.6676025390625, local kl: 31940.009765625 global kl: 0.0valid nll: 3.415494918823242, mse: 269.5493469238281, local kl: 19789.0703125 global kl: 0.0
it: 950, train nll: 3.5429325103759766, mse: 288.5662536621094, local kl: 16307.6962890625 global kl: 0.0valid nll: 3.4671847820281982, mse: 249.974853515625, local kl: 19998.134765625 global kl: 0.0
it: 1000, train nll: 4.553430557250977, mse: 281.8724670410156, local kl: 40621.16015625 global kl: 0.0valid nll: 3.6854357719421387, mse: 306.9922790527344, local kl: 25967.947265625 global kl: 0.0
it: 1050, train nll: 3.3629562854766846, mse: 250.0182342529297, local kl: 13710.30859375 global kl: 0.0valid nll: 3.2494049072265625, mse: 207.43910217285156, local kl: 34388.359375 global kl: 0.0
it: 1100, train nll: 3.559967041015625, mse: 248.835205078125, local kl: 12408.8017578125 global kl: 0.0valid nll: 3.511564254760742, mse: 231.4795379638672, local kl: 26562.681640625 global kl: 0.0
it: 1150, train nll: 3.3173398971557617, mse: 285.4549560546875, local kl: 13595.9970703125 global kl: 0.0valid nll: 3.165963649749756, mse: 193.72756958007812, local kl: 58005.76953125 global kl: 0.0
it: 1200, train nll: 3.25348162651062, mse: 224.94085693359375, local kl: 3596.0400390625 global kl: 0.0valid nll: 3.250171422958374, mse: 239.1678924560547, local kl: 19046.216796875 global kl: 0.0
it: 1250, train nll: 3.3800406455993652, mse: 257.1576232910156, local kl: 6057.19384765625 global kl: 0.0valid nll: 3.330843925476074, mse: 250.9578857421875, local kl: 31005.404296875 global kl: 0.0
it: 1300, train nll: 3.6411006450653076, mse: 286.04364013671875, local kl: 27481.365234375 global kl: 0.0valid nll: 3.5975308418273926, mse: 270.7287902832031, local kl: 19458.806640625 global kl: 0.0
it: 1350, train nll: 3.310220241546631, mse: 248.2986602783203, local kl: 31432.6171875 global kl: 0.0valid nll: 3.3577637672424316, mse: 214.0373992919922, local kl: 31716.4140625 global kl: 0.0
it: 1400, train nll: 3.517072916030884, mse: 237.61749267578125, local kl: 23360.294921875 global kl: 0.0valid nll: 3.282334566116333, mse: 210.0449676513672, local kl: 23318.158203125 global kl: 0.0
it: 1450, train nll: 3.4225621223449707, mse: 305.62762451171875, local kl: 32183.8671875 global kl: 0.0valid nll: 3.402050733566284, mse: 306.8847351074219, local kl: 34493.296875 global kl: 0.0
it: 1500, train nll: 3.414238214492798, mse: 298.3011474609375, local kl: 12280.8798828125 global kl: 0.0valid nll: 3.3410964012145996, mse: 246.4779815673828, local kl: 26816.84375 global kl: 0.0
it: 1550, train nll: 3.5883591175079346, mse: 259.4383239746094, local kl: 21575.314453125 global kl: 0.0valid nll: 3.5218935012817383, mse: 269.5672302246094, local kl: 26368.8671875 global kl: 0.0
it: 1600, train nll: 3.3647680282592773, mse: 267.7734680175781, local kl: 25599.93359375 global kl: 0.0valid nll: 3.5629594326019287, mse: 331.22125244140625, local kl: 66531.671875 global kl: 0.0
it: 1650, train nll: 3.348034620285034, mse: 240.50941467285156, local kl: 41227.42578125 global kl: 0.0valid nll: 3.43697190284729, mse: 279.41656494140625, local kl: 20541.47265625 global kl: 0.0
it: 1700, train nll: 3.28129506111145, mse: 225.59510803222656, local kl: 37044.15234375 global kl: 0.0valid nll: 3.295877695083618, mse: 261.3631286621094, local kl: 50870.046875 global kl: 0.0
it: 1750, train nll: 3.55202579498291, mse: 257.55517578125, local kl: 30228.78515625 global kl: 0.0valid nll: 3.5055313110351562, mse: 308.5923156738281, local kl: 35936.20703125 global kl: 0.0
it: 1800, train nll: 3.2675533294677734, mse: 237.079345703125, local kl: 54332.90234375 global kl: 0.0valid nll: 3.273933172225952, mse: 279.4231262207031, local kl: 64749.0546875 global kl: 0.0
it: 1850, train nll: 3.116328477859497, mse: 241.0929412841797, local kl: 88280.1328125 global kl: 0.0valid nll: 3.196551561355591, mse: 272.121826171875, local kl: 110814.453125 global kl: 0.0
it: 1900, train nll: 3.0804624557495117, mse: 210.13845825195312, local kl: 73917.1953125 global kl: 0.0valid nll: 3.1712958812713623, mse: 287.32763671875, local kl: 134843.296875 global kl: 0.0
it: 1950, train nll: 3.178898334503174, mse: 289.71014404296875, local kl: 110371.0078125 global kl: 0.0valid nll: 3.308182716369629, mse: 236.907958984375, local kl: 130569.5078125 global kl: 0.0

Hide Run


In [0]:



Initializing model NeuroLinear-bnn.
Initializing model SNP - Attentive GP.
Initializing model SNP - Freeform.
Training NeuroLinear-bnn for 50 steps...
20 Training SNP - Freeform for 50 steps...
Average nll: 13.975089073181152, mse: 442.6430969238281, local kl: 0.0027215962763875723 global kl: 0.008305639028549194
Training NeuroLinear-bnn for 50 steps...
40 Training SNP - Freeform for 50 steps...
Average nll: 4.444005966186523, mse: 424.95819091796875, local kl: 0.05070192366838455 global kl: 0.006373311392962933
Training NeuroLinear-bnn for 50 steps...
60 Training SNP - Freeform for 50 steps...
Average nll: 4.522835731506348, mse: 495.9842834472656, local kl: 0.008277200162410736 global kl: 0.00125440105330199
Training NeuroLinear-bnn for 50 steps...
80 Training SNP - Freeform for 50 steps...
Average nll: 3.5025625228881836, mse: 274.245849609375, local kl: 0.023440886288881302 global kl: 0.0004953107563778758
Training NeuroLinear-bnn for 50 steps...
100 Training SNP - Freeform for 50 steps...
Average nll: 2.769019365310669, mse: 108.15879821777344, local kl: 0.009365015663206577 global kl: 1.700696157058701e-05
Training NeuroLinear-bnn for 50 steps...
120 Training SNP - Freeform for 50 steps...
Average nll: 2.9683678150177, mse: 64.54940795898438, local kl: 0.016004765406250954 global kl: 4.643177817342803e-06
Training NeuroLinear-bnn for 50 steps...
140 Training SNP - Freeform for 50 steps...
Average nll: 2.7586734294891357, mse: 58.559608459472656, local kl: 0.08529890328645706 global kl: 5.486677991939359e-07
Training NeuroLinear-bnn for 50 steps...
160 Training SNP - Freeform for 50 steps...
Average nll: 2.7814579010009766, mse: 50.45353698730469, local kl: 0.019919991493225098 global kl: 3.500949969748035e-05
Training NeuroLinear-bnn for 50 steps...
180 Training SNP - Freeform for 50 steps...
Average nll: 2.5930330753326416, mse: 32.366127014160156, local kl: 0.004955447278916836 global kl: 3.5279495023132768e-06
Training NeuroLinear-bnn for 50 steps...
200 Training SNP - Freeform for 50 steps...
Average nll: 2.4776883125305176, mse: 35.766151428222656, local kl: 17.357881546020508 global kl: 0.003143706126138568
Training NeuroLinear-bnn for 50 steps...
220 Training SNP - Freeform for 50 steps...
Average nll: 4.932832717895508, mse: 18.889719009399414, local kl: 0.8095366954803467 global kl: 0.00036865740548819304
Training NeuroLinear-bnn for 50 steps...
240 Training SNP - Freeform for 50 steps...
Average nll: 3.280557155609131, mse: 27.871706008911133, local kl: 0.041029900312423706 global kl: 3.5253407986601815e-05
Training NeuroLinear-bnn for 50 steps...
260 Training SNP - Freeform for 50 steps...
Average nll: 2.4722414016723633, mse: 32.39585876464844, local kl: 4.293435573577881 global kl: 0.0011978293769061565
Training NeuroLinear-bnn for 50 steps...
280 Training SNP - Freeform for 50 steps...
Average nll: 2.5823006629943848, mse: 53.762760162353516, local kl: 0.020922159776091576 global kl: 0.0010096329497173429
Training NeuroLinear-bnn for 50 steps...
300 Training SNP - Freeform for 50 steps...
Average nll: 2.115488052368164, mse: 26.188283920288086, local kl: 6.6019697189331055 global kl: 0.0037223524414002895
Training NeuroLinear-bnn for 50 steps...
320 Training SNP - Freeform for 50 steps...
Average nll: 2.193331003189087, mse: 23.847461700439453, local kl: 0.07723837345838547 global kl: 0.0013325790641829371
Training NeuroLinear-bnn for 50 steps...
340 Training SNP - Freeform for 50 steps...
Average nll: 2.1533865928649902, mse: 36.54912567138672, local kl: 8.138694763183594 global kl: 0.0015361069235950708
Training NeuroLinear-bnn for 50 steps...
360 Training SNP - Freeform for 50 steps...
Average nll: 1.9415419101715088, mse: 31.112167358398438, local kl: 0.07373294234275818 global kl: 0.00014575956447515637
Training NeuroLinear-bnn for 50 steps...
380 Training SNP - Freeform for 50 steps...
Average nll: 2.040236234664917, mse: 25.81932830810547, local kl: 0.035705287009477615 global kl: 0.0022807735949754715
Training NeuroLinear-bnn for 50 steps...
400 Training SNP - Freeform for 50 steps...
Average nll: 3.1283910274505615, mse: 37.97145080566406, local kl: 18.858545303344727 global kl: 0.002568727359175682
Training NeuroLinear-bnn for 50 steps...
420 Training SNP - Freeform for 50 steps...
Average nll: 2.300389528274536, mse: 44.353050231933594, local kl: 0.014566467143595219 global kl: 3.501701939967461e-05
Training NeuroLinear-bnn for 50 steps...
440 Training SNP - Freeform for 50 steps...
Average nll: 1.9846007823944092, mse: 39.1985969543457, local kl: 0.020711174234747887 global kl: 0.00032125250436365604
Training NeuroLinear-bnn for 50 steps...
460 Training SNP - Freeform for 50 steps...
Average nll: 2.0283517837524414, mse: 35.242366790771484, local kl: 0.02512827329337597 global kl: 3.1139006750890985e-05
Training NeuroLinear-bnn for 50 steps...
480 Training SNP - Freeform for 50 steps...
Average nll: 1.7352584600448608, mse: 21.958864212036133, local kl: 1.7142999172210693 global kl: 76.54103088378906
Training NeuroLinear-bnn for 50 steps...
500 Training SNP - Freeform for 50 steps...
Average nll: 1.9911788702011108, mse: 40.360435485839844, local kl: 0.29962867498397827 global kl: 0.00011550405906746164
Training NeuroLinear-bnn for 50 steps...
520 Training SNP - Freeform for 50 steps...
Average nll: 1.6980948448181152, mse: 19.357664108276367, local kl: 0.04216445982456207 global kl: 19648.212890625
Training NeuroLinear-bnn for 50 steps...
540 Training SNP - Freeform for 50 steps...
Average nll: 1.5290945768356323, mse: 18.527320861816406, local kl: 7.796933174133301 global kl: 0.00019222711853217334
Training NeuroLinear-bnn for 50 steps...
560 Training SNP - Freeform for 50 steps...
Average nll: 1.7833492755889893, mse: 20.800573348999023, local kl: 9.195223808288574 global kl: 2.387622833251953
Training NeuroLinear-bnn for 50 steps...
580 Training SNP - Freeform for 50 steps...
Average nll: 1.5191916227340698, mse: 20.902973175048828, local kl: 3.443289279937744 global kl: 0.0018843275029212236
Training NeuroLinear-bnn for 50 steps...
600 Training SNP - Freeform for 50 steps...
Average nll: 1.5750218629837036, mse: 22.257349014282227, local kl: 0.019717397168278694 global kl: 5.0978520448552445e-05
Training NeuroLinear-bnn for 50 steps...
620 Training SNP - Freeform for 50 steps...
Average nll: 1.476322054862976, mse: 27.610210418701172, local kl: 0.13884371519088745 global kl: 0.002854557242244482
Training NeuroLinear-bnn for 50 steps...
640 Training SNP - Freeform for 50 steps...
Average nll: 1.4065780639648438, mse: 24.740699768066406, local kl: 0.043482422828674316 global kl: 0.00029008020646870136
Training NeuroLinear-bnn for 50 steps...
660 Training SNP - Freeform for 50 steps...
Average nll: 1.5457658767700195, mse: 22.425884246826172, local kl: 0.018515516072511673 global kl: 1.4171633893056423e-06
Training NeuroLinear-bnn for 50 steps...
680 Training SNP - Freeform for 50 steps...
Average nll: 1.434037446975708, mse: 25.857213973999023, local kl: 0.028475992381572723 global kl: 1.3347479352887603e-06
Training NeuroLinear-bnn for 50 steps...
700 Training SNP - Freeform for 50 steps...
Average nll: 1.482252836227417, mse: 36.4525032043457, local kl: 0.02174651063978672 global kl: 0.0003475952835287899
Training NeuroLinear-bnn for 50 steps...
720 Training SNP - Freeform for 50 steps...
Average nll: 1.698920488357544, mse: 37.95124435424805, local kl: 0.08231242001056671 global kl: 1.842356869019568e-05
Training NeuroLinear-bnn for 50 steps...
740 Training SNP - Freeform for 50 steps...
Average nll: 1.4375579357147217, mse: 39.94460678100586, local kl: 0.12832635641098022 global kl: 0.002874696860089898
Training NeuroLinear-bnn for 50 steps...
760 Training SNP - Freeform for 50 steps...
Average nll: 1.3292423486709595, mse: 35.457244873046875, local kl: 0.042247071862220764 global kl: 5.653735570376739e-06
Training NeuroLinear-bnn for 50 steps...
780 Training SNP - Freeform for 50 steps...
Average nll: 1.5089184045791626, mse: 34.108665466308594, local kl: 0.03777702897787094 global kl: 3.890620519086951e-06
Training NeuroLinear-bnn for 50 steps...
800 Training SNP - Freeform for 50 steps...
Average nll: 1.2812187671661377, mse: 33.24834442138672, local kl: 0.22972632944583893 global kl: 2.8213597033754922e-05
Training NeuroLinear-bnn for 50 steps...
820 Training SNP - Freeform for 50 steps...
Average nll: 1.6621899604797363, mse: 38.94581985473633, local kl: 3.5695743560791016 global kl: 3.9418060623575e-05
Training NeuroLinear-bnn for 50 steps...
840 Training SNP - Freeform for 50 steps...
Average nll: 1.3448542356491089, mse: 40.36164855957031, local kl: 0.02576899342238903 global kl: 2.408310137980152e-06
Training NeuroLinear-bnn for 50 steps...
860 Training SNP - Freeform for 50 steps...
Average nll: 1.3344424962997437, mse: 44.45014572143555, local kl: 14.481250762939453 global kl: 7.077213922457304e-06
Training NeuroLinear-bnn for 50 steps...
880 Training SNP - Freeform for 50 steps...
Average nll: 1.624318242073059, mse: 44.65461730957031, local kl: 0.08831737190485 global kl: 0.0062095969915390015
Training NeuroLinear-bnn for 50 steps...
900 Training SNP - Freeform for 50 steps...
Average nll: 1.2640711069107056, mse: 47.49481964111328, local kl: 0.0587751567363739 global kl: 0.015977028757333755
Training NeuroLinear-bnn for 50 steps...
920 Training SNP - Freeform for 50 steps...
Average nll: 1.2552456855773926, mse: 44.54985809326172, local kl: 0.014569725841283798 global kl: 6.81055962559185e-06
Training NeuroLinear-bnn for 50 steps...
940 Training SNP - Freeform for 50 steps...
Average nll: 1.2278969287872314, mse: 41.48384094238281, local kl: 0.05156959593296051 global kl: 0.00035147301969118416
Training NeuroLinear-bnn for 50 steps...
960 Training SNP - Freeform for 50 steps...
Average nll: 1.2384288311004639, mse: 46.36746597290039, local kl: 0.13710714876651764 global kl: 0.0006778037059120834
Training NeuroLinear-bnn for 50 steps...
980 Training SNP - Freeform for 50 steps...
Average nll: 1.2712172269821167, mse: 47.43877410888672, local kl: 9.098155975341797 global kl: 0.8227764368057251
Training NeuroLinear-bnn for 50 steps...
1000 Training SNP - Freeform for 50 steps...
Average nll: 1.547909140586853, mse: 47.162628173828125, local kl: 0.055024076253175735 global kl: 0.3800041973590851
Training NeuroLinear-bnn for 50 steps...
1020 Training SNP - Freeform for 50 steps...
Average nll: 1.4069914817810059, mse: 45.119903564453125, local kl: 0.3066556453704834 global kl: 0.000619206577539444
Training NeuroLinear-bnn for 50 steps...
1040 Training SNP - Freeform for 50 steps...
Average nll: 1.4591575860977173, mse: 50.55404281616211, local kl: 0.017278315499424934 global kl: 0.0011456134961917996
Training NeuroLinear-bnn for 50 steps...
1060 Training SNP - Freeform for 50 steps...
Average nll: 1.2877830266952515, mse: 47.76518249511719, local kl: 0.05982470512390137 global kl: 27.61163902282715
Training NeuroLinear-bnn for 50 steps...
1080 Training SNP - Freeform for 50 steps...
Average nll: 1.229244589805603, mse: 49.284767150878906, local kl: 0.05252294987440109 global kl: 0.0005247529479674995
Training NeuroLinear-bnn for 50 steps...
1100 Training SNP - Freeform for 50 steps...
Average nll: 1.2072218656539917, mse: 45.87445831298828, local kl: 0.1934794783592224 global kl: 5.1214126870036125e-05
Training NeuroLinear-bnn for 50 steps...
1120 Training SNP - Freeform for 50 steps...
Average nll: 1.413419246673584, mse: 45.570796966552734, local kl: 8.4515962600708 global kl: 0.01385357417166233
Training NeuroLinear-bnn for 50 steps...
1140 Training SNP - Freeform for 50 steps...
Average nll: 4.240945816040039, mse: 62.059669494628906, local kl: 0.02337832935154438 global kl: 8.015317298770697e-09
Training NeuroLinear-bnn for 50 steps...
1160 Training SNP - Freeform for 50 steps...
Average nll: 1.3677847385406494, mse: 71.3122787475586, local kl: 0.04970349743962288 global kl: 0.00012125635839765891
Training NeuroLinear-bnn for 50 steps...
1180 Training SNP - Freeform for 50 steps...
Average nll: 1.238936185836792, mse: 60.37493133544922, local kl: 0.2969072461128235 global kl: 2.0242771370249102e-07
Training NeuroLinear-bnn for 50 steps...
1200 Training SNP - Freeform for 50 steps...
Average nll: 1.2066348791122437, mse: 68.67889404296875, local kl: 0.03136260807514191 global kl: 8.478330073558027e-07
Training NeuroLinear-bnn for 50 steps...
1220 Training SNP - Freeform for 50 steps...
Average nll: 1.2348915338516235, mse: 59.682289123535156, local kl: 0.054219502955675125 global kl: 0.004792483523488045
Training NeuroLinear-bnn for 50 steps...
1240 Training SNP - Freeform for 50 steps...
Average nll: 1.3336334228515625, mse: 50.19935607910156, local kl: 0.08662517368793488 global kl: 0.0014361967332661152
Training NeuroLinear-bnn for 50 steps...
1260 Training SNP - Freeform for 50 steps...
Average nll: 1.1693809032440186, mse: 55.28459167480469, local kl: 0.03194306790828705 global kl: 0.0001920744398375973
Training NeuroLinear-bnn for 50 steps...
1280 Training SNP - Freeform for 50 steps...
Average nll: 1.1875426769256592, mse: 52.477935791015625, local kl: 7.427121639251709 global kl: 4.3095657019875944e-05
Training NeuroLinear-bnn for 50 steps...
1300 Training SNP - Freeform for 50 steps...
Average nll: 1.1021891832351685, mse: 52.7716064453125, local kl: 0.29179251194000244 global kl: 1.4963989158900404e-08
Training NeuroLinear-bnn for 50 steps...
1320 Training SNP - Freeform for 50 steps...
Average nll: 1.0249452590942383, mse: 48.38215255737305, local kl: 0.03165148198604584 global kl: 3.2156851375475526e-05
Training NeuroLinear-bnn for 50 steps...
1340 Training SNP - Freeform for 50 steps...
Average nll: 1.0865187644958496, mse: 47.765869140625, local kl: 0.03037659078836441 global kl: 0.002690743189305067
Training NeuroLinear-bnn for 50 steps...
1360 Training SNP - Freeform for 50 steps...
Average nll: 1.0322026014328003, mse: 50.596229553222656, local kl: 678.6976318359375 global kl: 245940976.0
Training NeuroLinear-bnn for 50 steps...
1380 Training SNP - Freeform for 50 steps...
Average nll: 1.6090835332870483, mse: 52.163795471191406, local kl: 0.5968203544616699 global kl: 0.0002958212571684271
Training NeuroLinear-bnn for 50 steps...
1400 Training SNP - Freeform for 50 steps...
Average nll: 1.5129588842391968, mse: 79.08045196533203, local kl: 3.73641300201416 global kl: 0.007267542649060488
Training NeuroLinear-bnn for 50 steps...
1420 Training SNP - Freeform for 50 steps...
Average nll: 1.3456060886383057, mse: 57.1939697265625, local kl: 0.3890998959541321 global kl: 2.1252659280435182e-05
Training NeuroLinear-bnn for 50 steps...
1440 Training SNP - Freeform for 50 steps...
Average nll: 1.1238950490951538, mse: 66.13631439208984, local kl: 0.027808919548988342 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1460 Training SNP - Freeform for 50 steps...
Average nll: 1.5148496627807617, mse: 59.38129425048828, local kl: 0.053711604326963425 global kl: 0.0003901697346009314
Training NeuroLinear-bnn for 50 steps...
1480 Training SNP - Freeform for 50 steps...
Average nll: 1.1032111644744873, mse: 65.77779388427734, local kl: 0.20427539944648743 global kl: 0.0007842466002330184
Training NeuroLinear-bnn for 50 steps...
1500 Training SNP - Freeform for 50 steps...
Average nll: 1.251294732093811, mse: 61.354732513427734, local kl: 11.086437225341797 global kl: 0.18574115633964539
Training NeuroLinear-bnn for 50 steps...
1520 Training SNP - Freeform for 50 steps...
Average nll: 1.1537410020828247, mse: 75.30718994140625, local kl: 0.039187896996736526 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1540 Training SNP - Freeform for 50 steps...
Average nll: 1.1046643257141113, mse: 61.3782844543457, local kl: 4.843801498413086 global kl: 4.996725692762993e-05
Training NeuroLinear-bnn for 50 steps...
1560 Training SNP - Freeform for 50 steps...
Average nll: 1.1422967910766602, mse: 62.642337799072266, local kl: 0.05624907836318016 global kl: 2.9117538815626176e-06
Training NeuroLinear-bnn for 50 steps...
1580 Training SNP - Freeform for 50 steps...
Average nll: 1.1193926334381104, mse: 62.20458984375, local kl: 0.4646046757698059 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1600 Training SNP - Freeform for 50 steps...
Average nll: 1.0760167837142944, mse: 65.58612060546875, local kl: 2.625370740890503 global kl: 0.17062586545944214
Training NeuroLinear-bnn for 50 steps...
1620 Training SNP - Freeform for 50 steps...
Average nll: 1.1114128828048706, mse: 65.02207946777344, local kl: 0.023677146062254906 global kl: 2.2462938886746997e-06
Training NeuroLinear-bnn for 50 steps...
1640 Training SNP - Freeform for 50 steps...
Average nll: 1.0373274087905884, mse: 68.8541488647461, local kl: 0.13282358646392822 global kl: 2.837321744664223e-06
Training NeuroLinear-bnn for 50 steps...
1660 Training SNP - Freeform for 50 steps...
Average nll: 1.142235279083252, mse: 67.85621643066406, local kl: 0.024016940966248512 global kl: 0.004331720061600208
Training NeuroLinear-bnn for 50 steps...
1680 Training SNP - Freeform for 50 steps...
Average nll: 0.9843287467956543, mse: 62.755191802978516, local kl: 0.08569977432489395 global kl: 0.00014275932335294783
Training NeuroLinear-bnn for 50 steps...
1700 Training SNP - Freeform for 50 steps...
Average nll: 0.9326106905937195, mse: 62.681846618652344, local kl: 0.11164706945419312 global kl: 1.3081088923172501e-07
Training NeuroLinear-bnn for 50 steps...
1720 Training SNP - Freeform for 50 steps...
Average nll: 0.9784389734268188, mse: 58.6071662902832, local kl: 38.8541259765625 global kl: 5396131.0
Training NeuroLinear-bnn for 50 steps...
1740 Training SNP - Freeform for 50 steps...
Average nll: 0.9167380332946777, mse: 61.759483337402344, local kl: 36.67591094970703 global kl: 24051586.0
Training NeuroLinear-bnn for 50 steps...
1760 Training SNP - Freeform for 50 steps...
Average nll: 1.2178163528442383, mse: 64.04535675048828, local kl: 0.07610049843788147 global kl: 2.3965669697645353e-06
Training NeuroLinear-bnn for 50 steps...
1780 Training SNP - Freeform for 50 steps...
Average nll: 0.914000928401947, mse: 62.16566848754883, local kl: 0.01609613373875618 global kl: 3.357833975314861e-06
Training NeuroLinear-bnn for 50 steps...
1800 Training SNP - Freeform for 50 steps...
Average nll: 1.011247992515564, mse: 68.3225326538086, local kl: 0.02484613098204136 global kl: 1.16744240585831e-06
Training NeuroLinear-bnn for 50 steps...
1820 Training SNP - Freeform for 50 steps...
Average nll: 0.8883697390556335, mse: 63.564605712890625, local kl: 0.01767549104988575 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1840 Training SNP - Freeform for 50 steps...
Average nll: 0.9359519481658936, mse: 67.36384582519531, local kl: 0.40362972021102905 global kl: 1.0938882155642204e-07
Training NeuroLinear-bnn for 50 steps...
1860 Training SNP - Freeform for 50 steps...
Average nll: 0.9329845309257507, mse: 68.76295471191406, local kl: 77.73342895507812 global kl: 0.0008181382436305285
Training NeuroLinear-bnn for 50 steps...
1880 Training SNP - Freeform for 50 steps...
Average nll: 0.8714841604232788, mse: 62.087406158447266, local kl: 0.6080942749977112 global kl: 3.969682802562602e-05
Training NeuroLinear-bnn for 50 steps...
1900 Training SNP - Freeform for 50 steps...
Average nll: 0.8890625238418579, mse: 67.75786590576172, local kl: 0.048196956515312195 global kl: 0.0021836552768945694
Training NeuroLinear-bnn for 50 steps...
1920 Training SNP - Freeform for 50 steps...
Average nll: 0.8646968007087708, mse: 63.70701217651367, local kl: 0.02181943692266941 global kl: 1.0360118096741644e-08
Training NeuroLinear-bnn for 50 steps...
1940 Training SNP - Freeform for 50 steps...
Average nll: 0.8572739362716675, mse: 65.78986358642578, local kl: 1.0096087455749512 global kl: 6.059671401977539
Training NeuroLinear-bnn for 50 steps...
1960 Training SNP - Freeform for 50 steps...
Average nll: 1.2867435216903687, mse: 60.223960876464844, local kl: 0.06367649137973785 global kl: 1.759452459282329e-07
Training NeuroLinear-bnn for 50 steps...
1980 Training SNP - Freeform for 50 steps...
Average nll: 0.8369308710098267, mse: 63.93562316894531, local kl: 0.24088259041309357 global kl: 1.798391281226941e-06
Training NeuroLinear-bnn for 50 steps...
2000 Training SNP - Freeform for 50 steps...
Average nll: 0.8510993123054504, mse: 65.51788330078125, local kl: 0.026131683960556984 global kl: 5.146942072542515e-08
---------------------------------------------------
---------------------------------------------------
0.5_0 bandit completed after 517.9900827407837 seconds.
---------------------------------------------------
  0) NeuroLinear         |		 cummulative regret = 4009.4379190833724.
  1) SNP - Freeform      |		 cummulative regret = 5000.133844757513.
  2) MultitaskGP         |		 cummulative regret = 20509.78251206908.
  3) SNP - Attentive GP  |		 cummulative regret = 46110.88540529558.
  4) Uniform Sampling    |		 cummulative regret = 58334.638635839496.
---------------------------------------------------
---------------------------------------------------
  0) NeuroLinear         |		 simple regret = 410.3035823011785.
  1) SNP - Freeform      |		 simple regret = 698.4054956932118.
  2) MultitaskGP         |		 simple regret = 4707.528671919163.
  3) SNP - Attentive GP  |		 simple regret = 5673.204352687191.
  4) Uniform Sampling    |		 simple regret = 14899.927971777552.
---------------------------------------------------
---------------------------------------------------
Initializing model NeuroLinear-bnn.
Initializing model SNP - Attentive GP.
Initializing model SNP - Freeform.
Training NeuroLinear-bnn for 50 steps...
20 Training SNP - Freeform for 50 steps...
Average nll: 15.542445182800293, mse: 425.5062255859375, local kl: 0.00953707192093134 global kl: 0.019159924238920212
Training NeuroLinear-bnn for 50 steps...
40 Training SNP - Freeform for 50 steps...
Average nll: 4.5157599449157715, mse: 487.74578857421875, local kl: 0.008029019460082054 global kl: 0.003521548816934228
Training NeuroLinear-bnn for 50 steps...
60 Training SNP - Freeform for 50 steps...
Average nll: 4.46603536605835, mse: 443.0834045410156, local kl: 0.0140712670981884 global kl: 0.00040697638178244233
Training NeuroLinear-bnn for 50 steps...
80 Training SNP - Freeform for 50 steps...
Average nll: 3.0992956161499023, mse: 245.68540954589844, local kl: 1.2591826915740967 global kl: 8.79609797266312e-05
Training NeuroLinear-bnn for 50 steps...
100 Training SNP - Freeform for 50 steps...
Average nll: 1.7336217164993286, mse: 94.79383850097656, local kl: 4.591380596160889 global kl: 0.0001121577515732497
Training NeuroLinear-bnn for 50 steps...
120 Training SNP - Freeform for 50 steps...
Average nll: 46.1240119934082, mse: 146.36257934570312, local kl: 7.198675632476807 global kl: 2.2163812900544144e-05
Training NeuroLinear-bnn for 50 steps...
140 Training SNP - Freeform for 50 steps...
Average nll: 3.887735366821289, mse: 100.16881561279297, local kl: 0.4797371029853821 global kl: 4.516316039371304e-06
Training NeuroLinear-bnn for 50 steps...
160 Training SNP - Freeform for 50 steps...
Average nll: 3.3308680057525635, mse: 111.55757141113281, local kl: 6.446475028991699 global kl: 0.00031759421108290553
Training NeuroLinear-bnn for 50 steps...
180 Training SNP - Freeform for 50 steps...
Average nll: 2.537485361099243, mse: 123.11776733398438, local kl: 22.155160903930664 global kl: 0.00018240173812955618
Training NeuroLinear-bnn for 50 steps...
200 Training SNP - Freeform for 50 steps...
Average nll: 2.7867801189422607, mse: 103.87457275390625, local kl: 1.602138638496399 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
220 Training SNP - Freeform for 50 steps...
Average nll: 2.434086799621582, mse: 109.19169616699219, local kl: 2.946852207183838 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
240 Training SNP - Freeform for 50 steps...
Average nll: 2.5118660926818848, mse: 117.19535827636719, local kl: 1.372156023979187 global kl: 0.001934200176037848
Training NeuroLinear-bnn for 50 steps...
260 Training SNP - Freeform for 50 steps...
Average nll: 2.5934371948242188, mse: 120.79740142822266, local kl: 0.6277145147323608 global kl: 5.302341037349834e-07
Training NeuroLinear-bnn for 50 steps...
280 Training SNP - Freeform for 50 steps...
Average nll: 8.01050090789795, mse: 121.71282958984375, local kl: 0.4481833577156067 global kl: 1.1478559258648602e-07
Training NeuroLinear-bnn for 50 steps...
300 Training SNP - Freeform for 50 steps...
Average nll: 2.5684478282928467, mse: 129.2842254638672, local kl: 1.419995903968811 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
320 Training SNP - Freeform for 50 steps...
Average nll: 2.505589485168457, mse: 123.39518737792969, local kl: 0.7410065531730652 global kl: 0.0002807065029628575
Training NeuroLinear-bnn for 50 steps...
340 Training SNP - Freeform for 50 steps...
Average nll: 2.4901835918426514, mse: 120.65950012207031, local kl: 0.5472608804702759 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
360 Training SNP - Freeform for 50 steps...
Average nll: 2.4408977031707764, mse: 111.14115905761719, local kl: 0.7920466661453247 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
380 Training SNP - Freeform for 50 steps...
Average nll: 2.5151147842407227, mse: 115.68700408935547, local kl: 49319.125 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
400 Training SNP - Freeform for 50 steps...
Average nll: 2.3899805545806885, mse: 98.93582916259766, local kl: 0.4856698513031006 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
420 Training SNP - Freeform for 50 steps...
Average nll: 3.2279696464538574, mse: 114.90121459960938, local kl: 0.2538774609565735 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
440 Training SNP - Freeform for 50 steps...
Average nll: 3.278707265853882, mse: 109.36270141601562, local kl: 3058.288330078125 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
460 Training SNP - Freeform for 50 steps...
Average nll: 2.5420312881469727, mse: 101.93122100830078, local kl: 0.4104650020599365 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
480 Training SNP - Freeform for 50 steps...
Average nll: 2.598275661468506, mse: 111.90116119384766, local kl: 0.09307189285755157 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
500 Training SNP - Freeform for 50 steps...
Average nll: 2.5460731983184814, mse: 100.91336059570312, local kl: 131.8262481689453 global kl: 5.953659751867235e-07
Training NeuroLinear-bnn for 50 steps...
520 Training SNP - Freeform for 50 steps...
Average nll: 2.494614839553833, mse: 88.48028564453125, local kl: 6.043760776519775 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
540 Training SNP - Freeform for 50 steps...
Average nll: 2.533694267272949, mse: 89.06641387939453, local kl: 130.84466552734375 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
560 Training SNP - Freeform for 50 steps...
Average nll: 2.5802383422851562, mse: 85.60690307617188, local kl: 0.0875980481505394 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
580 Training SNP - Freeform for 50 steps...
Average nll: 2.488624334335327, mse: 84.81887817382812, local kl: 2.4249513149261475 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
600 Training SNP - Freeform for 50 steps...
Average nll: 2.5054638385772705, mse: 82.95972442626953, local kl: 10.620789527893066 global kl: 9.530267561785877e-05
Training NeuroLinear-bnn for 50 steps...
620 Training SNP - Freeform for 50 steps...
Average nll: 2.4664158821105957, mse: 86.43817138671875, local kl: 0.24324306845664978 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
640 Training SNP - Freeform for 50 steps...
Average nll: 2.466892957687378, mse: 82.25923156738281, local kl: 0.09455734491348267 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
660 Training SNP - Freeform for 50 steps...
Average nll: 2.6082980632781982, mse: 90.4946517944336, local kl: 0.10287132114171982 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
680 Training SNP - Freeform for 50 steps...
Average nll: 2.3994710445404053, mse: 79.16094970703125, local kl: 1.1385483741760254 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
700 Training SNP - Freeform for 50 steps...
Average nll: 2.4836831092834473, mse: 92.56512451171875, local kl: 7.845297813415527 global kl: 1.4748296734978794e-06
Training NeuroLinear-bnn for 50 steps...
720 Training SNP - Freeform for 50 steps...
Average nll: 2.536323308944702, mse: 79.20594024658203, local kl: 0.09799874573945999 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
740 Training SNP - Freeform for 50 steps...
Average nll: 2.4911441802978516, mse: 84.968505859375, local kl: 0.14120064675807953 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
760 Training SNP - Freeform for 50 steps...
Average nll: 2.484999418258667, mse: 78.57018280029297, local kl: 0.4608880877494812 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
780 Training SNP - Freeform for 50 steps...
Average nll: 2.4570565223693848, mse: 75.46786499023438, local kl: 0.12342169880867004 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
800 Training SNP - Freeform for 50 steps...
Average nll: 2.6004462242126465, mse: 86.81710815429688, local kl: 1.469112515449524 global kl: 0.013154028914868832
Training NeuroLinear-bnn for 50 steps...
820 Training SNP - Freeform for 50 steps...
Average nll: 2.4179601669311523, mse: 75.82429504394531, local kl: 0.16803784668445587 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
840 Training SNP - Freeform for 50 steps...
Average nll: 6.012154579162598, mse: 89.14623260498047, local kl: 0.13017645478248596 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
860 Training SNP - Freeform for 50 steps...
Average nll: 2.545712471008301, mse: 90.52790832519531, local kl: 0.03500144183635712 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
880 Training SNP - Freeform for 50 steps...
Average nll: 2.885044574737549, mse: 86.21686553955078, local kl: 0.4555032253265381 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
900 Training SNP - Freeform for 50 steps...
Average nll: 2.697780132293701, mse: 98.08269500732422, local kl: 0.07928407192230225 global kl: 7.580258056805178e-07
Training NeuroLinear-bnn for 50 steps...
920 Training SNP - Freeform for 50 steps...
Average nll: 2.5451886653900146, mse: 82.41370391845703, local kl: 2.0340566635131836 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
940 Training SNP - Freeform for 50 steps...
Average nll: 2.660517692565918, mse: 90.90184783935547, local kl: 2167.78076171875 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
960 Training SNP - Freeform for 50 steps...
Average nll: 2.4652822017669678, mse: 79.43191528320312, local kl: 0.2305336594581604 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
980 Training SNP - Freeform for 50 steps...
Average nll: 2.5180838108062744, mse: 80.74398803710938, local kl: 0.36226534843444824 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1000 Training SNP - Freeform for 50 steps...
Average nll: 2.466574192047119, mse: 75.80061340332031, local kl: 0.555678129196167 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1020 Training SNP - Freeform for 50 steps...
Average nll: 2.498965263366699, mse: 77.97142028808594, local kl: 0.09502742439508438 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1040 Training SNP - Freeform for 50 steps...
Average nll: 2.5466558933258057, mse: 87.09977722167969, local kl: 0.3061549961566925 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1060 Training SNP - Freeform for 50 steps...
Average nll: 2.5566511154174805, mse: 78.4106216430664, local kl: 1.0774052143096924 global kl: 0.0001986495772143826
Training NeuroLinear-bnn for 50 steps...
1080 Training SNP - Freeform for 50 steps...
Average nll: 2.6367347240448, mse: 92.28193664550781, local kl: 0.06911370158195496 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1100 Training SNP - Freeform for 50 steps...
Average nll: 2.5494544506073, mse: 78.69898223876953, local kl: 0.28762558102607727 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1120 Training SNP - Freeform for 50 steps...
Average nll: 2.581153154373169, mse: 85.97676849365234, local kl: 2301.600830078125 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1140 Training SNP - Freeform for 50 steps...
Average nll: 2.5969817638397217, mse: 82.68778991699219, local kl: 0.0886351615190506 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1160 Training SNP - Freeform for 50 steps...
Average nll: 2.5127952098846436, mse: 78.7823257446289, local kl: 0.09041357785463333 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1180 Training SNP - Freeform for 50 steps...
Average nll: 2.5766589641571045, mse: 83.60899353027344, local kl: 0.052699994295835495 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1200 Training SNP - Freeform for 50 steps...
Average nll: 2.4736194610595703, mse: 81.26473999023438, local kl: 0.06683254987001419 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1220 Training SNP - Freeform for 50 steps...
Average nll: 2.5148568153381348, mse: 77.37602233886719, local kl: 0.29025277495384216 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1240 Training SNP - Freeform for 50 steps...
Average nll: 2.6067633628845215, mse: 91.18827056884766, local kl: 159.32923889160156 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1260 Training SNP - Freeform for 50 steps...
Average nll: 2.5329275131225586, mse: 82.15402221679688, local kl: 0.8537045121192932 global kl: 6.220401246537222e-06
Training NeuroLinear-bnn for 50 steps...
1280 Training SNP - Freeform for 50 steps...
Average nll: 2.4504551887512207, mse: 80.07794189453125, local kl: 0.34913304448127747 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1300 Training SNP - Freeform for 50 steps...
Average nll: 2.524010181427002, mse: 83.48165893554688, local kl: 0.8546801209449768 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1320 Training SNP - Freeform for 50 steps...
Average nll: 2.4958341121673584, mse: 78.27450561523438, local kl: 0.8634567856788635 global kl: 0.012403187341988087
Training NeuroLinear-bnn for 50 steps...
1340 Training SNP - Freeform for 50 steps...
Average nll: 2.472973346710205, mse: 79.78190612792969, local kl: 0.09306897968053818 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1360 Training SNP - Freeform for 50 steps...
Average nll: 2.460934638977051, mse: 79.5693130493164, local kl: 0.17202846705913544 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1380 Training SNP - Freeform for 50 steps...
Average nll: 2.470592737197876, mse: 84.10237121582031, local kl: 0.7621386647224426 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1400 Training SNP - Freeform for 50 steps...
Average nll: 2.5739030838012695, mse: 84.83343505859375, local kl: 0.15908227860927582 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1420 Training SNP - Freeform for 50 steps...
Average nll: 2.5134711265563965, mse: 83.8968276977539, local kl: 2.2949817180633545 global kl: 0.0012195032322779298
Training NeuroLinear-bnn for 50 steps...
1440 Training SNP - Freeform for 50 steps...
Average nll: 2.5783307552337646, mse: 91.21034240722656, local kl: 437.72479248046875 global kl: 9.160035230326713e-11
Training NeuroLinear-bnn for 50 steps...
1460 Training SNP - Freeform for 50 steps...
Average nll: 2.496212959289551, mse: 87.36735534667969, local kl: 0.20834974944591522 global kl: 1.7147292510344414e-06
Training NeuroLinear-bnn for 50 steps...
1480 Training SNP - Freeform for 50 steps...
Average nll: 2.4920928478240967, mse: 82.41979217529297, local kl: 0.6421018242835999 global kl: 0.019356176257133484
Training NeuroLinear-bnn for 50 steps...
1500 Training SNP - Freeform for 50 steps...
Average nll: 2.466522455215454, mse: 82.34548950195312, local kl: 0.06472976505756378 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1520 Training SNP - Freeform for 50 steps...
Average nll: 2.508551836013794, mse: 76.9550552368164, local kl: 0.1921239048242569 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1540 Training SNP - Freeform for 50 steps...
Average nll: 2.5523040294647217, mse: 92.19522094726562, local kl: 0.2722793221473694 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1560 Training SNP - Freeform for 50 steps...
Average nll: 2.509793519973755, mse: 85.75205993652344, local kl: 0.09388987720012665 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1580 Training SNP - Freeform for 50 steps...
Average nll: 2.369264841079712, mse: 79.1649398803711, local kl: 0.16086989641189575 global kl: 2.1141463548701722e-06
Training NeuroLinear-bnn for 50 steps...
1600 Training SNP - Freeform for 50 steps...
Average nll: 2.5804455280303955, mse: 85.17894744873047, local kl: 0.469111829996109 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1620 Training SNP - Freeform for 50 steps...
Average nll: 2.4833147525787354, mse: 87.66256713867188, local kl: 0.17973394691944122 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1640 Training SNP - Freeform for 50 steps...
Average nll: 2.517977237701416, mse: 80.85000610351562, local kl: 1.0780458450317383 global kl: 1.2266938938410021e-05
Training NeuroLinear-bnn for 50 steps...
1660 Training SNP - Freeform for 50 steps...
Average nll: 2.5702948570251465, mse: 86.52263641357422, local kl: 0.1527978479862213 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1680 Training SNP - Freeform for 50 steps...
Average nll: 2.521615743637085, mse: 79.60993957519531, local kl: 2.958216905593872 global kl: 8.045671506806684e-07
Training NeuroLinear-bnn for 50 steps...
1700 Training SNP - Freeform for 50 steps...
Average nll: 2.548192024230957, mse: 83.77892303466797, local kl: 0.0827503427863121 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1720 Training SNP - Freeform for 50 steps...
Average nll: 2.541680097579956, mse: 81.85768127441406, local kl: 0.17460164427757263 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1740 Training SNP - Freeform for 50 steps...
Average nll: 2.521469831466675, mse: 81.28446960449219, local kl: 0.15289825201034546 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1760 Training SNP - Freeform for 50 steps...
Average nll: 2.4790079593658447, mse: 77.48767852783203, local kl: 0.08261960744857788 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1780 Training SNP - Freeform for 50 steps...
Average nll: 2.490173816680908, mse: 77.20011901855469, local kl: 0.5875442028045654 global kl: 0.024830427020788193
Training NeuroLinear-bnn for 50 steps...
1800 Training SNP - Freeform for 50 steps...
Average nll: 2.506850481033325, mse: 79.236328125, local kl: 0.09400438517332077 global kl: 8.847889088059446e-10
Training NeuroLinear-bnn for 50 steps...
1820 Training SNP - Freeform for 50 steps...
Average nll: 2.4625885486602783, mse: 77.95173645019531, local kl: 0.2568739056587219 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1840 Training SNP - Freeform for 50 steps...
Average nll: 2.4700162410736084, mse: 72.1074447631836, local kl: 0.2930769622325897 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1860 Training SNP - Freeform for 50 steps...
Average nll: 2.4429636001586914, mse: 76.0475082397461, local kl: 1.5933254957199097 global kl: 1.167979007732356e-05
Training NeuroLinear-bnn for 50 steps...
1880 Training SNP - Freeform for 50 steps...
Average nll: 2.4206085205078125, mse: 75.86341094970703, local kl: 0.0910179391503334 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1900 Training SNP - Freeform for 50 steps...
Average nll: 2.519989013671875, mse: 76.76382446289062, local kl: 0.11446163058280945 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1920 Training SNP - Freeform for 50 steps...
Average nll: 2.4756720066070557, mse: 79.7070541381836, local kl: 0.06238197162747383 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1940 Training SNP - Freeform for 50 steps...
Average nll: 2.3841474056243896, mse: 72.25299835205078, local kl: 0.12665769457817078 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1960 Training SNP - Freeform for 50 steps...
Average nll: 2.4721803665161133, mse: 80.3380355834961, local kl: 0.16634725034236908 global kl: 1.1109867728009704e-06
Training NeuroLinear-bnn for 50 steps...
1980 Training SNP - Freeform for 50 steps...
Average nll: 2.4127581119537354, mse: 73.75731658935547, local kl: 0.25556716322898865 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
2000 Training SNP - Freeform for 50 steps...
Average nll: 2.4823999404907227, mse: 84.21133422851562, local kl: 0.0958174541592598 global kl: 8.30409480840899e-05
---------------------------------------------------
---------------------------------------------------
0.5_1 bandit completed after 537.8340775966644 seconds.
---------------------------------------------------
  0) SNP - Freeform      |		 cummulative regret = 4400.197274833656.
  1) NeuroLinear         |		 cummulative regret = 6166.849028298944.
  2) MultitaskGP         |		 cummulative regret = 7391.461052616658.
  3) SNP - Attentive GP  |		 cummulative regret = 47862.754600343214.
  4) Uniform Sampling    |		 cummulative regret = 58336.0644026861.
---------------------------------------------------
---------------------------------------------------
  0) NeuroLinear         |		 simple regret = 70.41167503351433.
  1) SNP - Freeform      |		 simple regret = 155.08649122381692.
  2) MultitaskGP         |		 simple regret = 219.8613631487624.
  3) SNP - Attentive GP  |		 simple regret = 8207.887245995793.
  4) Uniform Sampling    |		 simple regret = 14557.405510451586.
---------------------------------------------------
---------------------------------------------------
Initializing model NeuroLinear-bnn.
Initializing model SNP - Attentive GP.
Initializing model SNP - Freeform.
Training NeuroLinear-bnn for 50 steps...
20 Training SNP - Freeform for 50 steps...
Average nll: 6.762989044189453, mse: 214.7725830078125, local kl: 0.009014076553285122 global kl: 0.079348623752594
Training NeuroLinear-bnn for 50 steps...
40 Training SNP - Freeform for 50 steps...
Average nll: 4.3497514724731445, mse: 351.95074462890625, local kl: 0.017598111182451248 global kl: 0.005329011473804712
Training NeuroLinear-bnn for 50 steps...
60 Training SNP - Freeform for 50 steps...
Average nll: 4.304250240325928, mse: 321.8254699707031, local kl: 0.02956279367208481 global kl: 0.0026277475990355015
Training NeuroLinear-bnn for 50 steps...
80 Training SNP - Freeform for 50 steps...
Average nll: 3.315439224243164, mse: 241.95907592773438, local kl: 0.33192184567451477 global kl: 0.001214540214277804
Training NeuroLinear-bnn for 50 steps...
100 Training SNP - Freeform for 50 steps...
Average nll: 5.689150333404541, mse: 105.4231948852539, local kl: 1.4692792892456055 global kl: 0.0006233938620425761
Training NeuroLinear-bnn for 50 steps...
120 Training SNP - Freeform for 50 steps...
Average nll: 3.065526008605957, mse: 135.17575073242188, local kl: 0.1069241464138031 global kl: 2.9135262593626976e-06
Training NeuroLinear-bnn for 50 steps...
140 Training SNP - Freeform for 50 steps...
Average nll: 2.770930528640747, mse: 131.21035766601562, local kl: 2.2439160346984863 global kl: 0.00014998266124166548
Training NeuroLinear-bnn for 50 steps...
160 Training SNP - Freeform for 50 steps...
Average nll: 2.4490301609039307, mse: 139.2729949951172, local kl: 0.07485726475715637 global kl: 4.286198773684191e-08
Training NeuroLinear-bnn for 50 steps...
180 Training SNP - Freeform for 50 steps...
Average nll: 2.561333656311035, mse: 150.57211303710938, local kl: 1.9248172044754028 global kl: 1.9805418105534045e-07
Training NeuroLinear-bnn for 50 steps...
200 Training SNP - Freeform for 50 steps...
Average nll: 2.4237701892852783, mse: 133.4380645751953, local kl: 0.37163153290748596 global kl: 2.2819945932894825e-09
Training NeuroLinear-bnn for 50 steps...
220 Training SNP - Freeform for 50 steps...
Average nll: 2.369955062866211, mse: 111.08694458007812, local kl: 0.06259825080633163 global kl: 4.376176576670332e-08
Training NeuroLinear-bnn for 50 steps...
240 Training SNP - Freeform for 50 steps...
Average nll: 3.899188280105591, mse: 95.82970428466797, local kl: 0.09448011219501495 global kl: 4.0313960880666855e-07
Training NeuroLinear-bnn for 50 steps...
260 Training SNP - Freeform for 50 steps...
Average nll: 2.382448434829712, mse: 105.3316421508789, local kl: 0.6598722338676453 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
280 Training SNP - Freeform for 50 steps...
Average nll: 2.434317111968994, mse: 98.8801498413086, local kl: 0.08823877573013306 global kl: 1.0838520836387033e-07
Training NeuroLinear-bnn for 50 steps...
300 Training SNP - Freeform for 50 steps...
Average nll: 2.4469103813171387, mse: 98.95320129394531, local kl: 0.44223907589912415 global kl: 2.2927457621335634e-07
Training NeuroLinear-bnn for 50 steps...
320 Training SNP - Freeform for 50 steps...
Average nll: 2.4142885208129883, mse: 91.06682586669922, local kl: 0.039241887629032135 global kl: 1.371990236975762e-07
Training NeuroLinear-bnn for 50 steps...
340 Training SNP - Freeform for 50 steps...
Average nll: 2.550179958343506, mse: 99.29718780517578, local kl: 0.03716794773936272 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
360 Training SNP - Freeform for 50 steps...
Average nll: 2.4591920375823975, mse: 85.4695816040039, local kl: 0.06256521493196487 global kl: 4.29877536589629e-06
Training NeuroLinear-bnn for 50 steps...
380 Training SNP - Freeform for 50 steps...
Average nll: 2.418712854385376, mse: 81.98641967773438, local kl: 0.2310071587562561 global kl: 1.4130506542642252e-08
Training NeuroLinear-bnn for 50 steps...
400 Training SNP - Freeform for 50 steps...
Average nll: 2.599780321121216, mse: 86.7420425415039, local kl: 0.061984118074178696 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
420 Training SNP - Freeform for 50 steps...
Average nll: 2.470144748687744, mse: 83.24882507324219, local kl: 0.05476594343781471 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
440 Training SNP - Freeform for 50 steps...
Average nll: 2.4837048053741455, mse: 88.95023345947266, local kl: 0.0405513197183609 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
460 Training SNP - Freeform for 50 steps...
Average nll: 2.528843879699707, mse: 84.18107604980469, local kl: 0.06610378623008728 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
480 Training SNP - Freeform for 50 steps...
Average nll: 2.4815561771392822, mse: 85.84759521484375, local kl: 0.038523975759744644 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
500 Training SNP - Freeform for 50 steps...
Average nll: 2.437922716140747, mse: 76.85533142089844, local kl: 0.11667510122060776 global kl: 2.2652882307738764e-06
Training NeuroLinear-bnn for 50 steps...
520 Training SNP - Freeform for 50 steps...
Average nll: 2.4311342239379883, mse: 76.2635498046875, local kl: 0.17845426499843597 global kl: 8.83914344740333e-06
Training NeuroLinear-bnn for 50 steps...
540 Training SNP - Freeform for 50 steps...
Average nll: 2.5301084518432617, mse: 87.07701110839844, local kl: 0.0325067937374115 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
560 Training SNP - Freeform for 50 steps...
Average nll: 2.496647596359253, mse: 77.90460968017578, local kl: 0.04132075607776642 global kl: 1.245910425495822e-06
Training NeuroLinear-bnn for 50 steps...
580 Training SNP - Freeform for 50 steps...
Average nll: 2.496426582336426, mse: 79.35222625732422, local kl: 0.0521785207092762 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
600 Training SNP - Freeform for 50 steps...
Average nll: 2.5144505500793457, mse: 79.97642517089844, local kl: 0.3343513011932373 global kl: 0.0001738459977786988
Training NeuroLinear-bnn for 50 steps...
620 Training SNP - Freeform for 50 steps...
Average nll: 2.519240140914917, mse: 76.24942779541016, local kl: 0.12258781492710114 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
640 Training SNP - Freeform for 50 steps...
Average nll: 2.506235361099243, mse: 82.65391540527344, local kl: 0.08341799676418304 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
660 Training SNP - Freeform for 50 steps...
Average nll: 2.5034279823303223, mse: 78.31695556640625, local kl: 0.03859831392765045 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
680 Training SNP - Freeform for 50 steps...
Average nll: 2.4392788410186768, mse: 87.88469696044922, local kl: 0.055660806596279144 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
700 Training SNP - Freeform for 50 steps...
Average nll: 2.4806089401245117, mse: 84.92369079589844, local kl: 0.029394112527370453 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
720 Training SNP - Freeform for 50 steps...
Average nll: 2.5007386207580566, mse: 90.47624206542969, local kl: 0.023544101044535637 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
740 Training SNP - Freeform for 50 steps...
Average nll: 2.4870729446411133, mse: 86.62966918945312, local kl: 0.0626775473356247 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
760 Training SNP - Freeform for 50 steps...
Average nll: 2.5209805965423584, mse: 87.1420669555664, local kl: 0.016190510243177414 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
780 Training SNP - Freeform for 50 steps...
Average nll: 2.4729764461517334, mse: 82.00289916992188, local kl: 0.19952057301998138 global kl: 5.93329532421194e-05
Training NeuroLinear-bnn for 50 steps...
800 Training SNP - Freeform for 50 steps...
Average nll: 2.5433578491210938, mse: 81.24165344238281, local kl: 0.04212842136621475 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
820 Training SNP - Freeform for 50 steps...
Average nll: 2.5218355655670166, mse: 83.71631622314453, local kl: 0.1260073035955429 global kl: 8.805084689811338e-06
Training NeuroLinear-bnn for 50 steps...
840 Training SNP - Freeform for 50 steps...
Average nll: 2.503514051437378, mse: 77.02679443359375, local kl: 0.032520465552806854 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
860 Training SNP - Freeform for 50 steps...
Average nll: 2.4501771926879883, mse: 79.34534454345703, local kl: 0.07871861755847931 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
880 Training SNP - Freeform for 50 steps...
Average nll: 2.5461297035217285, mse: 82.17642211914062, local kl: 0.19630132615566254 global kl: 3.561715493560769e-05
Training NeuroLinear-bnn for 50 steps...
900 Training SNP - Freeform for 50 steps...
Average nll: 2.544585943222046, mse: 89.6808090209961, local kl: 0.30691957473754883 global kl: 1.690315912128426e-05
Training NeuroLinear-bnn for 50 steps...
920 Training SNP - Freeform for 50 steps...
Average nll: 2.5919082164764404, mse: 84.73133850097656, local kl: 0.16433033347129822 global kl: 2.0417632185854018e-05
Training NeuroLinear-bnn for 50 steps...
940 Training SNP - Freeform for 50 steps...
Average nll: 2.461014747619629, mse: 76.85515594482422, local kl: 0.11747957020998001 global kl: 0.00021173486311454326
Training NeuroLinear-bnn for 50 steps...
960 Training SNP - Freeform for 50 steps...
Average nll: 2.6102118492126465, mse: 82.84262084960938, local kl: 0.02741057239472866 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
980 Training SNP - Freeform for 50 steps...
Average nll: 2.48056960105896, mse: 83.97003936767578, local kl: 0.04159300774335861 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1000 Training SNP - Freeform for 50 steps...
Average nll: 2.4719676971435547, mse: 83.337158203125, local kl: 0.048561498522758484 global kl: 9.533576303510927e-06
Training NeuroLinear-bnn for 50 steps...
1020 Training SNP - Freeform for 50 steps...
Average nll: 2.475013494491577, mse: 79.22117614746094, local kl: 0.09612203389406204 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1040 Training SNP - Freeform for 50 steps...
Average nll: 2.4653687477111816, mse: 78.95515441894531, local kl: 0.031009778380393982 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1060 Training SNP - Freeform for 50 steps...
Average nll: 2.4677786827087402, mse: 82.10443115234375, local kl: 0.08503440022468567 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1080 Training SNP - Freeform for 50 steps...
Average nll: 2.3404245376586914, mse: 77.38145446777344, local kl: 7.2137532234191895 global kl: 2.370094534853706e-06
Training NeuroLinear-bnn for 50 steps...
1100 Training SNP - Freeform for 50 steps...
Average nll: 2.3989615440368652, mse: 77.04975891113281, local kl: 0.09496324509382248 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1120 Training SNP - Freeform for 50 steps...
Average nll: 2.436159133911133, mse: 80.02452850341797, local kl: 0.09671230614185333 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1140 Training SNP - Freeform for 50 steps...
Average nll: 2.371269464492798, mse: 73.65385437011719, local kl: 0.4455569088459015 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1160 Training SNP - Freeform for 50 steps...
Average nll: 2.403400182723999, mse: 76.93775939941406, local kl: 0.2384614497423172 global kl: 0.0023917010985314846
Training NeuroLinear-bnn for 50 steps...
1180 Training SNP - Freeform for 50 steps...
Average nll: 2.458580255508423, mse: 79.45972442626953, local kl: 0.1202429011464119 global kl: 1.1136118700960651e-05
Training NeuroLinear-bnn for 50 steps...
1200 Training SNP - Freeform for 50 steps...
Average nll: 2.325488805770874, mse: 74.63617706298828, local kl: 0.38315218687057495 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1220 Training SNP - Freeform for 50 steps...
Average nll: 2.2881572246551514, mse: 68.76351165771484, local kl: 0.02426811307668686 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1240 Training SNP - Freeform for 50 steps...
Average nll: 2.3926384449005127, mse: 75.2972183227539, local kl: 0.12629875540733337 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1260 Training SNP - Freeform for 50 steps...
Average nll: 2.3967151641845703, mse: 77.35427856445312, local kl: 0.05266840383410454 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1280 Training SNP - Freeform for 50 steps...
Average nll: 2.340921640396118, mse: 78.1362075805664, local kl: 0.642682671546936 global kl: 0.03747912868857384
Training NeuroLinear-bnn for 50 steps...
1300 Training SNP - Freeform for 50 steps...
Average nll: 2.3940470218658447, mse: 76.64071655273438, local kl: 0.3409910500049591 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1320 Training SNP - Freeform for 50 steps...
Average nll: 2.4251232147216797, mse: 81.34961700439453, local kl: 544.973388671875 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1340 Training SNP - Freeform for 50 steps...
Average nll: 2.3872358798980713, mse: 72.29169464111328, local kl: 0.245911106467247 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1360 Training SNP - Freeform for 50 steps...
Average nll: 2.4676589965820312, mse: 84.63390350341797, local kl: 0.04746698960661888 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1380 Training SNP - Freeform for 50 steps...
Average nll: 2.3154044151306152, mse: 71.84961700439453, local kl: 0.13843174278736115 global kl: 2.8162554510657856e-09
Training NeuroLinear-bnn for 50 steps...
1400 Training SNP - Freeform for 50 steps...
Average nll: 2.338890552520752, mse: 73.0473861694336, local kl: 0.5962059497833252 global kl: 0.03882807120680809
Training NeuroLinear-bnn for 50 steps...
1420 Training SNP - Freeform for 50 steps...
Average nll: 2.312070369720459, mse: 75.87211608886719, local kl: 0.05452711135149002 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1440 Training SNP - Freeform for 50 steps...
Average nll: 2.400763988494873, mse: 80.98008728027344, local kl: 0.06152292340993881 global kl: 6.164119099594245e-07
Training NeuroLinear-bnn for 50 steps...
1460 Training SNP - Freeform for 50 steps...
Average nll: 2.364076614379883, mse: 72.4758071899414, local kl: 0.06580064445734024 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1480 Training SNP - Freeform for 50 steps...
Average nll: 2.380014657974243, mse: 70.10009765625, local kl: 0.040874890983104706 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1500 Training SNP - Freeform for 50 steps...
Average nll: 2.3729159832000732, mse: 76.74051666259766, local kl: 0.09608286619186401 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1520 Training SNP - Freeform for 50 steps...
Average nll: 2.3933136463165283, mse: 75.1297836303711, local kl: 0.04167759045958519 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1540 Training SNP - Freeform for 50 steps...
Average nll: 2.306975841522217, mse: 74.13919067382812, local kl: 0.07555060088634491 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1560 Training SNP - Freeform for 50 steps...
Average nll: 2.2829926013946533, mse: 74.32135009765625, local kl: 0.04649477079510689 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1580 Training SNP - Freeform for 50 steps...
Average nll: 2.3565750122070312, mse: 74.21533966064453, local kl: 0.033183760941028595 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1600 Training SNP - Freeform for 50 steps...
Average nll: 2.222683906555176, mse: 73.7978515625, local kl: 0.02767210453748703 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1620 Training SNP - Freeform for 50 steps...
Average nll: 2.288482666015625, mse: 72.7662582397461, local kl: 0.040828824043273926 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1640 Training SNP - Freeform for 50 steps...
Average nll: 2.2241344451904297, mse: 69.76924133300781, local kl: 0.043106451630592346 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1660 Training SNP - Freeform for 50 steps...
Average nll: 2.4423744678497314, mse: 71.35629272460938, local kl: 0.6000491976737976 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1680 Training SNP - Freeform for 50 steps...
Average nll: 2.3203840255737305, mse: 71.55393981933594, local kl: 0.03126922994852066 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1700 Training SNP - Freeform for 50 steps...
Average nll: 2.1111583709716797, mse: 71.5715103149414, local kl: 0.09920699149370193 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1720 Training SNP - Freeform for 50 steps...
Average nll: 2.187767744064331, mse: 69.27469635009766, local kl: 0.025130001828074455 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1740 Training SNP - Freeform for 50 steps...
Average nll: 2.1399054527282715, mse: 68.20899200439453, local kl: 0.1247720792889595 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1760 Training SNP - Freeform for 50 steps...
Average nll: 2.1665897369384766, mse: 70.6439437866211, local kl: 0.17893457412719727 global kl: 1.1703091473691529e-07
Training NeuroLinear-bnn for 50 steps...
1780 Training SNP - Freeform for 50 steps...
Average nll: 2.060899257659912, mse: 65.0951919555664, local kl: 0.14709998667240143 global kl: 2.3335512651101453e-06
Training NeuroLinear-bnn for 50 steps...
1800 Training SNP - Freeform for 50 steps...
Average nll: 2.145251989364624, mse: 72.89887237548828, local kl: 0.18982811272144318 global kl: 0.0023250365629792213
Training NeuroLinear-bnn for 50 steps...
1820 Training SNP - Freeform for 50 steps...
Average nll: 2.094878673553467, mse: 69.4060287475586, local kl: 0.03262694552540779 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1840 Training SNP - Freeform for 50 steps...
Average nll: 2.0678935050964355, mse: 65.8034896850586, local kl: 0.035021159797906876 global kl: 2.26676752390631e-06
Training NeuroLinear-bnn for 50 steps...
1860 Training SNP - Freeform for 50 steps...
Average nll: 1.9390099048614502, mse: 72.90364837646484, local kl: 0.15293775498867035 global kl: 1.2139686987211462e-05
Training NeuroLinear-bnn for 50 steps...
1880 Training SNP - Freeform for 50 steps...
Average nll: 2.0140655040740967, mse: 79.13536834716797, local kl: 1.3491075038909912 global kl: 1.7953186670638388e-06
Training NeuroLinear-bnn for 50 steps...
1900 Training SNP - Freeform for 50 steps...
Average nll: 1.8685945272445679, mse: 74.06376647949219, local kl: 0.09578002244234085 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1920 Training SNP - Freeform for 50 steps...
Average nll: 1.8063064813613892, mse: 76.61641693115234, local kl: 0.21998757123947144 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1940 Training SNP - Freeform for 50 steps...
Average nll: 1.7261358499526978, mse: 71.34044647216797, local kl: 1.629981517791748 global kl: 0.0003638435446191579
Training NeuroLinear-bnn for 50 steps...
1960 Training SNP - Freeform for 50 steps...
Average nll: 1.7536935806274414, mse: 75.50250244140625, local kl: 0.39727577567100525 global kl: 1.0716850738390349e-05
Training NeuroLinear-bnn for 50 steps...
1980 Training SNP - Freeform for 50 steps...
Average nll: 1.7293601036071777, mse: 71.44905090332031, local kl: 0.3579493463039398 global kl: 0.11866024881601334
Training NeuroLinear-bnn for 50 steps...
2000 Training SNP - Freeform for 50 steps...
Average nll: 1.599573016166687, mse: 67.64166259765625, local kl: 0.084197998046875 global kl: 2.8525390050049282e-08
---------------------------------------------------
---------------------------------------------------
0.5_2 bandit completed after 512.6714601516724 seconds.
---------------------------------------------------
  0) MultitaskGP         |		 cummulative regret = 3710.090904048783.
  1) NeuroLinear         |		 cummulative regret = 4059.235858555016.
  2) SNP - Freeform      |		 cummulative regret = 4831.258420053347.
  3) SNP - Attentive GP  |		 cummulative regret = 39669.19544935711.
  4) Uniform Sampling    |		 cummulative regret = 58568.64997320331.
---------------------------------------------------
---------------------------------------------------
  0) SNP - Freeform      |		 simple regret = 298.7471948398794.
  1) MultitaskGP         |		 simple regret = 363.17361270717834.
  2) NeuroLinear         |		 simple regret = 559.5982303602179.
  3) SNP - Attentive GP  |		 simple regret = 6986.682675139774.
  4) Uniform Sampling    |		 simple regret = 14753.26243437628.
---------------------------------------------------
---------------------------------------------------
Initializing model NeuroLinear-bnn.
Initializing model SNP - Attentive GP.
Initializing model SNP - Freeform.
Training NeuroLinear-bnn for 50 steps...
20 Training SNP - Freeform for 50 steps...
Average nll: 8.469804763793945, mse: 385.7143859863281, local kl: 0.03697182238101959 global kl: 0.11189010739326477
Training NeuroLinear-bnn for 50 steps...
40 Training SNP - Freeform for 50 steps...
Average nll: 4.1665496826171875, mse: 246.1144561767578, local kl: 0.051761727780103683 global kl: 0.002506398828700185
Training NeuroLinear-bnn for 50 steps...
60 Training SNP - Freeform for 50 steps...
Average nll: 4.1206889152526855, mse: 219.07855224609375, local kl: 0.018485594540834427 global kl: 0.00021156406728550792
Training NeuroLinear-bnn for 50 steps...
80 Training SNP - Freeform for 50 steps...
Average nll: 3.6417298316955566, mse: 330.8223876953125, local kl: 0.05639087036252022 global kl: 0.005333320703357458
Training NeuroLinear-bnn for 50 steps...
100 Training SNP - Freeform for 50 steps...
Average nll: 7.065913200378418, mse: 494.8811340332031, local kl: 5.42623233795166 global kl: 0.0013536631595343351
Training NeuroLinear-bnn for 50 steps...
120 Training SNP - Freeform for 50 steps...
Average nll: 2.814758539199829, mse: 511.0781555175781, local kl: 0.20083731412887573 global kl: 0.00042601354653015733
Training NeuroLinear-bnn for 50 steps...
140 Training SNP - Freeform for 50 steps...
Average nll: 15.709429740905762, mse: 582.5437622070312, local kl: 0.021206267178058624 global kl: 2.4377426598221064e-05
Training NeuroLinear-bnn for 50 steps...
160 Training SNP - Freeform for 50 steps...
Average nll: 20.598247528076172, mse: 637.6438598632812, local kl: 0.20252883434295654 global kl: 8.071483171079308e-05
Training NeuroLinear-bnn for 50 steps...
180 Training SNP - Freeform for 50 steps...
Average nll: 1.016862392425537, mse: 614.4688720703125, local kl: 0.17632007598876953 global kl: 2.1784708224004135e-05
Training NeuroLinear-bnn for 50 steps...
200 Training SNP - Freeform for 50 steps...
Average nll: 30.336631774902344, mse: 640.20166015625, local kl: 30.765247344970703 global kl: 0.0008678616723045707
Training NeuroLinear-bnn for 50 steps...
220 Training SNP - Freeform for 50 steps...
Average nll: 15.11189079284668, mse: 620.1327514648438, local kl: 0.03058759495615959 global kl: 4.275182163837599e-06
Training NeuroLinear-bnn for 50 steps...
240 Training SNP - Freeform for 50 steps...
Average nll: 10.910540580749512, mse: 649.9132690429688, local kl: 0.023568982258439064 global kl: 1.8912431187345646e-05
Training NeuroLinear-bnn for 50 steps...
260 Training SNP - Freeform for 50 steps...
Average nll: 8.605607032775879, mse: 630.341796875, local kl: 0.016612574458122253 global kl: 2.8264519187359838e-06
Training NeuroLinear-bnn for 50 steps...
280 Training SNP - Freeform for 50 steps...
Average nll: 0.3862566351890564, mse: 605.6221313476562, local kl: 0.009912621229887009 global kl: 1.7594770724826958e-06
Training NeuroLinear-bnn for 50 steps...
300 Training SNP - Freeform for 50 steps...
Average nll: 6.106593132019043, mse: 687.014892578125, local kl: 0.06444050371646881 global kl: 1.6811911336844787e-05
Training NeuroLinear-bnn for 50 steps...
320 Training SNP - Freeform for 50 steps...
Average nll: 0.5274044275283813, mse: 677.330322265625, local kl: 0.01867414079606533 global kl: 1.3446224045310373e-07
Training NeuroLinear-bnn for 50 steps...
340 Training SNP - Freeform for 50 steps...
Average nll: 0.7587118744850159, mse: 699.2109375, local kl: 0.018125412985682487 global kl: 2.157458993679029e-06
Training NeuroLinear-bnn for 50 steps...
360 Training SNP - Freeform for 50 steps...
Average nll: 0.8856523633003235, mse: 688.7711181640625, local kl: 0.04228678345680237 global kl: 4.2430679059179965e-07
Training NeuroLinear-bnn for 50 steps...
380 Training SNP - Freeform for 50 steps...
Average nll: 0.27647605538368225, mse: 672.0131225585938, local kl: 0.010061000473797321 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
400 Training SNP - Freeform for 50 steps...
Average nll: 0.30159667134284973, mse: 703.2664794921875, local kl: 0.0199296735227108 global kl: 1.1211265871224896e-07
Training NeuroLinear-bnn for 50 steps...
420 Training SNP - Freeform for 50 steps...
Average nll: 0.2558387517929077, mse: 686.112548828125, local kl: 11.708724021911621 global kl: 0.00033384139533154666
Training NeuroLinear-bnn for 50 steps...
440 Training SNP - Freeform for 50 steps...
Average nll: 0.34075790643692017, mse: 721.41357421875, local kl: 0.0436224490404129 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
460 Training SNP - Freeform for 50 steps...
Average nll: 0.21254920959472656, mse: 693.0042114257812, local kl: 0.015062539838254452 global kl: 3.260296580265276e-06
Training NeuroLinear-bnn for 50 steps...
480 Training SNP - Freeform for 50 steps...
Average nll: 12.153940200805664, mse: 709.9912719726562, local kl: 0.08812890946865082 global kl: 1.1703828022291418e-05
Training NeuroLinear-bnn for 50 steps...
500 Training SNP - Freeform for 50 steps...
Average nll: 0.20203231275081635, mse: 689.7432250976562, local kl: 6.702854633331299 global kl: 0.0005063716089352965
Training NeuroLinear-bnn for 50 steps...
520 Training SNP - Freeform for 50 steps...
Average nll: 2.1329941749572754, mse: 679.914306640625, local kl: 0.12381181120872498 global kl: 8.918105595512316e-05
Training NeuroLinear-bnn for 50 steps...
540 Training SNP - Freeform for 50 steps...
Average nll: 0.3669187128543854, mse: 710.7288818359375, local kl: 0.025872519239783287 global kl: 4.19378096694345e-07
Training NeuroLinear-bnn for 50 steps...
560 Training SNP - Freeform for 50 steps...
Average nll: 22.571847915649414, mse: 687.9475708007812, local kl: 0.29399001598358154 global kl: 9.961795876733959e-05
Training NeuroLinear-bnn for 50 steps...
580 Training SNP - Freeform for 50 steps...
Average nll: 0.06673765927553177, mse: 693.5540771484375, local kl: 0.10529302805662155 global kl: 1.6714579032850452e-05
Training NeuroLinear-bnn for 50 steps...
600 Training SNP - Freeform for 50 steps...
Average nll: 0.19089734554290771, mse: 733.1436767578125, local kl: 0.06817366927862167 global kl: 1.5880134014878422e-05
Training NeuroLinear-bnn for 50 steps...
620 Training SNP - Freeform for 50 steps...
Average nll: 0.20964315533638, mse: 740.6573486328125, local kl: 0.03152942284941673 global kl: 1.7307482949036057e-06
Training NeuroLinear-bnn for 50 steps...
640 Training SNP - Freeform for 50 steps...
Average nll: 0.18468773365020752, mse: 746.8875122070312, local kl: 0.01204878930002451 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
660 Training SNP - Freeform for 50 steps...
Average nll: 0.08196114748716354, mse: 729.2963256835938, local kl: 351.2696228027344 global kl: 0.0014794159214943647
Training NeuroLinear-bnn for 50 steps...
680 Training SNP - Freeform for 50 steps...
Average nll: 0.25487491488456726, mse: 780.6353149414062, local kl: 0.01487412117421627 global kl: 2.2376018815606358e-09
Training NeuroLinear-bnn for 50 steps...
700 Training SNP - Freeform for 50 steps...
Average nll: 0.2337239384651184, mse: 749.6587524414062, local kl: 7.351313591003418 global kl: 0.0004282761365175247
Training NeuroLinear-bnn for 50 steps...
720 Training SNP - Freeform for 50 steps...
Average nll: 0.20116034150123596, mse: 745.599609375, local kl: 0.3370068371295929 global kl: 2.3631412204849767e-07
Training NeuroLinear-bnn for 50 steps...
740 Training SNP - Freeform for 50 steps...
Average nll: 0.1486469805240631, mse: 742.18896484375, local kl: 0.04076932743191719 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
760 Training SNP - Freeform for 50 steps...
Average nll: 0.12430570274591446, mse: 733.4408569335938, local kl: 255.0469207763672 global kl: 0.0009351305780000985
Training NeuroLinear-bnn for 50 steps...
780 Training SNP - Freeform for 50 steps...
Average nll: 0.18314588069915771, mse: 718.9054565429688, local kl: 0.20975780487060547 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
800 Training SNP - Freeform for 50 steps...
Average nll: 0.035707227885723114, mse: 726.5261840820312, local kl: 361.3553161621094 global kl: 0.0010061279172077775
Training NeuroLinear-bnn for 50 steps...
820 Training SNP - Freeform for 50 steps...
Average nll: 0.04773031547665596, mse: 739.018310546875, local kl: 0.07304809987545013 global kl: 7.874069524405058e-06
Training NeuroLinear-bnn for 50 steps...
840 Training SNP - Freeform for 50 steps...
Average nll: 0.15424048900604248, mse: 763.5274047851562, local kl: 0.024140607565641403 global kl: 3.728839459427036e-08
Training NeuroLinear-bnn for 50 steps...
860 Training SNP - Freeform for 50 steps...
Average nll: 0.2233135998249054, mse: 775.7427368164062, local kl: 0.024403659626841545 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
880 Training SNP - Freeform for 50 steps...
Average nll: 0.20524728298187256, mse: 782.9637451171875, local kl: 344.95294189453125 global kl: 0.0009018712444230914
Training NeuroLinear-bnn for 50 steps...
900 Training SNP - Freeform for 50 steps...
Average nll: 0.06452694535255432, mse: 771.7349853515625, local kl: 209.72154235839844 global kl: 0.000844875699840486
Training NeuroLinear-bnn for 50 steps...
920 Training SNP - Freeform for 50 steps...
Average nll: 0.16956889629364014, mse: 791.4046630859375, local kl: 0.12535814940929413 global kl: 6.789213512092829e-05
Training NeuroLinear-bnn for 50 steps...
940 Training SNP - Freeform for 50 steps...
Average nll: 0.15319541096687317, mse: 788.3909912109375, local kl: 3.432826519012451 global kl: 0.0003927871584892273
Training NeuroLinear-bnn for 50 steps...
960 Training SNP - Freeform for 50 steps...
Average nll: 0.1369677037000656, mse: 790.5662231445312, local kl: 0.020518573001027107 global kl: 3.077410104523892e-09
Training NeuroLinear-bnn for 50 steps...
980 Training SNP - Freeform for 50 steps...
Average nll: 0.10261232405900955, mse: 784.2221069335938, local kl: 0.019505584612488747 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1000 Training SNP - Freeform for 50 steps...
Average nll: 0.11958623677492142, mse: 796.3883056640625, local kl: 0.02145487628877163 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1020 Training SNP - Freeform for 50 steps...
Average nll: 0.02492440678179264, mse: 770.4276733398438, local kl: 0.026451867073774338 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1040 Training SNP - Freeform for 50 steps...
Average nll: 0.01983478292822838, mse: 761.3025512695312, local kl: 0.0480802096426487 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1060 Training SNP - Freeform for 50 steps...
Average nll: 0.15613719820976257, mse: 826.8607177734375, local kl: 0.3528585731983185 global kl: 6.141266203485429e-05
Training NeuroLinear-bnn for 50 steps...
1080 Training SNP - Freeform for 50 steps...
Average nll: 0.13035529851913452, mse: 827.709716796875, local kl: 0.008243096061050892 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1100 Training SNP - Freeform for 50 steps...
Average nll: -0.025763647630810738, mse: 772.1277465820312, local kl: 0.013016884215176105 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1120 Training SNP - Freeform for 50 steps...
Average nll: 0.0343334898352623, mse: 794.9816284179688, local kl: 0.047606296837329865 global kl: 2.4770915842964314e-06
Training NeuroLinear-bnn for 50 steps...
1140 Training SNP - Freeform for 50 steps...
Average nll: 0.05412505194544792, mse: 799.9197387695312, local kl: 0.08855707198381424 global kl: 1.158480671392681e-07
Training NeuroLinear-bnn for 50 steps...
1160 Training SNP - Freeform for 50 steps...
Average nll: 0.13974064588546753, mse: 786.65771484375, local kl: 0.01721012406051159 global kl: 3.024620137681566e-11
Training NeuroLinear-bnn for 50 steps...
1180 Training SNP - Freeform for 50 steps...
Average nll: 0.12245149910449982, mse: 801.967041015625, local kl: 0.03466854244470596 global kl: 6.106257188775999e-08
Training NeuroLinear-bnn for 50 steps...
1200 Training SNP - Freeform for 50 steps...
Average nll: 0.10301366448402405, mse: 811.2642211914062, local kl: 522.4153442382812 global kl: 0.0015709481667727232
Training NeuroLinear-bnn for 50 steps...
1220 Training SNP - Freeform for 50 steps...
Average nll: 0.15809296071529388, mse: 836.8775024414062, local kl: 1.2007520198822021 global kl: 0.00022568601707462221
Training NeuroLinear-bnn for 50 steps...
1240 Training SNP - Freeform for 50 steps...
Average nll: 0.030462348833680153, mse: 790.6102905273438, local kl: 0.0036838341038674116 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1260 Training SNP - Freeform for 50 steps...
Average nll: 0.034711893647909164, mse: 808.2363891601562, local kl: 0.01649845950305462 global kl: 6.2992821767693385e-06
Training NeuroLinear-bnn for 50 steps...
1280 Training SNP - Freeform for 50 steps...
Average nll: 0.0948353111743927, mse: 828.3158569335938, local kl: 0.012591954320669174 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1300 Training SNP - Freeform for 50 steps...
Average nll: 0.12569725513458252, mse: 830.8082885742188, local kl: 322.61163330078125 global kl: 0.0008940672851167619
Training NeuroLinear-bnn for 50 steps...
1320 Training SNP - Freeform for 50 steps...
Average nll: 0.17099082469940186, mse: 864.4717407226562, local kl: 0.012253552675247192 global kl: 8.418724917191867e-08
Training NeuroLinear-bnn for 50 steps...
1340 Training SNP - Freeform for 50 steps...
Average nll: 0.13808423280715942, mse: 845.8095092773438, local kl: 0.006422713398933411 global kl: 8.884167357336992e-08
Training NeuroLinear-bnn for 50 steps...
1360 Training SNP - Freeform for 50 steps...
Average nll: 0.004681806080043316, mse: 799.718505859375, local kl: 276.61553955078125 global kl: 0.0008395884069614112
Training NeuroLinear-bnn for 50 steps...
1380 Training SNP - Freeform for 50 steps...
Average nll: -0.0007882103091105819, mse: 807.2809448242188, local kl: 444.1826171875 global kl: 0.0015861805295571685
Training NeuroLinear-bnn for 50 steps...
1400 Training SNP - Freeform for 50 steps...
Average nll: -0.02669759839773178, mse: 813.4656982421875, local kl: 0.11705166101455688 global kl: 0.00013048438995610923
Training NeuroLinear-bnn for 50 steps...
1420 Training SNP - Freeform for 50 steps...
Average nll: -0.1022678092122078, mse: 782.0054931640625, local kl: 27.216474533081055 global kl: 0.0006369594484567642
Training NeuroLinear-bnn for 50 steps...
1440 Training SNP - Freeform for 50 steps...
Average nll: 0.00838814489543438, mse: 809.1652221679688, local kl: 0.01096532866358757 global kl: 3.901376643966614e-08
Training NeuroLinear-bnn for 50 steps...
1460 Training SNP - Freeform for 50 steps...
Average nll: -0.055504366755485535, mse: 788.7086181640625, local kl: 0.010156511329114437 global kl: 2.5397186576014974e-08
Training NeuroLinear-bnn for 50 steps...
1480 Training SNP - Freeform for 50 steps...
Average nll: -0.07218753546476364, mse: 802.5265502929688, local kl: 0.03222681209445 global kl: 1.7754032910488604e-08
Training NeuroLinear-bnn for 50 steps...
1500 Training SNP - Freeform for 50 steps...
Average nll: -0.019432691857218742, mse: 814.4749145507812, local kl: 158.7078399658203 global kl: 0.0007657367968931794
Training NeuroLinear-bnn for 50 steps...
1520 Training SNP - Freeform for 50 steps...
Average nll: -0.05015223100781441, mse: 805.677490234375, local kl: 3.474156379699707 global kl: 9.539146958559286e-06
Training NeuroLinear-bnn for 50 steps...
1540 Training SNP - Freeform for 50 steps...
Average nll: 0.00651239650323987, mse: 802.998291015625, local kl: 2.0893964767456055 global kl: 0.00019911144045181572
Training NeuroLinear-bnn for 50 steps...
1560 Training SNP - Freeform for 50 steps...
Average nll: 0.04143686592578888, mse: 816.07666015625, local kl: 158.56597900390625 global kl: 0.0007183463894762099
Training NeuroLinear-bnn for 50 steps...
1580 Training SNP - Freeform for 50 steps...
Average nll: -0.018533287569880486, mse: 817.668701171875, local kl: 48.18753433227539 global kl: 3.248643042752519e-05
Training NeuroLinear-bnn for 50 steps...
1600 Training SNP - Freeform for 50 steps...
Average nll: -0.09990566968917847, mse: 794.029052734375, local kl: 0.012857798486948013 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1620 Training SNP - Freeform for 50 steps...
Average nll: 0.0239001102745533, mse: 830.0220336914062, local kl: 0.008902414701879025 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1640 Training SNP - Freeform for 50 steps...
Average nll: 0.10714969784021378, mse: 851.4545288085938, local kl: 0.016431408002972603 global kl: 2.1883437284486718e-07
Training NeuroLinear-bnn for 50 steps...
1660 Training SNP - Freeform for 50 steps...
Average nll: -0.06721429526805878, mse: 799.8536987304688, local kl: 159.67190551757812 global kl: 3.3346834243275225e-05
Training NeuroLinear-bnn for 50 steps...
1680 Training SNP - Freeform for 50 steps...
Average nll: -0.04666401073336601, mse: 815.4701538085938, local kl: 187.19296264648438 global kl: 0.0008260360336862504
Training NeuroLinear-bnn for 50 steps...
1700 Training SNP - Freeform for 50 steps...
Average nll: 1.4499342441558838, mse: 852.8408813476562, local kl: 0.01766195334494114 global kl: 1.6339408714927117e-09
Training NeuroLinear-bnn for 50 steps...
1720 Training SNP - Freeform for 50 steps...
Average nll: -0.07618095725774765, mse: 799.8136596679688, local kl: 0.12047290056943893 global kl: 8.243651245720685e-05
Training NeuroLinear-bnn for 50 steps...
1740 Training SNP - Freeform for 50 steps...
Average nll: 0.2046176940202713, mse: 859.5377197265625, local kl: 0.011587405577301979 global kl: 8.895976577605325e-08
Training NeuroLinear-bnn for 50 steps...
1760 Training SNP - Freeform for 50 steps...
Average nll: -0.042931266129016876, mse: 810.35205078125, local kl: 0.1258738934993744 global kl: 4.6372248263537585e-09
Training NeuroLinear-bnn for 50 steps...
1780 Training SNP - Freeform for 50 steps...
Average nll: -0.0032116728834807873, mse: 829.4442749023438, local kl: 0.00958559475839138 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1800 Training SNP - Freeform for 50 steps...
Average nll: -0.10371999442577362, mse: 802.870361328125, local kl: 0.01586952805519104 global kl: 5.1461506167527205e-09
Training NeuroLinear-bnn for 50 steps...
1820 Training SNP - Freeform for 50 steps...
Average nll: -0.005165600683540106, mse: 824.972412109375, local kl: 0.014603164978325367 global kl: 1.6819306836168835e-08
Training NeuroLinear-bnn for 50 steps...
1840 Training SNP - Freeform for 50 steps...
Average nll: 0.0007208072929643095, mse: 825.565185546875, local kl: 0.025233568623661995 global kl: 8.862914313567671e-08
Training NeuroLinear-bnn for 50 steps...
1860 Training SNP - Freeform for 50 steps...
Average nll: 10.957118034362793, mse: 825.9156494140625, local kl: 0.008402535691857338 global kl: 1.2250557901438697e-08
Training NeuroLinear-bnn for 50 steps...
1880 Training SNP - Freeform for 50 steps...
Average nll: 0.11294365674257278, mse: 854.5052490234375, local kl: 0.07494930922985077 global kl: 2.2683971110382117e-06
Training NeuroLinear-bnn for 50 steps...
1900 Training SNP - Freeform for 50 steps...
Average nll: 0.0005558079574257135, mse: 823.525634765625, local kl: 0.0171327106654644 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1920 Training SNP - Freeform for 50 steps...
Average nll: 0.03548343479633331, mse: 841.7930297851562, local kl: 0.01598227024078369 global kl: 6.619264247831325e-09
Training NeuroLinear-bnn for 50 steps...
1940 Training SNP - Freeform for 50 steps...
Average nll: 0.11785765737295151, mse: 856.412353515625, local kl: 0.019249582663178444 global kl: 1.053248752214131e-06
Training NeuroLinear-bnn for 50 steps...
1960 Training SNP - Freeform for 50 steps...
Average nll: 0.06859877705574036, mse: 851.1386108398438, local kl: 6.315991401672363 global kl: 0.0002743309596553445
Training NeuroLinear-bnn for 50 steps...
1980 Training SNP - Freeform for 50 steps...
Average nll: -0.019947126507759094, mse: 829.24169921875, local kl: 0.006706395652145147 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
2000 Training SNP - Freeform for 50 steps...
Average nll: -0.07115408033132553, mse: 818.594970703125, local kl: 0.012491252273321152 global kl: 0.0
---------------------------------------------------
---------------------------------------------------
0.5_3 bandit completed after 516.5220892429352 seconds.
---------------------------------------------------
  0) MultitaskGP         |		 cummulative regret = 4198.503588094754.
  1) NeuroLinear         |		 cummulative regret = 5429.774082329617.
  2) SNP - Freeform      |		 cummulative regret = 41017.60177582976.
  3) Uniform Sampling    |		 cummulative regret = 57834.78634855501.
  4) SNP - Attentive GP  |		 cummulative regret = 63553.97863533606.
---------------------------------------------------
---------------------------------------------------
  0) NeuroLinear         |		 simple regret = 265.9541908957134.
  1) MultitaskGP         |		 simple regret = 313.39311251015874.
  2) SNP - Freeform      |		 simple regret = 9568.463404914946.
  3) Uniform Sampling    |		 simple regret = 14849.097237935111.
  4) SNP - Attentive GP  |		 simple regret = 14985.755151359128.
---------------------------------------------------
---------------------------------------------------
Initializing model NeuroLinear-bnn.
Initializing model SNP - Attentive GP.
Initializing model SNP - Freeform.
Training NeuroLinear-bnn for 50 steps...
20 Training SNP - Freeform for 50 steps...
Average nll: 9.187387466430664, mse: 326.8687438964844, local kl: 0.0069914525374770164 global kl: 0.024303998798131943
Training NeuroLinear-bnn for 50 steps...
40 Training SNP - Freeform for 50 steps...
Average nll: 4.431732654571533, mse: 412.63104248046875, local kl: 0.029889430850744247 global kl: 0.07819225639104843
Training NeuroLinear-bnn for 50 steps...
60 Training SNP - Freeform for 50 steps...
Average nll: 4.494774341583252, mse: 468.8158264160156, local kl: 0.014713422395288944 global kl: 5.28622986166738e-05
Training NeuroLinear-bnn for 50 steps...
80 Training SNP - Freeform for 50 steps...
Average nll: 3.969992160797119, mse: 317.946533203125, local kl: 0.010595445521175861 global kl: 1.1282138075330295e-05
Training NeuroLinear-bnn for 50 steps...
100 Training SNP - Freeform for 50 steps...
Average nll: 2.416447877883911, mse: 129.27886962890625, local kl: 10.286591529846191 global kl: 0.0007502158987335861
Training NeuroLinear-bnn for 50 steps...
120 Training SNP - Freeform for 50 steps...
Average nll: 1.9345959424972534, mse: 79.11003112792969, local kl: 0.12935400009155273 global kl: 2.2998572148935637e-06
Training NeuroLinear-bnn for 50 steps...
140 Training SNP - Freeform for 50 steps...
Average nll: 2.1198625564575195, mse: 58.99985885620117, local kl: 0.10302992910146713 global kl: 3.8622513898189936e-07
Training NeuroLinear-bnn for 50 steps...
160 Training SNP - Freeform for 50 steps...
Average nll: 11.667994499206543, mse: 43.21083068847656, local kl: 0.06067385524511337 global kl: 5.047954942938304e-08
Training NeuroLinear-bnn for 50 steps...
180 Training SNP - Freeform for 50 steps...
Average nll: 1.8954777717590332, mse: 35.42949295043945, local kl: 6.730868339538574 global kl: 0.0006986829685047269
Training NeuroLinear-bnn for 50 steps...
200 Training SNP - Freeform for 50 steps...
Average nll: 1.444779396057129, mse: 22.14361572265625, local kl: 0.09808387607336044 global kl: 1.2222131751116194e-08
Training NeuroLinear-bnn for 50 steps...
220 Training SNP - Freeform for 50 steps...
Average nll: 6.255839824676514, mse: 30.908246994018555, local kl: 5.387927055358887 global kl: 0.0009755798382684588
Training NeuroLinear-bnn for 50 steps...
240 Training SNP - Freeform for 50 steps...
Average nll: 1.4140074253082275, mse: 27.47895050048828, local kl: 0.06740927696228027 global kl: 6.762293480733206e-08
Training NeuroLinear-bnn for 50 steps...
260 Training SNP - Freeform for 50 steps...
Average nll: 1.4406418800354004, mse: 28.319826126098633, local kl: 0.03518187254667282 global kl: 6.450637357602318e-08
Training NeuroLinear-bnn for 50 steps...
280 Training SNP - Freeform for 50 steps...
Average nll: 1.0797902345657349, mse: 22.94424819946289, local kl: 2.1650192737579346 global kl: 2.5292843019997235e-07
Training NeuroLinear-bnn for 50 steps...
300 Training SNP - Freeform for 50 steps...
Average nll: 1.0582540035247803, mse: 18.403514862060547, local kl: 0.29329150915145874 global kl: 9.56494403681063e-08
Training NeuroLinear-bnn for 50 steps...
320 Training SNP - Freeform for 50 steps...
Average nll: 2.846999406814575, mse: 35.1051025390625, local kl: 2.9217584133148193 global kl: 0.0009380662231706083
Training NeuroLinear-bnn for 50 steps...
340 Training SNP - Freeform for 50 steps...
Average nll: 1.0715335607528687, mse: 39.52189254760742, local kl: 0.12553748488426208 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
360 Training SNP - Freeform for 50 steps...
Average nll: 8.645820617675781, mse: 33.02529525756836, local kl: 31.4650936126709 global kl: 1.1307891327305697e-05
Training NeuroLinear-bnn for 50 steps...
380 Training SNP - Freeform for 50 steps...
Average nll: 1.1019527912139893, mse: 29.261886596679688, local kl: 7.654660224914551 global kl: 6.238990977180947e-07
Training NeuroLinear-bnn for 50 steps...
400 Training SNP - Freeform for 50 steps...
Average nll: 0.9376949071884155, mse: 23.969562530517578, local kl: 0.07794355601072311 global kl: 7.645888899787678e-07
Training NeuroLinear-bnn for 50 steps...
420 Training SNP - Freeform for 50 steps...
Average nll: 28.79251480102539, mse: 27.8355770111084, local kl: 79.95054626464844 global kl: 4.963601895724423e-06
Training NeuroLinear-bnn for 50 steps...
440 Training SNP - Freeform for 50 steps...
Average nll: 0.512233316898346, mse: 21.64344024658203, local kl: 0.17591504752635956 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
460 Training SNP - Freeform for 50 steps...
Average nll: 0.6284840106964111, mse: 27.36605453491211, local kl: 0.13901004195213318 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
480 Training SNP - Freeform for 50 steps...
Average nll: 23.341821670532227, mse: 14.710654258728027, local kl: 0.10027329623699188 global kl: 1.3915826002630638e-07
Training NeuroLinear-bnn for 50 steps...
500 Training SNP - Freeform for 50 steps...
Average nll: 0.5199447274208069, mse: 18.8010311126709, local kl: 0.1357332170009613 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
520 Training SNP - Freeform for 50 steps...
Average nll: 0.6779772043228149, mse: 19.312026977539062, local kl: 0.10403099656105042 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
540 Training SNP - Freeform for 50 steps...
Average nll: 1.2281795740127563, mse: 21.532209396362305, local kl: 0.089422807097435 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
560 Training SNP - Freeform for 50 steps...
Average nll: 0.42046183347702026, mse: 23.167654037475586, local kl: 0.06338537484407425 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
580 Training SNP - Freeform for 50 steps...
Average nll: 0.4428926408290863, mse: 26.381446838378906, local kl: 0.050988875329494476 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
600 Training SNP - Freeform for 50 steps...
Average nll: 0.4706740975379944, mse: 26.289072036743164, local kl: 0.05944948270916939 global kl: 6.180145817324956e-08
Training NeuroLinear-bnn for 50 steps...
620 Training SNP - Freeform for 50 steps...
Average nll: 0.3146068751811981, mse: 23.814733505249023, local kl: 111.82166290283203 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
640 Training SNP - Freeform for 50 steps...
Average nll: 0.3272228240966797, mse: 19.4210147857666, local kl: 0.3911190330982208 global kl: 4.7457487539759313e-07
Training NeuroLinear-bnn for 50 steps...
660 Training SNP - Freeform for 50 steps...
Average nll: 0.23126697540283203, mse: 28.572935104370117, local kl: 0.04601437970995903 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
680 Training SNP - Freeform for 50 steps...
Average nll: 0.18077541887760162, mse: 21.422279357910156, local kl: 1.7388794422149658 global kl: 0.0031148837879300117
Training NeuroLinear-bnn for 50 steps...
700 Training SNP - Freeform for 50 steps...
Average nll: 0.25404077768325806, mse: 22.11577606201172, local kl: 2.205124855041504 global kl: 0.008920364081859589
Training NeuroLinear-bnn for 50 steps...
720 Training SNP - Freeform for 50 steps...
Average nll: 0.13508033752441406, mse: 22.131811141967773, local kl: 0.0601552277803421 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
740 Training SNP - Freeform for 50 steps...
Average nll: 0.08062992990016937, mse: 20.97577667236328, local kl: 0.04520312324166298 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
760 Training SNP - Freeform for 50 steps...
Average nll: 0.07472066581249237, mse: 21.898391723632812, local kl: 228.5687713623047 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
780 Training SNP - Freeform for 50 steps...
Average nll: 0.12186294794082642, mse: 23.347583770751953, local kl: 0.07686148583889008 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
800 Training SNP - Freeform for 50 steps...
Average nll: -0.0016085715033113956, mse: 20.97601318359375, local kl: 1.1562159061431885 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
820 Training SNP - Freeform for 50 steps...
Average nll: 0.05235043540596962, mse: 21.8834171295166, local kl: 0.15145523846149445 global kl: 1.0187862244492862e-05
Training NeuroLinear-bnn for 50 steps...
840 Training SNP - Freeform for 50 steps...
Average nll: 0.08974037319421768, mse: 21.891897201538086, local kl: 0.5400445461273193 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
860 Training SNP - Freeform for 50 steps...
Average nll: -0.1437300741672516, mse: 17.8566951751709, local kl: 0.05723773315548897 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
880 Training SNP - Freeform for 50 steps...
Average nll: 0.02478533238172531, mse: 30.686864852905273, local kl: 0.2943938076496124 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
900 Training SNP - Freeform for 50 steps...
Average nll: -0.08602061122655869, mse: 17.519180297851562, local kl: 0.15503916144371033 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
920 Training SNP - Freeform for 50 steps...
Average nll: 0.04457126557826996, mse: 19.730606079101562, local kl: 0.040007248520851135 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
940 Training SNP - Freeform for 50 steps...
Average nll: -0.0012710165465250611, mse: 21.72254180908203, local kl: 0.062231048941612244 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
960 Training SNP - Freeform for 50 steps...
Average nll: -0.10777643322944641, mse: 20.128446578979492, local kl: 0.07828892767429352 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
980 Training SNP - Freeform for 50 steps...
Average nll: -0.18143931031227112, mse: 19.154930114746094, local kl: 0.13704447448253632 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1000 Training SNP - Freeform for 50 steps...
Average nll: 374.4565734863281, mse: 28.469385147094727, local kl: 0.02792995050549507 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1020 Training SNP - Freeform for 50 steps...
Average nll: -0.029962707310914993, mse: 23.704187393188477, local kl: 149.27047729492188 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1040 Training SNP - Freeform for 50 steps...
Average nll: 179.25363159179688, mse: 25.361764907836914, local kl: 0.058572910726070404 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1060 Training SNP - Freeform for 50 steps...
Average nll: 145.45448303222656, mse: 25.921850204467773, local kl: 0.23276029527187347 global kl: 2.1289116602929425e-07
Training NeuroLinear-bnn for 50 steps...
1080 Training SNP - Freeform for 50 steps...
Average nll: 0.10087783634662628, mse: 31.88949966430664, local kl: 0.05782628059387207 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1100 Training SNP - Freeform for 50 steps...
Average nll: -0.14008890092372894, mse: 21.40066909790039, local kl: 0.24900703132152557 global kl: 7.124644525902113e-06
Training NeuroLinear-bnn for 50 steps...
1120 Training SNP - Freeform for 50 steps...
Average nll: -0.20568092167377472, mse: 22.233240127563477, local kl: 0.05950876325368881 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1140 Training SNP - Freeform for 50 steps...
Average nll: -0.2568369209766388, mse: 21.808576583862305, local kl: 0.14970222115516663 global kl: 3.918528454960324e-05
Training NeuroLinear-bnn for 50 steps...
1160 Training SNP - Freeform for 50 steps...
Average nll: -0.2225847840309143, mse: 25.18083953857422, local kl: 0.03947337716817856 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1180 Training SNP - Freeform for 50 steps...
Average nll: -0.2432912439107895, mse: 26.284788131713867, local kl: 0.053501687943935394 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1200 Training SNP - Freeform for 50 steps...
Average nll: -0.1702476143836975, mse: 28.55919647216797, local kl: 0.06239301711320877 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1220 Training SNP - Freeform for 50 steps...
Average nll: -0.20540353655815125, mse: 25.514694213867188, local kl: 0.06859847903251648 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1240 Training SNP - Freeform for 50 steps...
Average nll: -0.23402249813079834, mse: 28.616426467895508, local kl: 144.16444396972656 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1260 Training SNP - Freeform for 50 steps...
Average nll: -0.2971948981285095, mse: 28.175565719604492, local kl: 0.0871858224272728 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1280 Training SNP - Freeform for 50 steps...
Average nll: -0.26679080724716187, mse: 27.046754837036133, local kl: 161.2941131591797 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1300 Training SNP - Freeform for 50 steps...
Average nll: -0.270793080329895, mse: 27.867338180541992, local kl: 0.05747829005122185 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1320 Training SNP - Freeform for 50 steps...
Average nll: 26.05276107788086, mse: 30.320716857910156, local kl: 1.6468974351882935 global kl: 0.01113545149564743
Training NeuroLinear-bnn for 50 steps...
1340 Training SNP - Freeform for 50 steps...
Average nll: -0.26193541288375854, mse: 25.41699981689453, local kl: 0.17830026149749756 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1360 Training SNP - Freeform for 50 steps...
Average nll: -0.3395145535469055, mse: 28.827829360961914, local kl: 0.04099560156464577 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1380 Training SNP - Freeform for 50 steps...
Average nll: -0.2892926335334778, mse: 29.457429885864258, local kl: 0.0781722217798233 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1400 Training SNP - Freeform for 50 steps...
Average nll: -0.3447295129299164, mse: 28.550466537475586, local kl: 155.68521118164062 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1420 Training SNP - Freeform for 50 steps...
Average nll: -0.3509666323661804, mse: 27.470762252807617, local kl: 0.2811853289604187 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1440 Training SNP - Freeform for 50 steps...
Average nll: -0.32402679324150085, mse: 29.73196792602539, local kl: 0.09108322858810425 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1460 Training SNP - Freeform for 50 steps...
Average nll: 24.6529483795166, mse: 24.161117553710938, local kl: 0.15413261950016022 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1480 Training SNP - Freeform for 50 steps...
Average nll: -0.3896304965019226, mse: 29.319896697998047, local kl: 3.555264949798584 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1500 Training SNP - Freeform for 50 steps...
Average nll: -0.3563030958175659, mse: 28.892860412597656, local kl: 0.027223708108067513 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1520 Training SNP - Freeform for 50 steps...
Average nll: -0.3691776394844055, mse: 28.850502014160156, local kl: 126.1156234741211 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1540 Training SNP - Freeform for 50 steps...
Average nll: -0.13054504990577698, mse: 28.123003005981445, local kl: 0.06873542070388794 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1560 Training SNP - Freeform for 50 steps...
Average nll: -0.45290669798851013, mse: 25.02523422241211, local kl: 0.05895678699016571 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1580 Training SNP - Freeform for 50 steps...
Average nll: -0.43492820858955383, mse: 27.17365074157715, local kl: 0.0400475412607193 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1600 Training SNP - Freeform for 50 steps...
Average nll: -0.44019564986228943, mse: 29.831777572631836, local kl: 0.11937319487333298 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1620 Training SNP - Freeform for 50 steps...
Average nll: -0.40614044666290283, mse: 27.237524032592773, local kl: 0.1620001196861267 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1640 Training SNP - Freeform for 50 steps...
Average nll: -0.4053007960319519, mse: 29.265974044799805, local kl: 0.05345738306641579 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1660 Training SNP - Freeform for 50 steps...
Average nll: -0.3877968490123749, mse: 29.214365005493164, local kl: 0.20501074194908142 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1680 Training SNP - Freeform for 50 steps...
Average nll: 0.06277947127819061, mse: 29.039390563964844, local kl: 114.1205825805664 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1700 Training SNP - Freeform for 50 steps...
Average nll: -0.5166326761245728, mse: 26.76428985595703, local kl: 380.4644470214844 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1720 Training SNP - Freeform for 50 steps...
Average nll: -0.48723897337913513, mse: 27.234580993652344, local kl: 0.035163696855306625 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1740 Training SNP - Freeform for 50 steps...
Average nll: -0.46268582344055176, mse: 28.70880126953125, local kl: 0.07645389437675476 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1760 Training SNP - Freeform for 50 steps...
Average nll: -0.5171773433685303, mse: 26.510446548461914, local kl: 0.08349506556987762 global kl: 6.234242899694209e-09
Training NeuroLinear-bnn for 50 steps...
1780 Training SNP - Freeform for 50 steps...
Average nll: -0.501615583896637, mse: 31.287931442260742, local kl: 0.14274245500564575 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1800 Training SNP - Freeform for 50 steps...
Average nll: -0.6038743853569031, mse: 26.213790893554688, local kl: 0.09368620067834854 global kl: 1.0728021493378037e-07
Training NeuroLinear-bnn for 50 steps...
1820 Training SNP - Freeform for 50 steps...
Average nll: -0.5139739513397217, mse: 24.797704696655273, local kl: 0.07689665257930756 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1840 Training SNP - Freeform for 50 steps...
Average nll: -0.4836488664150238, mse: 26.5329532623291, local kl: 112.90855407714844 global kl: 4.1127717054223467e-07
Training NeuroLinear-bnn for 50 steps...
1860 Training SNP - Freeform for 50 steps...
Average nll: -0.5742719769477844, mse: 25.455608367919922, local kl: 0.03553317114710808 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1880 Training SNP - Freeform for 50 steps...
Average nll: -0.63866126537323, mse: 25.104272842407227, local kl: 138.9413299560547 global kl: 1.1772735888371244e-05
Training NeuroLinear-bnn for 50 steps...
1900 Training SNP - Freeform for 50 steps...
Average nll: -0.5626521706581116, mse: 23.34984588623047, local kl: 0.14496782422065735 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
1920 Training SNP - Freeform for 50 steps...
Average nll: -0.5527907013893127, mse: 25.632373809814453, local kl: 0.2055022418498993 global kl: 5.0172371146572914e-08
Training NeuroLinear-bnn for 50 steps...
1940 Training SNP - Freeform for 50 steps...
Average nll: -0.3548074960708618, mse: 51.621524810791016, local kl: 0.102879598736763 global kl: 6.155987080092018e-07
Training NeuroLinear-bnn for 50 steps...
1960 Training SNP - Freeform for 50 steps...
Average nll: -0.5209757089614868, mse: 30.1167049407959, local kl: 0.0990341454744339 global kl: 4.140202918279101e-07
Training NeuroLinear-bnn for 50 steps...
1980 Training SNP - Freeform for 50 steps...
Average nll: -0.6739566922187805, mse: 25.65570640563965, local kl: 0.2185898721218109 global kl: 0.0
Training NeuroLinear-bnn for 50 steps...
2000 Training SNP - Freeform for 50 steps...
Average nll: -0.6113163232803345, mse: 25.466415405273438, local kl: 0.1045212522149086 global kl: 0.0
---------------------------------------------------
---------------------------------------------------
0.5_4 bandit completed after 514.9875190258026 seconds.
---------------------------------------------------
  0) MultitaskGP         |		 cummulative regret = 2887.9186566437174.
  1) NeuroLinear         |		 cummulative regret = 4358.036801496788.
  2) SNP - Freeform      |		 cummulative regret = 5102.862282421773.
  3) SNP - Attentive GP  |		 cummulative regret = 56781.20674954462.
  4) Uniform Sampling    |		 cummulative regret = 57844.579710834354.
---------------------------------------------------
---------------------------------------------------
  0) NeuroLinear         |		 simple regret = 168.35791431336338.
  1) MultitaskGP         |		 simple regret = 266.9638545066582.
  2) SNP - Freeform      |		 simple regret = 490.0036740243685.
  3) SNP - Attentive GP  |		 simple regret = 12498.004975616983.
  4) Uniform Sampling    |		 simple regret = 14606.396531735016.
---------------------------------------------------
---------------------------------------------------
Overall Summary for delta =  0.5
---------------------------------------------------
---------------------------------------------------
0.5 bandit completed after 2600.0052287578583 seconds.
---------------------------------------------------
  0) NeuroLinear         |		 cummulative regret = 4804.666737952748.
  1) MultitaskGP         |		 cummulative regret = 7739.551342694598.
  2) SNP - Freeform      |		 cummulative regret = 12070.41071957921.
  3) SNP - Attentive GP  |		 cummulative regret = 50795.60416797532.
  4) Uniform Sampling    |		 cummulative regret = 58183.74381422365.
---------------------------------------------------
---------------------------------------------------
  0) NeuroLinear         |		 simple regret = 294.9251185807975.
  1) MultitaskGP         |		 simple regret = 1174.184122958384.
  2) SNP - Freeform      |		 simple regret = 2242.1412521392444.
  3) SNP - Attentive GP  |		 simple regret = 9670.306880159773.
  4) Uniform Sampling    |		 simple regret = 14733.21793725511.
---------------------------------------------------
---------------------------------------------------

posterior predictive + mse


In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)


it: 0, train nll: 69.30583953857422, mse: 227.4987335205078, local kl: 4.271792888641357 global kl: 0.01877990923821926valid nll: 86.39154815673828, mse: 283.9880676269531, local kl: 2.247591972351074 global kl: 0.00802246667444706
it: 50, train nll: 29.94801139831543, mse: 67.4618148803711, local kl: 65.69648742675781 global kl: 0.0041467431001365185valid nll: 43.102577209472656, mse: 105.57545471191406, local kl: 47.18109893798828 global kl: 0.002302606124430895
it: 100, train nll: 24.34380340576172, mse: 51.35904312133789, local kl: 84.17845916748047 global kl: 0.0002516347449272871valid nll: 37.54330825805664, mse: 84.63033294677734, local kl: 49.208221435546875 global kl: 0.0002653984120115638
it: 150, train nll: 20.86395835876465, mse: 41.2745361328125, local kl: 66.63032531738281 global kl: 1.979325679712929e-05valid nll: 39.090572357177734, mse: 83.65245056152344, local kl: 63.73513412475586 global kl: 5.718140891985968e-05
it: 200, train nll: 33.20714569091797, mse: 68.56368255615234, local kl: 58.058048248291016 global kl: 3.916582954843761e-06valid nll: 22.147907257080078, mse: 41.65616989135742, local kl: 59.766761779785156 global kl: 8.537031681044027e-05
it: 250, train nll: 21.23867416381836, mse: 38.73579025268555, local kl: 67.01470947265625 global kl: 3.1736174150864827e-06valid nll: 36.73422622680664, mse: 75.30731201171875, local kl: 69.11150360107422 global kl: 2.305381713085808e-06
it: 300, train nll: 31.774721145629883, mse: 61.49626541137695, local kl: 64.43514251708984 global kl: 7.460013762283779e-08valid nll: 28.418113708496094, mse: 50.04291915893555, local kl: 95.21513366699219 global kl: 0.0
it: 350, train nll: 22.904325485229492, mse: 38.93213653564453, local kl: 75.84784698486328 global kl: 0.0valid nll: 29.16703987121582, mse: 52.75983810424805, local kl: 80.30721282958984 global kl: 0.0
it: 400, train nll: 28.000587463378906, mse: 55.720306396484375, local kl: 76.77484130859375 global kl: 0.0valid nll: 19.1330509185791, mse: 39.75444793701172, local kl: 75.59008026123047 global kl: 0.0
it: 450, train nll: 23.72516441345215, mse: 48.036773681640625, local kl: 57.20126724243164 global kl: 0.0valid nll: 15.291404724121094, mse: 32.179569244384766, local kl: 55.25490951538086 global kl: 0.0
it: 500, train nll: 35.24778747558594, mse: 82.07746124267578, local kl: 65.26183319091797 global kl: 0.0valid nll: 14.804286003112793, mse: 25.49320411682129, local kl: 91.5669937133789 global kl: 0.0
it: 550, train nll: 21.232065200805664, mse: 36.15296936035156, local kl: 91.66233825683594 global kl: 0.0valid nll: 28.37128448486328, mse: 55.3062629699707, local kl: 65.65888214111328 global kl: 0.0
it: 600, train nll: 29.4051513671875, mse: 48.82991027832031, local kl: 60.25905227661133 global kl: 0.0valid nll: 25.59354019165039, mse: 37.3077507019043, local kl: 87.95748138427734 global kl: 0.0
it: 650, train nll: 38.93806076049805, mse: 58.58415222167969, local kl: 80.17646026611328 global kl: 0.0valid nll: 29.371347427368164, mse: 39.080909729003906, local kl: 74.8018798828125 global kl: 0.0
it: 700, train nll: 43.44169998168945, mse: 56.090816497802734, local kl: 85.95675659179688 global kl: 0.0valid nll: 40.28554153442383, mse: 52.5775146484375, local kl: 79.6629409790039 global kl: 0.0
it: 750, train nll: 41.188480377197266, mse: 56.1728401184082, local kl: 80.34385681152344 global kl: 0.0valid nll: 34.9385871887207, mse: 46.54536819458008, local kl: 86.816162109375 global kl: 0.0
it: 800, train nll: 26.055906295776367, mse: 30.275094985961914, local kl: 82.80529022216797 global kl: 0.0valid nll: 34.07087707519531, mse: 45.78014373779297, local kl: 63.88542556762695 global kl: 0.0
it: 850, train nll: 20.97116470336914, mse: 43.475791931152344, local kl: 71.66441345214844 global kl: 0.0valid nll: 17.463787078857422, mse: 44.53580856323242, local kl: 75.4577407836914 global kl: 0.0
it: 900, train nll: 22.86075210571289, mse: 36.446678161621094, local kl: 69.28800964355469 global kl: 0.0valid nll: 28.10452651977539, mse: 45.34553527832031, local kl: 64.94532012939453 global kl: 0.0
it: 950, train nll: 28.00560760498047, mse: 33.98417282104492, local kl: 75.48758697509766 global kl: 0.0valid nll: 44.37925338745117, mse: 63.00627517700195, local kl: 60.78720474243164 global kl: 0.0
it: 1000, train nll: 33.520172119140625, mse: 41.06328582763672, local kl: 78.46305847167969 global kl: 0.0valid nll: 37.71023178100586, mse: 49.0592155456543, local kl: 63.6708869934082 global kl: 0.0
it: 1050, train nll: 29.068836212158203, mse: 42.17941665649414, local kl: 66.37188720703125 global kl: 0.0valid nll: 24.72761344909668, mse: 33.97840118408203, local kl: 71.38043975830078 global kl: 0.0
it: 1100, train nll: 21.439247131347656, mse: 41.96303939819336, local kl: 53.84902572631836 global kl: 0.0valid nll: 21.238576889038086, mse: 37.31914138793945, local kl: 76.74258422851562 global kl: 0.0
it: 1150, train nll: 32.289241790771484, mse: 40.86528396606445, local kl: 86.77616119384766 global kl: 0.0valid nll: 44.338932037353516, mse: 60.579444885253906, local kl: 65.52584075927734 global kl: 0.0
it: 1200, train nll: 41.40840530395508, mse: 62.760128021240234, local kl: 73.64424133300781 global kl: 0.0valid nll: 13.048828125, mse: 16.772323608398438, local kl: 55.721412658691406 global kl: 0.0
it: 1250, train nll: 36.03093719482422, mse: 49.96666717529297, local kl: 78.03782653808594 global kl: 0.0valid nll: 29.623641967773438, mse: 39.05404281616211, local kl: 73.60641479492188 global kl: 0.0
it: 1300, train nll: 31.021387100219727, mse: 37.029930114746094, local kl: 65.031494140625 global kl: 0.0valid nll: 37.55267333984375, mse: 44.82294464111328, local kl: 73.28715515136719 global kl: 0.0
it: 1350, train nll: 32.677791595458984, mse: 51.81110382080078, local kl: 79.99293518066406 global kl: 0.0valid nll: 25.669965744018555, mse: 40.876338958740234, local kl: 74.72438049316406 global kl: 0.0
it: 1400, train nll: 38.5272216796875, mse: 50.32453536987305, local kl: 67.0750732421875 global kl: 0.0valid nll: 36.18844985961914, mse: 45.85405349731445, local kl: 83.95060729980469 global kl: 0.0
it: 1450, train nll: 37.3214225769043, mse: 50.397117614746094, local kl: 63.450443267822266 global kl: 0.0valid nll: 19.63437843322754, mse: 23.767974853515625, local kl: 53.80479431152344 global kl: 0.0
it: 1500, train nll: 32.72158432006836, mse: 42.11119842529297, local kl: 81.94157409667969 global kl: 0.0valid nll: 36.165283203125, mse: 52.14964294433594, local kl: 57.7842903137207 global kl: 0.0
it: 1550, train nll: 34.11354446411133, mse: 47.89038848876953, local kl: 70.75119018554688 global kl: 0.0valid nll: 30.036724090576172, mse: 40.27745056152344, local kl: 67.69834899902344 global kl: 0.0
it: 1600, train nll: 42.45426940917969, mse: 61.28825378417969, local kl: 65.98312377929688 global kl: 0.0valid nll: 27.21445083618164, mse: 32.31903839111328, local kl: 78.9754638671875 global kl: 0.0
it: 1650, train nll: 39.3293342590332, mse: 51.06004333496094, local kl: 85.98977661132812 global kl: 0.0valid nll: 28.863624572753906, mse: 41.290794372558594, local kl: 55.6628532409668 global kl: 0.0
it: 1700, train nll: 24.410934448242188, mse: 32.078182220458984, local kl: 79.46546173095703 global kl: 0.0valid nll: 30.19243812561035, mse: 44.44026184082031, local kl: 58.374183654785156 global kl: 0.0
it: 1750, train nll: 39.4742431640625, mse: 52.843902587890625, local kl: 55.28364562988281 global kl: 0.0valid nll: 31.819664001464844, mse: 37.29852294921875, local kl: 71.95195007324219 global kl: 0.0
it: 1800, train nll: 37.16926956176758, mse: 49.43376159667969, local kl: 65.60565948486328 global kl: 0.0valid nll: 31.897480010986328, mse: 38.86524200439453, local kl: 75.82583618164062 global kl: 0.0
it: 1850, train nll: 29.74041748046875, mse: 37.98810958862305, local kl: 79.7176742553711 global kl: 0.0valid nll: 32.18796920776367, mse: 42.441349029541016, local kl: 51.496273040771484 global kl: 0.0
it: 1900, train nll: 28.462055206298828, mse: 43.101173400878906, local kl: 71.54998016357422 global kl: 0.0valid nll: 30.10837173461914, mse: 44.906673431396484, local kl: 58.77400588989258 global kl: 0.0
it: 1950, train nll: 21.406444549560547, mse: 32.361270904541016, local kl: 54.86785888671875 global kl: 0.0valid nll: 23.144145965576172, mse: 34.48278045654297, local kl: 59.01050567626953 global kl: 0.0

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)


it: 0, train nll: 63.06072235107422, mse: 215.38909912109375, local kl: 3.032419443130493 global kl: 0.00995991937816143valid nll: 63.42512512207031, mse: 216.81292724609375, local kl: 3.3935022354125977 global kl: 0.01095118559896946
it: 50, train nll: 40.56180953979492, mse: 113.71380615234375, local kl: 146.83770751953125 global kl: 0.007505028508603573valid nll: 41.103233337402344, mse: 122.31623077392578, local kl: 49.39350128173828 global kl: 0.018857663497328758
it: 100, train nll: 30.77800178527832, mse: 79.66529083251953, local kl: 68.61537170410156 global kl: 6.1731977462768555valid nll: 26.246097564697266, mse: 72.5315933227539, local kl: 60.54109573364258 global kl: 0.02694229781627655
it: 150, train nll: 28.668663024902344, mse: 76.17350006103516, local kl: 70.8505859375 global kl: 0.005348356906324625valid nll: 23.36986541748047, mse: 59.470455169677734, local kl: 67.83472442626953 global kl: 0.003085080999881029
it: 200, train nll: 23.498933792114258, mse: 63.863677978515625, local kl: 80.21842193603516 global kl: 0.011260751634836197valid nll: 19.0505428314209, mse: 49.77246856689453, local kl: 74.64579010009766 global kl: 0.00029110434115864336
it: 250, train nll: 18.654150009155273, mse: 55.639400482177734, local kl: 57.40072250366211 global kl: 1.2435530152288266e-05valid nll: 14.957308769226074, mse: 40.67765426635742, local kl: 83.767822265625 global kl: 1.5849669580347836e-05
it: 300, train nll: 21.32248878479004, mse: 58.43162155151367, local kl: 91.81661987304688 global kl: 7.551752787549049e-06valid nll: 20.819887161254883, mse: 55.885345458984375, local kl: 69.2139892578125 global kl: 4.3652761405610363e-07
it: 350, train nll: 19.137954711914062, mse: 54.441650390625, local kl: 76.36198425292969 global kl: 0.0003430648357607424valid nll: 16.16408920288086, mse: 47.8390007019043, local kl: 84.67859649658203 global kl: 0.00027700065402314067
it: 400, train nll: 17.036861419677734, mse: 49.65110397338867, local kl: 84.13072967529297 global kl: 0.00991512555629015valid nll: 15.419821739196777, mse: 43.523521423339844, local kl: 69.07366943359375 global kl: 0.00014124861627351493
it: 450, train nll: 15.528565406799316, mse: 41.445030212402344, local kl: 72.95622253417969 global kl: 0.00830586813390255valid nll: 10.864349365234375, mse: 30.566251754760742, local kl: 58.890525817871094 global kl: 0.015370051376521587
it: 500, train nll: 17.105642318725586, mse: 43.68159103393555, local kl: 65.41458892822266 global kl: 0.35858353972435valid nll: 19.303903579711914, mse: 50.49518966674805, local kl: 69.23296356201172 global kl: 0.00046432079398073256
it: 550, train nll: 16.252029418945312, mse: 36.84682083129883, local kl: 80.3216781616211 global kl: 0.00010965151886921376valid nll: 23.830089569091797, mse: 60.15428161621094, local kl: 67.03443908691406 global kl: 3.876365008181892e-05
it: 600, train nll: 18.94751739501953, mse: 48.666168212890625, local kl: 70.30752563476562 global kl: 5.595894435828086e-06valid nll: 17.612213134765625, mse: 43.04922866821289, local kl: 62.61797332763672 global kl: 7.93003709986806e-06
it: 650, train nll: 15.35086441040039, mse: 35.641761779785156, local kl: 74.67696380615234 global kl: 0.0005205316701903939valid nll: 19.906301498413086, mse: 51.28835678100586, local kl: 67.33728790283203 global kl: 0.0007968117715790868
it: 700, train nll: 18.98784637451172, mse: 45.037967681884766, local kl: 69.05419158935547 global kl: 0.0009166345116682351valid nll: 21.652528762817383, mse: 49.44622802734375, local kl: 78.89848327636719 global kl: 0.0008927981252782047
it: 750, train nll: 20.857524871826172, mse: 48.76597213745117, local kl: 73.71446228027344 global kl: 0.0005265462095849216valid nll: 21.250926971435547, mse: 47.61959457397461, local kl: 70.82953643798828 global kl: 0.0008207792416214943
it: 800, train nll: 20.117109298706055, mse: 41.11979675292969, local kl: 89.0832748413086 global kl: 0.023082416504621506valid nll: 25.491455078125, mse: 61.59838104248047, local kl: 76.7528305053711 global kl: 0.023254219442605972
it: 850, train nll: 23.3465576171875, mse: 48.44394302368164, local kl: 70.60987091064453 global kl: 0.01324918307363987valid nll: 14.959359169006348, mse: 30.976282119750977, local kl: 60.10578155517578 global kl: 0.014420250430703163
it: 900, train nll: 27.094751358032227, mse: 49.87214279174805, local kl: 78.73623657226562 global kl: 0.01850232109427452valid nll: 19.000595092773438, mse: 35.675655364990234, local kl: 43.66059875488281 global kl: 0.007812763564288616
it: 950, train nll: 22.543577194213867, mse: 53.37714767456055, local kl: 58.06190872192383 global kl: 0.020948141813278198valid nll: 11.154351234436035, mse: 21.450176239013672, local kl: 63.98209762573242 global kl: 0.017730267718434334
it: 1000, train nll: 15.90113353729248, mse: 36.102699279785156, local kl: 78.30152893066406 global kl: 0.031107719987630844valid nll: 20.71133804321289, mse: 57.20733642578125, local kl: 67.56367492675781 global kl: 0.0849929079413414
it: 1050, train nll: 21.748165130615234, mse: 46.58993911743164, local kl: 81.75639343261719 global kl: 0.01320687960833311valid nll: 19.864776611328125, mse: 42.53876495361328, local kl: 63.61568069458008 global kl: 0.04385796934366226
it: 1100, train nll: 21.85807228088379, mse: 47.47929000854492, local kl: 66.69312286376953 global kl: 0.05580480024218559valid nll: 20.900001525878906, mse: 45.4069709777832, local kl: 76.64691162109375 global kl: 0.029934663325548172
it: 1150, train nll: 17.11998748779297, mse: 42.203861236572266, local kl: 67.03917694091797 global kl: 0.07873336970806122valid nll: 17.93390464782715, mse: 42.20318603515625, local kl: 73.91483306884766 global kl: 0.018082860857248306
it: 1200, train nll: 14.706640243530273, mse: 54.87645721435547, local kl: 78.0022964477539 global kl: 0.01570923998951912valid nll: 8.866144180297852, mse: 34.26356506347656, local kl: 68.5001449584961 global kl: 0.047442883253097534
it: 1250, train nll: 20.452871322631836, mse: 51.52635192871094, local kl: 77.93737030029297 global kl: 0.00849766656756401valid nll: 16.49269676208496, mse: 43.50374221801758, local kl: 70.3230972290039 global kl: 0.08757313340902328
it: 1300, train nll: 15.722665786743164, mse: 43.80097961425781, local kl: 80.64022064208984 global kl: 0.0541851706802845valid nll: 20.698888778686523, mse: 59.22097396850586, local kl: 61.189308166503906 global kl: 0.0038620177656412125
it: 1350, train nll: 17.793832778930664, mse: 47.30776596069336, local kl: 45.614532470703125 global kl: 0.0035906233824789524valid nll: 17.27609634399414, mse: 40.48519515991211, local kl: 87.75505065917969 global kl: 0.02456105500459671
it: 1400, train nll: 20.350553512573242, mse: 52.3165397644043, local kl: 77.65715026855469 global kl: 0.024992898106575012valid nll: 16.890371322631836, mse: 42.578155517578125, local kl: 73.55687713623047 global kl: 0.09656679630279541
it: 1450, train nll: 20.695810317993164, mse: 48.436683654785156, local kl: 62.238773345947266 global kl: 0.011770632117986679valid nll: 16.12401008605957, mse: 36.350730895996094, local kl: 74.71292877197266 global kl: 0.12354423850774765
it: 1500, train nll: 15.811275482177734, mse: 40.31724166870117, local kl: 71.28557586669922 global kl: 0.026020819321274757valid nll: 12.741310119628906, mse: 32.97138214111328, local kl: 54.22089767456055 global kl: 0.03361105918884277
it: 1550, train nll: 18.48696517944336, mse: 47.48664855957031, local kl: 64.03622436523438 global kl: 0.030964320525527valid nll: 13.357494354248047, mse: 31.898530960083008, local kl: 61.610313415527344 global kl: 0.0738820880651474
it: 1600, train nll: 14.672139167785645, mse: 40.33184051513672, local kl: 73.37171173095703 global kl: 0.008316589519381523valid nll: 17.17024803161621, mse: 46.99702453613281, local kl: 54.83610534667969 global kl: 0.012921598739922047
it: 1650, train nll: 17.702058792114258, mse: 43.11305618286133, local kl: 74.45330810546875 global kl: 0.0037357057444751263valid nll: 19.431203842163086, mse: 50.35197448730469, local kl: 62.12971115112305 global kl: 0.02335527166724205
it: 1700, train nll: 11.930219650268555, mse: 29.624582290649414, local kl: 68.88066864013672 global kl: 0.026260558515787125valid nll: 19.816434860229492, mse: 55.670894622802734, local kl: 56.81549072265625 global kl: 0.011375464498996735
it: 1750, train nll: 18.136728286743164, mse: 48.868621826171875, local kl: 66.09882354736328 global kl: 0.00523568969219923valid nll: 15.383318901062012, mse: 40.45002365112305, local kl: 66.66216278076172 global kl: 0.062216829508543015
it: 1800, train nll: 15.651430130004883, mse: 42.9111213684082, local kl: 61.234317779541016 global kl: 0.005749610718339682valid nll: 13.813420295715332, mse: 36.73366928100586, local kl: 64.31543731689453 global kl: 0.02718406356871128
it: 1850, train nll: 19.14667320251465, mse: 45.937950134277344, local kl: 62.14559555053711 global kl: 0.03130928799510002valid nll: 13.502644538879395, mse: 31.93645668029785, local kl: 60.933197021484375 global kl: 0.0614127591252327
it: 1900, train nll: 16.176328659057617, mse: 42.21115493774414, local kl: 60.54005813598633 global kl: 0.022890115156769753valid nll: 13.757333755493164, mse: 33.29761505126953, local kl: 60.425567626953125 global kl: 0.0699986070394516
it: 1950, train nll: 10.875977516174316, mse: 25.967384338378906, local kl: 44.72900390625 global kl: 0.028215229511260986valid nll: 10.244438171386719, mse: 25.933094024658203, local kl: 44.04553985595703 global kl: 0.04023250564932823
it: 2000, train nll: 17.34665298461914, mse: 42.35016632080078, local kl: 64.67743682861328 global kl: 0.026365751400589943valid nll: 11.36293888092041, mse: 27.81946563720703, local kl: 60.65430450439453 global kl: 0.06782550364732742
it: 2050, train nll: 13.195025444030762, mse: 28.76435661315918, local kl: 65.02043151855469 global kl: 0.016737673431634903valid nll: 14.863659858703613, mse: 37.96871566772461, local kl: 53.22877883911133 global kl: 0.06048823520541191
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-86-6dda4255b542> in <module>()
      1 pretrain(data_hparams,
      2          model_hparams,
----> 3          training_hparams)

<ipython-input-84-8ed220a2dc47> in pretrain(data_hparams, model_hparams, training_hparams)
    120   training_loop(data_hparams,
    121                 model,
--> 122                 training_hparams)
    123 
    124   # initialize SNP model using model_hparams

<ipython-input-84-8ed220a2dc47> in training_loop(data_hparams, model, hparams)
     54   num_context=hparams.num_context_states
     55   for it in range(hparams.num_iterations):
---> 56     target_x, target_y = procure_dataset(data_hparams)
     57     # perm = np.random.permutation(all_state_action_pairs.shape[1])
     58     # perm = perm[:hparams.num_context_states+hparams.num_target_states]

<ipython-input-84-8ed220a2dc47> in procure_dataset(hparams)
      9         hparams.num_actions,
     10         hparams.context_dim,
---> 11         delta)
     12     all_state_action_pairs.append(state_action_pairs)
     13     all_rewards.append(rewards)

<ipython-input-51-8578ad8e9206> in get_training_wheel_data(num_total_states, num_actions, context_dim, delta)
     78                                               std_v,
     79                                               mu_large,
---> 80                                               std_large)
     81   return state_action_pairs, rewards

<ipython-input-51-8578ad8e9206> in sample_training_wheel_bandit_data(num_total_states, num_actions, context_dim, delta, mean_v, std_v, mu_large, std_large)
     37       if np.linalg.norm(raw_data[i, :]) <= 1:
     38         for j in range(num_actions):
---> 39           one_hot_vector = np.zeros((5))
     40           one_hot_vector[j]=1
     41           data.append(np.hstack([raw_data[i, :], one_hot_vector]))

KeyboardInterrupt: 

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)


it: 0, train nll: 58.961856842041016, mse: 205.11830139160156, local kl: 2.90657639503479 global kl: 0.003845751751214266valid nll: 83.07482147216797, mse: 290.5240783691406, local kl: 1.5624231100082397 global kl: 0.0018705299589782953
it: 50, train nll: 2.700871467590332, mse: 3.628453493118286, local kl: 4039669.5 global kl: 0.012357242405414581valid nll: 3.038377523422241, mse: 4.814786911010742, local kl: 9913327.0 global kl: 0.009815489873290062
it: 100, train nll: 2.113877058029175, mse: 0.2591399848461151, local kl: 63664288.0 global kl: 0.0012066202471032739valid nll: 2.106165885925293, mse: 0.5803396105766296, local kl: 73767528.0 global kl: 0.00068276422098279
it: 150, train nll: 2.9890506267547607, mse: 0.21710118651390076, local kl: 368845952.0 global kl: 0.011972357518970966valid nll: 2.979311466217041, mse: 0.062382180243730545, local kl: 452203360.0 global kl: 0.010854621417820454
it: 200, train nll: 2.993539333343506, mse: 0.12942281365394592, local kl: 174734832.0 global kl: 0.004149547312408686valid nll: 2.992830753326416, mse: 0.07247623801231384, local kl: 293517536.0 global kl: 0.0070689162239432335
it: 250, train nll: 3.0182852745056152, mse: 0.09846416860818863, local kl: 235085376.0 global kl: 0.005863801576197147valid nll: 2.9740915298461914, mse: 0.3675422668457031, local kl: 303114848.0 global kl: 0.004511209670454264
it: 300, train nll: 3.044343948364258, mse: 0.3071638345718384, local kl: 280666816.0 global kl: 0.001510246773250401valid nll: 3.049323081970215, mse: 0.461736798286438, local kl: 374701056.0 global kl: 0.0006327595910988748
it: 350, train nll: 3.1857004165649414, mse: 0.15492163598537445, local kl: 557111872.0 global kl: 0.0003688436408992857valid nll: 3.213104724884033, mse: 0.015167862176895142, local kl: 639127104.0 global kl: 0.0006376059609465301
it: 400, train nll: 3.470541477203369, mse: 0.008499015122652054, local kl: 1123064448.0 global kl: 0.016306210309267044valid nll: 3.4990851879119873, mse: 0.03802347183227539, local kl: 1709697408.0 global kl: 0.006358642131090164
it: 450, train nll: 3.4800734519958496, mse: 0.001351524842903018, local kl: 970648128.0 global kl: 0.005563330836594105valid nll: 3.480504274368286, mse: 0.048660796135663986, local kl: 1078734592.0 global kl: 7426.1201171875
it: 500, train nll: 3.433819055557251, mse: 0.13115379214286804, local kl: 1052457984.0 global kl: 0.007772705052047968valid nll: 3.4342823028564453, mse: 0.0066461097449064255, local kl: 995501696.0 global kl: 0.010738678276538849
it: 550, train nll: 3.43011212348938, mse: 0.07306498289108276, local kl: 1063606784.0 global kl: 0.005441803019493818valid nll: 3.4245309829711914, mse: 0.08151636272668839, local kl: 1381059328.0 global kl: 0.009001446887850761
it: 600, train nll: 3.423692226409912, mse: 0.0008446523570455611, local kl: 782076032.0 global kl: 0.0037483717314898968valid nll: 3.4287703037261963, mse: 0.31432658433914185, local kl: 1444983808.0 global kl: 0.009851839393377304
it: 650, train nll: 3.4238674640655518, mse: 0.011824171058833599, local kl: 1141892480.0 global kl: 0.004763867240399122valid nll: 3.44429612159729, mse: 0.002929924288764596, local kl: 951267520.0 global kl: 0.006768574006855488
it: 700, train nll: 3.1067888736724854, mse: 0.08848026394844055, local kl: 943941376.0 global kl: 0.0030246630776673555valid nll: 3.0909547805786133, mse: 0.06820278614759445, local kl: 1215621376.0 global kl: 0.008580348454415798
it: 750, train nll: 3.1398229598999023, mse: 0.08521401137113571, local kl: 1089146752.0 global kl: 0.0033222143538296223valid nll: 3.0665879249572754, mse: 0.12499111890792847, local kl: 1024599488.0 global kl: 0.0025402747560292482
it: 800, train nll: 3.0691065788269043, mse: 0.1225980743765831, local kl: 918171712.0 global kl: 0.002317312639206648valid nll: 3.082725763320923, mse: 0.03551483154296875, local kl: 994146176.0 global kl: 0.002466950798407197
it: 850, train nll: 3.1082541942596436, mse: 0.11260442435741425, local kl: 671770944.0 global kl: 0.0025485376827418804valid nll: 3.088731288909912, mse: 0.0007540767546743155, local kl: 1030668352.0 global kl: 0.0055611166171729565
it: 900, train nll: 3.0891458988189697, mse: 0.05464208498597145, local kl: 933485888.0 global kl: 0.0018561702454462647valid nll: 3.1032211780548096, mse: 0.015298591926693916, local kl: 863290624.0 global kl: 0.002679202239960432
it: 950, train nll: 3.0695769786834717, mse: 0.1590917706489563, local kl: 847808832.0 global kl: 0.0030159971211105585valid nll: 3.0736191272735596, mse: 0.0870058685541153, local kl: 988006016.0 global kl: 0.0011307161767035723
it: 1000, train nll: 3.0775296688079834, mse: 0.05836066976189613, local kl: 901235968.0 global kl: 0.0017684970516711473valid nll: 3.1164259910583496, mse: 0.3452860414981842, local kl: 1060234560.0 global kl: 0.0010175046045333147
it: 1050, train nll: 3.130429267883301, mse: 0.07328684628009796, local kl: 1140564864.0 global kl: 0.0009032817324623466valid nll: 3.0677831172943115, mse: 0.019454998895525932, local kl: 1000410496.0 global kl: 0.0033826064318418503
it: 1100, train nll: 3.0964462757110596, mse: 0.013735372573137283, local kl: 1027349056.0 global kl: 0.00045964503078721464valid nll: 3.1805520057678223, mse: 0.11486658453941345, local kl: 1101294848.0 global kl: 0.0005226265639066696
it: 1150, train nll: 3.094805955886841, mse: 0.2790409028530121, local kl: 776404672.0 global kl: 0.0009114152053371072valid nll: 3.109269142150879, mse: 0.4805215299129486, local kl: 706564096.0 global kl: 0.0016452738782390952
it: 1200, train nll: 3.1286168098449707, mse: 0.07749082893133163, local kl: 909116608.0 global kl: 0.0007736885454505682valid nll: 3.0801174640655518, mse: 0.06907245516777039, local kl: 937104512.0 global kl: 0.00037128885742276907
it: 1250, train nll: 3.063004970550537, mse: 0.18585866689682007, local kl: 1005983424.0 global kl: 0.00026404429809190333valid nll: 3.080387592315674, mse: 0.017852095887064934, local kl: 1116789248.0 global kl: 0.0008198929135687649
it: 1300, train nll: 3.0354232788085938, mse: 0.14205317199230194, local kl: 836057536.0 global kl: 0.0007279487326741219valid nll: 2.964357852935791, mse: 0.012241141870617867, local kl: 635555904.0 global kl: 0.0004966954002156854
it: 1350, train nll: 2.877556562423706, mse: 0.16847871243953705, local kl: 1038326400.0 global kl: 0.0015540809836238623valid nll: 2.8493077754974365, mse: 0.05686056241393089, local kl: 1243919488.0 global kl: 0.0020621002186089754
it: 1400, train nll: 2.7882795333862305, mse: 0.0009136111475527287, local kl: 1448658432.0 global kl: 0.0012470472138375044valid nll: 2.782350778579712, mse: 0.0005831235903315246, local kl: 1776232576.0 global kl: 0.0006379054975695908
it: 1450, train nll: 2.598015069961548, mse: 1.3299325019033859e-07, local kl: 1518688512.0 global kl: 0.008897090330719948valid nll: 2.48555850982666, mse: 0.0002692076959647238, local kl: 2107680640.0 global kl: 0.005310703534632921
it: 1500, train nll: 2.332942008972168, mse: 2.2750831703888252e-05, local kl: 2272262656.0 global kl: 0.00020643696188926697valid nll: 2.017017364501953, mse: 2.4314047664120153e-07, local kl: 2434450176.0 global kl: 0.00043935925350524485
it: 1550, train nll: 2.3033933639526367, mse: 2.9773555070278235e-06, local kl: 1775099904.0 global kl: 0.0005600734730251133valid nll: 2.4924519062042236, mse: 3.4110214652827153e-09, local kl: 2489405184.0 global kl: 0.00044099544174969196
it: 1600, train nll: 2.3718173503875732, mse: 5.930925374300969e-14, local kl: 2204733440.0 global kl: 1.1635678674792871e-05valid nll: 2.4745607376098633, mse: 1.1233877033356115e-17, local kl: 2391021824.0 global kl: 5.6687977121328e-06
it: 1650, train nll: 2.505502700805664, mse: 3.373777993087579e-11, local kl: 2127624320.0 global kl: 0.0001696811377769336valid nll: 2.522768974304199, mse: 7.875902241494259e-08, local kl: 2177239552.0 global kl: 1.3791242963634431e-05
it: 1700, train nll: 2.5558266639709473, mse: 6.216160919336744e-10, local kl: 1672364288.0 global kl: 3.601110802264884e-05valid nll: 2.495659589767456, mse: 1.6374562296361805e-12, local kl: 2102218368.0 global kl: 2.356537697778549e-05
it: 1750, train nll: 2.4511044025421143, mse: 3.3827807328634663e-06, local kl: 2105492352.0 global kl: 5.1466595323290676e-05valid nll: 2.3498077392578125, mse: 6.951488863704558e-12, local kl: 1778120064.0 global kl: 3.215510514564812e-05
it: 1800, train nll: 2.606743574142456, mse: 4.2092418880201876e-05, local kl: 2403887616.0 global kl: 6.827048309787642e-06valid nll: 2.6228256225585938, mse: 0.003250936744734645, local kl: 3217786112.0 global kl: 5.760492967965547e-06
it: 1850, train nll: 2.6368582248687744, mse: 7.723290253637176e-18, local kl: 1734014976.0 global kl: 5.884365691599669e-06valid nll: 2.4046335220336914, mse: 1.1740953596017789e-05, local kl: 2216025344.0 global kl: 6.895207661727909e-06
it: 1900, train nll: 2.460270404815674, mse: 5.859206175795606e-11, local kl: 2019072640.0 global kl: 8.47535102366237e-06valid nll: 2.5227890014648438, mse: 1.1360011740180198e-06, local kl: 2158126080.0 global kl: 1.578358387632761e-05
it: 1950, train nll: 2.5193448066711426, mse: 9.478583494604135e-16, local kl: 2273384448.0 global kl: 6.78981959936209e-05valid nll: 2.616962432861328, mse: 0.0004332281823735684, local kl: 2830836224.0 global kl: 8.094704389804974e-05
it: 2000, train nll: 2.7021846771240234, mse: 8.495619361718955e-17, local kl: 2297443584.0 global kl: 6.690119334962219e-05valid nll: 2.749786138534546, mse: 2.8084692583390287e-18, local kl: 2601519360.0 global kl: 7.110213482519612e-05
it: 2050, train nll: 2.821805953979492, mse: 6.509195848991567e-09, local kl: 2330140928.0 global kl: 7.074712584653753e-07valid nll: 2.7516050338745117, mse: 0.0, local kl: 2078516736.0 global kl: 2.2420724690164207e-06
it: 2100, train nll: 2.705542802810669, mse: 2.8084692583390287e-18, local kl: 1849972608.0 global kl: 2.27208215619612e-06valid nll: 2.7615389823913574, mse: 6.319056038057968e-18, local kl: 2033111168.0 global kl: 1.3486427405950963e-06
it: 2150, train nll: 2.784679651260376, mse: 2.1998794181854464e-06, local kl: 2133003392.0 global kl: 2.0075487555004656e-05valid nll: 2.969350576400757, mse: 0.0, local kl: 1792392320.0 global kl: 1.7975695300265215e-05
it: 2200, train nll: 2.8892734050750732, mse: 5.857434762918112e-13, local kl: 1544996480.0 global kl: 2.6153481940127676e-06valid nll: 2.826018810272217, mse: 0.0, local kl: 2970508544.0 global kl: 7.690678103244863e-07
it: 2250, train nll: 2.934501886367798, mse: 2.8084692583390287e-18, local kl: 1818158208.0 global kl: 9.116223509408883e-07valid nll: 2.9886929988861084, mse: 1.4557701154575484e-14, local kl: 2457421312.0 global kl: 6.628865776292514e-07
it: 2300, train nll: 2.8207285404205322, mse: 1.284136396861868e-05, local kl: 2034204544.0 global kl: 1.1890509767908952e-06valid nll: 2.9602644443511963, mse: 2.8084692583390287e-18, local kl: 2473306624.0 global kl: 1.2526255659395247e-06
it: 2350, train nll: 2.9117088317871094, mse: 0.027650346979498863, local kl: 2301513472.0 global kl: 8.345746209670324e-06valid nll: 2.859774112701416, mse: 2.0448118753790823e-09, local kl: 2466383872.0 global kl: 5.3678813856095076e-06
it: 2400, train nll: 2.752645492553711, mse: 0.003390224650502205, local kl: 2306483456.0 global kl: 1.6292817235807888e-05valid nll: 2.8458988666534424, mse: 5.770761646317624e-10, local kl: 1993733248.0 global kl: 5.262333434075117e-05
it: 2450, train nll: 2.9372775554656982, mse: 8.179998985724524e-05, local kl: 2543592960.0 global kl: 7.299009303096682e-05valid nll: 2.8079171180725098, mse: 0.0, local kl: 2467917312.0 global kl: 1.9945746316807345e-05
it: 2500, train nll: 2.710695266723633, mse: 0.0, local kl: 1571160192.0 global kl: 1.0611651077852002e-06valid nll: 2.476489782333374, mse: 1.7131662227713891e-16, local kl: 2590287360.0 global kl: 7.832656478967692e-07
it: 2550, train nll: 2.7061614990234375, mse: 4.493550813342446e-17, local kl: 1430648704.0 global kl: 1.4626851907451055e-06valid nll: 2.7589643001556396, mse: 2.439523960617862e-08, local kl: 3059962112.0 global kl: 3.307378847239306e-06
it: 2600, train nll: 2.725351333618164, mse: 5.038842232352181e-08, local kl: 2065620992.0 global kl: 1.4345171166496584e-06valid nll: 2.6068549156188965, mse: 2.2509208039145356e-12, local kl: 2691090688.0 global kl: 6.794092541895225e-07
it: 2650, train nll: 2.9241483211517334, mse: 1.297871810024276e-10, local kl: 1725149952.0 global kl: 7.814751734258607e-05valid nll: 2.867985963821411, mse: 0.0, local kl: 1833444608.0 global kl: 0.00037239064113236964
it: 2700, train nll: 2.7822747230529785, mse: 0.0, local kl: 1434435840.0 global kl: 5.580125616688747e-06valid nll: 2.8413383960723877, mse: 8.211044288941594e-09, local kl: 2206632960.0 global kl: 5.065775440016296e-06
it: 2750, train nll: 2.826726198196411, mse: 0.0, local kl: 2220877312.0 global kl: 6.107293302193284e-06valid nll: 2.7536699771881104, mse: 0.0, local kl: 1994932480.0 global kl: 8.570971658627968e-06
it: 2800, train nll: 2.732779026031494, mse: 2.695722534085121e-10, local kl: 2056230784.0 global kl: 3.4416104881529463e-06valid nll: 2.811826229095459, mse: 0.0, local kl: 2455563776.0 global kl: 1.4909177480149083e-05
it: 2850, train nll: 2.722261428833008, mse: 1.1233877033356115e-17, local kl: 2331831296.0 global kl: 3.836934865830699e-06valid nll: 2.693962335586548, mse: 0.0, local kl: 2896114688.0 global kl: 3.5214943636674434e-06
it: 2900, train nll: 2.729822874069214, mse: 0.0, local kl: 2049404032.0 global kl: 4.767774441916117e-07valid nll: 2.6972458362579346, mse: 0.0, local kl: 2808681984.0 global kl: 1.5360359384430922e-06
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-80-6dda4255b542> in <module>()
      1 pretrain(data_hparams,
      2          model_hparams,
----> 3          training_hparams)

<ipython-input-78-27a17933eb9f> in pretrain(data_hparams, model_hparams, training_hparams)
    120   training_loop(data_hparams,
    121                 model,
--> 122                 training_hparams)
    123 
    124   # initialize SNP model using model_hparams

<ipython-input-78-27a17933eb9f> in training_loop(data_hparams, model, hparams)
     54   num_context=hparams.num_context_states
     55   for it in range(hparams.num_iterations):
---> 56     target_x, target_y = procure_dataset(data_hparams)
     57     # perm = np.random.permutation(all_state_action_pairs.shape[1])
     58     # perm = perm[:hparams.num_context_states+hparams.num_target_states]

<ipython-input-78-27a17933eb9f> in procure_dataset(hparams)
      9         hparams.num_actions,
     10         hparams.context_dim,
---> 11         delta)
     12     all_state_action_pairs.append(state_action_pairs)
     13     all_rewards.append(rewards)

<ipython-input-51-8578ad8e9206> in get_training_wheel_data(num_total_states, num_actions, context_dim, delta)
     78                                               std_v,
     79                                               mu_large,
---> 80                                               std_large)
     81   return state_action_pairs, rewards

<ipython-input-51-8578ad8e9206> in sample_training_wheel_bandit_data(num_total_states, num_actions, context_dim, delta, mean_v, std_v, mu_large, std_large)
     39           one_hot_vector = np.zeros((5))
     40           one_hot_vector[j]=1
---> 41           data.append(np.hstack([raw_data[i, :], one_hot_vector]))
     42 
     43   state_action_pairs = np.stack(data)[:num_actions * num_total_states, :]

KeyboardInterrupt: 

In [0]: