In [ ]:
import numpy as np
np.random.seed(123)

import os
from keras.models import Model
from keras.layers import Input, Convolution2D, MaxPooling2D, BatchNormalization
from keras.layers import Flatten, Dense, Dropout, ZeroPadding2D, Reshape, UpSampling2D
from keras.layers.local import LocallyConnected1D
from keras.layers.noise import GaussianDropout
from keras.optimizers import SGD
from keras.regularizers import l2
from keras import backend as K
from keras.utils.layer_utils import print_summary

import tensorflow as tf

import cv2
import h5py

import matplotlib.pyplot as plt
%matplotlib inline

In [ ]:
#os.environ["CUDA_VISIBLE_DEVICES"] = "" # uncomment this line to run the code on the CPU

In [ ]:
filter_id = 470 # candle class
N = 3200 # feature vector size

In [ ]:
def max_loss(y_true, y_pred):
    return (1.-K.sum(tf.mul(y_true,y_pred),axis=-1))

def max_metric(y_true, y_pred):
    return (1.-max_loss(y_true,y_pred))

def get_model():    
    # generator
    inputs = Input(shape=(N,), name='input')
    
    g0 = Reshape((N,1))(inputs)
    g0 = GaussianDropout(0.05)(g0)
    g1 = LocallyConnected1D(nb_filter=1, filter_length=1,
                            init='one', activation='relu', bias=False,
                            border_mode='valid',W_regularizer=l2(0.1))(g0)
    g2 = Reshape((128,5,5))(g1)
    
    g3 = UpSampling2D(size=(2, 2))(g2) # 10x10
    g3 = Convolution2D(512,2,2,activation='relu',border_mode='valid')(g3) # 9x9
    g3 = BatchNormalization(mode = 0 , axis = 1)(g3)
    g3 = Convolution2D(512,2,2,activation='relu',border_mode='same')(g3) # 9x9
    g3 = BatchNormalization(mode = 0 , axis = 1)(g3)
    
    g4 = UpSampling2D(size=(2, 2))(g3) # 18x18
    g4 = Convolution2D(256,3,3,activation='relu',border_mode='valid')(g4) # 16x16
    g4 = BatchNormalization(mode = 0 , axis = 1)(g4)
    g4 = Convolution2D(256,3,3,activation='relu',border_mode='same')(g4) # 16x16
    g4 = BatchNormalization(mode = 0 , axis = 1)(g4)
    
    g5 = UpSampling2D(size=(2, 2))(g4) # 32x32
    g5 = Convolution2D(256,3,3,activation='relu',border_mode='valid')(g5)# 30x30
    g5 = BatchNormalization(mode = 0 , axis = 1)(g5)
    g5 = Convolution2D(256,3,3,activation='relu',border_mode='same')(g5) # 30x30
    g5 = BatchNormalization(mode = 0 , axis = 1)(g5)
    
    g6 = UpSampling2D(size=(2, 2))(g5) # 60x60
    g6 = Convolution2D(128,3,3,activation='relu',border_mode='valid')(g6) # 58x58
    g6 = BatchNormalization(mode = 0 , axis = 1)(g6)
    g6 = Convolution2D(128,3,3,activation='relu',border_mode='same')(g6) # 58x58
    g6 = BatchNormalization(mode = 0 , axis = 1)(g6)
    
    g7 = UpSampling2D(size=(2, 2))(g6) # 116x116
    g7 = Convolution2D(128,4,4,activation='relu',border_mode='valid')(g7) # 113x113
    g7 = BatchNormalization(mode = 0 , axis = 1)(g7)
    g7 = Convolution2D(128,4,4,activation='relu',border_mode='same')(g7) # 113x113
    g7 = BatchNormalization(mode = 0 , axis = 1)(g7)
    
    g8 = UpSampling2D(size=(2, 2))(g7) # 226x226
    g8 = Convolution2D(64,3,3,activation='relu',border_mode='valid')(g8) # 224x224
    g8 = BatchNormalization(mode = 0 , axis = 1)(g8)
    g8 = Convolution2D(64,3,3,activation='relu',border_mode='same')(g8) # 224x224
    g8 = BatchNormalization(mode = 0 , axis = 1)(g8)
    g8 = Convolution2D(3,3,3,activation='linear',border_mode='same')(g8) # 224x224
    g8 = BatchNormalization(mode = 0, axis = 1, name='image')(g8)
    
    temp = Model(input=inputs, output=g8)
    offset = len(temp.layers)
    
    # discriminator  
    vgg1 = ZeroPadding2D((1,1),input_shape=(3,224,224))(g8)
    vgg2 = Convolution2D(64, 3, 3, activation='relu')(vgg1)
    vgg3 = ZeroPadding2D((1,1))(vgg2)
    vgg4 = Convolution2D(64, 3, 3, activation='relu')(vgg3)
    vgg5 = MaxPooling2D((2,2), strides=(2,2))(vgg4)

    vgg6 = ZeroPadding2D((1,1))(vgg5)
    vgg7 = Convolution2D(128, 3, 3, activation='relu')(vgg6)
    vgg8 = ZeroPadding2D((1,1))(vgg7)
    vgg9 = Convolution2D(128, 3, 3, activation='relu')(vgg8)
    vgg10 = MaxPooling2D((2,2), strides=(2,2))(vgg9)

    vgg11 = ZeroPadding2D((1,1))(vgg10)
    vgg12 = Convolution2D(256, 3, 3, activation='relu')(vgg11)
    vgg13 = ZeroPadding2D((1,1))(vgg12)
    vgg14 = Convolution2D(256, 3, 3, activation='relu')(vgg13)
    vgg15 = ZeroPadding2D((1,1))(vgg14)
    vgg16 = Convolution2D(256, 3, 3, activation='relu')(vgg15)
    vgg17 = MaxPooling2D((2,2), strides=(2,2))(vgg16)

    vgg18 = ZeroPadding2D((1,1))(vgg17)
    vgg19 = Convolution2D(512, 3, 3, activation='relu')(vgg18)
    vgg20 = ZeroPadding2D((1,1))(vgg19)
    vgg21 = Convolution2D(512, 3, 3, activation='relu')(vgg20)
    vgg22 = ZeroPadding2D((1,1))(vgg21)
    vgg23 = Convolution2D(512, 3, 3, activation='relu')(vgg22)
    vgg24 = MaxPooling2D((2,2), strides=(2,2))(vgg23)

    vgg25 = ZeroPadding2D((1,1))(vgg24)
    vgg26 = Convolution2D(512, 3, 3, activation='relu')(vgg25)
    vgg27 = ZeroPadding2D((1,1))(vgg26)
    vgg28 = Convolution2D(512, 3, 3, activation='relu')(vgg27)
    vgg29 = ZeroPadding2D((1,1))(vgg28)
    vgg30 = Convolution2D(512, 3, 3, activation='relu')(vgg29)
    vgg31 = MaxPooling2D((2,2), strides=(2,2))(vgg30)

    vgg32 = Flatten()(vgg31)
    vgg33 = Dense(4096, activation='relu')(vgg32)
    vgg34 = Dropout(0.5)(vgg33)
    vgg35 = Dense(4096, activation='relu')(vgg34)
    vgg36 = Dropout(0.5)(vgg35)
    vgg37 = Dense(1000, activation='relu', name='vgg_class')(vgg36)
    
    # create model
    model = Model(input=inputs, output=[vgg37,g8])
    
    # set generator weights
    enc_size = 30
    f = h5py.File('decoder_weights.h5')
    for k, l in enumerate(f.attrs['layer_names']):
        if(k<enc_size):
            continue
        g = f[l]
        weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
        weights = [g[weight_name] for weight_name in weight_names]
        model.layers[k-enc_size+4].set_weights(weights)
        model.layers[k-enc_size+4].trainable = False
    f.close()
    
    # set discriminator weights (vgg)
    f = h5py.File('vgg16_weights.h5')
    for k in range(f.attrs['nb_layers']):
        g = f['layer_{}'.format(k)]
        weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
        model.layers[k+offset].set_weights(weights)
        model.layers[k+offset].trainable = False
    f.close()
    
    # set the locally connected layer weights to trainable
    model.layers[3].trainable = True
    
    # compile model
    sgd = SGD(lr=0.01, decay=0.0, momentum=0.1, nesterov=True)
    model.compile(optimizer=sgd, loss=[max_loss, 'mse'], metrics=['mse'], loss_weights=[1.,0.])

    return model

In [ ]:
# create neural network
model = get_model()
print_summary(model.layers)

In [ ]:
def reconstruct_image(im):
    im2 = np.squeeze(im)*1
    im2 = im2.transpose((1,2,0))
    im2[:,:,0] += 103.939
    im2[:,:,1] += 116.779
    im2[:,:,2] += 123.68
    im2 = im2.astype(np.uint8)
    return cv2.cvtColor(im2,cv2.COLOR_BGR2RGB)

def print_img(model,z=None):
    if(z is None):
        z = np.random.uniform(0,1,size=(1,N))
    out = model.predict(z, batch_size=z.shape[0])
    
    activ = out[0][0]
    img = out[1][0]

    # change to RGB colors and rescale image
    img -= np.min(img)
    img /= np.max(img)
    img *= 256.
    img = cv2.cvtColor(img.astype('uint8').transpose(1,2,0), cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(6,6))
    plt.imshow(np.flipud(img))
    plt.title('filter activation: '+str(activ[filter_id]))
    plt.axis('off')
    plt.show()
    return img

_ = print_img(model)

In [ ]:
# training the model
batch_size = 1
n_samples = 40
dummy_labels2 = np.zeros(shape=(n_samples,3,224,224))
vgg_nclasses = 1000

z = np.ones(shape=(n_samples,N))
IMG = np.zeros((30,224,224,3))
for k in np.arange(0,30):
    dummy_labels1 = np.ones(shape=(n_samples,vgg_nclasses))*(-10./vgg_nclasses) # put a penalty to the other classes
    dummy_labels1[:,filter_id] = 1.                                             # give a positive unit weight for the target class
    out = model.fit(z, [dummy_labels1,dummy_labels2], batch_size=batch_size, nb_epoch=1, verbose=1)
    
    IMG[k,:,:,:] = print_img(model, z[0:1])

# plotting the median of the last 10 iterations gives a smoother final image
plt.figure()
plt.imshow(np.flipud(np.median(IMG[20:,:,:,:],axis=0).astype('uint8')))
plt.axis('off')
plt.show()

In [ ]: