In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from sklearn import metrics
import tensorflow as tf

In [4]:
layers = tf.contrib.layers
learn = tf.contrib.learn

In [5]:
def max_pool_2x2(tensor_in):
    return tf.nn.max_pool(
        tensor_in, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

In [6]:
def conv_model(feature, target, mode):
    """2-layer convolution model."""
    target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0)
    feature = tf.reshape(feature, [-1, 28, 28, 1])
    with tf.variabale_score('conv_layer1'):
        h_conv1 = layers.convolution2d(
            feature, 32, kernel_size[5,5], activation_fn=tf.nn.relu)
        h_pool1 = max_pool_2x2(h_conv1)
    
    with tf.variable_scope('conv_layer2'):
        h_conv2 = layers.convolution2d(
            h_pool1, 64, kernel_size = [5,5], activation_fn=tf.nn.relu)
        h_pool2 = max_pool_2x2(h_conv2)
        h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    
    h_fc1 = layers.dropout(
        layers.fully_connected(
            h_pool2_flat, 1024, activation_fn=tf.nn.relu),
        keep_prob=0.5,
        is_training = (mode==tf.contrib.learn.ModeKeys.TRAIN)
        )
    
    logits = layers.fully_connected(h_fc1, 10, activation_fn=None)
    loss = tf.losses.softmax_cross_entropy(target, logits)
    
    train_op = layers.optimize_loss(
        loss,
        tf.contrib.framework.get_global_step(),
        optimizer='SGD',
        learning_rate=0.001)
    
    return tf.argmax(logits, 1), loss, train_op

In [7]:
def main(unused_args):
    # Download and load MNIST dataset
    mnist = learn.datasets.load_dataset('mnist')
    
    # Linear classifier
    feature_columns = learn.infer_real_valued_columns_from_input(
        mnist.train.images)
    classifier = learn.LinearClassifier(
        feature_columns=feature_columns, n_classes=10)
    classifier.fit(mnist.train.images, 
                   mnist.train.labels.astype(np.int32),
                  batch_size=100,
                  steps=1000)
    score = metrics.accuracy_score(mnist.test.labels,
                                  list(classifier.predict(mnist.test.images)))
    
    print('Accurary: {0:f}'.format(score))

In [8]:
if __name__ == '__main__':
    tf.app.run()


------------------------------------------------------------------------
TimeoutError                           Traceback (most recent call last)
/usr/lib/python3.5/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
   1253             try:
-> 1254                 h.request(req.get_method(), req.selector, req.data, headers)
   1255             except OSError as err: # timeout error

/usr/lib/python3.5/http/client.py in request(self, method, url, body, headers)
   1105         """Send a complete request to the server."""
-> 1106         self._send_request(method, url, body, headers)
   1107 

/usr/lib/python3.5/http/client.py in _send_request(self, method, url, body, headers)
   1150             body = _encode(body, 'body')
-> 1151         self.endheaders(body)
   1152 

/usr/lib/python3.5/http/client.py in endheaders(self, message_body)
   1101             raise CannotSendHeader()
-> 1102         self._send_output(message_body)
   1103 

/usr/lib/python3.5/http/client.py in _send_output(self, message_body)
    933 
--> 934         self.send(msg)
    935         if message_body is not None:

/usr/lib/python3.5/http/client.py in send(self, data)
    876             if self.auto_open:
--> 877                 self.connect()
    878             else:

/usr/lib/python3.5/http/client.py in connect(self)
    848         self.sock = self._create_connection(
--> 849             (self.host,self.port), self.timeout, self.source_address)
    850         self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

/usr/lib/python3.5/socket.py in create_connection(address, timeout, source_address)
    710     if err is not None:
--> 711         raise err
    712     else:

/usr/lib/python3.5/socket.py in create_connection(address, timeout, source_address)
    701                 sock.bind(source_address)
--> 702             sock.connect(sa)
    703             return sock

TimeoutError: [Errno 110] Connection timed out

During handling of the above exception, another exception occurred:

URLError                               Traceback (most recent call last)
<ipython-input-8-5be245e2ed29> in <module>()
      1 if __name__ == '__main__':
----> 2     tf.app.run()

/home/pageu/.local/lib/python3.5/site-packages/tensorflow/python/platform/app.py in run(main, argv)
     41   # Call the main function, passing through any arguments
     42   # to the final program.
---> 43   sys.exit(main(sys.argv[:1] + flags_passthrough))

<ipython-input-7-1c0d586d02df> in main(unused_args)
      1 def main(unused_args):
      2     # Download and load MNIST dataset
----> 3     mnist = learn.datasets.load_dataset('mnist')
      4 
      5     # Linear classifier

/home/pageu/.local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/__init__.py in load_dataset(name, size, test_with_fake_data)
     64     return DATASETS[name](size, test_with_fake_data)
     65   else:
---> 66     return DATASETS[name]()

/home/pageu/.local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py in load_mnist(train_dir)
    249 
    250 def load_mnist(train_dir='MNIST-data'):
--> 251   return read_data_sets(train_dir)

/home/pageu/.local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py in read_data_sets(train_dir, fake_data, one_hot, dtype, reshape, validation_size)
    209 
    210   local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
--> 211                                    SOURCE_URL + TRAIN_IMAGES)
    212   with open(local_file, 'rb') as f:
    213     train_images = extract_images(f)

/home/pageu/.local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py in maybe_download(filename, work_directory, source_url)
    206   filepath = os.path.join(work_directory, filename)
    207   if not gfile.Exists(filepath):
--> 208     temp_file_name, _ = urlretrieve_with_retry(source_url)
    209     gfile.Copy(temp_file_name, filepath)
    210     with gfile.GFile(filepath) as f:

/home/pageu/.local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py in wrapped_fn(*args, **kwargs)
    163       for delay in delays():
    164         try:
--> 165           return fn(*args, **kwargs)
    166         except Exception as e:  # pylint: disable=broad-except)
    167           if is_retriable is None:

/home/pageu/.local/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py in urlretrieve_with_retry(url, filename)
    188 @retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
    189 def urlretrieve_with_retry(url, filename=None):
--> 190   return urllib.request.urlretrieve(url, filename)
    191 
    192 

/usr/lib/python3.5/urllib/request.py in urlretrieve(url, filename, reporthook, data)
    186     url_type, path = splittype(url)
    187 
--> 188     with contextlib.closing(urlopen(url, data)) as fp:
    189         headers = fp.info()
    190 

/usr/lib/python3.5/urllib/request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    161     else:
    162         opener = _opener
--> 163     return opener.open(url, data, timeout)
    164 
    165 def install_opener(opener):

/usr/lib/python3.5/urllib/request.py in open(self, fullurl, data, timeout)
    464             req = meth(req)
    465 
--> 466         response = self._open(req, data)
    467 
    468         # post-process response

/usr/lib/python3.5/urllib/request.py in _open(self, req, data)
    482         protocol = req.type
    483         result = self._call_chain(self.handle_open, protocol, protocol +
--> 484                                   '_open', req)
    485         if result:
    486             return result

/usr/lib/python3.5/urllib/request.py in _call_chain(self, chain, kind, meth_name, *args)
    442         for handler in handlers:
    443             func = getattr(handler, meth_name)
--> 444             result = func(*args)
    445             if result is not None:
    446                 return result

/usr/lib/python3.5/urllib/request.py in http_open(self, req)
   1280 
   1281     def http_open(self, req):
-> 1282         return self.do_open(http.client.HTTPConnection, req)
   1283 
   1284     http_request = AbstractHTTPHandler.do_request_

/usr/lib/python3.5/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
   1254                 h.request(req.get_method(), req.selector, req.data, headers)
   1255             except OSError as err: # timeout error
-> 1256                 raise URLError(err)
   1257             r = h.getresponse()
   1258         except:

URLError: <urlopen error [Errno 110] Connection timed out>

In [ ]: