Ch 07: Concept 02

Autoencoder with images

Import the autoencoder class we wrote earlier:


In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import pickle
import numpy as np
from autoencoder import Autoencoder

Define some helper function to load and preprocess the data:


In [2]:
def unpickle(file):
    fo = open(file, 'rb')
    dict = pickle.load(fo, encoding='latin1')
    fo.close()
    return dict

def grayscale(a):
    return a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1)

Download the CIFAR-10 dataset in Python from https://www.cs.toronto.edu/~kriz/cifar.html. Then we can load the data using the following code:


In [3]:
names = unpickle('./cifar-10-batches-py/batches.meta')['label_names']
data, labels = [], []
for i in range(1, 6):
    filename = './cifar-10-batches-py/data_batch_' + str(i)
    batch_data = unpickle(filename)
    if len(data) > 0:
        data = np.vstack((data, batch_data['data']))
        labels = np.hstack((labels, batch_data['labels']))
    else:
        data = batch_data['data']
        labels = batch_data['labels']

data = grayscale(data)
x = np.matrix(data)
y = np.array(labels)

Train the autoencoder on images of horses:


In [4]:
horse_indices = np.where(y == 7)[0]
horse_x = x[horse_indices]
print(np.shape(horse_x))  # (5000, 3072)

print('Some examples of horse images we will feed to the autoencoder for training')
plt.rcParams['figure.figsize'] = (10, 10)
num_examples = 5
for i in range(num_examples):
    horse_img = np.reshape(horse_x[i, :], (32, 32))
    plt.subplot(1, num_examples, i+1)
    plt.imshow(horse_img, cmap='Greys_r')
plt.show()


(5000, 1024)
Some examples of horse images we will feed to the autoencoder for training

In [5]:
input_dim = np.shape(horse_x)[1]
hidden_dim = 100
ae = Autoencoder(input_dim, hidden_dim)
ae.train(horse_x)


epoch 0: loss = 141.4900665283203
epoch 10: loss = 72.48475646972656
epoch 20: loss = 60.29756164550781
epoch 30: loss = 62.54150390625
epoch 40: loss = 58.15482711791992
epoch 50: loss = 56.428043365478516
epoch 60: loss = 57.38883590698242
epoch 70: loss = 58.01189041137695
epoch 80: loss = 55.45024871826172
epoch 90: loss = 53.27738571166992
epoch 100: loss = 55.66946029663086
epoch 110: loss = 52.93635177612305
epoch 120: loss = 52.29307174682617
epoch 130: loss = 51.190956115722656
epoch 140: loss = 52.70167922973633
epoch 150: loss = 51.69100570678711
epoch 160: loss = 52.5902099609375
epoch 170: loss = 54.53725814819336
epoch 180: loss = 50.20492935180664
epoch 190: loss = 50.705711364746094
epoch 200: loss = 52.63679504394531
epoch 210: loss = 47.62162780761719
epoch 220: loss = 51.37104415893555
epoch 230: loss = 50.87934494018555
epoch 240: loss = 47.789520263671875
epoch 250: loss = 48.64930725097656
epoch 260: loss = 49.09609603881836
epoch 270: loss = 52.35578155517578
epoch 280: loss = 49.07854461669922
epoch 290: loss = 50.17341613769531
epoch 300: loss = 50.5720100402832
epoch 310: loss = 51.39663314819336
epoch 320: loss = 48.94121551513672
epoch 330: loss = 48.08163070678711
epoch 340: loss = 51.91111373901367
epoch 350: loss = 48.66571807861328
epoch 360: loss = 53.95515823364258
epoch 370: loss = 51.0589599609375
epoch 380: loss = 47.25322341918945
epoch 390: loss = 48.82767868041992
epoch 400: loss = 48.54008102416992
epoch 410: loss = 48.49171447753906
epoch 420: loss = 49.44685363769531
epoch 430: loss = 48.99972915649414
epoch 440: loss = 48.93858337402344
epoch 450: loss = 47.747283935546875
epoch 460: loss = 49.11021423339844
epoch 470: loss = 52.6614990234375
epoch 480: loss = 47.462528228759766
epoch 490: loss = 48.2142219543457
epoch 500: loss = 47.22665786743164
epoch 510: loss = 46.21189498901367
epoch 520: loss = 48.50703430175781
epoch 530: loss = 46.67418670654297
epoch 540: loss = 49.2231330871582
epoch 550: loss = 46.520503997802734
epoch 560: loss = 51.34899139404297
epoch 570: loss = 45.424476623535156
epoch 580: loss = 50.18787384033203
epoch 590: loss = 46.64382553100586
epoch 600: loss = 48.735843658447266
epoch 610: loss = 48.83089065551758
epoch 620: loss = 47.54549789428711
epoch 630: loss = 50.897132873535156
epoch 640: loss = 50.95079803466797
epoch 650: loss = 47.783199310302734
epoch 660: loss = 51.59523391723633
epoch 670: loss = 48.8479118347168
epoch 680: loss = 50.1485595703125
epoch 690: loss = 46.98124313354492
epoch 700: loss = 47.87333297729492
epoch 710: loss = 49.56991195678711
epoch 720: loss = 49.013526916503906
epoch 730: loss = 48.30108642578125
epoch 740: loss = 49.60373306274414
epoch 750: loss = 49.387760162353516
epoch 760: loss = 49.210697174072266
epoch 770: loss = 50.2661247253418
epoch 780: loss = 48.906288146972656
epoch 790: loss = 49.0813102722168
epoch 800: loss = 49.76530456542969
epoch 810: loss = 47.713905334472656
epoch 820: loss = 47.74589920043945
epoch 830: loss = 49.45218276977539
epoch 840: loss = 47.087501525878906
epoch 850: loss = 49.574161529541016
epoch 860: loss = 48.69647979736328
epoch 870: loss = 50.04840850830078
epoch 880: loss = 47.72985076904297
epoch 890: loss = 47.26101303100586
epoch 900: loss = 48.874332427978516
epoch 910: loss = 47.47987747192383
epoch 920: loss = 50.03631591796875
epoch 930: loss = 45.66521072387695
epoch 940: loss = 49.02825164794922
epoch 950: loss = 47.84540939331055
epoch 960: loss = 48.757850646972656
epoch 970: loss = 46.702640533447266
epoch 980: loss = 49.98514175415039
epoch 990: loss = 48.62955856323242

Test the autoencoder on other images:


In [6]:
test_data = unpickle('./cifar-10-batches-py/test_batch')
test_x = grayscale(test_data['data'])
test_labels = np.array(test_data['labels'])
encodings = ae.classify(test_x, test_labels)


data (10000, 1024)
reconstructed (1024,)
loss (10000,)
horse 67.4191074286
not horse 65.5469002694

In [7]:
plt.rcParams['figure.figsize'] = (100, 100)
plt.figure()
for i in range(20):
    plt.subplot(20, 2, i*2 + 1)
    original_img = np.reshape(test_x[i, :], (32, 32))
    plt.imshow(original_img, cmap='Greys_r')
    
    plt.subplot(20, 2, i*2 + 2)
    reconstructed_img = ae.decode([encodings[i]])
    plt.imshow(reconstructed_img, cmap='Greys_r')

plt.show()



In [ ]: