Non Linear Autoencoders


In [19]:
%matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import seaborn as sns

In [20]:
from tensorflow.examples.tutorials.mnist import input_data

In [21]:
mnist = input_data.read_data_sets('../MNIST_data/',one_hot=True)


Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz

In [22]:
plt.imshow(mnist.train.images[0].reshape((28,28)),cmap='gray',interpolation="nearest")


Out[22]:
<matplotlib.image.AxesImage at 0x12025fe50>

In [47]:
learning_rate = 0.01
n_training_examples,n_features = mnist.train.images.shape
batch_size = 100
n_epochs = 100

In [48]:
tf.reset_default_graph()

In [49]:
with tf.variable_scope("data") as scope:
    input_image = tf.placeholder(dtype=tf.float32,shape=[None,784],name="input")

In [50]:
with tf.variable_scope("hidden_layer") as scope:
    w = tf.get_variable(name="weights",shape=[784,128],initializer=tf.contrib.layers.xavier_initializer())
    b = tf.get_variable(name="biases",shape=[128],initializer=tf.random_normal_initializer())
    encoding = tf.nn.relu(tf.matmul(input_image,w) + b)

In [51]:
with tf.variable_scope("output_layer") as scope:
    w = tf.get_variable(name="weights",shape=[128,784],initializer=tf.contrib.layers.xavier_initializer())
    b = tf.get_variable(name="biases",shape=[784],initializer=tf.random_normal_initializer())
    output_image = tf.matmul(encoding,w) + b

In [52]:
with tf.variable_scope("loss") as scope:
    loss = tf.reduce_mean(tf.squared_difference(input_image,output_image))

In [53]:
with tf.variable_scope("optimizer") as scope:
    optimizer = tf.train.AdagradOptimizer(learning_rate).minimize(loss)

In [54]:
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    writer = tf.summary.FileWriter('graphs/',sess.graph)
    ## Training the network
    for i in range(n_epochs):
        epoch_loss = 0
        for batch in range(n_training_examples/batch_size):
            x_batch,_ = mnist.train.next_batch(batch_size)
            _,l = sess.run([optimizer,loss],feed_dict = {input_image:x_batch})
            epoch_loss += l
        print 'Epoch: {}\t Loss: {}'.format(i+1,epoch_loss)
    
    ## Testing Examples
    print 'Checking the output of network on Test data....'
    output = []
    error = 0
    for i in range(n_testing_examples/batch_size):
        x_batch,_ = mnist.test.next_batch(batch_size)
        l,image = sess.run([loss,output_image],feed_dict = {input_image:x_batch})
        output.append(image)
        error += l
        if (i+1)%10 == 0:
            print 'Error after {} batches: {}'.format(i+1,error)
            error = 0


Epoch: 1	 Loss: 114.051694684
Epoch: 2	 Loss: 44.3769447654
Epoch: 3	 Loss: 42.0630340427
Epoch: 4	 Loss: 40.484398596
Epoch: 5	 Loss: 39.3076668307
Epoch: 6	 Loss: 38.372906208
Epoch: 7	 Loss: 37.5898293555
Epoch: 8	 Loss: 36.9051306956
Epoch: 9	 Loss: 36.2813824564
Epoch: 10	 Loss: 35.6936593466
Epoch: 11	 Loss: 35.1266304404
Epoch: 12	 Loss: 34.5706646107
Epoch: 13	 Loss: 34.0218508951
Epoch: 14	 Loss: 33.4778660126
Epoch: 15	 Loss: 32.9389003366
Epoch: 16	 Loss: 32.4057905935
Epoch: 17	 Loss: 31.879738044
Epoch: 18	 Loss: 31.3617734388
Epoch: 19	 Loss: 30.852421172
Epoch: 20	 Loss: 30.3527168706
Epoch: 21	 Loss: 29.8619098663
Epoch: 22	 Loss: 29.3811007217
Epoch: 23	 Loss: 28.9100803621
Epoch: 24	 Loss: 28.4493103363
Epoch: 25	 Loss: 27.9992557615
Epoch: 26	 Loss: 27.5605178066
Epoch: 27	 Loss: 27.1335849501
Epoch: 28	 Loss: 26.7187242024
Epoch: 29	 Loss: 26.3166321926
Epoch: 30	 Loss: 25.9274149053
Epoch: 31	 Loss: 25.5511623658
Epoch: 32	 Loss: 25.1882650517
Epoch: 33	 Loss: 24.8383147009
Epoch: 34	 Loss: 24.5013127252
Epoch: 35	 Loss: 24.1771068051
Epoch: 36	 Loss: 23.8652042784
Epoch: 37	 Loss: 23.5654122084
Epoch: 38	 Loss: 23.2773666643
Epoch: 39	 Loss: 23.0005765557
Epoch: 40	 Loss: 22.7346823514
Epoch: 41	 Loss: 22.4790600948
Epoch: 42	 Loss: 22.233501792
Epoch: 43	 Loss: 21.9971862994
Epoch: 44	 Loss: 21.7697980218
Epoch: 45	 Loss: 21.5508762673
Epoch: 46	 Loss: 21.3399185799
Epoch: 47	 Loss: 21.1364395879
Epoch: 48	 Loss: 20.9399429671
Epoch: 49	 Loss: 20.7500815205
Epoch: 50	 Loss: 20.566382166
Epoch: 51	 Loss: 20.3882859461
Epoch: 52	 Loss: 20.2156009264
Epoch: 53	 Loss: 20.0478675701
Epoch: 54	 Loss: 19.8848518655
Epoch: 55	 Loss: 19.7262853272
Epoch: 56	 Loss: 19.5717186667
Epoch: 57	 Loss: 19.4210398383
Epoch: 58	 Loss: 19.2739499249
Epoch: 59	 Loss: 19.130183218
Epoch: 60	 Loss: 18.989470467
Epoch: 61	 Loss: 18.8519462626
Epoch: 62	 Loss: 18.7172226906
Epoch: 63	 Loss: 18.5850717463
Epoch: 64	 Loss: 18.4556708131
Epoch: 65	 Loss: 18.3286189809
Epoch: 66	 Loss: 18.2040103208
Epoch: 67	 Loss: 18.0816678908
Epoch: 68	 Loss: 17.961475838
Epoch: 69	 Loss: 17.84344876
Epoch: 70	 Loss: 17.7274723127
Epoch: 71	 Loss: 17.6135662496
Epoch: 72	 Loss: 17.5015043709
Epoch: 73	 Loss: 17.3914026283
Epoch: 74	 Loss: 17.2831214834
Epoch: 75	 Loss: 17.1766564883
Epoch: 76	 Loss: 17.0719708782
Epoch: 77	 Loss: 16.9689774327
Epoch: 78	 Loss: 16.8676336724
Epoch: 79	 Loss: 16.7679895759
Epoch: 80	 Loss: 16.6699523032
Epoch: 81	 Loss: 16.5735381152
Epoch: 82	 Loss: 16.4786082692
Epoch: 83	 Loss: 16.3852623589
Epoch: 84	 Loss: 16.29339098
Epoch: 85	 Loss: 16.2030250896
Epoch: 86	 Loss: 16.1139755975
Epoch: 87	 Loss: 16.0263374038
Epoch: 88	 Loss: 15.9402579069
Epoch: 89	 Loss: 15.855342906
Epoch: 90	 Loss: 15.7718145624
Epoch: 91	 Loss: 15.6896023396
Epoch: 92	 Loss: 15.6086113546
Epoch: 93	 Loss: 15.5288837682
Epoch: 94	 Loss: 15.4503824711
Epoch: 95	 Loss: 15.3730661552
Epoch: 96	 Loss: 15.2969302125
Epoch: 97	 Loss: 15.221866563
Epoch: 98	 Loss: 15.1479590982
Epoch: 99	 Loss: 15.0751519427
Epoch: 100	 Loss: 15.0033744108
Checking the output of network on Test data....
Error after 10 batches: 0.269919754937
Error after 20 batches: 0.27109804377
Error after 30 batches: 0.26735717617
Error after 40 batches: 0.270253889263
Error after 50 batches: 0.264255257323
Error after 60 batches: 0.262490360066
Error after 70 batches: 0.266797050834
Error after 80 batches: 0.268831349909
Error after 90 batches: 0.272368231788
Error after 100 batches: 0.266832135618

In [55]:
output = np.asarray(output).reshape((10000,784))
print output.shape


(10000, 784)

In [56]:
plt.imshow(output[0].reshape((28,28)),cmap="gray",interpolation="nearest")


Out[56]:
<matplotlib.image.AxesImage at 0x120b4bed0>


In [ ]: