In [2]:
from keras.models import Sequential
from keras.callbacks import Callback
from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten, UpSampling2D
from keras import backend as K
import random
import glob
import wandb
from wandb.keras import WandbCallback
import subprocess
import os
from PIL import Image
import numpy as np
from matplotlib.pyplot import imshow, figure
In [3]:
#initialize wandb and download dataset
hyperparams = {"num_epochs": 10,
"batch_size": 32,
"height": 96,
"width": 96}
wandb.init(config=hyperparams)
config = wandb.config
val_dir = 'catz/test'
train_dir = 'catz/train'
# automatically get the data if it doesn't exist
if not os.path.exists("catz"):
print("Downloading catz dataset...")
subprocess.check_output(
"curl https://storage.googleapis.com/wandb/catz.tar.gz | tar xz", shell=True)
In [4]:
# generator to loop over train and test images
def my_generator(batch_size, img_dir):
"""A generator that returns 5 images plus a result image"""
cat_dirs = glob.glob(img_dir + "/*")
counter = 0
while True:
input_images = np.zeros(
(batch_size, config.width, config.height, 3 * 5))
output_images = np.zeros((batch_size, config.width, config.height, 3))
random.shuffle(cat_dirs)
if (counter+batch_size >= len(cat_dirs)):
counter = 0
for i in range(batch_size):
input_imgs = glob.glob(cat_dirs[counter + i] + "/cat_[0-5]*")
imgs = [Image.open(img) for img in sorted(input_imgs)]
input_images[i] = np.concatenate(imgs, axis=2)
output_images[i] = np.array(Image.open(
cat_dirs[counter + i] + "/cat_result.jpg"))
input_images[i] /= 255.
output_images[i] /= 255.
yield (input_images, output_images)
counter += batch_size
steps_per_epoch = len(glob.glob(train_dir + "/*")) // config.batch_size
validation_steps = len(glob.glob(val_dir + "/*")) // config.batch_size
In [5]:
#callback to log the images
class ImageCallback(Callback):
def on_epoch_end(self, epoch, logs):
validation_X, validation_y = next(
my_generator(15, val_dir))
output = self.model.predict(validation_X)
wandb.log({
"input": [wandb.Image(np.concatenate(np.split(c, 5, axis=2), axis=1)) for c in validation_X],
"output": [wandb.Image(np.concatenate([validation_y[i], o], axis=1)) for i, o in enumerate(output)]
}, commit=False)
In [83]:
# Test the generator
gen = my_generator(2, train_dir)
videos, next_frame = next(gen)
videos[0].shape
next_frame[0].shape
Out[83]:
In [64]:
figure()
imshow(videos[0][:,:,0:3])
figure()
imshow(videos[0][:,:,3:6])
figure()
imshow(videos[0][:,:,6:9])
figure()
imshow(videos[0][:,:,9:12])
figure()
imshow(next_frame[0][:,:,0:3])
Out[64]:
In [7]:
# Function for measuring how similar two images are
def perceptual_distance(y_true, y_pred):
y_true *= 255.
y_pred *= 255.
rmean = (y_true[:, :, :, 0] + y_pred[:, :, :, 0]) / 2
r = y_true[:, :, :, 0] - y_pred[:, :, :, 0]
g = y_true[:, :, :, 1] - y_pred[:, :, :, 1]
b = y_true[:, :, :, 2] - y_pred[:, :, :, 2]
return K.mean(K.sqrt((((512+rmean)*r*r)/256) + 4*g*g + (((767-rmean)*b*b)/256)))
In [84]:
wandb.init(config=hyperparams)
config = wandb.config
model = Sequential()
model.add(Conv2D(3, (3, 3), activation='relu', padding='same', input_shape=(config.height, config.width, 5 * 3)))
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])
model.fit_generator(my_generator(config.batch_size, train_dir),
steps_per_epoch=steps_per_epoch//4,
epochs=config.num_epochs, callbacks=[
ImageCallback(), WandbCallback()],
validation_steps=validation_steps//4,
validation_data=my_generator(config.batch_size, val_dir))
Out[84]:
In [85]:
# Baseline model - just return the last layer
from keras.layers import Lambda, Reshape, Permute
def slice(x):
return x[:,:,:,:, -1]
wandb.init(config=hyperparams)
config = wandb.config
model=Sequential()
model.add(Reshape((96,96,5,3), input_shape=(config.height, config.width, 5 * 3)))
model.add(Permute((1,2,4,3)))
model.add(Lambda(slice, input_shape=(96,96,3,5), output_shape=(96,96,3)))
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])
model.fit_generator(my_generator(config.batch_size, train_dir),
steps_per_epoch=steps_per_epoch//4,
epochs=config.num_epochs, callbacks=[
ImageCallback(), WandbCallback()],
validation_steps=validation_steps//4,
validation_data=my_generator(config.batch_size, val_dir))
In [86]:
# Just return the last layer, functional style
from keras.layers import Lambda, Reshape, Permute, Input
from keras.models import Model
def slice(x):
return x[:,:,:,:, -1]
wandb.init(config=hyperparams)
config = wandb.config
inp = Input((config.height, config.width, 5 * 3))
reshaped = Reshape((96,96,5,3))(inp)
permuted = Permute((1,2,4,3))(reshaped)
last_layer = Lambda(slice, input_shape=(96,96,3,5), output_shape=(96,96,3))(permuted)
model=Model(inputs=[inp], outputs=[last_layer])
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])
model.fit_generator(my_generator(config.batch_size, train_dir),
steps_per_epoch=steps_per_epoch//4,
epochs=config.num_epochs, callbacks=[
ImageCallback(), WandbCallback()],
validation_steps=validation_steps//4,
validation_data=my_generator(config.batch_size, val_dir))
In [90]:
# Conv3D
from keras.layers import Lambda, Reshape, Permute, Input, add, Conv3D
from keras.models import Model
def slice(x):
return x[:,:,:,:, -1]
hyperparams["num_epochs"] = 100
wandb.init(config=hyperparams)
config = wandb.config
inp = Input((config.height, config.width, 5 * 3))
reshaped = Reshape((96,96,5,3))(inp)
permuted = Permute((1,2,4,3))(reshaped)
last_layer = Lambda(slice, input_shape=(96,96,3,5), output_shape=(96,96,3))(permuted)
conv_output = Conv3D(1, (3,3,3), padding="same")(permuted)
conv_output_reshape = Reshape((96,96,3))(conv_output)
combined = add([last_layer, conv_output_reshape])
model=Model(inputs=[inp], outputs=[combined])
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])
model.fit_generator(my_generator(config.batch_size, train_dir),
steps_per_epoch=steps_per_epoch//4,
epochs=config.num_epochs, callbacks=[
ImageCallback(), WandbCallback()],
validation_steps=validation_steps//4,
validation_data=my_generator(config.batch_size, val_dir))
In [ ]:
# Conv3D with Gaussian Noise
from keras.layers import Lambda, Reshape, Permute, Input, add, Conv3D, GaussianNoise
from keras.models import Model
def slice(x):
return x[:,:,:,:, -1]
wandb.init()
inp = Input((config.height, config.width, 5 * 3))
reshaped = Reshape((96,96,5,3))(inp)
permuted = Permute((1,2,4,3))(reshaped)
noise = GaussianNoise(0.1)(permuted)
last_layer = Lambda(slice, input_shape=(96,96,3,5), output_shape=(96,96,3))(noise)
conv_output = Conv3D(1, (3,3,3), padding="same")(noise)
conv_output_reshape = Reshape((96,96,3))(conv_output)
combined = add([last_layer, conv_output_reshape])
model=Model(inputs=[inp], outputs=[combined])
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])
model.fit_generator(my_generator(config.batch_size, train_dir),
steps_per_epoch=steps_per_epoch//4,
epochs=config.num_epochs, callbacks=[
ImageCallback(), WandbCallback()],
validation_steps=validation_steps//4,
validation_data=my_generator(config.batch_size, val_dir))
In [10]:
# Conv2DLSTM with Gaussian Noise
from keras.layers import Lambda, Reshape, Permute, Input, add, Conv3D, GaussianNoise, ConvLSTM2D
from keras.models import Model
def slice(x):
return x[:,:,:,:, -1]
wandb.init(config=hyperparams)
config = wandb.config
inp = Input((config.height, config.width, 5 * 3))
reshaped = Reshape((96,96,5,3))(inp)
permuted = Permute((1,2,4,3))(reshaped)
noise = GaussianNoise(0.1)(permuted)
last_layer = Lambda(slice, input_shape=(96,96,3,5), output_shape=(96,96,3))(noise)
permuted_2 = Permute((4,1,2,3))(noise)
conv_lstm_output_1 = ConvLSTM2D(6, (3,3), padding='same')(permuted_2)
conv_output = Conv2D(3, (3,3), padding="same")(conv_lstm_output_1)
combined = add([last_layer, conv_output])
model=Model(inputs=[inp], outputs=[combined])
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])
model.fit_generator(my_generator(config.batch_size, train_dir),
steps_per_epoch=steps_per_epoch//4,
epochs=config.num_epochs, callbacks=[
ImageCallback(), WandbCallback()],
validation_steps=validation_steps//4,
validation_data=my_generator(config.batch_size, val_dir))
In [11]:
# Conv2DLSTM with Gaussian Noise
from keras.layers import Lambda, Reshape, Permute, Input, add, Conv3D, GaussianNoise, concatenate
from keras.layers import ConvLSTM2D, BatchNormalization, TimeDistributed, Add
from keras.models import Model
def slice(x):
return x[:,:,:,:, -1]
wandb.init(config=hyperparams)
config = wandb.config
c=4
inp = Input((config.height, config.width, 5 * 3))
reshaped = Reshape((96,96,5,3))(inp)
permuted = Permute((1,2,4,3))(reshaped)
noise = GaussianNoise(0.1)(permuted)
last_layer = Lambda(slice, input_shape=(96,96,3,5), output_shape=(96,96,3))(noise)
x = Permute((4,1,2,3))(noise)
x =(ConvLSTM2D(filters=c, kernel_size=(3,3),padding='same',name='conv_lstm1', return_sequences=True))(x)
c1=(BatchNormalization())(x)
x = Dropout(0.2)(x)
x =(TimeDistributed(MaxPooling2D(pool_size=(2,2))))(c1)
x =(ConvLSTM2D(filters=2*c,kernel_size=(3,3),padding='same',name='conv_lstm3',return_sequences=True))(x)
c2=(BatchNormalization())(x)
x = Dropout(0.2)(x)
x =(TimeDistributed(MaxPooling2D(pool_size=(2,2))))(c2)
x =(ConvLSTM2D(filters=4*c,kernel_size=(3,3),padding='same',name='conv_lstm4',return_sequences=True))(x)
x =(TimeDistributed(UpSampling2D(size=(2, 2))))(x)
x =(ConvLSTM2D(filters=4*c,kernel_size=(3,3),padding='same',name='conv_lstm5',return_sequences=True))(x)
x =(BatchNormalization())(x)
x =(ConvLSTM2D(filters=2*c,kernel_size=(3,3),padding='same',name='conv_lstm6',return_sequences=True))(x)
x =(BatchNormalization())(x)
x = Add()([c2, x])
x = Dropout(0.2)(x)
x =(TimeDistributed(UpSampling2D(size=(2, 2))))(x)
x =(ConvLSTM2D(filters=c,kernel_size=(3,3),padding='same',name='conv_lstm7',return_sequences=False))(x)
x =(BatchNormalization())(x)
combined = concatenate([last_layer, x])
combined = Conv2D(3, (1,1))(combined)
model=Model(inputs=[inp], outputs=[combined])
model.compile(optimizer='adam', loss='mse', metrics=[perceptual_distance])
model.fit_generator(my_generator(config.batch_size, train_dir),
steps_per_epoch=steps_per_epoch//4,
epochs=config.num_epochs, callbacks=[
ImageCallback(), WandbCallback()],
validation_steps=validation_steps//4,
validation_data=my_generator(config.batch_size, val_dir))
In [ ]: