Imports & Dataset


In [2]:
from numpy import *
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
dataset = array([[104, 3],[100,2],[81,1],[10,101],[5, 99],[2, 98]])

In [4]:
labels = ['Romance', 'Romance', 'Romance', 'Action', 'Action', 'Action']

Helper Functions


In [5]:
def plot_dataset(dataset):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_ylabel('Number of Kisses')
    ax.set_xlabel('Number of Kicks')
    ax.scatter(dataset[:,0], dataset[:,1])
    plt.show()

In [6]:
def plot_with_inX(dataset, inX):
    new_dataset = list(dataset[:])
    new_dataset.append(inX)
    plot_dataset(array(new_dataset))

In [7]:
def calc_closest_points(dataset, inX):
    dataset_size = dataset.shape[0]
    diff_mat = tile(inX, (dataset_size, 1)) - dataset
    sq_diffmat = diff_mat**2
    summed = sq_diffmat.sum(axis=1)
    distances = summed**0.5
    return distances.argsort()

In [8]:
def classify(labels, closest_points, k):
    closest_labels=[labels[i] for i in closest_points[:k]]
    return max(set(closest_labels), key=closest_labels.count)

Code


In [9]:
plot_dataset(dataset)



In [10]:
plot_with_inX(dataset, [30,130])



In [11]:
calc_closest_points(dataset, [30, 130])


Out[11]:
array([3, 4, 5, 2, 1, 0], dtype=int64)

In [ ]: