In [1]:
%matplotlib inline

In [2]:
# Regression can be used on pockets of feature space.
# Think of the dataset as subject for several data processes
# Regression can be used in the context of clustering.
# Because Regression is a supervised technique, we need to use
# k-Nearest Neighbors instead of k-Means

In [3]:
# Using Iris dataset to predict the petal width for each flower
# clustering by iris species can give us better results.
# We cannot cluster by species but we can work under the
# assumption that the X's will be closer for the same species

In [4]:
from sklearn.datasets import load_iris

In [10]:
iris = load_iris()
X = iris.data
y = iris.target

In [11]:
iris.feature_names


Out[11]:
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

In [12]:
# try to predict the petal length based on the sepal 
# length and width. Fit a linear regression to see how well the
# k-NN regresion does in comparison.

In [14]:
from sklearn.linear_model import LinearRegression
import numpy as np
lr = LinearRegression()
lr.fit(X, y)


Out[14]:
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)

In [16]:
print "The MSE is: {:.2}".format(np.power(y - lr.predict(X),
                                          2).mean())


The MSE is: 0.046

In [17]:
# For k-NN regression:
from sklearn.neighbors import KNeighborsRegressor

In [19]:
knnr = KNeighborsRegressor(n_neighbors=10)
knnr.fit(X, y)
print "The MSE is: {:.2}".format(np.power(y - knnr.predict(X),
                                          2).mean())


The MSE is: 0.022

In [20]:
import matplotlib.pyplot as plt
f, ax = plt.subplots(nrows=2, figsize=(7,10))
ax[0].set_title("Predictions")
ax[0].scatter(X[:, 0], X[:, 1], s=lr.predict(X)*80,
              label='LR Predictions', color='c',
              edgecolors='black')
ax[1].scatter(X[:, 0], X[: ,1], s=knnr.predict(X)*80,
             label='k-NN Predictions', color='m',
             edgecolors='black')
ax[0].legend()
ax[1].legend()


/usr/local/lib/python2.7/site-packages/matplotlib/collections.py:764: RuntimeWarning: invalid value encountered in sqrt
  scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor
Out[20]:
<matplotlib.legend.Legend at 0x11430c910>

In [24]:
setosa_idx = np.where(iris.target_names=='setosa')
setosa_mask = iris.target == setosa_idx[0]
y[setosa_mask][:5]


(array([0]),)
Out[24]:
array([0, 0, 0, 0, 0])

In [22]:
knnr.predict(X)[setosa_mask][:5]


Out[22]:
array([ 0.,  0.,  0.,  0.,  0.])

In [23]:
lr.predict(X)[setosa_mask][:5]


Out[23]:
array([-0.08265827, -0.03858976, -0.04818969,  0.01260878, -0.07610817])

In [ ]: