In [57]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
import tensorflow as tf
import numpy as np

# training parameters
training_epochs = 30
batch_size = 128

# architecture parameters
n_labels = 10
image_pixels = 28 * 28
bottleneck = 128


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

In [107]:
def train_and_test(mode="c_is_softmax_prob", seed=100, learning_rate=0.001):
    '''
    modes: c_is_softmax_prob, c_is_trained_softmax_prob, c_is_cotrained_sigmoid, c_is_auxiliary_sigmoid
    '''
    
    graph = tf.Graph()
    with graph.as_default():
        tf.set_random_seed(seed)  # seed set upon graph construction; does not work

        x = tf.placeholder(dtype=tf.float32, shape=[None, image_pixels])
        y = tf.placeholder(dtype=tf.float32, shape=[None, n_labels])

        def gelu(x):
            return 0.5 * x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))
        f = gelu

        W = {}
        b = {}

        with tf.variable_scope("classifier"):
            W['1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([image_pixels, 256]), 0))
            W['2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
            W['3'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
            W['logits'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, n_labels]), 0))

            b['1'] = tf.Variable(tf.zeros([256]))
            b['2'] = tf.Variable(tf.zeros([256]))
            b['3'] = tf.Variable(tf.zeros([256]))
            b['logits'] = tf.Variable(tf.zeros([n_labels]))

        with tf.variable_scope("confidence_scorer"):
            W['hidden_to_conf1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 512]), 0))
            W['logits_to_conf1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([n_labels, 512]), 0))
            W['conf2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([512, 128]), 0))
            W['conf'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([128, 1]), 0))

            b['conf1'] = tf.Variable(tf.zeros([512]))
            b['conf2'] = tf.Variable(tf.zeros([128]))
            b['conf'] = tf.Variable(tf.zeros([1]))

        def cautious_fcn(x):
            h1 = f(tf.matmul(x, W['1']) + b['1'])
            h2 = f(tf.matmul(h1, W['2']) + b['2'])
            h3 = f(tf.matmul(h2, W['3']) + b['3'])
            logits_out = tf.matmul(h3, W['logits']) + b['logits']

            conf1 = f(tf.matmul(logits_out, W['logits_to_conf1']) +
                        tf.matmul(h2, W['hidden_to_conf1']) + b['conf1'])
            conf2 = f(tf.matmul(conf1, W['conf2']) + b['conf2'])
            conf_out = tf.matmul(conf2, W['conf']) + b['conf']

            return logits_out, tf.squeeze(conf_out)

        logits, confidence_logit = cautious_fcn(x)

        right_answer = tf.stop_gradient(tf.to_float(tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))))
        compute_error = 100*tf.reduce_mean(1 - right_answer)

        classification_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, y))
        if "softmax" in mode:
            confidence_logit = tf.reduce_max(tf.nn.softmax(logits), reduction_indices=[1])
            caution_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(confidence_logit, right_answer))
            
            # cc_loss is cautious classification loss
            if mode == "c_is_trained_softmax_prob":
                cc_loss = classification_loss + caution_loss
            else:
                cc_loss = classification_loss
        
        elif mode == "c_is_cotrained_sigmoid":
            caution_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(confidence_logit, right_answer))
            cc_loss = classification_loss + caution_loss
            confidence = tf.sigmoid(confidence_logit)
        elif mode == "c_is_auxiliary_sigmoid":
            caution_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(confidence_logit, right_answer))
            cc_loss = classification_loss  # we use caution_loss after training normal classifier
        else:
            assert False, "Invalid mode specified"
        
        cc_calibration_score = tf.reduce_mean((2 * right_answer - 1) * (2 * tf.sigmoid(confidence_logit) - 1))
        cc_model_score = tf.reduce_mean(right_answer * ((2 * right_answer - 1) * (2 * tf.sigmoid(confidence_logit) - 1)+ 1)/2)
        
        # cautious classification perplexity
        cc_calibration_perplexity = tf.exp(caution_loss)
        cc_model_perplexity = tf.exp(caution_loss + classification_loss)
        
        lr = tf.constant(learning_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(cc_loss)

    sess = tf.InteractiveSession(graph=graph)
    
    if "softmax" in mode:
        sess.run(tf.initialize_all_variables())
    
    elif mode == "c_is_cotrained_sigmoid":
        sess.run(tf.initialize_all_variables())
    
    elif mode == "c_is_auxiliary_sigmoid":
        thawed_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "classifier")
        frozen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "confidence_scorer")
        sess.run(tf.initialize_variables(set(tf.all_variables()) - set(frozen_vars)))
    
    err_ema = 90
    cc_calibration_perp_ema = 10
    cc_model_perp_ema = 10
    cc_calibration_score_ema = -1
    cc_model_score_ema = -1
    num_batches = (mnist.train.num_examples + mnist.validation.num_examples) // batch_size
    
    for epoch in range(1,training_epochs+1):
        if epoch >= 20:
            learning_rate *= 0.1
        for i in range(num_batches):
            if i < mnist.train.num_examples//batch_size: 
                bx, by = mnist.train.next_batch(batch_size)
            else:  # there is no need to hold out the validation set
                bx, by = mnist.validation.next_batch(batch_size)
                
            if mode != "c_is_auxiliary_sigmoid":
                _, err, cc_model_score_curr, cc_calibration_score_curr,\
                cc_model_perp_curr, cc_calibration_perp_curr = sess.run([
                        optimizer, compute_error, cc_model_score, cc_calibration_score,
                        cc_model_perplexity, cc_calibration_perplexity],
                     feed_dict={x: bx, y: by, lr: learning_rate})
                
                err_ema = err_ema * 0.95 + 0.05 * err
                cc_calibration_perp_ema = cc_calibration_perp_ema * 0.95 + 0.05 * cc_calibration_perp_curr
                cc_model_perp_ema = cc_model_perp_ema * 0.95 + 0.05 * cc_model_perp_curr
                cc_calibration_score_ema = cc_calibration_score_ema * 0.95 + 0.05 * cc_calibration_score_curr
                cc_model_score_ema = cc_model_score_ema * 0.95 + 0.05 * cc_model_score_curr
            else:
                _, err, l = sess.run([optimizer, compute_error, cc_loss],
                                     feed_dict={x: bx, y: by, lr: learning_rate})
                err_ema = err_ema * 0.95 + 0.05 * err
        
        if epoch % 10 == 0:
            print('Epoch', epoch, ' | ', 'Current Classification Error (%)', err_ema)
            if mode != "c_is_auxiliary_sigmoid":
                print('Epoch', epoch, ' | ', 'Cautious Classification Calibration Perp', cc_calibration_perp_ema)
                print('Epoch', epoch, ' | ', 'Cautious Classification Model Perp', cc_model_perp_ema)
                print('Epoch', epoch, ' | ', 'Cautious Classification Calibration Score', cc_calibration_score_ema)
                print('Epoch', epoch, ' | ', 'Cautious Classification Model Score', cc_model_score_ema)

    if mode == "c_is_auxiliary_sigmoid":
        # train sigmoid separately from the classifier
        phase2_vars = list(set(tf.all_variables()) - set(thawed_vars))
        optimizer2 = tf.train.AdamOptimizer(learning_rate=0.001).minimize(caution_loss, var_list=phase2_vars)
        sess.run(tf.initialize_variables(set(tf.all_variables()) - set(thawed_vars)))
        
        for epoch in range(5):
            for i in range(num_batches):
                if i < mnist.train.num_examples//batch_size: 
                    bx, by = mnist.train.next_batch(batch_size)
                else:  # there is no need to hold out the validation set
                    bx, by = mnist.validation.next_batch(batch_size)

                sess.run([optimizer2], feed_dict={x: bx, y: by})

    err, cc_model_score_test, cc_calibration_score_test,\
    cc_model_perp_test, cc_calibration_perp_test = sess.run([
                    compute_error, cc_model_score, cc_calibration_score,
                    cc_model_perplexity, cc_calibration_perplexity],
                                  feed_dict={x: mnist.test.images, y: mnist.test.labels})

    print('Test Classification Error (%)', err)
    print('Test Cautious Classification Calibration Perp', cc_calibration_perp_test)
    print('Test Cautious Classification Model Perp', cc_model_perp_test)
    print('Test Cautious Classification Calibration Score', cc_calibration_score_test)
    print('Test Cautious Classification Model Score', cc_model_score_test)

    sess.close()

In [112]:
train_and_test()
train_and_test()
train_and_test()


Epoch 10  |  Current Classification Error (%) 0.238838601712
Epoch 10  |  Cautious Classification Calibration Perp 1.3716098518
Epoch 10  |  Cautious Classification Model Perp 1.38537145556
Epoch 10  |  Cautious Classification Calibration Score 0.459150262912
Epoch 10  |  Cautious Classification Model Score 0.728819428557
Epoch 20  |  Current Classification Error (%) 4.74314965953e-05
Epoch 20  |  Cautious Classification Calibration Perp 1.36802734231
Epoch 20  |  Cautious Classification Model Perp 1.36858499174
Epoch 20  |  Cautious Classification Calibration Score 0.461961028479
Epoch 20  |  Cautious Classification Model Score 0.730980158868
Epoch 30  |  Current Classification Error (%) 5.34356365369e-07
Epoch 30  |  Cautious Classification Calibration Perp 1.3679960514
Epoch 30  |  Cautious Classification Model Perp 1.36843520308
Epoch 30  |  Cautious Classification Calibration Score 0.461993655648
Epoch 30  |  Cautious Classification Model Score 0.730996612043
Test Classification Error (%) 1.53
Test Cautious Classification Calibration Perp 1.38784
Test Cautious Classification Model Perp 1.49883
Test Cautious Classification Calibration Score 0.447779
Test Cautious Classification Model Score 0.719284
Epoch 10  |  Current Classification Error (%) 0.416679292296
Epoch 10  |  Cautious Classification Calibration Perp 1.37385376202
Epoch 10  |  Cautious Classification Model Perp 1.39062503716
Epoch 10  |  Cautious Classification Calibration Score 0.457356303276
Epoch 10  |  Cautious Classification Model Score 0.727309472492
Epoch 20  |  Current Classification Error (%) 0.000780590460416
Epoch 20  |  Cautious Classification Calibration Perp 1.36797648831
Epoch 20  |  Cautious Classification Model Perp 1.36832275065
Epoch 20  |  Cautious Classification Calibration Score 0.462016639195
Epoch 20  |  Cautious Classification Model Score 0.73100533955
Epoch 30  |  Current Classification Error (%) 0.000281799938929
Epoch 30  |  Cautious Classification Calibration Perp 1.36795983827
Epoch 30  |  Cautious Classification Model Perp 1.36825805159
Epoch 30  |  Cautious Classification Calibration Score 0.462033336176
Epoch 30  |  Cautious Classification Model Score 0.73101562866
Test Classification Error (%) 1.66
Test Cautious Classification Calibration Perp 1.3891
Test Cautious Classification Model Perp 1.4917
Test Cautious Classification Calibration Score 0.4469
Test Cautious Classification Model Score 0.718406
Epoch 10  |  Current Classification Error (%) 0.622645483511
Epoch 10  |  Cautious Classification Calibration Perp 1.37618056136
Epoch 10  |  Cautious Classification Model Perp 1.40128698743
Epoch 10  |  Cautious Classification Calibration Score 0.455695356711
Epoch 10  |  Cautious Classification Model Score 0.725811888692
Epoch 20  |  Current Classification Error (%) 6.81468317733e-05
Epoch 20  |  Cautious Classification Calibration Perp 1.36794962297
Epoch 20  |  Cautious Classification Model Perp 1.36821193983
Epoch 20  |  Cautious Classification Calibration Score 0.462042976158
Epoch 20  |  Cautious Classification Model Score 0.731021042065
Epoch 30  |  Current Classification Error (%) 0.00476900581373
Epoch 30  |  Cautious Classification Calibration Perp 1.3680108517
Epoch 30  |  Cautious Classification Model Perp 1.36839370513
Epoch 30  |  Cautious Classification Calibration Score 0.461995886019
Epoch 30  |  Cautious Classification Model Score 0.73098246707
Test Classification Error (%) 1.5
Test Cautious Classification Calibration Perp 1.3878
Test Cautious Classification Model Perp 1.50388
Test Cautious Classification Calibration Score 0.447997
Test Cautious Classification Model Score 0.7196

In [113]:
train_and_test("c_is_cotrained_sigmoid")
train_and_test("c_is_cotrained_sigmoid")
train_and_test("c_is_cotrained_sigmoid")


Epoch 10  |  Current Classification Error (%) 0.492964803515
Epoch 10  |  Cautious Classification Calibration Perp 1.01624450795
Epoch 10  |  Cautious Classification Model Perp 1.03024725482
Epoch 10  |  Cautious Classification Calibration Score 0.985736966119
Epoch 10  |  Cautious Classification Model Score 0.991994368456
Epoch 20  |  Current Classification Error (%) 0.000387889086603
Epoch 20  |  Cautious Classification Calibration Perp 1.00065019317
Epoch 20  |  Cautious Classification Model Perp 1.00108927638
Epoch 20  |  Cautious Classification Calibration Score 0.998852761851
Epoch 20  |  Cautious Classification Model Score 0.999425916576
Epoch 30  |  Current Classification Error (%) 5.99268018009e-08
Epoch 30  |  Cautious Classification Calibration Perp 1.00033587642
Epoch 30  |  Cautious Classification Model Perp 1.00057627473
Epoch 30  |  Cautious Classification Calibration Score 0.999371581918
Epoch 30  |  Cautious Classification Model Score 0.999685801202
Test Classification Error (%) 1.51
Test Cautious Classification Calibration Perp 1.07178
Test Cautious Classification Model Perp 1.15772
Test Cautious Classification Calibration Score 0.968612
Test Cautious Classification Model Score 0.983022
Epoch 10  |  Current Classification Error (%) 0.385521686141
Epoch 10  |  Cautious Classification Calibration Perp 1.01345305793
Epoch 10  |  Cautious Classification Model Perp 1.02873438673
Epoch 10  |  Cautious Classification Calibration Score 0.987455548967
Epoch 10  |  Cautious Classification Model Score 0.993077165628
Epoch 20  |  Current Classification Error (%) 8.81764889558e-06
Epoch 20  |  Cautious Classification Calibration Perp 1.00027256867
Epoch 20  |  Cautious Classification Model Perp 1.00052996979
Epoch 20  |  Cautious Classification Calibration Score 0.999480506572
Epoch 20  |  Cautious Classification Model Score 0.999740258354
Epoch 30  |  Current Classification Error (%) 1.86084553124e-14
Epoch 30  |  Cautious Classification Calibration Perp 1.00025121018
Epoch 30  |  Cautious Classification Model Perp 1.00046695358
Epoch 30  |  Cautious Classification Calibration Score 0.999518018019
Epoch 30  |  Cautious Classification Model Score 0.99975901285
Test Classification Error (%) 1.5
Test Cautious Classification Calibration Perp 1.07125
Test Cautious Classification Model Perp 1.15947
Test Cautious Classification Calibration Score 0.969402
Test Cautious Classification Model Score 0.983227
Epoch 10  |  Current Classification Error (%) 0.622331476209
Epoch 10  |  Cautious Classification Calibration Perp 1.01745984271
Epoch 10  |  Cautious Classification Model Perp 1.03617324213
Epoch 10  |  Cautious Classification Calibration Score 0.981872093257
Epoch 10  |  Cautious Classification Model Score 0.989661248126
Epoch 20  |  Current Classification Error (%) 8.3148174879e-05
Epoch 20  |  Cautious Classification Calibration Perp 1.00049382802
Epoch 20  |  Cautious Classification Model Perp 1.00091350948
Epoch 20  |  Cautious Classification Calibration Score 0.999129851366
Epoch 20  |  Cautious Classification Model Score 0.999564708541
Epoch 30  |  Current Classification Error (%) 1.18056436581e-05
Epoch 30  |  Cautious Classification Calibration Perp 1.00029681831
Epoch 30  |  Cautious Classification Model Perp 1.00055416415
Epoch 30  |  Cautious Classification Calibration Score 0.999438835167
Epoch 30  |  Cautious Classification Model Score 0.999719374213
Test Classification Error (%) 1.47
Test Cautious Classification Calibration Perp 1.06906
Test Cautious Classification Model Perp 1.15397
Test Cautious Classification Calibration Score 0.968602
Test Cautious Classification Model Score 0.982903

In [110]:
train_and_test("c_is_auxiliary_sigmoid")
# train_and_test("c_is_auxiliary_sigmoid")
# train_and_test("c_is_auxiliary_sigmoid")

# perhaps we need validation data or an aux ae


Epoch 10  |  Current Classification Error (%) 0.473885899655
Epoch 20  |  Current Classification Error (%) 0.00904122401227
Epoch 30  |  Current Classification Error (%) 6.72494346557e-08
Test Classification Error (%) 1.6
Test Cautious Classification Calibration Perp 1.88311
Test Cautious Classification Model Perp 2.04299
Test Cautious Classification Calibration Score 0.967986
Test Cautious Classification Model Score 0.983993