In [6]:
"""Example of DNNClassifier for Iris plant dataset, with run config."""
from sklearn import cross_validation
from sklearn import datasets
from sklearn import metrics
import tensorflow as tf
print(tf.__version__)
In [7]:
# load dataset
iris = datasets.load_iris()
x_train, x_test, y_train, y_test = cross_validation.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
run_config = tf.contrib.learn.estimators.RunConfig(
num_cores=3, gpu_memory_fraction=0.6)
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(x_train)
classifier = tf.contrib.learn.DNNClassifier(
feature_columns=feature_columns,
hidden_units = [10,20,10],
n_classes=3,
config=run_config)
# fit and predict
classifier.fit(x_train, y_train, steps=200)
predictions = list(classifier.predict(x_test, as_iterable=True))
score = metrics.accuracy_score(y_test, predictions)
print('Accuracy:{0:f}'.format(score))
In [ ]: