In [1]:
#
# COMMENTS TO DO
#
%matplotlib inline
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import time
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow.contrib.layers as layers
def plot(samples, w, h, fw, fh, iw=28, ih=28):
fig = plt.figure(figsize=(fw, fh))
gs = gridspec.GridSpec(w, h)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(iw, ih), cmap='Greys_r')
return fig
DATA_PATH = "../DATASETS/"
mnist = input_data.read_data_sets(DATA_PATH + "MNIST_TF/", one_hot=True)
X_TOTAL = mnist.train.images.shape[0]
X_DIM = mnist.train.images.shape[1]
Y_DIM = mnist.train.labels.shape[1]
print("# samples {}".format(X_TOTAL))
print("Input's dimension {}".format(X_DIM))
print("Label's dimension {}".format(Y_DIM))
#Determining data's input (Setting to None first dimension allows us to use a variable batch size)
images_placeholder = tf.placeholder(tf.float32, shape=(None, X_DIM))
labels_placeholder = tf.placeholder(tf.int32, shape=(None, Y_DIM))
learning_rate_placeholder = tf.placeholder(tf.float32)
is_training_placeholder = tf.placeholder(tf.bool)
#Defining a model
def model_batch_norm(images, is_training=True):
h1 =layers.fully_connected(
inputs=images,
num_outputs=1024,
activation_fn=tf.nn.relu,
normalizer_fn=layers.batch_norm,
normalizer_params={"is_training": is_training, "updates_collections": None},
scope='d_%d' % (0,)
)
h2 =layers.fully_connected(
inputs=h1,
num_outputs=1024,
activation_fn=tf.nn.relu,
normalizer_fn=layers.batch_norm,
normalizer_params={"is_training": is_training, "updates_collections": None},
scope='d_%d' % (1,)
)
h3 =layers.fully_connected(
inputs=h2,
num_outputs=1024,
activation_fn=tf.nn.relu,
normalizer_fn=layers.batch_norm,
normalizer_params={"is_training": is_training, "updates_collections": None},
scope='d_%d' % (2,)
)
logits =layers.fully_connected(
inputs=h3,
num_outputs=10,
activation_fn=None,
normalizer_fn=None,
normalizer_params={"is_training": is_training, "updates_collections": None},
scope='d_%d' % (3,)
)
return logits
output_logits=model_batch_norm(images_placeholder, is_training_placeholder)
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=labels_placeholder, logits=output_logits))
train_step = tf.train.GradientDescentOptimizer(learning_rate_placeholder).minimize(cross_entropy)
#Obtaining accuracy
y_pred = tf.argmax(input=output_logits, axis=1)
y_true = tf.argmax(input=labels_placeholder, axis=1)
correct_prediction = tf.equal(y_pred, y_true)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
In [2]:
# Parameters
LEARNING_RATE = 0.5
BATCH_SIZE = 64
EPOCHS = 10
TOTAL_BATCHES = int(X_TOTAL/BATCH_SIZE)
X_TOTAL_VALID = mnist.validation.images.shape[0]
BATCH_SIZE_VAL = 100
TOTAL_BATCHES_VALIDATION = int(X_TOTAL_VALID/BATCH_SIZE_VAL)
# Initializing the variables
init = tf.global_variables_initializer()
#Models' managing
MODELS_PATH = "MODELS/"
if not os.path.exists(MODELS_PATH):
os.makedirs(MODELS_PATH)
MODEL_NAME = "MLP_1024_RELU_x3_10_SOFT_BATCH_NORM.ckpt"
# 'Saver' op to save and restore all the variables
MLP_SAVER = tf.train.Saver()
#A Session with a "with" block. The Session closes automatically at the end of the with block.
with tf.Session() as sess:
sess.run(init)
training_acc = []
validation_acc = []
for epoch in range(EPOCHS):
batch_indexes = np.random.permutation(TOTAL_BATCHES)
training_total_acc = 0
start_time = time.time()
for minibatch_number, batch_index in enumerate(batch_indexes):
X_minibatch = mnist.train.images[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE]
Y_minibatch = mnist.train.labels[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE]
_, minibatch_acc = sess.run([train_step, accuracy],
feed_dict={
images_placeholder: X_minibatch,
labels_placeholder: Y_minibatch,
learning_rate_placeholder: LEARNING_RATE,
is_training_placeholder: True
})
training_total_acc+=minibatch_acc
if minibatch_number % 500 == 0:
print("MB INDEX {}".format(minibatch_number))
print("E {} | TRAINING ACC: {:.4f} | TIME {:.2f} secs".format(epoch, training_total_acc/TOTAL_BATCHES, time.time() - start_time))
training_acc.append(training_total_acc/TOTAL_BATCHES)
total_minibatch_acc_val = 0
start_time_val = time.time()
for minibatch_number_validation in range(TOTAL_BATCHES_VALIDATION):
X_minibatch = mnist.validation.images[minibatch_number_validation*BATCH_SIZE_VAL:(minibatch_number_validation+1)*BATCH_SIZE_VAL]
Y_minibatch = mnist.validation.labels[minibatch_number_validation*BATCH_SIZE_VAL:(minibatch_number_validation+1)*BATCH_SIZE_VAL]
minibatch_acc_val = sess.run(accuracy,
feed_dict={
images_placeholder: X_minibatch,
labels_placeholder: Y_minibatch,
learning_rate_placeholder: LEARNING_RATE,
is_training_placeholder: False
})
total_minibatch_acc_val+=minibatch_acc_val
print("E {} | VALIDATION ACC: {:.4f} | TIME {:.2f} secs".format(epoch, total_minibatch_acc_val/TOTAL_BATCHES_VALIDATION, time.time() - start_time_val))
validation_acc.append(total_minibatch_acc_val/TOTAL_BATCHES_VALIDATION)
print("Optimization Finished!")
#Saving the model
# Save model weights to disk
save_path = MLP_SAVER.save(sess, MODELS_PATH + MODEL_NAME)
print("Model saved in file: {}".format(save_path))
In [3]:
plt.plot(training_acc, label="Training")
plt.xlabel("Epochs")
plt.plot(validation_acc, label="Validation")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
In [4]:
X_TEST_TOTAL = mnist.test.images.shape[0]
TEST_BATCH_SIZE = 100
TEST_TOTAL_BATCHES = int(X_TEST_TOTAL/TEST_BATCH_SIZE)
with tf.Session() as sess:
sess.run(init)
MLP_SAVER.restore(sess, save_path)
print("Model restored in file: {}".format(save_path))
test_total_acc = 0
start_time = time.time()
for test_minibatch_number in range(TEST_TOTAL_BATCHES):
X_minibatch = mnist.test.images[test_minibatch_number*TEST_BATCH_SIZE:(test_minibatch_number+1)*TEST_BATCH_SIZE]
Y_minibatch = mnist.test.labels[test_minibatch_number*TEST_BATCH_SIZE:(test_minibatch_number+1)*TEST_BATCH_SIZE]
minibatch_acc = sess.run(accuracy,
feed_dict={
images_placeholder: X_minibatch,
labels_placeholder: Y_minibatch,
is_training_placeholder: False
})
test_total_acc+=minibatch_acc
if test_minibatch_number % 10 == 0:
print("MB INDEX {}".format(test_minibatch_number))
print("TEST ACC: {:.4f} | TIME {:.2f} secs".format(test_total_acc/TEST_TOTAL_BATCHES, time.time() - start_time))