In [1]:
"""
reference paper: http://arxiv.org/pdf/1512.03385.pdf
"""

from collections import namedtuple
from math import sqrt
import os

import tensorflow as tf

In [18]:
batch_norm = tf.contrib.layers.batch_norm
convolution2d = tf.contrib.layers.convolution2d

In [19]:
def res_net(x, y, activation=tf.nn.relu):
    """Build a residual network
    Args:
        x: input of the network
        y: output of the network
        activation: activation function to apply after each convolution
    
    Returns:
        predictions and loss tensors.
    """
    BottleneckGroup = namedtuple('BottleneckGroup',
                               ['num_blocks', 'num_filters', 'bottleneck_size'])
    
    groups = [
        BottleneckGroup(3, 128, 32), BottleneckGroup(3, 256, 64),
        BottleneckGroup(3, 512, 128), BottleneckGroup(3, 1024, 256)
    ]
    
    input_shape = x.get_shape().as_list()
    
    if len(input_shape) == 2:
        ndim = int(sqrt(input_shape[1]))
        x = tf.reshape(x, [-1, ndim, ndim, 1])
        
    with tf.variable_scope('conv_layer1'):
        net = convolution2d(
            x, 64, 7, normalizer_fn=batch_norm, activation_fn=activation)
        
    net = tf.nn.max_pool(net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
    
    with tf.variable_scope('conv_layer2'):
        net = convolution2d(net, groups[0].num_filters, 1, padding='VALID')
        
    for group_i, group in enumerate(groups):
        for block_i in range(group.num_blocks):
            name = 'group_%d/block_%d' % (group_i, block_i)
            
            with tf.variable_scope(name + '/conv_in'):
                conv = convolution2d(net, group.bottleneck_size, 1, padding='VALID', 
                                    activation_fn=activation, normalizer_fn=batch_norm)
            with tf.variable_scope(name + 'conv_bottleneck'):
                conv = convolution2d(conv, group.bottleneck_size, 3, padding='SAME',
                                    activation_fn=activation, normalizer_fn=batch_norm)
            with tf.variable_scope(name + '/conv_out'):
                input_dim = net.get_shape()[-1].value
                conv = convolution2d(conv, input_dim, 1, padding='VALID',
                                     activation_fn=activation, normalizer_fn=batch_norm)
            net = conv + net
        
        try:
            next_group = groups[group_i + 1]
            with tf.variable_scope('block_%d/conv_upscale' % group_i):
                net = convolution2d(net, next_group.num_filters, 1, activation_fn=activation,
                                   biases_initializer=None, padding='SAME')
        except IndexError:
            pass
        
    net_shape = net.get_shape().as_list()
    net = tf.nn.avg_pool(net, ksize=[1, net.shape[1], net.shape[2], 1],
                        strides=[1, 1, 1, 1], padding='VALID')
    
    net_shape = net.get_shape().as_list()
    net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
    
    target = tf.one_hot(y, depth=10, dtype=tf.float32)
    logits = tf.contrib.layers.fully_connected(net, 10, activation_fn=None)
    loss = tf.losses.softmax_cross_entropy(target, logits)
    return tf.nn.softmax(logits), loss

In [20]:
def res_net_mode(x, y):
    prediction, loss = res_net(x, y)
    predicted = tf.argmax(prediction, 1)
    accuracy = tf.equal(predicted, tf.cast(y, tf.int64))
    predictions = {'prob':prediction, 'class':predicted, 'accurcy':accuracy}
    train_op = tf.contrib.layers.optimize_loss(loss, tf.contrib.framework.get_global_step(),
                                              optimizer='Adagrad',
                                              learning_rate=0.001)
    return predictions, loss, train_op

In [21]:
def main():
    mnist = tf.contrib.learn.datasets.load_dataset('mnist')
    
    classifier = tf.contrib.learn.Estimator(model_fn=res_net_mode)
    
    tf.logging.set_verbosity(tf.logging.INFO)
    
    classifier.fit(mnist.train.images, mnist.train.labels,
                  batch_size=100, steps=1000)
    
    result = classifier.evaluate(
        x=mnist.test.images, 
        y=mnist.test.labels,
        metrics={'accuracy': tf.contrib.learn.MetricSpec(
                    metrics_fn=tf.contrib.metrics.streaming_accuracy,
                    prediction_key='accuracy')})
    
    score = result['accuracy']
    print('Accuracy: {0:f}'.format(score))
    
if __name__ == '__main__':
    main()


Extracting MNIST-data/train-images-idx3-ubyte.gz
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:Using temporary folder as model directory: /var/folders/4n/dvmcx9mx3sgcp8mk0r_3xkhm0000gn/T/tmpujaeu0eg
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_tf_random_seed': None, '_save_checkpoints_secs': 600, '_task_type': None, '_task_id': 0, '_is_chief': True, '_evaluation_master': '', '_keep_checkpoint_max': 5, '_save_checkpoints_steps': None, '_environment': 'local', '_save_summary_steps': 100, '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1
}
, '_num_ps_replicas': 0, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x125f08438>, '_keep_checkpoint_every_n_hours': 10000, '_master': ''}
WARNING:tensorflow:From <ipython-input-21-e4260c172c53>:9: calling BaseEstimator.fit (from tensorflow.contrib.learn.python.learn.estimators.estimator) with x is deprecated and will be removed after 2016-12-01.
Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
WARNING:tensorflow:From <ipython-input-21-e4260c172c53>:9: calling BaseEstimator.fit (from tensorflow.contrib.learn.python.learn.estimators.estimator) with y is deprecated and will be removed after 2016-12-01.
Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
WARNING:tensorflow:From <ipython-input-21-e4260c172c53>:9: calling BaseEstimator.fit (from tensorflow.contrib.learn.python.learn.estimators.estimator) with batch_size is deprecated and will be removed after 2016-12-01.
Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
/usr/local/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py:247: FutureWarning: comparison to `None` will result in an elementwise object comparison in the future.
  equality = a == b
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/4n/dvmcx9mx3sgcp8mk0r_3xkhm0000gn/T/tmpujaeu0eg/model.ckpt.
INFO:tensorflow:loss = 5.54514, step = 1
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-21-e4260c172c53> in <module>()
     20 
     21 if __name__ == '__main__':
---> 22     main()

<ipython-input-21-e4260c172c53> in main()
      7 
      8     classifier.fit(mnist.train.images, mnist.train.labels,
----> 9                   batch_size=100, steps=1000)
     10 
     11     result = classifier.evaluate(

/usr/local/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    278             _call_location(), decorator_utils.get_qualified_name(func),
    279             func.__module__, arg_name, date, instructions)
--> 280       return func(*args, **kwargs)
    281     new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
    282         func.__doc__, date, instructions)

/usr/local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in fit(self, x, y, input_fn, steps, batch_size, monitors, max_steps)
    408     _verify_input_args(x, y, input_fn, None, batch_size)
    409     if x is not None:
--> 410       SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
    411       return self
    412 

/usr/local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in fit(self, x, y, batch_size, steps, max_steps, monitors)
   1351                         steps=steps,
   1352                         max_steps=max_steps,
-> 1353                         monitors=all_monitors)
   1354     return self
   1355 

/usr/local/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    278             _call_location(), decorator_utils.get_qualified_name(func),
    279             func.__module__, arg_name, date, instructions)
--> 280       return func(*args, **kwargs)
    281     new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
    282         func.__doc__, date, instructions)

/usr/local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in fit(self, x, y, input_fn, steps, batch_size, monitors, max_steps)
    424       hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
    425 
--> 426     loss = self._train_model(input_fn=input_fn, hooks=hooks)
    427     logging.info('Loss for final step: %s.', loss)
    428     return self

/usr/local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in _train_model(self, input_fn, hooks)
    982         loss = None
    983         while not mon_sess.should_stop():
--> 984           _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
    985       summary_io.SummaryWriterCache.clear()
    986       return loss

/usr/local/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
    460                           feed_dict=feed_dict,
    461                           options=options,
--> 462                           run_metadata=run_metadata)
    463 
    464   def should_stop(self):

/usr/local/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
    784                               feed_dict=feed_dict,
    785                               options=options,
--> 786                               run_metadata=run_metadata)
    787       except errors.AbortedError:
    788         logging.info('An AbortedError was raised. Closing the current session. '

/usr/local/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
    742 
    743   def run(self, *args, **kwargs):
--> 744     return self._sess.run(*args, **kwargs)
    745 
    746 

/usr/local/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, fetches, feed_dict, options, run_metadata)
    889                                   feed_dict=feed_dict,
    890                                   options=options,
--> 891                                   run_metadata=run_metadata)
    892 
    893     for hook in self._hooks:

/usr/local/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py in run(self, *args, **kwargs)
    742 
    743   def run(self, *args, **kwargs):
--> 744     return self._sess.run(*args, **kwargs)
    745 
    746 

/usr/local/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/usr/local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/usr/local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/usr/local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: