In [4]:
import helpers as hp
import tensorflow as tf
import AutoBrake as ab

In [5]:
data = hp.get_data()
training_inds, test_inds = hp.get_train_test_split(70000, 10000)

In [ ]:
with tf.Graph().as_default():
    test_glimpses = data['input_glimpses'][test_inds, :, :, :1]
    test_gazes = data['input_gazes'][test_inds]
    test_seq = data['input_sequences'][test_inds]
    test_output = data['outputs'][test_inds]

    CONVNET_FILE_NAME = "model/convnet_.ckpt"
    
    with tf.Session() as sess:
        # Define ops
        infer_op = ab.inference()
        loss_op = ab.loss(infer_op)
        train_op = ab.train(infer_op)
        acc_op = ab.accuracy(infer_op)
        
        init = tf.initialize_all_variables()
        sess.run(init)
        saver = tf.train.Saver()
    #     ops = sess.graph.get_operations()
    #     for op in ops:
    #         print(op.values())
    #     te = tf.report_uninitialized_variables()
    #     tes = sess.run(te)
        saver.restore(sess, CONVNET_FILE_NAME)
        print("Restored model to full power...")
        loss, acc = sess.run([loss_op, acc_op],
                             feed_dict=
                                 {
                                    "images:0": test_glimpses,
                                    "gazes:0": test_gazes,
                                    "brake_sequences:0": test_seq,
                                    "expected_braking:0": test_output
                                }
                            )
        print("\tCross-entropy: {:.3f}\tAccuracy: {:.3f}".format(loss, acc))
        for epoch in range(100):
            batches = hp.minibatch(training_inds, 10000, len(training_inds))
            print("(Epoch {0}) Batches: {1}".format(epoch + 1, len(batches)))
            for batch_num, index_batch in enumerate(batches):
                print("\tProcessing batch {0}".format(batch_num + 1))
                glimpses = data['input_glimpses'][index_batch, :, :, :1]
                gazes = data['input_gazes'][index_batch]
                seq = data['input_sequences'][index_batch]
                output = data['outputs'][index_batch]

                sess.run(train_op,
                         feed_dict=
                             {
                                "images:0": glimpses,
                                "gazes:0": gazes,
                                "brake_sequences:0": seq,
                                "expected_braking:0": output
                            }
                    )
            
            loss, acc = sess.run([loss_op, acc_op],
                         feed_dict=
                             {
                                "images:0": test_glimpses,
                                "gazes:0": test_gazes,
                                "brake_sequences:0": test_seq,
                                "expected_braking:0": test_output
                            }
                        )
            print("\tCross-entropy: {:.3f}\tAccuracy: {:.3f}".format(loss, acc))
            
            save_path = saver.save(sess, CONVNET_FILE_NAME)
            print("\tModel saved in file: %s" % save_path)


Restored model to full power...
	Cross-entropy: 11.287	Accuracy: 0.510
(Epoch 1) Batches: 7
	Processing batch 1
	Processing batch 2
	Processing batch 3

In [ ]: