In [1]:
from sklearn import model_selection
from sklearn import datasets
from sklearn import metrics
import tensorflow as tf
import numpy as np
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
learn = tf.contrib.learn

1. 自定义softmax回归模型。


In [2]:
def my_model(features, target):
    target = tf.one_hot(target, 3, 1, 0)
    
    # 计算预测值及损失函数。
    logits = tf.contrib.layers.fully_connected(features, 3, tf.nn.softmax)
    loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
    
    # 创建优化步骤。
    train_op = tf.contrib.layers.optimize_loss(
        loss,
        tf.contrib.framework.get_global_step(),
        optimizer='Adam',
        learning_rate=0.01)
    return tf.arg_max(logits, 1), loss, train_op

2. 读取数据并将数据转化成TensorFlow要求的float32格式。


In [3]:
iris = datasets.load_iris()
x_train, x_test, y_train, y_test = model_selection.train_test_split(
    iris.data, iris.target, test_size=0.2, random_state=0)

x_train, x_test = map(np.float32, [x_train, x_test])

3. 封装和训练模型,输出准确率。


In [4]:
classifier = SKCompat(learn.Estimator(model_fn=my_model, model_dir="Models/model_1"))
classifier.fit(x_train, y_train, steps=800)

y_predicted = [i for i in classifier.predict(x_test)]
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy: %.2f%%' % (score * 100))


INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'save_summary_steps': 100, '_num_ps_replicas': 0, '_task_type': None, '_environment': 'local', '_is_chief': True, 'save_checkpoints_secs': 600, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11cc47d90>, 'tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1
}
, '_task_id': 0, 'tf_random_seed': None, 'keep_checkpoint_every_n_hours': 10000, '_evaluation_master': '', 'save_checkpoints_steps': None, '_master': '', 'keep_checkpoint_max': 5}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:loss = 1.21793, step = 1
INFO:tensorflow:Saving checkpoints for 1 into /tmp/ymx2/model.ckpt.
WARNING:tensorflow:*******************************************************
WARNING:tensorflow:TensorFlow's V1 checkpoint format has been deprecated.
WARNING:tensorflow:Consider switching to the more efficient V2 format:
WARNING:tensorflow:   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`
WARNING:tensorflow:now on by default.
WARNING:tensorflow:*******************************************************
INFO:tensorflow:loss = 0.837075, step = 101
INFO:tensorflow:global_step/sec: 270.689
INFO:tensorflow:loss = 0.758434, step = 201
INFO:tensorflow:global_step/sec: 727.854
INFO:tensorflow:loss = 0.70665, step = 301
INFO:tensorflow:global_step/sec: 656.759
INFO:tensorflow:loss = 0.675011, step = 401
INFO:tensorflow:global_step/sec: 705.442
INFO:tensorflow:loss = 0.654652, step = 501
INFO:tensorflow:global_step/sec: 762.655
INFO:tensorflow:loss = 0.640652, step = 601
INFO:tensorflow:global_step/sec: 737.675
INFO:tensorflow:loss = 0.630499, step = 701
INFO:tensorflow:global_step/sec: 465.224
INFO:tensorflow:Saving checkpoints for 800 into /tmp/ymx2/model.ckpt.
WARNING:tensorflow:*******************************************************
WARNING:tensorflow:TensorFlow's V1 checkpoint format has been deprecated.
WARNING:tensorflow:Consider switching to the more efficient V2 format:
WARNING:tensorflow:   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`
WARNING:tensorflow:now on by default.
WARNING:tensorflow:*******************************************************
INFO:tensorflow:Loss for final step: 0.622899.
INFO:tensorflow:Loading model from checkpoint: /tmp/ymx2/model.ckpt-800-?????-of-00001.
Accuracy: 100.00%