vanilla-DNN

Author: Justin Tan

Vanilla neural network. Do anything from MNIST to signal classification.

Update 20/03: Added batch normalization, TensorBoard visualization

Update 19/06: Added cosine annealing, exponential moving average

Update 22/09: Moved input pipeline to tfrecords


In [1]:
import tensorflow as tf
import numpy as np
import glob, time, os
import selu
from diagnostics import *

class config(object):
    # Set network parameters
    mode = 'pi0veto'
    channel = 'Bu2Xsy'
    nFeatures = 44
    keep_prob = 0.72
    num_epochs = 2
    batch_size = 512
    n_layers = 5
    hidden_layer_nodes = [1024, 1024, 512, 512, 256]
    ema_decay = 0.999
    learning_rate = 1e-4
    cycles = 8 # Number of annealing cycles
    n_classes = 2
    builder = 'selu'

class directories(object):
    train ='/home/jtan/gpu/jtan/spark/spark2tf/examples/example_train.tfrecords' #'/var/local/tmp/tfrecords/example_train.tfrecords'
    test = '/home/jtan/gpu/jtan/spark/spark2tf/examples/example_test.tfrecords'#'/var/local/tmp/tfrecords/example_test.tfrecords'
    tensorboard = 'tensorboard'
    checkpoints = 'checkpoints'
    
architecture = '{} - {} | Layers: {} | Dropout: {} | Base LR: {} | Epochs: {}'.format(
    config.channel, config.mode, config.n_layers, config.keep_prob, config.learning_rate, config.num_epochs)
nTrainExamples=5520000#sum(1 for fn in glob.glob(directories.train+'/*') for record in tf.python_io.tf_record_iterator(fn))
get_available_gpus()


Available GPUs:
['/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3']

In [2]:
def dataset_train(dataDirectory, batchSize, numEpochs, nFeatures, training=True):
    filenames = glob.glob('{}/part*'.format(dataDirectory))
    dataset = tf.contrib.data.TFRecordDataset(filenames)

    # Extract data from `tf.Example` protocol buffer
    def parser(record, batchSize=128):
        keys_to_features = {
            "features": tf.FixedLenFeature([nFeatures], tf.float32),
            "labels": tf.FixedLenFeature((), tf.float32,
            default_value=tf.zeros([], dtype=tf.float32)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
        label = tf.cast(parsed['labels'], tf.int32)

        return parsed['features'], label

    # Transform into feature, label tensor pair
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=16384)
    dataset = dataset.batch(batchSize)
    dataset = dataset.repeat(numEpochs) if training else dataset

    return dataset

def dense_builder(x, shape, name, keep_prob, training=True, actv=tf.nn.relu):
    init=tf.contrib.layers.xavier_initializer()
    kwargs = {'center': True, 'scale': True, 'training': training, 'fused': True, 'renorm': True}

    with tf.variable_scope(name, initializer=init) as scope:
        layer = tf.layers.dense(x, units=shape[1], activation=actv)
        bn = tf.layers.batch_normalization(layer, **kwargs)
        layer_out = tf.layers.dropout(bn, keep_prob, training=training)

    return layer_out

def selu_builder(x, shape, name, keep_prob, training=True):
    init = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN')

    with tf.variable_scope(name) as scope:
        W = tf.get_variable("weights", shape = shape, initializer=init)
        b = tf.get_variable("biases", shape = [shape[1]], initializer=tf.random_normal_initializer(stddev=0.1))
        actv = selu.selu(tf.add(tf.matmul(x, W), b))
        layer_output = selu.dropout_selu(actv, rate=1-keep_prob, training=training)

    return layer_output

def dense_model(x, n_layers, hidden_layer_nodes, keep_prob, builder=selu_builder, reuse=False, training=True):
    # Extensible dense model
    SELU_initializer = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN')
    init = SELU_initializer if builder==selu_builder else tf.contrib.layers.xavier_initializer()
    assert n_layers == len(hidden_layer_nodes), 'Specified layer nodes and number of layers do not correspond.'
    layers = [x]

    with tf.variable_scope('dense_model', reuse=reuse):
        hidden_0 = builder(x, shape=[config.nFeatures, hidden_layer_nodes[0]], name='hidden0',
                                keep_prob = keep_prob, training=training)
        layers.append(hidden_0)
        for n in range(0,n_layers-1):
            hidden_n = builder(layers[-1], shape=[hidden_layer_nodes[n], hidden_layer_nodes[n+1]], name='hidden{}'.format(n+1),
                                keep_prob=keep_prob, training=training)
            layers.append(hidden_n)

        readout = tf.layers.dense(hidden_n, units=config.n_classes, kernel_initializer=init)

    return readout

def dense_SELU(x, n_layers, hidden_layer_nodes, keep_prob, reuse=False,
    training=True, actv=selu.selu):
    SELU_initializer = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN')

    with tf.variable_scope('seluNet', reuse=reuse):
        l0 = tf.layers.dense(x, units=hidden_layer_nodes[0], activation=actv,
        kernel_initializer=SELU_initializer)
        d0 = selu.dropout_selu(l0, rate=1-keep_prob, training=training)

        l1 = tf.layers.dense(d0, units=hidden_layer_nodes[1], activation=actv,
        kernel_initializer=SELU_initializer)
        d1 = selu.dropout_selu(l1, rate=1-keep_prob, training=training)

        l2 = tf.layers.dense(d1, units=hidden_layer_nodes[2], activation=actv,
        kernel_initializer=SELU_initializer)
        d2 = selu.dropout_selu(l2, rate=1-keep_prob, training=training)

        l3 = tf.layers.dense(d2, units=hidden_layer_nodes[3], activation=actv,
        kernel_initializer=SELU_initializer)
        d3 = selu.dropout_selu(l3, rate=1-keep_prob, training=training)

        l4 = tf.layers.dense(d3, units=hidden_layer_nodes[4], activation=actv,
        kernel_initializer=SELU_initializer)
        d4 = selu.dropout_selu(l4, rate=1-keep_prob, training=training)

        # Readout layer
        readout = tf.layers.dense(d4, units=config.n_classes,
        kernel_initializer=SELU_initializer)

    return readout

def cosine_anneal(initial_lr, t, T, M):
    from math import ceil
    beta = initial_lr/2 * (np.cos(np.pi* (t % ceil(T/M))/ceil(T/M)) + 1)
    return beta

In [3]:
class vanillaDNN():
    # Builds the computational graph
    def __init__(self, config, training=True, cyclical=False):
        
        self.global_step = tf.Variable(0, trainable=False)
        self.handle = tf.placeholder(tf.string, shape=[])
        self.training_phase = tf.placeholder(tf.bool)
        self.beta = tf.placeholder(tf.float32) if cyclical else config.learning_rate
#         self.beta = tf.train.exponential_decay(config.learning_rate, self.global_step, 
#                                                decay_steps = config.steps_per_epoch, decay_rate = config.lr_epoch_decay, staircase=True)
   
        # Feedable iterator defined by handle placeholder and structure.
        trainDataset = dataset_train(directories.train, batchSize=config.batch_size, numEpochs=config.num_epochs, nFeatures=config.nFeatures)
        testDataset = dataset_train(directories.test, batchSize=config.batch_size, numEpochs=config.num_epochs, nFeatures=config.nFeatures)
        self.iterator = tf.contrib.data.Iterator.from_string_handle(self.handle, trainDataset.output_types, 
                                                               trainDataset.output_shapes)

        
        self.train_iterator = trainDataset.make_one_shot_iterator()
        self.test_iterator = testDataset.make_one_shot_iterator()

        self.example, self.label = self.iterator.get_next()
        # self.readout = dense_SELU(self.example, config.n_layers, [1024, 1024, 512, 512, 256], config.keep_prob, training=self.training_phase)
        self.readout = dense_model(self.example, config.n_layers, config.hidden_layer_nodes, config.keep_prob, training=self.training_phase)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        self.cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits = self.readout, labels = self.label))

        with tf.control_dependencies(update_ops):
            # Ensures that we execute the update_ops before performing the train_step
            self.opt_op = tf.train.AdamOptimizer(self.beta).minimize(self.cross_entropy, name = 'optimizer',
                                                                     global_step = self.global_step)


        self.ema = tf.train.ExponentialMovingAverage(decay = config.ema_decay, num_updates = self.global_step)
        maintain_averages_op = self.ema.apply(tf.trainable_variables())
        
        with tf.control_dependencies([self.opt_op]):
            self.train_op = tf.group(maintain_averages_op)

        # Evaluation metrics
        self.p = tf.nn.softmax(self.readout)
        correct_prediction = tf.equal(tf.cast(tf.argmax(self.readout, 1), tf.int32), self.label)
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        _, self.auc_op = tf.metrics.auc(predictions = tf.argmax(self.readout,1), labels = self.label, num_thresholds = 1024)
        tf.summary.scalar('accuracy', self.accuracy)
        tf.summary.scalar('auc', self.auc_op)
        tf.summary.scalar('learning_rate', self.beta)
        tf.summary.scalar('cross_entropy', self.cross_entropy)
        
        self.merge_op = tf.summary.merge_all()
        self.train_writer = tf.summary.FileWriter(
            os.path.join(directories.tensorboard, 'train_{}'.format(time.strftime('%d-%m_%I:%M'))), graph = tf.get_default_graph())
        self.test_writer = tf.summary.FileWriter(
            os.path.join(directories.tensorboard, 'test_{}'.format(time.strftime('%d-%m_%I:%M'))))

    def predict(self, ckpt):
        pin_cpu = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True, device_count = {'GPU':0})
        start_time = time.time()
        
        # Restore the moving average version of the learned variables for eval.
        variables_to_restore = self.ema.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        valDataset = dataset_train(directories.test, batchSize=262144, numEpochs=config.num_epochs, 
                                    nFeatures=config.nFeatures, training=False)

        val_iterator = valDataset.make_one_shot_iterator()
        concatLabels = tf.cast(self.label, tf.int32)
        concatPreds = tf.cast(tf.argmax(self.readout,1), tf.int32)
        concatOutput = self.p[:,1]

        with tf.Session(config=pin_cpu) as sess:
            # Initialize variables
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            sess.run(tf.local_variables_initializer())
            assert (ckpt.model_checkpoint_path), 'Missing checkpoint file!'    
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('{} restored.'.format(ckpt.model_checkpoint_path))
            val_handle = sess.run(val_iterator.string_handle())
            labels, preds, outputs = [], [], []
            while True:
                try:
                    l, p, o = sess.run([concatLabels, concatPreds, concatOutput], 
                                       feed_dict = {vDNN.training_phase: False, vDNN.handle: val_handle})
                    labels.append(l), preds.append(p), outputs.append(o)
                except tf.errors.OutOfRangeError:
                    labels, preds, outputs = np.concatenate(labels), np.concatenate(preds), np.concatenate(outputs)
                    break
            acc = np.mean(np.equal(labels,preds))
            print("Validation accuracy: {:.3f}".format(acc))
            
            plot_ROC_curve(network_output=outputs, y_true=labels, identifier=config.mode+config.channel,
                           meta=architecture + ' | Test accuracy: {:.3f}'.format(acc))
            delta_t = time.time() - start_time
            print("Inference complete. Duration: %g s" %(delta_t))
            
            return labels, preds, outputs

In [4]:
def train(config, restore = False):
    # Executes training operations
    print('Architecture: {}'.format(architecture))
    vDNN = vanillaDNN(config, training=True)
    start_time = time.time()
    global_step, n_checkpoints, v_auc_best = 0, 0, 0.
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(directories.checkpoints)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
        # Initialize variables
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        train_handle = sess.run(vDNN.train_iterator.string_handle())
        test_handle = sess.run(vDNN.test_iterator.string_handle())
        
        if restore and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('{} restored.'.format(ckpt.model_checkpoint_path))
        
        while True:
            try:
                # Run X steps on training dataset
                for _ in range(4096):
                    sess.run(vDNN.train_op, feed_dict = {vDNN.training_phase: True, vDNN.handle: train_handle})
                    global_step+=1

                # Single pass over validation dataset
                for _ in range(1):
                    epoch, v_auc_best = run_diagnostics(vDNN, config, directories, sess, saver, train_handle, test_handle, 
                                                        global_step, nTrainExamples, start_time, v_auc_best, n_checkpoints)
                    
            except tf.errors.OutOfRangeError:
                break

        save_path = saver.save(sess, os.path.join(directories.checkpoints, 'vDNN_{}_{}_end.ckpt'.format(config.mode, config.channel)), 
                               global_step=epoch)
    
    print("Training Complete. Model saved to file: {} Time elapsed: {:.3f} s".format(save_path, time.time()-start_time))

In [6]:
train(config)


Architecture: Bu2Xsy - pi0veto | Layers: 5 | Dropout: 0.72 | Base LR: 0.0001 | Epochs: 2
Epoch 0, Step 2097152 | Training Acc: 0.896 | Test Acc: 0.734 | Test Loss: 1.384 | Test AUC: 0.832 | Rate: 839 examples/s (121.75 s) [*]
Epoch 0, Step 4194304 | Training Acc: 0.832 | Test Acc: 0.744 | Test Loss: 1.275 | Test AUC: 0.771 | Rate: 6494 examples/s (241.07 s) 
Epoch 1, Step 771456 | Training Acc: 0.811 | Test Acc: 0.799 | Test Loss: 0.869 | Test AUC: 0.777 | Rate: 6476 examples/s (361.55 s) 
Epoch 1, Step 2868608 | Training Acc: 0.807 | Test Acc: 0.773 | Test Loss: 1.078 | Test AUC: 0.797 | Rate: 6213 examples/s (482.85 s) 
Epoch 1, Step 4965760 | Training Acc: 0.953 | Test Acc: 0.791 | Test Loss: 0.811 | Test AUC: 0.807 | Rate: 6415 examples/s (605.07 s) 
Training Complete. Model saved to file: checkpoints/vDNN_pi0veto_Bu2Xsy_end.ckpt-1 Time elapsed: 606.224 s

Making Predictions

Classification on a new instance is given by the softmax of the output of the final readout layer.


In [5]:
ckpt = tf.train.get_checkpoint_state(directories.checkpoints)
vDNN = vanillaDNN(config, training = False)
labels, preds, output = vDNN.predict(ckpt)


INFO:tensorflow:Restoring parameters from checkpoints/vDNN_pi0veto_Bu2Xsy_end.ckpt-1
checkpoints/vDNN_pi0veto_Bu2Xsy_end.ckpt-1 restored.
Validation accuracy: 0.848
AUC: 0.8942641055827576
Plotting signal efficiency versus background rejection
<matplotlib.figure.Figure at 0x2ba1f28cf358>
Inference complete. Duration: 26.0885 s

In [8]:
import seaborn as sns
import matplotlib.pyplot as plt

In [13]:
output.tolist()


Out[13]:
[0.9907813668251038,
 0.05039473995566368,
 0.884621262550354,
 0.9996405839920044,
 0.00015342087135650218,
 0.6280800104141235,
 0.0582210011780262,
 0.9877579212188721,
 0.9872825145721436,
 0.9987074136734009,
 0.16737699508666992,
 0.7736568450927734,
 5.1007084725118e-30,
 0.5163956880569458,
 3.3337519349902964e-14,
 0.7787963151931763,
 0.11351200938224792,
 0.8796751499176025,
 0.9239558577537537,
 0.9999392032623291,
 0.8913349509239197,
 0.99637371301651,
 0.990965723991394,
 4.648340450330579e-07,
 0.9956802129745483,
 0.028188219293951988,
 0.9954485297203064,
 0.9992848038673401,
 0.037943340837955475,
 0.7568750381469727,
 0.9053255319595337,
 0.9674948453903198,
 0.7167554497718811,
 5.4415402089172525e-21,
 0.7716109156608582,
 0.002667702967301011,
 0.5553229451179504,
 7.811603808249856e-08,
 0.5511792898178101,
 1.7536036010312653e-29,
 0.0005530606140382588,
 0.9588214159011841,
 0.9998102784156799,
 0.012071918696165085,
 0.8917171359062195,
 0.9365729689598083,
 0.720715343952179,
 0.991726279258728,
 0.9827690720558167,
 0.9945017099380493,
 0.9615130424499512,
 0.9285325407981873,
 0.9885231852531433,
 0.7459745407104492,
 0.6195662617683411,
 0.9661865830421448,
 6.131432346579276e-20,
 0.9998378753662109,
 2.357361791204352e-32,
 0.9926441311836243,
 0.9999946355819702,
 0.8710817694664001,
 3.0064913339005482e-21,
 4.709630463395342e-09,
 0.9944555163383484,
 0.9858365058898926,
 0.8135520815849304,
 0.9779592752456665,
 1.535405891536339e-23,
 0.544379711151123,
 0.2946131229400635,
 0.7425923347473145,
 0.9590376019477844,
 0.002682980615645647,
 2.199954782096822e-23,
 0.3929561376571655,
 0.9261753559112549,
 0.9779912829399109,
 0.38148224353790283,
 0.0014391717268154025,
 0.7791245579719543,
 0.1015225350856781,
 0.9811270833015442,
 0.9997485280036926,
 4.265167490296196e-15,
 0.5008424520492554,
 0.9680300951004028,
 0.41957467794418335,
 0.9997144341468811,
 0.9961099028587341,
 0.999764621257782,
 0.9973865151405334,
 0.7848142385482788,
 0.0,
 0.810299277305603,
 0.9568296670913696,
 9.646196122048423e-05,
 0.831780195236206,
 0.9904994368553162,
 7.591350481561676e-07,
 0.9233107566833496,
 3.712369270043913e-19,
 0.0,
 6.76633055141096e-11,
 0.6192763447761536,
 1.265534343986019e-08,
 0.0003953269624616951,
 0.0022120261564850807,
 0.9995892643928528,
 0.26718589663505554,
 2.9033130259531914e-24,
 7.081282995790672e-22,
 0.9375547766685486,
 0.9570314288139343,
 0.9991719722747803,
 0.9942975640296936,
 0.9739757776260376,
 0.9849735498428345,
 3.475628110779705e-19,
 0.759407103061676,
 0.00010810273670358583,
 0.8955947756767273,
 0.9749541878700256,
 0.9996697902679443,
 0.986570417881012,
 0.0015807270538061857,
 0.8418996334075928,
 1.2221555407736584e-28,
 4.039817369516641e-09,
 0.007675881497561932,
 0.9833166003227234,
 1.0410824415885145e-06,
 0.6593395471572876,
 0.00015444538439624012,
 4.6040190682106186e-07,
 0.9491873979568481,
 7.218450264190324e-07,
 0.9489790201187134,
 0.03673708066344261,
 0.0006113869021646678,
 0.12209878116846085,
 0.659618616104126,
 0.05909285321831703,
 0.9822865128517151,
 0.801372766494751,
 0.9777550101280212,
 0.8866784572601318,
 0.4223741292953491,
 1.044875644340415e-11,
 0.004027624148875475,
 4.126535143066302e-16,
 0.8756007552146912,
 1.8324218039748393e-19,
 0.7082085609436035,
 0.7645165324211121,
 0.9793766736984253,
 0.9860014915466309,
 0.9940457344055176,
 0.7592846155166626,
 0.008363292552530766,
 0.012399738654494286,
 0.5586126446723938,
 2.7428434550740358e-15,
 0.858317494392395,
 2.881017202582825e-08,
 0.9784666895866394,
 0.000822880887426436,
 0.745221734046936,
 0.2434251308441162,
 0.0,
 1.3136339238606354e-10,
 0.9978982210159302,
 0.9371930360794067,
 0.9761151671409607,
 0.9980661273002625,
 0.8493164777755737,
 0.7025588750839233,
 0.043169908225536346,
 0.004673896823078394,
 0.9960021376609802,
 0.8119084239006042,
 5.500857702228471e-20,
 0.8916662335395813,
 1.7380834836249975e-17,
 0.7966852188110352,
 2.2194049265067406e-08,
 9.630909971747315e-07,
 0.7602645754814148,
 0.0015051465015858412,
 0.99992835521698,
 0.013933666981756687,
 0.8558840155601501,
 2.246723340704193e-08,
 0.22200198471546173,
 4.323469699784197e-22,
 0.0038578093517571688,
 0.007960334420204163,
 0.9598391056060791,
 1.0415649410163055e-18,
 0.0027555618435144424,
 0.8813332319259644,
 0.00028592764283530414,
 0.015043025836348534,
 0.9652674794197083,
 0.9958525896072388,
 0.9765846133232117,
 0.9998169541358948,
 0.060206569731235504,
 0.9135982394218445,
 1.060915591488687e-10,
 0.999995231628418,
 3.519766938961836e-33,
 0.4028857350349426,
 0.999729573726654,
 4.999218072043732e-05,
 0.9981227517127991,
 8.689795998689807e-19,
 0.9998879432678223,
 0.9990764856338501,
 8.768494069701582e-16,
 0.9732479453086853,
 0.02382492460310459,
 0.00021827328600920737,
 6.342576671158895e-05,
 3.9988745470509457e-07,
 0.0015175238950178027,
 0.5931071639060974,
 0.9962588548660278,
 0.11453244090080261,
 0.00016353891987819225,
 0.9992496371269226,
 0.11318626999855042,
 0.9999903440475464,
 0.9986196756362915,
 0.9897170066833496,
 9.751202014740556e-05,
 0.9349307417869568,
 0.9822587966918945,
 0.2402026206254959,
 0.40713170170783997,
 1.548848871965429e-14,
 9.87343355518533e-07,
 0.00387214170768857,
 1.1867580529401067e-10,
 3.91834182664752e-05,
 2.42371653712995e-33,
 2.1981691165517637e-22,
 0.00024620257318019867,
 0.8016248345375061,
 0.0031905807554721832,
 1.0306398393234521e-27,
 0.42112651467323303,
 0.0,
 6.093631565834666e-21,
 6.896569243311998e-20,
 0.6254987120628357,
 5.1007084725118e-30,
 0.8510136008262634,
 0.9975417852401733,
 0.6157392859458923,
 0.8706943392753601,
 0.5684421062469482,
 0.04249873012304306,
 0.8336203098297119,
 0.03498150035738945,
 0.9983281493186951,
 6.194324669195339e-06,
 0.029114678502082825,
 0.8253659605979919,
 0.9998766183853149,
 0.802760124206543,
 0.9784427881240845,
 0.8102816939353943,
 0.9986832737922668,
 0.9998313188552856,
 0.000642026774585247,
 0.8242563605308533,
 4.0463295908343966e-11,
 0.9777751564979553,
 0.0008019189699552953,
 0.4455636143684387,
 4.897754441408432e-28,
 0.2920900881290436,
 0.5998949408531189,
 0.7792483568191528,
 0.01155032403767109,
 0.12851369380950928,
 0.006178895942866802,
 0.017612027004361153,
 2.5370134153490653e-06,
 0.9652674794197083,
 0.999934196472168,
 0.8710686564445496,
 1.1115166133422033e-13,
 3.6973465955547624e-11,
 0.9999966621398926,
 0.8651831150054932,
 6.030527297840317e-20,
 0.9559729099273682,
 0.717013418674469,
 0.9944887161254883,
 0.8615578413009644,
 0.9585404396057129,
 0.7410907745361328,
 4.062527165666006e-24,
 0.9769747257232666,
 3.0779008231007765e-09,
 1.5822284368472534e-33,
 2.9150971386115998e-05,
 0.7943753004074097,
 0.0,
 4.601688203820231e-07,
 0.08663008362054825,
 0.9981878399848938,
 2.227686397209422e-32,
 2.495349752275839e-30,
 0.7230226993560791,
 0.05130835995078087,
 0.5764269828796387,
 0.6842902302742004,
 0.011941835284233093,
 0.9875714778900146,
 0.9560330510139465,
 0.06145534664392471,
 0.9999697208404541,
 0.24742081761360168,
 0.8225245475769043,
 0.9450522661209106,
 0.7825350165367126,
 1.4858706157578673e-36,
 0.9970955848693848,
 0.9771530628204346,
 0.0,
 0.9460485577583313,
 0.9786696434020996,
 0.7848706245422363,
 0.9925262928009033,
 0.0110146040096879,
 0.9892610311508179,
 0.003403616836294532,
 0.8734604716300964,
 0.2094828486442566,
 0.3637388050556183,
 0.3362067639827728,
 0.2973164916038513,
 0.9588133096694946,
 0.8259446024894714,
 0.9992200136184692,
 0.9959646463394165,
 0.4850097596645355,
 0.27458685636520386,
 5.999570467312303e-21,
 0.0,
 0.5597532391548157,
 0.9593170881271362,
 0.9932804703712463,
 0.9447270035743713,
 0.9960491061210632,
 6.608412377318018e-07,
 0.46863484382629395,
 0.030984103679656982,
 0.9700325727462769,
 0.00029735316638834774,
 0.9592794179916382,
 0.23963457345962524,
 0.9677055478096008,
 0.9664562344551086,
 7.499497587559745e-05,
 0.9694162607192993,
 0.9999785423278809,
 0.9488241076469421,
 4.527186749224718e-24,
 0.03387337550520897,
 0.9999394416809082,
 1.2578545982577793e-09,
 0.01495937630534172,
 0.2101057916879654,
 0.8264070153236389,
 0.42660149931907654,
 5.8796466460538036e-21,
 6.341258267639205e-05,
 0.547993004322052,
 0.003332463325932622,
 0.0,
 0.9590609669685364,
 0.0,
 0.03521322086453438,
 0.021296463906764984,
 0.004806661047041416,
 0.833922803401947,
 0.9999607801437378,
 0.36968740820884705,
 0.14276263117790222,
 0.7532050013542175,
 0.0010272098006680608,
 0.5707389116287231,
 0.9652359485626221,
 7.339324148425419e-10,
 0.9999300241470337,
 0.9575607180595398,
 0.1891431361436844,
 0.9945530891418457,
 2.7571553928851245e-09,
 4.554574173277958e-31,
 0.0001120438682846725,
 0.0006266694399528205,
 0.9690019488334656,
 0.3955489695072174,
 0.9948490262031555,
 6.795051632479954e-08,
 0.08603546023368835,
 0.9479504823684692,
 0.72103351354599,
 0.8501611351966858,
 1.1597768341965087e-23,
 0.9489769339561462,
 1.1671597686173302e-27,
 0.9942030310630798,
 0.9939957857131958,
 0.12910333275794983,
 0.6372085809707642,
 0.994023859500885,
 6.0868431334733936e-15,
 0.787786602973938,
 2.0406628209457267e-06,
 4.145977982261684e-06,
 0.19106823205947876,
 0.9994114637374878,
 3.260212277752816e-17,
 0.08096954226493835,
 0.9988874793052673,
 0.6735149025917053,
 0.1611880660057068,
 0.720715343952179,
 0.9873773455619812,
 0.9984242916107178,
 6.371621452672116e-07,
 5.430320774832654e-29,
 0.9999982118606567,
 0.5588401556015015,
 0.882703423500061,
 0.9987962245941162,
 0.8670592904090881,
 0.003932415507733822,
 0.9857215881347656,
 0.0,
 0.9785526394844055,
 0.8839476108551025,
 0.9486073851585388,
 1.3410406934639049e-22,
 1.5456837215997764e-16,
 0.6656900644302368,
 0.8821372985839844,
 1.5222029301340428e-30,
 0.25554582476615906,
 1.2834578863506057e-22,
 0.0,
 0.019911082461476326,
 0.10530281811952591,
 0.9341623783111572,
 0.8281369209289551,
 0.0579177625477314,
 2.984824283404567e-29,
 0.0038777512963861227,
 0.03521322086453438,
 3.895363249739603e-07,
 0.25227853655815125,
 0.999923586845398,
 1.1282206516425192e-34,
 0.9996262788772583,
 1.9160822928299126e-10,
 0.28529849648475647,
 0.9421833753585815,
 0.9928326606750488,
 0.0,
 1.1577475333979237e-06,
 0.3991885483264923,
 0.999950647354126,
 0.018166987225413322,
 2.5008671400428284e-06,
 0.7080745697021484,
 0.9137699007987976,
 0.00038303283508867025,
 0.08542847633361816,
 4.824863991952262e-14,
 0.08077766746282578,
 5.214564589550719e-05,
 1.638071717998327e-36,
 0.9985688924789429,
 0.9736044406890869,
 0.0005121970898471773,
 0.9923048615455627,
 1.6353831893654834e-33,
 0.9992594122886658,
 0.0,
 0.7934457063674927,
 0.9956035614013672,
 3.947549904377206e-33,
 2.984690539165058e-08,
 0.0027295283507555723,
 0.02836478501558304,
 0.5558547377586365,
 0.20213228464126587,
 0.004185635130852461,
 0.5536692142486572,
 0.9720465540885925,
 8.377122211820903e-21,
 0.9991573095321655,
 0.9935774207115173,
 0.9670475125312805,
 1.0493875834244065e-22,
 0.9999302625656128,
 0.5772183537483215,
 0.9978570342063904,
 1.750313792011393e-31,
 0.9143568873405457,
 0.9999953508377075,
 0.022976821288466454,
 0.3703846335411072,
 0.8553131818771362,
 0.3118666112422943,
 0.9997316002845764,
 0.0038403065409511328,
 0.9985540509223938,
 0.9652183055877686,
 0.26009517908096313,
 1.0109811654501755e-08,
 0.6234498620033264,
 2.558906864180699e-09,
 0.8661386370658875,
 4.384905602967281e-16,
 0.00013539136853069067,
 0.10577860474586487,
 0.8794561624526978,
 0.04320254549384117,
 5.8504706430539954e-06,
 0.8855071663856506,
 3.12923778783035e-22,
 1.6186676588386796e-27,
 3.720761654335547e-08,
 0.9819730520248413,
 0.677070677280426,
 0.42159536480903625,
 0.07161189615726471,
 0.4704573154449463,
 0.943310558795929,
 0.830120861530304,
 0.6075963973999023,
 0.0,
 4.058067928982159e-31,
 0.0006774079520255327,
 0.8174142241477966,
 3.146441940771183e-06,
 0.16226331889629364,
 8.399389272980074e-37,
 0.7796065807342529,
 0.972859799861908,
 0.7327096462249756,
 0.999504804611206,
 0.26227113604545593,
 0.9888494610786438,
 7.988106517586857e-06,
 0.9802026748657227,
 0.0021191765554249287,
 0.4052523970603943,
 4.784438921028077e-15,
 0.9502782821655273,
 0.0,
 0.9999130964279175,
 0.3097434937953949,
 0.576941967010498,
 0.9334170818328857,
 0.9709354043006897,
 3.423306755993095e-21,
 0.3140771687030792,
 0.0,
 0.0010068987030535936,
 0.7170798182487488,
 3.5304790912960016e-07,
 5.897283540491971e-35,
 0.9980645775794983,
 3.1833032577852893e-18,
 0.3140549659729004,
 0.9999856948852539,
 0.7047194838523865,
 0.8416957855224609,
 0.9559873938560486,
 0.9997934699058533,
 0.003031548112630844,
 0.5460397005081177,
 0.9999861717224121,
 0.8191222548484802,
 0.9797284007072449,
 0.8270653486251831,
 0.8852736353874207,
 1.0065951755677816e-05,
 0.94650799036026,
 0.001020408933982253,
 0.880054771900177,
 0.023853763937950134,
 0.868408739566803,
 4.026805798185213e-18,
 0.0021449304185807705,
 0.9832797050476074,
 0.981919527053833,
 0.9815242886543274,
 0.7569570541381836,
 1.182383082822779e-19,
 0.001184663618914783,
 0.9167937636375427,
 0.0014829194406047463,
 0.9862558245658875,
 0.0005756296450272202,
 0.0023589369375258684,
 0.3500722348690033,
 1.2731445082023635e-22,
 0.9997175335884094,
 0.9996579885482788,
 1.8678536434890702e-05,
 1.2990471166814995e-38,
 0.7385057210922241,
 0.0003625985118560493,
 0.0020198000129312277,
 0.8054090738296509,
 0.21632149815559387,
 1.4321719118441118e-30,
 0.9981979727745056,
 3.426684361262583e-17,
 0.9661266207695007,
 0.997576892375946,
 0.9992687106132507,
 0.3064834475517273,
 0.014109540730714798,
 0.9999779462814331,
 0.9890110492706299,
 0.6999473571777344,
 6.65991931261764e-29,
 0.999336302280426,
 0.9534661769866943,
 0.9946445226669312,
 0.9996304512023926,
 5.149036903375134e-37,
 0.5029244422912598,
 3.4575743484310806e-05,
 0.894957423210144,
 0.9993371367454529,
 3.539962467628759e-27,
 0.9930921792984009,
 0.804103672504425,
 0.995342493057251,
 2.490286078682402e-06,
 0.6184222102165222,
 0.05659213289618492,
 0.9751229286193848,
 0.867205023765564,
 0.044890325516462326,
 0.9973137974739075,
 0.0,
 0.22165486216545105,
 0.013079768046736717,
 0.16108736395835876,
 8.907869641916477e-08,
 0.9415039420127869,
 0.07338161766529083,
 0.9990262985229492,
 0.9307928681373596,
 0.9434071779251099,
 2.1579362510237843e-05,
 0.9939159750938416,
 0.2512285113334656,
 0.9662351608276367,
 0.04488743469119072,
 0.9998834133148193,
 0.0,
 0.7230562567710876,
 0.9287179708480835,
 1.441231811357465e-14,
 0.023209035396575928,
 0.5799762606620789,
 0.0036413269117474556,
 3.8347559893736616e-05,
 0.9921659231185913,
 0.9995904564857483,
 0.5246551632881165,
 0.9884707927703857,
 0.0011417652713134885,
 0.825749397277832,
 0.046105820685625076,
 0.41533163189888,
 0.1698428988456726,
 3.322014890727587e-05,
 5.46997384275414e-27,
 0.6169763207435608,
 0.11824487894773483,
 4.0164415111099366e-30,
 0.38148224353790283,
 0.9949501752853394,
 0.912842869758606,
 0.03421127051115036,
 0.9977001547813416,
 0.7494027614593506,
 0.9999703168869019,
 0.8182739019393921,
 2.3557529260642696e-09,
 0.9997790455818176,
 0.75749272108078,
 0.006093466188758612,
 0.8672223687171936,
 0.9846515655517578,
 0.9889950156211853,
 0.10005716979503632,
 0.9060381650924683,
 0.7834839224815369,
 1.4951007955643055e-33,
 0.152635857462883,
 3.2876877753551526e-07,
 0.9999984502792358,
 1.1649732414298342e-06,
 0.9999791383743286,
 0.8081140518188477,
 0.8821372985839844,
 0.9944339394569397,
 1.6199319361476228e-05,
 0.9153187274932861,
 0.5561962723731995,
 0.9886294603347778,
 0.12437250465154648,
 2.63453785009915e-05,
 5.513774328588491e-29,
 0.8917171359062195,
 0.07729538530111313,
 0.8904231786727905,
 3.565530134325406e-15,
 0.9948091506958008,
 0.9784489870071411,
 0.9460821747779846,
 0.9966655373573303,
 0.9990142583847046,
 0.5875645279884338,
 0.907046377658844,
 0.9998465776443481,
 1.709375510472455e-06,
 0.9836069345474243,
 0.011615133844316006,
 0.4456110894680023,
 0.999350368976593,
 0.9463541507720947,
 0.04840086027979851,
 6.740538992744405e-06,
 8.001472451724112e-05,
 1.943379629665287e-06,
 0.0,
 1.201220015900617e-06,
 6.322676603750909e-31,
 8.893400809029117e-05,
 0.9887922406196594,
 0.7466632127761841,
 0.9988046884536743,
 0.8853980302810669,
 0.8034935593605042,
 0.6559690833091736,
 0.6058226227760315,
 0.9910516738891602,
 0.7370922565460205,
 2.0200130544638827e-19,
 0.0,
 0.4470866322517395,
 0.10882315784692764,
 0.7765461802482605,
 0.3717862367630005,
 0.971238911151886,
 0.5617297887802124,
 2.4078310691056843e-24,
 8.351759674951609e-07,
 0.9315819144248962,
 0.6937048435211182,
 0.0028619503136724234,
 0.7380889654159546,
 0.48726382851600647,
 0.17919597029685974,
 0.10078345239162445,
 9.564673058584307e-29,
 0.27100640535354614,
 3.1459941965295e-05,
 7.325355599885341e-13,
 0.00026253503165207803,
 0.9597129225730896,
 0.6054542064666748,
 0.9953752756118774,
 0.5648270845413208,
 0.9602113366127014,
 0.9975390434265137,
 0.9973170161247253,
 0.003827891545370221,
 1.8098261816407017e-16,
 0.6198335289955139,
 0.9999475479125977,
 4.2431489418959245e-05,
 0.7766709923744202,
 1.6110660133481831e-27,
 0.8913708329200745,
 1.4285501492850017e-07,
 0.0,
 0.8029621839523315,
 0.7017467021942139,
 0.07981622219085693,
 0.5998949408531189,
 0.0009404151351191103,
 0.9980327486991882,
 0.8752758502960205,
 0.9236344695091248,
 3.123353498816755e-25,
 0.9999805688858032,
 0.036076877266168594,
 0.18882544338703156,
 0.7606800198554993,
 0.04344206675887108,
 2.6153322255654e-31,
 0.9702962636947632,
 0.5986231565475464,
 0.9749791026115417,
 0.6034210324287415,
 9.354359945765071e-25,
 0.2617012858390808,
 0.2215651273727417,
 1.2117721379025403e-18,
 0.9691234827041626,
 0.13636483252048492,
 0.9999909400939941,
 2.47999713441069e-21,
 0.9841530919075012,
 0.9891931414604187,
 0.9255242943763733,
 0.9857726693153381,
 0.99704509973526,
 0.5819037556648254,
 0.0016225839499384165,
 0.995672881603241,
 6.75356131978333e-05,
 3.732815457624383e-05,
 0.9915958046913147,
 2.8931444668899296e-22,
 0.7068904638290405,
 0.945927619934082,
 2.973321967054878e-15,
 0.7751308083534241,
 0.717013418674469,
 0.8348455429077148,
 0.9871243238449097,
 2.517885011599219e-09,
 0.9966208934783936,
 3.687125221380789e-24,
 0.9505706429481506,
 0.9747278094291687,
 0.999984860420227,
 0.9994864463806152,
 0.926215648651123,
 0.0,
 0.37997928261756897,
 0.9919061660766602,
 5.8319958043284714e-05,
 0.8140643239021301,
 0.9920456409454346,
 0.7080745697021484,
 6.12292148914364e-36,
 0.9946593642234802,
 0.9997327923774719,
 1.3004510391722429e-14,
 0.9985044002532959,
 2.984690539165058e-08,
 9.318662108626086e-08,
 0.8747587203979492,
 0.6983076333999634,
 0.9993816614151001,
 5.836198736860846e-31,
 0.9399607181549072,
 0.02739552967250347,
 0.9388659596443176,
 0.9740647077560425,
 0.9660518765449524,
 0.0014023033436387777,
 0.9881183505058289,
 0.8594507575035095,
 0.0028343377634882927,
 0.9885421991348267,
 0.3476802110671997,
 0.08978930115699768,
 0.0017469897866249084,
 0.0031960750930011272,
 3.3745455002633056e-27,
 0.8847326636314392,
 6.643111322215007e-17,
 0.7509255409240723,
 0.4879135489463806,
 2.471523885105853e-06,
 5.1537474064389244e-05,
 2.822654554348594e-38,
 0.9904273748397827,
 4.622986289399769e-29,
 6.345290254705688e-24,
 0.9585162997245789,
 0.7174072265625,
 0.00015241200162563473,
 1.6240259128608159e-06,
 0.9652526378631592,
 0.9982983469963074,
 6.045617010917845e-19,
 0.9784564971923828,
 0.9785888195037842,
 0.9520010948181152,
 0.6086170077323914,
 0.0,
 0.03792516887187958,
 0.6598922610282898,
 7.260825512882263e-13,
 0.9982226490974426,
 0.00034666384453885257,
 0.0004886348033323884,
 0.10106366872787476,
 8.77886039639977e-18,
 0.9144126772880554,
 0.001961778150871396,
 0.9653871655464172,
 0.9733970165252686,
 0.9922731518745422,
 0.6990535855293274,
 0.8294259309768677,
 0.9999638795852661,
 0.9936954379081726,
 0.03707238659262657,
 0.0,
 0.9904273748397827,
 0.11500763148069382,
 0.9819712042808533,
 0.0,
 0.9697737693786621,
 0.9996860027313232,
 0.5659996271133423,
 0.6416997313499451,
 0.6127889156341553,
 6.018216480718763e-10,
 7.933967552423225e-35,
 0.6648234724998474,
 0.896718442440033,
 0.0003887229831889272,
 0.9929038882255554,
 0.13801316916942596,
 0.3836430013179779,
 6.379738692885439e-07,
 0.021613990887999535,
 0.20872539281845093,
 0.807988703250885,
 6.382665132166908e-35,
 4.3356740893091706e-15,
 3.641804358078815e-10,
 0.00042686922824941576,
 9.94984456849005e-17,
 0.00015123678895179182,
 3.098825118286186e-06,
 0.039450887590646744,
 0.421340674161911,
 0.877007007598877,
 0.009903046302497387,
 0.6279163956642151,
 0.9990721940994263,
 0.2122027724981308,
 0.046542368829250336,
 0.9995357990264893,
 0.8429683446884155,
 1.1216359041554824e-07,
 0.8571765422821045,
 0.9999974966049194,
 0.8840159773826599,
 0.9999003410339355,
 0.9901665449142456,
 0.999944806098938,
 0.0,
 0.9933109283447266,
 6.668481091764988e-06,
 2.4254233361194343e-29,
 0.8422706723213196,
 2.6660706121963925e-16,
 0.0008365473477169871,
 0.010523132048547268,
 0.0080649359151721,
 0.7100608944892883,
 0.9999903440475464,
 0.008162623271346092,
 0.7613001465797424,
 1.108030863652993e-16,
 2.171421239420719e-25,
 0.9856776595115662,
 0.1986595094203949,
 0.3272842466831207,
 0.9980852603912354,
 0.8561562895774841,
 0.1407991349697113,
 5.482133844103565e-11,
 1.8785727520684e-15,
 0.9687801003456116,
 ...]

In [12]:
sns.distplot(output)
plt.show()



In [6]:
ckpt = tf.train.get_checkpoint_state(directories.checkpoints)
vDNN = vanillaDNN(config, training = False)
labels, preds, output = vDNN.predict(ckpt)


Building SELU architecture
INFO:tensorflow:Restoring parameters from checkpoints/vDNN_pi0veto_Bu2Xsy_end.ckpt-0
checkpoints/vDNN_pi0veto_Bu2Xsy_end.ckpt-0 restored.
Validation accuracy: 0.857
AUC: 0.8904646780489629
Plotting signal efficiency versus background rejection
<matplotlib.figure.Figure at 0x2af6f132ac50>
Inference complete. Duration: 42.6341 s

In [ ]: