In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Create Some Data


In [2]:
from sklearn.datasets import make_blobs

In [3]:
data = make_blobs(n_samples = 100, 
                  n_features = 3,
                  centers = 2,
                  random_state = 101)

In [4]:
data


Out[4]:
(array([[  0.99429187,   0.87511711,  -9.99909683],
        [  0.65761885,   0.91624771, -10.18458517],
        [ -6.48804352,   3.29034852,   6.73159871],
        [  0.5815206 ,   2.06249948,  -9.07157495],
        [ -7.42442678,   5.30835517,   6.86341624],
        [  2.26790487,   0.40816482, -10.17230518],
        [  1.62073713,  -0.05416228,  -9.92461083],
        [  2.41199108,   1.03683306,  -9.20017913],
        [ -7.62441747,   5.87096075,   5.41069708],
        [ -7.34160752,   3.671108  ,   7.36634568],
        [ -7.56358517,   3.5736335 ,   5.34284552],
        [  2.22972735,   1.6514787 ,  -7.43386319],
        [ -0.75249891,   2.31574949,  -9.26873428],
        [  0.54946324,   0.5581557 ,  -7.88852506],
        [ -0.4308995 ,   0.48011452,  -8.47545896],
        [ -1.19436032,   2.54705473,  -8.90232874],
        [ -0.06318407,   2.4416449 , -11.40312057],
        [ -6.70999871,   3.8246846 ,   7.34365184],
        [  1.32054601,   2.60559237, -10.47729501],
        [  0.15391601,   2.19349145,  -9.8137736 ],
        [  0.32833903,   2.43933589,  -9.58711337],
        [ -7.01543966,   4.3914134 ,   9.46294692],
        [  0.16543782,   1.89916047,  -9.03802666],
        [  0.7185004 ,   1.58025637,  -9.24601361],
        [ -7.65700277,   1.13777271,   7.33896645],
        [ -7.12044288,   4.64758461,   5.7025879 ],
        [  1.93475242,   0.29764177, -10.81589403],
        [ -6.4220804 ,   2.9760733 ,   5.9578275 ],
        [ -7.39003004,   2.67596248,   5.15308172],
        [ -7.03832316,   4.56563667,   6.9080243 ],
        [ -4.07557669,   3.15827975,   7.17466421],
        [ -0.61543385,   1.89810338,  -9.54728879],
        [  0.9339379 ,  -0.60481651,  -8.69039341],
        [ -4.81755245,   4.40108636,   6.83159809],
        [  0.57115017,   2.27751662, -10.99144692],
        [ -1.0016875 ,   1.45481168,  -9.84157087],
        [  1.7353108 ,   2.05515725, -10.33561537],
        [ -8.68467997,   3.42602942,   7.74064767],
        [  0.9790905 ,   1.09403369, -10.27859245],
        [ -5.79759684,   3.99975626,   5.28641819],
        [  2.44569994,   1.61087572,  -7.12752821],
        [  1.81046805,   2.37480989, -11.57172776],
        [ -6.24790729,   4.71286377,   7.00772091],
        [ -8.20249087,   4.507669  ,   5.01040047],
        [  1.00918185,   2.44847681,  -9.46167595],
        [ -6.78887759,   3.65986315,   6.69035824],
        [  0.51876688,   3.39210906,  -6.82454819],
        [ -8.40882774,   4.56913524,   5.60134675],
        [ -0.44335605,   1.52382915, -10.23516698],
        [ -5.24785153,   4.49966814,   8.03764923],
        [ -6.13431395,   4.96544332,   6.2300389 ],
        [ -1.37811338,   0.25423232,  -9.56535619],
        [  0.29639341,   2.06317757,  -7.27566903],
        [ -6.17610612,   3.075033  ,   5.27964757],
        [ -0.53891248,   2.13413934, -10.65359751],
        [ -5.43954882,   4.60333595,   7.00880287],
        [  0.85678605,   0.8243512 ,  -9.24182016],
        [ -0.65480344,   3.64490698, -10.40190804],
        [ -6.05573415,   2.7773348 ,   7.52484115],
        [ -0.80584461,   2.02382964,  -9.04448516],
        [ -0.2822863 ,   0.6580264 ,  -9.77693397],
        [ -5.66674102,   3.16763048,   5.12826615],
        [  1.35278283,   0.48747748,  -7.56765133],
        [ -6.47593889,   4.94635225,   5.58024423],
        [ -6.86800262,   3.73468111,   7.56696814],
        [ -6.65886532,   4.40410604,   5.26881157],
        [ -7.14994986,   4.91456692,   7.87968128],
        [  1.05699629,   0.55026047,  -9.12488396],
        [  0.51509708,   0.68050659, -10.81343557],
        [ -5.53765102,   2.62442136,   6.67758229],
        [ -5.69865986,   5.29135203,   7.60759509],
        [  1.16100188,   2.38907142,  -9.81875417],
        [  0.47499933,   0.9339037 ,  -8.87174606],
        [ -9.30656147,   5.22810174,   6.85594634],
        [  1.1112886 ,   0.704398  ,  -8.84366818],
        [ -6.16198043,   4.51338238,   3.94147281],
        [ -7.80518114,   2.90711167,   7.43205869],
        [  0.0760756 ,   0.83523148,  -9.19351984],
        [  0.52875302,   1.74119648,  -8.75603028],
        [ -5.95090888,   1.65297671,   6.51129114],
        [ -5.96070958,   3.99201262,   6.47214495],
        [ -5.04581642,   4.99576956,   5.19961789],
        [ -7.37241005,   3.45206025,   6.768636  ],
        [ -6.78944842,   3.99833476,   7.51229177],
        [ -7.55471973,   3.55026388,   5.66470924],
        [ -0.66529095,   1.61015124, -10.56716006],
        [ -4.54002269,   2.65918125,   5.37145495],
        [ -1.3503111 ,   1.74732484,  -9.962986  ],
        [  1.01148144,   1.71601719,  -7.73679255],
        [ -8.4775757 ,   3.32543613,   5.01187807],
        [ -1.29337527,   2.09088685,  -9.40440999],
        [ -6.35533268,   3.46502969,   5.25698778],
        [ -0.08134593,   1.75689092,  -9.23424015],
        [  1.13567847,   1.48631141,  -8.79172846],
        [ -6.40192863,   2.93960956,   7.64023642],
        [ -6.85482364,   5.79607878,   7.6819934 ],
        [ -7.63261577,   3.58015883,   5.73234913],
        [ -6.90248768,   2.77693736,   8.39388742],
        [ -6.3033601 ,   4.1181198 ,   7.6667101 ],
        [ -7.19637228,   3.31445067,   8.4214142 ]]),
 array([0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
        0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1,
        0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1,
        1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0,
        0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1]))

Scale the Data


In [5]:
from sklearn.preprocessing import MinMaxScaler

In [6]:
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(data[0])

In [7]:
# data[0] 
data_x = scaled_data[:, 0]
data_y = scaled_data[:, 1]
data_z = scaled_data[:, 2]

In [8]:
from mpl_toolkits.mplot3d import Axes3D

In [9]:
fig = plt.figure(figsize = (10, 10))
ax = fig.add_subplot(111, 
                     projection = '3d')



In [10]:
fig = plt.figure(figsize = (10, 10))
ax = fig.add_subplot(111, 
                     projection = '3d')
ax.scatter(data_x, 
           data_y, 
           data_z,
           c = data[1])


Out[10]:
<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x205d54447f0>

The Linear Autoencoder


In [11]:
import tensorflow as tf
from tensorflow.contrib.layers import fully_connected


WARNING:tensorflow:From c:\programdata\anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

In [12]:
num_inputs = 3  # 3 dimensional input
num_hidden = 2  # 2 dimensional representation 
num_outputs = num_inputs # Must be true for an autoencoder!

learning_rate = 0.01

Placeholder

Notice there is no real label here, just X.


In [13]:
X = tf.placeholder(tf.float32, shape = [None, num_inputs])

Layers

Using the fully_connected layers API, we do not provide an activation function!


In [14]:
hidden = fully_connected(X, num_hidden, activation_fn = None)
outputs = fully_connected(hidden, num_outputs, activation_fn = None)

Loss Function


In [15]:
loss = tf.reduce_mean(tf.square(outputs - X))  # MSE

Optimizer


In [16]:
optimizer = tf.train.AdamOptimizer(learning_rate)
train  = optimizer.minimize(loss)

Init


In [17]:
init = tf.global_variables_initializer()

Running the Session


In [18]:
num_steps = 1000

with tf.Session() as sess:
    sess.run(init)
    for iteration in range(num_steps):
        sess.run(train,
                 feed_dict = {X: scaled_data})

    # Now ask for the hidden layer output (the 2 dimensional output)
    output_2d = hidden.eval(feed_dict={X: scaled_data})

In [19]:
output_2d.shape


Out[19]:
(100, 2)

In [20]:
plt.scatter(output_2d[:, 0],
            output_2d[:, 1],
            c = data[1])


Out[20]:
<matplotlib.collections.PathCollection at 0x20675bb1898>

Great Job!