Estimator(Custom)

  • Tensorflow High Level API
  • Tensorflow 공식 문서
  • 미리 정의된 모델(pre-made) 말고도 custom하게 estimator 사용 가능
  • tf.Session을 따로 관리할 필요 없으며, tf.global_variables_initializer() tf.local_variables_initializer()도 필요없음
  • 이 글에선 custom estimator에 대해 이야기함

구성 요소

  • input_fn() : feature, label return, feature는 dict으로!
  • model_fn(features, labels, mode) : mode별로 분기 => train은 loss, op, evaluate는 pred, accuracy, pred는 prob, class
  • est = tf.estimator.Estimator(model_fn)
    • est.train(input_fn, steps=500)
    • est.evaluate(input_fn, steps=10)
    • est.predict(pred_input_fn = tf.estimator.inputs.numpy_input_fn({'feature': data}))

참고 자료



In [43]:
import tensorflow as tf
import numpy as np
BATCH_SIZE = 100

input_fn


In [28]:
def input_fn():
    '''
    data load하고 feature, label을 return
    단, feature는 dict 형식으로 넣어서 predict때도 사용할 수 있도록 함
    '''
    dataset = tf.data.TextLineDataset("./test_data.csv")\
            .batch(2)\
            .repeat(999999)\
            .make_one_shot_iterator()\
            .get_next()
            
    lines = tf.decode_csv(dataset, record_defaults=[[0]]*10)
    feature = tf.stack(lines[1:], axis=1)
    label = tf.expand_dims(lines[0], axis=-1)

    feature = tf.cast(feature, tf.float32)
    label = tf.cast(label, tf.float32)
    
    return {'feature': feature}, label

Model


In [29]:
def model_fn(features, labels, mode):
    '''
    mode별로 분기 => train은 loss, op, evaluate는 pred, accuracy
    '''
    
    TRAIN = mode == tf.estimator.ModeKeys.TRAIN
    EVAL = mode == tf.estimator.ModeKeys.EVAL
    PRED = mode == tf.estimator.ModeKeys.PREDICT
    
    layer1 = tf.layers.dense(features["feature"], units=9, activation=tf.nn.relu)
    layer2 = tf.layers.dense(layer1, units=9, activation=tf.nn.relu)
    layer3 = tf.layers.dense(layer2, units=9, activation=tf.nn.relu)
    layer4 = tf.layers.dense(layer3, units=9, activation=tf.nn.relu)
    out = tf.layers.dense(layer4, units=1)
    
    if TRAIN:
        global_step = tf.train.get_global_step()
        loss = tf.losses.sigmoid_cross_entropy(labels, out)
        train_op = tf.train.GradientDescentOptimizer(1e-2).minimize(loss, global_step=global_step)
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
        
    elif EVAL:
        loss = tf.losses.sigmoid_cross_entropy(labels, out) # test loss
        pred = tf.nn.sigmoid(out)
        accuracy = tf.metrics.accuracy(labels, tf.round(pred))

        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops={'acc': accuracy})
        
    elif PRED:
        prob = tf.nn.sigmoid(out)
        _class = tf.round(prob)
        return tf.estimator.EstimatorSpec(mode=mode, predictions={'prob': prob, 'class': _class})

In [41]:
if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    est = tf.estimator.Estimator(model_fn)
    est.train(input_fn, steps=500)
    est.evaluate(input_fn, steps=10)
    
    data1 = np.array([1,2,3,4,5,6,7,8,9], np.float32)
    data2 = np.array([5,5,5,5,5,5,5,5,5], np.float32)
    data3 = np.array([9-i for i in range(9)], np.float32)
    data = np.stack([data1, data2, data3]) # 여러 데이터 input
    
    pred_input_fn = tf.estimator.inputs.numpy_input_fn({'feature': data}, shuffle=False)
    for d, pred in zip(data, est.predict(pred_input_fn)):
        print('feature: {}, prob: {}, class: {}'.format(d, pred['prob'], pred['class']))


INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /var/folders/f7/lrsclmhd6mx2hgq049xw8dv80000gn/T/tmpnjf2yyyn
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/f7/lrsclmhd6mx2hgq049xw8dv80000gn/T/tmpnjf2yyyn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x12322c9e8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/f7/lrsclmhd6mx2hgq049xw8dv80000gn/T/tmpnjf2yyyn/model.ckpt.
INFO:tensorflow:loss = 0.6336597, step = 1
INFO:tensorflow:global_step/sec: 773.971
INFO:tensorflow:loss = 0.063303, step = 101 (0.130 sec)
INFO:tensorflow:global_step/sec: 1717.8
INFO:tensorflow:loss = 0.011939027, step = 201 (0.058 sec)
INFO:tensorflow:global_step/sec: 1655.41
INFO:tensorflow:loss = 0.0056393184, step = 301 (0.060 sec)
INFO:tensorflow:global_step/sec: 1709.7
INFO:tensorflow:loss = 0.0034946266, step = 401 (0.058 sec)
INFO:tensorflow:Saving checkpoints for 500 into /var/folders/f7/lrsclmhd6mx2hgq049xw8dv80000gn/T/tmpnjf2yyyn/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0012561312.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-06-29-13:17:26
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/f7/lrsclmhd6mx2hgq049xw8dv80000gn/T/tmpnjf2yyyn/model.ckpt-500
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Finished evaluation at 2018-06-29-13:17:27
INFO:tensorflow:Saving dict for global step 500: acc = 1.0, global_step = 500, loss = 0.0018566784
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/f7/lrsclmhd6mx2hgq049xw8dv80000gn/T/tmpnjf2yyyn/model.ckpt-500
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
feature: [1. 2. 3. 4. 5. 6. 7. 8. 9.], prob: [0.99949217], class: [1.]
feature: [5. 5. 5. 5. 5. 5. 5. 5. 5.], prob: [0.6270031], class: [1.]
feature: [9. 8. 7. 6. 5. 4. 3. 2. 1.], prob: [0.00072852], class: [0.]