google的东西,先把版权放上来


In [1]:
# Copyright 2017 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [1]:
#导入各种包
import os
import os.path
import shutil
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
#设定目录和标签图,标签用来后面出数据图
LOGDIR = "D:\\temp\\"
LABELS = os.path.join(os.getcwd(), "labels_1024.tsv")
SPRITES =  os.path.join(os.getcwd(), "sprite_1024.png")

In [3]:
mnist =input_data.read_data_sets("/tmp/data/", one_hot=True)


---------------------------------------------------------------------------
ConnectionAbortedError                    Traceback (most recent call last)
F:\Anaconda\lib\urllib\request.py in do_open(self, http_class, req, **http_conn_args)
   1317                 h.request(req.get_method(), req.selector, req.data, headers,
-> 1318                           encode_chunked=req.has_header('Transfer-encoding'))
   1319             except OSError as err: # timeout error

F:\Anaconda\lib\http\client.py in request(self, method, url, body, headers, encode_chunked)
   1238         """Send a complete request to the server."""
-> 1239         self._send_request(method, url, body, headers, encode_chunked)
   1240 

F:\Anaconda\lib\http\client.py in _send_request(self, method, url, body, headers, encode_chunked)
   1284             body = _encode(body, 'body')
-> 1285         self.endheaders(body, encode_chunked=encode_chunked)
   1286 

F:\Anaconda\lib\http\client.py in endheaders(self, message_body, encode_chunked)
   1233             raise CannotSendHeader()
-> 1234         self._send_output(message_body, encode_chunked=encode_chunked)
   1235 

F:\Anaconda\lib\http\client.py in _send_output(self, message_body, encode_chunked)
   1025         del self._buffer[:]
-> 1026         self.send(msg)
   1027 

F:\Anaconda\lib\http\client.py in send(self, data)
    963             if self.auto_open:
--> 964                 self.connect()
    965             else:

F:\Anaconda\lib\http\client.py in connect(self)
   1399             self.sock = self._context.wrap_socket(self.sock,
-> 1400                                                   server_hostname=server_hostname)
   1401             if not self._context.check_hostname and self._check_hostname:

F:\Anaconda\lib\ssl.py in wrap_socket(self, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, session)
    406                          server_hostname=server_hostname,
--> 407                          _context=self, _session=session)
    408 

F:\Anaconda\lib\ssl.py in __init__(self, sock, keyfile, certfile, server_side, cert_reqs, ssl_version, ca_certs, do_handshake_on_connect, family, type, proto, fileno, suppress_ragged_eofs, npn_protocols, ciphers, server_hostname, _context, _session)
    813                         raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
--> 814                     self.do_handshake()
    815 

F:\Anaconda\lib\ssl.py in do_handshake(self, block)
   1067                 self.settimeout(None)
-> 1068             self._sslobj.do_handshake()
   1069         finally:

F:\Anaconda\lib\ssl.py in do_handshake(self)
    688         """Start the SSL/TLS handshake."""
--> 689         self._sslobj.do_handshake()
    690         if self.context.check_hostname:

ConnectionAbortedError: [WinError 10053] 你的主机中的软件中止了一个已建立的连接。

During handling of the above exception, another exception occurred:

URLError                                  Traceback (most recent call last)
<ipython-input-3-118ccd0cf8c9> in <module>()
----> 1 mnist =input_data.read_data_sets("/tmp/data/", one_hot=True)

F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py in read_data_sets(train_dir, fake_data, one_hot, dtype, reshape, validation_size, seed)
    233 
    234   local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
--> 235                                    SOURCE_URL + TRAIN_IMAGES)
    236   with open(local_file, 'rb') as f:
    237     train_images = extract_images(f)

F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py in maybe_download(filename, work_directory, source_url)
    206   filepath = os.path.join(work_directory, filename)
    207   if not gfile.Exists(filepath):
--> 208     temp_file_name, _ = urlretrieve_with_retry(source_url)
    209     gfile.Copy(temp_file_name, filepath)
    210     with gfile.GFile(filepath) as f:

F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py in wrapped_fn(*args, **kwargs)
    163       for delay in delays():
    164         try:
--> 165           return fn(*args, **kwargs)
    166         except Exception as e:  # pylint: disable=broad-except)
    167           if is_retriable is None:

F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py in urlretrieve_with_retry(url, filename)
    188 @retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
    189 def urlretrieve_with_retry(url, filename=None):
--> 190   return urllib.request.urlretrieve(url, filename)
    191 
    192 

F:\Anaconda\lib\urllib\request.py in urlretrieve(url, filename, reporthook, data)
    246     url_type, path = splittype(url)
    247 
--> 248     with contextlib.closing(urlopen(url, data)) as fp:
    249         headers = fp.info()
    250 

F:\Anaconda\lib\urllib\request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    221     else:
    222         opener = _opener
--> 223     return opener.open(url, data, timeout)
    224 
    225 def install_opener(opener):

F:\Anaconda\lib\urllib\request.py in open(self, fullurl, data, timeout)
    524             req = meth(req)
    525 
--> 526         response = self._open(req, data)
    527 
    528         # post-process response

F:\Anaconda\lib\urllib\request.py in _open(self, req, data)
    542         protocol = req.type
    543         result = self._call_chain(self.handle_open, protocol, protocol +
--> 544                                   '_open', req)
    545         if result:
    546             return result

F:\Anaconda\lib\urllib\request.py in _call_chain(self, chain, kind, meth_name, *args)
    502         for handler in handlers:
    503             func = getattr(handler, meth_name)
--> 504             result = func(*args)
    505             if result is not None:
    506                 return result

F:\Anaconda\lib\urllib\request.py in https_open(self, req)
   1359         def https_open(self, req):
   1360             return self.do_open(http.client.HTTPSConnection, req,
-> 1361                 context=self._context, check_hostname=self._check_hostname)
   1362 
   1363         https_request = AbstractHTTPHandler.do_request_

F:\Anaconda\lib\urllib\request.py in do_open(self, http_class, req, **http_conn_args)
   1318                           encode_chunked=req.has_header('Transfer-encoding'))
   1319             except OSError as err: # timeout error
-> 1320                 raise URLError(err)
   1321             r = h.getresponse()
   1322         except:

URLError: <urlopen error [WinError 10053] 你的主机中的软件中止了一个已建立的连接。>

In [5]:
# 定义conv layer
def conv_layer(input, size_in, size_out, name="conv"):
      with tf.name_scope(name):
        w = tf.Variable(tf.truncated_normal([5, 5, size_in, size_out], stddev=0.1), name="W")
        b = tf.Variable(tf.constant(0.1, shape=[size_out]), name="B")
        conv = tf.nn.conv2d(input, w, strides=[1, 1, 1, 1], padding="SAME")
        act = tf.nn.relu(conv + b)
        tf.summary.histogram("weights", w)
        tf.summary.histogram("biases", b)
        tf.summary.histogram("activations", act)
        return tf.nn.max_pool(act, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")

In [6]:
#定义 fc layer
def fc_layer(input, size_in, size_out, name="fc"):
      with tf.name_scope(name):
        w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1), name="W")
        b = tf.Variable(tf.constant(0.1, shape=[size_out]), name="B")
        act = tf.matmul(input, w) + b
        tf.summary.histogram("weights", w)
        tf.summary.histogram("biases", b)
        tf.summary.histogram("activations", act)
        return act

In [7]:
# 定义 model
def mnist_model(learning_rate, use_two_fc, use_two_conv, hparam):
    tf.reset_default_graph()
    sess = tf.Session()

    # Setup placeholders, and reshape the data
    x = tf.placeholder(tf.float32, shape=[None, 784], name="x")
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    tf.summary.image('input', x_image, 3)
    y = tf.placeholder(tf.float32, shape=[None, 10], name="labels")

    if use_two_conv:
    conv1 = conv_layer(x_image, 1, 32, "conv1")
    conv_out = conv_layer(conv1, 32, 64, "conv2")
    else:
    conv1 = conv_layer(x_image, 1, 64, "conv")
    conv_out = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")

    flattened = tf.reshape(conv_out, [-1, 7 * 7 * 64])


    if use_two_fc:
        fc1 = fc_layer(flattened, 7 * 7 * 64, 1024, "fc1")
        relu = tf.nn.relu(fc1)
        embedding_input = relu
        tf.summary.histogram("fc1/relu", relu)
        embedding_size = 1024
        logits = fc_layer(fc1, 1024, 10, "fc2")
    else:
        embedding_input = flattened
        embedding_size = 7*7*64
        logits = fc_layer(flattened, 7*7*64, 10, "fc")

    with tf.name_scope("xent"):
        xent = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(
                logits=logits, labels=y), name="xent")
        tf.summary.scalar("xent", xent)

    with tf.name_scope("train"):
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(xent)

    with tf.name_scope("accuracy"):
        correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar("accuracy", accuracy)

    summ = tf.summary.merge_all()


    embedding = tf.Variable(tf.zeros([1024, embedding_size]), name="test_embedding")
    assignment = embedding.assign(embedding_input)
    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter(LOGDIR + hparam)
    writer.add_graph(sess.graph)

    config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
    embedding_config = config.embeddings.add()
    embedding_config.tensor_name = embedding.name
    embedding_config.sprite.image_path = SPRITES
    embedding_config.metadata_path = LABELS
    # Specify the width and height of a single thumbnail.
    embedding_config.sprite.single_image_dim.extend([28, 28])
    tf.contrib.tensorboard.plugins.projector.visualize_embeddings(writer, config)

    for i in range(2001):
        batch = mnist.train.next_batch(100)
        if i % 5 == 0:
            [train_accuracy, s] = sess.run([accuracy, summ], feed_dict={x: batch[0], y: batch[1]})
        if i % 500 == 0:
            sess.run(assignment, feed_dict={x: mnist.test.images[:1024], y: mnist.test.labels[:1024]})
            saver.save(sess, os.path.join(LOGDIR, "model.ckpt"), i)
        sess.run(train_step, feed_dict={x: batch[0], y: batch[1]})

In [8]:
# 定义参数组合,用于后面看图
def make_hparam_string(learning_rate, use_two_fc, use_two_conv):
    conv_param = "conv=2" if use_two_conv else "conv=1"
    fc_param = "fc=2" if use_two_fc else "fc=1"
    return "lr_%.0E,%s,%s" % (learning_rate, conv_param, fc_param)

In [9]:
# 定义主函数,主要包括训练等
def main():
  # You can try adding some more learning rates
  for learning_rate in [1E-3,1E-4]:

    # Include "False" as a value to try different model architectures
    for use_two_fc in [True]:
        for use_two_conv in [True,False]:
        # Construct a hyperparameter string for each one (example: "lr_1E-3,fc=2,conv=2")
        hparam = make_hparam_string(learning_rate, use_two_fc, use_two_conv)
        print('Starting run for %s' % hparam)

        # Actually run with the new settings
        mnist_model(learning_rate, use_two_fc, use_two_conv, hparam)
    print('Done training!')
    print('Run `tensorboard --logdir=%s` to see the results.' % LOGDIR)
    print('Running on mac? If you want to get rid of the dialogue asking to give '
        'network permissions to TensorBoard, you can provide this flag: '
        '--host=localhost')

In [10]:
#运行
main()


Starting run for lr_1E-03,conv=2,fc=2
Done training!
Run `tensorboard --logdir=D:\temp\3\` to see the results.
Running on mac? If you want to get rid of the dialogue asking to give network permissions to TensorBoard, you can provide this flag: --host=localhost