In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

Load the iris dataset


In [2]:
from sklearn.datasets import load_iris
iris = load_iris()

In [3]:
from IPython.core.display import Image, display
display(Image(filename='images/iris_setosa_2.jpg'))
print("Iris Setosa\n")

display(Image(filename='images/iris_versicolor.jpg'))
print("Iris Versicolor\n")

display(Image(filename='images/iris_virginica_2.jpg'))
print("Iris Virginica")

display(Image(filename='images/iris_petal_sepal_2.png'))
print("Petals ands Sepals")


Iris Setosa

Iris Versicolor

Iris Virginica
Petals ands Sepals

Explore the dataset


In [4]:
iris.keys()


Out[4]:
['target_names', 'data', 'target', 'DESCR', 'feature_names']

In [5]:
iris.feature_names


Out[5]:
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

In [6]:
iris.target_names


Out[6]:
array(['setosa', 'versicolor', 'virginica'], 
      dtype='|S10')
Sepal L Sepal W Petal L Petal W Target Target Name
5.1 3.5 1.4 0.2 0 setosa
4.7 3.2 1.3 0.2 0 setosa
6.3 2.3 4.4 1.3 1 versicolour
6.2 3.4 5.4 2.3 2 virginica
... ...
  • each row: sample, example, observation, record, instance
  • each column: feature, predictor, attribute, independent variable, input, regressor, covariant
  • target: respose, outcome, label, dependent variable
  • features in the iris dataset:
    1. sepal length in cm
    2. sepal width in cm
    3. petal length in cm
    4. petal width in cm
  • target classes to predict:
    1. iris setosa
    2. iris versicolour
    3. iris virginica

In [7]:
iris.data.shape


Out[7]:
(150L, 4L)

In [8]:
iris.target.shape


Out[8]:
(150L,)

In [9]:
iris.data[0]


Out[9]:
array([ 5.1,  3.5,  1.4,  0.2])

In [10]:
iris.target[0]


Out[10]:
0

In [11]:
iris.data[101]


Out[11]:
array([ 5.8,  2.7,  5.1,  1.9])

In [12]:
iris.target[101]


Out[12]:
2

In [13]:
iris.data


Out[13]:
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2],
       [ 5.4,  3.9,  1.7,  0.4],
       [ 4.6,  3.4,  1.4,  0.3],
       [ 5. ,  3.4,  1.5,  0.2],
       [ 4.4,  2.9,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5.4,  3.7,  1.5,  0.2],
       [ 4.8,  3.4,  1.6,  0.2],
       [ 4.8,  3. ,  1.4,  0.1],
       [ 4.3,  3. ,  1.1,  0.1],
       [ 5.8,  4. ,  1.2,  0.2],
       [ 5.7,  4.4,  1.5,  0.4],
       [ 5.4,  3.9,  1.3,  0.4],
       [ 5.1,  3.5,  1.4,  0.3],
       [ 5.7,  3.8,  1.7,  0.3],
       [ 5.1,  3.8,  1.5,  0.3],
       [ 5.4,  3.4,  1.7,  0.2],
       [ 5.1,  3.7,  1.5,  0.4],
       [ 4.6,  3.6,  1. ,  0.2],
       [ 5.1,  3.3,  1.7,  0.5],
       [ 4.8,  3.4,  1.9,  0.2],
       [ 5. ,  3. ,  1.6,  0.2],
       [ 5. ,  3.4,  1.6,  0.4],
       [ 5.2,  3.5,  1.5,  0.2],
       [ 5.2,  3.4,  1.4,  0.2],
       [ 4.7,  3.2,  1.6,  0.2],
       [ 4.8,  3.1,  1.6,  0.2],
       [ 5.4,  3.4,  1.5,  0.4],
       [ 5.2,  4.1,  1.5,  0.1],
       [ 5.5,  4.2,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5. ,  3.2,  1.2,  0.2],
       [ 5.5,  3.5,  1.3,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 4.4,  3. ,  1.3,  0.2],
       [ 5.1,  3.4,  1.5,  0.2],
       [ 5. ,  3.5,  1.3,  0.3],
       [ 4.5,  2.3,  1.3,  0.3],
       [ 4.4,  3.2,  1.3,  0.2],
       [ 5. ,  3.5,  1.6,  0.6],
       [ 5.1,  3.8,  1.9,  0.4],
       [ 4.8,  3. ,  1.4,  0.3],
       [ 5.1,  3.8,  1.6,  0.2],
       [ 4.6,  3.2,  1.4,  0.2],
       [ 5.3,  3.7,  1.5,  0.2],
       [ 5. ,  3.3,  1.4,  0.2],
       [ 7. ,  3.2,  4.7,  1.4],
       [ 6.4,  3.2,  4.5,  1.5],
       [ 6.9,  3.1,  4.9,  1.5],
       [ 5.5,  2.3,  4. ,  1.3],
       [ 6.5,  2.8,  4.6,  1.5],
       [ 5.7,  2.8,  4.5,  1.3],
       [ 6.3,  3.3,  4.7,  1.6],
       [ 4.9,  2.4,  3.3,  1. ],
       [ 6.6,  2.9,  4.6,  1.3],
       [ 5.2,  2.7,  3.9,  1.4],
       [ 5. ,  2. ,  3.5,  1. ],
       [ 5.9,  3. ,  4.2,  1.5],
       [ 6. ,  2.2,  4. ,  1. ],
       [ 6.1,  2.9,  4.7,  1.4],
       [ 5.6,  2.9,  3.6,  1.3],
       [ 6.7,  3.1,  4.4,  1.4],
       [ 5.6,  3. ,  4.5,  1.5],
       [ 5.8,  2.7,  4.1,  1. ],
       [ 6.2,  2.2,  4.5,  1.5],
       [ 5.6,  2.5,  3.9,  1.1],
       [ 5.9,  3.2,  4.8,  1.8],
       [ 6.1,  2.8,  4. ,  1.3],
       [ 6.3,  2.5,  4.9,  1.5],
       [ 6.1,  2.8,  4.7,  1.2],
       [ 6.4,  2.9,  4.3,  1.3],
       [ 6.6,  3. ,  4.4,  1.4],
       [ 6.8,  2.8,  4.8,  1.4],
       [ 6.7,  3. ,  5. ,  1.7],
       [ 6. ,  2.9,  4.5,  1.5],
       [ 5.7,  2.6,  3.5,  1. ],
       [ 5.5,  2.4,  3.8,  1.1],
       [ 5.5,  2.4,  3.7,  1. ],
       [ 5.8,  2.7,  3.9,  1.2],
       [ 6. ,  2.7,  5.1,  1.6],
       [ 5.4,  3. ,  4.5,  1.5],
       [ 6. ,  3.4,  4.5,  1.6],
       [ 6.7,  3.1,  4.7,  1.5],
       [ 6.3,  2.3,  4.4,  1.3],
       [ 5.6,  3. ,  4.1,  1.3],
       [ 5.5,  2.5,  4. ,  1.3],
       [ 5.5,  2.6,  4.4,  1.2],
       [ 6.1,  3. ,  4.6,  1.4],
       [ 5.8,  2.6,  4. ,  1.2],
       [ 5. ,  2.3,  3.3,  1. ],
       [ 5.6,  2.7,  4.2,  1.3],
       [ 5.7,  3. ,  4.2,  1.2],
       [ 5.7,  2.9,  4.2,  1.3],
       [ 6.2,  2.9,  4.3,  1.3],
       [ 5.1,  2.5,  3. ,  1.1],
       [ 5.7,  2.8,  4.1,  1.3],
       [ 6.3,  3.3,  6. ,  2.5],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 7.1,  3. ,  5.9,  2.1],
       [ 6.3,  2.9,  5.6,  1.8],
       [ 6.5,  3. ,  5.8,  2.2],
       [ 7.6,  3. ,  6.6,  2.1],
       [ 4.9,  2.5,  4.5,  1.7],
       [ 7.3,  2.9,  6.3,  1.8],
       [ 6.7,  2.5,  5.8,  1.8],
       [ 7.2,  3.6,  6.1,  2.5],
       [ 6.5,  3.2,  5.1,  2. ],
       [ 6.4,  2.7,  5.3,  1.9],
       [ 6.8,  3. ,  5.5,  2.1],
       [ 5.7,  2.5,  5. ,  2. ],
       [ 5.8,  2.8,  5.1,  2.4],
       [ 6.4,  3.2,  5.3,  2.3],
       [ 6.5,  3. ,  5.5,  1.8],
       [ 7.7,  3.8,  6.7,  2.2],
       [ 7.7,  2.6,  6.9,  2.3],
       [ 6. ,  2.2,  5. ,  1.5],
       [ 6.9,  3.2,  5.7,  2.3],
       [ 5.6,  2.8,  4.9,  2. ],
       [ 7.7,  2.8,  6.7,  2. ],
       [ 6.3,  2.7,  4.9,  1.8],
       [ 6.7,  3.3,  5.7,  2.1],
       [ 7.2,  3.2,  6. ,  1.8],
       [ 6.2,  2.8,  4.8,  1.8],
       [ 6.1,  3. ,  4.9,  1.8],
       [ 6.4,  2.8,  5.6,  2.1],
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.4,  2.8,  6.1,  1.9],
       [ 7.9,  3.8,  6.4,  2. ],
       [ 6.4,  2.8,  5.6,  2.2],
       [ 6.3,  2.8,  5.1,  1.5],
       [ 6.1,  2.6,  5.6,  1.4],
       [ 7.7,  3. ,  6.1,  2.3],
       [ 6.3,  3.4,  5.6,  2.4],
       [ 6.4,  3.1,  5.5,  1.8],
       [ 6. ,  3. ,  4.8,  1.8],
       [ 6.9,  3.1,  5.4,  2.1],
       [ 6.7,  3.1,  5.6,  2.4],
       [ 6.9,  3.1,  5.1,  2.3],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 6.8,  3.2,  5.9,  2.3],
       [ 6.7,  3.3,  5.7,  2.5],
       [ 6.7,  3. ,  5.2,  2.3],
       [ 6.3,  2.5,  5. ,  1.9],
       [ 6.5,  3. ,  5.2,  2. ],
       [ 6.2,  3.4,  5.4,  2.3],
       [ 5.9,  3. ,  5.1,  1.8]])

In [14]:
iris.target


Out[14]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

Plot

Are the examples linearly separable?


In [15]:
# pick any 2 attributes and plot them in a 2d plot

fig, axes = plt.subplots(2,3)
pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
samples = iris.data
target = iris.target
for i, (p0, p1) in enumerate(pairs):
    
    ax = axes.flat[i]
    
    for t, marker, c in zip(xrange(3), ">ox", "rgb"):
        # scatter plot 
        a = ax.scatter(samples[target == t, p0], samples[target == t, p1], marker=marker, c=c, label=iris.target_names[t])
        
    #ax.legend(shadow=True, fancybox=True)
    plt.legend(fancybox=True, loc='upper left', framealpha=0.5) 
    
           
    ax.set_xlabel(iris.feature_names[p0])
    ax.set_ylabel(iris.feature_names[p1])
    ax.set_xticks([])
    ax.set_yticks([])
fig.set_size_inches(15, 8)
fig.tight_layout()


C:\Anaconda2\lib\site-packages\matplotlib\axes\_axes.py:519: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.
  warnings.warn("No labelled objects found. "

Use our ANN


In [16]:
import numpy as np
from sklearn.cross_validation import train_test_split 
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelBinarizer, MinMaxScaler

Load the data


In [17]:
# load the data
iris = load_iris()
X, y = iris.data, iris.target

Transform the data


In [18]:
# scale the matrix with all the 4 features
xsc = MinMaxScaler(feature_range=(0, 1), copy=True)
xsc.fit(X)
# xsc will remember the min and max for each feature


Out[18]:
MinMaxScaler(copy=True, feature_range=(0, 1))

In [19]:
# binarize the targets
# if we have 3 classes:
# 0 setosa ->      [1 0 0]
# 1 versicolour->  [0 1 0]
# 2 virginica->    [0 0 1]

ylb = LabelBinarizer()
ylb.fit(y)
# ylb will remember that [0.2 0.7 0.1] means class 1
# and [0.3 0.1 0.6] means class 2


Out[19]:
LabelBinarizer(neg_label=0, pos_label=1, sparse_output=False)

Training vs Testing split


In [20]:
X_train, X_test, y_train, y_test = train_test_split(xsc.transform(X), y)

# our network uses lists (pure Python, no frameworks) so passing numpy arrays will fail
X_train_l = X_train.tolist()
y_train_l = ylb.transform(y_train).tolist()
X_test_l = X_test.tolist()

Create the ANN


In [21]:
from ann import ANN

# 4 input neurons: sepal length (cm),sepal width (cm), petal length (cm), petal width (cm)
# 10 hidden neurons
# 3 output neurons: setosa [1, 0, 0], versicolor [0, 1, 0], virginica [0, 0, 1]
nn = ANN([4, 10, 3])

# train the ANN
nn.train(X_train_l, y_train_l, 5000)


(0, 31.021428655532773)
(100, 3.3408752289359342)
(200, 2.3240018942521643)
(300, 2.0574224730976693)
(400, 1.9121984952587834)
(500, 1.811618532044788)
(600, 1.7366873158713192)
(700, 1.6782173988013018)
(800, 1.6300163090276092)
(900, 1.5888376278355734)
(1000, 1.5531877836220178)
(1100, 1.521979780353659)
(1200, 1.4942727046828588)
(1300, 1.4693315238479252)
(1400, 1.4466208611138318)
(1500, 1.4257509120710752)
(1600, 1.4064298104168833)
(1700, 1.3884328888348274)
(1800, 1.3715836739319913)
(1900, 1.3557415736091112)
(2000, 1.3407933776831318)
(2100, 1.3266470603424791)
(2200, 1.3132270840328646)
(2300, 1.3004707665392783)
(2400, 1.288325453348148)
(2500, 1.2767463239293908)
(2600, 1.2656947025709107)
(2700, 1.2551367680141159)
(2800, 1.2450425734633193)
(2900, 1.2353853039254794)
(3000, 1.2261407121196717)
(3100, 1.217286687059104)
(3200, 1.2088029204347623)
(3300, 1.2006706449392146)
(3400, 1.1928724257299874)
(3500, 1.1853919915731552)
(3600, 1.178214096128891)
(3700, 1.1713244026455716)
(3800, 1.1647093872888126)
(3900, 1.15835625767256)
(4000, 1.1522528840578685)
(4100, 1.1463877412751795)
(4200, 1.1407498598069175)
(4300, 1.1353287847102598)
(4400, 1.1301145412170812)
(4500, 1.1250976059568856)
(4600, 1.1202688828334129)
(4700, 1.1156196826670373)
(4800, 1.1111417057999418)
(4900, 1.10682702695678)

Predict using the ANN


In [22]:
example_id = 0
pred = nn.predict(X_test_l[example_id])
#pred_class = ylb.inverse_transform(np.array([pred]))[0]
pred_class = np.argmax(np.array(pred))

# our prediction
print('prediction', pred)
# our predicted class
print('predicted class', pred_class)
# the predicted class name
print('predicted class name', iris.target_names[pred_class])
# the correct class name
print('correct class name', iris.target_names[y_test[example_id]])


('prediction', [0.9940208005374855, -0.02538109634867978, 0.005452991564777164])
('predicted class', 0)
('predicted class name', 'setosa')
('correct class name', 'setosa')

In [23]:
preds = np.array([nn.predict(record) for record in X_test_l])
preds


Out[23]:
array([[  9.94020801e-01,  -2.53810963e-02,   5.45299156e-03],
       [  9.94562277e-01,   1.54890612e-02,   2.94772663e-03],
       [ -2.78499450e-03,   2.61752085e-02,   9.99217633e-01],
       [  3.21910676e-03,   2.50451142e-02,   9.98659358e-01],
       [  9.94381310e-01,  -9.08355825e-03,   4.99165009e-03],
       [  9.94064879e-01,  -1.52339496e-02,   6.34811156e-03],
       [ -8.06706541e-05,   9.99999752e-01,  -1.74158703e-02],
       [ -3.29168419e-03,  -6.21464219e-03,   9.91761640e-01],
       [  9.94402676e-01,   2.48699283e-03,   3.35412872e-03],
       [  9.01331378e-04,   4.25853254e-03,   9.98198170e-01],
       [  9.94681805e-01,   2.67417765e-03,   3.96626306e-03],
       [ -1.87669424e-03,   9.21878729e-03,   9.99220085e-01],
       [ -3.90071870e-03,   9.99999756e-01,  -9.33634304e-03],
       [ -3.47345222e-03,   6.61451068e-01,   6.34488213e-01],
       [  4.20843847e-03,   9.99999733e-01,  -9.47402699e-04],
       [ -2.21711306e-02,   9.45797202e-03,   9.96015763e-01],
       [  9.94929716e-01,  -3.57184603e-03,   4.84478165e-03],
       [  1.29932603e-02,   9.99996244e-01,  -2.75545546e-02],
       [  5.12528150e-03,   1.76996356e-02,   9.99081929e-01],
       [  3.41197200e-03,   6.61987304e-03,   9.98789488e-01],
       [  1.57586833e-02,  -1.75096140e-03,   9.99240967e-01],
       [  9.94283277e-01,  -8.29916583e-03,   4.50484437e-03],
       [ -4.45399721e-03,   9.99999319e-01,   5.08690649e-03],
       [  6.85619294e-03,   9.71329969e-01,   4.76572222e-02],
       [  5.15050699e-03,   9.99999701e-01,   3.36802679e-03],
       [  2.20254821e-03,   9.99999509e-01,   7.34395678e-03],
       [  9.94465967e-01,  -7.34802183e-03,   4.88426702e-03],
       [ -2.56778257e-02,   9.99999720e-01,  -9.12676867e-04],
       [  1.11509144e-02,   9.76052058e-01,  -1.64789094e-01],
       [ -7.46472082e-04,  -4.49598320e-02,   9.34411568e-01],
       [  9.94828208e-01,  -2.27354531e-03,   3.95310395e-03],
       [  1.68179672e-02,   9.99997938e-01,  -2.76380424e-02],
       [  9.94233336e-01,  -1.21096943e-02,   4.45702340e-03],
       [  9.94355188e-01,  -8.97977148e-03,   4.27105200e-03],
       [  9.94473754e-01,  -1.70011217e-02,   5.40513632e-03],
       [  1.88331026e-03,   9.99999433e-01,   1.43139092e-02],
       [ -5.36682257e-04,  -5.55623792e-03,   9.98796522e-01],
       [  1.63217335e-02,   2.59845158e-02,   9.99253170e-01]])

Show the confusion matrix


In [24]:
ypred = ylb.inverse_transform(preds)
print(confusion_matrix(y_test, ypred))


[[13  0  0]
 [ 0 12  0]
 [ 0  1 12]]

Show the classification report


In [25]:
print(classification_report(y_test, ypred))


             precision    recall  f1-score   support

          0       1.00      1.00      1.00        13
          1       0.92      1.00      0.96        12
          2       1.00      0.92      0.96        13

avg / total       0.98      0.97      0.97        38


In [ ]: