In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
test_data = mnist.test
train_data = mnist.train
valid_data = mnist.validation


Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [2]:
epsilon = 1e-3

class FC(object):
    def __init__(self, learning_rate=0.01):
        self.lr = learning_rate
        self.sess = tf.Session()
        self.x = tf.placeholder(tf.float32,[None, 784], 'x')
        self.y_ = tf.placeholder(tf.float32, [None, 10], 'y_')
        self.training = tf.placeholder(tf.bool, name='training')
        self._build_net(self.x,'FC')

        with tf.variable_scope('Accuracy'):
            self.correct_prediction = tf.equal(tf.argmax(self.y,1), tf.argmax(self.y_,1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
        with tf.variable_scope('Train'):
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y_, logits=self.y)) 
            extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(extra_update_ops):
                self.train_opt = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss)
        self.sess.run(tf.global_variables_initializer())
        
    def batch_norm_wrapper(self, inputs, is_training, decay = 0.999):

        scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))
        beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))
        pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
        pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)

        if is_training==tf.constant(True):
            batch_mean, batch_var = tf.nn.moments(inputs,[0])
            train_mean = tf.assign(pop_mean,
                                   pop_mean * decay + batch_mean * (1 - decay))
            train_var = tf.assign(pop_var,
                                  pop_var * decay + batch_var * (1 - decay))
            with tf.control_dependencies([train_mean, train_var]):
                return tf.nn.batch_normalization(inputs,
                    batch_mean, batch_var, beta, scale, epsilon)
        else:
            return tf.nn.batch_normalization(inputs,
                pop_mean, pop_var, beta, scale, epsilon)
        
    def _build_net(self, x, scope):
        with tf.variable_scope(scope):
            bn = tf.layers.batch_normalization(x, axis=1, training=self.training, name = 'bn')
            #bn = self.batch_norm_wrapper(x, self.training)
            hidden = tf.layers.dense(bn, 50, activation=tf.nn.relu, name='l1')
            self.y = tf.layers.dense(hidden, 10, name='o')
    
    def learn(self, x, y):
        loss,_ = self.sess.run([self.loss,self.train_opt],{self.x:x, self.y_:y, self.training:True})
        return loss
    
    def inference(self, x, y=None):
        y = self.sess.run(self.y,{self.x:x, self.training:False})
        #loss,_ = self.sess.run(self.loss,{self.x:x, self.y_:y, self.training:False})
        return y
    
fc = FC()

In [3]:
OUTPUT_GRAPH = True
if OUTPUT_GRAPH:
    tf.summary.FileWriter("logs/", fc.sess.graph)

In [4]:
for i in range(1000):
    batch = train_data.next_batch(100)
    loss = fc.learn(batch[0],batch[1])
    if i%200 == 0:
        print(loss)
    
batch = valid_data.next_batch(5000)
print("validation accuracy: %f" % fc.sess.run(fc.accuracy,{fc.x:batch[0], fc.y_:batch[1], fc.training:True}))


2.70122
0.862728
0.519033
0.401137
0.317988
validation accuracy: 0.917200

In [5]:
nm = 1000
count = 0
for _ in range(nm):
    t = test_data.next_batch(1)
    x = t[0]
    y = fc.inference(x)
    a = np.argmax(y,axis=1)
    b = np.argmax(t[1],axis=1)
    if a==b:
        count += 1
print count/float(nm)


0.914