In this exercise we will train an end-to-end convolutional neural network for semantic segmentation. The goal of semantic segmentation is to classify the image on the pixel level. For each pixel we want to determine the class of the object to which it belongs. This is different from image classification which classifies an image as a whole and doesn't tell us the location of the objects. This is why semantic segmentation goes into the category of structured prediction problems. It answers on both the 'what' and 'where' questions while classifcation tells us only 'what'. By classifying each pixel we are infering the structure of the whole scene. Typical examples of input image and target labels for this problem are shown below.
Input image | Target image |
---|---|
Cityscapes dataset contains a diverse set of stereo video sequences recorded in street scenes from 50 different cities, with high quality pixel-level annotations. Dataset contains 2975 training and 500 validation images of size 2048x1024. The test set of 1000 images is evaluated on the server and benchmark is available here. Here we will use downsampled images of size 384x160. The original dataset has 19 classes but we lowered that to 7 by uniting similar classes into broader categories. This makes sense due to low visibility of very small objects in downsampled images. We also have ignore class which we need to ignore during training because those pixels don't belong to any class.
ID | Class | Color |
---|---|---|
0 | road | purple |
1 | building | grey |
2 | infrastructure | yellow |
3 | nature | green |
4 | sky | light blue |
5 | person | red |
6 | vehicle | dark blue |
7 | ignore | black |
In [2]:
%matplotlib inline
import time
from os.path import join
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import utils
from data import Dataset
tf.set_random_seed(31415)
tf.logging.set_verbosity(tf.logging.ERROR)
plt.rcParams["figure.figsize"] = (15, 5)
In [ ]:
batch_size = 10
num_classes = Dataset.num_classes
# create the Dataset for training and validation
train_data = Dataset('train', batch_size)
val_data = Dataset('val', batch_size, shuffle=False)
print('Train shape:', train_data.x.shape)
print('Validation shape:', val_data.x.shape)
#print('mean = ', train_data.x.mean((0,1,2)))
#print('std = ', train_data.x.std((0,1,2)))
In [ ]:
# store the input image dimensions
height = train_data.height
width = train_data.width
channels = train_data.channels
# create placeholders for inputs
def build_inputs():
...
Now we can define the computational graph. Here we will heavily use tf.layers
high level API which handles tf.Variable
creation for us. The main difference here compared to the classification model is that the network is going to be fully convolutional without any fully connected layers. Brief sketch of the model we are going to define is given below.
conv3x3(32) -> 4 x (pool2x2 -> conv3x3(64) -> conv3x3(64)) -> conv1x1(7) -> resize_bilinear -> softmax() -> Loss
In [ ]:
# helper function which applies conv2d + ReLU with filter size k
def conv(x, num_maps, k=3):
...
# helper function for 2x2 max pooling with stride=2
def pool(x):
...
# this functions takes the input placeholder and the number of classes, builds the model and returns the logits
def build_model(x, num_classes):
...
Now we are going to implement the build_loss
function which will create nodes for loss computation and return the final tf.Tensor
representing the scalar loss value.
Because segmentation is just classification on a pixel level we can again use the cross entropy loss function $L$ between the target one-hot distribution $ \mathbf{y} $ and the predicted distribution from a softmax layer $ \mathbf{s} $. But compared to the image classification here we need to define the loss at each pixel. Below are the equations describing the loss for just one example (one pixel in our case).
$$
L = - \sum_{i=1}^{C} y_i log(s_j(\mathbf{x})) \\
s_i(\mathbf{x}) = \frac{e^{x_i}}{\sum_{j=1}^{C} e^{x_j}} \\
$$
In [4]:
# this funcions takes logits and targets (y) and builds the loss subgraph
def build_loss(logits, y):
...
In [ ]:
# create inputs
# create model
# create loss
# we will need argmax predictions for IoU
In [ ]:
# this functions trains the model
def train(sess, x, y, y_pred, loss, checkpoint_dir):
num_epochs = 30
batch_size = 10
log_dir = 'local/logs'
utils.clear_dir(log_dir)
utils.clear_dir(checkpoint_dir)
learning_rate = 1e-3
decay_power = 1.0
global_step = tf.Variable(0, trainable=False)
decay_steps = num_epochs * train_data.num_batches
# usually SGD learning rate is decreased over time which enables us
# to better fine-tune the parameters when close to solution
lr = tf.train.polynomial_decay(learning_rate, global_step, decay_steps,
end_learning_rate=0, power=decay_power)
...
In [ ]:
sess = tf.Session()
train(sess, x, y, y_pred, loss, 'local/checkpoint1')
We usually evaluate the semantic segmentation results with Intersection over Union measure (IoU aka Jaccard index). Note that accurracy we used on MNIST image classification problem is a bad measure in this case because semantic segmentation datasets are often heavily imbalanced. First we compute IoU for each class in one-vs-all fashion (shown below) and then take the mean IoU (mIoU) over all classes. By taking the mean we are treating all classes as equally important.
In order to compute the IoU we are going to do the forward pass
on validation data collect the confusion matrix first.
$$
IOU = \frac{TP}{TP + FN + FP}
$$
In [ ]:
def validate(sess, data, x, y, y_pred, loss, draw_steps=0):
print('\nValidation phase:')
...
return utils.print_stats(conf_mat, 'Validation', Dataset.class_info)
In [ ]:
sess = tf.Session()
train(sess, x, y, y_pred, loss, 'local/checkpoint1')
In [ ]:
# restore the checkpoint
...
In [ ]:
# upsampling layer
def upsample(x, skip, num_maps):
# this functions takes the input placeholder and the number of classes, builds the model and returns the logits
def build_model(x, num_classes):
In [ ]:
sess.close()
tf.reset_default_graph()
# create inputs
# create model
# create loss
# we are going to need argmax predictions for IoU
In [ ]:
sess = tf.Session()
train(sess, x, y, y_pred, loss, 'local/checkpoint2')
In [ ]:
# restore the checkpoint
...
In [ ]: