In [1]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
In [3]:
iris = load_iris()
print(type(iris))
<class 'sklearn.utils.Bunch'>
In [4]:
print(iris)
{'data': array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.6, 1.4, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 150 (50 in each of three classes)\n :Number of Attributes: 4 numeric, predictive attributes and the class\n :Attribute Information:\n - sepal length in cm\n - sepal width in cm\n - petal length in cm\n - petal width in cm\n - class:\n - Iris-Setosa\n - Iris-Versicolour\n - Iris-Virginica\n \n :Summary Statistics:\n\n ============== ==== ==== ======= ===== ====================\n Min Max Mean SD Class Correlation\n ============== ==== ==== ======= ===== ====================\n sepal length: 4.3 7.9 5.84 0.83 0.7826\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n ============== ==== ==== ======= ===== ====================\n\n :Missing Attribute Values: None\n :Class Distribution: 33.3% for each of 3 classes.\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature. Fisher\'s paper is a classic in the field and\nis referenced frequently to this day. (See Duda & Hart, for example.) The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant. One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n Mathematical Statistics" (John Wiley, NY, 1950).\n - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments". IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions\n on Information Theory, May 1972, 431-433.\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n - Many, many more ...', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 'filename': '/home/nbuser/anaconda3_501/lib/python3.6/site-packages/sklearn/datasets/data/iris.csv'}
In [8]:
attribs = iris['data']
classes = iris['target']
names = iris['target_names']
feature_names = iris['feature_names']
print(type(attribs), attribs)
print(type(classes), classes)
print(type(names), names)
print(type(feature_names), feature_names)
<class 'numpy.ndarray'> [[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]
[5.4 3.9 1.7 0.4]
[4.6 3.4 1.4 0.3]
[5. 3.4 1.5 0.2]
[4.4 2.9 1.4 0.2]
[4.9 3.1 1.5 0.1]
[5.4 3.7 1.5 0.2]
[4.8 3.4 1.6 0.2]
[4.8 3. 1.4 0.1]
[4.3 3. 1.1 0.1]
[5.8 4. 1.2 0.2]
[5.7 4.4 1.5 0.4]
[5.4 3.9 1.3 0.4]
[5.1 3.5 1.4 0.3]
[5.7 3.8 1.7 0.3]
[5.1 3.8 1.5 0.3]
[5.4 3.4 1.7 0.2]
[5.1 3.7 1.5 0.4]
[4.6 3.6 1. 0.2]
[5.1 3.3 1.7 0.5]
[4.8 3.4 1.9 0.2]
[5. 3. 1.6 0.2]
[5. 3.4 1.6 0.4]
[5.2 3.5 1.5 0.2]
[5.2 3.4 1.4 0.2]
[4.7 3.2 1.6 0.2]
[4.8 3.1 1.6 0.2]
[5.4 3.4 1.5 0.4]
[5.2 4.1 1.5 0.1]
[5.5 4.2 1.4 0.2]
[4.9 3.1 1.5 0.2]
[5. 3.2 1.2 0.2]
[5.5 3.5 1.3 0.2]
[4.9 3.6 1.4 0.1]
[4.4 3. 1.3 0.2]
[5.1 3.4 1.5 0.2]
[5. 3.5 1.3 0.3]
[4.5 2.3 1.3 0.3]
[4.4 3.2 1.3 0.2]
[5. 3.5 1.6 0.6]
[5.1 3.8 1.9 0.4]
[4.8 3. 1.4 0.3]
[5.1 3.8 1.6 0.2]
[4.6 3.2 1.4 0.2]
[5.3 3.7 1.5 0.2]
[5. 3.3 1.4 0.2]
[7. 3.2 4.7 1.4]
[6.4 3.2 4.5 1.5]
[6.9 3.1 4.9 1.5]
[5.5 2.3 4. 1.3]
[6.5 2.8 4.6 1.5]
[5.7 2.8 4.5 1.3]
[6.3 3.3 4.7 1.6]
[4.9 2.4 3.3 1. ]
[6.6 2.9 4.6 1.3]
[5.2 2.7 3.9 1.4]
[5. 2. 3.5 1. ]
[5.9 3. 4.2 1.5]
[6. 2.2 4. 1. ]
[6.1 2.9 4.7 1.4]
[5.6 2.9 3.6 1.3]
[6.7 3.1 4.4 1.4]
[5.6 3. 4.5 1.5]
[5.8 2.7 4.1 1. ]
[6.2 2.2 4.5 1.5]
[5.6 2.5 3.9 1.1]
[5.9 3.2 4.8 1.8]
[6.1 2.8 4. 1.3]
[6.3 2.5 4.9 1.5]
[6.1 2.8 4.7 1.2]
[6.4 2.9 4.3 1.3]
[6.6 3. 4.4 1.4]
[6.8 2.8 4.8 1.4]
[6.7 3. 5. 1.7]
[6. 2.9 4.5 1.5]
[5.7 2.6 3.5 1. ]
[5.5 2.4 3.8 1.1]
[5.5 2.4 3.7 1. ]
[5.8 2.7 3.9 1.2]
[6. 2.7 5.1 1.6]
[5.4 3. 4.5 1.5]
[6. 3.4 4.5 1.6]
[6.7 3.1 4.7 1.5]
[6.3 2.3 4.4 1.3]
[5.6 3. 4.1 1.3]
[5.5 2.5 4. 1.3]
[5.5 2.6 4.4 1.2]
[6.1 3. 4.6 1.4]
[5.8 2.6 4. 1.2]
[5. 2.3 3.3 1. ]
[5.6 2.7 4.2 1.3]
[5.7 3. 4.2 1.2]
[5.7 2.9 4.2 1.3]
[6.2 2.9 4.3 1.3]
[5.1 2.5 3. 1.1]
[5.7 2.8 4.1 1.3]
[6.3 3.3 6. 2.5]
[5.8 2.7 5.1 1.9]
[7.1 3. 5.9 2.1]
[6.3 2.9 5.6 1.8]
[6.5 3. 5.8 2.2]
[7.6 3. 6.6 2.1]
[4.9 2.5 4.5 1.7]
[7.3 2.9 6.3 1.8]
[6.7 2.5 5.8 1.8]
[7.2 3.6 6.1 2.5]
[6.5 3.2 5.1 2. ]
[6.4 2.7 5.3 1.9]
[6.8 3. 5.5 2.1]
[5.7 2.5 5. 2. ]
[5.8 2.8 5.1 2.4]
[6.4 3.2 5.3 2.3]
[6.5 3. 5.5 1.8]
[7.7 3.8 6.7 2.2]
[7.7 2.6 6.9 2.3]
[6. 2.2 5. 1.5]
[6.9 3.2 5.7 2.3]
[5.6 2.8 4.9 2. ]
[7.7 2.8 6.7 2. ]
[6.3 2.7 4.9 1.8]
[6.7 3.3 5.7 2.1]
[7.2 3.2 6. 1.8]
[6.2 2.8 4.8 1.8]
[6.1 3. 4.9 1.8]
[6.4 2.8 5.6 2.1]
[7.2 3. 5.8 1.6]
[7.4 2.8 6.1 1.9]
[7.9 3.8 6.4 2. ]
[6.4 2.8 5.6 2.2]
[6.3 2.8 5.1 1.5]
[6.1 2.6 5.6 1.4]
[7.7 3. 6.1 2.3]
[6.3 3.4 5.6 2.4]
[6.4 3.1 5.5 1.8]
[6. 3. 4.8 1.8]
[6.9 3.1 5.4 2.1]
[6.7 3.1 5.6 2.4]
[6.9 3.1 5.1 2.3]
[5.8 2.7 5.1 1.9]
[6.8 3.2 5.9 2.3]
[6.7 3.3 5.7 2.5]
[6.7 3. 5.2 2.3]
[6.3 2.5 5. 1.9]
[6.5 3. 5.2 2. ]
[6.2 3.4 5.4 2.3]
[5.9 3. 5.1 1.8]]
<class 'numpy.ndarray'> [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
<class 'numpy.ndarray'> ['setosa' 'versicolor' 'virginica']
<class 'list'> ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
In [13]:
encoder = OneHotEncoder()
iris_encoded = encoder.fit_transform(classes[:, np.newaxis]).toarray()
/home/nbuser/anaconda3_501/lib/python3.6/site-packages/sklearn/preprocessing/_encoders.py:371: FutureWarning: The handling of integer data will change in version 0.22. Currently, the categories are determined based on the range [0, max(values)], while in the future they will be determined based on the unique values.
If you want the future behaviour and silence this warning, you can specify "categories='auto'".
In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.
warnings.warn(msg, FutureWarning)
In [14]:
print(iris_encoded)
[[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]
[0. 0. 1.]]
In [15]:
X_train, X_test, y_train, y_test = train_test_split(
attribs, iris_encoded, test_size=0.5, random_state=2)
In [16]:
print(X_train[:10])
[[5.9 3. 4.2 1.5]
[6.7 3.1 4.7 1.5]
[7.7 2.8 6.7 2. ]
[4.9 3. 1.4 0.2]
[6.3 3.3 4.7 1.6]
[5.1 3.8 1.5 0.3]
[5.8 2.7 3.9 1.2]
[6.9 3.2 5.7 2.3]
[4.9 3.1 1.5 0.1]
[5. 2. 3.5 1. ]]
In [17]:
print(y_train[:10])
[[0. 1. 0.]
[0. 1. 0.]
[0. 0. 1.]
[1. 0. 0.]
[0. 1. 0.]
[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]
[1. 0. 0.]
[0. 1. 0.]]
In [20]:
model = Sequential()
model.add(Dense(8, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
In [21]:
model.fit(X_train, y_train, epochs=150, batch_size=10)
perda,acuracia = model.evaluate(X_test, y_test)
print('Acurácia: %.2f' % (acuracia*100))
WARNING:tensorflow:From /home/nbuser/anaconda3_501/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 75 samples
Epoch 1/150
75/75 [==============================] - 3s 37ms/sample - loss: 1.3718 - accuracy: 0.6400
Epoch 2/150
75/75 [==============================] - 0s 765us/sample - loss: 1.2667 - accuracy: 0.6400
Epoch 3/150
75/75 [==============================] - 0s 705us/sample - loss: 1.1769 - accuracy: 0.6400
Epoch 4/150
75/75 [==============================] - 0s 729us/sample - loss: 1.1135 - accuracy: 0.6400
Epoch 5/150
75/75 [==============================] - 0s 1ms/sample - loss: 1.0492 - accuracy: 0.6400
Epoch 6/150
75/75 [==============================] - 0s 695us/sample - loss: 0.9940 - accuracy: 0.6400
Epoch 7/150
75/75 [==============================] - 0s 883us/sample - loss: 0.9471 - accuracy: 0.6400
Epoch 8/150
75/75 [==============================] - 0s 855us/sample - loss: 0.9080 - accuracy: 0.6400
Epoch 9/150
75/75 [==============================] - 0s 867us/sample - loss: 0.8679 - accuracy: 0.6400
Epoch 10/150
75/75 [==============================] - 0s 649us/sample - loss: 0.8369 - accuracy: 0.6400
Epoch 11/150
75/75 [==============================] - 0s 733us/sample - loss: 0.8099 - accuracy: 0.6400
Epoch 12/150
75/75 [==============================] - 0s 730us/sample - loss: 0.7811 - accuracy: 0.6400
Epoch 13/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.7545 - accuracy: 0.6400
Epoch 14/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.7328 - accuracy: 0.6400
Epoch 15/150
75/75 [==============================] - 0s 768us/sample - loss: 0.7184 - accuracy: 0.6400
Epoch 16/150
75/75 [==============================] - 0s 681us/sample - loss: 0.7076 - accuracy: 0.6533
Epoch 17/150
75/75 [==============================] - 0s 658us/sample - loss: 0.6980 - accuracy: 0.6533
Epoch 18/150
75/75 [==============================] - 0s 909us/sample - loss: 0.6887 - accuracy: 0.6533
Epoch 19/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.6797 - accuracy: 0.6533
Epoch 20/150
75/75 [==============================] - 0s 860us/sample - loss: 0.6705 - accuracy: 0.6533
Epoch 21/150
75/75 [==============================] - 0s 583us/sample - loss: 0.6622 - accuracy: 0.6667
Epoch 22/150
75/75 [==============================] - 0s 792us/sample - loss: 0.6542 - accuracy: 0.6667
Epoch 23/150
75/75 [==============================] - 0s 592us/sample - loss: 0.6454 - accuracy: 0.6667
Epoch 24/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.6371 - accuracy: 0.6667
Epoch 25/150
75/75 [==============================] - 0s 745us/sample - loss: 0.6294 - accuracy: 0.6800
Epoch 26/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.6215 - accuracy: 0.6800
Epoch 27/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.6139 - accuracy: 0.6800
Epoch 28/150
75/75 [==============================] - 0s 606us/sample - loss: 0.6069 - accuracy: 0.6800
Epoch 29/150
75/75 [==============================] - 0s 555us/sample - loss: 0.5990 - accuracy: 0.7333
Epoch 30/150
75/75 [==============================] - 0s 801us/sample - loss: 0.5924 - accuracy: 0.7600
Epoch 31/150
75/75 [==============================] - 0s 889us/sample - loss: 0.5855 - accuracy: 0.7867
Epoch 32/150
75/75 [==============================] - 0s 772us/sample - loss: 0.5787 - accuracy: 0.7867
Epoch 33/150
75/75 [==============================] - 0s 658us/sample - loss: 0.5729 - accuracy: 0.7867
Epoch 34/150
75/75 [==============================] - 0s 649us/sample - loss: 0.5657 - accuracy: 0.7867
Epoch 35/150
75/75 [==============================] - 0s 760us/sample - loss: 0.5596 - accuracy: 0.7867
Epoch 36/150
75/75 [==============================] - 0s 949us/sample - loss: 0.5537 - accuracy: 0.7867
Epoch 37/150
75/75 [==============================] - 0s 893us/sample - loss: 0.5492 - accuracy: 0.7867
Epoch 38/150
75/75 [==============================] - 0s 792us/sample - loss: 0.5425 - accuracy: 0.7867
Epoch 39/150
75/75 [==============================] - 0s 882us/sample - loss: 0.5364 - accuracy: 0.8000
Epoch 40/150
75/75 [==============================] - 0s 846us/sample - loss: 0.5321 - accuracy: 0.8267
Epoch 41/150
75/75 [==============================] - 0s 950us/sample - loss: 0.5257 - accuracy: 0.8267
Epoch 42/150
75/75 [==============================] - 0s 775us/sample - loss: 0.5211 - accuracy: 0.8267
Epoch 43/150
75/75 [==============================] - 0s 628us/sample - loss: 0.5159 - accuracy: 0.8267
Epoch 44/150
75/75 [==============================] - 0s 608us/sample - loss: 0.5110 - accuracy: 0.8267
Epoch 45/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.5064 - accuracy: 0.8267
Epoch 46/150
75/75 [==============================] - ETA: 0s - loss: 0.5157 - accuracy: 0.90 - 0s 687us/sample - loss: 0.5016 - accuracy: 0.8533
Epoch 47/150
75/75 [==============================] - 0s 688us/sample - loss: 0.4970 - accuracy: 0.8667
Epoch 48/150
75/75 [==============================] - 0s 880us/sample - loss: 0.4929 - accuracy: 0.8667
Epoch 49/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.4886 - accuracy: 0.8800
Epoch 50/150
75/75 [==============================] - 0s 819us/sample - loss: 0.4849 - accuracy: 0.8533
Epoch 51/150
75/75 [==============================] - 0s 785us/sample - loss: 0.4811 - accuracy: 0.8533
Epoch 52/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.4770 - accuracy: 0.8533
Epoch 53/150
75/75 [==============================] - 0s 775us/sample - loss: 0.4733 - accuracy: 0.9067
Epoch 54/150
75/75 [==============================] - 0s 646us/sample - loss: 0.4693 - accuracy: 0.9067
Epoch 55/150
75/75 [==============================] - 0s 674us/sample - loss: 0.4660 - accuracy: 0.9333
Epoch 56/150
75/75 [==============================] - 0s 993us/sample - loss: 0.4626 - accuracy: 0.9333
Epoch 57/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.4587 - accuracy: 0.9333
Epoch 58/150
75/75 [==============================] - 0s 754us/sample - loss: 0.4570 - accuracy: 0.9333
Epoch 59/150
75/75 [==============================] - 0s 725us/sample - loss: 0.4524 - accuracy: 0.9333
Epoch 60/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.4492 - accuracy: 0.9467
Epoch 61/150
75/75 [==============================] - 0s 832us/sample - loss: 0.4459 - accuracy: 0.9333
Epoch 62/150
75/75 [==============================] - 0s 536us/sample - loss: 0.4429 - accuracy: 0.9333
Epoch 63/150
75/75 [==============================] - 0s 737us/sample - loss: 0.4405 - accuracy: 0.9200
Epoch 64/150
75/75 [==============================] - 0s 572us/sample - loss: 0.4370 - accuracy: 0.9200
Epoch 65/150
75/75 [==============================] - 0s 855us/sample - loss: 0.4347 - accuracy: 0.9067
Epoch 66/150
75/75 [==============================] - 0s 789us/sample - loss: 0.4323 - accuracy: 0.8933
Epoch 67/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.4285 - accuracy: 0.9200
Epoch 68/150
75/75 [==============================] - 0s 711us/sample - loss: 0.4268 - accuracy: 0.9333
Epoch 69/150
75/75 [==============================] - 0s 760us/sample - loss: 0.4249 - accuracy: 0.9467
Epoch 70/150
75/75 [==============================] - 0s 833us/sample - loss: 0.4207 - accuracy: 0.9467
Epoch 71/150
75/75 [==============================] - 0s 509us/sample - loss: 0.4190 - accuracy: 0.9333
Epoch 72/150
75/75 [==============================] - 0s 529us/sample - loss: 0.4171 - accuracy: 0.9333
Epoch 73/150
75/75 [==============================] - 0s 778us/sample - loss: 0.4134 - accuracy: 0.9467
Epoch 74/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.4109 - accuracy: 0.9467
Epoch 75/150
75/75 [==============================] - 0s 821us/sample - loss: 0.4090 - accuracy: 0.9467
Epoch 76/150
75/75 [==============================] - 0s 660us/sample - loss: 0.4065 - accuracy: 0.9467
Epoch 77/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.4043 - accuracy: 0.9600
Epoch 78/150
75/75 [==============================] - 0s 979us/sample - loss: 0.4018 - accuracy: 0.9467
Epoch 79/150
75/75 [==============================] - 0s 752us/sample - loss: 0.3998 - accuracy: 0.9600
Epoch 80/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3972 - accuracy: 0.9467
Epoch 81/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3951 - accuracy: 0.9467
Epoch 82/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3931 - accuracy: 0.9467
Epoch 83/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3911 - accuracy: 0.9467
Epoch 84/150
75/75 [==============================] - 0s 954us/sample - loss: 0.3893 - accuracy: 0.9600
Epoch 85/150
75/75 [==============================] - 0s 823us/sample - loss: 0.3866 - accuracy: 0.9600
Epoch 86/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3849 - accuracy: 0.9600
Epoch 87/150
75/75 [==============================] - 0s 657us/sample - loss: 0.3821 - accuracy: 0.9600
Epoch 88/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3812 - accuracy: 0.9467
Epoch 89/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3799 - accuracy: 0.9333
Epoch 90/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3769 - accuracy: 0.9333
Epoch 91/150
75/75 [==============================] - 0s 655us/sample - loss: 0.3750 - accuracy: 0.9467
Epoch 92/150
75/75 [==============================] - 0s 834us/sample - loss: 0.3717 - accuracy: 0.9600
Epoch 93/150
75/75 [==============================] - 0s 980us/sample - loss: 0.3701 - accuracy: 0.9600
Epoch 94/150
75/75 [==============================] - 0s 999us/sample - loss: 0.3673 - accuracy: 0.9600
Epoch 95/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3672 - accuracy: 0.9600
Epoch 96/150
75/75 [==============================] - 0s 783us/sample - loss: 0.3640 - accuracy: 0.9600
Epoch 97/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3609 - accuracy: 0.9600
Epoch 98/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3594 - accuracy: 0.9600
Epoch 99/150
75/75 [==============================] - 0s 883us/sample - loss: 0.3571 - accuracy: 0.9467
Epoch 100/150
75/75 [==============================] - 0s 948us/sample - loss: 0.3546 - accuracy: 0.9600
Epoch 101/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3522 - accuracy: 0.9600
Epoch 102/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3504 - accuracy: 0.9600
Epoch 103/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3480 - accuracy: 0.9600
Epoch 104/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3461 - accuracy: 0.9600
Epoch 105/150
75/75 [==============================] - 0s 824us/sample - loss: 0.3433 - accuracy: 0.9600
Epoch 106/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3415 - accuracy: 0.9600
Epoch 107/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3389 - accuracy: 0.9600
Epoch 108/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3368 - accuracy: 0.9600
Epoch 109/150
75/75 [==============================] - 0s 971us/sample - loss: 0.3352 - accuracy: 0.9733
Epoch 110/150
75/75 [==============================] - ETA: 0s - loss: 0.3386 - accuracy: 0.95 - 0s 977us/sample - loss: 0.3327 - accuracy: 0.9600
Epoch 111/150
75/75 [==============================] - 0s 773us/sample - loss: 0.3305 - accuracy: 0.9467
Epoch 112/150
75/75 [==============================] - 0s 985us/sample - loss: 0.3292 - accuracy: 0.9600
Epoch 113/150
75/75 [==============================] - 0s 932us/sample - loss: 0.3260 - accuracy: 0.9600
Epoch 114/150
75/75 [==============================] - 0s 897us/sample - loss: 0.3235 - accuracy: 0.9733
Epoch 115/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3207 - accuracy: 0.9600
Epoch 116/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3196 - accuracy: 0.9600
Epoch 117/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3165 - accuracy: 0.9733
Epoch 118/150
75/75 [==============================] - 0s 829us/sample - loss: 0.3144 - accuracy: 0.9733
Epoch 119/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3120 - accuracy: 0.9733
Epoch 120/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3099 - accuracy: 0.9733
Epoch 121/150
75/75 [==============================] - 0s 687us/sample - loss: 0.3077 - accuracy: 0.9733
Epoch 122/150
75/75 [==============================] - 0s 730us/sample - loss: 0.3056 - accuracy: 0.9733
Epoch 123/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.3035 - accuracy: 0.9733
Epoch 124/150
75/75 [==============================] - 0s 843us/sample - loss: 0.3014 - accuracy: 0.9733
Epoch 125/150
75/75 [==============================] - 0s 796us/sample - loss: 0.2991 - accuracy: 0.9733
Epoch 126/150
75/75 [==============================] - 0s 932us/sample - loss: 0.2980 - accuracy: 0.9600
Epoch 127/150
75/75 [==============================] - 0s 970us/sample - loss: 0.2944 - accuracy: 0.9600
Epoch 128/150
75/75 [==============================] - 0s 571us/sample - loss: 0.2930 - accuracy: 0.9733
Epoch 129/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2918 - accuracy: 0.9733
Epoch 130/150
75/75 [==============================] - 0s 677us/sample - loss: 0.2887 - accuracy: 0.9600
Epoch 131/150
75/75 [==============================] - 0s 679us/sample - loss: 0.2862 - accuracy: 0.9733
Epoch 132/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2847 - accuracy: 0.9733
Epoch 133/150
75/75 [==============================] - 0s 854us/sample - loss: 0.2825 - accuracy: 0.9600
Epoch 134/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2806 - accuracy: 0.9600
Epoch 135/150
75/75 [==============================] - 0s 843us/sample - loss: 0.2780 - accuracy: 0.9600
Epoch 136/150
75/75 [==============================] - 0s 702us/sample - loss: 0.2765 - accuracy: 0.9733
Epoch 137/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2745 - accuracy: 0.9733
Epoch 138/150
75/75 [==============================] - 0s 691us/sample - loss: 0.2728 - accuracy: 0.9733
Epoch 139/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2722 - accuracy: 0.9600
Epoch 140/150
75/75 [==============================] - 0s 866us/sample - loss: 0.2688 - accuracy: 0.9600
Epoch 141/150
75/75 [==============================] - 0s 899us/sample - loss: 0.2664 - accuracy: 0.9733
Epoch 142/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2672 - accuracy: 0.9600
Epoch 143/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2675 - accuracy: 0.9467
Epoch 144/150
75/75 [==============================] - 0s 784us/sample - loss: 0.2617 - accuracy: 0.9600
Epoch 145/150
75/75 [==============================] - 0s 789us/sample - loss: 0.2597 - accuracy: 0.9733
Epoch 146/150
75/75 [==============================] - 0s 1ms/sample - loss: 0.2579 - accuracy: 0.9600
Epoch 147/150
75/75 [==============================] - ETA: 0s - loss: 0.2455 - accuracy: 0.97 - 0s 1ms/sample - loss: 0.2568 - accuracy: 0.9600
Epoch 148/150
75/75 [==============================] - 0s 668us/sample - loss: 0.2537 - accuracy: 0.9600
Epoch 149/150
75/75 [==============================] - 0s 886us/sample - loss: 0.2527 - accuracy: 0.9600
Epoch 150/150
75/75 [==============================] - 0s 988us/sample - loss: 0.2504 - accuracy: 0.9600
75/75 [==============================] - 0s 3ms/sample - loss: 0.2247 - accuracy: 1.0000
Acurácia: 100.00
In [22]:
model.save("./iris.h5")
In [24]:
# Setosa (1,0,0), Versicolor (0,1,0) e Virgínica (0,0,1)
new_iris_samples = np.array(
[[5.1, 3.3, 1.7, 0.5],
[5.9, 3.0, 4.2, 1.5],
[6.9, 3.1, 5.4, 2.1]], dtype=np.float32)
In [26]:
predictions = model.predict(new_iris_samples)
rounded = [[np.round(float(i), 0) for i in amostra] for amostra in predictions]
print(rounded)
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
In [ ]:
Content source: cleuton/datascience
Similar notebooks: