In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.io
import scipy.misc
import tensorflow as tf
import numpy as np
import time
from IPython.display import Image


//anaconda/lib/python2.7/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.
  warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')

In [2]:
content_file_name = '2-input.jpg'
style_file_name = '2-style.jpg'
fig = plt.figure()
a=fig.add_subplot(1,2,1)
img_content = mpimg.imread('./images/'+content_file_name)
imgplot = plt.imshow(img_content)
a.set_title('Content')

img_style = mpimg.imread('./images/'+style_file_name)
a=fig.add_subplot(1,2,2)
imgplot = plt.imshow(img_style)
a.set_title('Style')


Out[2]:
<matplotlib.text.Text at 0x1104932d0>

In [3]:
data = scipy.io.loadmat('./imagenet-vgg-verydeep-19.mat')

image_content = scipy.misc.imread('./images/'+content_file_name)
image_content = image_content.astype('float32')
image_content = np.ndarray.reshape(image_content,((1,) + image_content.shape)) # 1=batch_size

image_style = scipy.misc.imread('./images/'+style_file_name)
image_style = image_style.astype('float32')
image_style = np.ndarray.reshape(image_style,((1,) + image_style.shape))

In [4]:
def _conv_layer(input, weights, bias):
    conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),
            padding='SAME')
    return tf.nn.bias_add(conv, bias)

def _pool_layer(input):
    return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
            padding='SAME')

def preprocess(image, mean_pixel):
    return (image - mean_pixel).astype('float32')

def unprocess(image, mean_pixel):
    return (image + mean_pixel).astype('float32')

In [5]:
def net(input_image):
    layers = (
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',

        'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',

        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
        'relu3_3', 'conv3_4', 'relu3_4', 'pool3',

        'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
        'relu4_3', 'conv4_4', 'relu4_4', 'pool4',

        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
        'relu5_3', 'conv5_4', 'relu5_4'
    )
    weight = data['layers'][0]
    net = {}
    current = input_image
    for i, name in enumerate(layers):
        kind = name[:4]
        if kind == 'conv':
            kernels, bias = weight[i][0][0][0][0]
            # matconvnet: weights are [width, height, in_channels, out_channels]
            # tensorflow: weights are [height, width, in_channels, out_channels]
            kernels = np.transpose(kernels, (1, 0, 2, 3))
            bias = bias.reshape(-1)
            current = _conv_layer(current, kernels, bias)
        elif kind == 'relu':
            current = tf.nn.relu(current)
        elif kind == 'pool':
            current = _pool_layer(current)
        net[name] = current

    assert len(net) == len(layers)
    return net#, mean_pixel

In [6]:
mean = data['normalization'][0][0][0]
mean_pixel = np.mean(mean, axis=(0, 1))

In [7]:
CONTENT_LAYERS = ('conv1_1', 'conv2_1', 'conv4_1', 'conv4_2')
content_features = {}

with tf.Session() as sess:
    content_pre = preprocess(image_content, mean_pixel)
    content_net = net(content_pre)
    for layer in CONTENT_LAYERS:
        content_features[layer] = content_net[layer].eval()

In [9]:
STYLE_LAYERS = ('conv3_1','conv5_1')
style_features = {}

with tf.Session() as sess:
    style_pre = preprocess(image_style, mean_pixel)
    style_net = net(style_pre)
    for layer in STYLE_LAYERS:
        features = style_net[layer].eval()
        features = np.reshape(features, (-1, features.shape[3]))
        gram = np.matmul(features.T, features) / features.size
        style_features[layer] = gram

In [10]:
# make stylized image using backpropogation
initial = None
#initial = scipy.misc.imread('./images/cat.jpg')
if initial is None:
    noise = np.random.normal(size=image_content.shape, scale=np.std(image_content) * 0.1)
    initial = tf.random_normal(image_content.shape) * 0.256
else:
    initial = np.array([preprocess(initial, mean_pixel)])
    initial = initial.astype('float32')

In [11]:
image = tf.Variable(initial)
image_net = net(image)

In [12]:
content_weight = 5e0
style_weight = 1e4
tv_weight = 1e3
learning_rate = 1e0
iterations =  6000
checkpoint_iterations = 200
print_iterations = 100

In [13]:
# content loss
content_loss = 0
content_losses = []
for content_layer in CONTENT_LAYERS:
    content_losses.append(2 * tf.nn.l2_loss(
                          image_net[content_layer] - content_features[content_layer]) / 
                          content_features[content_layer].size)
content_loss += content_weight * reduce(tf.add, content_losses)

In [14]:
# style loss
style_loss = 0
style_losses = []
for style_layer in STYLE_LAYERS:
    layer = image_net[style_layer]
    _, height, width, number = map(lambda i: i.value, layer.get_shape())
    size = height * width * number
    feats = tf.reshape(layer, (-1, number))
    gram = tf.matmul(tf.transpose(feats), feats) / size
    style_gram = style_features[style_layer]
    style_losses.append(2 * tf.nn.l2_loss(gram - style_gram) / style_gram.size)
style_loss += style_weight * reduce(tf.add, style_losses)

In [15]:
def _tensor_size(tensor):
    from operator import mul
    return reduce(mul, (d.value for d in tensor.get_shape()), 1)

In [16]:
# total variation denoising
tv_y_size = _tensor_size(image[:,1:,:,:])
tv_x_size = _tensor_size(image[:,:,1:,:])
tv_loss = tv_weight * 2 * (
        (tf.nn.l2_loss(image[:,1:,:,:] - image[:,:image_content.shape[1]-1,:,:]) /
            tv_y_size) +
        (tf.nn.l2_loss(image[:,:,1:,:] - image[:,:,:image_content.shape[2]-1,:]) /
            tv_x_size))

In [17]:
# overall loss
loss = content_loss + style_loss + tv_loss

In [18]:
# optimizer setup
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)

In [19]:
def imsave(path, img):
    img = np.clip(img, 0, 255).astype(np.uint8)
    scipy.misc.imsave(path, img)

In [20]:
# optimization
best_loss = float('inf')
best = None

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    for i in range(iterations):
        train_step.run()
        
        if i % checkpoint_iterations == 0 or i == iterations - 1:
            this_loss = loss.eval()
            if this_loss < best_loss:
                best_loss = this_loss
                best = image.eval()
            # save a check point
            import os
            try:
                os.makedirs('./checks/'+str.split(content_file_name,'.')[0])
            except OSError:
                pass
            timestr = time.strftime("%Y%m%d_%H%M%S")
            filename_cp = './checks/'+str.split(content_file_name,'.')[0]+'/'+timestr+'.jpg'
            cp = unprocess(best.reshape(image_content.shape[1:]), mean_pixel)
            imsave(filename_cp, cp)
        
        if i % print_iterations == 0 or i == iterations - 1:
            print('Iteration %d/%d' % (i + 1, iterations))
            print('  content loss: %g' % content_loss.eval())
            print('    style loss: %g' % style_loss.eval())
            print('       tv loss: %g' % tv_loss.eval())
            print('    total loss: %g' % loss.eval())

    output = unprocess(best.reshape(image_content.shape[1:]), mean_pixel)


Iteration 1/6000
  content loss: 1.45456e+07
    style loss: 8.238e+08
       tv loss: 2920.19
    total loss: 8.38349e+08
Iteration 101/6000
  content loss: 1.71363e+07
    style loss: 2.35337e+07
       tv loss: 944722
    total loss: 4.16148e+07
Iteration 201/6000
  content loss: 1.39873e+07
    style loss: 8.80304e+06
       tv loss: 989392
    total loss: 2.37798e+07
Iteration 301/6000
  content loss: 1.21304e+07
    style loss: 5.60181e+06
       tv loss: 977295
    total loss: 1.87095e+07
Iteration 401/6000
  content loss: 1.09504e+07
    style loss: 4.30421e+06
       tv loss: 947938
    total loss: 1.62025e+07
Iteration 501/6000
  content loss: 1.014e+07
    style loss: 3.63718e+06
       tv loss: 914164
    total loss: 1.46914e+07
Iteration 601/6000
  content loss: 9.54293e+06
    style loss: 3.2415e+06
       tv loss: 879743
    total loss: 1.36642e+07
Iteration 701/6000
  content loss: 9.06731e+06
    style loss: 2.98122e+06
       tv loss: 847179
    total loss: 1.28957e+07
Iteration 801/6000
  content loss: 8.69361e+06
    style loss: 2.78917e+06
       tv loss: 817482
    total loss: 1.23003e+07
Iteration 901/6000
  content loss: 8.38562e+06
    style loss: 2.64359e+06
       tv loss: 790488
    total loss: 1.18197e+07
Iteration 1001/6000
  content loss: 8.1257e+06
    style loss: 2.5339e+06
       tv loss: 766354
    total loss: 1.14259e+07
Iteration 1101/6000
  content loss: 7.91101e+06
    style loss: 2.44007e+06
       tv loss: 744842
    total loss: 1.10959e+07
Iteration 1201/6000
  content loss: 7.72257e+06
    style loss: 2.36288e+06
       tv loss: 725272
    total loss: 1.08107e+07
Iteration 1301/6000
  content loss: 7.55557e+06
    style loss: 2.29984e+06
       tv loss: 707536
    total loss: 1.05629e+07
Iteration 1401/6000
  content loss: 7.41258e+06
    style loss: 2.24467e+06
       tv loss: 691483
    total loss: 1.03487e+07
Iteration 1501/6000
  content loss: 7.28559e+06
    style loss: 2.19753e+06
       tv loss: 676879
    total loss: 1.016e+07
Iteration 1601/6000
  content loss: 7.17118e+06
    style loss: 2.15626e+06
       tv loss: 663371
    total loss: 9.99081e+06
Iteration 1701/6000
  content loss: 7.06762e+06
    style loss: 2.12165e+06
       tv loss: 650914
    total loss: 9.84019e+06
Iteration 1801/6000
  content loss: 6.97661e+06
    style loss: 2.08812e+06
       tv loss: 639644
    total loss: 9.70437e+06
Iteration 1901/6000
  content loss: 6.88842e+06
    style loss: 2.06075e+06
       tv loss: 629664
    total loss: 9.57884e+06
Iteration 2001/6000
  content loss: 6.80365e+06
    style loss: 2.04203e+06
       tv loss: 620196
    total loss: 9.46588e+06
Iteration 2101/6000
  content loss: 6.74468e+06
    style loss: 2.0148e+06
       tv loss: 611366
    total loss: 9.37085e+06
Iteration 2201/6000
  content loss: 6.66755e+06
    style loss: 2.00953e+06
       tv loss: 603298
    total loss: 9.28037e+06
Iteration 2301/6000
  content loss: 6.61104e+06
    style loss: 1.99062e+06
       tv loss: 595596
    total loss: 9.19726e+06
Iteration 2401/6000
  content loss: 6.57475e+06
    style loss: 1.96249e+06
       tv loss: 588631
    total loss: 9.12587e+06
Iteration 2501/6000
  content loss: 6.52539e+06
    style loss: 1.94363e+06
       tv loss: 582340
    total loss: 9.05135e+06
Iteration 2601/6000
  content loss: 6.48602e+06
    style loss: 1.93208e+06
       tv loss: 576591
    total loss: 8.9947e+06
Iteration 2701/6000
  content loss: 6.42153e+06
    style loss: 1.943e+06
       tv loss: 571283
    total loss: 8.93582e+06
Iteration 2801/6000
  content loss: 6.38218e+06
    style loss: 1.93778e+06
       tv loss: 566325
    total loss: 8.88629e+06
Iteration 2901/6000
  content loss: 6.36222e+06
    style loss: 1.9044e+06
       tv loss: 561833
    total loss: 8.82845e+06
Iteration 3001/6000
  content loss: 6.33661e+06
    style loss: 1.89787e+06
       tv loss: 558014
    total loss: 8.7925e+06
Iteration 3101/6000
  content loss: 6.30424e+06
    style loss: 1.88487e+06
       tv loss: 554421
    total loss: 8.74353e+06
Iteration 3201/6000
  content loss: 6.2612e+06
    style loss: 1.88327e+06
       tv loss: 551004
    total loss: 8.69547e+06
Iteration 3301/6000
  content loss: 6.24989e+06
    style loss: 1.86532e+06
       tv loss: 547885
    total loss: 8.6631e+06
Iteration 3401/6000
  content loss: 6.21691e+06
    style loss: 1.87775e+06
       tv loss: 544967
    total loss: 8.63962e+06
Iteration 3501/6000
  content loss: 6.19975e+06
    style loss: 1.86389e+06
       tv loss: 542277
    total loss: 8.60591e+06
Iteration 3601/6000
  content loss: 6.18184e+06
    style loss: 1.87553e+06
       tv loss: 540227
    total loss: 8.5976e+06
Iteration 3701/6000
  content loss: 6.16689e+06
    style loss: 1.87111e+06
       tv loss: 537816
    total loss: 8.57582e+06
Iteration 3801/6000
  content loss: 6.12208e+06
    style loss: 1.86431e+06
       tv loss: 536459
    total loss: 8.52285e+06
Iteration 3901/6000
  content loss: 6.10593e+06
    style loss: 1.85141e+06
       tv loss: 535055
    total loss: 8.4924e+06
Iteration 4001/6000
  content loss: 6.11338e+06
    style loss: 1.85987e+06
       tv loss: 534216
    total loss: 8.50746e+06
Iteration 4101/6000
  content loss: 6.07489e+06
    style loss: 1.83417e+06
       tv loss: 532277
    total loss: 8.44134e+06
Iteration 4201/6000
  content loss: 6.08235e+06
    style loss: 1.82851e+06
       tv loss: 530829
    total loss: 8.44169e+06
Iteration 4301/6000
  content loss: 6.04602e+06
    style loss: 1.8395e+06
       tv loss: 530254
    total loss: 8.41578e+06
Iteration 4401/6000
  content loss: 6.02953e+06
    style loss: 1.82771e+06
       tv loss: 528866
    total loss: 8.3861e+06
Iteration 4501/6000
  content loss: 6.01068e+06
    style loss: 1.83505e+06
       tv loss: 527664
    total loss: 8.3734e+06
Iteration 4601/6000
  content loss: 6.01666e+06
    style loss: 1.81668e+06
       tv loss: 526446
    total loss: 8.35978e+06
Iteration 4701/6000
  content loss: 5.99093e+06
    style loss: 1.83917e+06
       tv loss: 526503
    total loss: 8.3566e+06
Iteration 4801/6000
  content loss: 5.99876e+06
    style loss: 1.81112e+06
       tv loss: 525510
    total loss: 8.33539e+06
Iteration 4901/6000
  content loss: 5.97459e+06
    style loss: 1.8081e+06
       tv loss: 524787
    total loss: 8.30747e+06
Iteration 5001/6000
  content loss: 5.95366e+06
    style loss: 1.83287e+06
       tv loss: 523557
    total loss: 8.31008e+06
Iteration 5101/6000
  content loss: 5.94490e+06
    style loss: 1.81389e+06
       tv loss: 522984
    total loss: 8.28178e+06
Iteration 5201/6000
  content loss: 5.95548e+06
    style loss: 1.80234e+06
       tv loss: 521861
    total loss: 8.27969e+06
Iteration 5301/6000
  content loss: 5.92303e+06
    style loss: 1.83088e+06
       tv loss: 521488
    total loss: 8.2754e+06
Iteration 5401/6000
  content loss: 5.93582e+06
    style loss: 1.79992e+06
       tv loss: 521146
    total loss: 8.25688e+06
Iteration 5501/6000
  content loss: 5.9325e+06
    style loss: 1.79246e+06
       tv loss: 520493
    total loss: 8.24546e+06
Iteration 5601/6000
  content loss: 5.89781e+06
    style loss: 1.81681e+06
       tv loss: 519540
    total loss: 8.23416e+06
Iteration 5701/6000
  content loss: 5.92547e+06
    style loss: 1.80645e+06
       tv loss: 518775
    total loss: 8.25069e+06
Iteration 5801/6000
  content loss: 5.92043e+06
    style loss: 1.79625e+06
       tv loss: 518712
    total loss: 8.2354e+06
Iteration 5901/6000
  content loss: 5.88064e+06
    style loss: 1.81199e+06
       tv loss: 518351
    total loss: 8.21098e+06
Iteration 6000/6000
  content loss: 5.87406e+06
    style loss: 1.80675e+06
       tv loss: 517611
    total loss: 8.19842e+06

In [21]:
imsave('./images/output_'+content_file_name, output)

In [22]:
Image(filename = './images/output_'+content_file_name)


Out[22]:

In [23]:
output_file_name = 'output_'+content_file_name

fig = plt.figure()
a=fig.add_subplot(1,3,1)
img_content = mpimg.imread('./images/'+content_file_name)
imgplot = plt.imshow(img_content)
a.set_title('Content')

img_style = mpimg.imread('./images/'+style_file_name)
a=fig.add_subplot(1,3,2)
imgplot = plt.imshow(img_style)
a.set_title('Style')

img_output = mpimg.imread('./images/'+output_file_name)
a=fig.add_subplot(1,3,3)
imgplot = plt.imshow(img_output)
a.set_title('Output')


Out[23]:
<matplotlib.text.Text at 0x1afee9250>

In [ ]: