In [ ]:
import tensorflow as tf
import numpy as np
import os
import scipy.io
import glob
import random
import BatchDatsetReader as dataset
import scipy.misc as misc

# reset the graph
tf.reset_default_graph()

# reset tf.flags.FLAGS
import argparse
tf.reset_default_graph()
tf.flags.FLAGS = tf.python.platform.flags._FlagValues()
tf.flags._global_parser = argparse.ArgumentParser()

# set tf.flags.FLAGS
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer("batch_size","2","batch size for training")
tf.flags.DEFINE_string("logs_dir","logs/","path to logs directory")
tf.flags.DEFINE_string("data_dir","Data_zoo/MIT_SceneParsing/","path to dataset")
tf.flags.DEFINE_string("pickle_name","MITSceneParsing.pickle","pickle file of the data")
tf.flags.DEFINE_string("data_url","http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip","url of the data")
tf.flags.DEFINE_float("learning_rate","1e-4","learning rate for the optimizier")
tf.flags.DEFINE_string("model_dir","Model_zoo/","path to vgg model mat")
tf.flags.DEFINE_bool("debug","True","Debug model: True/False")
tf.flags.DEFINE_string("mode","train","Mode: train/ valid")
tf.flags.DEFINE_float("max_iters","1e+5","max training iterations of batches")
tf.flags.DEFINE_integer("num_classes","151","mit_sceneparsing with (150+1) classes")
tf.flags.DEFINE_integer("image_size","224","can be variable in deed")
tf.flags.DEFINE_string("model_weights","http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat","pretrained weights of the CNN in use")
tf.flags.DEFINE_string("full_model","full_model/","trained parameters of the whole network")
tf.flags.DEFINE_string("full_model_file"," ","pretrained parameters of the whole network")
tf.flags.DEFINE_bool("load","False","load in pretrained parameters")

# check if the CNN weights folder exist
if not os.path.exists(FLAGS.model_dir):
    os.makedirs(FLAGS.model_dir)

# check if the CNN weights file exist
weights_file = os.path.join(FLAGS.model_dir,FLAGS.model_weights.split('/')[-1])
if not os.path.exists(weights_file):
    print("\ndownloading "+weights_file+" ...")
    os.system("wget "+FLAGS.model_weights+" -P "+FLAGS.model_dir)
    print("download finished!\n")
else:
    print("\n"+weights_file+" has already been downloaded.\n")

# load the weights file
print("\nloading pretrained weights from: "+weights_file)
pretrain_weights = scipy.io.loadmat(weights_file)
print("loading finished!\n")
    
# the mean RGB
mean = pretrain_weights['normalization'][0][0][0] # shape(224,224,3)
mean_pixel = np.mean(mean,axis=(0,1)) # average on (height,width) to compute the mean RGB   

# the weights and biases
weights_biases = np.squeeze(pretrain_weights['layers'])

# network input data
dropout_prob = tf.placeholder(tf.float32,name="dropout_probability")
images = tf.placeholder(tf.float32,shape=[None,None,None,3],name="input_images")
annotations = tf.placeholder(tf.uint8,shape=[None,None,None,1],name="input_annotations")

# subtract the mean image
processed_image = images - mean_pixel

# construct the semantic_seg network
with tf.variable_scope("semantic_seg"):
    # convs of the vgg net
    net = {}
    layers = [
        'conv1_1','relu1_1','conv1_2','relu1_2','pool1',
        'conv2_1','relu2_1','conv2_2','relu2_2','pool2',
        'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3',
        'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4',
        'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4','pool5'
    ]
    current = processed_image
    for i,name in enumerate(layers):
        type = name[:4]
        if type == 'conv':
            # matconvnet weights: (width, height, in_channels, out_channels)
            # tensorflow weights: (height, width, in_channels, out_channels)
            weights, biases = weights_biases[i][0][0][0][0]
            weights = np.transpose(weights,(1,0,2,3)) 
            biases = np.squeeze(biases)
            
            init = tf.constant_initializer(weights,dtype=tf.float32)
            weights = tf.get_variable(initializer=init,shape=weights.shape,name=name+"_w")
            
            init = tf.constant_initializer(biases,dtype=tf.float32)
            biases = tf.get_variable(initializer=init,shape=biases.shape,name=name+"_b")
            
            current = tf.nn.conv2d(current,weights,strides=[1,1,1,1],padding="SAME")
            current = tf.nn.bias_add(current,biases,name=name)
        elif type == 'relu':
            current = tf.nn.relu(current,name=name)
            if FLAGS.debug:
                tf.histogram_summary(current.op.name+"/activation",current)
                tf.scalar_summary(current.op.name+"/sparsity",tf.nn.zero_fraction(current))
        elif type == 'pool':
            if name == 'pool5':
                current = tf.nn.max_pool(current,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name=name)
            else:
                current = tf.nn.avg_pool(current,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name=name)
        net[name] = current
                
     # fcn6
    init = tf.truncated_normal(shape=[7,7,512,4096],stddev=0.02)
    fcn6_w = tf.get_variable(initializer=init,name="fcn6_w")

    init = tf.constant(0.0,shape=[4096])
    fcn6_b = tf.get_variable(initializer=init,name="fcn6_b")

    fcn6 = tf.nn.conv2d(current,fcn6_w,strides=[1,1,1,1],padding="SAME")
    fcn6 = tf.nn.bias_add(fcn6,fcn6_b,name="fcn6")

    relu6 = tf.nn.relu(fcn6,name="relu6")
    if FLAGS.debug:
        tf.histogram_summary("relu6/activation", relu6, collections=None, name=None)
        tf.scalar_summary("relu6/sparsity", tf.nn.zero_fraction(relu6), collections=None, name=None)
    dropout6 = tf.nn.dropout(relu6, keep_prob=dropout_prob, noise_shape=None, seed=None, name="dropout6")

     # fcn7
    init = tf.truncated_normal(shape=[1,1,4096,4096],stddev=0.02)
    fcn7_w = tf.get_variable(initializer=init,name="fcn7_w")

    init = tf.constant(0.0,shape=[4096])
    fcn7_b = tf.get_variable(initializer=init,name="fcn7_b")

    fcn7 = tf.nn.conv2d(dropout6, fcn7_w, strides=[1,1,1,1], padding="SAME", use_cudnn_on_gpu=None, data_format=None, name=None)
    fcn7 = tf.nn.bias_add(fcn7, fcn7_b, data_format=None, name="fcn7")

    relu7 = tf.nn.relu(fcn7,name="relu7")
    if FLAGS.debug:
        tf.histogram_summary("relu7/activation", relu7, collections=None, name=None)
        tf.scalar_summary("relu7/sparsity", tf.nn.zero_fraction(relu7), collections=None, name=None)
    dropout7 = tf.nn.dropout(relu7, keep_prob=dropout_prob, noise_shape=None, seed=None, name="dropout7")

    # fcn8
    init = tf.truncated_normal(shape=[1,1,4096,FLAGS.num_classes],stddev=0.02)
    fcn8_w = tf.get_variable(initializer=init,name="fcn8_w")

    init = tf.constant(0.0,shape=[FLAGS.num_classes])
    fcn8_b = tf.get_variable(initializer=init,name="fcn8_b")

    fcn8 = tf.nn.conv2d(dropout7, fcn8_w, strides=[1,1,1,1], padding="SAME", use_cudnn_on_gpu=None, data_format=None, name=None)
    fcn8 = tf.nn.bias_add(fcn8, fcn8_b, data_format=None, name="fcn8")

    # deconv1 + net['pool4']: x32 -> x16
    s = 2
    k = 2*s
    in_channel = FLAGS.num_classes
    out_channel = net['pool4'].get_shape()[3].value
    out_shape = tf.shape(net['pool4'])

    init = tf.truncated_normal(shape=[k,k,out_channel,in_channel],stddev=0.02)
    deconv1_w = tf.get_variable(initializer=init,name="deconv1_w")

    init = tf.constant(0.0,shape=[out_channel])
    deconv1_b = tf.get_variable(initializer=init,name="deconv1_b")

    deconv1 = tf.nn.conv2d_transpose(fcn8, deconv1_w, output_shape=out_shape, strides=[1,s,s,1], padding='SAME', name=None)
    deconv1 = tf.nn.bias_add(deconv1, deconv1_b, data_format=None, name="deconv1")

    fuse1 = tf.add(deconv1, net['pool4'], name="fuse1")
            
    # deconv2 + net['pool3']: x16 -> x8
    s = 2
    k = 2*s
    in_channel = out_channel
    out_channel = net['pool3'].get_shape()[3].value
    out_shape = tf.shape(net['pool3'])

    init = tf.truncated_normal(shape=[k,k,out_channel,in_channel],stddev=0.02)
    deconv2_w = tf.get_variable(initializer=init,name="deconv2_w")

    init = tf.constant(0.0,shape=[out_channel])
    deconv2_b = tf.get_variable(initializer=init,name="deconv2_b")

    deconv2 = tf.nn.conv2d_transpose(fuse1, deconv2_w, output_shape=out_shape, strides=[1,s,s,1], padding='SAME', name=None)
    deconv2 = tf.nn.bias_add(deconv2, deconv2_b, data_format=None, name="deconv2")

    fuse2 = tf.add(deconv2,net['pool3'],name="fuse2")

    # deconv3: x8 -> image_size
    s = 8
    k = 2*s
    in_channel = out_channel
    out_channel = FLAGS.num_classes
    out_shape = tf.pack([tf.shape(processed_image)[0],tf.shape(processed_image)[1],tf.shape(processed_image)[2],out_channel])
            
    init = tf.truncated_normal(shape=[k,k,out_channel,in_channel],stddev=0.02)
    deconv3_w = tf.get_variable(initializer=init,name="deconv3_w")

    init = tf.constant(0.0,shape=[out_channel])
    deconv3_b = tf.get_variable(initializer=init,name="deconv3_b")

    deconv3 = tf.nn.conv2d_transpose(fuse2, deconv3_w, output_shape=out_shape, strides=[1,s,s,1], padding='SAME', name=None)
    deconv3 = tf.nn.bias_add(deconv3, deconv3_b, data_format=None, name="deconv3")

    # per-pixel prediction
    annotations_pred = tf.argmax(deconv3, dimension=3, name=None)
    annotations_pred = tf.expand_dims(annotations_pred, dim=3, name="prediction")

# log images, annotations, annotations_pred
tf.image_summary("images", images, max_images=2, collections=None, name=None)
tf.image_summary("annotations", tf.cast(annotations,tf.uint8), max_images=2, collections=None, name=None)
tf.image_summary("annotations_pred", tf.cast(annotations_pred,tf.uint8), max_images=2, collections=None, name=None)

# construct the loss
loss = tf.nn.softmax_cross_entropy_with_logits(deconv3, annotations, dim=3, name=None)
loss = tf.reduce_mean(loss, reduction_indices=None, keep_dims=False, name="pixel-wise_cross-entropy_loss")

# log the loss
tf.scalar_summary("pixel-wise_cross-entropy_loss", loss, collections=None, name=None)

# log all the trainable variables
trainabel_vars = tf.trainable_variables()
if FLAGS.debug:
    for var in trainabel_vars:
        tf.histogram_summary(var.op.name+"/values", var, collections=None, name=None)
        tf.add_to_collection("sum(t ** 2) / 2 of all trainable_vars", tf.nn.l2_loss(var))
        
# construct the optimizier
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
gradients = optimizer.compute_gradients(loss,trainabel_vars)
if FLAGS.debug:
    # log the gradients
    for grad, var in gradients:
        tf.histogram_summary(var.op.name+"/gradients", grad, collections=None, name=None)
train_op = optimizer.apply_gradients(gradients)

# initialize the variables
print("\nInitializing the variables ...\n")
sess = tf.InteractiveSession()
tf.initialize_all_variables().run()

# set up the saver
print("\nSetting up the Saver ...\n")
saver = tf.train.Saver()
if FLAGS.load:
    print("\nLoading pretrain parameters of the whole network ...\n")
    saver.restore(sess, FLAGS.full_model_file)

# set the summary writer
print("\nSetting the summary writers ...\n")
summary_op = tf.merge_all_summaries()
if not os.path.exists(FLAGS.logs_dir):
    os.system("mkdir "+FLAGS.logs_dir)
if FLAGS.mode == 'train':
    if os.path.exists(FLAGS.logs_dir+"/train"):
        os.system("rm -r "+FLAGS.logs_dir+"/train")
    if os.path.exists(FLAGS.logs_dir+"/valid"):
        os.system("rm -r "+FLAGS.logs_dir+"/valid")
    train_writer = tf.train.SummaryWriter(FLAGS.logs_dir+"/train",sess.graph)
    valid_writer = tf.train.SummaryWriter(FLAGS.logs_dir+"/valid")
elif FLAGS.mode == 'valid':
    if os.path.exists(FLAGS.logs_dir+"/complete_valid"):
        os.system("rm -r "+FLAGS.logs_dir+"/complete_valid")
    valid_writer = tf.train.SummaryWriter(FLAGS.logs_dir+"/complete_valid")    

# read data_records from *.pickle
print("\nReading in and reprocessing all images ...\n")
# check if FLAGS.data_dir folder exist
if not os.path.exists(FLAGS.data_dir):
    os.makedirs(FLAGS.data_dir)
# check if the *.pickle file exist
pickle_file = os.path.join(FLAGS.data_dir,FLAGS.pickle_name)
if not os.path.exists(pickle_file):
    # check if the *.zip exist
    zip_file = os.path.join(FLAGS.data_dir,FLAGS.data_url.split('/')[-1])
    if not os.path.exists(zip_file):
        # download the *.zip
        print("downloading "+zip_file+" ..")
        os.system("wget "+FLAGS.data_url+" -P "+FLAGS.data_dir)
        print("download finished!")
        # unzip the file
        print("unzipping "+zip_file+" ..")
        os.system("unzip "+zip_file+" -d "+FLAGS.data_dir)
        print("unzipping finished!")
    # pack data into *.pickle
    source_datadir =  zip_file.split('.')[0]
    if not os.path.exists(source_datadir):
        print("Error: source_datadir not found!!!")
        exit()
    else:
        data_types = ['train','valid']
        data_list = {}
        for data_type in data_types:
            image_list = []
            data_list[data_type] = []
            # find all images
            image_names = os.path.join(source_datadir,"images",data_type,'*.jpg')
            image_list.extend(glob.glob(image_names))
            if not image_list:
                print("Error: no images found for "+data_type+"!!!")
                exit()
            else:
                # find corresponding annotations
                for i in image_list:
                    image_name = (i.split('/')[-1]).split('.')[0]
                    annotation_name = os.path.join(source_datadir,"annotations",data_type,image_name+".png")
                    if os.path.exists(annotation_name):
                        # record this data tuple
                        record = {'image':i,'annotation':annotation_name,'filename':image_name}
                        data_list[data_type].append(record)
            # shuffle all tuples
            random.shuffle(data_list[data_type])
            print("Number of %s tuples: %d"%(data_type,len(data_list[data_type])))
    print("Packing data into "+pickle_file+" ...")
    with open(pickle_file,'wb') as f:
        pickle.dump(data_list,f,pickle.HIGHEST_PROTOCOL)
    print("pickle finished!!!")
# load data_records from *.pickle
with open(pickle_file,'rb') as f:
    pickle_records = pickle.load(f)
    train_records = pickle_records['train']
    valid_records = pickle_records['valid']
    del pickle_records
    
# initialize the data reader
print("Initializing the data reader...")
reader_optiions = {'resize':True,'resize_size':FLAGS.image_size}
if FLAGS.mode == 'train':
    train_reader = dataset.BatchDatset(train_records,reader_options)
valid_reader = dataset.BatchDatset(valid_records,reader_options)

# check if FLAGS.full_model exist
if not os.path.exists(FLAGS.full_model):
    os.makedirs(FLAGS.full_model)

# start training/ testing
if FLAGS.mode == 'train':
    for itr in xrange(FLAGS.max_iters):
        # read next batch
        train_images, train_annotations = train_reader.next_batch(FLAGS.batch_size)
        feed_dict = {images:train_images,annotations:train_annotations,dropout_prob:0.85}
        # training
        sess.run(train_op,feed_dict=feed_dict)
        # log training info
        if itr % 10 == 0:
            train_loss, train_summary = sess.run([loss,summary_op],feed_dict=feed_dict)
            train_writer.add_summary(train_summary,itr)
            print("Step: %d, train_loss: %f"%(itr,train_loss))
        # log valid info
        if itr % 100 == 0:
            valid_images, valid_annotations = valid_reader.get_random_batch(FLAGS.batch_size)
            feed_dict = {images:valid_images,annotations:valid_annotations,dropout_prob:1.0}
            valid_loss, valid_summary = sess.run([loss,summary_op],feed_dict=feed_dict)
            valid_writer.add_summary(valid_summary,itr)
            print("==============================")
            print("Step: %d, valid_loss: %f"%(itr,valid_loss))
            print("==============================")
        # save snapshot
        if itr % 500 == 0:
            snapshot_name = os.path.join(FLAGS.full_model,str(itr)+"_model.ckpt")
            saver.save(sess,snapshot_name)
elif FLAGS.mode == 'valid':
    # quantitative results
    valid_images,valid_annotations=valid_reader.get_records()
    feed_dict = {images:valid_images,annotations:valid_annotations,dropout_prob:1.0}
    valid_loss,valid_summary = sess.run([loss,summary_op],feed_dict=feed_dict)
    valid_writer.add_summary(valid_summary,FLAGS.max_iters)
    print("==============================")
    print("Step: %d, valid_loss: %f"%(FLAGS.max_iters,valid_loss))
    print("==============================")
    # qualitative results
    valid_images,valid_annotations=valid_reader.get_random_batch(FLAGS.batch_size)
    feed_dict = {images:valid_images,annotations:valid_annotations,dropout_prob:1.0}
    annotations_pred_results = sess.run(annotations_pred,feed_dict=feed_dict)
    
    valid_annotations = np.squeeze(valid_annotations,axis=3)
    annotations_pred_results = np.squeeze(annotations_pred_results,axis=3)
    
    for n in xrange(FLAGS.batch_size):
        print("Saving %d valid tuples for qualitative comparisons...")
        misc.imsave(FLAGS.logs_dir+"/complete_valid/"+str(n)+"_image.png",valid_images[n].astype(np.uint8))
        misc.imsave(FLAGS.logs_dir+"/complete_valid/"+str(n)+"_annotation.png",valid_annotations[n].astype(np.uint8))
        misc.imsave(FLAGS.logs_dir+"/complete_valid/"+str(n)+"_prediction.png",annotations_pred_results[n].astype(np.uint8))
        print("saving finished!!!")


Model_zoo/imagenet-vgg-verydeep-19.mat has already been downloaded.


loading pretrained weights from: Model_zoo/imagenet-vgg-verydeep-19.mat
loading finished!


Initializing the variables ...

In [0]: