Keras Basics

Welcome to the section on deep learning! We'll be using Keras with a TensorFlow backend to perform our deep learning operations.

This means we should get familiar with some Keras fundamentals and basics!

Imports


In [145]:
import numpy as np

Dataset

We will use the famous Iris Data set.


More info on the data set: https://en.wikipedia.org/wiki/Iris_flower_data_set

Reading in the Data Set

We've already downloaded the dataset, its in this folder. So let's open it up.


In [146]:
from sklearn.datasets import load_iris

In [147]:
iris = load_iris()

In [148]:
type(iris)


Out[148]:
sklearn.utils.Bunch

In [149]:
print(iris.DESCR)


.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

.. topic:: References

   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

In [150]:
X = iris.data

In [151]:
X


Out[151]:
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])

In [152]:
y = iris.target

In [153]:
y


Out[153]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [154]:
from keras.utils import to_categorical

In [156]:
y = to_categorical(y)

In [157]:
y.shape


Out[157]:
(150, 3)

In [158]:
y


Out[158]:
array([[1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.]], dtype=float32)

Split the Data into Training and Test

Its time to split the data into a train/test set. Keep in mind, sometimes people like to split 3 ways, train/test/validation. We'll keep things simple for now. Remember to check out the video explanation as to why we split and what all the parameters mean!


In [159]:
from sklearn.model_selection import train_test_split

In [160]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

In [111]:
X_train


Out[111]:
array([[5.7, 2.9, 4.2, 1.3],
       [7.6, 3. , 6.6, 2.1],
       [5.6, 3. , 4.5, 1.5],
       [5.1, 3.5, 1.4, 0.2],
       [7.7, 2.8, 6.7, 2. ],
       [5.8, 2.7, 4.1, 1. ],
       [5.2, 3.4, 1.4, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [5.1, 3.8, 1.9, 0.4],
       [5. , 2. , 3.5, 1. ],
       [6.3, 2.7, 4.9, 1.8],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [5.6, 2.7, 4.2, 1.3],
       [5.1, 3.4, 1.5, 0.2],
       [5.7, 3. , 4.2, 1.2],
       [7.7, 3.8, 6.7, 2.2],
       [4.6, 3.2, 1.4, 0.2],
       [6.2, 2.9, 4.3, 1.3],
       [5.7, 2.5, 5. , 2. ],
       [5.5, 4.2, 1.4, 0.2],
       [6. , 3. , 4.8, 1.8],
       [5.8, 2.7, 5.1, 1.9],
       [6. , 2.2, 4. , 1. ],
       [5.4, 3. , 4.5, 1.5],
       [6.2, 3.4, 5.4, 2.3],
       [5.5, 2.3, 4. , 1.3],
       [5.4, 3.9, 1.7, 0.4],
       [5. , 2.3, 3.3, 1. ],
       [6.4, 2.7, 5.3, 1.9],
       [5. , 3.3, 1.4, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 2.4, 3.8, 1.1],
       [6.7, 3. , 5. , 1.7],
       [4.9, 3.1, 1.5, 0.2],
       [5.8, 2.8, 5.1, 2.4],
       [5. , 3.4, 1.5, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.9, 3.2, 4.8, 1.8],
       [5.1, 2.5, 3. , 1.1],
       [6.9, 3.2, 5.7, 2.3],
       [6. , 2.7, 5.1, 1.6],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [5.5, 2.5, 4. , 1.3],
       [4.4, 2.9, 1.4, 0.2],
       [4.3, 3. , 1.1, 0.1],
       [6. , 2.2, 5. , 1.5],
       [7.2, 3.2, 6. , 1.8],
       [4.6, 3.1, 1.5, 0.2],
       [5.1, 3.5, 1.4, 0.3],
       [4.4, 3. , 1.3, 0.2],
       [6.3, 2.5, 4.9, 1.5],
       [6.3, 3.4, 5.6, 2.4],
       [4.6, 3.4, 1.4, 0.3],
       [6.8, 3. , 5.5, 2.1],
       [6.3, 3.3, 6. , 2.5],
       [4.7, 3.2, 1.3, 0.2],
       [6.1, 2.9, 4.7, 1.4],
       [6.5, 2.8, 4.6, 1.5],
       [6.2, 2.8, 4.8, 1.8],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 5.3, 2.3],
       [5.1, 3.8, 1.6, 0.2],
       [6.9, 3.1, 5.4, 2.1],
       [5.9, 3. , 4.2, 1.5],
       [6.5, 3. , 5.2, 2. ],
       [5.7, 2.6, 3.5, 1. ],
       [5.2, 2.7, 3.9, 1.4],
       [6.1, 3. , 4.6, 1.4],
       [4.5, 2.3, 1.3, 0.3],
       [6.6, 2.9, 4.6, 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [5.3, 3.7, 1.5, 0.2],
       [5.6, 3. , 4.1, 1.3],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [5.1, 3.7, 1.5, 0.4],
       [4.9, 2.4, 3.3, 1. ],
       [6.7, 3.3, 5.7, 2.5],
       [7.2, 3. , 5.8, 1.6],
       [4.9, 3.6, 1.4, 0.1],
       [6.7, 3.1, 5.6, 2.4],
       [4.9, 3. , 1.4, 0.2],
       [6.9, 3.1, 4.9, 1.5],
       [7.4, 2.8, 6.1, 1.9],
       [6.3, 2.9, 5.6, 1.8],
       [5.7, 2.8, 4.1, 1.3],
       [6.5, 3. , 5.5, 1.8],
       [6.3, 2.3, 4.4, 1.3],
       [6.4, 2.9, 4.3, 1.3],
       [5.6, 2.8, 4.9, 2. ],
       [5.9, 3. , 5.1, 1.8],
       [5.4, 3.4, 1.7, 0.2],
       [6.1, 2.8, 4. , 1.3],
       [4.9, 2.5, 4.5, 1.7],
       [5.8, 4. , 1.2, 0.2],
       [5.8, 2.6, 4. , 1.2],
       [7.1, 3. , 5.9, 2.1]])

In [112]:
X_test


Out[112]:
array([[6.1, 2.8, 4.7, 1.2],
       [5.7, 3.8, 1.7, 0.3],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.9, 4.5, 1.5],
       [6.8, 2.8, 4.8, 1.4],
       [5.4, 3.4, 1.5, 0.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.9, 3.1, 5.1, 2.3],
       [6.2, 2.2, 4.5, 1.5],
       [5.8, 2.7, 3.9, 1.2],
       [6.5, 3.2, 5.1, 2. ],
       [4.8, 3. , 1.4, 0.1],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.1, 3.8, 1.5, 0.3],
       [6.3, 3.3, 4.7, 1.6],
       [6.5, 3. , 5.8, 2.2],
       [5.6, 2.5, 3.9, 1.1],
       [5.7, 2.8, 4.5, 1.3],
       [6.4, 2.8, 5.6, 2.2],
       [4.7, 3.2, 1.6, 0.2],
       [6.1, 3. , 4.9, 1.8],
       [5. , 3.4, 1.6, 0.4],
       [6.4, 2.8, 5.6, 2.1],
       [7.9, 3.8, 6.4, 2. ],
       [6.7, 3. , 5.2, 2.3],
       [6.7, 2.5, 5.8, 1.8],
       [6.8, 3.2, 5.9, 2.3],
       [4.8, 3. , 1.4, 0.3],
       [4.8, 3.1, 1.6, 0.2],
       [4.6, 3.6, 1. , 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [6.7, 3.1, 4.4, 1.4],
       [4.8, 3.4, 1.6, 0.2],
       [4.4, 3.2, 1.3, 0.2],
       [6.3, 2.5, 5. , 1.9],
       [6.4, 3.2, 4.5, 1.5],
       [5.2, 3.5, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.2, 4.1, 1.5, 0.1],
       [5.8, 2.7, 5.1, 1.9],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [5.4, 3.9, 1.3, 0.4],
       [5.4, 3.7, 1.5, 0.2],
       [5.5, 2.4, 3.7, 1. ],
       [6.3, 2.8, 5.1, 1.5],
       [6.4, 3.1, 5.5, 1.8],
       [6.6, 3. , 4.4, 1.4],
       [7.2, 3.6, 6.1, 2.5]])

In [113]:
y_train


Out[113]:
array([[0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

In [114]:
y_test


Out[114]:
array([[0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

Standardizing the Data

Usually when using Neural Networks, you will get better performance when you standardize the data. Standardization just means normalizing the values to all fit between a certain range, like 0-1, or -1 to 1.

The scikit learn library also provides a nice function for this.

http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html


In [115]:
from sklearn.preprocessing import MinMaxScaler

In [116]:
scaler_object = MinMaxScaler()

In [117]:
scaler_object.fit(X_train)


Out[117]:
MinMaxScaler(copy=True, feature_range=(0, 1))

In [118]:
scaled_X_train = scaler_object.transform(X_train)

In [119]:
scaled_X_test = scaler_object.transform(X_test)

Ok, now we have the data scaled!


In [120]:
X_train.max()


Out[120]:
7.7

In [121]:
scaled_X_train.max()


Out[121]:
1.0

In [122]:
X_train


Out[122]:
array([[5.7, 2.9, 4.2, 1.3],
       [7.6, 3. , 6.6, 2.1],
       [5.6, 3. , 4.5, 1.5],
       [5.1, 3.5, 1.4, 0.2],
       [7.7, 2.8, 6.7, 2. ],
       [5.8, 2.7, 4.1, 1. ],
       [5.2, 3.4, 1.4, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [5.1, 3.8, 1.9, 0.4],
       [5. , 2. , 3.5, 1. ],
       [6.3, 2.7, 4.9, 1.8],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [5.6, 2.7, 4.2, 1.3],
       [5.1, 3.4, 1.5, 0.2],
       [5.7, 3. , 4.2, 1.2],
       [7.7, 3.8, 6.7, 2.2],
       [4.6, 3.2, 1.4, 0.2],
       [6.2, 2.9, 4.3, 1.3],
       [5.7, 2.5, 5. , 2. ],
       [5.5, 4.2, 1.4, 0.2],
       [6. , 3. , 4.8, 1.8],
       [5.8, 2.7, 5.1, 1.9],
       [6. , 2.2, 4. , 1. ],
       [5.4, 3. , 4.5, 1.5],
       [6.2, 3.4, 5.4, 2.3],
       [5.5, 2.3, 4. , 1.3],
       [5.4, 3.9, 1.7, 0.4],
       [5. , 2.3, 3.3, 1. ],
       [6.4, 2.7, 5.3, 1.9],
       [5. , 3.3, 1.4, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 2.4, 3.8, 1.1],
       [6.7, 3. , 5. , 1.7],
       [4.9, 3.1, 1.5, 0.2],
       [5.8, 2.8, 5.1, 2.4],
       [5. , 3.4, 1.5, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.9, 3.2, 4.8, 1.8],
       [5.1, 2.5, 3. , 1.1],
       [6.9, 3.2, 5.7, 2.3],
       [6. , 2.7, 5.1, 1.6],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [5.5, 2.5, 4. , 1.3],
       [4.4, 2.9, 1.4, 0.2],
       [4.3, 3. , 1.1, 0.1],
       [6. , 2.2, 5. , 1.5],
       [7.2, 3.2, 6. , 1.8],
       [4.6, 3.1, 1.5, 0.2],
       [5.1, 3.5, 1.4, 0.3],
       [4.4, 3. , 1.3, 0.2],
       [6.3, 2.5, 4.9, 1.5],
       [6.3, 3.4, 5.6, 2.4],
       [4.6, 3.4, 1.4, 0.3],
       [6.8, 3. , 5.5, 2.1],
       [6.3, 3.3, 6. , 2.5],
       [4.7, 3.2, 1.3, 0.2],
       [6.1, 2.9, 4.7, 1.4],
       [6.5, 2.8, 4.6, 1.5],
       [6.2, 2.8, 4.8, 1.8],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 5.3, 2.3],
       [5.1, 3.8, 1.6, 0.2],
       [6.9, 3.1, 5.4, 2.1],
       [5.9, 3. , 4.2, 1.5],
       [6.5, 3. , 5.2, 2. ],
       [5.7, 2.6, 3.5, 1. ],
       [5.2, 2.7, 3.9, 1.4],
       [6.1, 3. , 4.6, 1.4],
       [4.5, 2.3, 1.3, 0.3],
       [6.6, 2.9, 4.6, 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [5.3, 3.7, 1.5, 0.2],
       [5.6, 3. , 4.1, 1.3],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [5.1, 3.7, 1.5, 0.4],
       [4.9, 2.4, 3.3, 1. ],
       [6.7, 3.3, 5.7, 2.5],
       [7.2, 3. , 5.8, 1.6],
       [4.9, 3.6, 1.4, 0.1],
       [6.7, 3.1, 5.6, 2.4],
       [4.9, 3. , 1.4, 0.2],
       [6.9, 3.1, 4.9, 1.5],
       [7.4, 2.8, 6.1, 1.9],
       [6.3, 2.9, 5.6, 1.8],
       [5.7, 2.8, 4.1, 1.3],
       [6.5, 3. , 5.5, 1.8],
       [6.3, 2.3, 4.4, 1.3],
       [6.4, 2.9, 4.3, 1.3],
       [5.6, 2.8, 4.9, 2. ],
       [5.9, 3. , 5.1, 1.8],
       [5.4, 3.4, 1.7, 0.2],
       [6.1, 2.8, 4. , 1.3],
       [4.9, 2.5, 4.5, 1.7],
       [5.8, 4. , 1.2, 0.2],
       [5.8, 2.6, 4. , 1.2],
       [7.1, 3. , 5.9, 2.1]])

In [123]:
scaled_X_train


Out[123]:
array([[0.41176471, 0.40909091, 0.55357143, 0.5       ],
       [0.97058824, 0.45454545, 0.98214286, 0.83333333],
       [0.38235294, 0.45454545, 0.60714286, 0.58333333],
       [0.23529412, 0.68181818, 0.05357143, 0.04166667],
       [1.        , 0.36363636, 1.        , 0.79166667],
       [0.44117647, 0.31818182, 0.53571429, 0.375     ],
       [0.26470588, 0.63636364, 0.05357143, 0.04166667],
       [0.20588235, 0.68181818, 0.03571429, 0.08333333],
       [0.23529412, 0.81818182, 0.14285714, 0.125     ],
       [0.20588235, 0.        , 0.42857143, 0.375     ],
       [0.58823529, 0.31818182, 0.67857143, 0.70833333],
       [0.14705882, 0.63636364, 0.14285714, 0.04166667],
       [0.20588235, 0.45454545, 0.08928571, 0.04166667],
       [0.23529412, 0.59090909, 0.10714286, 0.16666667],
       [0.38235294, 0.31818182, 0.55357143, 0.5       ],
       [0.23529412, 0.63636364, 0.07142857, 0.04166667],
       [0.41176471, 0.45454545, 0.55357143, 0.45833333],
       [1.        , 0.81818182, 1.        , 0.875     ],
       [0.08823529, 0.54545455, 0.05357143, 0.04166667],
       [0.55882353, 0.40909091, 0.57142857, 0.5       ],
       [0.41176471, 0.22727273, 0.69642857, 0.79166667],
       [0.35294118, 1.        , 0.05357143, 0.04166667],
       [0.5       , 0.45454545, 0.66071429, 0.70833333],
       [0.44117647, 0.31818182, 0.71428571, 0.75      ],
       [0.5       , 0.09090909, 0.51785714, 0.375     ],
       [0.32352941, 0.45454545, 0.60714286, 0.58333333],
       [0.55882353, 0.63636364, 0.76785714, 0.91666667],
       [0.35294118, 0.13636364, 0.51785714, 0.5       ],
       [0.32352941, 0.86363636, 0.10714286, 0.125     ],
       [0.20588235, 0.13636364, 0.39285714, 0.375     ],
       [0.61764706, 0.31818182, 0.75      , 0.75      ],
       [0.20588235, 0.59090909, 0.05357143, 0.04166667],
       [0.20588235, 0.54545455, 0.01785714, 0.04166667],
       [0.35294118, 0.18181818, 0.48214286, 0.41666667],
       [0.70588235, 0.45454545, 0.69642857, 0.66666667],
       [0.17647059, 0.5       , 0.07142857, 0.04166667],
       [0.44117647, 0.36363636, 0.71428571, 0.95833333],
       [0.20588235, 0.63636364, 0.07142857, 0.04166667],
       [0.20588235, 0.68181818, 0.08928571, 0.20833333],
       [0.47058824, 0.54545455, 0.66071429, 0.70833333],
       [0.23529412, 0.22727273, 0.33928571, 0.41666667],
       [0.76470588, 0.54545455, 0.82142857, 0.91666667],
       [0.5       , 0.31818182, 0.71428571, 0.625     ],
       [0.52941176, 0.27272727, 0.80357143, 0.54166667],
       [1.        , 0.45454545, 0.89285714, 0.91666667],
       [0.35294118, 0.22727273, 0.51785714, 0.5       ],
       [0.02941176, 0.40909091, 0.05357143, 0.04166667],
       [0.        , 0.45454545, 0.        , 0.        ],
       [0.5       , 0.09090909, 0.69642857, 0.58333333],
       [0.85294118, 0.54545455, 0.875     , 0.70833333],
       [0.08823529, 0.5       , 0.07142857, 0.04166667],
       [0.23529412, 0.68181818, 0.05357143, 0.08333333],
       [0.02941176, 0.45454545, 0.03571429, 0.04166667],
       [0.58823529, 0.22727273, 0.67857143, 0.58333333],
       [0.58823529, 0.63636364, 0.80357143, 0.95833333],
       [0.08823529, 0.63636364, 0.05357143, 0.08333333],
       [0.73529412, 0.45454545, 0.78571429, 0.83333333],
       [0.58823529, 0.59090909, 0.875     , 1.        ],
       [0.11764706, 0.54545455, 0.03571429, 0.04166667],
       [0.52941176, 0.40909091, 0.64285714, 0.54166667],
       [0.64705882, 0.36363636, 0.625     , 0.58333333],
       [0.55882353, 0.36363636, 0.66071429, 0.70833333],
       [0.79411765, 0.54545455, 0.64285714, 0.54166667],
       [0.61764706, 0.54545455, 0.75      , 0.91666667],
       [0.23529412, 0.81818182, 0.08928571, 0.04166667],
       [0.76470588, 0.5       , 0.76785714, 0.83333333],
       [0.47058824, 0.45454545, 0.55357143, 0.58333333],
       [0.64705882, 0.45454545, 0.73214286, 0.79166667],
       [0.41176471, 0.27272727, 0.42857143, 0.375     ],
       [0.26470588, 0.31818182, 0.5       , 0.54166667],
       [0.52941176, 0.45454545, 0.625     , 0.54166667],
       [0.05882353, 0.13636364, 0.03571429, 0.08333333],
       [0.67647059, 0.40909091, 0.625     , 0.5       ],
       [0.35294118, 0.27272727, 0.58928571, 0.45833333],
       [0.29411765, 0.77272727, 0.07142857, 0.04166667],
       [0.38235294, 0.45454545, 0.53571429, 0.5       ],
       [0.88235294, 0.40909091, 0.92857143, 0.70833333],
       [0.70588235, 0.59090909, 0.82142857, 0.83333333],
       [0.23529412, 0.77272727, 0.07142857, 0.125     ],
       [0.17647059, 0.18181818, 0.39285714, 0.375     ],
       [0.70588235, 0.59090909, 0.82142857, 1.        ],
       [0.85294118, 0.45454545, 0.83928571, 0.625     ],
       [0.17647059, 0.72727273, 0.05357143, 0.        ],
       [0.70588235, 0.5       , 0.80357143, 0.95833333],
       [0.17647059, 0.45454545, 0.05357143, 0.04166667],
       [0.76470588, 0.5       , 0.67857143, 0.58333333],
       [0.91176471, 0.36363636, 0.89285714, 0.75      ],
       [0.58823529, 0.40909091, 0.80357143, 0.70833333],
       [0.41176471, 0.36363636, 0.53571429, 0.5       ],
       [0.64705882, 0.45454545, 0.78571429, 0.70833333],
       [0.58823529, 0.13636364, 0.58928571, 0.5       ],
       [0.61764706, 0.40909091, 0.57142857, 0.5       ],
       [0.38235294, 0.36363636, 0.67857143, 0.79166667],
       [0.47058824, 0.45454545, 0.71428571, 0.70833333],
       [0.32352941, 0.63636364, 0.10714286, 0.04166667],
       [0.52941176, 0.36363636, 0.51785714, 0.5       ],
       [0.17647059, 0.22727273, 0.60714286, 0.66666667],
       [0.44117647, 0.90909091, 0.01785714, 0.04166667],
       [0.44117647, 0.27272727, 0.51785714, 0.45833333],
       [0.82352941, 0.45454545, 0.85714286, 0.83333333]])

Building the Network with Keras

Let's build a simple neural network!


In [132]:
from keras.models import Sequential
from keras.layers import Dense

In [133]:
model = Sequential()
model.add(Dense(8, input_dim=4, activation='relu'))
model.add(Dense(8, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [134]:
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_27 (Dense)             (None, 8)                 40        
_________________________________________________________________
dense_28 (Dense)             (None, 8)                 72        
_________________________________________________________________
dense_29 (Dense)             (None, 3)                 27        
=================================================================
Total params: 139
Trainable params: 139
Non-trainable params: 0
_________________________________________________________________

Fit (Train) the Model


In [135]:
# Play around with number of epochs as well!
model.fit(scaled_X_train,y_train,epochs=150, verbose=2)


Epoch 1/150
 - 0s - loss: 1.0926 - acc: 0.3400
Epoch 2/150
 - 0s - loss: 1.0871 - acc: 0.3400
Epoch 3/150
 - 0s - loss: 1.0814 - acc: 0.3400
Epoch 4/150
 - 0s - loss: 1.0760 - acc: 0.3400
Epoch 5/150
 - 0s - loss: 1.0700 - acc: 0.3400
Epoch 6/150
 - 0s - loss: 1.0640 - acc: 0.3500
Epoch 7/150
 - 0s - loss: 1.0581 - acc: 0.3700
Epoch 8/150
 - 0s - loss: 1.0520 - acc: 0.4200
Epoch 9/150
 - 0s - loss: 1.0467 - acc: 0.5100
Epoch 10/150
 - 0s - loss: 1.0416 - acc: 0.5700
Epoch 11/150
 - 0s - loss: 1.0361 - acc: 0.6300
Epoch 12/150
 - 0s - loss: 1.0308 - acc: 0.6300
Epoch 13/150
 - 0s - loss: 1.0249 - acc: 0.6300
Epoch 14/150
 - 0s - loss: 1.0187 - acc: 0.6200
Epoch 15/150
 - 0s - loss: 1.0124 - acc: 0.6300
Epoch 16/150
 - 0s - loss: 1.0062 - acc: 0.6300
Epoch 17/150
 - 0s - loss: 0.9994 - acc: 0.6400
Epoch 18/150
 - 0s - loss: 0.9919 - acc: 0.6400
Epoch 19/150
 - 0s - loss: 0.9836 - acc: 0.6400
Epoch 20/150
 - 0s - loss: 0.9748 - acc: 0.6400
Epoch 21/150
 - 0s - loss: 0.9649 - acc: 0.6400
Epoch 22/150
 - 0s - loss: 0.9552 - acc: 0.6400
Epoch 23/150
 - 0s - loss: 0.9448 - acc: 0.6500
Epoch 24/150
 - 0s - loss: 0.9337 - acc: 0.6500
Epoch 25/150
 - 0s - loss: 0.9225 - acc: 0.6500
Epoch 26/150
 - 0s - loss: 0.9111 - acc: 0.6500
Epoch 27/150
 - 0s - loss: 0.8992 - acc: 0.6500
Epoch 28/150
 - 0s - loss: 0.8882 - acc: 0.6500
Epoch 29/150
 - 0s - loss: 0.8766 - acc: 0.6500
Epoch 30/150
 - 0s - loss: 0.8658 - acc: 0.6500
Epoch 31/150
 - 0s - loss: 0.8555 - acc: 0.6500
Epoch 32/150
 - 0s - loss: 0.8446 - acc: 0.6500
Epoch 33/150
 - 0s - loss: 0.8337 - acc: 0.6500
Epoch 34/150
 - 0s - loss: 0.8225 - acc: 0.6500
Epoch 35/150
 - 0s - loss: 0.8112 - acc: 0.6500
Epoch 36/150
 - 0s - loss: 0.7998 - acc: 0.6500
Epoch 37/150
 - 0s - loss: 0.7886 - acc: 0.6500
Epoch 38/150
 - 0s - loss: 0.7768 - acc: 0.6500
Epoch 39/150
 - 0s - loss: 0.7650 - acc: 0.6500
Epoch 40/150
 - 0s - loss: 0.7539 - acc: 0.6500
Epoch 41/150
 - 0s - loss: 0.7428 - acc: 0.6500
Epoch 42/150
 - 0s - loss: 0.7318 - acc: 0.6500
Epoch 43/150
 - 0s - loss: 0.7213 - acc: 0.6500
Epoch 44/150
 - 0s - loss: 0.7111 - acc: 0.6500
Epoch 45/150
 - 0s - loss: 0.7007 - acc: 0.6500
Epoch 46/150
 - 0s - loss: 0.6907 - acc: 0.6500
Epoch 47/150
 - 0s - loss: 0.6806 - acc: 0.6500
Epoch 48/150
 - 0s - loss: 0.6714 - acc: 0.6500
Epoch 49/150
 - 0s - loss: 0.6621 - acc: 0.6500
Epoch 50/150
 - 0s - loss: 0.6534 - acc: 0.6500
Epoch 51/150
 - 0s - loss: 0.6452 - acc: 0.6500
Epoch 52/150
 - 0s - loss: 0.6366 - acc: 0.6500
Epoch 53/150
 - 0s - loss: 0.6285 - acc: 0.6500
Epoch 54/150
 - 0s - loss: 0.6205 - acc: 0.6500
Epoch 55/150
 - 0s - loss: 0.6130 - acc: 0.6500
Epoch 56/150
 - 0s - loss: 0.6058 - acc: 0.6500
Epoch 57/150
 - 0s - loss: 0.5990 - acc: 0.6500
Epoch 58/150
 - 0s - loss: 0.5920 - acc: 0.6500
Epoch 59/150
 - 0s - loss: 0.5852 - acc: 0.6700
Epoch 60/150
 - 0s - loss: 0.5790 - acc: 0.6700
Epoch 61/150
 - 0s - loss: 0.5727 - acc: 0.6900
Epoch 62/150
 - 0s - loss: 0.5665 - acc: 0.6900
Epoch 63/150
 - 0s - loss: 0.5605 - acc: 0.6900
Epoch 64/150
 - 0s - loss: 0.5544 - acc: 0.6900
Epoch 65/150
 - 0s - loss: 0.5490 - acc: 0.6900
Epoch 66/150
 - 0s - loss: 0.5436 - acc: 0.7000
Epoch 67/150
 - 0s - loss: 0.5385 - acc: 0.7000
Epoch 68/150
 - 0s - loss: 0.5334 - acc: 0.7000
Epoch 69/150
 - 0s - loss: 0.5287 - acc: 0.7100
Epoch 70/150
 - 0s - loss: 0.5243 - acc: 0.7200
Epoch 71/150
 - 0s - loss: 0.5198 - acc: 0.7300
Epoch 72/150
 - 0s - loss: 0.5150 - acc: 0.7300
Epoch 73/150
 - 0s - loss: 0.5108 - acc: 0.7400
Epoch 74/150
 - 0s - loss: 0.5066 - acc: 0.7400
Epoch 75/150
 - 0s - loss: 0.5022 - acc: 0.7600
Epoch 76/150
 - 0s - loss: 0.4986 - acc: 0.7500
Epoch 77/150
 - 0s - loss: 0.4945 - acc: 0.7400
Epoch 78/150
 - 0s - loss: 0.4908 - acc: 0.7500
Epoch 79/150
 - 0s - loss: 0.4867 - acc: 0.7700
Epoch 80/150
 - 0s - loss: 0.4830 - acc: 0.7700
Epoch 81/150
 - 0s - loss: 0.4794 - acc: 0.7900
Epoch 82/150
 - 0s - loss: 0.4761 - acc: 0.8500
Epoch 83/150
 - 0s - loss: 0.4726 - acc: 0.8500
Epoch 84/150
 - 0s - loss: 0.4693 - acc: 0.8500
Epoch 85/150
 - 0s - loss: 0.4660 - acc: 0.8500
Epoch 86/150
 - 0s - loss: 0.4628 - acc: 0.8500
Epoch 87/150
 - 0s - loss: 0.4599 - acc: 0.8200
Epoch 88/150
 - 0s - loss: 0.4575 - acc: 0.8000
Epoch 89/150
 - 0s - loss: 0.4546 - acc: 0.7800
Epoch 90/150
 - 0s - loss: 0.4521 - acc: 0.7800
Epoch 91/150
 - 0s - loss: 0.4490 - acc: 0.8100
Epoch 92/150
 - 0s - loss: 0.4460 - acc: 0.8300
Epoch 93/150
 - 0s - loss: 0.4433 - acc: 0.8500
Epoch 94/150
 - 0s - loss: 0.4407 - acc: 0.8500
Epoch 95/150
 - 0s - loss: 0.4380 - acc: 0.8500
Epoch 96/150
 - 0s - loss: 0.4352 - acc: 0.8600
Epoch 97/150
 - 0s - loss: 0.4326 - acc: 0.8600
Epoch 98/150
 - 0s - loss: 0.4302 - acc: 0.9000
Epoch 99/150
 - 0s - loss: 0.4270 - acc: 0.9100
Epoch 100/150
 - 0s - loss: 0.4245 - acc: 0.9200
Epoch 101/150
 - 0s - loss: 0.4223 - acc: 0.9200
Epoch 102/150
 - 0s - loss: 0.4195 - acc: 0.9200
Epoch 103/150
 - 0s - loss: 0.4171 - acc: 0.9200
Epoch 104/150
 - 0s - loss: 0.4147 - acc: 0.9200
Epoch 105/150
 - 0s - loss: 0.4124 - acc: 0.9000
Epoch 106/150
 - 0s - loss: 0.4102 - acc: 0.9000
Epoch 107/150
 - 0s - loss: 0.4077 - acc: 0.9200
Epoch 108/150
 - 0s - loss: 0.4055 - acc: 0.9200
Epoch 109/150
 - 0s - loss: 0.4028 - acc: 0.9200
Epoch 110/150
 - 0s - loss: 0.4006 - acc: 0.9200
Epoch 111/150
 - 0s - loss: 0.3985 - acc: 0.9200
Epoch 112/150
 - 0s - loss: 0.3965 - acc: 0.9200
Epoch 113/150
 - 0s - loss: 0.3945 - acc: 0.9200
Epoch 114/150
 - 0s - loss: 0.3930 - acc: 0.9200
Epoch 115/150
 - 0s - loss: 0.3914 - acc: 0.9000
Epoch 116/150
 - 0s - loss: 0.3891 - acc: 0.9100
Epoch 117/150
 - 0s - loss: 0.3860 - acc: 0.9200
Epoch 118/150
 - 0s - loss: 0.3844 - acc: 0.9200
Epoch 119/150
 - 0s - loss: 0.3821 - acc: 0.9500
Epoch 120/150
 - 0s - loss: 0.3801 - acc: 0.9500
Epoch 121/150
 - 0s - loss: 0.3780 - acc: 0.9500
Epoch 122/150
 - 0s - loss: 0.3757 - acc: 0.9500
Epoch 123/150
 - 0s - loss: 0.3738 - acc: 0.9500
Epoch 124/150
 - 0s - loss: 0.3720 - acc: 0.9500
Epoch 125/150
 - 0s - loss: 0.3693 - acc: 0.9500
Epoch 126/150
 - 0s - loss: 0.3662 - acc: 0.9500
Epoch 127/150
 - 0s - loss: 0.3646 - acc: 0.9300
Epoch 128/150
 - 0s - loss: 0.3652 - acc: 0.9200
Epoch 129/150
 - 0s - loss: 0.3640 - acc: 0.9100
Epoch 130/150
 - 0s - loss: 0.3627 - acc: 0.9100
Epoch 131/150
 - 0s - loss: 0.3604 - acc: 0.9100
Epoch 132/150
 - 0s - loss: 0.3572 - acc: 0.9300
Epoch 133/150
 - 0s - loss: 0.3544 - acc: 0.9300
Epoch 134/150
 - 0s - loss: 0.3521 - acc: 0.9300
Epoch 135/150
 - 0s - loss: 0.3501 - acc: 0.9500
Epoch 136/150
 - 0s - loss: 0.3482 - acc: 0.9500
Epoch 137/150
 - 0s - loss: 0.3465 - acc: 0.9500
Epoch 138/150
 - 0s - loss: 0.3461 - acc: 0.9500
Epoch 139/150
 - 0s - loss: 0.3435 - acc: 0.9300
Epoch 140/150
 - 0s - loss: 0.3407 - acc: 0.9500
Epoch 141/150
 - 0s - loss: 0.3384 - acc: 0.9500
Epoch 142/150
 - 0s - loss: 0.3366 - acc: 0.9500
Epoch 143/150
 - 0s - loss: 0.3347 - acc: 0.9500
Epoch 144/150
 - 0s - loss: 0.3336 - acc: 0.9400
Epoch 145/150
 - 0s - loss: 0.3324 - acc: 0.9300
Epoch 146/150
 - 0s - loss: 0.3311 - acc: 0.9300
Epoch 147/150
 - 0s - loss: 0.3300 - acc: 0.9300
Epoch 148/150
 - 0s - loss: 0.3283 - acc: 0.9300
Epoch 149/150
 - 0s - loss: 0.3256 - acc: 0.9400
Epoch 150/150
 - 0s - loss: 0.3224 - acc: 0.9500
Out[135]:
<keras.callbacks.History at 0x1fefed36ac8>

Predicting New Unseen Data

Let's see how we did by predicting on new data. Remember, our model has never seen the test data that we scaled previously! This process is the exact same process you would use on totally brand new data. For example , a brand new bank note that you just analyzed .


In [136]:
scaled_X_test


Out[136]:
array([[ 0.52941176,  0.36363636,  0.64285714,  0.45833333],
       [ 0.41176471,  0.81818182,  0.10714286,  0.08333333],
       [ 1.        ,  0.27272727,  1.03571429,  0.91666667],
       [ 0.5       ,  0.40909091,  0.60714286,  0.58333333],
       [ 0.73529412,  0.36363636,  0.66071429,  0.54166667],
       [ 0.32352941,  0.63636364,  0.07142857,  0.125     ],
       [ 0.38235294,  0.40909091,  0.44642857,  0.5       ],
       [ 0.76470588,  0.5       ,  0.71428571,  0.91666667],
       [ 0.55882353,  0.09090909,  0.60714286,  0.58333333],
       [ 0.44117647,  0.31818182,  0.5       ,  0.45833333],
       [ 0.64705882,  0.54545455,  0.71428571,  0.79166667],
       [ 0.14705882,  0.45454545,  0.05357143,  0.        ],
       [ 0.35294118,  0.68181818,  0.03571429,  0.04166667],
       [ 0.17647059,  0.5       ,  0.07142857,  0.        ],
       [ 0.23529412,  0.81818182,  0.07142857,  0.08333333],
       [ 0.58823529,  0.59090909,  0.64285714,  0.625     ],
       [ 0.64705882,  0.45454545,  0.83928571,  0.875     ],
       [ 0.38235294,  0.22727273,  0.5       ,  0.41666667],
       [ 0.41176471,  0.36363636,  0.60714286,  0.5       ],
       [ 0.61764706,  0.36363636,  0.80357143,  0.875     ],
       [ 0.11764706,  0.54545455,  0.08928571,  0.04166667],
       [ 0.52941176,  0.45454545,  0.67857143,  0.70833333],
       [ 0.20588235,  0.63636364,  0.08928571,  0.125     ],
       [ 0.61764706,  0.36363636,  0.80357143,  0.83333333],
       [ 1.05882353,  0.81818182,  0.94642857,  0.79166667],
       [ 0.70588235,  0.45454545,  0.73214286,  0.91666667],
       [ 0.70588235,  0.22727273,  0.83928571,  0.70833333],
       [ 0.73529412,  0.54545455,  0.85714286,  0.91666667],
       [ 0.14705882,  0.45454545,  0.05357143,  0.08333333],
       [ 0.14705882,  0.5       ,  0.08928571,  0.04166667],
       [ 0.08823529,  0.72727273, -0.01785714,  0.04166667],
       [ 0.41176471,  1.09090909,  0.07142857,  0.125     ],
       [ 0.70588235,  0.5       ,  0.58928571,  0.54166667],
       [ 0.14705882,  0.63636364,  0.08928571,  0.04166667],
       [ 0.02941176,  0.54545455,  0.03571429,  0.04166667],
       [ 0.58823529,  0.22727273,  0.69642857,  0.75      ],
       [ 0.61764706,  0.54545455,  0.60714286,  0.58333333],
       [ 0.26470588,  0.68181818,  0.07142857,  0.04166667],
       [ 0.20588235,  0.72727273,  0.05357143,  0.04166667],
       [ 0.26470588,  0.95454545,  0.07142857,  0.        ],
       [ 0.44117647,  0.31818182,  0.71428571,  0.75      ],
       [ 0.5       ,  0.63636364,  0.60714286,  0.625     ],
       [ 0.70588235,  0.5       ,  0.64285714,  0.58333333],
       [ 0.32352941,  0.86363636,  0.03571429,  0.125     ],
       [ 0.32352941,  0.77272727,  0.07142857,  0.04166667],
       [ 0.35294118,  0.18181818,  0.46428571,  0.375     ],
       [ 0.58823529,  0.36363636,  0.71428571,  0.58333333],
       [ 0.61764706,  0.5       ,  0.78571429,  0.70833333],
       [ 0.67647059,  0.45454545,  0.58928571,  0.54166667],
       [ 0.85294118,  0.72727273,  0.89285714,  1.        ]])

In [137]:
# Spits out probabilities by default.
# model.predict(scaled_X_test)

In [138]:
model.predict_classes(scaled_X_test)


Out[138]:
array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
       0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,
       0, 1, 2, 2, 1, 2], dtype=int64)

Evaluating Model Performance

So how well did we do? How do we actually measure "well". Is 95% accuracy good enough? It all depends on the situation. Also we need to take into account things like recall and precision. Make sure to watch the video discussion on classification evaluation before running this code!


In [139]:
model.metrics_names


Out[139]:
['loss', 'acc']

In [140]:
model.evaluate(x=scaled_X_test,y=y_test)


50/50 [==============================] - 0s 2ms/step
Out[140]:
[0.2843634402751923, 0.96]

In [141]:
from sklearn.metrics import confusion_matrix,classification_report

In [142]:
predictions = model.predict_classes(scaled_X_test)

In [161]:
predictions


Out[161]:
array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
       0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,
       0, 1, 2, 2, 1, 2], dtype=int64)

In [163]:
y_test.argmax(axis=1)


Out[163]:
array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
       0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,
       0, 1, 2, 2, 1, 2], dtype=int64)

In [164]:
confusion_matrix(y_test.argmax(axis=1),predictions)


Out[164]:
array([[19,  0,  0],
       [ 0, 13,  2],
       [ 0,  0, 16]], dtype=int64)

In [165]:
print(classification_report(y_test.argmax(axis=1),predictions))


              precision    recall  f1-score   support

           0       1.00      1.00      1.00        19
           1       1.00      0.87      0.93        15
           2       0.89      1.00      0.94        16

   micro avg       0.96      0.96      0.96        50
   macro avg       0.96      0.96      0.96        50
weighted avg       0.96      0.96      0.96        50

Saving and Loading Models

Now that we have a model trained, let's see how we can save and load it.


In [166]:
model.save('myfirstmodel.h5')

In [167]:
from keras.models import load_model

In [168]:
newmodel = load_model('myfirstmodel.h5')

In [169]:
newmodel.predict_classes(X_test)


Out[169]:
array([2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2,
       1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1,
       1, 2, 2, 2, 2, 2], dtype=int64)

Great job! You now know how to preprocess data, train a neural network, and evaluate its classification performance!