In this exercise we will train an end-to-end convolutional neural network for semantic segmentation. The goal of semantic segmentation is to classify the image on the pixel level. For each pixel we want to determine the class of the object to which it belongs. This is different from image classification which classifies an image as a whole and doesn't tell us the location of the objects. This is why semantic segmentation goes into the category of structured prediction problems. It answers on both the 'what' and 'where' questions while classifcation tells us only 'what'. By classifying each pixel we are infering the structure of the whole scene. Typical examples of input image and target labels for this problem are shown below.
Input image | Target image |
---|---|
Cityscapes dataset contains a diverse set of stereo video sequences recorded in street scenes from 50 different cities, with high quality pixel-level annotations. Dataset contains 2975 training and 500 validation images of size 2048x1024. The test set of 1000 images is evaluated on the server and benchmark is available here. Here we will use downsampled images of size 384x160. The original dataset has 19 classes but we lowered that to 7 by uniting similar classes into broader categories. This makes sense due to low visibility of very small objects in downsampled images. We also have ignore class which we need to ignore during training because those pixels don't belong to any class.
https://drive.google.com/file/d/0B6NQEJnkignaM2ZPQWUzTTc5Rjg/view?usp=sharing
ID | Class | Color |
---|---|---|
0 | road | purple |
1 | building | grey |
2 | infrastructure | yellow |
3 | nature | green |
4 | sky | light blue |
5 | person | red |
6 | vehicle | dark blue |
7 | ignore | black |
In [ ]:
%matplotlib inline
import time
from os.path import join
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import utils
from data import Dataset
tf.set_random_seed(31415)
tf.logging.set_verbosity(tf.logging.ERROR)
plt.rcParams["figure.figsize"] = (15, 5)
In [ ]:
batch_size = 10
num_classes = Dataset.num_classes
# create the Dataset for training and validation
train_data = Dataset('train', batch_size)
val_data = Dataset('val', batch_size, shuffle=False)
# downsample = 2
# train_data = Dataset('train', batch_size, downsample)
# val_data = Dataset('val', batch_size, downsample, shuffle=False)
print('Train shape:', train_data.x.shape)
print('Validation shape:', val_data.x.shape)
#print('mean = ', train_data.x.mean((0,1,2)))
#print('std = ', train_data.x.std((0,1,2)))
In [ ]:
# store the input image dimensions
height = train_data.height
width = train_data.width
channels = train_data.channels
# create placeholders for inputs
def build_inputs():
with tf.name_scope('data'):
x = tf.placeholder(tf.float32, shape=(None, height, width, channels), name='rgb_images')
y = tf.placeholder(tf.int32, shape=(None, height, width), name='labels')
return x, y
Now we can define the computational graph. Here we will heavily use tf.layers
high level API which handles tf.Variable
creation for us. The main difference here compared to the classification model is that the network is going to be fully convolutional without any fully connected layers. Brief sketch of the model we are going to define is given below.
conv3x3(32) -> 4 x (pool2x2 -> conv3x3(64) -> conv3x3(64)) -> conv1x1(7) -> resize_bilinear -> softmax() -> Loss
In [ ]:
# helper function which applies conv2d + ReLU with filter size k
def conv(x, num_maps, k=3):
x = tf.layers.conv2d(x, num_maps, k, padding='same')
x = tf.nn.relu(x)
return x
# helper function for 2x2 max pooling with stride=2
def pool(x):
return tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='same')
# this functions takes the input placeholder and the number of classes, builds the model and returns the logits
def build_model(x, num_classes):
input_size = x.get_shape().as_list()[1:3]
block_sizes = [64, 64, 64, 64]
x = conv(x, 32, k=3)
for i, size in enumerate(block_sizes):
with tf.name_scope('block'+str(i)):
x = pool(x)
x = conv(x, size)
x = conv(x, size)
print(x)
with tf.name_scope('logits'):
x = tf.layers.conv2d(x, num_classes, 1, padding='same')
# ask why no relu
x = tf.image.resize_bilinear(x, input_size, name='upsample_logits')
return x
Now we are going to implement the build_loss
function which will create nodes for loss computation and return the final tf.Tensor
representing the scalar loss value.
Because segmentation is just classification on a pixel level we can again use the cross entropy loss function $L$ between the target one-hot distribution $ \mathbf{y} $ and the predicted distribution from a softmax layer $ \mathbf{s} $. But compared to the image classification here we need to define the loss at each pixel. Below are the equations describing the loss for just one example (one pixel in our case).
$$
L = - \sum_{i=1}^{C} y_i log(s_j(\mathbf{x})) \\
s_i(\mathbf{x}) = \frac{e^{x_i}}{\sum_{j=1}^{C} e^{x_j}} \\
$$
In [ ]:
# this funcions takes logits and targets (y) and builds the loss subgraph
def build_loss(logits, y):
with tf.name_scope('loss'):
# vectorize the image
y = tf.reshape(y, shape=[-1])
logits = tf.reshape(logits, [-1, num_classes])
# gather all labels with valid ID
mask = y < num_classes
y = tf.boolean_mask(y, mask)
logits = tf.boolean_mask(logits, mask)
# define softmax and cross entropy loss
y_one_hot = tf.one_hot(y, num_classes)
xent = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_one_hot)
# take the mean because we don't want the loss to depend on the number of pixels in batch
xent = tf.reduce_mean(xent)
tf.summary.scalar('cross_entropy', xent)
return xent
In [ ]:
# create inputs
x, y = build_inputs()
# create model
logits = build_model(x, num_classes)
# create loss
loss = build_loss(logits, y)
# we are going to need argmax predictions for IoU
y_pred = tf.argmax(logits, axis=3, output_type=tf.int32)
In [ ]:
# this functions trains the model
def train(sess, x, y, y_pred, loss, checkpoint_dir):
num_epochs = 30
batch_size = 10
log_dir = 'local/logs'
utils.clear_dir(log_dir)
utils.clear_dir(checkpoint_dir)
learning_rate = 1e-3
decay_power = 1.0
global_step = tf.Variable(0, trainable=False)
decay_steps = num_epochs * train_data.num_batches
# usually SGD learning rate is decreased over time which enables us
# to better fine-tune the parameters when close to solution
lr = tf.train.polynomial_decay(learning_rate, global_step, decay_steps,
end_learning_rate=0, power=decay_power)
train_step = tf.train.AdamOptimizer(lr).minimize(loss, global_step=global_step)
saver = tf.train.Saver()
summary_all = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(join(log_dir, 'train'), sess.graph)
tf.global_variables_initializer().run(session=sess)
step = 0
best_iou = 0
best_epoch = 0
exp_start_time = time.time()
for epoch in range(1, num_epochs+1):
# confusion_mat = np.zeros((num_classes, num_classes), dtype=np.uint64)
print('\nTraining phase:')
for x_np, y_np, names in train_data:
start_time = time.time()
loss_np, summary, _ = sess.run([loss, summary_all, train_step],
feed_dict={x: x_np, y: y_np})
train_writer.add_summary(summary, step)
duration = time.time() - start_time
# confusion_mat += batch_conf_mat.astype(np.uint64)
if step % 20 == 0:
# if step % 2 == 0:
string = '%s: epoch %d / %d, iter %05d, loss = %.2f (%.1f images/sec)' % \
(utils.get_expired_time(exp_start_time), epoch, num_epochs, step, loss_np, batch_size / duration)
print(string)
step += 1
# utils.print_metrics(confusion_mat, 'Train')
# add this later
iou = validate(sess, val_data, x, y, y_pred, loss, draw_steps=5)
if iou > best_iou:
best_iou, best_epoch = iou, epoch
save_path = saver.save(sess, join(checkpoint_dir, 'model.ckpt'))
print('Model saved in file: ', save_path)
print('\nBest IoU = %.2f (epoch %d)' % (best_iou, best_epoch))
In [ ]:
sess = tf.Session()
train(sess, x, y, y_pred, loss, 'local/checkpoint1')
We usually evaluate the semantic segmentation results with Intersection over Union measure (IoU aka Jaccard index). Note that accurracy we used on MNIST image classification problem is a bad measure in this case because semantic segmentation datasets are often heavily imbalanced. First we compute IoU for each class in one-vs-all fashion (shown below) and then take the mean IoU (mIoU) over all classes. By taking the mean we are treating all classes as equally important. $$ IOU = \frac{TP}{TP + FN + FP} $$
In [ ]:
def validate(sess, data, x, y, y_pred, loss, draw_steps=0):
print('\nValidation phase:')
conf_mat = np.zeros((num_classes, num_classes), dtype=np.uint64)
for i, (x_np, y_np, names) in enumerate(data):
start_time = time.time()
loss_np, y_pred_np = sess.run([loss, y_pred],
feed_dict={x: x_np, y: y_np})
duration = time.time() - start_time
batch_conf_mat = confusion_matrix(y_np.reshape(-1), y_pred_np.reshape(-1))
batch_conf_mat = batch_conf_mat[:-1,:-1].astype(np.uint64)
conf_mat += batch_conf_mat
for j in range(min(draw_steps, batch_size)):
img_pred = utils.colorize_labels(y_pred_np[j], Dataset.class_info)
img_true = utils.colorize_labels(y_np[j], Dataset.class_info)
img_raw = data.get_img(names[j])
img = np.concatenate((img_raw, img_true, img_pred), axis=1)
plt.imshow(img)
plt.show()
draw_steps -= 1
if i % 5 == 0:
string = 'batch %03d loss = %.2f (%.1f images/sec)' % \
(i, loss_np, x_np.shape[0] / duration)
print(string)
print(conf_mat)
return utils.print_stats(conf_mat, 'Validation', Dataset.class_info)
In [ ]:
sess = tf.Session()
# ask why forward is faster
train(sess, x, y, y_pred, loss, 'local/checkpoint1')
In [ ]:
# restore the best checkpoint
checkpoint_path = 'local/pretrained1/model.ckpt'
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
validate(sess, val_data, x, y, y_pred, loss, draw_steps=10)
In [ ]:
def upsample(x, skip, num_maps):
skip_size = skip.get_shape().as_list()[1:3]
x = tf.image.resize_bilinear(x, skip_size)
x = tf.concat([x, skip], 3)
return conv(x, num_maps)
# this functions takes the input placeholder and the number of classes, builds the model and returns the logits
def build_model(x, num_classes):
input_size = x.get_shape().as_list()[1:3]
block_sizes = [64, 64, 64, 64]
skip_layers = []
x = conv(x, 32, k=3)
for i, size in enumerate(block_sizes):
with tf.name_scope('block'+str(i)):
x = pool(x)
x = conv(x, size)
x = conv(x, size)
# if i < len(block_sizes) - 1:
skip_layers.append(x)
for i, skip in reversed(list(enumerate(skip_layers))):
with tf.name_scope('upsample'+str(i)):
print(i, x, '\n', skip)
x = upsample(x, skip, block_sizes[i])
with tf.name_scope('logits'):
x = tf.layers.conv2d(x, num_classes, 1, padding='same')
x = tf.image.resize_bilinear(x, input_size, name='upsample_logits')
return x
In [ ]:
sess.close()
tf.reset_default_graph()
# create inputs
x, y = build_inputs()
# create model
logits = build_model(x, num_classes)
# create loss
loss = build_loss(logits, y)
# we are going to need argmax predictions for IoU
y_pred = tf.argmax(logits, axis=3, output_type=tf.int32)
sess = tf.Session()
In [ ]:
train(sess, x, y, y_pred, loss, 'local/checkpoint2')
In [ ]:
# restore the best checkpoint
checkpoint_path = 'local/pretrained2/model.ckpt'
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
validate(sess, val_data, x, y, y_pred, loss, draw_steps=10)
In [ ]: