Learning Image Colorization with Deep Neural Networks in TensorFlow

[TODO: Improve shitty title]

Colorization example

Table of contents [TODO polish]

  • Intro
  • History of Colorization
  • (Maybe previous approaches that didn't work?)
  • Deep learning: Results from this approach
  • Note about prerequisites
  • Getting started with architecture
  • Dilated convolutions; Why they are necessary etc.
  • [Where to insert ab color space]
  • Loss functions
  • Perceptual losses, temperature softmax sampling
  • Training the model/Loading trained model
  • Playing around with it and conclusion

Intro/ history/previous approaches

lol

Enter deep learning

In October, 2015 Zhang et al, 2016 introduce an amazing result using deep convolutional neural networks. By carefully crafting their architecture and loss function, they were able to avoid color desaturation issues. The pictures produced by this method are so brilliant, they even fool humans 32% of the time in a 'colorization turing test'.

In this article we will explain this technique and demonstate how to implement it. This article assumes familiarity with tensorflow and knowledge of convolutional neural networks. If you would like to read more about those, here are some great resources: [TODO : insert]

Starting with the image pipeline and basic modules

To process any given image, we will take in an RGB image and convert it into the LAB color space. We will use LAB over RGB because it is a better approximation of how humans percieve color, and so images produced with LAB look better

Just like RGB, LAB has three channels: Luminance, A, and B. The Luminance channel contains brightness information, while A and B channels contain information about the color shade. The black-and-white image is exactly the luminance channel, so our task will be to use Luminance to predict color channels.

[insert relavant code]

Stacking the layers

Now that we have our image feeding pipeline, it's time to move on to the network layers!

The network has 8 convolutional layers, and the middle two layers use dilated convolutions. What's a dilated convolution, you ask? It is exactly the same as a normal convolution, except that it effectively inserts zeros periodically between all the kernel values, 'dilating' the kernel by some factor. These dilated kernels span over a larger area, so every unit in the layer has access to more information. This larger area span is important because image color depends on much more than just local texture; we need global information. Here is a great explanation of Dilated convolutions: http://www.inference.vc/dilated-convolutions-and-kronecker-factorisation/

We will now implement all the layers in tensorflow.

from .utils import conv2d

def construct(input_placeholder):
    """
    Constructs the main model architecture in tensorflow.
    """
        ###############################
        #      MODEL ARCHITECTURE     #
        ###############################
        # First block of convolutions
        with tf.variable_scope("conv_1"):
            conv_1_1 = conv2d(input_placeholder,
                input_channels=1,
                output_channels=64,
                kernel_size=3,
                pad=1)
            conv_1_2 = conv2d(conv_1_1,
                input_channels=64,
                output_channels=64,
                kernel_size=3,
                pad=1,
                stride=2)
            # TODO batchn
            bn_1 = conv_1_2

        # Second block of convolutions.
        with tf.variable_scope("conv2"):
            conv_2_1 = conv2d(bn_1,
                input_channels=64,
                output_channels=128,
                kernel_size=3,
                pad=1)
            conv_2_2 = conv2d(conv_2_1,
                input_channels=128,
                output_channels=128,
                kernel_size=3,
                pad=1,
                stride=2)

            # TODO batchn
            bn_2 = conv_2_2

        with tf.variable_scope("conv3"):
            conv_3_1 = conv2d(bn_2,
                input_channels=128,
                output_channels=256,
                kernel_size=3,
                pad=1)
            conv_3_2 = conv2d(conv_3_1,
                input_channels=256,
                output_channels=256,
                kernel_size=3,
                pad=1)
            conv_3_3 = conv2d(conv_3_2,
                input_channels=256,
                output_channels=256,
                kernel_size=3,
                pad=1,
                stride=2)
            # TODO batchn
            bn_3 = conv_3_3


        # DILATED LAYERS:
        with tf.variable_scope("conv4"):
            conv_4_1 = conv2d(bn_3,
                input_channels=256,
                output_channels=512,
                kernel_size=3,
                pad=1,
                dilation=1)
            conv_4_2 = conv2d(conv_4_1,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=1,
                dilation=1)
            conv_4_3 = conv2d(conv_4_2,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=1,
                dilation=1)
            # TODO batchn
            bn_4 = conv_4_3

        with tf.variable_scope("conv5"):
            conv_5_1 = conv2d(bn_4,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=2,
                dilation=2)
            conv_5_2 = conv2d(conv_5_1,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=2,
                dilation=2)
            conv_5_3 = conv2d(conv_5_2,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=2,
                dilation=2)
            # TODO batchn
            bn_5 = conv_5_3

        with tf.variable_scope("conv6"):
            conv_6_1 = conv2d(bn_5,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=2,
                dilation=2)
            conv_6_2 = conv2d(conv_6_1,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=2,
                dilation=2)
            conv_6_3 = conv2d(conv_6_2,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=2,
                dilation=2)
            # TODO batchn
            bn_6 = conv_6_3


        with tf.variable_scope("conv7"):
            conv_7_1 = conv2d(bn_6,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=1,
                dilation=1)
            conv_7_2 = conv2d(conv_7_1,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=1,
                dilation=1)
            conv_7_3 = conv2d(conv_7_2,
                input_channels=512,
                output_channels=512,
                kernel_size=3,
                pad=1,
                dilation=1)
            # TODO batchn
            bn_7 = conv_7_3


        with tf.variable_scope("conv8"):
            conv_8_1 = deconv2d(bn_7,
                input_channels=512,
                output_size=[None, 64, 64, 256],
                kernel_size=4,
                stride=2,
                pad=1)
            conv_8_2 = conv2d(conv_8_1,
                input_channels=256,
                output_channels=256,
                kernel_size=3,
                pad=1)
            conv_8_3 = conv2d(conv_8_2,
                input_channels=256,
                output_channels=256,
                kernel_size=3,
                pad=1,
                stride=1)
            conv_8_313 = conv2d(conv_8_3,
                input_channels=256,
                output_channels=313,
                kernel_size=3,
                pad=1,
                stride=1)


        return conv_8_313

Some complications

One might notice that the output of this layer stack is the blue box labelled '(a,b) probability distribution', which has 313 channels. What is going on there? Don't we just want to predict 2 channels? These convolutional layers are actually designed not to predict a color straightaway, but rather a probability distribution over colors. The reason for this is simple but very important: Color prediction is an under-constrained problem. For most objects, there isn't one 'correct' color but rather multiple 'valid' colors. So, for example, if the network tries to predcit a toy's color, it will almost always be wrong since the set of possible colors is huge. This messes up the training because the network will think it is wrong even if it had a valid answer.

To solve this problem, the authors divided the entire color space into 313 bins of different shades, and instead of simply predicting a single color, the network was made to predict the likelihood of each of the 313 colors for every single pixel.

This is a great step, but as it turns out, it is not enough. Under-constrained nature of image colorization is a very fundamental problem. We can't just take any run-of-the-mill neural network, train it on a few million samples, and expect nice results. In order to get the most vibrant and realistic image colors possible, we have to improvise with our loss functions and training methods, and that is exactly what the authors of the paper do.

The Loss Function

Let’s start with the loss function. The most popular way of quantifying prediction accuracy is the squared error loss (SQE). It doesn't care about the sign of the prediction error, and has nice statistical and mathematical properties, so it is used everywhere in statistics and machine learning. This is why it is also used very frequently as a neural network loss. However, we we find that if we use this loss in image colorization, we get very dull, desaturated pictures. This happens because of a peculiar property of SQE: If there are multiple likely answers, SQE is minimum for their average. A neural network trained with SQE would not predict a Christmas sweater as bright red or bright green, but rather a muddy average of the two.

[TODO: insert image?, vs x-entropy?].

In general, averaging a bunch of vibrant colors gives us desaturated colors, and this is unacceptable, so we need to find a better suited loss function. The authors used a multi-class extension of the commonly used cross-entropy loss, which looks like: [insert equation and code]

It turns out that in practice, even this is not enough. Most pictures that we encounter consist of a small amount of vibrant pixels, but most of the pixels – like in skies, ground, dirt – are still quite dull. In fact, the number of these dull pixels is orders of magnitudes higher. A naively trained network will be tempted to always predict dull and desaturated colors, since they are just so much more frequent. To make things more equal, we weight every color by its rarity, which we calculate by measuring how frequently the color occurs in a randomly chosen set of images, and normalizing it with respect to all the other color rarities. The new loss function looks like:

Here, H and W represent the height and the width of the image, and q repesents the index for color bins

[insert loss function code here]

Predicting colors from probabilities:

Now that we have an estimate for how likely each color is, what color should we predict? A natural prediction rule would be to choose the most likely color, but in practice it tends to produce images with strange spots all over them.

This happens because the max function is not a smooth function, and the predicted color can change drastically with the smallest change in the network’s predicted likelihoods. So, even if the network predicts most colors' likelihoods perfectly, if it accidentally overestimates the likelihood of one color – say red – the final color drastically shifts towards red and we end up with a red spot. Clearly we want a more balanced way to take all the color likelihoods into account. A natural way to do this is to take a weighted average of all the colors. More likely colors have higher weights, and so the final color will be more similar to them, but since there are so many colors, overestimating one or two colors will not bias the prediction too badly. It sounds like a great idea in theory, and in practice it does eliminate the spots, but as we saw before, averaging a bunch of likely colors gives us desaturated pictures. After all our efforts in the loss function, that is the last thing we want.

Ideally, we want to do something between the two extremes: we don't want blotches on our pictures, but we also want the pictures to look vibrant. So what should we do? A good way of getting the best of both worlds is to use a softmax function. Softmax does exactly what it sounds like; it is smooth, but it also acts like a max function. The temperature T is like a knob that lets us control how 'sharp' our softmax function is. With a temperature of 0, we end up with the original 'flat' weighted average. A temperature of infinity, on the other hand, gives us the original 'jagged' max function. Setting the temperature just right will give us the perfect balance.

So our final prediction rule is: [insert equation and code]

Training the model/Loading trained model

The trainer classes make it easy to train the model by just instantiating the class. However, this is a computationally expensive method, and even with one of the K40 GPUs, it can take up to 2 weeks to train it. We are providing a pre-trained model that can be loaded by:

[insert code]

Conclusion


In [ ]: