In [1]:
import argparse
import sys
import tempfile

In [2]:
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.platform import app

In [3]:
FLAGS = None

In [4]:
def build_estimator(model_dir):
    """Build an estimator"""
    params = tensor_forest.ForestHParams(
        num_classes=10, num_features=784,
        num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
    graph_builder_class = tensor_forest.RandomForestGraphs
    if FLAGS.use_training_loss:
        graph_builder_class = tensor_forest.TrainingLossForest
        
    return estimator.SKCompat(random_forest.TensorForestEstimator(
        params, graph_builder_class=graph_builder_class,
        model_dir=model_dir))

In [5]:
def train_and_eval():
    """Train and evaluate the model"""
    model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
    print('model dir = %s'% model_dir)
    
    est = build_estimator(model_dir)
    
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)
    
    est.fit(x=mnist.train.images, y=mnist.train.labels,
           batch_size=FLAGS.batch_size)
    metric_name = 'accuracy'
    metric = {metric_name:
             metric_spec.MetricSpec(
                 eval_metrics.get_metric(metric_name),
                 prediction_key=eval_metrics.get_prediction_key(metric_name))}
    
    results = est.score(x=mnist.test.images, y=mnist.test.labels,
                       batch_size=FLAGS.batch_size,
                       metrics=metric)
    for key in sorted(results):
        print('%s: %s'% (key, results[key]))

In [6]:
def main(_):
    train_and_eval()

In [ ]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_dir',
        type=str,
        default='',
        help='Base directory for output models.'
    )
    
    parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/data/',
      help='Directory for storing data'
    )
    parser.add_argument(
      '--train_steps',
      type=int,
      default=1000,
      help='Number of training steps.'
    )
    parser.add_argument(
      '--batch_size',
      type=str,
      default=1000,
      help='Number of examples in a training batch.'
      )
    parser.add_argument(
      '--num_trees',
      type=int,
      default=100,
      help='Number of trees in the forest.'
      )
    parser.add_argument(
      '--max_nodes',
      type=int,
      default=1000,
      help='Max total nodes in a single tree.'
      )
    parser.add_argument(
      '--use_training_loss',
      type=bool,
      default=False,
      help='If true, use training loss as termination criteria.'
        )
    FLAGS, unparsed = parser.parse_known_args()
    app.run(main=main, argv=[sys.argv[0]] + unparsed
           )


model dir = /tmp/tmpx57l4pc8
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_environment': 'local', '_keep_checkpoint_max': 5, '_tf_random_seed': None, '_is_chief': True, '_task_id': 0, '_evaluation_master': '', '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1.0
}
, '_num_ps_replicas': 0, '_save_checkpoints_steps': None, '_save_summary_steps': 100, '_master': '', '_task_type': None, '_keep_checkpoint_every_n_hours': 10000, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f779149cef0>, '_save_checkpoints_secs': 600}
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Constructing forest with params = 
INFO:tensorflow:{'split_initializations_per_input': 3, 'base_random_seed': 0, 'dominate_fraction': 0.99, 'num_splits_to_consider': 784, 'feature_bagging_fraction': 1.0, 'bagged_features': None, 'num_output_columns': 11, 'num_features': 784, 'num_trees': 100, 'dominate_method': 'bootstrap', 'num_classes': 10, 'bagging_fraction': 1.0, 'valid_leaf_threshold': 1, 'split_after_samples': 250, 'min_split_samples': 5, 'regression': False, 'max_nodes': 1000, 'bagged_num_features': 784, 'max_fertile_nodes': 500, 'num_outputs': 1}
INFO:tensorflow:training graph for tree: 0
INFO:tensorflow:training graph for tree: 1
INFO:tensorflow:training graph for tree: 2
INFO:tensorflow:training graph for tree: 3
INFO:tensorflow:training graph for tree: 4
INFO:tensorflow:training graph for tree: 5
INFO:tensorflow:training graph for tree: 6
INFO:tensorflow:training graph for tree: 7
INFO:tensorflow:training graph for tree: 8
INFO:tensorflow:training graph for tree: 9
INFO:tensorflow:training graph for tree: 10
INFO:tensorflow:training graph for tree: 11
INFO:tensorflow:training graph for tree: 12
INFO:tensorflow:training graph for tree: 13
INFO:tensorflow:training graph for tree: 14
INFO:tensorflow:training graph for tree: 15
INFO:tensorflow:training graph for tree: 16
INFO:tensorflow:training graph for tree: 17
INFO:tensorflow:training graph for tree: 18
INFO:tensorflow:training graph for tree: 19
INFO:tensorflow:training graph for tree: 20
INFO:tensorflow:training graph for tree: 21
INFO:tensorflow:training graph for tree: 22
INFO:tensorflow:training graph for tree: 23
INFO:tensorflow:training graph for tree: 24
INFO:tensorflow:training graph for tree: 25
INFO:tensorflow:training graph for tree: 26
INFO:tensorflow:training graph for tree: 27
INFO:tensorflow:training graph for tree: 28
INFO:tensorflow:training graph for tree: 29
INFO:tensorflow:training graph for tree: 30
INFO:tensorflow:training graph for tree: 31
INFO:tensorflow:training graph for tree: 32
INFO:tensorflow:training graph for tree: 33
INFO:tensorflow:training graph for tree: 34
INFO:tensorflow:training graph for tree: 35
INFO:tensorflow:training graph for tree: 36
INFO:tensorflow:training graph for tree: 37
INFO:tensorflow:training graph for tree: 38
INFO:tensorflow:training graph for tree: 39
INFO:tensorflow:training graph for tree: 40
INFO:tensorflow:training graph for tree: 41
INFO:tensorflow:training graph for tree: 42
INFO:tensorflow:training graph for tree: 43
INFO:tensorflow:training graph for tree: 44
INFO:tensorflow:training graph for tree: 45
INFO:tensorflow:training graph for tree: 46
INFO:tensorflow:training graph for tree: 47
INFO:tensorflow:training graph for tree: 48
INFO:tensorflow:training graph for tree: 49
INFO:tensorflow:training graph for tree: 50
INFO:tensorflow:training graph for tree: 51
INFO:tensorflow:training graph for tree: 52
INFO:tensorflow:training graph for tree: 53
INFO:tensorflow:training graph for tree: 54
INFO:tensorflow:training graph for tree: 55
INFO:tensorflow:training graph for tree: 56
INFO:tensorflow:training graph for tree: 57
INFO:tensorflow:training graph for tree: 58
INFO:tensorflow:training graph for tree: 59
INFO:tensorflow:training graph for tree: 60
INFO:tensorflow:training graph for tree: 61
INFO:tensorflow:training graph for tree: 62
INFO:tensorflow:training graph for tree: 63
INFO:tensorflow:training graph for tree: 64
INFO:tensorflow:training graph for tree: 65
INFO:tensorflow:training graph for tree: 66
INFO:tensorflow:training graph for tree: 67
INFO:tensorflow:training graph for tree: 68
INFO:tensorflow:training graph for tree: 69
INFO:tensorflow:training graph for tree: 70
INFO:tensorflow:training graph for tree: 71
INFO:tensorflow:training graph for tree: 72
INFO:tensorflow:training graph for tree: 73
INFO:tensorflow:training graph for tree: 74
INFO:tensorflow:training graph for tree: 75
INFO:tensorflow:training graph for tree: 76
INFO:tensorflow:training graph for tree: 77
INFO:tensorflow:training graph for tree: 78
INFO:tensorflow:training graph for tree: 79
INFO:tensorflow:training graph for tree: 80
INFO:tensorflow:training graph for tree: 81
INFO:tensorflow:training graph for tree: 82
INFO:tensorflow:training graph for tree: 83
INFO:tensorflow:training graph for tree: 84
INFO:tensorflow:training graph for tree: 85
INFO:tensorflow:training graph for tree: 86
INFO:tensorflow:training graph for tree: 87
INFO:tensorflow:training graph for tree: 88
INFO:tensorflow:training graph for tree: 89
INFO:tensorflow:training graph for tree: 90
INFO:tensorflow:training graph for tree: 91
INFO:tensorflow:training graph for tree: 92
INFO:tensorflow:training graph for tree: 93
INFO:tensorflow:training graph for tree: 94
INFO:tensorflow:training graph for tree: 95
INFO:tensorflow:training graph for tree: 96
INFO:tensorflow:training graph for tree: 97
INFO:tensorflow:training graph for tree: 98
INFO:tensorflow:training graph for tree: 99
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpx57l4pc8/model.ckpt.
INFO:tensorflow:loss = -0.0, step = 1
INFO:tensorflow:global_step/sec: 1.4717
INFO:tensorflow:loss = -268.5, step = 101
INFO:tensorflow:global_step/sec: 1.27748
INFO:tensorflow:loss = -572.08, step = 201
INFO:tensorflow:global_step/sec: 1.20549
INFO:tensorflow:loss = -882.02, step = 301
INFO:tensorflow:global_step/sec: 1.21635
INFO:tensorflow:loss = -998.0, step = 401
INFO:tensorflow:global_step/sec: 1.15713
INFO:tensorflow:loss = -998.0, step = 501
INFO:tensorflow:global_step/sec: 1.04715
INFO:tensorflow:loss = -998.0, step = 601
INFO:tensorflow:Saving checkpoints for 697 into /tmp/tmpx57l4pc8/model.ckpt.
INFO:tensorflow:global_step/sec: 0.90717
INFO:tensorflow:loss = -998.0, step = 701
INFO:tensorflow:global_step/sec: 1.04329
INFO:tensorflow:loss = -998.0, step = 801
INFO:tensorflow:global_step/sec: 1.04248
INFO:tensorflow:loss = -998.0, step = 901
INFO:tensorflow:global_step/sec: 1.0103
INFO:tensorflow:loss = -998.0, step = 1001
INFO:tensorflow:global_step/sec: 1.00127
INFO:tensorflow:loss = -998.0, step = 1101
INFO:tensorflow:global_step/sec: 1.06198
INFO:tensorflow:loss = -998.0, step = 1201
INFO:tensorflow:global_step/sec: 1.05387
INFO:tensorflow:loss = -998.0, step = 1301
INFO:tensorflow:Saving checkpoints for 1305 into /tmp/tmpx57l4pc8/model.ckpt.
INFO:tensorflow:global_step/sec: 0.854608
INFO:tensorflow:loss = -998.0, step = 1401

In [ ]: