In [2]:
# -*- coding: utf-8 -*-
"""
Created on Thu May 28 14:18:39 2015

@author: keithlam
"""
import cv2
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

# Feature set containing (x,y) values of 25 known/training data
trainData = np.random.randint(0,100,(25,2)).astype(np.float32)

# Labels each one either Red or Blue with numbers 0 and 1
responses = np.random.randint(0,2,(25,1)).astype(np.float32)

# Take Red families and plot them
red = trainData[responses.ravel()==0]
plt.scatter(red[:,0],red[:,1],80,'r','^')

# Take Blue families and plot them
blue = trainData[responses.ravel()==1]
plt.scatter(blue[:,0],blue[:,1],80,'b','s')

plt.show()


newcomer = np.random.randint(0,100,(1,2)).astype(np.float32)
plt.scatter(newcomer[:,0],newcomer[:,1],80,'g','o')

knn = cv2.KNearest()
knn.train(trainData,responses)
ret, results, neighbours ,dist = knn.find_nearest(newcomer, 3)

print ("result: ", results,"\n")
print ("neighbours: ", neighbours,"\n")
print ("distance: ", dist)

plt.show()

# 10 new comers
newcomers = np.random.randint(0,100,(10,2)).astype(np.float32)
ret, results,neighbours,dist = knn.find_nearest(newcomer, 3)
# The results also will contain 10 labels.

plt.show()


('result: ', array([[ 1.]], dtype=float32), '\n')
('neighbours: ', array([[ 0.,  1.,  1.]], dtype=float32), '\n')
('distance: ', array([[ 10.,  20.,  98.]], dtype=float32))

In [ ]: