Image generation

In Part 1 of this course, we focused mainly on models that were useful for classification. However, many applications require generating much higher dimensional results, such as images and sentences. Examples include:

  • Text: neural translation, text to speech, image captioning
  • Image: Segmentation, artistic filters, image sharpening and cleaning

In [69]:
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *

from scipy.optimize import fmin_l_bfgs_b
from scipy.misc import imsave
from keras import metrics

In [2]:
import vgg16_avg; importlib.reload(vgg16_avg)
from vgg16_avg import VGG16_Avg

In [3]:
# Tell Tensorflow to use no more GPU RAM than necessary
limit_mem()

Data can be downloaded from here. Update path below to where you download data to. Optionally use a 2nd path for fast (e.g. SSD) storage - set both to the same path if using AWS.


In [4]:
path = 'data/imagenet/train/'
dpath = 'data/imagenet/train/'

Neural style transfer

The first use case of an image to image architecture we're going to look at is neural style transfer, using the approach in this paper. This is a fairly popular application of deep learning in which an image is recreated in the style of a work of art, such as Van Gogh's Starry Night. For more information about the use of neural networks in art, see this Scientific American article or Google's Magenta Project.

Setup

Our first step is to list out the files we have, and then grab some image.


In [5]:
fnames = glob.glob(path+'**/*.JPEG', recursive=True)
n = len(fnames); n


Out[5]:
19439

In [6]:
idx=60

In [7]:
fn = fnames[idx]; fn


Out[7]:
'data/imagenet/train\\n01491361\\n01491361_2884.JPEG'

In [11]:
img=Image.open(fnames[idx]); img


Out[11]:

That's a nice looking image! Feel free to use any other image that you're interested in playing with.

We'll be using this image with VGG16. Therefore, we need to subtract the mean of each channel of the imagenet data and reverse the order of RGB->BGR since those are the preprocessing steps that the VGG authors did - so their model won't work unless we do the same thing.

We can do this in one step using broadcasting, which is a topic we'll be returning to many times during this course.


In [12]:
img.size


Out[12]:
(500, 331)

In [83]:
np.array(img).shape


Out[83]:
(331, 500, 3)

In [13]:
temp = np.expand_dims(np.array(img),0)
temp.shape


Out[13]:
(1, 331, 500, 3)

In [14]:
rn_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32)
preproc = lambda x: (x - rn_mean)[:, :, :, ::-1] # 4D shape tensor now

Function for undoing the preprocessing for the generated images.


In [15]:
deproc = lambda x,s: np.clip(x.reshape(s)[:, :, :, ::-1] + rn_mean, 0, 255)

In [16]:
img_arr = preproc(np.expand_dims(np.array(img), 0))
shp = img_arr.shape; shp


Out[16]:
(1, 331, 500, 3)

Broadcasting examples


In [12]:
np.array([1,2,3]) - 2


Out[12]:
array([-1,  0,  1])

In [12]:
np.array([2,3]).reshape(1,1,1,2)


Out[12]:
array([[[[2, 3]]]])

In [13]:
np.array([2,3]).reshape(1,1,2,1)


Out[13]:
array([[[[2],
         [3]]]])

In [14]:
a = np.random.randn(5,1,3,2)
b = np.random.randn(2)
(a-b).shape


Out[14]:
(5, 1, 3, 2)

Recreate input

The first step in style transfer is understanding how to recreate an image from noise based on "content loss", which is the amount of difference between activations in some layer. In earlier layes, content loss is very similar to per-pixel loss, but in later layers it is capturing the "meaning" of a part of an image, rather than the specific details.

To do this, we first take a CNN and pass an image through it. We then pass a "noise image" (i.e. random pixel values) through the same CNN. At some layer, we compare the outputs from it for both images. We then use a MSE to compare the activations of these two outputs.

The interesting part is that now, instead of updating the parameters of the CNN, we update the pixels of the noisy image. In other words, our goal is to alter the noisy image so as to minimize the difference between the original image's output at some convolutional layer with the output of the noisy image at the same layer.

In order to construct this architecture, we're going to be working with keras.backend, which is an abstraction layer that allows us to target both theano and tensorflow with the same code.

The CNN we'll use is VGG16, but with a twist. Previously we've always used Vgg with max pooling, and this was useful for image classification. It's not as useful in this case however, because max pooling loses information about the original input area. Instead we will use average pooling, as this does not throw away as much information.


In [17]:
model = VGG16_Avg(include_top=False)

In [18]:
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, None, None, 3)     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (AveragePooling2 (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (AveragePooling2 (None, None, None, 128)   0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, None, None, 256)   295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_pool (AveragePooling2 (None, None, None, 256)   0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, None, None, 512)   1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_pool (AveragePooling2 (None, None, None, 512)   0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_pool (AveragePooling2 (None, None, None, 512)   0         
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________

Here we're grabbing the activations from near the end of the convolutional model).


In [19]:
layer = model.get_layer('block5_conv1').output

And let's calculate the target activations for this layer:

Create a new model using that late layer conv layer as output and model.input as input


In [20]:
layer_model = Model(model.input, layer)
targ = K.variable(layer_model.predict(img_arr)) #the same with tf.Variable(...)

# Targ is the 'particular' image output at that late conv layer
# in this case it's the fish

In our implementation, we need to define an object that will allow us to separately access the loss function and gradients of a function, since that is what scikit-learn's optimizers require.


In [21]:
# Good for deterministic approach for optimization
class Evaluator(object):
    def __init__(self, f, shp): self.f, self.shp = f, shp
        
    def loss(self, x):
        loss_, self.grad_values = self.f([x.reshape(self.shp)])
        return loss_.astype(np.float64)

    def grads(self, x): return self.grad_values.flatten().astype(np.float64)

We'll define our loss function to calculate the mean squared error between the two outputs at the specified convolutional layer.


In [22]:
# using just metrics.mse(layer, targ) doesn't work: returns a tensor instead of a scalar

# loss = (metrics.mse(layer, targ)) #using mse so we can use much faster convex optimization
# Get a mse loss function between content photo (targ) and generated photo (from layer)
loss = K.mean(metrics.mse(layer, targ))

# layer: a symbolic obj with no fix value, = to whatever the output value of that late conv layer at the moment


grads = K.gradients(loss, model.input) 
# for optimizing generated image, we need gradient with respect to generated image, which is input of model
# Get the gradient of loss function above with respect to model's input

fn = K.function([model.input], [loss]+grads)
# function input is model.input, output is a list contains loss and grads, i.e [loss,grads]

evaluator = Evaluator(fn, shp)

Now we're going to optimize this loss function with a deterministic approach to optimization that uses a line search, which we can implement with sklearn's `fmin_l_bfgs_b` function. , instead of SGD since there is no batch needed or involved


In [23]:
def solve_image(eval_obj, niter, x):
    for i in range(niter):
        # pass in fmin_l_bfgs_b :
        # - loss function at ONE current point,
        # - starting point x, just a random image at first
        # - gradient function at ONE current point
        # Return x (array list) as estimated 'position' that minimum loss happens, in this case the image input (which originally is random)
        x, min_val, info = fmin_l_bfgs_b(eval_obj.loss, x.flatten(),
                                         fprime=eval_obj.grads, maxfun=20)
        
        x = np.clip(x, -127,127)
        print('Current loss value:', min_val)
        imsave('{}/results/res_at_iteration_{}.png'.format(path, i), deproc(x.copy(), shp)[0])
    return x

Next we need to generate a random image.


In [24]:
rand_img = lambda shape: np.random.uniform(-2.5, 2.5, shape)/100
x = rand_img(shp)
plt.imshow(x[0]);shp


Out[24]:
(1, 331, 500, 3)

Now we'll run through this optimization approach ten times and train the noise image's pixels as desired.


In [20]:
iterations=20

In [22]:
x = solve_image(evaluator, iterations, x)


Current loss value: 31.5032863617
Current loss value: 11.3070430756
Current loss value: 7.29805803299
Current loss value: 5.40987682343
Current loss value: 4.38178062439
Current loss value: 3.77204465866
Current loss value: 3.33651447296
Current loss value: 2.98644351959
Current loss value: 2.72604608536
Current loss value: 2.52122664452
Current loss value: 2.3497467041
Current loss value: 2.21426439285
Current loss value: 2.08880472183
Current loss value: 2.05548334122
Current loss value: 2.04777598381
Current loss value: 2.04778599739
Current loss value: 2.04778599739
Current loss value: 2.04778599739
Current loss value: 2.04778599739
Current loss value: 2.04778599739

In [23]:
x.shape


Out[23]:
(496500,)

Our result by comparing output at conv 1 of last block (5) is fairly amorphous, but still easily recognizable as a bird. Notice that the things it has reconstructed particularly well are those things that we expect Vgg16 to be good at recognizing, such as an eye or a beak.


In [24]:
Image.open(path + 'results/res_at_iteration_0.png')


Out[24]:

Important note:

If instead we optimized by calculating loss from the output of conv 1 of 4th block, our trained image looks much more like the original (image will also have same background/ background details are emphasized the same way with the main obj ). This makes sense because with less transformations to go through, comparing at an earlier layer means that we have a smaller receptive field and the features are more based on geometric details rather than broad features.

Using later convo net, details of the object (shark's fin, head, body color ...) are more hightlighted, as VGG does not care what the background looks like anymore


In [25]:
Image.open(path + 'results/res_at_iteration_9.png')


Out[25]:

In [26]:
Image.open(path + 'results/res_at_iteration_19.png')


Out[26]:

In [43]:
from IPython.display import HTML
from matplotlib import animation, rc

In [44]:
fig, ax = plt.subplots()
def animate(i): ax.imshow(Image.open('{}results/res_at_iteration_{}.png'.format(path, i)))


The optimizer first focuses on the important details of the bird, before trying to match the background.


In [ ]:
anim = animation.FuncAnimation(fig, animate, frames=10, interval=200)
HTML(anim.to_html5_video())

Recreate style

Now that we've learned how to recreate an input image, we'll move onto attempting to recreate style. By "style", we mean the color palette and texture of an image. Unlike recreating based on content, with style we are not concerned about the actual structure of what we're creating, all we care about is that it captures this concept of "style".

Here are some examples of images we can extract style from.


In [25]:
def plot_arr(arr): plt.imshow(deproc(arr,arr.shape)[0].astype('uint8'))

In [26]:
style = Image.open('data/imagenet/starry_night.jpg')
print(style.size)
style = style.resize(img.size); style  # - use this to avoid cropping the original image
#style = style.resize(np.divide(style.size,3.5).astype('int32')); style # - original statement


(1170, 968)
Out[26]:

In [31]:
#style = Image.open('data/imagenet/bird.jpg')
#style = style.resize(img.size); style  # - use this to avoid cropping the original image
# style = style.resize(np.divide(style.size,2.4).astype('int32')); style # - original statement

In [32]:
# style = Image.open('data/imagenet/simpsons.jpg')
# style = style.resize(img.size); style  # - use this to avoid cropping the original image
# style = style.resize(np.divide(style.size,2.7).astype('int32')); style # - original statement

We're going to repeat the same approach as before, but with some differences.


In [27]:
style.size


Out[27]:
(500, 331)

In [28]:
style_arr = preproc(np.expand_dims(style,0)[:,:,:,:3])
shp = style_arr.shape
print(shp)


(1, 331, 500, 3)

In [29]:
model = VGG16_Avg(include_top=False, input_shape=shp[1:])
outputs = {l.name: l.output for l in model.layers}

One thing to notice is that we're actually going to be calculating the loss function multiple layers, rather than just one. (Note however that there's no reason you couldn't try using multiple layers in your content loss function, if you wanted to try that).


In [30]:
layers = [outputs['block{}_conv1'.format(o)] for o in range(1,4)]

In [31]:
layers_model = Model(model.input, layers)
targs = [K.variable(o) for o in layers_model.predict(style_arr)]

The key difference is our choice of loss function. Whereas before we were calculating mse of the raw convolutional outputs, here we transform them into the "gramian matrix" of their channels (that is, the product of a matrix and its transpose) before taking their mse. It's unclear why this helps us achieve our goal, but it works. One thought is that the gramian shows how our features at that convolutional layer correlate, and completely removes all location information. So matching the gram matrix of channels can only match some type of texture information, not location information.

By doing dot product of a matrix to its transpose (Gram matrix), you multiply one row to another row, which is similar to taking each row and compare it to each other (including itself as well). Thus if 2 rows are similar, their products will be higher, thus emphasize their similarity and create some sort of 'fingerprint'


In [28]:
layers[0][0].get_shape()


Out[28]:
TensorShape([Dimension(331), Dimension(500), Dimension(64)])

In [32]:
def gram_matrix(x):
    # We want each row to be a channel, and the columns to be flattened x,y locations
    features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1))) 
    #change the channel from height x width x channel to channel x height x width
    #K.batch_flatten takes everything except for 1st dimension and flatten it to a vector. 
    # Thus features is now a 2D matrix with dimension channels x (height x width), or 64 x (331 x 500) = 64 x 165500  
    
    # The dot product of this with its transpose shows the correlation 
    # between each pair of channels
    return K.dot(features, K.transpose(features)) / x.get_shape().num_elements()

Note:

By flattening out height x width, location information is completely thrown away, thus doing loss function with gram matrix+mse between 2 images, it will show how similar these 2 images' "fingerprints" are


In [33]:
def style_loss(x, targ): return K.mean(metrics.mse(gram_matrix(x), gram_matrix(targ)))  # - using just metrics.mse(layer, targ) doesn't work: returns a tensor

In [34]:
loss = sum(style_loss(l1[0], l2[0]) for l1,l2 in zip(layers, targs))
grads = K.gradients(loss, model.input)
style_fn = K.function([model.input], [loss]+grads)
evaluator = Evaluator(style_fn, shp)

We then solve as we did before.


In [35]:
rand_img = lambda shape: np.random.uniform(-2.5, 2.5, shape)/1
x = rand_img(shp)
x = scipy.ndimage.filters.gaussian_filter(x, [0,2,2,0])

In [132]:
plt.imshow(x[0]);



In [133]:
iterations=20
x = rand_img(shp)

In [134]:
x = solve_image(evaluator, iterations, x)


Current loss value: 8001.65917969
Current loss value: 663.166870117
Current loss value: 328.129974365
Current loss value: 217.96812439
Current loss value: 152.284332275
Current loss value: 116.942276001
Current loss value: 90.6019592285
Current loss value: 74.2245483398
Current loss value: 58.2962226868
Current loss value: 48.0198745728
Current loss value: 40.1180801392
Current loss value: 34.5200424194
Current loss value: 28.3078517914
Current loss value: 25.1212348938
Current loss value: 21.7250862122
Current loss value: 19.3613853455
Current loss value: 16.7956428528
Current loss value: 15.3070983887
Current loss value: 13.2233715057
Current loss value: 12.0971097946

Our results are stunning. By transforming the convolutional outputs to the gramian, we are somehow able to update the noise pixels to produce an image that captures the raw style of the original image, with absolutely no structure or meaning.


In [135]:
Image.open(path + 'results/res_at_iteration_0.png')


Out[135]:

In [136]:
Image.open(path + 'results/res_at_iteration_9.png')


Out[136]:

In [86]:
Image.open(path + 'results/res_at_iteration_19.png')


Out[86]:

Style transfer

We now know how to reconstruct an image, as well as how to construct an image that captures the style of an original image. The obvious idea may be to just combine these two approaches by weighting and adding the two loss functions.


In [36]:
w,h = style.size
src = img_arr[:,:h,:w]
plot_arr(src)


Like before, we're going to grab a sequence of layer outputs to compute the style loss. However, we still only need one layer output to compute the content loss. How do we know which layer to grab? As we discussed earlier, the lower the layer, the more exact the content reconstruction will be. In merging content reconstruction with style, we might expect that a looser reconstruction of the content will allow more room for the style to have an effect (re: inspiration). Furthermore, a later layer ensures that the image "looks like" the same subject, even if it doesn't have the same details.


In [37]:
style_layers = [outputs['block{}_conv2'.format(o)] for o in range(1,6)]
content_name = 'block4_conv2'
content_layer = outputs[content_name]

In [38]:
style_model = Model(model.input, style_layers)
style_targs = [K.variable(o) for o in style_model.predict(style_arr)]

In [39]:
content_model = Model(model.input, content_layer)
content_targ = K.variable(content_model.predict(src))

Now to actually merge the two approaches is as simple as merging their respective loss functions. Note that as opposed to our previous to functions, this function is producing three separate types of outputs: one for the original image (src), one for the image whose style we're emulating (starry_night), and one for the random image whose pixel's we are training.

One way for us to tune how the reconstructions mix is by changing the factor on the content loss, which we have here as 1/10. If we increase that denominator, the style will have a larger effect on the image, and if it's too large the original content of the image will be obscured by unstructured style. Likewise, if it is too small than the image will not have enough style.


In [40]:
style_wgts = [0.05,0.2,0.2,0.25,0.3]

In [41]:
loss = sum(style_loss(l1[0], l2[0])*w
           for l1,l2,w in zip(style_layers, style_targs, style_wgts))
loss += K.mean(metrics.mse(content_layer, content_targ))/6
grads = K.gradients(loss, model.input)
transfer_fn = K.function([model.input], [loss]+grads)

In [42]:
evaluator = Evaluator(transfer_fn, shp)

In [40]:
iterations=40
x = rand_img(shp)

In [41]:
x = solve_image(evaluator, iterations, x)


Current loss value: 10198.8378906
Current loss value: 1240.29614258
Current loss value: 844.29675293
Current loss value: 621.955566406
Current loss value: 509.327545166
Current loss value: 435.025512695
Current loss value: 377.440063477
Current loss value: 339.355651855
Current loss value: 310.541137695
Current loss value: 289.971069336
Current loss value: 272.373718262
Current loss value: 259.668823242
Current loss value: 249.206085205
Current loss value: 241.012161255
Current loss value: 233.965789795
Current loss value: 228.280258179
Current loss value: 222.979919434
Current loss value: 217.907165527
Current loss value: 214.046859741
Current loss value: 210.887832642
Current loss value: 207.251403809
Current loss value: 204.970581055
Current loss value: 202.167160034
Current loss value: 200.219512939
Current loss value: 197.87612915
Current loss value: 196.126373291
Current loss value: 194.142333984
Current loss value: 192.269592285
Current loss value: 190.54510498
Current loss value: 189.137969971
Current loss value: 187.677886963
Current loss value: 186.345306396
Current loss value: 185.296264648
Current loss value: 184.031555176
Current loss value: 182.985061646
Current loss value: 181.882263184
Current loss value: 180.981140137
Current loss value: 180.02935791
Current loss value: 179.178726196
Current loss value: 178.307296753

These results are remarkable. Each does a fantastic job at recreating the original image in the style of the artist.


In [42]:
Image.open(path + 'results/res_at_iteration_5.png')


Out[42]:

In [43]:
Image.open(path + 'results/res_at_iteration_9.png')


Out[43]:

In [44]:
Image.open(path + 'results/res_at_iteration_19.png')


Out[44]:

In [45]:
Image.open(path + 'results/res_at_iteration_29.png')


Out[45]:

In [46]:
Image.open(path + 'results/res_at_iteration_39.png')


Out[46]:

In [150]:
plot_arr(src)



In [9]:
style = Image.open('data/imagenet/starry_night.jpg')
print(style.size)
style = style.resize(img.size); style


(1170, 968)
Out[9]:

There are lots of interesting additional things you could try, such as the ideas shown here: https://github.com/titu1994/Neural-Style-Transfer .

Use content loss to create a super-resolution network

So far we've demonstrated how to achieve succesful results in style transfer. However, there is an obvious drawback to our implementation, namely that we're training an image, not a network, and therefore every new image requires us to retrain. It's not a feasible method for any sort of real-time application. Fortunately we can adress this issue by using a fully convolutional network (FCN), and in particular we'll look at this implementation for Super resolution. We are following the approach in this paper.


In [43]:
#arr_lr = bcolz.open(dpath+'trn_resized_72_r.bc')[:]
#arr_hr = bcolz.open(dpath+'trn_resized_288_r.bc')[:]

# - here alternatively using the already loaded image:
# high-res image
img_hr0 = img.resize((288, 288))
arr_hr0 = np.expand_dims(np.array(img_hr0), 0)
img_hr1 = img.resize((288, 288))
arr_hr1 = np.expand_dims(np.array(img_hr1), 0)
arr_hr = np.concatenate((arr_hr0, arr_hr1))
shp_hr = arr_hr.shape
print(shp_hr)
# low-res image
img_lr0 = img.resize((72, 72))
arr_lr0 = np.expand_dims(np.array(img_lr0), 0)
img_lr1 = img.resize((72, 72))
arr_lr1 = np.expand_dims(np.array(img_lr1), 0)
arr_lr = np.vstack((arr_lr0, arr_lr1))  # same result as the above concatenate 
shp_lr = arr_lr.shape
print(shp_lr)


(2, 288, 288, 3)
(2, 72, 72, 3)

In [46]:
plt.imshow(arr_lr[0].astype('uint8'));



In [47]:
plt.imshow(arr_hr[0].astype('uint8'));



In [48]:
# parms = {'verbose': 0, 'callbacks': [TQDMNotebookCallback(leave_inner=False)]}

To start we'll define some of the building blocks of our network. In particular recall the residual block (as used in Resnet), which is just a sequence of 2 convolutional layers that is added to the initial block input. We also have a de-convolutional layer (also known as a "transposed convolution" or "fractionally strided convolution"), whose purpose is to learn to "undo" the convolutional function. It does this by padding the smaller image in such a way to apply filters on it to produce a larger image.


In [119]:
# Keras 2
def conv_block(x, filters, size, stride=(2,2), mode='same', act=True):
    x = Conv2D(filters, (size, size), strides=stride, padding=mode)(x)  # Keras 2
    x = BatchNormalization()(x)    # Keras 2 takes default axis=-1
    return Activation('relu')(x) if act else x

In [120]:
def res_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1))
    x = conv_block(x, nf, 3, (1,1), act=False)
    return add([x, ip])

In [121]:
# Keras 2
def deconv_block(x, filters, size, stride=(2,2)):
    # stride =2 => it's going to be doubling the size of the image
    # filters will stay at 64
    x = Conv2DTranspose(filters, kernel_size=size, strides=stride, padding='same')(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

In [122]:
# Keras 2
def up_block(x, filters, size):
    x = keras.layers.UpSampling2D()(x)
    x = Conv2D(filters, (size, size), padding='same')(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

This model here is using the previously defined blocks to encode a low resolution image and then upsample it to match the same image in high resolution.


In [123]:
# - first network segment for training the low-res image array
inp=Input(arr_lr.shape[1:])  # - input to this network segment will be low-res pre-processed image array
x=conv_block(inp, 64, 9, (1,1))
for i in range(4): x=res_block(x)
# x=deconv_block(x, 64, 3)  # - test: if using this comment up_block. Output shape: 144x144x64 (72x2=144)
# x=deconv_block(x, 64, 3)  # - test: if using this comment up_block. Output shape: 288x288x64 (72x2=144)
x=up_block(x, 64, 3)  # - alternative approach - if using this then comment deconv_block
x=up_block(x, 64, 3)  # - alternative approach - if using this then comment deconv_block

# last conv layers, set filter depth to 3 to get back 3 channels
# 9x9 filter size to get content back?
x=Conv2D(3, (9, 9), activation='tanh', padding='same')(x)
# output shape should be 288x288x3
# Note: the last conv2d ACTIVATION and the following rescaleing lambda can be removed,
# as the author said wihout them, the model works just well
# x=Conv2D(3, (9, 9), padding='same')(x)
outp=Lambda(lambda x: (x+1)*127.5)(x)  # - values will be in [0, 255]


C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))
C:\Users\anhqu\Anaconda3\lib\site-packages\tensorflow\python\util\tf_inspect.py:45: DeprecationWarning: inspect.getargspec() is deprecated, use inspect.signature() or inspect.getfullargspec()
  if d.decorator_argspec is not None), _inspect.getargspec(target))

The method of training this network is almost exactly the same as training the pixels from our previous implementations. The idea here is we're going to feed two images to Vgg16 and compare their convolutional outputs at some layer. These two images are the target image (which in our case is the same as the original but at higher resolution), and the output of the previous network we just defined, which we hope will learn to output a high resolution image.

The key then is to train this other network to produce an image that minimizes the loss between the outputs of some convolutional layer in Vgg16 (which the paper refers to as "perceptual loss"). In doing so, we are able to train a network that can upsample an image and recreate the higher resolution details.


In [124]:
# - second network segment: its input is the high-res pre-processed image array
vgg_inp=Input(shp_hr[1:])  
vgg= VGG16(include_top=False, input_tensor=Lambda(preproc)(vgg_inp))

Since we only want to learn the "upsampling network", and are just using VGG to calculate the loss function, we set the Vgg layers to not be trainable.


In [125]:
for l in vgg.layers: l.trainable=False

An important difference in training for super resolution is the loss function. We use what's known as a perceptual loss function (which is simply the content loss for some layer).


In [126]:
# - Original version using a group of layers - comment Test version if using this
def get_outp(m, ln): return m.get_layer('block{}_conv1'.format(ln)).output
vgg_content = Model(vgg_inp, [get_outp(vgg, o) for o in [1,2,3]])  # - using a group of layers

vgg1 = vgg_content(vgg_inp)  
# - will be used for the high-res target
# vgg_content is a MODEL, but can be treated as a function
# thus vgg_content(vgg_inp) is a NEW MODEL with vgg_inp as input

vgg2 = vgg_content(outp)  # - will be used for the trainable segment output
# similarly, vgg2 is a NEW MODEL with outp as input

# mean square as output
# diff is the subtraction of 2 convo outputs: with shape (1,...,...,512), should be a tensor...
def mean_sqr_b(diff):
    dims = list(range(1,K.ndim(diff))) # range(1,4)
    return K.expand_dims(K.sqrt(K.mean(diff**2, dims)), 0) # shape= (1,)

w=[0.1, 0.8, 0.1]
# add loss (with weights) from 3 outpus
def content_fn(x): 
    res = 0; n=len(w)
    for i in range(n): res += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return res

# output = Lambda(content_fn)(vgg1+vgg2)
m_sr = Model([inp, vgg_inp], Lambda(content_fn)(vgg1+vgg2))

#vgg1 and vgg2 is now a list of outputs (3 outputs each), so adding them means concatenate list

In [69]:
# - Test version using one layer - comment Original version if using this
# def get_outp(m, ln): return m.get_layer('block{}_conv1'.format(ln)).output
# vgg_content = Model(vgg_inp, get_outp(vgg, 2))  # using a selected early layer

# vgg1 = vgg_content(vgg_inp)  # - will be used for the high-res target
# vgg2 = vgg_content(outp)  # - will be used for the trainable segment output

# dims = list(range(1,K.ndim(x)))
# content_fn_simple = Lambda(lambda x: K.expand_dims(K.sqrt(K.mean((x[0] - x[1])**2, dims)), 0))([vgg1, vgg2])
# m_sr = Model([inp, vgg_inp], content_fn_simple)

In [127]:
targ = np.zeros((arr_hr.shape[0], 1))
targ.shape
#targ is zero vector as we want loss to be minimum
#targ.shape is 2,1 since there are 2 images for each resolution
# data table looks like this
# [arr_lr[0],arr_hr[1],0]
# [arr_lr[1],arr_hr[1],0]


Out[127]:
(2, 1)

In [92]:
targ


Out[92]:
array([[ 0.],
       [ 0.]])

Finally we compile this chain of models and we can pass it the original low resolution image as well as the high resolution to train on. We also define a zero vector as a target parameter, which is a necessary parameter when calling fit on a keras model.


In [128]:
m_sr.compile('adam', loss='mse')
m_sr.fit([arr_lr, arr_hr], targ, 8, 100)


Epoch 1/100
2/2 [==============================] - 2s - loss: 229482.5312
Epoch 2/100
2/2 [==============================] - 0s - loss: 115592.4531
Epoch 3/100
2/2 [==============================] - 0s - loss: 68964.9766
Epoch 4/100
2/2 [==============================] - 0s - loss: 50688.8047
Epoch 5/100
2/2 [==============================] - 0s - loss: 46017.8828
Epoch 6/100
2/2 [==============================] - 0s - loss: 43045.9102
Epoch 7/100
2/2 [==============================] - 0s - loss: 41369.0703
Epoch 8/100
2/2 [==============================] - 0s - loss: 41098.6875
Epoch 9/100
2/2 [==============================] - 0s - loss: 41078.3516
Epoch 10/100
2/2 [==============================] - 0s - loss: 41119.7148
Epoch 11/100
2/2 [==============================] - 0s - loss: 41168.5312
Epoch 12/100
2/2 [==============================] - 0s - loss: 41193.7500
Epoch 13/100
2/2 [==============================] - 0s - loss: 41189.8281
Epoch 14/100
2/2 [==============================] - 0s - loss: 41161.2578
Epoch 15/100
2/2 [==============================] - 0s - loss: 41116.6328
Epoch 16/100
2/2 [==============================] - 0s - loss: 41067.3906
Epoch 17/100
2/2 [==============================] - 0s - loss: 41023.9492
Epoch 18/100
2/2 [==============================] - 0s - loss: 40987.3594
Epoch 19/100
2/2 [==============================] - 0s - loss: 40953.0820
Epoch 20/100
2/2 [==============================] - 0s - loss: 40918.4883
Epoch 21/100
2/2 [==============================] - 0s - loss: 40886.5742
Epoch 22/100
2/2 [==============================] - 0s - loss: 40862.0352
Epoch 23/100
2/2 [==============================] - 0s - loss: 40843.4922
Epoch 24/100
2/2 [==============================] - 0s - loss: 40825.7656
Epoch 25/100
2/2 [==============================] - 0s - loss: 40804.1406
Epoch 26/100
2/2 [==============================] - 0s - loss: 40777.8867
Epoch 27/100
2/2 [==============================] - 0s - loss: 40750.5430
Epoch 28/100
2/2 [==============================] - 0s - loss: 40727.1562
Epoch 29/100
2/2 [==============================] - 0s - loss: 40710.6641
Epoch 30/100
2/2 [==============================] - 0s - loss: 40698.1992
Epoch 31/100
2/2 [==============================] - 0s - loss: 40683.6562
Epoch 32/100
2/2 [==============================] - 0s - loss: 40665.2148
Epoch 33/100
2/2 [==============================] - 0s - loss: 40645.2969
Epoch 34/100
2/2 [==============================] - 0s - loss: 40625.0312
Epoch 35/100
2/2 [==============================] - 0s - loss: 40601.2969
Epoch 36/100
2/2 [==============================] - 0s - loss: 40569.1562
Epoch 37/100
2/2 [==============================] - 0s - loss: 40525.4609
Epoch 38/100
2/2 [==============================] - 0s - loss: 40468.8555
Epoch 39/100
2/2 [==============================] - 0s - loss: 40395.4023
Epoch 40/100
2/2 [==============================] - 0s - loss: 40304.5508
Epoch 41/100
2/2 [==============================] - 0s - loss: 40206.1797
Epoch 42/100
2/2 [==============================] - 0s - loss: 40129.2031
Epoch 43/100
2/2 [==============================] - 0s - loss: 40084.5312
Epoch 44/100
2/2 [==============================] - 0s - loss: 40056.1484
Epoch 45/100
2/2 [==============================] - 0s - loss: 40028.7031
Epoch 46/100
2/2 [==============================] - 0s - loss: 39992.0859
Epoch 47/100
2/2 [==============================] - 0s - loss: 39949.6055
Epoch 48/100
2/2 [==============================] - 0s - loss: 39910.0977
Epoch 49/100
2/2 [==============================] - 0s - loss: 39880.2891
Epoch 50/100
2/2 [==============================] - 0s - loss: 39861.0547
Epoch 51/100
2/2 [==============================] - 0s - loss: 39847.1367
Epoch 52/100
2/2 [==============================] - 0s - loss: 39832.7266
Epoch 53/100
2/2 [==============================] - 0s - loss: 39815.3906
Epoch 54/100
2/2 [==============================] - 0s - loss: 39795.0391
Epoch 55/100
2/2 [==============================] - 0s - loss: 39774.4844
Epoch 56/100
2/2 [==============================] - 0s - loss: 39756.1836
Epoch 57/100
2/2 [==============================] - 0s - loss: 39741.1094
Epoch 58/100
2/2 [==============================] - 0s - loss: 39726.8242
Epoch 59/100
2/2 [==============================] - 0s - loss: 39711.3398
Epoch 60/100
2/2 [==============================] - 0s - loss: 39693.2500
Epoch 61/100
2/2 [==============================] - 0s - loss: 39672.6133
Epoch 62/100
2/2 [==============================] - 0s - loss: 39650.0625
Epoch 63/100
2/2 [==============================] - 0s - loss: 39626.8828
Epoch 64/100
2/2 [==============================] - 0s - loss: 39603.8516
Epoch 65/100
2/2 [==============================] - 0s - loss: 39581.0312
Epoch 66/100
2/2 [==============================] - 0s - loss: 39557.2812
Epoch 67/100
2/2 [==============================] - 0s - loss: 39531.6328
Epoch 68/100
2/2 [==============================] - 0s - loss: 39503.5469
Epoch 69/100
2/2 [==============================] - 0s - loss: 39473.0156
Epoch 70/100
2/2 [==============================] - 0s - loss: 39440.6016
Epoch 71/100
2/2 [==============================] - 0s - loss: 39407.0703
Epoch 72/100
2/2 [==============================] - 0s - loss: 39372.6016
Epoch 73/100
2/2 [==============================] - 0s - loss: 39337.7109
Epoch 74/100
2/2 [==============================] - 0s - loss: 39302.1562
Epoch 75/100
2/2 [==============================] - 0s - loss: 39266.4922
Epoch 76/100
2/2 [==============================] - 0s - loss: 39230.3906
Epoch 77/100
2/2 [==============================] - 0s - loss: 39193.1172
Epoch 78/100
2/2 [==============================] - 0s - loss: 39154.0977
Epoch 79/100
2/2 [==============================] - 0s - loss: 39112.9805
Epoch 80/100
2/2 [==============================] - 0s - loss: 39070.0508
Epoch 81/100
2/2 [==============================] - 0s - loss: 39024.9766
Epoch 82/100
2/2 [==============================] - 0s - loss: 38977.0781
Epoch 83/100
2/2 [==============================] - 0s - loss: 38926.2891
Epoch 84/100
2/2 [==============================] - 0s - loss: 38872.8398
Epoch 85/100
2/2 [==============================] - 0s - loss: 38816.4609
Epoch 86/100
2/2 [==============================] - 0s - loss: 38755.0078
Epoch 87/100
2/2 [==============================] - 0s - loss: 38686.6523
Epoch 88/100
2/2 [==============================] - 0s - loss: 38608.8555
Epoch 89/100
2/2 [==============================] - 0s - loss: 38518.7812
Epoch 90/100
2/2 [==============================] - 0s - loss: 38414.9609
Epoch 91/100
2/2 [==============================] - 0s - loss: 38300.5156
Epoch 92/100
2/2 [==============================] - 0s - loss: 38178.5312
Epoch 93/100
2/2 [==============================] - 0s - loss: 38051.2617
Epoch 94/100
2/2 [==============================] - 0s - loss: 37914.9375
Epoch 95/100
2/2 [==============================] - 0s - loss: 37762.3711
Epoch 96/100
2/2 [==============================] - 0s - loss: 37616.4688
Epoch 97/100
2/2 [==============================] - 0s - loss: 37469.2852
Epoch 98/100
2/2 [==============================] - 0s - loss: 37289.8828
Epoch 99/100
2/2 [==============================] - 0s - loss: 37071.5977
Epoch 100/100
2/2 [==============================] - 0s - loss: 36819.3398
Out[128]:
<keras.callbacks.History at 0x19a9b9b45c0>

We use learning rate annealing to get a better fit.


In [131]:
K.set_value(m_sr.optimizer.lr, 1e-4)
m_sr.fit([arr_lr, arr_hr], targ, 16, 100)


Epoch 1/100
2/2 [==============================] - 0s - loss: 25916.2148
Epoch 2/100
2/2 [==============================] - 0s - loss: 25896.2539
Epoch 3/100
2/2 [==============================] - 0s - loss: 25876.3164
Epoch 4/100
2/2 [==============================] - 0s - loss: 25856.3223
Epoch 5/100
2/2 [==============================] - 0s - loss: 25836.3242
Epoch 6/100
2/2 [==============================] - 0s - loss: 25816.3438
Epoch 7/100
2/2 [==============================] - 0s - loss: 25796.3477
Epoch 8/100
2/2 [==============================] - 0s - loss: 25776.5156
Epoch 9/100
2/2 [==============================] - 0s - loss: 25756.6367
Epoch 10/100
2/2 [==============================] - 0s - loss: 25736.9297
Epoch 11/100
2/2 [==============================] - 0s - loss: 25717.2422
Epoch 12/100
2/2 [==============================] - 0s - loss: 25697.5898
Epoch 13/100
2/2 [==============================] - 0s - loss: 25677.9727
Epoch 14/100
2/2 [==============================] - 0s - loss: 25658.4336
Epoch 15/100
2/2 [==============================] - 0s - loss: 25638.9844
Epoch 16/100
2/2 [==============================] - 0s - loss: 25619.6016
Epoch 17/100
2/2 [==============================] - 0s - loss: 25600.2637
Epoch 18/100
2/2 [==============================] - 0s - loss: 25581.0273
Epoch 19/100
2/2 [==============================] - 0s - loss: 25561.8477
Epoch 20/100
2/2 [==============================] - 0s - loss: 25542.7070
Epoch 21/100
2/2 [==============================] - 0s - loss: 25523.5742
Epoch 22/100
2/2 [==============================] - 0s - loss: 25504.4844
Epoch 23/100
2/2 [==============================] - 0s - loss: 25485.4180
Epoch 24/100
2/2 [==============================] - 0s - loss: 25466.4238
Epoch 25/100
2/2 [==============================] - 0s - loss: 25447.3867
Epoch 26/100
2/2 [==============================] - 0s - loss: 25428.3242
Epoch 27/100
2/2 [==============================] - 0s - loss: 25409.2695
Epoch 28/100
2/2 [==============================] - 0s - loss: 25390.1719
Epoch 29/100
2/2 [==============================] - 0s - loss: 25371.1016
Epoch 30/100
2/2 [==============================] - 0s - loss: 25351.9922
Epoch 31/100
2/2 [==============================] - 0s - loss: 25332.8867
Epoch 32/100
2/2 [==============================] - 0s - loss: 25313.8711
Epoch 33/100
2/2 [==============================] - 0s - loss: 25294.8496
Epoch 34/100
2/2 [==============================] - 0s - loss: 25275.7539
Epoch 35/100
2/2 [==============================] - 0s - loss: 25256.5625
Epoch 36/100
2/2 [==============================] - 0s - loss: 25237.4609
Epoch 37/100
2/2 [==============================] - 0s - loss: 25218.3086
Epoch 38/100
2/2 [==============================] - 0s - loss: 25199.0781
Epoch 39/100
2/2 [==============================] - 0s - loss: 25179.9180
Epoch 40/100
2/2 [==============================] - 0s - loss: 25160.7031
Epoch 41/100
2/2 [==============================] - 0s - loss: 25141.5254
Epoch 42/100
2/2 [==============================] - 0s - loss: 25122.4199
Epoch 43/100
2/2 [==============================] - 0s - loss: 25103.2969
Epoch 44/100
2/2 [==============================] - 0s - loss: 25084.2539
Epoch 45/100
2/2 [==============================] - 0s - loss: 25065.2324
Epoch 46/100
2/2 [==============================] - 0s - loss: 25046.2773
Epoch 47/100
2/2 [==============================] - 0s - loss: 25027.3164
Epoch 48/100
2/2 [==============================] - 0s - loss: 25008.4414
Epoch 49/100
2/2 [==============================] - 0s - loss: 24989.6328
Epoch 50/100
2/2 [==============================] - 0s - loss: 24970.8906
Epoch 51/100
2/2 [==============================] - 0s - loss: 24952.0938
Epoch 52/100
2/2 [==============================] - 0s - loss: 24933.3066
Epoch 53/100
2/2 [==============================] - 0s - loss: 24914.5684
Epoch 54/100
2/2 [==============================] - 0s - loss: 24895.9180
Epoch 55/100
2/2 [==============================] - 0s - loss: 24877.3848
Epoch 56/100
2/2 [==============================] - 0s - loss: 24858.9004
Epoch 57/100
2/2 [==============================] - 0s - loss: 24840.4629
Epoch 58/100
2/2 [==============================] - 0s - loss: 24822.0703
Epoch 59/100
2/2 [==============================] - 0s - loss: 24803.7793
Epoch 60/100
2/2 [==============================] - 0s - loss: 24785.5742
Epoch 61/100
2/2 [==============================] - 0s - loss: 24767.4922
Epoch 62/100
2/2 [==============================] - 0s - loss: 24749.5352
Epoch 63/100
2/2 [==============================] - 0s - loss: 24731.6348
Epoch 64/100
2/2 [==============================] - 0s - loss: 24714.0156
Epoch 65/100
2/2 [==============================] - 0s - loss: 24696.6953
Epoch 66/100
2/2 [==============================] - 0s - loss: 24679.3125
Epoch 67/100
2/2 [==============================] - 0s - loss: 24661.8008
Epoch 68/100
2/2 [==============================] - 0s - loss: 24644.0938
Epoch 69/100
2/2 [==============================] - 0s - loss: 24626.8633
Epoch 70/100
2/2 [==============================] - 0s - loss: 24610.1836
Epoch 71/100
2/2 [==============================] - 0s - loss: 24593.3281
Epoch 72/100
2/2 [==============================] - 0s - loss: 24576.2793
Epoch 73/100
2/2 [==============================] - 0s - loss: 24559.1055
Epoch 74/100
2/2 [==============================] - 0s - loss: 24542.0391
Epoch 75/100
2/2 [==============================] - 0s - loss: 24525.4141
Epoch 76/100
2/2 [==============================] - 0s - loss: 24509.1191
Epoch 77/100
2/2 [==============================] - 0s - loss: 24493.1523
Epoch 78/100
2/2 [==============================] - 0s - loss: 24477.0664
Epoch 79/100
2/2 [==============================] - 0s - loss: 24460.4883
Epoch 80/100
2/2 [==============================] - 0s - loss: 24443.8809
Epoch 81/100
2/2 [==============================] - 0s - loss: 24427.4688
Epoch 82/100
2/2 [==============================] - 0s - loss: 24411.2695
Epoch 83/100
2/2 [==============================] - 0s - loss: 24395.2227
Epoch 84/100
2/2 [==============================] - 0s - loss: 24379.5352
Epoch 85/100
2/2 [==============================] - 0s - loss: 24364.4473
Epoch 86/100
2/2 [==============================] - 0s - loss: 24349.9727
Epoch 87/100
2/2 [==============================] - 0s - loss: 24335.1074
Epoch 88/100
2/2 [==============================] - 0s - loss: 24319.8828
Epoch 89/100
2/2 [==============================] - 0s - loss: 24302.3145
Epoch 90/100
2/2 [==============================] - 0s - loss: 24285.4785
Epoch 91/100
2/2 [==============================] - 0s - loss: 24270.1758
Epoch 92/100
2/2 [==============================] - 0s - loss: 24255.7773
Epoch 93/100
2/2 [==============================] - 0s - loss: 24241.3633
Epoch 94/100
2/2 [==============================] - 0s - loss: 24225.7383
Epoch 95/100
2/2 [==============================] - 0s - loss: 24209.8828
Epoch 96/100
2/2 [==============================] - 0s - loss: 24193.6797
Epoch 97/100
2/2 [==============================] - 0s - loss: 24179.1289
Epoch 98/100
2/2 [==============================] - 0s - loss: 24165.8379
Epoch 99/100
2/2 [==============================] - 0s - loss: 24151.6504
Epoch 100/100
2/2 [==============================] - 0s - loss: 24136.0508
Out[131]:
<keras.callbacks.History at 0x19a9bcbb0f0>

We are only interested in the trained part of the model, which does the actual upsampling.


In [132]:
top_model = Model(inp, outp)

In [133]:
inp, outp, arr_lr.shape


Out[133]:
(<tf.Tensor 'input_17:0' shape=(?, 72, 72, 3) dtype=float32>,
 <tf.Tensor 'lambda_11/mul:0' shape=(?, 288, 288, 3) dtype=float32>,
 (2, 72, 72, 3))

In [134]:
#p = top_model.predict(arr_lr[10:11])
# - here using the alternative image
p = top_model.predict(arr_lr[0:2])

After training for some time, we get some very impressive results! Look at these two images, we can see that the predicted higher resolution image has filled in a lot of detail, including the shadows under the greens and the texture of the food.


In [135]:
plt.imshow(arr_lr[1].astype('uint8'));
plt.show()
plt.imshow(p[1].astype('uint8'));
plt.show()
#using layers.upscaling2d



In [117]:
plt.imshow(arr_lr[1].astype('uint8'));
plt.show()
plt.imshow(p[1].astype('uint8'));
plt.show()
#using conv2d transpose



In [77]:
p.shape


Out[77]:
(2, 288, 288, 3)

In [79]:
top_model.save_weights(dpath+'top_final.h5')

The important thing to take away here is that as opposed to our earlier approaches, this type of approach results in a model that can created the desired image and is a scalable implementation.

Note that we haven't used a test set here, so we don't know if the above result is due to over-fitting. As part of your homework, you should create a test set, and try to train a model that gets the best result you can on the test set.

Fast style transfer

The original paper showing the above approach to super resolution also used this approach to create a much faster style transfer system (for a specific style). Take a look at the paper and the very helpful supplementary material. Your mission, should you choose to accept it, it to modify the super resolution example above to do fast style transfer based on this paper.

Reflection padding

The supplementary material mentions that that found reflection padding helpful - we have implemented this as a keras layer for you. All the other layers and blocks are already defined above.

(This is also a nice simple example of a custom later that you can refer to when creating your own custom layers in the future.)


In [80]:
# - Original version
#class ReflectionPadding2D(Layer):
#    def __init__(self, padding=(1, 1), **kwargs):
#        self.padding = tuple(padding)
#        self.input_spec = [InputSpec(ndim=4)]
#        super(ReflectionPadding2D, self).__init__(**kwargs)
#        
#    def get_output_shape_for(self, s):
#        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
#
#    def call(self, x, mask=None):
#        w_pad,h_pad = self.padding
#        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

In [81]:
# Keras 2 version
class ReflectionPadding2D(Layer):
    def __init__(self, output_dim, padding = (1, 1), **kwargs):
        self.padding = padding
        self.output_dim = output_dim
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def build(self, input_shape):
        super(ReflectionPadding2D, self).build(input_shape)

    def call(self, x):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

    def compute_output_shape(self, input_shape):
        return (input_shape[0], 
                input_shape[1] + 2 * self.padding[0], 
                input_shape[2] + 2 * self.padding[1], 
                input_shape[3])

Testing the reflection padding layer:


In [82]:
inp = Input((288,288,3))
ref_model = Model(inp, ReflectionPadding2D((1,288,288,3), padding=(40,10))(inp))
ref_model.compile('adam', 'mse')

In [83]:
p = ref_model.predict(arr_hr[0:1])
p.shape


Out[83]:
(1, 308, 368, 3)

In [84]:
plt.imshow(arr_hr[0].astype('uint8'));  # - source image



In [85]:
plt.imshow(p[0].astype('uint8'));  # - the "REFLECT" effect is visible close to the borders


Main algorithm

This approach is exactly the same as super resolution, except now the loss includes the style loss.


In [86]:
shp = arr_hr.shape[1:]
shp


Out[86]:
(288, 288, 3)

In [87]:
style = Image.open('data/imagenet/starry_night.jpg')
style = style.resize((shp[0], shp[1]));
#style = style.resize(np.divide(style.size,3.5).astype('int32'))
style = np.array(style)[:shp[0], :shp[1], :shp[2]]
plt.imshow(style);



In [88]:
def res_crop_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1), 'valid')
    x = conv_block(x,  nf, 3, (1,1), 'valid', False)
    ip = Lambda(lambda x: x[:, 2:-2, 2:-2])(ip)
    return add([x, ip])

In [89]:
inp=Input(shp)
x=ReflectionPadding2D(((1,shp[0],shp[1],shp[2])), padding=(40, 40))(inp)
x=conv_block(x, 64, 9, (1,1))
x=conv_block(x, 64, 3)
x=conv_block(x, 64, 3)
for i in range(5): x=res_crop_block(x)
x=up_block(x, 64, 3)
x=up_block(x, 64, 3)
x=Conv2D(3, (9, 9), activation='tanh', padding='same')(x)
outp=Lambda(lambda x: (x+1)*127.5)(x)
outp


Out[89]:
<tf.Tensor 'lambda_9/mul:0' shape=(?, 288, 288, 3) dtype=float32>

In [90]:
vgg_inp=Input(shp)
vgg= VGG16(include_top=False, input_tensor=Lambda(preproc)(vgg_inp))
for l in vgg.layers: l.trainable=False

In [91]:
def get_outp(m, ln): return m.get_layer('block{}_conv2'.format(ln)).output
vgg_content = Model(vgg_inp, [get_outp(vgg, o) for o in [2,3,4,5]])

Here we alter the super resolution approach by adding style outputs


In [92]:
style_targs = [K.variable(o) for o in
               vgg_content.predict(np.expand_dims(style,0))]

In [93]:
[K.eval(K.shape(o)) for o in style_targs]


Out[93]:
[array([  1, 144, 144, 128], dtype=int32),
 array([  1,  72,  72, 256], dtype=int32),
 array([  1,  36,  36, 512], dtype=int32),
 array([  1,  18,  18, 512], dtype=int32)]

In [94]:
vgg1 = vgg_content(vgg_inp)
vgg2 = vgg_content(outp)
vgg1, vgg2


Out[94]:
([<tf.Tensor 'model_9/block2_conv2/Relu:0' shape=(?, 144, 144, 128) dtype=float32>,
  <tf.Tensor 'model_9/block3_conv2/Relu:0' shape=(?, 72, 72, 256) dtype=float32>,
  <tf.Tensor 'model_9/block4_conv2/Relu:0' shape=(?, 36, 36, 512) dtype=float32>,
  <tf.Tensor 'model_9/block5_conv2/Relu:0' shape=(?, 18, 18, 512) dtype=float32>],
 [<tf.Tensor 'model_9_1/block2_conv2/Relu:0' shape=(?, 144, 144, 128) dtype=float32>,
  <tf.Tensor 'model_9_1/block3_conv2/Relu:0' shape=(?, 72, 72, 256) dtype=float32>,
  <tf.Tensor 'model_9_1/block4_conv2/Relu:0' shape=(?, 36, 36, 512) dtype=float32>,
  <tf.Tensor 'model_9_1/block5_conv2/Relu:0' shape=(?, 18, 18, 512) dtype=float32>])

Our loss now includes the mse for the content loss and the gram matrix for the style


In [95]:
def gram_matrix_b(x):
    x = K.permute_dimensions(x, (0, 3, 1, 2))
    s = K.shape(x)
    feat = K.reshape(x, (s[0], s[1], s[2]*s[3]))
    return K.batch_dot(feat, K.permute_dimensions(feat, (0, 2, 1))
                      ) / K.prod(K.cast(s[1:], K.floatx()))

In [96]:
w=[0.1, 0.2, 0.6, 0.1]
def tot_loss(x):
    loss = 0; n = len(style_targs)
    for i in range(n):
        loss += mean_sqr_b(gram_matrix_b(x[i+n]) - gram_matrix_b(style_targs[i])) / 20.
        loss += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return loss

In [97]:
loss = Lambda(tot_loss)(vgg1+vgg2)

In [98]:
m_style = Model([inp, vgg_inp], loss)
targ = np.zeros((arr_hr.shape[0], 1))

In [99]:
m_style.compile('adam', 'mae')  # because we want the minimum absolute error to be the closest to 0
m_style.fit([arr_hr, arr_hr], targ, 8, 100, **parms)



Out[99]:
<keras.callbacks.History at 0x7f602e528b38>

In [100]:
K.set_value(m_style.optimizer.lr, 1e-4)
m_style.fit([arr_hr, arr_hr], targ, 16, 100, **parms)



Out[100]:
<keras.callbacks.History at 0x7f602db9a128>

In [101]:
top_model = Model(inp, outp)

Now we can pass any image through this CNN and it will produce it in the style desired!


In [102]:
p = top_model.predict(arr_hr[0:1])
p.shape


Out[102]:
(1, 288, 288, 3)

In [103]:
plt.imshow(np.round(p[0]).astype('uint8'));



In [104]:
top_model.save_weights(dpath+'style_final.h5')

In [105]:
top_model.load_weights(dpath+'style_final.h5')