Wrap up


In [4]:
%reset -f
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.reset_default_graph()

In [5]:
# Data 
from sklearn.datasets import load_digits
data = load_digits()

idx = np.random.permutation(data.data.shape[0])
idx_train = idx[:-100]
idx_test = idx[-100:]

x_train = data.data[idx_train,:]
y_train = data.target[idx_train]
x_test = data.data[idx_test,:]
y_test = data.target[idx_test]

In [ ]:


In [6]:
%reset
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.reset_default_graph() 


# Network Parameters
n_input = 64
n_hidden_1 = 32

# Parameters
learning_rate = 1e-3


Once deleted, variables cannot be recovered. Proceed (y/[n])? 
Nothing done.

In [19]:
# Graph inputs
x = tf.placeholder("float", [None, n_input], name='x')
y = tf.placeholder("float", [None, 1], name='y')

# Store layers weight & bias
c = 0.1 
weights = {
    'h1': tf.Variable(c*tf.random_normal([n_input, n_hidden_1]), name='W1'),
    'out': tf.Variable(c*tf.random_normal([n_hidden_1, 1]))
}
biases = {
    'b1': tf.Variable(c*tf.random_normal([n_hidden_1]), name='b1'),
    'out': tf.Variable(c*tf.random_normal([1]))
}

layer_1 = tf.nn.relu(tf.add(tf.matmul(x, weights['h1']), biases['b1']))
output = tf.add(tf.matmul(layer_1, weights['out']), biases['out'])



#Reconstruction loss
loss = tf.reduce_mean(tf.abs(output - y))


#Optimizer
opt = tf.train.AdamOptimizer(learning_rate).minimize(loss)

init = tf.initialize_all_variables()

In [20]:
# Data 
from sklearn.datasets import load_digits
data = load_digits()

idx = np.random.permutation(data.data.shape[0])
idx_train = idx[:-100]
idx_test = idx[-100:]

x_train = data.data[idx_train,:]
y_train = data.target[idx_train]
x_test = data.data[idx_test,:]
y_test = data.target[idx_test]

In [23]:
sess = tf.Session()
sess.run(init)

import tqdm
training_epochs = 10000
display_step = 50

# Training cycle
cost = []
batch_size=16
for epoch in tqdm.tqdm(xrange(training_epochs)):
    idxs = np.random.randint(0,x_train.shape[0],batch_size)
    xs = x_train[idxs,:]/16.
    ys = y_train[idxs][:,np.newaxis]
    sess.run(opt, feed_dict={x: xs, y: ys})
    if epoch % display_step == 1:
        cost.append(sess.run(loss, feed_dict={x: xs, y: ys }))

print "Optimization Finished!"


100%|██████████| 10000/10000 [00:09<00:00, 1044.33it/s]
Optimization Finished!

In [24]:
#Test

sess.run(output, feed_dict={x: x_test/16.})

np.c_[sess.run(output, feed_dict={x: x_test/16.}),y_test]


Out[24]:
array([[  6.05617476,   8.        ],
       [  1.22179008,   1.        ],
       [  4.80580568,   6.        ],
       [  5.2840848 ,   6.        ],
       [  3.23546791,   8.        ],
       [  4.35415268,   5.        ],
       [ 11.04371357,   9.        ],
       [  3.88889647,   4.        ],
       [  2.84020424,   3.        ],
       [  5.20426559,   5.        ],
       [  3.11576128,   1.        ],
       [  0.26507246,   0.        ],
       [  2.76579857,   2.        ],
       [  6.78263998,   7.        ],
       [  0.3442443 ,   1.        ],
       [  6.09711409,   7.        ],
       [  1.66482794,   2.        ],
       [  8.44147205,   9.        ],
       [  7.93614244,   9.        ],
       [  3.81157207,   4.        ],
       [  2.11424685,   3.        ],
       [  9.97387028,   9.        ],
       [  7.36736679,   7.        ],
       [  6.32968235,   6.        ],
       [ -0.22298861,   1.        ],
       [  7.92960453,   9.        ],
       [  7.5869875 ,   7.        ],
       [  9.56771183,   9.        ],
       [  6.30264711,   7.        ],
       [ -0.51921821,   0.        ],
       [  9.19419575,   9.        ],
       [  1.73979425,   1.        ],
       [  3.67623186,   4.        ],
       [  7.21883392,   7.        ],
       [  2.72593021,   3.        ],
       [  5.83749008,   7.        ],
       [  0.12014246,   0.        ],
       [  0.8581624 ,   0.        ],
       [  3.84807396,   4.        ],
       [  3.38268709,   6.        ],
       [  1.17133319,   0.        ],
       [  4.4332428 ,   4.        ],
       [  9.69098759,   9.        ],
       [  5.26011038,   5.        ],
       [ -0.25578964,   0.        ],
       [  3.3473897 ,   3.        ],
       [  8.16262722,   9.        ],
       [  6.05749559,   6.        ],
       [  4.35894966,   3.        ],
       [  5.85805988,   6.        ],
       [  6.68923712,   7.        ],
       [  5.09076071,   5.        ],
       [  3.78517628,   3.        ],
       [  1.56442964,   2.        ],
       [  0.50923955,   1.        ],
       [  8.66133881,   9.        ],
       [  2.08299327,   2.        ],
       [  5.78901911,   6.        ],
       [  1.85938251,   2.        ],
       [  2.32824659,   3.        ],
       [  0.43072015,   0.        ],
       [  3.70899773,   5.        ],
       [  0.57841361,   0.        ],
       [  6.85906649,   7.        ],
       [  2.14901376,   2.        ],
       [  1.75575542,   2.        ],
       [  8.63748646,   9.        ],
       [  5.9685955 ,   6.        ],
       [  4.91121817,   5.        ],
       [  6.61487055,   7.        ],
       [  4.37340546,   3.        ],
       [  4.2119379 ,   4.        ],
       [  4.7677207 ,   5.        ],
       [  7.96376133,   8.        ],
       [  5.82450485,   6.        ],
       [  4.95388651,   5.        ],
       [ -0.41276026,   0.        ],
       [  2.61445522,   1.        ],
       [  6.44314814,   8.        ],
       [  2.34214807,   2.        ],
       [  1.06951427,   0.        ],
       [  2.00880289,   1.        ],
       [ -0.05824363,   0.        ],
       [  7.53541327,   7.        ],
       [  6.07271194,   9.        ],
       [  0.83342779,   0.        ],
       [ 10.17685509,   9.        ],
       [  4.30995941,   5.        ],
       [  6.98378897,   8.        ],
       [  4.18541431,   4.        ],
       [  3.99603701,   3.        ],
       [  2.90997171,   1.        ],
       [  4.61576605,   5.        ],
       [  0.57803357,   1.        ],
       [  0.17314243,   0.        ],
       [  4.59220171,   4.        ],
       [  9.68675232,   9.        ],
       [  3.28117609,   3.        ],
       [  9.93737507,   9.        ],
       [  1.86977947,   2.        ]])

In [ ]: