In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
test_data = mnist.test
train_data = mnist.train
valid_data = mnist.validation
In [2]:
epsilon = 1e-3
class FC(object):
def __init__(self, learning_rate=0.01):
self.lr = learning_rate
self.sess = tf.Session()
self.x = tf.placeholder(tf.float32,[None, 784], 'x')
self.y_ = tf.placeholder(tf.float32, [None, 10], 'y_')
self.training = tf.placeholder(tf.bool, name='training')
self._build_net(self.x,'FC')
with tf.variable_scope('Accuracy'):
self.correct_prediction = tf.equal(tf.argmax(self.y,1), tf.argmax(self.y_,1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
with tf.variable_scope('Train'):
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y_, logits=self.y))
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
self.train_opt = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss)
self.sess.run(tf.global_variables_initializer())
def batch_norm_wrapper(self, inputs, is_training, decay = 0.999):
scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))
beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))
pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)
if is_training==tf.constant(True):
batch_mean, batch_var = tf.nn.moments(inputs,[0])
train_mean = tf.assign(pop_mean,
pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var,
pop_var * decay + batch_var * (1 - decay))
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(inputs,
batch_mean, batch_var, beta, scale, epsilon)
else:
return tf.nn.batch_normalization(inputs,
pop_mean, pop_var, beta, scale, epsilon)
def _build_net(self, x, scope):
with tf.variable_scope(scope):
bn = tf.layers.batch_normalization(x, axis=1, training=self.training, name = 'bn')
#bn = self.batch_norm_wrapper(x, self.training)
hidden = tf.layers.dense(bn, 50, activation=tf.nn.relu, name='l1')
self.y = tf.layers.dense(hidden, 10, name='o')
def learn(self, x, y):
loss,_ = self.sess.run([self.loss,self.train_opt],{self.x:x, self.y_:y, self.training:True})
return loss
def inference(self, x, y=None):
y = self.sess.run(self.y,{self.x:x, self.training:False})
#loss,_ = self.sess.run(self.loss,{self.x:x, self.y_:y, self.training:False})
return y
fc = FC()
In [3]:
OUTPUT_GRAPH = True
if OUTPUT_GRAPH:
tf.summary.FileWriter("logs/", fc.sess.graph)
In [4]:
for i in range(1000):
batch = train_data.next_batch(100)
loss = fc.learn(batch[0],batch[1])
if i%200 == 0:
print(loss)
batch = valid_data.next_batch(5000)
print("validation accuracy: %f" % fc.sess.run(fc.accuracy,{fc.x:batch[0], fc.y_:batch[1], fc.training:True}))
In [5]:
nm = 1000
count = 0
for _ in range(nm):
t = test_data.next_batch(1)
x = t[0]
y = fc.inference(x)
a = np.argmax(y,axis=1)
b = np.argmax(t[1],axis=1)
if a==b:
count += 1
print count/float(nm)