In [3]:
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
%matplotlib inline
import input_data
import numpy

In [4]:
import sys
import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  • 定義 Input 及 Output 暫存變數
  • Input 為 28x28 的點陣圖素
  • Output 為 10 個 Label Array ,分別代表著 0~9 的預測值

In [7]:
x = tf.placeholder(tf.float32,shape=[None,28*28])
y = tf.placeholder(tf.float32,shape=[None,10])

In [22]:
# Create model

# Set model weights
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

xw = tf.matmul(x, W)
r = xw + b
a = tf.nn.softmax(r)

In [33]:
cost = -tf.reduce_sum(y*tf.log(a))

In [34]:
op = tf.train.GradientDescentOptimizer(0.01).minimize(cost)

以下進行開始進行實際運算


In [35]:
init = tf.initialize_all_variables()

In [36]:
sess = tf.Session()

In [37]:
sess.run(init)

In [38]:
epochs = 100
batch_size = 200
for _ in range(100):
    avg_cost = 0
    input_x , output_y = mnist.train.next_batch(batch_size)
    sess.run(op,feed_dict={x:input_x,
                       y:output_y  })
    avg_cost += sess.run(cost,feed_dict={x:input_x,
                       y:output_y  })
    
    print "avg_cost:" ,avg_cost/batch_size


avg_cost: 1.53818847656
avg_cost: 4.20006500244
avg_cost: 3.53161804199
avg_cost: 5.75972351074
avg_cost: 6.60235961914
avg_cost: 5.1791998291
avg_cost: 5.26221313477
avg_cost: 1.79604919434
avg_cost: 2.81106994629
avg_cost: 1.26635177612
avg_cost: 0.970408935547
avg_cost: 2.00831665039
avg_cost: 1.75061691284
avg_cost: 2.19973968506
avg_cost: 1.0637865448
avg_cost: 0.93669631958
avg_cost: 1.28613098145
avg_cost: 0.47118812561
avg_cost: 0.464879074097
avg_cost: 0.662971420288
avg_cost: 0.668267059326
avg_cost: 0.584128265381
avg_cost: 0.624622650146
avg_cost: 1.04467071533
avg_cost: 1.78894287109
avg_cost: 0.916661224365
avg_cost: 0.863260498047
avg_cost: 1.38139678955
avg_cost: 0.785594482422
avg_cost: 0.40527130127
avg_cost: 0.748572311401
avg_cost: 1.42575210571
avg_cost: 0.672653656006
avg_cost: 0.317291584015
avg_cost: 0.412251815796
avg_cost: 0.291154708862
avg_cost: 0.362015991211
avg_cost: 0.316969490051
avg_cost: 0.463984222412
avg_cost: 0.530088119507
avg_cost: 0.25291891098
avg_cost: 0.390248565674
avg_cost: 0.436983985901
avg_cost: 0.827095031738
avg_cost: 0.544939117432
avg_cost: 0.449331207275
avg_cost: 0.708355407715
avg_cost: 0.494112892151
avg_cost: 0.382871055603
avg_cost: 0.369703903198
avg_cost: 0.33285484314
avg_cost: 0.377543029785
avg_cost: 0.329020233154
avg_cost: 0.365044708252
avg_cost: 0.286311912537
avg_cost: 0.258263244629
avg_cost: 0.61345123291
avg_cost: 0.575455703735
avg_cost: 0.4953490448
avg_cost: 0.390891075134
avg_cost: 0.308500938416
avg_cost: 0.301270771027
avg_cost: 0.299105205536
avg_cost: 0.257990455627
avg_cost: 0.343915176392
avg_cost: 0.51428276062
avg_cost: 1.20363868713
avg_cost: 0.698873672485
avg_cost: 0.426701049805
avg_cost: 0.70049041748
avg_cost: 0.331801528931
avg_cost: 0.363929672241
avg_cost: 0.415712966919
avg_cost: 0.529532318115
avg_cost: 0.330678100586
avg_cost: 0.571112785339
avg_cost: 0.611205291748
avg_cost: 0.416135940552
avg_cost: 0.585308380127
avg_cost: 0.455072631836
avg_cost: 0.239188995361
avg_cost: 0.294139823914
avg_cost: 0.327356185913
avg_cost: 0.340790977478
avg_cost: 0.271427326202
avg_cost: 0.264783706665
avg_cost: 0.332580070496
avg_cost: 0.343216705322
avg_cost: 0.56881401062
avg_cost: 0.319203109741
avg_cost: 0.40915096283
avg_cost: 0.48857170105
avg_cost: 0.439623680115
avg_cost: 0.351910400391
avg_cost: 0.363420562744
avg_cost: 0.192328529358
avg_cost: 0.349356040955
avg_cost: 0.190961685181
avg_cost: 0.367852096558
avg_cost: 0.541251602173

In [39]:
predict = tf.argmax(a, 1)
# sess.run(predict,feed_dict={x:mnist.test.images})

ans = tf.argmax(y,1)
# sess.run(ans, feed_dict= {y:mnist.test.labels})

preccision = sess.run(tf.reduce_mean(tf.cast(tf.equal(predict,ans),"float")),feed_dict= {x:mnist.test.images,y:mnist.test.labels}  )

In [40]:
print preccision


0.8163

In [43]:
import random

for img in  list(map(lambda _: random.choice(mnist.train.images), range(5))): #mnist.train.images[50:55]:
    tmp = img
    tmp2 = tmp.reshape((28,28))

    plt.imshow(tmp2, cmap = cm.Greys)
    plt.show()
    print sess.run(predict,feed_dict={x:[tmp]})[0]


9
6
6
3
2

In [79]:



1

In [ ]: