In [1]:
%matplotlib inline
import os
os.environ['THEANO_FLAGS']='device=gpu0'

import tarfile
import matplotlib
import numpy as np
np.random.seed(123)
import matplotlib.pyplot as plt
import lasagne
import theano
import theano.tensor as T
conv = lasagne.layers.Conv2DLayer
pool = lasagne.layers.MaxPool2DLayer
NUM_EPOCHS = 15
BATCH_SIZE = 256
LEARNING_RATE = 0.001
DIM = 32
NUM_CLASSES = 10

theano.config.optimizer = 'fast_compile'
theano.config.exception_verbosity = 'high'


Using gpu device 0: GeForce GTX 750 Ti (CNMeM is disabled)

Load up data


In [2]:
def unpickle(file):
    import cPickle
    fo = open(file, 'rb')
    dict = cPickle.load(fo)
    fo.close()
    return dict

def load_data():
    
    

    if os.path.isdir("cifar1-10-batches-py") == False:
        print 'Downloading CIFAR-10 dataset...'
        !wget -N http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
        print 'Unzipping dataset...'
        t = tarfile.open('cifar-10-python.tar.gz', 'r')
        t.extractall('cifar-10')

    xs = []
    ys = []
    for j in range(5):
      d = unpickle('cifar-10/cifar-10-batches-py/data_batch_'+`j+1`)
      x = d['data']
      y = d['labels']
      xs.append(x)
      ys.append(y)

    d = unpickle('cifar-10/cifar-10-batches-py/test_batch')
    xs.append(d['data'])
    ys.append(d['labels'])

    x = np.concatenate(xs)/np.float32(255)
    y = np.concatenate(ys)

    x = np.dstack((x[:, :1024], x[:, 1024:2048], x[:, 2048:]))
    x = x.reshape((x.shape[0], 3, 32, 32))
    
    X_train=lasagne.utils.floatX(x[0:40000,:])
    X_valid = lasagne.utils.floatX(x[40001:50000,:])
    X_test = lasagne.utils.floatX(x[50001:60000,:])
    
    return dict(
        X_train=X_train,
        y_train=y[0:40000].astype('int32'),
        X_valid=X_valid,
        y_valid = y[40001:50000].astype('int32'),
        X_test=X_test,
        y_test = y[50001:60000].astype('int32'),
        num_examples_train=X_train.shape[0],
        num_examples_valid=X_valid.shape[0],
        num_examples_test=X_test.shape[0],
        input_height=X_train.shape[2],
        input_width=X_train.shape[3],
        output_dim=10,)
    
data = load_data()
print 'Loaded dataset'


Downloading CIFAR-10 dataset...
--2015-11-10 16:11:04--  http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/x-gzip]
Server file no newer than local file ‘cifar-10-python.tar.gz’ -- not retrieving.

Unzipping dataset...
Loaded dataset

Check if images are loaded up correctly


In [3]:
#Apply a random color space transformation on an input test image and display it for sanity check.
b = np.zeros((3, 3), dtype='float32')
b[0, 0] = 1
b[1, 1] = 0 
b[2, 2] = 0 

b = np.append(b,b)
    
W = b.flatten()
inputImages = data['X_test'][9997:10000]

num_batch, num_channels, height, width = inputImages.shape
W = T.reshape(W,(-1,3,3))

inputImages = T.reshape(inputImages,(num_batch, height, width, num_channels))
output = T.batched_dot(inputImages, W)


output = T.reshape(output,(num_batch, height, width, num_channels))


img = output[0,:,:,:].eval()
plt.figure(figsize=(8,8))
plt.imshow(img.reshape(DIM, DIM,3), interpolation='none')
plt.title('Random color space transformation - Sanity check that the input data were loaded up correctly')
plt.axis('off')
plt.show()


Create model


In [4]:
theano.config.floatX='float32'
from colortransformationlayer import ColorTransformationLayer

def build_model(input_width, input_height, output_dim,
                batch_size=BATCH_SIZE):
    ini = lasagne.init.GlorotUniform()

    inputs = lasagne.layers.InputLayer(shape=(None, 3, input_width, input_height),)

    # Color Space Transformation network
    b = np.zeros((3, 3), dtype='float32')
    b[0, 0] = 1
    b[1, 1] = 1 
    b[2, 2] = 1 

    
    b = b.flatten()
    tr_l1 = pool(inputs, pool_size=(2, 2))
    tr_l5 = lasagne.layers.DenseLayer(
        tr_l1, num_units=50, W=lasagne.init.HeUniform('relu'))
    color_transformation_matrix = lasagne.layers.DenseLayer(
        tr_l5, num_units=9, b=b, W=lasagne.init.Constant(0.0), 
        nonlinearity=lasagne.nonlinearities.identity)
    

    color_transformed_image = ColorTransformationLayer([inputs, color_transformation_matrix])
    
        
    # Classification network
    class_l1 = conv(
        color_transformed_image,
        num_filters=32,
        filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )

    
    class_l2 = pool(class_l1, pool_size=(2, 2))

    class_l3 = conv(
        class_l2,
        num_filters=32,
        filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )
    class_l4 = pool(class_l3, pool_size=(2, 2))

    
    class_l5 = lasagne.layers.DenseLayer(
        class_l4,
        num_units=256,
        nonlinearity=lasagne.nonlinearities.rectify,
        W=ini,
    )

 
    # output layer
    network = lasagne.layers.DenseLayer(
        lasagne.layers.dropout(class_l5, p=.5),
        num_units=10,
        nonlinearity=lasagne.nonlinearities.softmax)
    

    return network, color_transformed_image, color_transformation_matrix

model, color_transformed_image, color_transformation_matrix = build_model(DIM, DIM, NUM_CLASSES)
model_params = lasagne.layers.get_all_params(model, trainable=True)

Prepare functions


In [5]:
X = T.tensor4()
y = T.ivector()

# training output
output_train = lasagne.layers.get_output(model, X, deterministic=False)

# evaluation output. Also includes output of transform for plotting
output_eval, color_eval, transformation_eval = lasagne.layers.get_output([model, color_transformed_image, color_transformation_matrix], X, deterministic=True)

sh_lr = theano.shared(lasagne.utils.floatX(LEARNING_RATE))
cost = T.mean(T.nnet.categorical_crossentropy(output_train, y))
updates = lasagne.updates.adam(cost, model_params, learning_rate=sh_lr)

train = theano.function([X, y], [cost, output_train], updates=updates,allow_input_downcast=True)
eval = theano.function([X], [output_eval, color_eval, transformation_eval],allow_input_downcast=True)

In [6]:
def train_epoch(X, y):
    num_samples = X.shape[0]
    num_batches = int(np.ceil(num_samples / float(BATCH_SIZE)))
    costs = []
    correct = 0
    for i in range(num_batches):
        idx = range(i*BATCH_SIZE, np.minimum((i+1)*BATCH_SIZE, num_samples))
        X_batch = X[idx]
        y_batch = y[idx]
        cost_batch, output_train = train(X_batch, y_batch)
        costs += [cost_batch]
        preds = np.argmax(output_train, axis=-1)
        correct += np.sum(y_batch == preds)

    return np.mean(costs), correct / float(num_samples)


def eval_epoch(X, y):
    output_eval, color_eval,transformation_eval = eval(X)
    preds = np.argmax(output_eval, axis=-1)
    acc = np.mean(preds == y)
    return acc, color_eval, transformation_eval

Run training, testing and validation


In [7]:
valid_accs, train_accs, test_accs = [], [], []
try:
    for n in range(NUM_EPOCHS):
        train_cost, train_acc = train_epoch(data['X_train'], data['y_train'])
        valid_acc, valid_transformed_image, valid_color_transformation = eval_epoch(data['X_valid'], data['y_valid'])
        test_acc, test_transformed_image, test_color_transformation = eval_epoch(data['X_test'], data['y_test'])
        valid_accs += [valid_acc]
        test_accs += [test_acc]
        train_accs += [train_acc]



        if (n+1) % 20 == 0:
            new_lr = sh_lr.get_value() * 0.7
            print "New LR:", new_lr
            sh_lr.set_value(lasagne.utils.floatX(new_lr))
          
        print "Epoch {0}: Train cost {1}, Train acc {2}, val acc {3}, test acc {4}".format(
                n, train_cost, train_acc, valid_acc, test_acc)
except KeyboardInterrupt:
    pass


Epoch 0: Train cost 1.76405358315, Train acc 0.35535, val acc 0.447944794479, test acc 0.454645464546
Epoch 1: Train cost 1.45698654652, Train acc 0.477125, val acc 0.489448944894, test acc 0.492649264926
Epoch 2: Train cost 1.32591879368, Train acc 0.53095, val acc 0.552255225523, test acc 0.55095509551
Epoch 3: Train cost 1.23403251171, Train acc 0.56535, val acc 0.578657865787, test acc 0.579757975798
Epoch 4: Train cost 1.15073931217, Train acc 0.59495, val acc 0.621862186219, test acc 0.620062006201
Epoch 5: Train cost 1.10409760475, Train acc 0.610625, val acc 0.6299629963, test acc 0.632363236324
Epoch 6: Train cost 1.05119585991, Train acc 0.632075, val acc 0.644864486449, test acc 0.641664166417
Epoch 7: Train cost 0.991000235081, Train acc 0.6509, val acc 0.653165316532, test acc 0.652465246525
Epoch 8: Train cost 0.956870913506, Train acc 0.66395, val acc 0.659365936594, test acc 0.660566056606
Epoch 9: Train cost 0.913740515709, Train acc 0.675825, val acc 0.664566456646, test acc 0.661666166617
Epoch 10: Train cost 0.892711937428, Train acc 0.685325, val acc 0.667166716672, test acc 0.670267026703
Epoch 11: Train cost 0.847681820393, Train acc 0.700925, val acc 0.674167416742, test acc 0.676767676768
Epoch 12: Train cost 0.848632216454, Train acc 0.699725, val acc 0.669366936694, test acc 0.664566456646
Epoch 13: Train cost 0.8243445158, Train acc 0.7048, val acc 0.675467546755, test acc 0.66896689669
Epoch 14: Train cost 0.798437714577, Train acc 0.716, val acc 0.667566756676, test acc 0.662866286629

Plot the error


In [8]:
plt.figure(figsize=(9,9))
plt.plot(1-np.array(train_accs), label='Training Error')
plt.plot(1-np.array(valid_accs), label='Validation Error')
plt.legend(fontsize=20)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Error', fontsize=20)
plt.show()


Display some images from the dataset


In [9]:
#Show some example images along with their color transformed counterparts

import matplotlib.pyplot as plt


plt.figure(figsize=(20,20))
k = 550
num_images = 10
for i in range(num_images):
    testImage = data['X_test'][i+k]
    
    plt.subplot(2,num_images,i+1)
    plt.imshow(testImage.reshape(DIM, DIM,3),interpolation='none')
    plt.axis('off')
    
plt.figure(figsize=(20,20))
for i in range(num_images):
    trans_Image = test_transform[i+k]
    trans_Image -= trans_Image.min()
    trans_Image /= trans_Image.max()

    plt.subplot(2,num_images,i+1)
    plt.imshow(trans_Image.transpose(1,2,0),interpolation='none')
    plt.axis('off')

plt.figure(figsize=(20,20))
k = 1550
num_images = 10
for i in range(num_images):
    testImage = data['X_test'][i+k]
    plt.subplot(2,num_images,i+1)
    plt.imshow(testImage.reshape(DIM, DIM,3),interpolation='none')
    plt.axis('off')
    
plt.figure(figsize=(20,20))
for i in range(num_images):
    trans_Image = test_transform[i+k]
    trans_Image -= trans_Image.min()
    trans_Image /= trans_Image.max()

    plt.subplot(2,num_images,i+1)
    plt.imshow(trans_Image.transpose(1,2,0),interpolation='none')
    plt.axis('off')
plt.show()


plt.figure(figsize=(20,20))
k = 1750
num_images = 10
for i in range(num_images):
    testImage = data['X_test'][i+k]
    plt.subplot(2,num_images,i+1)
    plt.imshow(testImage.reshape(DIM, DIM,3),interpolation='none')
    plt.axis('off')
    
plt.figure(figsize=(20,20))
for i in range(num_images):
    trans_Image = test_transform[i+k]
    trans_Image -= trans_Image.min()
    trans_Image /= trans_Image.max()

    plt.subplot(2,num_images,i+1)
    plt.imshow(trans_Image.transpose(1,2,0),interpolation='none')
    plt.axis('off')

plt.figure(figsize=(20,20))
k = 4100
num_images = 10
for i in range(num_images):
    testImage = data['X_test'][i+k]
    plt.subplot(2,num_images,i+1)
    plt.imshow(testImage.reshape(DIM, DIM,3),interpolation='none')
    plt.axis('off')
    
plt.figure(figsize=(20,20))
for i in range(num_images):
    trans_Image = test_transform[i+k]
    plt.subplot(2,num_images,i+1)
    trans_Image -= trans_Image.min()
    trans_Image /= trans_Image.max()

    plt.imshow(trans_Image.transpose(1,2,0),interpolation='none')
    plt.axis('off')
plt.show()


Display color space tranformation matrices


In [10]:
k = 550
for i in range(num_images):
    print test_color_transformation[i+k]


[ 9.9320612  -2.48975015  0.54050916  1.82765472  8.01981735  0.48643482
 -2.06804252 -0.05125558  8.57703018]
[ 9.21115685 -2.2872436   0.53776354  1.6737864   7.47449303  0.48399341
 -1.91919088 -0.0332629   8.00648785]
[ 9.69655895 -2.42810297  0.48228398  1.78268039  7.82260561  0.43640429
 -2.00865245 -0.05101392  8.33857632]
[ 7.82844925 -1.89661467  0.42516413  1.39803803  6.37753344  0.38268745
 -1.58553576 -0.02700713  6.80262566]
[ 8.51116943 -2.08987355  0.4916099   1.53527641  6.91980362  0.44169116
 -1.75117826 -0.03003541  7.40612364]
[ 8.77240944 -2.16049933  0.46481648  1.59330714  7.1096859   0.41810283
 -1.79343259 -0.03872615  7.58387089]
[ 6.47447014 -1.54201698  0.24994679  1.11160684  5.30198669  0.23358606
 -1.28395259 -0.01302335  5.58082199]
[ 7.94951344 -1.93506777  0.4090136   1.42089236  6.46883392  0.36999705
 -1.61265862 -0.03048461  6.88589334]
[ 8.10622692 -1.9722693   0.41850096  1.45667315  6.58969307  0.37711784
 -1.63993299 -0.03098224  7.0145278 ]
[ 11.4046011   -2.90486717   0.61903554   2.13055539   9.16657257
   0.5572809   -2.40567803  -0.06951967   9.81724548]

In [ ]: