In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
In [2]:
from tensorflow.examples.tutorials.mnist import input_data
MNIST=input_data.read_data_sets(r"..\DataSet\mnist",one_hot=True)
print("Number of training samples: {}\nNumber of test samples: {}".format(MNIST.train.num_examples,MNIST.test.num_examples))
In [3]:
def filter2Gray(weight):
"""
Convert filters in CNN into a batch of not yet normalized (e.g ranged from 0 to 255)gray images
:type weight: A Tensor of shape [filter_height, filter_width, in_channels, out_channels]
:type name: str
:rtype: A Tensor of shape [number_of_filters, height, width, 1]
"""
trans=tf.transpose(a=weight,perm=[2,3,0,1])
ret=tf.reshape(tensor=trans,shape=[-1,int(trans.shape[2]),int(trans.shape[3]),1])
return ret
with tf.name_scope("CNN"):
with tf.name_scope("Input"):
X=tf.placeholder(dtype=tf.float32,shape=[None,784],name="X")
Y=tf.placeholder(dtype=tf.float32,shape=[None,10],name="Y")
NetIn=tf.reshape(tensor=X,shape=[-1,28,28,1],name="NetIn")
with tf.name_scope("Conv1"):
W1=tf.Variable(tf.truncated_normal([5,5,1,32]),name="Weight")
b1=tf.Variable(tf.truncated_normal([32]),name="bias")
L1Out=tf.nn.conv2d(input=NetIn,filter=W1,strides=[1,1,1,1],padding="SAME",name="conv")
L1Out=tf.nn.relu(L1Out+b1,name="ReLu")
L1Out=tf.nn.max_pool(value=L1Out,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="L1Out")
#Need to check out what the param#0 and param#3 means in ksize.
with tf.name_scope("Conv2"):
W2=tf.Variable(tf.truncated_normal([5,5,32,64]),name="Weight")
b2=tf.Variable(tf.truncated_normal([64]),name="bias")
L2Out=tf.nn.conv2d(input=L1Out,filter=W2,strides=[1,1,1,1],padding="SAME",name="conv")
L2Out=tf.nn.relu(L2Out+b2,name="ReLu")
L2Out=tf.nn.max_pool(value=L2Out,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="L2Out")
#Need to check out what the param#0 and param#3 means in ksize.
with tf.name_scope("FullConnect"):
flatten=tf.reshape(tensor=L2Out,shape=[-1,7*7*64],name="flatten")
W3=tf.Variable(tf.truncated_normal(dtype=tf.float32,shape=[7*7*64,1024]),name="Weight")
b3=tf.Variable(tf.truncated_normal(dtype=tf.float32,shape=[1024]),name="bias")
L3Out=tf.nn.relu(tf.matmul(flatten,W3)+b3,name="L3Out")
with tf.name_scope("Output"):
W4=tf.Variable(tf.truncated_normal(dtype=tf.float32,shape=[1024,10]),name="Weight")
b4=tf.Variable(tf.truncated_normal(dtype=tf.float32,shape=[10]))
L4Out=tf.matmul(L3Out,W4)+b4
with tf.name_scope("Loss"):
entropy=tf.nn.softmax_cross_entropy_with_logits(logits=L4Out,labels=Y,name="crossEntropy")
loss=tf.reduce_mean(input_tensor=entropy,name="loss")
with tf.name_scope("Train") as scope:
optimizer=tf.train.AdamOptimizer(learning_rate=0.01,name="optimizer").minimize(loss)
#Need to make reference to the summary objects to prevent the garbage collection.
#Do I really have to?
with tf.name_scope("epochSummary") as epochSummary:
sumLoss=tf.summary.scalar(name="lossSummary",tensor=loss)
summary_op_epoch=tf.summary.merge(inputs=tf.get_collection(key=tf.GraphKeys.SUMMARIES,scope=epochSummary),name="epochSummaryOp")
with tf.name_scope("finalSummary") as finalSummary:
sumW1=filter2Gray(weight=W1)
#sumW1=tf.summary.image(tensor=sumW1,name="Conv1Weight",max_outputs=int(sumW1.shape[0]))
sumW1=tf.summary.image(tensor=sumW1,name="Conv1Weight",max_outputs=20)
sumW2=filter2Gray(weight=W2)
#sumW2=tf.summary.image(tensor=sumW2,name="Conv2Weight",max_outputs=int(sumW2.shape[0]))
sumW2=tf.summary.image(tensor=sumW2,name="Conv2Weight",max_outputs=20)
summary_op_final=tf.summary.merge(inputs=tf.get_collection(key=tf.GraphKeys.SUMMARIES,scope=finalSummary),name="finalSummaryOp")
In [4]:
batchSize=20
with tf.Session() as sess:
saver=tf.train.Saver()
saver.restore(sess=sess,save_path=r".\model_checkpoints\MNIST_CNN-"+str(3000))
acc=0
for batch_i in range(int(MNIST.test.num_examples/batchSize)):
x_batch,y_batch=MNIST.test.next_batch(batch_size=batchSize)
pred=sess.run(L4Out,feed_dict={X:x_batch})
acc+=sess.run(tf.reduce_sum(tf.cast(x=tf.equal(tf.argmax(input=pred,axis=1),tf.argmax(input=y_batch,axis=1)),dtype=tf.float32)))
print("Accuracy: {}".format(acc/MNIST.test.num_examples))