In [145]:
import numpy as np
We will use the famous Iris Data set.
More info on the data set: https://en.wikipedia.org/wiki/Iris_flower_data_set
We've already downloaded the dataset, its in this folder. So let's open it up.
In [146]:
from sklearn.datasets import load_iris
In [147]:
iris = load_iris()
In [148]:
type(iris)
Out[148]:
sklearn.utils.Bunch
In [149]:
print(iris.DESCR)
.. _iris_dataset:
Iris plants dataset
--------------------
**Data Set Characteristics:**
:Number of Instances: 150 (50 in each of three classes)
:Number of Attributes: 4 numeric, predictive attributes and the class
:Attribute Information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
- class:
- Iris-Setosa
- Iris-Versicolour
- Iris-Virginica
:Summary Statistics:
============== ==== ==== ======= ===== ====================
Min Max Mean SD Class Correlation
============== ==== ==== ======= ===== ====================
sepal length: 4.3 7.9 5.84 0.83 0.7826
sepal width: 2.0 4.4 3.05 0.43 -0.4194
petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
============== ==== ==== ======= ===== ====================
:Missing Attribute Values: None
:Class Distribution: 33.3% for each of 3 classes.
:Creator: R.A. Fisher
:Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
:Date: July, 1988
The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.
This is perhaps the best known database to be found in the
pattern recognition literature. Fisher's paper is a classic in the field and
is referenced frequently to this day. (See Duda & Hart, for example.) The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant. One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.
.. topic:: References
- Fisher, R.A. "The use of multiple measurements in taxonomic problems"
Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
Mathematical Statistics" (John Wiley, NY, 1950).
- Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
(Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
- Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
Structure and Classification Rule for Recognition in Partially Exposed
Environments". IEEE Transactions on Pattern Analysis and Machine
Intelligence, Vol. PAMI-2, No. 1, 67-71.
- Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions
on Information Theory, May 1972, 431-433.
- See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II
conceptual clustering system finds 3 classes in the data.
- Many, many more ...
In [150]:
X = iris.data
In [151]:
X
Out[151]:
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.6, 1.4, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]])
In [152]:
y = iris.target
In [153]:
y
Out[153]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
In [154]:
from keras.utils import to_categorical
In [156]:
y = to_categorical(y)
In [157]:
y.shape
Out[157]:
(150, 3)
In [158]:
y
Out[158]:
array([[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.]], dtype=float32)
In [159]:
from sklearn.model_selection import train_test_split
In [160]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
In [111]:
X_train
Out[111]:
array([[5.7, 2.9, 4.2, 1.3],
[7.6, 3. , 6.6, 2.1],
[5.6, 3. , 4.5, 1.5],
[5.1, 3.5, 1.4, 0.2],
[7.7, 2.8, 6.7, 2. ],
[5.8, 2.7, 4.1, 1. ],
[5.2, 3.4, 1.4, 0.2],
[5. , 3.5, 1.3, 0.3],
[5.1, 3.8, 1.9, 0.4],
[5. , 2. , 3.5, 1. ],
[6.3, 2.7, 4.9, 1.8],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5.1, 3.3, 1.7, 0.5],
[5.6, 2.7, 4.2, 1.3],
[5.1, 3.4, 1.5, 0.2],
[5.7, 3. , 4.2, 1.2],
[7.7, 3.8, 6.7, 2.2],
[4.6, 3.2, 1.4, 0.2],
[6.2, 2.9, 4.3, 1.3],
[5.7, 2.5, 5. , 2. ],
[5.5, 4.2, 1.4, 0.2],
[6. , 3. , 4.8, 1.8],
[5.8, 2.7, 5.1, 1.9],
[6. , 2.2, 4. , 1. ],
[5.4, 3. , 4.5, 1.5],
[6.2, 3.4, 5.4, 2.3],
[5.5, 2.3, 4. , 1.3],
[5.4, 3.9, 1.7, 0.4],
[5. , 2.3, 3.3, 1. ],
[6.4, 2.7, 5.3, 1.9],
[5. , 3.3, 1.4, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 2.4, 3.8, 1.1],
[6.7, 3. , 5. , 1.7],
[4.9, 3.1, 1.5, 0.2],
[5.8, 2.8, 5.1, 2.4],
[5. , 3.4, 1.5, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.9, 3.2, 4.8, 1.8],
[5.1, 2.5, 3. , 1.1],
[6.9, 3.2, 5.7, 2.3],
[6. , 2.7, 5.1, 1.6],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[5.5, 2.5, 4. , 1.3],
[4.4, 2.9, 1.4, 0.2],
[4.3, 3. , 1.1, 0.1],
[6. , 2.2, 5. , 1.5],
[7.2, 3.2, 6. , 1.8],
[4.6, 3.1, 1.5, 0.2],
[5.1, 3.5, 1.4, 0.3],
[4.4, 3. , 1.3, 0.2],
[6.3, 2.5, 4.9, 1.5],
[6.3, 3.4, 5.6, 2.4],
[4.6, 3.4, 1.4, 0.3],
[6.8, 3. , 5.5, 2.1],
[6.3, 3.3, 6. , 2.5],
[4.7, 3.2, 1.3, 0.2],
[6.1, 2.9, 4.7, 1.4],
[6.5, 2.8, 4.6, 1.5],
[6.2, 2.8, 4.8, 1.8],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 5.3, 2.3],
[5.1, 3.8, 1.6, 0.2],
[6.9, 3.1, 5.4, 2.1],
[5.9, 3. , 4.2, 1.5],
[6.5, 3. , 5.2, 2. ],
[5.7, 2.6, 3.5, 1. ],
[5.2, 2.7, 3.9, 1.4],
[6.1, 3. , 4.6, 1.4],
[4.5, 2.3, 1.3, 0.3],
[6.6, 2.9, 4.6, 1.3],
[5.5, 2.6, 4.4, 1.2],
[5.3, 3.7, 1.5, 0.2],
[5.6, 3. , 4.1, 1.3],
[7.3, 2.9, 6.3, 1.8],
[6.7, 3.3, 5.7, 2.1],
[5.1, 3.7, 1.5, 0.4],
[4.9, 2.4, 3.3, 1. ],
[6.7, 3.3, 5.7, 2.5],
[7.2, 3. , 5.8, 1.6],
[4.9, 3.6, 1.4, 0.1],
[6.7, 3.1, 5.6, 2.4],
[4.9, 3. , 1.4, 0.2],
[6.9, 3.1, 4.9, 1.5],
[7.4, 2.8, 6.1, 1.9],
[6.3, 2.9, 5.6, 1.8],
[5.7, 2.8, 4.1, 1.3],
[6.5, 3. , 5.5, 1.8],
[6.3, 2.3, 4.4, 1.3],
[6.4, 2.9, 4.3, 1.3],
[5.6, 2.8, 4.9, 2. ],
[5.9, 3. , 5.1, 1.8],
[5.4, 3.4, 1.7, 0.2],
[6.1, 2.8, 4. , 1.3],
[4.9, 2.5, 4.5, 1.7],
[5.8, 4. , 1.2, 0.2],
[5.8, 2.6, 4. , 1.2],
[7.1, 3. , 5.9, 2.1]])
In [112]:
X_test
Out[112]:
array([[6.1, 2.8, 4.7, 1.2],
[5.7, 3.8, 1.7, 0.3],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.9, 4.5, 1.5],
[6.8, 2.8, 4.8, 1.4],
[5.4, 3.4, 1.5, 0.4],
[5.6, 2.9, 3.6, 1.3],
[6.9, 3.1, 5.1, 2.3],
[6.2, 2.2, 4.5, 1.5],
[5.8, 2.7, 3.9, 1.2],
[6.5, 3.2, 5.1, 2. ],
[4.8, 3. , 1.4, 0.1],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.1, 3.8, 1.5, 0.3],
[6.3, 3.3, 4.7, 1.6],
[6.5, 3. , 5.8, 2.2],
[5.6, 2.5, 3.9, 1.1],
[5.7, 2.8, 4.5, 1.3],
[6.4, 2.8, 5.6, 2.2],
[4.7, 3.2, 1.6, 0.2],
[6.1, 3. , 4.9, 1.8],
[5. , 3.4, 1.6, 0.4],
[6.4, 2.8, 5.6, 2.1],
[7.9, 3.8, 6.4, 2. ],
[6.7, 3. , 5.2, 2.3],
[6.7, 2.5, 5.8, 1.8],
[6.8, 3.2, 5.9, 2.3],
[4.8, 3. , 1.4, 0.3],
[4.8, 3.1, 1.6, 0.2],
[4.6, 3.6, 1. , 0.2],
[5.7, 4.4, 1.5, 0.4],
[6.7, 3.1, 4.4, 1.4],
[4.8, 3.4, 1.6, 0.2],
[4.4, 3.2, 1.3, 0.2],
[6.3, 2.5, 5. , 1.9],
[6.4, 3.2, 4.5, 1.5],
[5.2, 3.5, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.2, 4.1, 1.5, 0.1],
[5.8, 2.7, 5.1, 1.9],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[5.4, 3.9, 1.3, 0.4],
[5.4, 3.7, 1.5, 0.2],
[5.5, 2.4, 3.7, 1. ],
[6.3, 2.8, 5.1, 1.5],
[6.4, 3.1, 5.5, 1.8],
[6.6, 3. , 4.4, 1.4],
[7.2, 3.6, 6.1, 2.5]])
In [113]:
y_train
Out[113]:
array([[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 0., 1.],
[0., 0., 1.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 1., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
In [114]:
y_test
Out[114]:
array([[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
Usually when using Neural Networks, you will get better performance when you standardize the data. Standardization just means normalizing the values to all fit between a certain range, like 0-1, or -1 to 1.
The scikit learn library also provides a nice function for this.
http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html
In [115]:
from sklearn.preprocessing import MinMaxScaler
In [116]:
scaler_object = MinMaxScaler()
In [117]:
scaler_object.fit(X_train)
Out[117]:
MinMaxScaler(copy=True, feature_range=(0, 1))
In [118]:
scaled_X_train = scaler_object.transform(X_train)
In [119]:
scaled_X_test = scaler_object.transform(X_test)
Ok, now we have the data scaled!
In [120]:
X_train.max()
Out[120]:
7.7
In [121]:
scaled_X_train.max()
Out[121]:
1.0
In [122]:
X_train
Out[122]:
array([[5.7, 2.9, 4.2, 1.3],
[7.6, 3. , 6.6, 2.1],
[5.6, 3. , 4.5, 1.5],
[5.1, 3.5, 1.4, 0.2],
[7.7, 2.8, 6.7, 2. ],
[5.8, 2.7, 4.1, 1. ],
[5.2, 3.4, 1.4, 0.2],
[5. , 3.5, 1.3, 0.3],
[5.1, 3.8, 1.9, 0.4],
[5. , 2. , 3.5, 1. ],
[6.3, 2.7, 4.9, 1.8],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5.1, 3.3, 1.7, 0.5],
[5.6, 2.7, 4.2, 1.3],
[5.1, 3.4, 1.5, 0.2],
[5.7, 3. , 4.2, 1.2],
[7.7, 3.8, 6.7, 2.2],
[4.6, 3.2, 1.4, 0.2],
[6.2, 2.9, 4.3, 1.3],
[5.7, 2.5, 5. , 2. ],
[5.5, 4.2, 1.4, 0.2],
[6. , 3. , 4.8, 1.8],
[5.8, 2.7, 5.1, 1.9],
[6. , 2.2, 4. , 1. ],
[5.4, 3. , 4.5, 1.5],
[6.2, 3.4, 5.4, 2.3],
[5.5, 2.3, 4. , 1.3],
[5.4, 3.9, 1.7, 0.4],
[5. , 2.3, 3.3, 1. ],
[6.4, 2.7, 5.3, 1.9],
[5. , 3.3, 1.4, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 2.4, 3.8, 1.1],
[6.7, 3. , 5. , 1.7],
[4.9, 3.1, 1.5, 0.2],
[5.8, 2.8, 5.1, 2.4],
[5. , 3.4, 1.5, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.9, 3.2, 4.8, 1.8],
[5.1, 2.5, 3. , 1.1],
[6.9, 3.2, 5.7, 2.3],
[6. , 2.7, 5.1, 1.6],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[5.5, 2.5, 4. , 1.3],
[4.4, 2.9, 1.4, 0.2],
[4.3, 3. , 1.1, 0.1],
[6. , 2.2, 5. , 1.5],
[7.2, 3.2, 6. , 1.8],
[4.6, 3.1, 1.5, 0.2],
[5.1, 3.5, 1.4, 0.3],
[4.4, 3. , 1.3, 0.2],
[6.3, 2.5, 4.9, 1.5],
[6.3, 3.4, 5.6, 2.4],
[4.6, 3.4, 1.4, 0.3],
[6.8, 3. , 5.5, 2.1],
[6.3, 3.3, 6. , 2.5],
[4.7, 3.2, 1.3, 0.2],
[6.1, 2.9, 4.7, 1.4],
[6.5, 2.8, 4.6, 1.5],
[6.2, 2.8, 4.8, 1.8],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 5.3, 2.3],
[5.1, 3.8, 1.6, 0.2],
[6.9, 3.1, 5.4, 2.1],
[5.9, 3. , 4.2, 1.5],
[6.5, 3. , 5.2, 2. ],
[5.7, 2.6, 3.5, 1. ],
[5.2, 2.7, 3.9, 1.4],
[6.1, 3. , 4.6, 1.4],
[4.5, 2.3, 1.3, 0.3],
[6.6, 2.9, 4.6, 1.3],
[5.5, 2.6, 4.4, 1.2],
[5.3, 3.7, 1.5, 0.2],
[5.6, 3. , 4.1, 1.3],
[7.3, 2.9, 6.3, 1.8],
[6.7, 3.3, 5.7, 2.1],
[5.1, 3.7, 1.5, 0.4],
[4.9, 2.4, 3.3, 1. ],
[6.7, 3.3, 5.7, 2.5],
[7.2, 3. , 5.8, 1.6],
[4.9, 3.6, 1.4, 0.1],
[6.7, 3.1, 5.6, 2.4],
[4.9, 3. , 1.4, 0.2],
[6.9, 3.1, 4.9, 1.5],
[7.4, 2.8, 6.1, 1.9],
[6.3, 2.9, 5.6, 1.8],
[5.7, 2.8, 4.1, 1.3],
[6.5, 3. , 5.5, 1.8],
[6.3, 2.3, 4.4, 1.3],
[6.4, 2.9, 4.3, 1.3],
[5.6, 2.8, 4.9, 2. ],
[5.9, 3. , 5.1, 1.8],
[5.4, 3.4, 1.7, 0.2],
[6.1, 2.8, 4. , 1.3],
[4.9, 2.5, 4.5, 1.7],
[5.8, 4. , 1.2, 0.2],
[5.8, 2.6, 4. , 1.2],
[7.1, 3. , 5.9, 2.1]])
In [123]:
scaled_X_train
Out[123]:
array([[0.41176471, 0.40909091, 0.55357143, 0.5 ],
[0.97058824, 0.45454545, 0.98214286, 0.83333333],
[0.38235294, 0.45454545, 0.60714286, 0.58333333],
[0.23529412, 0.68181818, 0.05357143, 0.04166667],
[1. , 0.36363636, 1. , 0.79166667],
[0.44117647, 0.31818182, 0.53571429, 0.375 ],
[0.26470588, 0.63636364, 0.05357143, 0.04166667],
[0.20588235, 0.68181818, 0.03571429, 0.08333333],
[0.23529412, 0.81818182, 0.14285714, 0.125 ],
[0.20588235, 0. , 0.42857143, 0.375 ],
[0.58823529, 0.31818182, 0.67857143, 0.70833333],
[0.14705882, 0.63636364, 0.14285714, 0.04166667],
[0.20588235, 0.45454545, 0.08928571, 0.04166667],
[0.23529412, 0.59090909, 0.10714286, 0.16666667],
[0.38235294, 0.31818182, 0.55357143, 0.5 ],
[0.23529412, 0.63636364, 0.07142857, 0.04166667],
[0.41176471, 0.45454545, 0.55357143, 0.45833333],
[1. , 0.81818182, 1. , 0.875 ],
[0.08823529, 0.54545455, 0.05357143, 0.04166667],
[0.55882353, 0.40909091, 0.57142857, 0.5 ],
[0.41176471, 0.22727273, 0.69642857, 0.79166667],
[0.35294118, 1. , 0.05357143, 0.04166667],
[0.5 , 0.45454545, 0.66071429, 0.70833333],
[0.44117647, 0.31818182, 0.71428571, 0.75 ],
[0.5 , 0.09090909, 0.51785714, 0.375 ],
[0.32352941, 0.45454545, 0.60714286, 0.58333333],
[0.55882353, 0.63636364, 0.76785714, 0.91666667],
[0.35294118, 0.13636364, 0.51785714, 0.5 ],
[0.32352941, 0.86363636, 0.10714286, 0.125 ],
[0.20588235, 0.13636364, 0.39285714, 0.375 ],
[0.61764706, 0.31818182, 0.75 , 0.75 ],
[0.20588235, 0.59090909, 0.05357143, 0.04166667],
[0.20588235, 0.54545455, 0.01785714, 0.04166667],
[0.35294118, 0.18181818, 0.48214286, 0.41666667],
[0.70588235, 0.45454545, 0.69642857, 0.66666667],
[0.17647059, 0.5 , 0.07142857, 0.04166667],
[0.44117647, 0.36363636, 0.71428571, 0.95833333],
[0.20588235, 0.63636364, 0.07142857, 0.04166667],
[0.20588235, 0.68181818, 0.08928571, 0.20833333],
[0.47058824, 0.54545455, 0.66071429, 0.70833333],
[0.23529412, 0.22727273, 0.33928571, 0.41666667],
[0.76470588, 0.54545455, 0.82142857, 0.91666667],
[0.5 , 0.31818182, 0.71428571, 0.625 ],
[0.52941176, 0.27272727, 0.80357143, 0.54166667],
[1. , 0.45454545, 0.89285714, 0.91666667],
[0.35294118, 0.22727273, 0.51785714, 0.5 ],
[0.02941176, 0.40909091, 0.05357143, 0.04166667],
[0. , 0.45454545, 0. , 0. ],
[0.5 , 0.09090909, 0.69642857, 0.58333333],
[0.85294118, 0.54545455, 0.875 , 0.70833333],
[0.08823529, 0.5 , 0.07142857, 0.04166667],
[0.23529412, 0.68181818, 0.05357143, 0.08333333],
[0.02941176, 0.45454545, 0.03571429, 0.04166667],
[0.58823529, 0.22727273, 0.67857143, 0.58333333],
[0.58823529, 0.63636364, 0.80357143, 0.95833333],
[0.08823529, 0.63636364, 0.05357143, 0.08333333],
[0.73529412, 0.45454545, 0.78571429, 0.83333333],
[0.58823529, 0.59090909, 0.875 , 1. ],
[0.11764706, 0.54545455, 0.03571429, 0.04166667],
[0.52941176, 0.40909091, 0.64285714, 0.54166667],
[0.64705882, 0.36363636, 0.625 , 0.58333333],
[0.55882353, 0.36363636, 0.66071429, 0.70833333],
[0.79411765, 0.54545455, 0.64285714, 0.54166667],
[0.61764706, 0.54545455, 0.75 , 0.91666667],
[0.23529412, 0.81818182, 0.08928571, 0.04166667],
[0.76470588, 0.5 , 0.76785714, 0.83333333],
[0.47058824, 0.45454545, 0.55357143, 0.58333333],
[0.64705882, 0.45454545, 0.73214286, 0.79166667],
[0.41176471, 0.27272727, 0.42857143, 0.375 ],
[0.26470588, 0.31818182, 0.5 , 0.54166667],
[0.52941176, 0.45454545, 0.625 , 0.54166667],
[0.05882353, 0.13636364, 0.03571429, 0.08333333],
[0.67647059, 0.40909091, 0.625 , 0.5 ],
[0.35294118, 0.27272727, 0.58928571, 0.45833333],
[0.29411765, 0.77272727, 0.07142857, 0.04166667],
[0.38235294, 0.45454545, 0.53571429, 0.5 ],
[0.88235294, 0.40909091, 0.92857143, 0.70833333],
[0.70588235, 0.59090909, 0.82142857, 0.83333333],
[0.23529412, 0.77272727, 0.07142857, 0.125 ],
[0.17647059, 0.18181818, 0.39285714, 0.375 ],
[0.70588235, 0.59090909, 0.82142857, 1. ],
[0.85294118, 0.45454545, 0.83928571, 0.625 ],
[0.17647059, 0.72727273, 0.05357143, 0. ],
[0.70588235, 0.5 , 0.80357143, 0.95833333],
[0.17647059, 0.45454545, 0.05357143, 0.04166667],
[0.76470588, 0.5 , 0.67857143, 0.58333333],
[0.91176471, 0.36363636, 0.89285714, 0.75 ],
[0.58823529, 0.40909091, 0.80357143, 0.70833333],
[0.41176471, 0.36363636, 0.53571429, 0.5 ],
[0.64705882, 0.45454545, 0.78571429, 0.70833333],
[0.58823529, 0.13636364, 0.58928571, 0.5 ],
[0.61764706, 0.40909091, 0.57142857, 0.5 ],
[0.38235294, 0.36363636, 0.67857143, 0.79166667],
[0.47058824, 0.45454545, 0.71428571, 0.70833333],
[0.32352941, 0.63636364, 0.10714286, 0.04166667],
[0.52941176, 0.36363636, 0.51785714, 0.5 ],
[0.17647059, 0.22727273, 0.60714286, 0.66666667],
[0.44117647, 0.90909091, 0.01785714, 0.04166667],
[0.44117647, 0.27272727, 0.51785714, 0.45833333],
[0.82352941, 0.45454545, 0.85714286, 0.83333333]])
In [132]:
from keras.models import Sequential
from keras.layers import Dense
In [133]:
model = Sequential()
model.add(Dense(8, input_dim=4, activation='relu'))
model.add(Dense(8, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
In [134]:
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_27 (Dense) (None, 8) 40
_________________________________________________________________
dense_28 (Dense) (None, 8) 72
_________________________________________________________________
dense_29 (Dense) (None, 3) 27
=================================================================
Total params: 139
Trainable params: 139
Non-trainable params: 0
_________________________________________________________________
In [135]:
# Play around with number of epochs as well!
model.fit(scaled_X_train,y_train,epochs=150, verbose=2)
Epoch 1/150
- 0s - loss: 1.0926 - acc: 0.3400
Epoch 2/150
- 0s - loss: 1.0871 - acc: 0.3400
Epoch 3/150
- 0s - loss: 1.0814 - acc: 0.3400
Epoch 4/150
- 0s - loss: 1.0760 - acc: 0.3400
Epoch 5/150
- 0s - loss: 1.0700 - acc: 0.3400
Epoch 6/150
- 0s - loss: 1.0640 - acc: 0.3500
Epoch 7/150
- 0s - loss: 1.0581 - acc: 0.3700
Epoch 8/150
- 0s - loss: 1.0520 - acc: 0.4200
Epoch 9/150
- 0s - loss: 1.0467 - acc: 0.5100
Epoch 10/150
- 0s - loss: 1.0416 - acc: 0.5700
Epoch 11/150
- 0s - loss: 1.0361 - acc: 0.6300
Epoch 12/150
- 0s - loss: 1.0308 - acc: 0.6300
Epoch 13/150
- 0s - loss: 1.0249 - acc: 0.6300
Epoch 14/150
- 0s - loss: 1.0187 - acc: 0.6200
Epoch 15/150
- 0s - loss: 1.0124 - acc: 0.6300
Epoch 16/150
- 0s - loss: 1.0062 - acc: 0.6300
Epoch 17/150
- 0s - loss: 0.9994 - acc: 0.6400
Epoch 18/150
- 0s - loss: 0.9919 - acc: 0.6400
Epoch 19/150
- 0s - loss: 0.9836 - acc: 0.6400
Epoch 20/150
- 0s - loss: 0.9748 - acc: 0.6400
Epoch 21/150
- 0s - loss: 0.9649 - acc: 0.6400
Epoch 22/150
- 0s - loss: 0.9552 - acc: 0.6400
Epoch 23/150
- 0s - loss: 0.9448 - acc: 0.6500
Epoch 24/150
- 0s - loss: 0.9337 - acc: 0.6500
Epoch 25/150
- 0s - loss: 0.9225 - acc: 0.6500
Epoch 26/150
- 0s - loss: 0.9111 - acc: 0.6500
Epoch 27/150
- 0s - loss: 0.8992 - acc: 0.6500
Epoch 28/150
- 0s - loss: 0.8882 - acc: 0.6500
Epoch 29/150
- 0s - loss: 0.8766 - acc: 0.6500
Epoch 30/150
- 0s - loss: 0.8658 - acc: 0.6500
Epoch 31/150
- 0s - loss: 0.8555 - acc: 0.6500
Epoch 32/150
- 0s - loss: 0.8446 - acc: 0.6500
Epoch 33/150
- 0s - loss: 0.8337 - acc: 0.6500
Epoch 34/150
- 0s - loss: 0.8225 - acc: 0.6500
Epoch 35/150
- 0s - loss: 0.8112 - acc: 0.6500
Epoch 36/150
- 0s - loss: 0.7998 - acc: 0.6500
Epoch 37/150
- 0s - loss: 0.7886 - acc: 0.6500
Epoch 38/150
- 0s - loss: 0.7768 - acc: 0.6500
Epoch 39/150
- 0s - loss: 0.7650 - acc: 0.6500
Epoch 40/150
- 0s - loss: 0.7539 - acc: 0.6500
Epoch 41/150
- 0s - loss: 0.7428 - acc: 0.6500
Epoch 42/150
- 0s - loss: 0.7318 - acc: 0.6500
Epoch 43/150
- 0s - loss: 0.7213 - acc: 0.6500
Epoch 44/150
- 0s - loss: 0.7111 - acc: 0.6500
Epoch 45/150
- 0s - loss: 0.7007 - acc: 0.6500
Epoch 46/150
- 0s - loss: 0.6907 - acc: 0.6500
Epoch 47/150
- 0s - loss: 0.6806 - acc: 0.6500
Epoch 48/150
- 0s - loss: 0.6714 - acc: 0.6500
Epoch 49/150
- 0s - loss: 0.6621 - acc: 0.6500
Epoch 50/150
- 0s - loss: 0.6534 - acc: 0.6500
Epoch 51/150
- 0s - loss: 0.6452 - acc: 0.6500
Epoch 52/150
- 0s - loss: 0.6366 - acc: 0.6500
Epoch 53/150
- 0s - loss: 0.6285 - acc: 0.6500
Epoch 54/150
- 0s - loss: 0.6205 - acc: 0.6500
Epoch 55/150
- 0s - loss: 0.6130 - acc: 0.6500
Epoch 56/150
- 0s - loss: 0.6058 - acc: 0.6500
Epoch 57/150
- 0s - loss: 0.5990 - acc: 0.6500
Epoch 58/150
- 0s - loss: 0.5920 - acc: 0.6500
Epoch 59/150
- 0s - loss: 0.5852 - acc: 0.6700
Epoch 60/150
- 0s - loss: 0.5790 - acc: 0.6700
Epoch 61/150
- 0s - loss: 0.5727 - acc: 0.6900
Epoch 62/150
- 0s - loss: 0.5665 - acc: 0.6900
Epoch 63/150
- 0s - loss: 0.5605 - acc: 0.6900
Epoch 64/150
- 0s - loss: 0.5544 - acc: 0.6900
Epoch 65/150
- 0s - loss: 0.5490 - acc: 0.6900
Epoch 66/150
- 0s - loss: 0.5436 - acc: 0.7000
Epoch 67/150
- 0s - loss: 0.5385 - acc: 0.7000
Epoch 68/150
- 0s - loss: 0.5334 - acc: 0.7000
Epoch 69/150
- 0s - loss: 0.5287 - acc: 0.7100
Epoch 70/150
- 0s - loss: 0.5243 - acc: 0.7200
Epoch 71/150
- 0s - loss: 0.5198 - acc: 0.7300
Epoch 72/150
- 0s - loss: 0.5150 - acc: 0.7300
Epoch 73/150
- 0s - loss: 0.5108 - acc: 0.7400
Epoch 74/150
- 0s - loss: 0.5066 - acc: 0.7400
Epoch 75/150
- 0s - loss: 0.5022 - acc: 0.7600
Epoch 76/150
- 0s - loss: 0.4986 - acc: 0.7500
Epoch 77/150
- 0s - loss: 0.4945 - acc: 0.7400
Epoch 78/150
- 0s - loss: 0.4908 - acc: 0.7500
Epoch 79/150
- 0s - loss: 0.4867 - acc: 0.7700
Epoch 80/150
- 0s - loss: 0.4830 - acc: 0.7700
Epoch 81/150
- 0s - loss: 0.4794 - acc: 0.7900
Epoch 82/150
- 0s - loss: 0.4761 - acc: 0.8500
Epoch 83/150
- 0s - loss: 0.4726 - acc: 0.8500
Epoch 84/150
- 0s - loss: 0.4693 - acc: 0.8500
Epoch 85/150
- 0s - loss: 0.4660 - acc: 0.8500
Epoch 86/150
- 0s - loss: 0.4628 - acc: 0.8500
Epoch 87/150
- 0s - loss: 0.4599 - acc: 0.8200
Epoch 88/150
- 0s - loss: 0.4575 - acc: 0.8000
Epoch 89/150
- 0s - loss: 0.4546 - acc: 0.7800
Epoch 90/150
- 0s - loss: 0.4521 - acc: 0.7800
Epoch 91/150
- 0s - loss: 0.4490 - acc: 0.8100
Epoch 92/150
- 0s - loss: 0.4460 - acc: 0.8300
Epoch 93/150
- 0s - loss: 0.4433 - acc: 0.8500
Epoch 94/150
- 0s - loss: 0.4407 - acc: 0.8500
Epoch 95/150
- 0s - loss: 0.4380 - acc: 0.8500
Epoch 96/150
- 0s - loss: 0.4352 - acc: 0.8600
Epoch 97/150
- 0s - loss: 0.4326 - acc: 0.8600
Epoch 98/150
- 0s - loss: 0.4302 - acc: 0.9000
Epoch 99/150
- 0s - loss: 0.4270 - acc: 0.9100
Epoch 100/150
- 0s - loss: 0.4245 - acc: 0.9200
Epoch 101/150
- 0s - loss: 0.4223 - acc: 0.9200
Epoch 102/150
- 0s - loss: 0.4195 - acc: 0.9200
Epoch 103/150
- 0s - loss: 0.4171 - acc: 0.9200
Epoch 104/150
- 0s - loss: 0.4147 - acc: 0.9200
Epoch 105/150
- 0s - loss: 0.4124 - acc: 0.9000
Epoch 106/150
- 0s - loss: 0.4102 - acc: 0.9000
Epoch 107/150
- 0s - loss: 0.4077 - acc: 0.9200
Epoch 108/150
- 0s - loss: 0.4055 - acc: 0.9200
Epoch 109/150
- 0s - loss: 0.4028 - acc: 0.9200
Epoch 110/150
- 0s - loss: 0.4006 - acc: 0.9200
Epoch 111/150
- 0s - loss: 0.3985 - acc: 0.9200
Epoch 112/150
- 0s - loss: 0.3965 - acc: 0.9200
Epoch 113/150
- 0s - loss: 0.3945 - acc: 0.9200
Epoch 114/150
- 0s - loss: 0.3930 - acc: 0.9200
Epoch 115/150
- 0s - loss: 0.3914 - acc: 0.9000
Epoch 116/150
- 0s - loss: 0.3891 - acc: 0.9100
Epoch 117/150
- 0s - loss: 0.3860 - acc: 0.9200
Epoch 118/150
- 0s - loss: 0.3844 - acc: 0.9200
Epoch 119/150
- 0s - loss: 0.3821 - acc: 0.9500
Epoch 120/150
- 0s - loss: 0.3801 - acc: 0.9500
Epoch 121/150
- 0s - loss: 0.3780 - acc: 0.9500
Epoch 122/150
- 0s - loss: 0.3757 - acc: 0.9500
Epoch 123/150
- 0s - loss: 0.3738 - acc: 0.9500
Epoch 124/150
- 0s - loss: 0.3720 - acc: 0.9500
Epoch 125/150
- 0s - loss: 0.3693 - acc: 0.9500
Epoch 126/150
- 0s - loss: 0.3662 - acc: 0.9500
Epoch 127/150
- 0s - loss: 0.3646 - acc: 0.9300
Epoch 128/150
- 0s - loss: 0.3652 - acc: 0.9200
Epoch 129/150
- 0s - loss: 0.3640 - acc: 0.9100
Epoch 130/150
- 0s - loss: 0.3627 - acc: 0.9100
Epoch 131/150
- 0s - loss: 0.3604 - acc: 0.9100
Epoch 132/150
- 0s - loss: 0.3572 - acc: 0.9300
Epoch 133/150
- 0s - loss: 0.3544 - acc: 0.9300
Epoch 134/150
- 0s - loss: 0.3521 - acc: 0.9300
Epoch 135/150
- 0s - loss: 0.3501 - acc: 0.9500
Epoch 136/150
- 0s - loss: 0.3482 - acc: 0.9500
Epoch 137/150
- 0s - loss: 0.3465 - acc: 0.9500
Epoch 138/150
- 0s - loss: 0.3461 - acc: 0.9500
Epoch 139/150
- 0s - loss: 0.3435 - acc: 0.9300
Epoch 140/150
- 0s - loss: 0.3407 - acc: 0.9500
Epoch 141/150
- 0s - loss: 0.3384 - acc: 0.9500
Epoch 142/150
- 0s - loss: 0.3366 - acc: 0.9500
Epoch 143/150
- 0s - loss: 0.3347 - acc: 0.9500
Epoch 144/150
- 0s - loss: 0.3336 - acc: 0.9400
Epoch 145/150
- 0s - loss: 0.3324 - acc: 0.9300
Epoch 146/150
- 0s - loss: 0.3311 - acc: 0.9300
Epoch 147/150
- 0s - loss: 0.3300 - acc: 0.9300
Epoch 148/150
- 0s - loss: 0.3283 - acc: 0.9300
Epoch 149/150
- 0s - loss: 0.3256 - acc: 0.9400
Epoch 150/150
- 0s - loss: 0.3224 - acc: 0.9500
Out[135]:
<keras.callbacks.History at 0x1fefed36ac8>
In [136]:
scaled_X_test
Out[136]:
array([[ 0.52941176, 0.36363636, 0.64285714, 0.45833333],
[ 0.41176471, 0.81818182, 0.10714286, 0.08333333],
[ 1. , 0.27272727, 1.03571429, 0.91666667],
[ 0.5 , 0.40909091, 0.60714286, 0.58333333],
[ 0.73529412, 0.36363636, 0.66071429, 0.54166667],
[ 0.32352941, 0.63636364, 0.07142857, 0.125 ],
[ 0.38235294, 0.40909091, 0.44642857, 0.5 ],
[ 0.76470588, 0.5 , 0.71428571, 0.91666667],
[ 0.55882353, 0.09090909, 0.60714286, 0.58333333],
[ 0.44117647, 0.31818182, 0.5 , 0.45833333],
[ 0.64705882, 0.54545455, 0.71428571, 0.79166667],
[ 0.14705882, 0.45454545, 0.05357143, 0. ],
[ 0.35294118, 0.68181818, 0.03571429, 0.04166667],
[ 0.17647059, 0.5 , 0.07142857, 0. ],
[ 0.23529412, 0.81818182, 0.07142857, 0.08333333],
[ 0.58823529, 0.59090909, 0.64285714, 0.625 ],
[ 0.64705882, 0.45454545, 0.83928571, 0.875 ],
[ 0.38235294, 0.22727273, 0.5 , 0.41666667],
[ 0.41176471, 0.36363636, 0.60714286, 0.5 ],
[ 0.61764706, 0.36363636, 0.80357143, 0.875 ],
[ 0.11764706, 0.54545455, 0.08928571, 0.04166667],
[ 0.52941176, 0.45454545, 0.67857143, 0.70833333],
[ 0.20588235, 0.63636364, 0.08928571, 0.125 ],
[ 0.61764706, 0.36363636, 0.80357143, 0.83333333],
[ 1.05882353, 0.81818182, 0.94642857, 0.79166667],
[ 0.70588235, 0.45454545, 0.73214286, 0.91666667],
[ 0.70588235, 0.22727273, 0.83928571, 0.70833333],
[ 0.73529412, 0.54545455, 0.85714286, 0.91666667],
[ 0.14705882, 0.45454545, 0.05357143, 0.08333333],
[ 0.14705882, 0.5 , 0.08928571, 0.04166667],
[ 0.08823529, 0.72727273, -0.01785714, 0.04166667],
[ 0.41176471, 1.09090909, 0.07142857, 0.125 ],
[ 0.70588235, 0.5 , 0.58928571, 0.54166667],
[ 0.14705882, 0.63636364, 0.08928571, 0.04166667],
[ 0.02941176, 0.54545455, 0.03571429, 0.04166667],
[ 0.58823529, 0.22727273, 0.69642857, 0.75 ],
[ 0.61764706, 0.54545455, 0.60714286, 0.58333333],
[ 0.26470588, 0.68181818, 0.07142857, 0.04166667],
[ 0.20588235, 0.72727273, 0.05357143, 0.04166667],
[ 0.26470588, 0.95454545, 0.07142857, 0. ],
[ 0.44117647, 0.31818182, 0.71428571, 0.75 ],
[ 0.5 , 0.63636364, 0.60714286, 0.625 ],
[ 0.70588235, 0.5 , 0.64285714, 0.58333333],
[ 0.32352941, 0.86363636, 0.03571429, 0.125 ],
[ 0.32352941, 0.77272727, 0.07142857, 0.04166667],
[ 0.35294118, 0.18181818, 0.46428571, 0.375 ],
[ 0.58823529, 0.36363636, 0.71428571, 0.58333333],
[ 0.61764706, 0.5 , 0.78571429, 0.70833333],
[ 0.67647059, 0.45454545, 0.58928571, 0.54166667],
[ 0.85294118, 0.72727273, 0.89285714, 1. ]])
In [137]:
# Spits out probabilities by default.
# model.predict(scaled_X_test)
In [138]:
model.predict_classes(scaled_X_test)
Out[138]:
array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,
0, 1, 2, 2, 1, 2], dtype=int64)
So how well did we do? How do we actually measure "well". Is 95% accuracy good enough? It all depends on the situation. Also we need to take into account things like recall and precision. Make sure to watch the video discussion on classification evaluation before running this code!
In [139]:
model.metrics_names
Out[139]:
['loss', 'acc']
In [140]:
model.evaluate(x=scaled_X_test,y=y_test)
50/50 [==============================] - 0s 2ms/step
Out[140]:
[0.2843634402751923, 0.96]
In [141]:
from sklearn.metrics import confusion_matrix,classification_report
In [142]:
predictions = model.predict_classes(scaled_X_test)
In [161]:
predictions
Out[161]:
array([1, 0, 2, 1, 2, 0, 1, 2, 2, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,
0, 1, 2, 2, 1, 2], dtype=int64)
In [163]:
y_test.argmax(axis=1)
Out[163]:
array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0,
0, 1, 2, 2, 1, 2], dtype=int64)
In [164]:
confusion_matrix(y_test.argmax(axis=1),predictions)
Out[164]:
array([[19, 0, 0],
[ 0, 13, 2],
[ 0, 0, 16]], dtype=int64)
In [165]:
print(classification_report(y_test.argmax(axis=1),predictions))
precision recall f1-score support
0 1.00 1.00 1.00 19
1 1.00 0.87 0.93 15
2 0.89 1.00 0.94 16
micro avg 0.96 0.96 0.96 50
macro avg 0.96 0.96 0.96 50
weighted avg 0.96 0.96 0.96 50
In [166]:
model.save('myfirstmodel.h5')
In [167]:
from keras.models import load_model
In [168]:
newmodel = load_model('myfirstmodel.h5')
In [169]:
newmodel.predict_classes(X_test)
Out[169]:
array([2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2,
1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1,
1, 2, 2, 2, 2, 2], dtype=int64)
Great job! You now know how to preprocess data, train a neural network, and evaluate its classification performance!
Content source: rishuatgithub/MLPy
Similar notebooks: