Introduction

This is an examination of the Iris data. Here we get descriptive stats on the data, and use machine learning classification techniques.

Load the dataset (from scikit-learn)


In [1]:
%time from sklearn.datasets import load_iris
%time iris = load_iris()
print(type(iris))


CPU times: user 307 ms, sys: 72.5 ms, total: 379 ms
Wall time: 380 ms
CPU times: user 2.86 ms, sys: 154 µs, total: 3.01 ms
Wall time: 3.08 ms
<class 'sklearn.datasets.base.Bunch'>

In [2]:
iris.DESCR


Out[2]:
'Iris Plants Database\n\nNotes\n-----\nData Set Characteristics:\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n    :Summary Statistics:\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n----------\n   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...\n'

Examine the data


In [3]:
print(type(iris.data))
print iris.data.shape
iris.data


<type 'numpy.ndarray'>
(150, 4)
Out[3]:
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2],
       [ 5.4,  3.9,  1.7,  0.4],
       [ 4.6,  3.4,  1.4,  0.3],
       [ 5. ,  3.4,  1.5,  0.2],
       [ 4.4,  2.9,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5.4,  3.7,  1.5,  0.2],
       [ 4.8,  3.4,  1.6,  0.2],
       [ 4.8,  3. ,  1.4,  0.1],
       [ 4.3,  3. ,  1.1,  0.1],
       [ 5.8,  4. ,  1.2,  0.2],
       [ 5.7,  4.4,  1.5,  0.4],
       [ 5.4,  3.9,  1.3,  0.4],
       [ 5.1,  3.5,  1.4,  0.3],
       [ 5.7,  3.8,  1.7,  0.3],
       [ 5.1,  3.8,  1.5,  0.3],
       [ 5.4,  3.4,  1.7,  0.2],
       [ 5.1,  3.7,  1.5,  0.4],
       [ 4.6,  3.6,  1. ,  0.2],
       [ 5.1,  3.3,  1.7,  0.5],
       [ 4.8,  3.4,  1.9,  0.2],
       [ 5. ,  3. ,  1.6,  0.2],
       [ 5. ,  3.4,  1.6,  0.4],
       [ 5.2,  3.5,  1.5,  0.2],
       [ 5.2,  3.4,  1.4,  0.2],
       [ 4.7,  3.2,  1.6,  0.2],
       [ 4.8,  3.1,  1.6,  0.2],
       [ 5.4,  3.4,  1.5,  0.4],
       [ 5.2,  4.1,  1.5,  0.1],
       [ 5.5,  4.2,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5. ,  3.2,  1.2,  0.2],
       [ 5.5,  3.5,  1.3,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 4.4,  3. ,  1.3,  0.2],
       [ 5.1,  3.4,  1.5,  0.2],
       [ 5. ,  3.5,  1.3,  0.3],
       [ 4.5,  2.3,  1.3,  0.3],
       [ 4.4,  3.2,  1.3,  0.2],
       [ 5. ,  3.5,  1.6,  0.6],
       [ 5.1,  3.8,  1.9,  0.4],
       [ 4.8,  3. ,  1.4,  0.3],
       [ 5.1,  3.8,  1.6,  0.2],
       [ 4.6,  3.2,  1.4,  0.2],
       [ 5.3,  3.7,  1.5,  0.2],
       [ 5. ,  3.3,  1.4,  0.2],
       [ 7. ,  3.2,  4.7,  1.4],
       [ 6.4,  3.2,  4.5,  1.5],
       [ 6.9,  3.1,  4.9,  1.5],
       [ 5.5,  2.3,  4. ,  1.3],
       [ 6.5,  2.8,  4.6,  1.5],
       [ 5.7,  2.8,  4.5,  1.3],
       [ 6.3,  3.3,  4.7,  1.6],
       [ 4.9,  2.4,  3.3,  1. ],
       [ 6.6,  2.9,  4.6,  1.3],
       [ 5.2,  2.7,  3.9,  1.4],
       [ 5. ,  2. ,  3.5,  1. ],
       [ 5.9,  3. ,  4.2,  1.5],
       [ 6. ,  2.2,  4. ,  1. ],
       [ 6.1,  2.9,  4.7,  1.4],
       [ 5.6,  2.9,  3.6,  1.3],
       [ 6.7,  3.1,  4.4,  1.4],
       [ 5.6,  3. ,  4.5,  1.5],
       [ 5.8,  2.7,  4.1,  1. ],
       [ 6.2,  2.2,  4.5,  1.5],
       [ 5.6,  2.5,  3.9,  1.1],
       [ 5.9,  3.2,  4.8,  1.8],
       [ 6.1,  2.8,  4. ,  1.3],
       [ 6.3,  2.5,  4.9,  1.5],
       [ 6.1,  2.8,  4.7,  1.2],
       [ 6.4,  2.9,  4.3,  1.3],
       [ 6.6,  3. ,  4.4,  1.4],
       [ 6.8,  2.8,  4.8,  1.4],
       [ 6.7,  3. ,  5. ,  1.7],
       [ 6. ,  2.9,  4.5,  1.5],
       [ 5.7,  2.6,  3.5,  1. ],
       [ 5.5,  2.4,  3.8,  1.1],
       [ 5.5,  2.4,  3.7,  1. ],
       [ 5.8,  2.7,  3.9,  1.2],
       [ 6. ,  2.7,  5.1,  1.6],
       [ 5.4,  3. ,  4.5,  1.5],
       [ 6. ,  3.4,  4.5,  1.6],
       [ 6.7,  3.1,  4.7,  1.5],
       [ 6.3,  2.3,  4.4,  1.3],
       [ 5.6,  3. ,  4.1,  1.3],
       [ 5.5,  2.5,  4. ,  1.3],
       [ 5.5,  2.6,  4.4,  1.2],
       [ 6.1,  3. ,  4.6,  1.4],
       [ 5.8,  2.6,  4. ,  1.2],
       [ 5. ,  2.3,  3.3,  1. ],
       [ 5.6,  2.7,  4.2,  1.3],
       [ 5.7,  3. ,  4.2,  1.2],
       [ 5.7,  2.9,  4.2,  1.3],
       [ 6.2,  2.9,  4.3,  1.3],
       [ 5.1,  2.5,  3. ,  1.1],
       [ 5.7,  2.8,  4.1,  1.3],
       [ 6.3,  3.3,  6. ,  2.5],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 7.1,  3. ,  5.9,  2.1],
       [ 6.3,  2.9,  5.6,  1.8],
       [ 6.5,  3. ,  5.8,  2.2],
       [ 7.6,  3. ,  6.6,  2.1],
       [ 4.9,  2.5,  4.5,  1.7],
       [ 7.3,  2.9,  6.3,  1.8],
       [ 6.7,  2.5,  5.8,  1.8],
       [ 7.2,  3.6,  6.1,  2.5],
       [ 6.5,  3.2,  5.1,  2. ],
       [ 6.4,  2.7,  5.3,  1.9],
       [ 6.8,  3. ,  5.5,  2.1],
       [ 5.7,  2.5,  5. ,  2. ],
       [ 5.8,  2.8,  5.1,  2.4],
       [ 6.4,  3.2,  5.3,  2.3],
       [ 6.5,  3. ,  5.5,  1.8],
       [ 7.7,  3.8,  6.7,  2.2],
       [ 7.7,  2.6,  6.9,  2.3],
       [ 6. ,  2.2,  5. ,  1.5],
       [ 6.9,  3.2,  5.7,  2.3],
       [ 5.6,  2.8,  4.9,  2. ],
       [ 7.7,  2.8,  6.7,  2. ],
       [ 6.3,  2.7,  4.9,  1.8],
       [ 6.7,  3.3,  5.7,  2.1],
       [ 7.2,  3.2,  6. ,  1.8],
       [ 6.2,  2.8,  4.8,  1.8],
       [ 6.1,  3. ,  4.9,  1.8],
       [ 6.4,  2.8,  5.6,  2.1],
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.4,  2.8,  6.1,  1.9],
       [ 7.9,  3.8,  6.4,  2. ],
       [ 6.4,  2.8,  5.6,  2.2],
       [ 6.3,  2.8,  5.1,  1.5],
       [ 6.1,  2.6,  5.6,  1.4],
       [ 7.7,  3. ,  6.1,  2.3],
       [ 6.3,  3.4,  5.6,  2.4],
       [ 6.4,  3.1,  5.5,  1.8],
       [ 6. ,  3. ,  4.8,  1.8],
       [ 6.9,  3.1,  5.4,  2.1],
       [ 6.7,  3.1,  5.6,  2.4],
       [ 6.9,  3.1,  5.1,  2.3],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 6.8,  3.2,  5.9,  2.3],
       [ 6.7,  3.3,  5.7,  2.5],
       [ 6.7,  3. ,  5.2,  2.3],
       [ 6.3,  2.5,  5. ,  1.9],
       [ 6.5,  3. ,  5.2,  2. ],
       [ 6.2,  3.4,  5.4,  2.3],
       [ 5.9,  3. ,  5.1,  1.8]])

Examine the features


In [4]:
iris.feature_names


Out[4]:
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

Examine the response (target)


In [5]:
print(type(iris.target))
print iris.target.shape
iris.target


<type 'numpy.ndarray'>
(150,)
Out[5]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [6]:
iris.target_names


Out[6]:
array(['setosa', 'versicolor', 'virginica'], 
      dtype='|S10')

In [7]:
import seaborn as sns
%matplotlib inline

In [8]:
iris2 = sns.load_dataset("iris")
g = sns.pairplot(iris2, hue="species")


/usr/local/lib/python2.7/dist-packages/matplotlib/font_manager.py:1282: UserWarning: findfont: Font family [u'sans-serif'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))
/usr/local/lib/python2.7/dist-packages/matplotlib/font_manager.py:1292: UserWarning: findfont: Could not match :family=Bitstream Vera Sans:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0. Returning /usr/share/matplotlib/mpl-data/fonts/ttf/cmb10.ttf
  UserWarning)
/usr/local/lib/python2.7/dist-packages/matplotlib/font_manager.py:1292: UserWarning: findfont: Could not match :family=Bitstream Vera Sans:style=normal:variant=normal:weight=normal:stretch=normal:size=11.0. Returning /usr/share/matplotlib/mpl-data/fonts/ttf/cmb10.ttf
  UserWarning)
/usr/local/lib/python2.7/dist-packages/matplotlib/font_manager.py:1292: UserWarning: findfont: Could not match :family=Bitstream Vera Sans:style=normal:variant=normal:weight=normal:stretch=normal:size=9.35. Returning /usr/share/matplotlib/mpl-data/fonts/ttf/cmb10.ttf
  UserWarning)

In [9]:
iris2.describe()


Out[9]:
sepal_length sepal_width petal_length petal_width
count 150.000000 150.000000 150.000000 150.000000
mean 5.843333 3.057333 3.758000 1.199333
std 0.828066 0.435866 1.765298 0.762238
min 4.300000 2.000000 1.000000 0.100000
25% 5.100000 2.800000 1.600000 0.300000
50% 5.800000 3.000000 4.350000 1.300000
75% 6.400000 3.300000 5.100000 1.800000
max 7.900000 4.400000 6.900000 2.500000

Prepare data for use with Scikit-learn


In [10]:
X = iris.data
y = iris.target
print X.shape
print y.shape


(150, 4)
(150,)

Do test/train splits


In [11]:
from sklearn.cross_validation import train_test_split
%time XTrain, XTest, yTrain, yTest = train_test_split(X,y,test_size=0.333,random_state=3141)
print XTrain.shape
print yTrain.shape
print XTest.shape
print yTest.shape


CPU times: user 965 µs, sys: 269 µs, total: 1.23 ms
Wall time: 888 µs
(100, 4)
(100,)
(50, 4)
(50,)

In [12]:
def fitAndPredict(model,XTrain,yTrain,XTest):
    """Fits a model and returns the fitted model along with its predictions of test data.
       :param model: An sklearn model.
       :param XTrain: Predictors training data.
       :param yTrain: Response training data.
       :param XTest: Predictors test data.
       :retrun: Array of the fitted model along with its predictions of test data.
    """
    model.fit(XTrain,yTrain)
    yPred = model.predict(XTest)
    return model,yPred

Modelling and prediction

KNN


In [13]:
%%timeit
from sklearn.neighbors import KNeighborsClassifier
global ks 
global knns 
ks = range(1,35)
knns = []
for k in ks:
    knns.append(fitAndPredict( KNeighborsClassifier(n_neighbors=k), XTrain,yTrain,XTest))


10 loops, best of 3: 56.7 ms per loop

Logistic Regression


In [14]:
from sklearn.linear_model import LogisticRegression
%time logisticRegression = fitAndPredict( LogisticRegression(), XTrain,yTrain,XTest)


CPU times: user 2.39 ms, sys: 705 µs, total: 3.09 ms
Wall time: 2.14 ms

Linear SVM


In [15]:
from sklearn import svm
%time linearSVC = fitAndPredict( svm.LinearSVC(), XTrain,yTrain,XTest)


CPU times: user 20 ms, sys: 4.28 ms, total: 24.3 ms
Wall time: 18.7 ms

Use Naive Bayes Modelling and predict classes


In [16]:
from sklearn.naive_bayes import GaussianNB
%time gaussianNB = fitAndPredict( GaussianNB(), XTrain,yTrain,XTest)


CPU times: user 1.49 ms, sys: 96 µs, total: 1.58 ms
Wall time: 1.59 ms

Evaluate accuracy of models

Classification accuracy

Test set accuracy


In [17]:
from sklearn import metrics
models = {}

In [18]:
%%timeit
i = 0
for k in knns:
    i = i + 1
    models["knn"+str(i) if i > 9 else 'knn0' + str(i)] = metrics.accuracy_score(yTest, k[1])


100 loops, best of 3: 6.35 ms per loop

In [19]:
%time models["logReg"] = metrics.accuracy_score(yTest, logisticRegression[1])
%time models["linearSVM"] = metrics.accuracy_score(yTest, linearSVC[1])
%time models["naiveBayes"] = metrics.accuracy_score(yTest, gaussianNB[1])
models


CPU times: user 921 µs, sys: 267 µs, total: 1.19 ms
Wall time: 654 µs
CPU times: user 728 µs, sys: 172 µs, total: 900 µs
Wall time: 549 µs
CPU times: user 775 µs, sys: 180 µs, total: 955 µs
Wall time: 569 µs
Out[19]:
{'knn01': 0.92000000000000004,
 'knn02': 0.92000000000000004,
 'knn03': 0.93999999999999995,
 'knn04': 0.93999999999999995,
 'knn05': 0.93999999999999995,
 'knn06': 0.97999999999999998,
 'knn07': 0.97999999999999998,
 'knn08': 0.97999999999999998,
 'knn09': 0.97999999999999998,
 'knn10': 0.97999999999999998,
 'knn11': 0.97999999999999998,
 'knn12': 0.97999999999999998,
 'knn13': 0.97999999999999998,
 'knn14': 0.97999999999999998,
 'knn15': 0.97999999999999998,
 'knn16': 0.97999999999999998,
 'knn17': 0.95999999999999996,
 'knn18': 0.95999999999999996,
 'knn19': 0.95999999999999996,
 'knn20': 0.93999999999999995,
 'knn21': 0.93999999999999995,
 'knn22': 0.93999999999999995,
 'knn23': 0.95999999999999996,
 'knn24': 0.93999999999999995,
 'knn25': 0.93999999999999995,
 'knn26': 0.93999999999999995,
 'knn27': 0.93999999999999995,
 'knn28': 0.90000000000000002,
 'knn29': 0.90000000000000002,
 'knn30': 0.90000000000000002,
 'knn31': 0.90000000000000002,
 'knn32': 0.90000000000000002,
 'knn33': 0.90000000000000002,
 'knn34': 0.90000000000000002,
 'linearSVM': 0.93999999999999995,
 'logReg': 0.93999999999999995,
 'naiveBayes': 0.93999999999999995}

Visualise relationship between KNN K and prediction accuracy


In [20]:
import matplotlib.pyplot as plt
%matplotlib inline
knns = [(key, value) for key, value in models.iteritems() if key.startswith("knn")]
knnlist= sorted(knns)
knnlist = [li[1] for li in knnlist]
plt.plot(ks, knnlist)
plt.xlabel('K')
plt.ylabel('Test Data Accuracy')


Out[20]:
<matplotlib.text.Text at 0x7f5595092a50>