Style Transfer Interpolation

Modified version of the https://github.com/fchollet/keras/blob/master/examples/neural_style_transfer.py neural style transfer. Rather than having one style image, we have two. We first train the network for a number of iterations to get to converge on the first image, before interpolating for the rest of the iterations between the first and the second.

Interpolation is done by having step variable that controls the weight of the loss


In [38]:
from keras.preprocessing.image import load_img, img_to_array
from scipy.misc import imsave
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
import time
import argparse
import os
import imageio
from IPython.display import Image, display, HTML

from keras.applications import vgg16
from keras import backend as K
import requests

Some functions to convert from and to the internal format used. Make sure we end up with images of the right size and take care of the colors being at the position that the selected back end expects.


In [25]:
def preprocess_image(image_path):
    img = load_img(image_path, target_size=(img_nrows, img_ncols))
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = vgg16.preprocess_input(img)
    return img

# util function to convert a tensor into a valid image


def deprocess_image(x):
    if K.image_data_format() == 'channels_first':
        x = x.reshape((3, img_nrows, img_ncols))
        x = x.transpose((1, 2, 0))
    else:
        x = x.reshape((img_nrows, img_ncols, 3))
    # Remove zero-center by mean pixel
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # 'BGR'->'RGB'
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype('uint8')
    return x

Fetch the three images we need from the Internet if they aren't already loaded. Base Image is the image that we'll render a style transfered version of, style reference 1 & 2 are the style images to interpolate between.


In [26]:
IMAGE_BASE = 'style_transfer/'

def fetch_image(url, fn=None):
    if not fn:
        fn = url.rsplit('/', 1)[-1]
    fn = IMAGE_BASE + fn
    if os.path.isfile(fn):
        return fn
    img = requests.get(url).content
    with open(fn, 'wb') as fout:
        fout.write(img)
    return fn

base_image_path = fetch_image('https://upload.wikimedia.org/wikipedia/commons/0/08/Okerk2.jpg')
style_reference_image_path_2 = fetch_image('https://upload.wikimedia.org/wikipedia/commons/9/99/Jan_van_Goyen_004b.jpg')
style_reference_image_path_1 = fetch_image('https://upload.wikimedia.org/wikipedia/commons/6/66/VanGogh-starry_night_ballance1.jpg')
Image(filename=base_image_path)

Some working variables and constants we are going to need. Change img_nrows to influence the size of the output picture. The various *_weight variables will determine how much influence the three loss functions have (see below). Iterations is the total of number of iterations, the first warm_up will be used to get to a stable version of the image, the rest will be used for the interpolation frames. result_prefix is the prefix used for the intermediate images


In [16]:
width, height = load_img(base_image_path).size
img_nrows = 400
img_ncols = int(width * img_nrows / height)

total_variation_weight = 1.0
style_weight = 1.0
content_weight = 0.025

iterations = 120
warm_up = 30

result_prefix = 'star_goyen'

Load the images into three variables and create a place holder for the resulting combination image. Then create a tensor that contains all four of them next to each other so we can process them together in the same way as we'd process a mini-batch when training.


In [31]:
base_image = K.variable(preprocess_image(base_image_path))
style_reference_image_1 = K.variable(preprocess_image(style_reference_image_path_1))
style_reference_image_2 = K.variable(preprocess_image(style_reference_image_path_2))

# this will contain our generated image
if K.image_data_format() == 'channels_first':
    combination_image = K.placeholder((1, 3, img_nrows, img_ncols))
else:
    combination_image = K.placeholder((1, img_nrows, img_ncols, 3))

# combine the 4 images into a single Keras tensor
input_tensor = K.concatenate([base_image,
                              style_reference_image_1,
                              style_reference_image_2,
                              combination_image], axis=0)
input_tensor


Out[31]:
<tf.Tensor 'concat_1:0' shape=(4, 400, 632, 3) dtype=float32>

Now load the pretrained vgg16 model and load it up with the four images as input.


In [7]:
model = vgg16.VGG16(input_tensor=input_tensor,
                    weights='imagenet', include_top=False)
print('Model loaded.')


Model loaded.

Create a simple helper dictionary to get to the outputs of the layers in the model:


In [16]:
outputs_dict = {layer.name: layer.output for layer in model.layers}

The clever bit

Now for the clever bit. We define three loss functions that we combine into one overall loss function to optimize. Each of the three functions tries to control one aspect of the process:

  • Style Loss - keep the style of the target image close to the style we selected
  • Content Loss - keep the overall image similar to the base image
  • Variation - supress local variation to keep the image locally coherent

Style loss

We need two functions. gram_matrix calculates the feature activation of an image by taking the outer product of a particular layer. We then calculate the style loss between our target image and the source of the style summing the squares of the gram_matrix() for these images scaled for the size of the tensor.


In [ ]:
def gram_matrix(x):
    assert K.ndim(x) == 3
    if K.image_data_format() == 'channels_first':
        features = K.batch_flatten(x)
    else:
        features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
    gram = K.dot(features, K.transpose(features))
    return gram

def style_loss(style, combination):
    assert K.ndim(style) == 3
    assert K.ndim(combination) == 3
    S = gram_matrix(style)
    C = gram_matrix(combination)
    channels = 3
    size = img_nrows * img_ncols
    return K.sum(K.square(S - C)) / (4. * (channels ** 2) * (size ** 2))

Content loss

Content is the straight forward sum of squares between the base and the target image.


In [ ]:
def content_loss(base, combination):
    return K.sum(K.square(combination - base))

Variation Loss

Minimize the local variation by comparing effectively pixels next to each other and minimizing the variation. This keeps the resulting image somewhat fuzzy, but it avoids large jumps in the pixels and keeps the image locally coherent.


In [9]:
def total_variation_loss(x):
    assert K.ndim(x) == 4
    if K.image_data_format() == 'channels_first':
        a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1])
        b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:])
    else:
        a = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
        b = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
    return K.sum(K.pow(a + b, 1.25))

The combined loss function

Now combine the three loss functions from above into one that we can optimize for. We use selected layers from the neural network for style and content loss, while the variation loss is calculated on just the resulting image.

We introduce an extra place holder variable "step" indicating where we are transitioning from one style to another and calculate the style loss for both style images and scale the result accordingly between these two loss values. Finally we calculate the gradient with regard to the combination image.


In [12]:
loss = K.variable(0.)
layer_features = outputs_dict['block4_conv2']
base_image_features = layer_features[0, :, :, :]
combination_features = layer_features[3, :, :, :]
loss += content_weight * content_loss(base_image_features,
                                      combination_features)
step = K.placeholder()
feature_layers = ['block1_conv1', 'block2_conv1',
                  'block3_conv1', 'block4_conv1',
                  'block5_conv1']
for layer_name in feature_layers:
    layer_features = outputs_dict[layer_name]
    style_reference_features_1 = layer_features[1, :, :, :]
    style_reference_features_2 = layer_features[2, :, :, :]
    combination_features = layer_features[3, :, :, :]
    sl_1 = style_loss(style_reference_features_1, combination_features) * step
    sl_2 = style_loss(style_reference_features_2, combination_features) * (1 - step)
    loss += (style_weight / len(feature_layers)) * (sl_1 + sl_2)
loss += total_variation_weight * total_variation_loss(combination_image)

# get the gradients of the generated image wrt the loss
grads = K.gradients(loss, combination_image)

outputs = [loss]
if isinstance(grads, (list, tuple)):
    outputs += grads
else:
    outputs.append(grads)

f_outputs = K.function([combination_image, step], outputs)

Evalutor object

Some extra plumbing to make it possible to use the above in combination with fmin_l_bfgs_b. Scipy.optimize requires two seperate loss and grads functions so rather than calculating those twice, we cache the values in an Evaluator object. We use the same object to store the value of where we are between the two style images.


In [13]:
def eval_loss_and_grads(x, perc):
    if K.image_data_format() == 'channels_first':
        x = x.reshape((1, 3, img_nrows, img_ncols))
    else:
        x = x.reshape((1, img_nrows, img_ncols, 3))
    outs = f_outputs([x, perc])
    loss_value = outs[0]
    if len(outs[1:]) == 1:
        grad_values = outs[1].flatten().astype('float64')
    else:
        grad_values = np.array(outs[1:]).flatten().astype('float64')
    return loss_value, grad_values

class Evaluator(object):

    def __init__(self):
        self.loss_value = None
        self.grads_values = None
        self.perc = 0

    def loss(self, x):
        assert self.loss_value is None
        loss_value, grad_values = eval_loss_and_grads(x, self.perc)
        self.loss_value = loss_value
        self.grad_values = grad_values
        return self.loss_value

    def grads(self, x):
        assert self.loss_value is not None
        grad_values = np.copy(self.grad_values)
        self.loss_value = None
        self.grad_values = None
        return grad_values

evaluator = Evaluator()

Run the model

Now run the model. We start with a randomized image and then use the scipy-based L-BFGS algorithm to optimize the pixels. The first number of "warm_up" iterations we just target the first style image. After that we start to interpolate. While interpolating we store the name of the intermediate image in the frames variable.


In [ ]:
if K.image_data_format() == 'channels_first':
    x = np.random.uniform(0, 255, (1, 3, img_nrows, img_ncols)) - 128.
else:
    x = np.random.uniform(0, 255, (1, img_nrows, img_ncols, 3)) - 128.


frames = []
for i in range(0, iterations):
    start_time = time.time()
    if i > warm_up:
        frames.append(fname)
        evaluator.perc = float(i - warm_up) / (iterations - warm_up - 1)
    else:
        evaluator.perc = 0
    print('Start of iteration', i, evaluator.perc)
    x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(),
                                     fprime=evaluator.grads, maxfun=20)
    print('Current loss value:', min_val)
    # save current generated image
    img = deprocess_image(x.copy())
    fname = IMAGE_BASE + result_prefix + '_at_iteration_%d.png' % i
    imsave(fname, img)
    end_time = time.time()
    print('Image saved as', fname)
    print('Iteration %d completed in %ds' % (i, end_time - start_time))


Start of iteration 0 0
Current loss value: 2.77769e+09
Image saved as style_transfer/star_goyen_at_iteration_0.png
Iteration 0 completed in 311s
Start of iteration 1 0
Current loss value: 2.19096e+09
Image saved as style_transfer/star_goyen_at_iteration_1.png
Iteration 1 completed in 291s
Start of iteration 2 0
Current loss value: 1.96724e+09
Image saved as style_transfer/star_goyen_at_iteration_2.png
Iteration 2 completed in 292s
Start of iteration 3 0
Current loss value: 1.84749e+09
Image saved as style_transfer/star_goyen_at_iteration_3.png
Iteration 3 completed in 292s
Start of iteration 4 0
Current loss value: 1.77549e+09
Image saved as style_transfer/star_goyen_at_iteration_4.png
Iteration 4 completed in 292s
Start of iteration 5 0
Current loss value: 1.7281e+09
Image saved as style_transfer/star_goyen_at_iteration_5.png
Iteration 5 completed in 312s
Start of iteration 6 0
Current loss value: 1.69238e+09
Image saved as style_transfer/star_goyen_at_iteration_6.png
Iteration 6 completed in 292s
Start of iteration 7 0
Current loss value: 1.66306e+09
Image saved as style_transfer/star_goyen_at_iteration_7.png
Iteration 7 completed in 293s
Start of iteration 8 0
Current loss value: 1.64026e+09
Image saved as style_transfer/star_goyen_at_iteration_8.png
Iteration 8 completed in 292s
Start of iteration 9 0
Current loss value: 1.62088e+09
Image saved as style_transfer/star_goyen_at_iteration_9.png
Iteration 9 completed in 300s
Start of iteration 10 0
Current loss value: 1.60472e+09
Image saved as style_transfer/star_goyen_at_iteration_10.png
Iteration 10 completed in 301s
Start of iteration 11 0
Current loss value: 1.59089e+09
Image saved as style_transfer/star_goyen_at_iteration_11.png
Iteration 11 completed in 293s
Start of iteration 12 0
Current loss value: 1.5785e+09
Image saved as style_transfer/star_goyen_at_iteration_12.png
Iteration 12 completed in 292s
Start of iteration 13 0
Current loss value: 1.56806e+09
Image saved as style_transfer/star_goyen_at_iteration_13.png
Iteration 13 completed in 292s
Start of iteration 14 0
Current loss value: 1.55861e+09
Image saved as style_transfer/star_goyen_at_iteration_14.png
Iteration 14 completed in 305s
Start of iteration 15 0
Current loss value: 1.55002e+09
Image saved as style_transfer/star_goyen_at_iteration_15.png
Iteration 15 completed in 297s
Start of iteration 16 0
Current loss value: 1.54235e+09
Image saved as style_transfer/star_goyen_at_iteration_16.png
Iteration 16 completed in 293s
Start of iteration 17 0
Current loss value: 1.53529e+09
Image saved as style_transfer/star_goyen_at_iteration_17.png
Iteration 17 completed in 292s
Start of iteration 18 0
Current loss value: 1.52888e+09
Image saved as style_transfer/star_goyen_at_iteration_18.png
Iteration 18 completed in 293s
Start of iteration 19 0
Current loss value: 1.52311e+09
Image saved as style_transfer/star_goyen_at_iteration_19.png
Iteration 19 completed in 306s
Start of iteration 20 0
Current loss value: 1.51797e+09
Image saved as style_transfer/star_goyen_at_iteration_20.png
Iteration 20 completed in 294s
Start of iteration 21 0
Current loss value: 1.51328e+09
Image saved as style_transfer/star_goyen_at_iteration_21.png
Iteration 21 completed in 293s
Start of iteration 22 0
Current loss value: 1.50881e+09
Image saved as style_transfer/star_goyen_at_iteration_22.png
Iteration 22 completed in 292s
Start of iteration 23 0
Current loss value: 1.50453e+09
Image saved as style_transfer/star_goyen_at_iteration_23.png
Iteration 23 completed in 293s
Start of iteration 24 0
Current loss value: 1.50043e+09
Image saved as style_transfer/star_goyen_at_iteration_24.png
Iteration 24 completed in 310s
Start of iteration 25 0
Current loss value: 1.49666e+09
Image saved as style_transfer/star_goyen_at_iteration_25.png
Iteration 25 completed in 292s
Start of iteration 26 0
Current loss value: 1.493e+09
Image saved as style_transfer/star_goyen_at_iteration_26.png
Iteration 26 completed in 293s
Start of iteration 27 0
Current loss value: 1.48948e+09
Image saved as style_transfer/star_goyen_at_iteration_27.png
Iteration 27 completed in 293s
Start of iteration 28 0
Current loss value: 1.48604e+09
Image saved as style_transfer/star_goyen_at_iteration_28.png
Iteration 28 completed in 293s
Start of iteration 29 0
Current loss value: 1.48271e+09
Image saved as style_transfer/star_goyen_at_iteration_29.png
Iteration 29 completed in 307s
Start of iteration 30 0
Current loss value: 1.47949e+09
Image saved as style_transfer/star_goyen_at_iteration_30.png
Iteration 30 completed in 293s
Start of iteration 31 0.011235955056179775
Current loss value: 2.38148e+09
Image saved as style_transfer/star_goyen_at_iteration_31.png
Iteration 31 completed in 293s
Start of iteration 32 0.02247191011235955
Current loss value: 3.26933e+09
Image saved as style_transfer/star_goyen_at_iteration_32.png
Iteration 32 completed in 293s
Start of iteration 33 0.033707865168539325
Current loss value: 4.14178e+09
Image saved as style_transfer/star_goyen_at_iteration_33.png
Iteration 33 completed in 293s
Start of iteration 34 0.0449438202247191
Current loss value: 4.99854e+09
Image saved as style_transfer/star_goyen_at_iteration_34.png
Iteration 34 completed in 309s
Start of iteration 35 0.056179775280898875
Current loss value: 5.83909e+09
Image saved as style_transfer/star_goyen_at_iteration_35.png
Iteration 35 completed in 293s
Start of iteration 36 0.06741573033707865
Current loss value: 6.6628e+09
Image saved as style_transfer/star_goyen_at_iteration_36.png
Iteration 36 completed in 293s
Start of iteration 37 0.07865168539325842
Current loss value: 7.46884e+09
Image saved as style_transfer/star_goyen_at_iteration_37.png
Iteration 37 completed in 293s
Start of iteration 38 0.0898876404494382
Current loss value: 8.25536e+09
Image saved as style_transfer/star_goyen_at_iteration_38.png
Iteration 38 completed in 301s
Start of iteration 39 0.10112359550561797

To display the result as an animated gif, we create a cycle from the frames variable and then use imageio to create an animated gif.


In [18]:
cycled = frames + list(reversed(frames[1:-1]))
# Save them as frames into a gif 
kargs = { 'duration': 0.1 }
imageio.mimsave(IMAGE_BASE + 'animated.gif', [imageio.imread(x) for x in cycled], 'GIF', **kargs)

HTML('<img src="%s">' % (IMAGE_BASE + 'animated.gif'))


In [ ]: