Standar usage of TensoFlow with model class

Tipically use 3 files:

  • data_utils.py: With the data access and batch generator functions
  • model.py: With the class model. A constructor with the graph definition and method to manage model needs
  • train.py: With parameters. Access to the data, instance the model and train it. Optionaly add a parameter to train or inference.

data_utils.py


In [1]:
#! /usr/bin/env python

import tensorflow as tf

# Access to the data
def get_data(data_dir='/tmp/MNIST_data'):
    from tensorflow.examples.tutorials.mnist import input_data
    return input_data.read_data_sets(data_dir, one_hot=True)


#Batch generator
def batch_generator(mnist, batch_size=256, type='train'):
    if type=='train':
        return mnist.train.next_batch(batch_size)
    else:
        return mnist.test.next_batch(batch_size)

model_mnist_cnn.py


In [2]:
#! /usr/bin/env python

import tensorflow as tf

class mnistCNN(object):
    """
    A NN for mnist classification.
    """
    def __init__(self, dense=500):
    
        # Placeholders for input, output and dropout
        self.input_x = tf.placeholder(tf.float32, [None, 784], name="input_x")
        self.input_y = tf.placeholder(tf.float32, [None, 10], name="input_y")
    
        # First layer
        self.dense_1 = self.dense_layer(self.input_x, input_dim=784, output_dim=dense)

        # Final layer
        self.dense_2 = self.dense_layer(self.dense_1, input_dim=dense, output_dim=10)

        self.predictions = tf.argmax(self.dense_2, 1, name="predictions")
        
        # Loss function
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.dense_2, self.input_y))
        
        # Accuracy
        correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
    

    def dense_layer(self, x, input_dim=10, output_dim=10, name='dense'):
        '''
        Dense layer function
        Inputs:
          x: Input tensor
          input_dim: Dimmension of the input tensor.
          output_dim: dimmension of the output tensor
          name: Layer name
        '''
        W = tf.Variable(tf.truncated_normal([input_dim, output_dim], stddev=0.1), name='W_'+name)
        b = tf.Variable(tf.constant(0.1, shape=[output_dim]), name='b_'+name)
        dense_output = tf.nn.relu(tf.matmul(x, W) + b)
        return dense_output

 train.py


In [3]:
#! /usr/bin/env python

from __future__ import print_function

import tensorflow as tf

#from data_utils import get_data, batch_generator
#from model_mnist_cnn import mnistCNN


# Parameters
# ==================================================

# Data loading params
tf.flags.DEFINE_string("data_directory", '/tmp/MNIST_data', "Data dir (default /tmp/MNIST_data)")

# Model Hyperparameters
tf.flags.DEFINE_integer("dense_size", 500, "dense_size (default 500)")

# Training parameters
tf.flags.DEFINE_float("learning_rate", 0.001, "learning rate (default: 0.001)")
tf.flags.DEFINE_integer("batch_size", 256, "Batch Size (default: 256)")
tf.flags.DEFINE_integer("num_epochs", 20, "Number of training epochs (default: 20)")

# Misc Parameters
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")

FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")


# Data Preparation
# ==================================================

#Access to the data
mnist_data = get_data(data_dir= FLAGS.data_directory)


# Training
# ==================================================

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333, allow_growth = True)
with tf.Graph().as_default():
    session_conf = tf.ConfigProto(
        gpu_options=gpu_options,
        log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        
        # Create model
        cnn = mnistCNN(dense=FLAGS.dense_size)
        
        # Trainer
        train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(cnn.loss)

        # Saver
        saver = tf.train.Saver(max_to_keep=1)

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        # Train proccess
        for epoch in range(FLAGS.num_epochs):
            for n_batch in range(int(55000/FLAGS.batch_size)):
                batch = batch_generator(mnist_data, batch_size=FLAGS.batch_size, type='train')
                _, ce = sess.run([train_op, cnn.loss], feed_dict={cnn.input_x: batch[0], cnn.input_y: batch[1]})

            print(epoch, ce)
        model_file = saver.save(sess, '/tmp/mnist_model')
        print('Model saved in', model_file)


Parameters:
BATCH_SIZE=256
DATA_DIRECTORY=/tmp/MNIST_data
DENSE_SIZE=500
LEARNING_RATE=0.001
LOG_DEVICE_PLACEMENT=False
NUM_EPOCHS=20

Extracting /tmp/MNIST_data/train-images-idx3-ubyte.gz
Extracting /tmp/MNIST_data/train-labels-idx1-ubyte.gz
Extracting /tmp/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting /tmp/MNIST_data/t10k-labels-idx1-ubyte.gz
0 0.91909
1 0.881748
2 0.777025
3 0.937195
4 0.783779
5 0.353136
6 0.256104
7 0.28514
8 0.277642
9 0.225344
10 0.301154
11 0.249453
12 0.324219
13 0.202852
14 0.244397
15 0.211011
16 0.199964
17 0.246017
18 0.289519
19 0.280718