L9 HW: Perceptual Loss for Real-Time Style Transfer

FADL2 Cutting Edge Deep Learning

Lesson 9 HW: Perceptual Loss for Real-Time Style Transfer

SEP 08 2017 - WH Nixalo

NOTE: this should more/less follow along w/ the 'Reflection Padding' section under 'Fast Style Transfer' in the neural-style Lesson 9 JNB.

Imports / Setup


In [1]:
%matplotlib inline
import os, sys; sys.path.insert(1, os.path.join('../utils'))
from utils2 import *
from vgg16_avg import VGG16_Avg
from bcolz_array_iterator import BcolzArrayIterator
from tqdm import tqdm

limit_mem()


Using TensorFlow backend.
/home/wnixalo/miniconda3/envs/FAI3/lib/python3.6/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
  "This module will be removed in 0.20.", DeprecationWarning)

In [2]:
path = '../data/'
dirpath = path + 'lesson9/style-test/'

Pre-De-Processing for VGG model:


In [3]:
rn_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape((1,1,1,3))
preproc = lambda x: (x - rn_mean)[:, :, :, ::-1]
deproc = lambda x,s: np.clip(x.reshape(s)[:, :, :, ::-1] + rn_mean, 0, 255)

Function Definitions:


In [4]:
def conv_block(x, filters, size, stride=(2,2), mode='same', act=True):
    x = Convolution2D(filters, size, size, subsample=stride, border_mode=mode)(x)
    x = BatchNormalization(mode=2)(x)
    return Activation('relu')(x) if act else x

def res_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1))
    x = conv_block(ip, nf, 3, (1,1), act=False)
    return merge([x, ip], mode='sum')

def up_block(x, filters, size):
    x = keras.layers.UpSampling2D()(x)
    x = Convolution2D(filters, size, size, border_mode='same')(x)
    x = BatchNormalization(mode=2)(x)
    return Activation('relu')(x)

#### The Perceptual (Content) Loss Function ####
def get_outp(m, λn): return m.get_layer(f'block{λn}_conv2').output

def mean_sqr_b(diff):
    dims = list(range(1, K.ndim(diff)))
    return K.expand_dims(K.sqrt(K.mean(diff**2, dims)), 0)

def content_fn(x):
    res = 0; n = len(w)
    for i in range(n): res += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return res
#### #### #### #### #### #### #### #### #### ####

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1,1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)
    def get_output_shape_for(self, s):
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad, h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad, h_pad], [w_pad, w_pad], [0,0]], 'REFLECT')
    
def res_crop_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1), 'valid')
    x = conv_block(x,  nf, 3, (1,1), 'valid', False)
    ip = Lambda(lambda x: x[:, 2:-2, 2:-2])(ip)
    return merge([x, ip], mode='sum')

# there was talk about removing the final activation and lambda layer; so I'm doing that here
def get_model(shp):
    inp = Input(shp)
    x = ReflectionPadding2D((40, 40))(inp)
    x = conv_block(x, 64, 9, (1,1))
    x = conv_block(x, 64, 3)
    x = conv_block(x, 64, 3)
    for i in range(5): x = res_crop_block(x)
    x = up_block(x, 64, 3)
    x = up_block(x, 64, 3)
    x = Convolution2D(3, 9, 9, activation='tanh', border_mode='same')(x)
    outp = Lambda(lambda x: (x+1)*127.5)(x)
    return inp, outp

def gram_matrix_b(x):
    x = K.permute_dimensions(x, (0, 3, 1, 2))
    s = K.shape(x)
    feat = K.reshape(x, (s[0], s[1], s[2]*s[3]))
    return K.batch_dot(feat, K.permute_dimensions(feat, (0, 2, 1))
                      ) / K.prod(K.cast(s[1:], K.floatx()))

def tot_loss(x):
    loss = 0; n = len(style_targs)
    for i in range(n):
        loss += mean_sqr_b(gram_matrix_b(x[i+n]) - gram_matrix_b(style_targs[i])) / 2.
        loss += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return loss

Bcolz -- Batched Data Array Iterating:


In [15]:
# question I have is what to do with the target parameter: 1 long zerod-array,
# or a batch-sized zerod-array?
#     targ = np.zeros((arr_hr.shape[0], 1))
#     targ = np.zeros((bs, 1))

def train(model, bs, niter=10):
    targ = np.zeros((bs, 1))
    bc = BcolzArrayIterator(arr_hr_c4, arr_hr_c4, batch_size=bs)
    for i in tqdm(range(niter)):
        hr1, hr2 = next(bc)
        model.train_on_batch([hr1[:bs], hr2[:bs]], targ[:len(hr1)])

Model:


In [16]:
# arr_lr_c6 = bcolz.open(path + 'trn_resized_72_c6.bc')
# arr_hr_c6 = bcolz.open(path + 'trn_resized_288_c6.bc')

arr_hr_c4 = bcolz.carray(bcolz.open(path+'trn_resized_288_c6.bc'),
                         chunklen=4, rootdir=path+'trn_resized_288_c4.bc')
# arr_hr_c4 = bcolz.open(path + 'trn_resized_288_c4.bc')

# pars = {'verbose': 0, 'callbacks': [TQDMNotebookCallback(leave_inner=True)]}

In [17]:
shp = arr_hr_c4[0].shape # same as: arr_hr_c4.shape[1:]

style_img = Image.open(dirpath + 'alena-aenami-eclipse-1k-square.jpg')
# style_img = style_img.resize(np.divide(style.size, 3.5).astype('int32'))
style_img = np.array(style_img)[:shp[0], :shp[1], :shp[2]]

inp, outp = get_model(shp)

vgg_inp = Input(shp)
vgg = VGG16(include_top=False, input_tensor=Lambda(preproc)(vgg_inp))
for λ in vgg.layers: λ.trainable = False

vgg_content = Model(vgg_inp, [get_outp(vgg, o) for o in [2,3,4,5]])

# super-resolution approach is altered by adding style outputs
style_targs = [K.variable(o) for o in 
               vgg_content.predict(np.expand_dims(style_img,0))]

vgg1 = vgg_content(vgg_inp)
vgg2 = vgg_content(outp)

# conv_block weighting
w = [0.1, 0.2, 0.6, 0.1]

loss = Lambda(tot_loss)(vgg1+vgg2)
m_style = Model([inp, vgg_inp], loss)
targ = np.zeros((arr_hr_c4.shape[0], 1))

m_style.compile('adam', 'mse')
# m_style.fit([arr_hr, arr_hr], targ, 8, 2, **pars)

In [18]:
train(m_style, bs=4, niter=(len(arr_hr_c4)//4 + 1))


  0%|          | 0/4860 [00:00<?, ?it/s]
100%|██████████| 4860/4860 [2:47:30<00:00,  4.28s/it]  

Just remembered this counts as a single epoch.


In [19]:
m_style.save_weights(path + 'lesson9/results/' + 'L9hw_m_style.h5')

In [20]:
niter = len(arr_hr_c4)//4 + 1

K.set_value(m_style.optimizer.lr, 1e-4)
train(m_style, bs=4, niter=niter)

m_style.save_weights(path + 'lesson9/results/' + 'L9hw_m_style.h5')


100%|██████████| 4860/4860 [2:46:45<00:00,  1.91s/it]  

Test:


In [6]:
%ls $dirpath


4cf708972aca0466d1edd5731b5096fa--caucasus-central.jpg
4cf708972aca0466d1edd5731b5096fa--caucasus-central-square.jpg
alena-aenami-eclipse-1k.jpg
alena-aenami-eclipse-1k-square.jpg

In [23]:
# style_img = Image.open(dirpath + 'alena-aenami-eclipse-1k-square.jpg')
targt_img = Image.open(dirpath + '4cf708972aca0466d1edd5731b5096fa--caucasus-central-square.jpg')

In [8]:
style_img


Out[8]:

In [9]:
targt_img


Out[9]:

In [21]:
top_model = Model(inp, outp)

In [87]:
# p = top_model.predict([targt_img])

# p = top_model.predict(np.expand_dims(np.array(targ_img), 0))

targ_img = targt_img.resize(np.divide(targt_img.size, 1.5).astype('int32'))
targ_img = np.array(targ_img)[30:30+shp[0], 20:20+shp[1], :shp[2]]
# targ_img = np.expand_dims(np.array(targ_img),0) # this & the line below work the same
targ_img = np.expand_dims(targ_img, 0)
p = top_model.predict(targ_img)
plt.imshow(p[0].astype('uint8'))

# plt.imshow(np.round(p[0]).astype('uint8'))


Out[87]:
<matplotlib.image.AxesImage at 0x7fd9b8d1a4e0>

In [43]:
top_model.save_weights(path + 'lesson9/results/' + 'L9hw_top_model.h5')

In [63]:
targ_img.shape


Out[63]:
(1, 288, 288, 3)

In [67]:
targ_img[0].shape


Out[67]:
(288, 288, 3)

In [71]:
targt_img.shape


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-71-fe0b56ca08eb> in <module>()
----> 1 targt_img.shape

AttributeError: 'JpegImageFile' object has no attribute 'shape'

In [91]:
plt.figure(figsize=(2,10))
# targ_img = np.expand_dims(targ_img, 0)
plt.imshow(p[0].astype('uint8'))


Out[91]:
<matplotlib.image.AxesImage at 0x7fd9b88cdf98>

In [ ]: