In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
%matplotlib inline
import input_data
import numpy

In [3]:
import sys
import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/gzip.py:275: VisibleDeprecationWarning: converting an array with ndim > 0 to an index will result in an error in the future
  chunk = self.extrabuf[offset: offset + size]
input_data.py:35: VisibleDeprecationWarning: converting an array with ndim > 0 to an index will result in an error in the future
  data = data.reshape(num_images, rows, cols, 1)
  • 定義 Input 及 Output 暫存變數
  • Input 為 28x28 的點陣圖素
  • Output 為 10 個 Label Array ,分別代表著 0~9 的預測值

In [4]:
session = tf.Session()

In [5]:
x = tf.placeholder(tf.float32,shape=[28*28])

In [6]:
allx = tf.placeholder(tf.float32,shape=[None,28**2])

In [7]:
testimage = mnist.test.images[0]

In [8]:
dist = tf.reduce_sum(tf.abs(tf.sub(x,allx)), reduction_indices=1)

In [9]:
l2dist = session.run(tf.reduce_sum(tf.sub(x,allx),1),feed_dict={x:testimage,allx:mnist.train.images})
  • 距離最近的那張圖片及為預測結果

In [10]:
predict = tf.arg_min(dist,0)
predictl2 = tf.arg_min(l2dist,0)

In [11]:
def dopredict(testimage):
    return (session.run(predict,feed_dict={x:testimage,allx:mnist.train.images}),
            session.run(predict,feed_dict={x:testimage,allx:mnist.train.images}))

In [12]:
def draw(img):
    tmp = img
    tmp2 = tmp.reshape((28,28))

    plt.imshow(tmp2, cmap = cm.Greys)
    plt.show()

In [16]:
import random
for img in  list(map(lambda _: random.choice(mnist.test.images), range(1))): #mnist.train.images[50:55]:
    draw(img)
    
    p,p2 = dopredict(img)
#     print p,p2
    print(numpy.argmax(mnist.train.labels[p]))
    print(numpy.argmax(mnist.train.labels[p2]))


5
5

結論

  • knn 實作非常的簡單
  • l1 l2 disstance 在此 dataset 效果差異不大

In [ ]: