Import packages


In [1]:
# import necessary packages
import tensorflow as tf
import numpy as np

Short recap: Tensorflow


In [2]:
# test tensorflow
a = tf.constant(2)
b = tf.constant(2)
mult = tf.multiply(a,b)
sess = tf.Session()
sess.run(mult)


Out[2]:
4

In [3]:
# initialize weights for feed-forward nn, 300x200 matrice with normal dist
# connect a layer with 300 neurons to a layer with 200 neurons
weights = tf.Variable(tf.random_normal([300, 200], stddev=0.5), name="weights")

In [4]:
shape = (200,300)
# Common tensors from the TensorFlow API docs
zeros = tf.zeros(shape, dtype=tf.float32, name=None)
ones = tf.ones(shape, dtype=tf.float32, name=None)
std_normal_dist = tf.random_normal(
    shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
trunacted_normal = tf.truncated_normal(
    shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
uniform_dist = tf.random_uniform(
    shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None)

In [5]:
import tensorflow as tf

# set up linear model
x = tf.placeholder(tf.float32, shape=(1024, 1024))
W = tf.Variable(tf.random_uniform([1024, 10], -1, 1), name="W")
b = tf.Variable(tf.zeros([10]), name="biases")
y = tf.matmul(x, W) + b

# initialize variables
init_op = tf.global_variables_initializer()

# create session
with tf.Session() as sess:
    sess.run(init_op)
    rand_array = np.random.rand(1024, 1024)
    result = sess.run(y, feed_dict={x: rand_array})
    print(result)
    print(result.shape)


[[ 14.64269161  16.9288311  -13.44757557 ...,  -1.03777885  -3.17562819
  -10.16825294]
 [ 15.74568081  10.47799873  -9.99257278 ...,   1.789608    -1.20423007
   -7.99270439]
 [ 15.23731899   8.5260849  -14.4066515  ...,  -1.86329055  -3.36746407
   -8.9739666 ]
 ..., 
 [ 18.08404541  16.31268501 -14.92662525 ...,  -2.06884551   0.82842398
   -4.14905453]
 [ 25.08653641  12.29383469 -11.6435318  ..., -10.76433754  -0.40803289
   -8.98172569]
 [ 21.42256546   8.27137089 -10.31844807 ...,  -6.32147264 -11.87181282
  -11.07276917]]
(1024, 10)

In [6]:
# using tf-namespaces/variable scoping
def layer(input, weight_shape, bias_shape):
    weight_init = tf.random_uniform_initializer(minval=-1, maxval=1)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable("W", weight_shape, initializer=weight_init)
    b = tf.get_variable("b", bias_shape, initializer=bias_init)
    return tf.matmul(input, W) + b

def my_network(input):
    with tf.variable_scope("layer_1"):
        output_1 = layer(input, [784, 100], [100])
    with tf.variable_scope("layer_2"):
        output_2 = layer(output_1, [100, 50], [50])
    with tf.variable_scope("layer_3"):
        output_3 = layer(output_2, [50, 10], [10])
    return output_3

In [7]:
# By default, sharing is not allowed, but if we want to enable sharing within a variable scope
with tf.variable_scope("shared_variables") as scope:
    i_1 = tf.placeholder(tf.float32, [1000, 784], name="i_1")
    print(my_network(i_1))
    # enable variable sharing within variable scope
    scope.reuse_variables()
    i_2 = tf.placeholder(tf.float32, [1000, 784], name="i_2")
    print(my_network(i_2))


Tensor("shared_variables/layer_3/add:0", shape=(1000, 10), dtype=float32)
Tensor("shared_variables/layer_3_1/add:0", shape=(1000, 10), dtype=float32)

Logistic Regression Model

Use MNIST dataset with 28x28 pixel images as input. Target classes are [0,9], use softmax of size 10.

Steps to train and evaluate model:

  1. inference: produces a probability distribution over the output classes given a minibatch
  2. loss: computes the value of the error function (in this case, the cross-entropy loss)
  3. training: responsible for computing the gradients of the model’s parameters and updating the model
  4. evaluate: will determine the effectiveness of a model

In [8]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [17]:
def inference(x):
    init = tf.constant_initializer(value=0)
    W = tf.get_variable("W", [784, 10], initializer=init)
    b = tf.get_variable("b", [10], initializer=init)
    output = tf.nn.softmax(tf.matmul(x, W) + b)
    return output

Formular for cross-entropy: $H(y, \hat{y}) = \sum_i y_i \log \frac{1}{\hat{y}_i} = -\sum_i y_i \log \hat{y}_i$


In [11]:
def loss(output, y):
    '''
    # tf.reduce_sum (takes `axis` arg):
    # 'x' is [[1, 1, 1]
    #         [1, 1, 1]]
    tf.reduce_sum(x) ==> 6
    tf.reduce_sum(x, 0) ==> [2, 2, 2]
    tf.reduce_sum(x, 1) ==> [3, 3]
    '''
    dot_product = y * tf.log(output)
    xentropy = -tf.reduce_sum(dot_product, axis=1)
    loss = tf.reduce_mean(xentropy)
    return loss

In [21]:
def training(cost, global_step):
    tf.summary.tensor_summary("cost", cost)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.minimize(cost, global_step=global_step)
    return train_op

In [13]:
def evaluate(output, y):
    # compare indices of predicted class, if equal (correct classification) set 1 otherwise 0
    correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy

In [28]:
# Parameters
learning_rate = 0.01
training_epochs = 100
batch_size = 100
display_step = 1

In [33]:
from tqdm import tqdm

# program flow
with tf.Graph().as_default():
    # mnist data image of shape 28*28=784
    x = tf.placeholder("float", [None, 784])
    # 0-9 digits recognition => 10 classes
    y = tf.placeholder("float", [None, 10])
    output = inference(x)
    cost = loss(output, y)
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = training(cost, global_step)
    eval_op = evaluate(output, y)

    # tf.merge_all_summaries in order to collect all summary statistics
    # use a tf.train.SummaryWriter to write the log to disk.
    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver()
    sess = tf.Session()
    # write to tensorboard graph api
    summary_writer = tf.summary.FileWriter(
        "logistic_logs/", graph=sess.graph)
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    # training cycle
    for epoch in tqdm(range(training_epochs)):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples / batch_size)

        # Loop over all batches
        for i in range(total_batch):
            mbatch_x, mbatch_y = mnist.train.next_batch(batch_size)

            # Fit training using batch data
            feed_dict = {x: mbatch_x, y: mbatch_y}
            sess.run(train_op, feed_dict=feed_dict)

            # Compute average loss
            minibatch_cost = sess.run(cost, feed_dict=feed_dict)
            avg_cost += minibatch_cost / total_batch
        # Display logs per epoch step
        if epoch % display_step == 0:
            val_feed_dict = {
                x: mnist.validation.images,
                y: mnist.validation.labels
            }
            accuracy = sess.run(eval_op, feed_dict=val_feed_dict)
            print("Validation Error in epoch %s: %.11f" % (epoch, 1 - accuracy))
            summary_str = sess.run(summary_op, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, sess.run(global_step))
            saver.save(
                sess,
                "logistic_logs/model-checkpoint",
                global_step=global_step)

    test_feed_dict = {x: mnist.test.images, y: mnist.test.labels}
    accuracy = sess.run(eval_op, feed_dict=test_feed_dict)
    print("Test Accuracy:", accuracy)


  0%|          | 0/100 [00:00<?, ?it/s]
Validation Error in epoch 0: 0.14819997549
  2%|▏         | 2/100 [00:01<01:28,  1.11it/s]
Validation Error in epoch 1: 0.13059997559
  3%|▎         | 3/100 [00:02<01:25,  1.14it/s]
Validation Error in epoch 2: 0.12139999866
  4%|▍         | 4/100 [00:03<01:22,  1.16it/s]
Validation Error in epoch 3: 0.11540001631
  5%|▌         | 5/100 [00:04<01:20,  1.18it/s]
Validation Error in epoch 4: 0.10939997435
  6%|▌         | 6/100 [00:05<01:18,  1.20it/s]
Validation Error in epoch 5: 0.10640001297
  7%|▋         | 7/100 [00:05<01:19,  1.18it/s]
Validation Error in epoch 6: 0.10500001907
  8%|▊         | 8/100 [00:06<01:17,  1.18it/s]
Validation Error in epoch 7: 0.10360002518
  9%|▉         | 9/100 [00:07<01:16,  1.19it/s]
Validation Error in epoch 8: 0.10219997168
 10%|█         | 10/100 [00:08<01:15,  1.18it/s]
Validation Error in epoch 9: 0.09899997711
 11%|█         | 11/100 [00:09<01:14,  1.20it/s]
Validation Error in epoch 10: 0.09680002928
 12%|█▏        | 12/100 [00:10<01:12,  1.21it/s]
Validation Error in epoch 11: 0.09640002251
 13%|█▎        | 13/100 [00:10<01:13,  1.19it/s]
Validation Error in epoch 12: 0.09439998865
 14%|█▍        | 14/100 [00:11<01:12,  1.19it/s]
Validation Error in epoch 13: 0.09380000830
 15%|█▌        | 15/100 [00:12<01:10,  1.20it/s]
Validation Error in epoch 14: 0.09100002050
 16%|█▌        | 16/100 [00:13<01:08,  1.23it/s]
Validation Error in epoch 15: 0.08980000019
 17%|█▋        | 17/100 [00:14<01:06,  1.25it/s]
Validation Error in epoch 16: 0.08999997377
 18%|█▊        | 18/100 [00:14<01:07,  1.22it/s]
Validation Error in epoch 17: 0.08880001307
 19%|█▉        | 19/100 [00:15<01:06,  1.22it/s]
Validation Error in epoch 18: 0.08920001984
 20%|██        | 20/100 [00:16<01:06,  1.20it/s]
Validation Error in epoch 19: 0.08819997311
 21%|██        | 21/100 [00:17<01:05,  1.21it/s]
Validation Error in epoch 20: 0.08700001240
 22%|██▏       | 22/100 [00:18<01:04,  1.22it/s]
Validation Error in epoch 21: 0.08579999208
 23%|██▎       | 23/100 [00:19<01:02,  1.24it/s]
Validation Error in epoch 22: 0.08880001307
 24%|██▍       | 24/100 [00:19<01:03,  1.19it/s]
Validation Error in epoch 23: 0.08579999208
 25%|██▌       | 25/100 [00:20<01:04,  1.16it/s]
Validation Error in epoch 24: 0.08560001850
 26%|██▌       | 26/100 [00:21<01:03,  1.17it/s]
Validation Error in epoch 25: 0.08539998531
 27%|██▋       | 27/100 [00:22<01:03,  1.16it/s]
Validation Error in epoch 26: 0.08579999208
 28%|██▊       | 28/100 [00:23<01:03,  1.13it/s]
Validation Error in epoch 27: 0.08539998531
 29%|██▉       | 29/100 [00:24<01:04,  1.10it/s]
Validation Error in epoch 28: 0.08480000496
 30%|███       | 30/100 [00:25<01:03,  1.10it/s]
Validation Error in epoch 29: 0.08380001783
 31%|███       | 31/100 [00:26<01:00,  1.14it/s]
Validation Error in epoch 30: 0.08279997110
 32%|███▏      | 32/100 [00:27<00:58,  1.16it/s]
Validation Error in epoch 31: 0.08380001783
 33%|███▎      | 33/100 [00:27<00:56,  1.18it/s]
Validation Error in epoch 32: 0.08380001783
 34%|███▍      | 34/100 [00:28<00:55,  1.19it/s]
Validation Error in epoch 33: 0.08259999752
 35%|███▌      | 35/100 [00:29<00:53,  1.21it/s]
Validation Error in epoch 34: 0.08219999075
 36%|███▌      | 36/100 [00:30<00:53,  1.19it/s]
Validation Error in epoch 35: 0.08279997110
 37%|███▋      | 37/100 [00:31<00:52,  1.20it/s]
Validation Error in epoch 36: 0.08359998465
 38%|███▊      | 38/100 [00:32<00:52,  1.18it/s]
Validation Error in epoch 37: 0.08079999685
 39%|███▉      | 39/100 [00:33<00:56,  1.07it/s]
Validation Error in epoch 38: 0.08139997721
 40%|████      | 40/100 [00:34<00:56,  1.06it/s]
Validation Error in epoch 39: 0.08099997044
 41%|████      | 41/100 [00:35<00:55,  1.07it/s]
Validation Error in epoch 40: 0.08020001650
 42%|████▏     | 42/100 [00:36<00:55,  1.05it/s]
Validation Error in epoch 41: 0.08060002327
 43%|████▎     | 43/100 [00:36<00:53,  1.06it/s]
Validation Error in epoch 42: 0.08039999008
 44%|████▍     | 44/100 [00:37<00:52,  1.07it/s]
Validation Error in epoch 43: 0.08060002327
 45%|████▌     | 45/100 [00:38<00:50,  1.09it/s]
Validation Error in epoch 44: 0.07980000973
 46%|████▌     | 46/100 [00:39<00:48,  1.11it/s]
Validation Error in epoch 45: 0.07940000296
 47%|████▋     | 47/100 [00:40<00:46,  1.13it/s]
Validation Error in epoch 46: 0.07920002937
 48%|████▊     | 48/100 [00:41<00:46,  1.12it/s]
Validation Error in epoch 47: 0.07899999619
 49%|████▉     | 49/100 [00:42<00:44,  1.15it/s]
Validation Error in epoch 48: 0.07899999619
 50%|█████     | 50/100 [00:43<00:43,  1.15it/s]
Validation Error in epoch 49: 0.07800000906
 51%|█████     | 51/100 [00:43<00:41,  1.17it/s]
Validation Error in epoch 50: 0.07800000906
 52%|█████▏    | 52/100 [00:44<00:41,  1.16it/s]
Validation Error in epoch 51: 0.07760000229
 53%|█████▎    | 53/100 [00:45<00:41,  1.13it/s]
Validation Error in epoch 52: 0.07800000906
 54%|█████▍    | 54/100 [00:46<00:42,  1.09it/s]
Validation Error in epoch 53: 0.07800000906
 55%|█████▌    | 55/100 [00:47<00:41,  1.09it/s]
Validation Error in epoch 54: 0.07819998264
 56%|█████▌    | 56/100 [00:48<00:41,  1.07it/s]
Validation Error in epoch 55: 0.07779997587
 57%|█████▋    | 57/100 [00:49<00:40,  1.05it/s]
Validation Error in epoch 56: 0.07740002871
 58%|█████▊    | 58/100 [00:50<00:40,  1.05it/s]
Validation Error in epoch 57: 0.07740002871
 59%|█████▉    | 59/100 [00:51<00:38,  1.05it/s]
Validation Error in epoch 58: 0.07840001583
 60%|██████    | 60/100 [00:52<00:38,  1.03it/s]
Validation Error in epoch 59: 0.07679998875
 61%|██████    | 61/100 [00:53<00:38,  1.02it/s]
Validation Error in epoch 60: 0.07719999552
 62%|██████▏   | 62/100 [00:54<00:36,  1.05it/s]
Validation Error in epoch 61: 0.07660001516
 63%|██████▎   | 63/100 [00:55<00:34,  1.09it/s]
Validation Error in epoch 62: 0.07639998198
 64%|██████▍   | 64/100 [00:56<00:31,  1.13it/s]
Validation Error in epoch 63: 0.07679998875
 65%|██████▌   | 65/100 [00:56<00:30,  1.13it/s]
Validation Error in epoch 64: 0.07639998198
 66%|██████▌   | 66/100 [00:57<00:29,  1.15it/s]
Validation Error in epoch 65: 0.07639998198
 67%|██████▋   | 67/100 [00:58<00:28,  1.15it/s]
Validation Error in epoch 66: 0.07700002193
 68%|██████▊   | 68/100 [00:59<00:26,  1.19it/s]
Validation Error in epoch 67: 0.07599997520
 69%|██████▉   | 69/100 [01:00<00:25,  1.22it/s]
Validation Error in epoch 68: 0.07620000839
 70%|███████   | 70/100 [01:01<00:24,  1.21it/s]
Validation Error in epoch 69: 0.07580000162
 71%|███████   | 71/100 [01:01<00:25,  1.16it/s]
Validation Error in epoch 70: 0.07539999485
 72%|███████▏  | 72/100 [01:02<00:24,  1.16it/s]
Validation Error in epoch 71: 0.07580000162
 73%|███████▎  | 73/100 [01:03<00:23,  1.14it/s]
Validation Error in epoch 72: 0.07620000839
 74%|███████▍  | 74/100 [01:04<00:22,  1.16it/s]
Validation Error in epoch 73: 0.07560002804
 75%|███████▌  | 75/100 [01:05<00:21,  1.17it/s]
Validation Error in epoch 74: 0.07599997520
 76%|███████▌  | 76/100 [01:06<00:20,  1.17it/s]
Validation Error in epoch 75: 0.07580000162
 77%|███████▋  | 77/100 [01:07<00:19,  1.16it/s]
Validation Error in epoch 76: 0.07520002127
 78%|███████▊  | 78/100 [01:07<00:18,  1.17it/s]
Validation Error in epoch 77: 0.07580000162
 79%|███████▉  | 79/100 [01:08<00:17,  1.19it/s]
Validation Error in epoch 78: 0.07499998808
 80%|████████  | 80/100 [01:09<00:16,  1.19it/s]
Validation Error in epoch 79: 0.07560002804
 81%|████████  | 81/100 [01:10<00:16,  1.19it/s]
Validation Error in epoch 80: 0.07560002804
 82%|████████▏ | 82/100 [01:11<00:15,  1.19it/s]
Validation Error in epoch 81: 0.07440000772
Validation Error in epoch 82: 0.07480001450
 84%|████████▍ | 84/100 [01:13<00:14,  1.09it/s]
Validation Error in epoch 83: 0.07539999485
 85%|████████▌ | 85/100 [01:14<00:13,  1.11it/s]
Validation Error in epoch 84: 0.07499998808
 86%|████████▌ | 86/100 [01:14<00:12,  1.14it/s]
Validation Error in epoch 85: 0.07520002127
 87%|████████▋ | 87/100 [01:16<00:12,  1.02it/s]
Validation Error in epoch 86: 0.07520002127
 88%|████████▊ | 88/100 [01:17<00:12,  1.01s/it]
Validation Error in epoch 87: 0.07480001450
 89%|████████▉ | 89/100 [01:18<00:11,  1.00s/it]
Validation Error in epoch 88: 0.07499998808
 90%|█████████ | 90/100 [01:19<00:09,  1.04it/s]
Validation Error in epoch 89: 0.07539999485
 91%|█████████ | 91/100 [01:20<00:08,  1.02it/s]
Validation Error in epoch 90: 0.07560002804
 92%|█████████▏| 92/100 [01:21<00:08,  1.03s/it]
Validation Error in epoch 91: 0.07520002127
 93%|█████████▎| 93/100 [01:22<00:07,  1.11s/it]
Validation Error in epoch 92: 0.07520002127
 94%|█████████▍| 94/100 [01:23<00:06,  1.17s/it]
Validation Error in epoch 93: 0.07459998131
Validation Error in epoch 94: 0.07480001450
 96%|█████████▌| 96/100 [01:26<00:05,  1.26s/it]
Validation Error in epoch 95: 0.07580000162
 97%|█████████▋| 97/100 [01:27<00:03,  1.27s/it]
Validation Error in epoch 96: 0.07480001450
 98%|█████████▊| 98/100 [01:29<00:02,  1.28s/it]
Validation Error in epoch 97: 0.07480001450
 99%|█████████▉| 99/100 [01:30<00:01,  1.27s/it]
Validation Error in epoch 98: 0.07520002127
100%|██████████| 100/100 [01:31<00:00,  1.15s/it]
Validation Error in epoch 99: 0.07560002804
Test Accuracy: 0.922