In [4]:
import tensorflow as tf
import sys
import matplotlib.pyplot as plt
import numpy as np
import os
In [5]:
sys.path.append("../");
In [6]:
import sropts
from neural_networks import *
In [5]:
# Download / load MNIST data
train_dir = '../data/MNIST/';
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(train_dir, one_hot=True);
In [8]:
# Parms
batch_size = 64;
image_size = 28;
category = 10;
learning_rate=1e-3;
chkpt_dir = "../chkpt/";
if not os.path.exists(chkpt_dir):
os.makedirs(chkpt_dir);
In [9]:
# Build network
x = tf.placeholder(dtype=tf.float32,shape=(batch_size, image_size, image_size,1), name='in-img');
y = tf.placeholder(dtype=tf.float32, shape=(batch_size, category), name='in-label');
is_train = tf.placeholder(dtype=tf.bool, shape=[], name='is_train');
m_nn = NeuralNetworks.ConvMNIST(h_dim=32, fc_dim = 256, block_num=8, is_train=is_train);
predict = m_nn.nn_predict(x, reuse=False);
m_loss = m_nn.loss(predict=predict, real=y);
correct_prediction = tf.equal(tf.argmax(predict,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
opt_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(m_loss);
init_op = tf.global_variables_initializer();
saver = tf.train.Saver();
In [ ]:
#Training phase
with tf.Session() as sess:
sess.run(init_op);
val_batches = int(0.2 * mnist.train.num_examples / batch_size);
best_val_acc = 0;
for epoch in xrange(25):
count = 0;
total_acc = []
while count < (mnist.train.num_examples- val_batches*batch_size):
count += batch_size;
[input_x, input_y] = mnist.train.next_batch(batch_size=batch_size);
input_x = np.reshape(input_x, newshape=(batch_size,28,28,1));
[comp_loss,comp_acc, __] = sess.run([m_loss, accuracy, opt_op], feed_dict={
x: input_x,
y: input_y,
is_train: True
});
total_acc.append(comp_acc);
if count/batch_size % 10 == 0:
print "\r\bepoch:",epoch, " image#:", count, " avg acc:", np.mean(total_acc), " loss:", comp_loss, "acc:", comp_acc,
# Valuation for the last batch
count = 0;
real_labels = [];
pred_labels = [];
while count < val_batches * batch_size:
count += batch_size;
[input_x, input_y] = mnist.train.next_batch(batch_size=batch_size);
input_x = np.reshape(input_x, newshape=(batch_size,28,28,1));
[comp_predict_val] = sess.run([predict], feed_dict={
x: input_x,
y: input_y,
is_train: False
});
pred_labels.extend(comp_predict_val);
real_labels.extend(input_y);
val_corr = np.equal(np.argmax(pred_labels,1), np.argmax(real_labels,1));
val_acc = np.mean(val_corr.astype(np.float32));
print "val: ", val_acc,
if val_acc > best_val_acc:
saver.save(sess, chkpt_dir+"model.chkpt")
best_val_acc = val_acc;
print "model saved!";
else:
print "";
sess.close();
In [ ]:
# Testing phase
with tf.Session() as sess:
saver.restore(sess,chkpt_dir+"model.chkpt");
# saver.restore(sess, saver.last_checkpoints[-1]);
count = 0;
real_labels = [];
pred_labels = [];
while count <= mnist.test.num_examples-batch_size:
count += batch_size;
[input_x, input_y] = mnist.test.next_batch(batch_size=batch_size);
input_x = np.reshape(input_x, newshape=(batch_size,28,28,1));
[comp_predict_test] = sess.run([predict], feed_dict={
x: input_x,
y: input_y,
is_train: False
});
pred_labels.extend(comp_predict_test);
real_labels.extend(input_y);
print "Forward completed";
test_corr = np.equal(np.argmax(pred_labels,1), np.argmax(real_labels,1));
test_acc = np.mean(test_corr.astype(np.float32));
print "average acc:", test_acc;
sess.close();
In [40]:
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
In [46]:
class A(object):
def __init__(self, data_dir):
self.data_dir = data_dir;
self.mnist = input_data.read_data_sets(self.data_dir, one_hot=True);
self.train = A.DataSet(self.mnist.train)
self.validation = self.mnist.validation;
self.test = self.mnist.test;
class DataSet(object):
def __init__(self, data_set):
self.data_set = data_set;
self.images = data_set.images;
self.labels = data_set.labels;
self.num_examples = data_set.num_examples;
def next_batch(self,batch_size):
return self.data_set.next_batch(batch_size);
In [47]:
m_a = A("../data/MNIST/");
In [48]:
[m_x, m_y] =m_a.train.next_batch(batch_size=2);
In [58]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
In [60]:
for im in m_x:
plt.figure();
plt.imshow(np.reshape(im,newshape=(28,28)),cmap='gray');
In [51]:
m_x
Out[51]:
In [ ]:
In [ ]:
In [ ]: