Tipically use 3 files:
In [1]:
#! /usr/bin/env python
import tensorflow as tf
# Access to the data
def get_data(data_dir='/tmp/MNIST_data'):
from tensorflow.examples.tutorials.mnist import input_data
return input_data.read_data_sets(data_dir, one_hot=True)
#Batch generator
def batch_generator(mnist, batch_size=256, type='train'):
if type=='train':
return mnist.train.next_batch(batch_size)
else:
return mnist.test.next_batch(batch_size)
In [2]:
#! /usr/bin/env python
import tensorflow as tf
class mnistCNN(object):
"""
A NN for mnist classification.
"""
def __init__(self, dense=500):
# Placeholders for input, output and dropout
self.input_x = tf.placeholder(tf.float32, [None, 784], name="input_x")
self.input_y = tf.placeholder(tf.float32, [None, 10], name="input_y")
# First layer
self.dense_1 = self.dense_layer(self.input_x, input_dim=784, output_dim=dense)
# Final layer
self.dense_2 = self.dense_layer(self.dense_1, input_dim=dense, output_dim=10)
self.predictions = tf.argmax(self.dense_2, 1, name="predictions")
# Loss function
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.dense_2, self.input_y))
# Accuracy
correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
def dense_layer(self, x, input_dim=10, output_dim=10, name='dense'):
'''
Dense layer function
Inputs:
x: Input tensor
input_dim: Dimmension of the input tensor.
output_dim: dimmension of the output tensor
name: Layer name
'''
W = tf.Variable(tf.truncated_normal([input_dim, output_dim], stddev=0.1), name='W_'+name)
b = tf.Variable(tf.constant(0.1, shape=[output_dim]), name='b_'+name)
dense_output = tf.nn.relu(tf.matmul(x, W) + b)
return dense_output
In [3]:
#! /usr/bin/env python
from __future__ import print_function
import tensorflow as tf
#from data_utils import get_data, batch_generator
#from model_mnist_cnn import mnistCNN
# Parameters
# ==================================================
# Data loading params
tf.flags.DEFINE_string("data_directory", '/tmp/MNIST_data', "Data dir (default /tmp/MNIST_data)")
# Model Hyperparameters
tf.flags.DEFINE_integer("dense_size", 500, "dense_size (default 500)")
# Training parameters
tf.flags.DEFINE_float("learning_rate", 0.001, "learning rate (default: 0.001)")
tf.flags.DEFINE_integer("batch_size", 256, "Batch Size (default: 256)")
tf.flags.DEFINE_integer("num_epochs", 20, "Number of training epochs (default: 20)")
# Misc Parameters
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
# Data Preparation
# ==================================================
#Access to the data
mnist_data = get_data(data_dir= FLAGS.data_directory)
# Training
# ==================================================
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333, allow_growth = True)
with tf.Graph().as_default():
session_conf = tf.ConfigProto(
gpu_options=gpu_options,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Create model
cnn = mnistCNN(dense=FLAGS.dense_size)
# Trainer
train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(cnn.loss)
# Saver
saver = tf.train.Saver(max_to_keep=1)
# Initialize all variables
sess.run(tf.global_variables_initializer())
# Train proccess
for epoch in range(FLAGS.num_epochs):
for n_batch in range(int(55000/FLAGS.batch_size)):
batch = batch_generator(mnist_data, batch_size=FLAGS.batch_size, type='train')
_, ce = sess.run([train_op, cnn.loss], feed_dict={cnn.input_x: batch[0], cnn.input_y: batch[1]})
print(epoch, ce)
model_file = saver.save(sess, '/tmp/mnist_model')
print('Model saved in', model_file)