In [1]:
import tensorflow as tf
import rfho as rf

from rfho.datasets import load_mnist


Experiment save directory is  /media/luca/DATA/EXPERIMENTS
Data folder is /media/luca/DATA/DATASETS

In [2]:
mnist = load_mnist(partitions=(.05, .01)) # 5% of data in training set, 1% in validation 
# remaining in test set (change these percentages and see the effect on regularization hyperparameter)


Extracting /media/luca/DATA/DATASETS/mnist_data/train-images-idx3-ubyte.gz
Extracting /media/luca/DATA/DATASETS/mnist_data/train-labels-idx1-ubyte.gz
Extracting /media/luca/DATA/DATASETS/mnist_data/t10k-images-idx3-ubyte.gz
Extracting /media/luca/DATA/DATASETS/mnist_data/t10k-labels-idx1-ubyte.gz
datasets.redivide_data:, computed partitions numbers - [0, 3500, 4200, 70000] len all 70000 DONE

In [3]:
x, y = tf.placeholder(tf.float32, name='x'), tf.placeholder(tf.float32, name='y')
# define the model (here use a linear model from rfho.models)
model = rf.LinearModel(x, mnist.train.dim_data, mnist.train.dim_target)
# vectorize the model, and build the state vector (augment by 1 since we are 
# going to optimize the weights with momentum) 
s, out, w_matrix = rf.vectorize_model(model.var_list, model.inp[-1], model.Ws[0],
                                     augment=1)
# (this function will print also some tensorflow infos and warnings about variables 
# collections... we'll solve this)

In [4]:
# define error 
error = tf.reduce_mean(rf.cross_entropy_loss(labels=y, logits=out), name='error')

constraints = []

# define training error by error + L2 weights penalty
rho = tf.Variable(0., name='rho')  # regularization hyperparameter
training_error = error + rho*tf.reduce_sum(tf.pow(w_matrix, 2))
constraints.append(rf.positivity(rho))  # regularization coefficient should be positive

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(y, 1)),
                                  "float"), name='accuracy')

# define learning rates and momentum factor as variables, to be optimized
eta = tf.Variable(.01, name='eta')
mu = tf.Variable(.5, name='mu')
# now define the training dynamics (similar to tf.train.Optimizer)
optimizer = rf.MomentumOptimizer.create(s, eta, mu, loss=training_error)

# add constraints for learning rate and momentum factor
constraints += optimizer.get_natural_hyperparameter_constraints()

In [5]:
# we want to optimize the weights w.r.t. training_error
# and hyperparameters w.r.t. validation error (that in this case is 
# error evaluated on the validation set)
# we are going to use ReverseMode
hyper_dict = {error: [rho, eta, mu]}
hyper_opt = rf.HyperOptimizer(optimizer, hyper_dict, method=rf.ReverseHG)

In [6]:
# define helper for stochastic descent
ev_data = rf.ExampleVisiting(mnist.train, batch_size=2**8, epochs=200)
tr_suppl = ev_data.create_supplier(x, y)
val_supplier = mnist.validation.create_supplier(x, y)
test_supplier = mnist.test.create_supplier(x, y)

In [7]:
# Run all for some hyper-iterations and print progresses 
def run(hyper_iterations):
    with tf.Session().as_default() as ss:
        ev_data.generate_visiting_scheme()  # needed for remembering the example visited in forward pass
        for hyper_step in range(hyper_iterations):
            hyper_opt.initialize()  # initializes all variables or reset weights to initial state
            hyper_opt.run(ev_data.T, train_feed_dict_supplier=tr_suppl,
                          val_feed_dict_suppliers=val_supplier, 
                          hyper_constraints_ops=constraints)
        # 
        # print('Concluded hyper-iteration', hyper_step)
        # print('Test accuracy:', ss.run(accuracy, feed_dict=test_supplier()))
        # print('Validation error:', ss.run(error, feed_dict=val_supplier()))

In [8]:
saver = rf.Saver('Staring example', collect_data=False)
with saver.record(rf.Records.tensors('error', fd=('x', 'y', mnist.validation), rec_name='valid'),
                  rf.Records.tensors('error', fd=('x', 'y', mnist.test), rec_name='test'),
                  rf.Records.tensors('accuracy', fd=('x', 'y', mnist.validation), rec_name='valid'),
                  rf.Records.tensors('accuracy', fd=('x', 'y', mnist.test), rec_name='test'),
                  rf.Records.hyperparameters(),
                  rf.Records.hypergradients(),
                  ):  # a context to print some statistics.
    # If you execute again any cell containing the model construction,
    # restart the notebook or reset tensorflow graph in order to prevent errors
    # due to tensor namings
    run(50)  # this will take some time... run it for less hyper-iterations for a qucker look


Step 0                Values
------------------  --------
valid::error         2.30259
test::error          2.30259
valid::accuracy      0.09143
test::accuracy       0.09828
rho                  0.00000
eta                  0.01000
mu                   0.50000
grad::rho            0.00000
grad::eta            0.00000
grad::mu             0.00000
Elapsed time (sec)   0.00000
Step 1                Values
------------------  --------
valid::error         0.47800
test::error          0.41440
valid::accuracy      0.87286
test::accuracy       0.88319
rho                  0.00000
eta                  0.01100
mu                   0.50100
grad::rho            5.67378
grad::eta           -5.25401
grad::mu            -0.10519
Elapsed time (sec)  34.00000
Step 2                Values
------------------  --------
valid::error         0.47308
test::error          0.40872
valid::accuracy      0.87429
test::accuracy       0.88409
rho                  0.00000
eta                  0.01199
mu                   0.50200
grad::rho            5.58267
grad::eta           -4.43401
grad::mu            -0.09784
Elapsed time (sec)  69.00000
Step 3                 Values
------------------  ---------
valid::error          0.46891
test::error           0.40393
valid::accuracy       0.87571
test::accuracy        0.88506
rho                   0.00000
eta                   0.01297
mu                    0.50299
grad::rho             5.46390
grad::eta            -3.80384
grad::mu             -0.09169
Elapsed time (sec)  104.00000
Step 4                 Values
------------------  ---------
valid::error          0.46534
test::error           0.39986
valid::accuracy       0.87571
test::accuracy        0.88576
rho                   0.00000
eta                   0.01394
mu                    0.50398
grad::rho             5.32266
grad::eta            -3.30910
grad::mu             -0.08645
Elapsed time (sec)  139.00000
Step 5                 Values
------------------  ---------
valid::error          0.46226
test::error           0.39637
valid::accuracy       0.87714
test::accuracy        0.88641
rho                   0.00000
eta                   0.01489
mu                    0.50496
grad::rho             5.16324
grad::eta            -2.91322
grad::mu             -0.08193
Elapsed time (sec)  174.00000
Step 6                 Values
------------------  ---------
valid::error          0.45957
test::error           0.39335
valid::accuracy       0.87714
test::accuracy        0.88682
rho                   0.00000
eta                   0.01582
mu                    0.50594
grad::rho             4.98908
grad::eta            -2.59104
grad::mu             -0.07798
Elapsed time (sec)  209.00000
Step 7                 Values
------------------  ---------
valid::error          0.45722
test::error           0.39073
valid::accuracy       0.87857
test::accuracy        0.88725
rho                   0.00000
eta                   0.01673
mu                    0.50691
grad::rho             4.80299
grad::eta            -2.32485
grad::mu             -0.07449
Elapsed time (sec)  244.00000
Step 8                 Values
------------------  ---------
valid::error          0.45513
test::error           0.38843
valid::accuracy       0.87857
test::accuracy        0.88763
rho                   0.00000
eta                   0.01762
mu                    0.50787
grad::rho             4.60730
grad::eta            -2.10197
grad::mu             -0.07136
Elapsed time (sec)  278.00000
Step 9                 Values
------------------  ---------
valid::error          0.45328
test::error           0.38641
valid::accuracy       0.87714
test::accuracy        0.88786
rho                   0.00000
eta                   0.01849
mu                    0.50883
grad::rho             4.40388
grad::eta            -1.91310
grad::mu             -0.06853
Elapsed time (sec)  313.00000
Step 10                Values
------------------  ---------
valid::error          0.45163
test::error           0.38463
valid::accuracy       0.87714
test::accuracy        0.88822
rho                   0.00000
eta                   0.01933
mu                    0.50977
grad::rho             4.19430
grad::eta            -1.75135
grad::mu             -0.06596
Elapsed time (sec)  347.00000
Step 11                Values
------------------  ---------
valid::error          0.45014
test::error           0.38305
valid::accuracy       0.87714
test::accuracy        0.88869
rho                   0.00000
eta                   0.02016
mu                    0.51071
grad::rho             3.97987
grad::eta            -1.61151
grad::mu             -0.06359
Elapsed time (sec)  382.00000
Step 12                Values
------------------  ---------
valid::error          0.44880
test::error           0.38165
valid::accuracy       0.87714
test::accuracy        0.88883
rho                   0.00000
eta                   0.02096
mu                    0.51164
grad::rho             3.76173
grad::eta            -1.48957
grad::mu             -0.06140
Elapsed time (sec)  416.00000
Step 13                Values
------------------  ---------
valid::error          0.44760
test::error           0.38039
valid::accuracy       0.87714
test::accuracy        0.88913
rho                   0.00000
eta                   0.02174
mu                    0.51256
grad::rho             3.54076
grad::eta            -1.38245
grad::mu             -0.05936
Elapsed time (sec)  451.00000
Step 14                Values
------------------  ---------
valid::error          0.44650
test::error           0.37926
valid::accuracy       0.87714
test::accuracy        0.88935
rho                   0.00000
eta                   0.02250
mu                    0.51347
grad::rho             3.31781
grad::eta            -1.28768
grad::mu             -0.05746
Elapsed time (sec)  485.00000
Step 15                Values
------------------  ---------
valid::error          0.44550
test::error           0.37824
valid::accuracy       0.87714
test::accuracy        0.88951
rho                   0.00000
eta                   0.02324
mu                    0.51438
grad::rho             3.09350
grad::eta            -1.20333
grad::mu             -0.05567
Elapsed time (sec)  520.00000
Step 16                Values
------------------  ---------
valid::error          0.44459
test::error           0.37733
valid::accuracy       0.87714
test::accuracy        0.88965
rho                   0.00000
eta                   0.02396
mu                    0.51527
grad::rho             2.86849
grad::eta            -1.12783
grad::mu             -0.05399
Elapsed time (sec)  554.00000
Step 17                Values
------------------  ---------
valid::error          0.44376
test::error           0.37650
valid::accuracy       0.87571
test::accuracy        0.88966
rho                   0.00000
eta                   0.02465
mu                    0.51615
grad::rho             2.64326
grad::eta            -1.05991
grad::mu             -0.05239
Elapsed time (sec)  588.00000
Step 18                Values
------------------  ---------
valid::error          0.44300
test::error           0.37575
valid::accuracy       0.87571
test::accuracy        0.89003
rho                   0.00000
eta                   0.02533
mu                    0.51702
grad::rho             2.41825
grad::eta            -0.99852
grad::mu             -0.05088
Elapsed time (sec)  623.00000
Step 19                Values
------------------  ---------
valid::error          0.44230
test::error           0.37507
valid::accuracy       0.87571
test::accuracy        0.89012
rho                   0.00000
eta                   0.02598
mu                    0.51789
grad::rho             2.19384
grad::eta            -0.94279
grad::mu             -0.04944
Elapsed time (sec)  656.00000
Step 20                Values
------------------  ---------
valid::error          0.44166
test::error           0.37445
valid::accuracy       0.87571
test::accuracy        0.89021
rho                   0.00000
eta                   0.02662
mu                    0.51874
grad::rho             1.97038
grad::eta            -0.89201
grad::mu             -0.04807
Elapsed time (sec)  691.00000
Step 21                Values
------------------  ---------
valid::error          0.44107
test::error           0.37389
valid::accuracy       0.87857
test::accuracy        0.89030
rho                   0.00000
eta                   0.02723
mu                    0.51959
grad::rho             1.74815
grad::eta            -0.84556
grad::mu             -0.04676
Elapsed time (sec)  725.00000
Step 22                Values
------------------  ---------
valid::error          0.44052
test::error           0.37337
valid::accuracy       0.87857
test::accuracy        0.89058
rho                   0.00000
eta                   0.02783
mu                    0.52042
grad::rho             1.52742
grad::eta            -0.80294
grad::mu             -0.04551
Elapsed time (sec)  760.00000
Step 23                Values
------------------  ---------
valid::error          0.44001
test::error           0.37290
valid::accuracy       0.87857
test::accuracy        0.89070
rho                   0.00000
eta                   0.02841
mu                    0.52124
grad::rho             1.30838
grad::eta            -0.76371
grad::mu             -0.04431
Elapsed time (sec)  794.00000
Step 24                Values
------------------  ---------
valid::error          0.43955
test::error           0.37247
valid::accuracy       0.87857
test::accuracy        0.89088
rho                   0.00000
eta                   0.02897
mu                    0.52206
grad::rho             1.09124
grad::eta            -0.72749
grad::mu             -0.04316
Elapsed time (sec)  828.00000
Step 25                Values
------------------  ---------
valid::error          0.43911
test::error           0.37208
valid::accuracy       0.87714
test::accuracy        0.89096
rho                   0.00000
eta                   0.02952
mu                    0.52286
grad::rho             0.87616
grad::eta            -0.69397
grad::mu             -0.04205
Elapsed time (sec)  862.00000
Step 26                Values
------------------  ---------
valid::error          0.43871
test::error           0.37172
valid::accuracy       0.87714
test::accuracy        0.89118
rho                   0.00000
eta                   0.03005
mu                    0.52365
grad::rho             0.66327
grad::eta            -0.66285
grad::mu             -0.04098
Elapsed time (sec)  896.00000
Step 27                Values
------------------  ---------
valid::error          0.43834
test::error           0.37138
valid::accuracy       0.87857
test::accuracy        0.89125
rho                   0.00000
eta                   0.03056
mu                    0.52444
grad::rho             0.45269
grad::eta            -0.63391
grad::mu             -0.03996
Elapsed time (sec)  930.00000
Step 28                Values
------------------  ---------
valid::error          0.43799
test::error           0.37107
valid::accuracy       0.87857
test::accuracy        0.89123
rho                   0.00000
eta                   0.03105
mu                    0.52521
grad::rho             0.24452
grad::eta            -0.60693
grad::mu             -0.03897
Elapsed time (sec)  964.00000
Step 29                Values
------------------  ---------
valid::error          0.43766
test::error           0.37079
valid::accuracy       0.87857
test::accuracy        0.89137
rho                   0.00000
eta                   0.03153
mu                    0.52597
grad::rho             0.03884
grad::eta            -0.58172
grad::mu             -0.03802
Elapsed time (sec)  999.00000
Step 30                 Values
------------------  ----------
valid::error           0.43736
test::error            0.37053
valid::accuracy        0.87714
test::accuracy         0.89141
rho                    0.00000
eta                    0.03200
mu                     0.52673
grad::rho             -0.16431
grad::eta             -0.55811
grad::mu              -0.03709
Elapsed time (sec)  1033.00000
Step 31                 Values
------------------  ----------
valid::error           0.43708
test::error            0.37029
valid::accuracy        0.87714
test::accuracy         0.89146
rho                    0.00000
eta                    0.03245
mu                     0.52747
grad::rho             -0.36484
grad::eta             -0.53597
grad::mu              -0.03620
Elapsed time (sec)  1065.00000
Step 32                 Values
------------------  ----------
valid::error           0.43681
test::error            0.37007
valid::accuracy        0.87714
test::accuracy         0.89137
rho                    0.00000
eta                    0.03289
mu                     0.52820
grad::rho             -0.56276
grad::eta             -0.51517
grad::mu              -0.03534
Elapsed time (sec)  1097.00000
Step 33                 Values
------------------  ----------
valid::error           0.43657
test::error            0.36986
valid::accuracy        0.87714
test::accuracy         0.89146
rho                    0.00000
eta                    0.03332
mu                     0.52893
grad::rho             -0.75800
grad::eta             -0.49559
grad::mu              -0.03451
Elapsed time (sec)  1129.00000
Step 34                 Values
------------------  ----------
valid::error           0.43633
test::error            0.36967
valid::accuracy        0.87714
test::accuracy         0.89141
rho                    0.00000
eta                    0.03373
mu                     0.52964
grad::rho             -0.95057
grad::eta             -0.47712
grad::mu              -0.03370
Elapsed time (sec)  1162.00000
Step 35                 Values
------------------  ----------
valid::error           0.43612
test::error            0.36949
valid::accuracy        0.87714
test::accuracy         0.89149
rho                    0.00000
eta                    0.03413
mu                     0.53035
grad::rho             -1.14045
grad::eta             -0.45969
grad::mu              -0.03292
Elapsed time (sec)  1196.00000
Step 36                 Values
------------------  ----------
valid::error           0.43591
test::error            0.36933
valid::accuracy        0.87714
test::accuracy         0.89156
rho                    0.00000
eta                    0.03452
mu                     0.53104
grad::rho             -1.32767
grad::eta             -0.44320
grad::mu              -0.03216
Elapsed time (sec)  1226.00000
Step 37                 Values
------------------  ----------
valid::error           0.43572
test::error            0.36918
valid::accuracy        0.87714
test::accuracy         0.89153
rho                    0.00000
eta                    0.03490
mu                     0.53173
grad::rho             -1.51220
grad::eta             -0.42758
grad::mu              -0.03142
Elapsed time (sec)  1258.00000
Step 38                 Values
------------------  ----------
valid::error           0.43554
test::error            0.36903
valid::accuracy        0.87714
test::accuracy         0.89170
rho                    0.00001
eta                    0.03527
mu                     0.53241
grad::rho             -1.69405
grad::eta             -0.41277
grad::mu              -0.03071
Elapsed time (sec)  1289.00000
Step 39                 Values
------------------  ----------
valid::error           0.43534
test::error            0.36894
valid::accuracy        0.87714
test::accuracy         0.89173
rho                    0.00008
eta                    0.03563
mu                     0.53307
grad::rho             -1.63431
grad::eta             -0.40491
grad::mu              -0.03048
Elapsed time (sec)  1322.00000
Step 40                 Values
------------------  ----------
valid::error           0.43510
test::error            0.36902
valid::accuracy        0.87714
test::accuracy         0.89175
rho                    0.00017
eta                    0.03597
mu                     0.53374
grad::rho             -0.71248
grad::eta             -0.41806
grad::mu              -0.03185
Elapsed time (sec)  1354.00000
Step 41                 Values
------------------  ----------
valid::error           0.43492
test::error            0.36922
valid::accuracy        0.87571
test::accuracy         0.89187
rho                    0.00023
eta                    0.03632
mu                     0.53440
grad::rho              0.47116
grad::eta             -0.43383
grad::mu              -0.03343
Elapsed time (sec)  1387.00000
Step 42                 Values
------------------  ----------
valid::error           0.43480
test::error            0.36938
valid::accuracy        0.87429
test::accuracy         0.89185
rho                    0.00024
eta                    0.03665
mu                     0.53506
grad::rho              1.27779
grad::eta             -0.43913
grad::mu              -0.03422
Elapsed time (sec)  1421.00000
Step 43                 Values
------------------  ----------
valid::error           0.43465
test::error            0.36931
valid::accuracy        0.87429
test::accuracy         0.89190
rho                    0.00021
eta                    0.03699
mu                     0.53572
grad::rho              1.36042
grad::eta             -0.43105
grad::mu              -0.03395
Elapsed time (sec)  1455.00000
Step 44                 Values
------------------  ----------
valid::error           0.43445
test::error            0.36901
valid::accuracy        0.87571
test::accuracy         0.89194
rho                    0.00015
eta                    0.03731
mu                     0.53638
grad::rho              0.71862
grad::eta             -0.41094
grad::mu              -0.03269
Elapsed time (sec)  1490.00000
Step 45                 Values
------------------  ----------
valid::error           0.43428
test::error            0.36865
valid::accuracy        0.87571
test::accuracy         0.89198
rho                    0.00011
eta                    0.03763
mu                     0.53704
grad::rho             -0.31706
grad::eta             -0.38320
grad::mu              -0.03079
Elapsed time (sec)  1524.00000
Step 46                 Values
------------------  ----------
valid::error           0.43418
test::error            0.36841
valid::accuracy        0.87714
test::accuracy         0.89214
rho                    0.00011
eta                    0.03795
mu                     0.53768
grad::rho             -1.14520
grad::eta             -0.35884
grad::mu              -0.02911
Elapsed time (sec)  1558.00000
Step 47                 Values
------------------  ----------
valid::error           0.43404
test::error            0.36833
valid::accuracy        0.87857
test::accuracy         0.89213
rho                    0.00016
eta                    0.03825
mu                     0.53832
grad::rho             -1.22059
grad::eta             -0.35014
grad::mu              -0.02867
Elapsed time (sec)  1592.00000
Step 48                 Values
------------------  ----------
valid::error           0.43387
test::error            0.36838
valid::accuracy        0.87571
test::accuracy         0.89213
rho                    0.00022
eta                    0.03855
mu                     0.53896
grad::rho             -0.52621
grad::eta             -0.35675
grad::mu              -0.02950
Elapsed time (sec)  1626.00000
Step 49                 Values
------------------  ----------
valid::error           0.43374
test::error            0.36852
valid::accuracy        0.87857
test::accuracy         0.89216
rho                    0.00026
eta                    0.03885
mu                     0.53960
grad::rho              0.41368
grad::eta             -0.36630
grad::mu              -0.03058
Elapsed time (sec)  1660.00000
Step 50                 Values
------------------  ----------
valid::error           0.43364
test::error            0.36860
valid::accuracy        0.87857
test::accuracy         0.89217
rho                    0.00026
eta                    0.03914
mu                     0.54023
grad::rho              0.99658
grad::eta             -0.36845
grad::mu              -0.03104
Elapsed time (sec)  1694.00000

In [ ]: