Style transfer/ super resolution implementation in pytorch.
In [1]:
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *
from vgg16_avg import VGG16_Avg
from keras import metrics
from scipy.optimize import fmin_l_bfgs_b
from scipy.misc import imsave
In [3]:
import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.autograd import Variable
from torch.utils.serialization import load_lua
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
In [4]:
# path = '/data/jhoward/imagenet/sample/'
# dpath = '/data/jhoward/fast/imagenet/sample/'
path = 'data/imagenet/'
dpath = 'data/imagenet/'
In [5]:
fnames = pickle.load(open(dpath+'fnames.pkl', 'rb'))
n = len(fnames); n
Out[5]:
In [6]:
img=Image.open(fnames[50]); img
Out[6]:
In [7]:
rn_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape((1,1,1,3))
preproc = lambda x: (x - rn_mean)[:, :, :, ::-1]
In [8]:
img_arr = preproc(np.expand_dims(np.array(img), 0))
shp = img_arr.shape
In [9]:
deproc = lambda x,s: np.clip(x.reshape(s)[:, :, :, ::-1] + rn_mean, 0, 255)
In [10]:
def download_convert_vgg16_model():
model_url='http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7'
file = get_file(model_url, cache_subdir='models')
vgglua = load_lua(file).parameters()
vgg = models.VGGFeature()
for (src, dst) in zip(vgglua[0], vgg.parameters()): dst[:] = src[:]
torch.save(vgg.state_dict(), dpath+'vgg16_feature.pth')
In [11]:
url = 'https://s3-us-west-2.amazonaws.com/jcjohns-models/'
fname = 'vgg16-00b39a1b.pth'
file = get_file(fname, url+fname, cache_subdir='models')
In [12]:
vgg = models.vgg.vgg16()
vgg.load_state_dict(torch.load(file))
optimizer = optim.Adam(vgg.parameters())
In [ ]:
vgg.cuda();
In [ ]:
arr_lr = bcolz.open(dpath+'trn_resized_72.bc')[:]
arr_hr = bcolz.open(dpath+'trn_resized_288.bc')[:]
In [ ]:
arr = bcolz.open(dpath+'trn_resized.bc')[:]
In [ ]:
x = Variable(arr[0])
y = model(x)
In [ ]:
url = 'http://www.platform.ai/models/'
fname = 'imagenet_class_index.json'
fpath = get_file(fname, url+fname, cache_subdir='models')
In [ ]:
class ResidualBlock(nn.Module):
def __init__(self, num):
super(ResidualBlock, self).__init__()
self.c1 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1)
self.c2 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1)
self.b1 = nn.BatchNorm2d(num)
self.b2 = nn.BatchNorm2d(num)
def forward(self, x):
h = F.relu(self.b1(self.c1(x)))
h = self.b2(self.c2(h))
return h + x
In [ ]:
class FastStyleNet(nn.Module):
def __init__(self):
super(FastStyleNet, self).__init__()
self.cs = [nn.Conv2d(3, 32, kernel_size=9, stride=1, padding=4)
,nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
,nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)]
self.b1s = [nn.BatchNorm2d(i) for i in [32, 64, 128]]
self.rs = [ResidualBlock(128) for i in range(5)]
self.ds [nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
,nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)]
self.b2s = [nn.BatchNorm2d(i) for i in [64, 32]]
self.d3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)
def forward(self, h):
for i in range(3): h = F.relu(self.b1s[i](self.cs[i](x)))
for r in self.rs: h = r(h)
for i in range(2): h = F.relu(self.b2s[i](self.ds[i](x)))
return self.d3(h)
In [ ]:
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
return features.bmm(features.transpose(1, 2)) / (ch * h * w)
In [ ]:
def vgg_preprocessing(batch):
tensortype = type(batch.data)
mean = tensortype(batch.data.size())
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
batch -= Variable(mean)
In [ ]:
def save_model(model, filename):
state = model.state_dict()
for key in state: state[key] = state[key].clone().cpu()
torch.save(state, filename)
In [ ]:
def tensor_save_rgbimage(tensor, filename):
img = tensor.clone().cpu().clamp(0, 255).numpy()
img = img.transpose(1, 2, 0).astype('uint8')
img = Image.fromarray(img)
img.save(filename)
In [ ]:
def tensor_save_bgrimage(tensor, filename):
(b, g, r) = torch.chunk(tensor, 3)
tensor = torch.cat((r, g, b))
tensor_save_rgbimage(tensor, filename)
In [ ]:
def tensor_load_rgbimage(filename, size=None):
img = Image.open(filename)
if size is not None: img = img.resize((size, size), Image.ANTIALIAS)
img = np.array(img).transpose(2, 0, 1)
img = torch.from_numpy(img).float()
return img
In [ ]:
def batch_rgb_to_bgr(batch):
batch = batch.transpose(0, 1)
(r, g, b) = torch.chunk(batch, 3)
batch = torch.cat((b, g, r))
batch = batch.transpose(0, 1)
return batch
In [ ]:
def batch_bgr_to_rgb(batch):
return batch_rgb_to_bgr(batch)
In [ ]:
base = K.variable(img_arr)
gen_img = K.placeholder(shp)
batch = K.concatenate([base, gen_img],0)
In [ ]:
model = VGG16_Avg(input_tensor=batch, include_top=False)
In [ ]:
outputs = {l.name: l.output for l in model.layers}
In [ ]:
layer = outputs['block5_conv1']
In [ ]:
class Evaluator(object):
def __init__(self, f, shp):
self.f = f
self.shp = shp
def loss(self, x):
loss_, grads_ = self.f([x.reshape(self.shp)])
self.grad_values = grads_.flatten().astype(np.float64)
return loss_.astype(np.float64)
def grads(self, x): return np.copy(self.grad_values)
In [ ]:
content_loss = lambda base, gen: metrics.mse(gen, base)
loss = content_loss(layer[0], layer[1])
grads = K.gradients(loss, gen_img)
fn = K.function([gen_img], [loss]+grads)
In [ ]:
evaluator = Evaluator(fn, shp)
In [ ]:
rand_img = lambda shape: np.random.uniform(-2.5, 2.5, shape)/100
In [ ]:
def solve_image(eval_obj, niter, x):
for i in range(niter):
x, min_val, info = fmin_l_bfgs_b(eval_obj.loss, x.flatten(),
fprime=eval_obj.grads, maxfun=20)
x = np.clip(x, -127,127)
print('Current loss value:', min_val)
imsave('{}res_at_iteration_{}.png'.format(path, i), deproc(x.copy(), shp)[0])
return x
In [ ]:
iterations=10
x = rand_img(shp)
In [ ]:
x = solve_image(evaluator, iterations, x)
conv 1 of last block (5)
In [ ]:
Image.open(path + 'res_at_iteration_9.png')
conv 1 of 4th block
In [ ]:
Image.open(path + 'res_at_iteration_9.png')
In [ ]:
def plot_arr(arr): plt.imshow(deproc(arr,arr.shape)[0].astype('uint8'))
In [ ]:
style = Image.open('data/starry_night.jpg')
style = style.resize(np.divide(style.size,3.5).astype('int32')); style
In [ ]:
style = Image.open('data/bird.jpg')
style = style.resize(np.divide(style.size,2.4).astype('int32')); style
In [ ]:
style = Image.open('data/simpsons.jpg')
style = style.resize(np.divide(style.size,2.7).astype('int32')); style
In [ ]:
w,h = style.size
In [ ]:
src = img_arr[:,:h,:w]
shp = src.shape
style_arr = preproc(np.expand_dims(style,0)[:,:,:,:3])
plot_arr(src)
In [ ]:
base = K.variable(style_arr)
gen_img = K.placeholder(shp)
batch = K.concatenate([base, gen_img],0)
In [ ]:
model = VGG16_Avg(input_tensor=batch, include_top=False)
outputs = {l.name: l.output for l in model.layers}
In [ ]:
layers = [outputs['block{}_conv1'.format(o)] for o in range(1,4)]
In [ ]:
def gram_matrix(x):
features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
return K.dot(features, K.transpose(features)) / x.get_shape().num_elements()
In [ ]:
def style_loss(x, targ):
return keras.metrics.mse(gram_matrix(x), gram_matrix(targ))
In [ ]:
loss = sum(style_loss(l[0], l[1]) for l in layers)
grads = K.gradients(loss, gen_img)
style_fn = K.function([gen_img], [loss]+grads)
In [ ]:
evaluator = Evaluator(style_fn, shp)
In [ ]:
iterations=10
x = rand_img(shp)
In [ ]:
x = solve_image(evaluator, iterations, x)
In [ ]:
Image.open(path + 'res_at_iteration_9.png')
In [ ]:
Image.open(path + 'res_at_iteration_9.png')
In [ ]:
def total_variation_loss(x, r, c):
assert K.ndim(x) == 3
a = K.square(x[:r - 1, :c - 1, :] - x[1:, :c - 1, :])
b = K.square(x[:r - 1, :c - 1, :] - x[:r - 1, 1:, :])
return K.sum(K.pow(a + b, 1.25))
In [ ]:
base = K.variable(src)
style_v = K.variable(style_arr)
gen_img = K.placeholder(shp)
batch = K.concatenate([base, style_v, gen_img],0)
In [ ]:
model = VGG16_Avg(input_tensor=batch, include_top=False)
outputs = {l.name: l.output for l in model.layers}
In [ ]:
style_layers = [outputs['block{}_conv1'.format(o)] for o in range(1,6)]
In [ ]:
content_name = 'block4_conv2'
In [ ]:
content_layer = outputs[content_name]
In [ ]:
input_layer = model.layers[0].output
In [ ]:
loss = sum(style_loss(l[1], l[2]) for l in style_layers)
loss += content_loss(content_layer[0], content_layer[2])/10.
# loss += total_variation_loss(input_layer[2], h, w)/1e9
grads = K.gradients(loss, gen_img)
transfer_fn = K.function([gen_img], [loss]+grads)
In [ ]:
evaluator = Evaluator(transfer_fn, shp)
In [ ]:
iterations=10
x = rand_img(shp)/10.
In [ ]:
x = solve_image(evaluator, iterations, x)
In [ ]:
Image.open(path + 'res_at_iteration_9.png')
In [ ]:
Image.open(path + 'res_at_iteration_9.png')
In [ ]:
Image.open(path + 'res_at_iteration_9.png')
In [ ]:
inp_shape = (72,72,3)
inp = Input(inp_shape)
In [ ]:
class ReflectionPadding2D(Layer):
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
self.input_spec = [InputSpec(ndim=4)]
super(ReflectionPadding2D, self).__init__(**kwargs)
def get_output_shape_for(self, s):
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
def call(self, x, mask=None):
w_pad,h_pad = self.padding
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
In [ ]:
ref_model = Model(inp, ReflectionPadding2D((60,20))(inp))
ref_model.compile('adam', 'mse')
In [ ]:
p = ref_model.predict(arr_lr[50:51])
In [ ]:
plt.imshow(p[0].astype('uint8'))
In [ ]:
def conv_block(x, filters, size, stride=(2,2), mode='same'):
x = Convolution2D(filters, size, size, subsample=stride, border_mode=mode)(x)
x = BatchNormalization(axis=1, mode=2)(x)
return Activation('relu')(x)
In [ ]:
def res_block(ip, nf=64):
x = conv_block(ip, nf, 3, (1,1))
x = Convolution2D(nf, 3, 3, border_mode='same')(x)
x = BatchNormalization(axis=1, mode=2)(x)
# ip = Lambda(lambda x: x[:, 2:-2, 2:-2])(ip)
return merge([x, ip], mode='sum')
In [ ]:
def deconv_block(x, filters, size, shape, stride=(2,2)):
x = Deconvolution2D(filters, size, size, subsample=stride, border_mode='same',
output_shape=(None,)+shape)(x)
x = BatchNormalization(axis=1, mode=2)(x)
return Activation('relu')(x)
In [ ]:
parms = {'verbose': 0, 'callbacks': [TQDMNotebookCallback(leave_inner=True)]}
In [ ]:
inp=Input(inp_shape)
# x=ReflectionPadding2D((40, 40))(inp)
x=conv_block(inp, 64, 9, (1,1))
# x=conv_block(x, 64, 3)
# x=conv_block(x, 128, 3)
for i in range(4): x=res_block(x)
x=deconv_block(x, 64, 3, (144, 144, 64))
x=deconv_block(x, 64, 3, (288, 288, 64))
x=Convolution2D(3, 9, 9, activation='tanh', border_mode='same')(x)
outp=Lambda(lambda x: (x+1)*127.5)(x)
In [ ]:
vgg_l = Lambda(preproc)
outp_l = vgg_l(outp)
In [ ]:
out_shape = (288,288,3)
vgg_inp=Input(out_shape)
vgg= VGG16(include_top=False, input_tensor=vgg_l(vgg_inp))
for l in vgg.layers: l.trainable=False
In [ ]:
vgg_content = Model(vgg_inp, vgg.get_layer('block2_conv2').output)
vgg1 = vgg_content(vgg_inp)
vgg2 = vgg_content(outp)
In [ ]:
loss = Lambda(lambda x: K.sqrt(K.mean((x[0]-x[1])**2, (1,2))))([vgg1, vgg2])
m_final = Model([inp, vgg_inp], loss)
targ = np.zeros((arr_lr.shape[0], 128))
In [ ]:
m_final.compile('adam', 'mse')
In [ ]:
m_final.evaluate([arr_lr[:10],arr_hr[:10]], targ[:10])
In [ ]:
K.set_value(m_final.optimizer.lr, 1e-3)
In [ ]:
m_final.fit([arr_lr, arr_hr], targ, 8, 2, **parms)
In [ ]:
K.set_value(m_final.optimizer.lr, 1e-4)
In [ ]:
m_final.fit([arr_lr, arr_hr], targ, 16, 2, **parms)
In [ ]:
m_final.save_weights(dpath+'m_final.h5')
In [ ]:
top_model = Model(inp, outp)
In [ ]:
top_model.save_weights(dpath+'top_final.h5')
In [ ]:
p = top_model.predict(arr_lr[:20])
In [ ]:
plt.imshow(arr_lr[10].astype('uint8'));
In [ ]:
plt.imshow(p[10].astype('uint8'))
In [ ]:
plt.imshow(arr_hr[0].astype('uint8'));
In [ ]: