Linear Regression Example


In [1]:
"""
Simple linear regression example in TensorFlow
This program tries to predict the number of thefts from 
the number of fire in the city of Chicago
"""

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import xlrd

DATA_FILE = 'data/fire_theft.xls'

1.불러온 데이터 형태 확인


In [2]:
book = xlrd.open_workbook(DATA_FILE, encoding_override="utf-8")
sheet = book.sheet_by_index(0)
print(sheet)
data = np.asarray([sheet.row_values(i) for i in range(1, sheet.nrows)])
n_samples = sheet.nrows - 1
print(data[:5])
print(n_samples)


<xlrd.sheet.Sheet object at 0x000002302A154C50>
[[  6.2  29. ]
 [  9.5  44. ]
 [ 10.5  36. ]
 [  7.7  37. ]
 [  8.6  53. ]]
42

2.placeholder 생성


In [3]:
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')

3.weight과 bias 생성


In [4]:
w = tf.Variable(0.0, name='weights')#0으로 initialize
b = tf.Variable(0.0, name='bias')

4.Y값 예측을 위한 model생성


In [5]:
Y_predicted = X * w + b

5. loss function


In [6]:
loss = tf.square(Y - Y_predicted, name='loss') # square error법 사용

5. cost 최소화


In [7]:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss) #gradient descent optimizer사용, learning rate = 0.001
with tf.Session() as sess:
	# Step 7: initialize the necessary variables, in this case, w and b
	sess.run(tf.global_variables_initializer()) 
	
	writer = tf.summary.FileWriter('./my_graph/01/linear_reg', sess.graph) # tensorboard
	
	# Step 8: train the model
	for i in range(150): # train the model 100 times
		total_loss = 0
		for x, y in data:
			# Session runs train_op and fetch values of loss
			_, l = sess.run([optimizer, loss], feed_dict={X: x, Y:y}) 
			total_loss += l
		print ('Epoch {0}: {1}'.format(i, total_loss/n_samples))

	# close the writer when you're done using it
	print(writer)
	writer.close() 
	
	# Step 9: output the values of w and b
	w_value, b_value = sess.run([w, b]) 

# plot the results
X, Y = data.T[0], data.T[1]
plt.plot(X, Y, 'bo', label='Real data')
plt.plot(X, X * w_value + b_value, 'r', label='Predicted data')
plt.legend()
plt.show()


Epoch 0: 2069.6319333978354
Epoch 1: 2117.0123581953535
Epoch 2: 2092.302723001866
Epoch 3: 2068.5080461938464
Epoch 4: 2045.591184088162
Epoch 5: 2023.5146448101316
Epoch 6: 2002.2447619835536
Epoch 7: 1981.748338803649
Epoch 8: 1961.9944411260742
Epoch 9: 1942.9520116143283
Epoch 10: 1924.5930823644712
Epoch 11: 1906.8898800636332
Epoch 12: 1889.8164505837929
Epoch 13: 1873.347133841543
Epoch 14: 1857.4588400604468
Epoch 15: 1842.1278742424079
Epoch 16: 1827.332495119955
Epoch 17: 1813.0520579712022
Epoch 18: 1799.2660847636982
Epoch 19: 1785.9562132299961
Epoch 20: 1773.1024853109072
Epoch 21: 1760.689129482884
Epoch 22: 1748.6984157081515
Epoch 23: 1737.1138680398553
Epoch 24: 1725.920873066732
Epoch 25: 1715.1046249579008
Epoch 26: 1704.6500954309377
Epoch 27: 1694.5447134910141
Epoch 28: 1684.7746311347667
Epoch 29: 1675.328450968245
Epoch 30: 1666.1935385839038
Epoch 31: 1657.3584002084322
Epoch 32: 1648.8122658529207
Epoch 33: 1640.5440742547091
Epoch 34: 1632.5446836102221
Epoch 35: 1624.8043315147183
Epoch 36: 1617.3126799958602
Epoch 37: 1610.0622532456405
Epoch 38: 1603.0433557207386
Epoch 39: 1596.2479176106197
Epoch 40: 1589.668056331575
Epoch 41: 1583.2965242617897
Epoch 42: 1577.126371285745
Epoch 43: 1571.1501190634
Epoch 44: 1565.360979151513
Epoch 45: 1559.7523780798629
Epoch 46: 1554.3184364555138
Epoch 47: 1549.0529469620615
Epoch 48: 1543.950059985476
Epoch 49: 1539.0050282141283
Epoch 50: 1534.211797797609
Epoch 51: 1529.56534988646
Epoch 52: 1525.0607591186251
Epoch 53: 1520.6934648507852
Epoch 54: 1516.4585935090713
Epoch 55: 1512.3524023861364
Epoch 56: 1508.3695780125756
Epoch 57: 1504.5066588066873
Epoch 58: 1500.7606269073274
Epoch 59: 1497.126336559476
Epoch 60: 1493.600210891061
Epoch 61: 1490.1794991287668
Epoch 62: 1486.8605145300749
Epoch 63: 1483.639419928193
Epoch 64: 1480.5144186365596
Epoch 65: 1477.4811065652452
Epoch 66: 1474.5376660533782
Epoch 67: 1471.6799176652871
Epoch 68: 1468.9063155567717
Epoch 69: 1466.2136880280007
Epoch 70: 1463.5996563179153
Epoch 71: 1461.0614086978492
Epoch 72: 1458.597208841216
Epoch 73: 1456.2043069711044
Epoch 74: 1453.8807724802089
Epoch 75: 1451.6242183893032
Epoch 76: 1449.432753210976
Epoch 77: 1447.3042320180018
Epoch 78: 1445.237068621615
Epoch 79: 1443.228872676177
Epoch 80: 1441.2782130186733
Epoch 81: 1439.3831422174615
Epoch 82: 1437.542224922173
Epoch 83: 1435.7540219968096
Epoch 84: 1434.0160684508405
Epoch 85: 1432.3276573866606
Epoch 86: 1430.687153330871
Epoch 87: 1429.093016880254
Epoch 88: 1427.543719962062
Epoch 89: 1426.038033108981
Epoch 90: 1424.5748210840281
Epoch 91: 1423.1531702368743
Epoch 92: 1421.771026852585
Epoch 93: 1420.4274983895677
Epoch 94: 1419.121967994741
Epoch 95: 1417.85251878131
Epoch 96: 1416.618930517208
Epoch 97: 1415.4196022436731
Epoch 98: 1414.2534379121803
Epoch 99: 1413.1202843011845
Epoch 100: 1412.0180716720365
Epoch 101: 1410.9467244467564
Epoch 102: 1409.9050737058833
Epoch 103: 1408.89239564538
Epoch 104: 1407.9075303567308
Epoch 105: 1406.9502103584152
Epoch 106: 1406.0192385521673
Epoch 107: 1405.1141311767556
Epoch 108: 1404.2334141191982
Epoch 109: 1403.3773468952804
Epoch 110: 1402.5446293339842
Epoch 111: 1401.7349663263276
Epoch 112: 1400.9475360775277
Epoch 113: 1400.1815872419447
Epoch 114: 1399.4366472427334
Epoch 115: 1398.7121980247043
Epoch 116: 1398.0075749322064
Epoch 117: 1397.3223697323174
Epoch 118: 1396.655994380514
Epoch 119: 1396.007601584707
Epoch 120: 1395.3772896343753
Epoch 121: 1394.7637101475682
Epoch 122: 1394.1668273763996
Epoch 123: 1393.5865733389344
Epoch 124: 1393.021867629318
Epoch 125: 1392.472776920313
Epoch 126: 1391.938410587254
Epoch 127: 1391.4184658789918
Epoch 128: 1390.9131654983476
Epoch 129: 1390.4211126998775
Epoch 130: 1389.9426775354714
Epoch 131: 1389.4770981903587
Epoch 132: 1389.0241417735815
Epoch 133: 1388.5833784611452
Epoch 134: 1388.1547465622425
Epoch 135: 1387.7372929602861
Epoch 136: 1387.3316521396239
Epoch 137: 1386.9367964785724
Epoch 138: 1386.5526603758335
Epoch 139: 1386.1787038090683
Epoch 140: 1385.8150548090537
Epoch 141: 1385.4611479015578
Epoch 142: 1385.1165423563548
Epoch 143: 1384.7815280691498
Epoch 144: 1384.4554683289357
Epoch 145: 1384.138334988838
Epoch 146: 1383.8294674903154
Epoch 147: 1383.5289615853912
Epoch 148: 1383.2364399525381
Epoch 149: 1382.9520691292626
<tensorflow.python.summary.writer.writer.FileWriter object at 0x000002302DFDEA90>

huber loss 이용


In [8]:
def huber_loss(labels, predictions, delta=1.0):
    residual = tf.abs(predictions - labels)
    condition = tf.less(residual, delta)
    small_res = 0.5 * tf.square(residual)
    large_res = delta * residual - 0.5 * tf.square(delta)
    return tf.where(condition, small_res, large_res)

비교


In [9]:
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
w = tf.Variable(0.0)
b = tf.Variable(0.0)
Y_pred = X * w + b
loss = tf.square(Y_pred - Y)
loss_hb = huber_loss(Y, Y_pred)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)
optimizer_hb = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss_hb)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(100):
        total_loss = 0
        for x, y in data:
            _, l = sess.run([optimizer, loss], feed_dict={X: x, Y: y})
            total_loss += l
        print('#Epoch {0}: loss = {1}'.format(i, total_loss/n_samples) )
    w_result, b_result = sess.run([w, b])

with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    for i in range(100):
        total_loss = 0
        for x, y in data:
            _, l = sess2.run([optimizer_hb, loss_hb], feed_dict={X: x, Y: y})
            total_loss += l
        print('#Epoch {0}: huber_loss = {1}'.format(i, total_loss/n_samples) )
    w_result_hb, b_result_hb = sess2.run([w, b])


#Epoch 0: loss = 2069.6319333978354
#Epoch 1: loss = 2117.0123581953535
#Epoch 2: loss = 2092.302723001866
#Epoch 3: loss = 2068.5080461938464
#Epoch 4: loss = 2045.591184088162
#Epoch 5: loss = 2023.5146448101316
#Epoch 6: loss = 2002.2447619835536
#Epoch 7: loss = 1981.748338803649
#Epoch 8: loss = 1961.9944411260742
#Epoch 9: loss = 1942.9520116143283
#Epoch 10: loss = 1924.5930823644712
#Epoch 11: loss = 1906.8898800636332
#Epoch 12: loss = 1889.8164505837929
#Epoch 13: loss = 1873.347133841543
#Epoch 14: loss = 1857.4588400604468
#Epoch 15: loss = 1842.1278742424079
#Epoch 16: loss = 1827.332495119955
#Epoch 17: loss = 1813.0520579712022
#Epoch 18: loss = 1799.2660847636982
#Epoch 19: loss = 1785.9562132299961
#Epoch 20: loss = 1773.1024853109072
#Epoch 21: loss = 1760.689129482884
#Epoch 22: loss = 1748.6984157081515
#Epoch 23: loss = 1737.1138680398553
#Epoch 24: loss = 1725.920873066732
#Epoch 25: loss = 1715.1046249579008
#Epoch 26: loss = 1704.6500954309377
#Epoch 27: loss = 1694.5447134910141
#Epoch 28: loss = 1684.7746311347667
#Epoch 29: loss = 1675.328450968245
#Epoch 30: loss = 1666.1935385839038
#Epoch 31: loss = 1657.3584002084322
#Epoch 32: loss = 1648.8122658529207
#Epoch 33: loss = 1640.5440742547091
#Epoch 34: loss = 1632.5446836102221
#Epoch 35: loss = 1624.8043315147183
#Epoch 36: loss = 1617.3126799958602
#Epoch 37: loss = 1610.0622532456405
#Epoch 38: loss = 1603.0433557207386
#Epoch 39: loss = 1596.2479176106197
#Epoch 40: loss = 1589.668056331575
#Epoch 41: loss = 1583.2965242617897
#Epoch 42: loss = 1577.126371285745
#Epoch 43: loss = 1571.1501190634
#Epoch 44: loss = 1565.360979151513
#Epoch 45: loss = 1559.7523780798629
#Epoch 46: loss = 1554.3184364555138
#Epoch 47: loss = 1549.0529469620615
#Epoch 48: loss = 1543.950059985476
#Epoch 49: loss = 1539.0050282141283
#Epoch 50: loss = 1534.211797797609
#Epoch 51: loss = 1529.56534988646
#Epoch 52: loss = 1525.0607591186251
#Epoch 53: loss = 1520.6934648507852
#Epoch 54: loss = 1516.4585935090713
#Epoch 55: loss = 1512.3524023861364
#Epoch 56: loss = 1508.3695780125756
#Epoch 57: loss = 1504.5066588066873
#Epoch 58: loss = 1500.7606269073274
#Epoch 59: loss = 1497.126336559476
#Epoch 60: loss = 1493.600210891061
#Epoch 61: loss = 1490.1794991287668
#Epoch 62: loss = 1486.8605145300749
#Epoch 63: loss = 1483.639419928193
#Epoch 64: loss = 1480.5144186365596
#Epoch 65: loss = 1477.4811065652452
#Epoch 66: loss = 1474.5376660533782
#Epoch 67: loss = 1471.6799176652871
#Epoch 68: loss = 1468.9063155567717
#Epoch 69: loss = 1466.2136880280007
#Epoch 70: loss = 1463.5996563179153
#Epoch 71: loss = 1461.0614086978492
#Epoch 72: loss = 1458.597208841216
#Epoch 73: loss = 1456.2043069711044
#Epoch 74: loss = 1453.8807724802089
#Epoch 75: loss = 1451.6242183893032
#Epoch 76: loss = 1449.432753210976
#Epoch 77: loss = 1447.3042320180018
#Epoch 78: loss = 1445.237068621615
#Epoch 79: loss = 1443.228872676177
#Epoch 80: loss = 1441.2782130186733
#Epoch 81: loss = 1439.3831422174615
#Epoch 82: loss = 1437.542224922173
#Epoch 83: loss = 1435.7540219968096
#Epoch 84: loss = 1434.0160684508405
#Epoch 85: loss = 1432.3276573866606
#Epoch 86: loss = 1430.687153330871
#Epoch 87: loss = 1429.093016880254
#Epoch 88: loss = 1427.543719962062
#Epoch 89: loss = 1426.038033108981
#Epoch 90: loss = 1424.5748210840281
#Epoch 91: loss = 1423.1531702368743
#Epoch 92: loss = 1421.771026852585
#Epoch 93: loss = 1420.4274983895677
#Epoch 94: loss = 1419.121967994741
#Epoch 95: loss = 1417.85251878131
#Epoch 96: loss = 1416.618930517208
#Epoch 97: loss = 1415.4196022436731
#Epoch 98: loss = 1414.2534379121803
#Epoch 99: loss = 1413.1202843011845
#Epoch 0: huber_loss = 30.231313444319227
#Epoch 1: huber_loss = 24.488211881546746
#Epoch 2: huber_loss = 19.95247097987504
#Epoch 3: huber_loss = 18.415514595407462
#Epoch 4: huber_loss = 17.62403281920013
#Epoch 5: huber_loss = 17.08878264540718
#Epoch 6: huber_loss = 16.827302360641106
#Epoch 7: huber_loss = 16.676950227957043
#Epoch 8: huber_loss = 16.589362557090464
#Epoch 9: huber_loss = 16.547880798134777
#Epoch 10: huber_loss = 16.54316401366322
#Epoch 11: huber_loss = 16.53523531635957
#Epoch 12: huber_loss = 16.527739029466396
#Epoch 13: huber_loss = 16.520207939580793
#Epoch 14: huber_loss = 16.51269836031965
#Epoch 15: huber_loss = 16.505204298932636
#Epoch 16: huber_loss = 16.497726049274206
#Epoch 17: huber_loss = 16.490262628311203
#Epoch 18: huber_loss = 16.48281567916274
#Epoch 19: huber_loss = 16.475384406479343
#Epoch 20: huber_loss = 16.467968803342607
#Epoch 21: huber_loss = 16.460568067573366
#Epoch 22: huber_loss = 16.453183558459084
#Epoch 23: huber_loss = 16.445813909350406
#Epoch 24: huber_loss = 16.438459956220218
#Epoch 25: huber_loss = 16.431121296870213
#Epoch 26: huber_loss = 16.42379860965801
#Epoch 27: huber_loss = 16.41649028893915
#Epoch 28: huber_loss = 16.409198701381683
#Epoch 29: huber_loss = 16.401921452111786
#Epoch 30: huber_loss = 16.394659822185833
#Epoch 31: huber_loss = 16.38741318797249
#Epoch 32: huber_loss = 16.38018172096816
#Epoch 33: huber_loss = 16.37296574961926
#Epoch 34: huber_loss = 16.36576510500163
#Epoch 35: huber_loss = 16.358579074853054
#Epoch 36: huber_loss = 16.351408738110745
#Epoch 37: huber_loss = 16.344253146777017
#Epoch 38: huber_loss = 16.337112197386368
#Epoch 39: huber_loss = 16.32998634916952
#Epoch 40: huber_loss = 16.322875844669483
#Epoch 41: huber_loss = 16.31578067016034
#Epoch 42: huber_loss = 16.308699420098925
#Epoch 43: huber_loss = 16.30163342688632
#Epoch 44: huber_loss = 16.294582396135887
#Epoch 45: huber_loss = 16.28755898626211
#Epoch 46: huber_loss = 16.28047544640001
#Epoch 47: huber_loss = 16.273403990858544
#Epoch 48: huber_loss = 16.26633599188755
#Epoch 49: huber_loss = 16.259274158272007
#Epoch 50: huber_loss = 16.252216737628693
#Epoch 51: huber_loss = 16.245164091033594
#Epoch 52: huber_loss = 16.23811663793666
#Epoch 53: huber_loss = 16.231074809673288
#Epoch 54: huber_loss = 16.224036197488505
#Epoch 55: huber_loss = 16.217004282843497
#Epoch 56: huber_loss = 16.20997659099521
#Epoch 57: huber_loss = 16.202954309684824
#Epoch 58: huber_loss = 16.195937064526202
#Epoch 59: huber_loss = 16.18892476707697
#Epoch 60: huber_loss = 16.181917927275034
#Epoch 61: huber_loss = 16.174915446411994
#Epoch 62: huber_loss = 16.167918239409726
#Epoch 63: huber_loss = 16.16092608089093
#Epoch 64: huber_loss = 16.153939434285608
#Epoch 65: huber_loss = 16.146957043220894
#Epoch 66: huber_loss = 16.13998002153433
#Epoch 67: huber_loss = 16.13300782732854
#Epoch 68: huber_loss = 16.126041477957013
#Epoch 69: huber_loss = 16.119078718775114
#Epoch 70: huber_loss = 16.112122138408484
#Epoch 71: huber_loss = 16.105170149403246
#Epoch 72: huber_loss = 16.09822299353116
#Epoch 73: huber_loss = 16.091280936884385
#Epoch 74: huber_loss = 16.084344242946134
#Epoch 75: huber_loss = 16.07741223408985
#Epoch 76: huber_loss = 16.070485063166068
#Epoch 77: huber_loss = 16.063563120213942
#Epoch 78: huber_loss = 16.056645905668294
#Epoch 79: huber_loss = 16.049733855878003
#Epoch 80: huber_loss = 16.042827161450294
#Epoch 81: huber_loss = 16.035924382297146
#Epoch 82: huber_loss = 16.029027766794787
#Epoch 83: huber_loss = 16.022134729084513
#Epoch 84: huber_loss = 16.015313830764388
#Epoch 85: huber_loss = 16.00842140499957
#Epoch 86: huber_loss = 16.001477149819646
#Epoch 87: huber_loss = 15.994540542408469
#Epoch 88: huber_loss = 15.987603935335452
#Epoch 89: huber_loss = 15.980666862944851
#Epoch 90: huber_loss = 15.973728769092954
#Epoch 91: huber_loss = 15.966792559217927
#Epoch 92: huber_loss = 15.959854714971568
#Epoch 93: huber_loss = 15.952918051132222
#Epoch 94: huber_loss = 15.945980502070771
#Epoch 95: huber_loss = 15.939043304794247
#Epoch 96: huber_loss = 15.932106254943868
#Epoch 97: huber_loss = 15.925168887534667
#Epoch 98: huber_loss = 15.918231542665689
#Epoch 99: huber_loss = 15.91129489017961

In [18]:
data_x, data_y = data.T[0], data.T[1]
plt.plot(data_x, data_y, '.', label='real data')
plt.xlabel('X, number of fire')
plt.ylabel('Y, number of theft')


Out[18]:
<matplotlib.text.Text at 0x23032b14668>

In [19]:
plt.plot(data_x, data_y, 'k.', label='real data')
plt.plot(data_x, (data_x * w_result + b_result), 'r', label='predicted data')
plt.plot(data_x, (data_x * w_result_hb + b_result_hb), 'b', label='predicted data_huber_loss')
plt.xlabel('X, number of fire')
plt.ylabel('Y, number of theft')
plt.legend()
plt.show()


Logistic Regression Example


In [10]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import time

파라미터 정의


In [11]:
# Define paramaters for the model
learning_rate = 0.01
batch_size = 128
n_epochs = 30

데이터 읽어오기


In [12]:
# Step 1: Read in data
# using TF Learn's built in function to load MNIST data to the folder data/mnist
mnist = input_data.read_data_sets('/data/mnist', one_hot=True)


Extracting /data/mnist\train-images-idx3-ubyte.gz
Extracting /data/mnist\train-labels-idx1-ubyte.gz
Extracting /data/mnist\t10k-images-idx3-ubyte.gz
Extracting /data/mnist\t10k-labels-idx1-ubyte.gz

placeholder지정


In [13]:
# Step 2: create placeholders for features and labels
# each image in the MNIST data is of shape 28*28 = 784
# therefore, each image is represented with a 1x784 tensor
# there are 10 classes for each image, corresponding to digits 0 - 9. 
# each lable is one hot vector.
X = tf.placeholder(tf.float32, [batch_size, 784], name='X_placeholder') 
Y = tf.placeholder(tf.float32, [batch_size, 10], name='Y_placeholder')

Weight, Bias 지정


In [14]:
# Step 3: create weights and bias
# w is initialized to random variables with mean of 0, stddev of 0.01
# b is initialized to 0
# shape of w depends on the dimension of X and Y so that Y = tf.matmul(X, w)
# shape of b depends on Y
w = tf.Variable(tf.random_normal(shape=[784, 10], stddev=0.01), name='weights')
b = tf.Variable(tf.zeros([1, 10]), name="bias")

logits정의 (y_pred= w*X+b)


In [15]:
# Step 4: build model
# the model that returns the logits.
# this logits will be later passed through softmax layer
logits = tf.matmul(X, w) + b

loss function


In [16]:
# Step 5: define loss function
# use cross entropy of softmax of logits as the loss function
entropy = tf.nn.softmax_cross_entropy_with_logits(logits=tf.matmul(X, w) + b , labels=Y)
loss = tf.reduce_mean(entropy) # computes the mean over all the examples in the batch

In [17]:
# Step 6: define training op
# using gradient descent with learning rate of 0.01 to minimize loss
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

with tf.Session() as sess:
	# to visualize using TensorBoard
	writer = tf.summary.FileWriter('./my_graph/03/logistic_reg', sess.graph)

	start_time = time.time()
	sess.run(tf.global_variables_initializer())	
	n_batches = int(mnist.train.num_examples/batch_size)
	for i in range(n_epochs): # train the model n_epochs times
		total_loss = 0

		for _ in range(n_batches):
			X_batch, Y_batch = mnist.train.next_batch(batch_size)
			_, loss_batch = sess.run([optimizer, loss], feed_dict={X: X_batch, Y:Y_batch}) 
			total_loss += loss_batch
		print ('Average loss epoch {0}: {1}'.format(i, total_loss/n_batches))

	print('Total time: {0} seconds'.format(time.time() - start_time))

	print('Optimization Finished!') # should be around 0.35 after 25 epochs

	# test the model
	n_batches = int(mnist.test.num_examples/batch_size)
	total_correct_preds = 0
	for i in range(n_batches):
		X_batch, Y_batch = mnist.test.next_batch(batch_size)
		_, loss_batch, logits_batch = sess.run([optimizer, loss, logits], feed_dict={X: X_batch, Y:Y_batch}) 
		preds = tf.nn.softmax(logits_batch)
		correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y_batch, 1))
		accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32)) # need numpy.count_nonzero(boolarr) :(
		total_correct_preds += sess.run(accuracy)	
	
	print('Accuracy {0}'.format(total_correct_preds/mnist.test.num_examples))

	writer.close()


Average loss epoch 0: 1.294086247613102
Average loss epoch 1: 0.7336066665349307
Average loss epoch 2: 0.6008130059792445
Average loss epoch 3: 0.5367587413682249
Average loss epoch 4: 0.4978690439170891
Average loss epoch 5: 0.4712763373807316
Average loss epoch 6: 0.4514035423457761
Average loss epoch 7: 0.43596616241481756
Average loss epoch 8: 0.4234519165990514
Average loss epoch 9: 0.4130490093659132
Average loss epoch 10: 0.4042133788962464
Average loss epoch 11: 0.39632375333414765
Average loss epoch 12: 0.3902723304666839
Average loss epoch 13: 0.3842911572306306
Average loss epoch 14: 0.37923672208280274
Average loss epoch 15: 0.3745601326227188
Average loss epoch 16: 0.3701485982794306
Average loss epoch 17: 0.36639440848138227
Average loss epoch 18: 0.36288573445279004
Average loss epoch 19: 0.3594875372512079
Average loss epoch 20: 0.35651893934586665
Average loss epoch 21: 0.3538673204697651
Average loss epoch 22: 0.3511017171930878
Average loss epoch 23: 0.348721202402126
Average loss epoch 24: 0.3464411957002742
Average loss epoch 25: 0.3443333965081435
Average loss epoch 26: 0.34212910819859493
Average loss epoch 27: 0.3399956278450839
Average loss epoch 28: 0.3384455337877318
Average loss epoch 29: 0.33660205847872443
Total time: 11.691677808761597 seconds
Optimization Finished!
Accuracy 0.9116

In [ ]: