In [1]:
# simple linear regression
# 2017-03-11 jkang
# Python3.5
# Tensorflow1.0.1
# ref: http://web.stanford.edu/class/cs20si/
#
# input: number of fire
# output: number of theft

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import xlrd

In [2]:
data_file = 'fire_theft.xls'

book = xlrd.open_workbook(data_file, encoding_override='utf-8')
sheet = book.sheet_by_index(0)
data = np.asarray([sheet.row_values(i) for i in range(1, sheet.nrows)])
n_samples = sheet.nrows - 1

X = tf.placeholder(tf.float64, shape=(), name='NumFire')
Y = tf.placeholder(tf.float64, shape=(), name='NumTheft')

w = tf.Variable(np.zeros(1), name='Weight')
b = tf.Variable(np.zeros(1), name='Bias')

Y_predict = tf.add(tf.multiply(X, w), b)

def huber_loss(labels, predictions, delta=1.0):
    # Huber loss (outlier robust)
    delta = np.array(delta, dtype=np.float64)
    residual = tf.abs(predictions - labels)
    condition = tf.less(residual, delta)
    small_res = 0.5 * tf.square(residual)
    large_res = delta * residual - 0.5 * tf.square(delta)
    return tf.where(condition, small_res, large_res)
loss = huber_loss(Y, Y_predict, delta=1.0)

# loss = tf.square(tf.sub(Y, Y_predict), name='loss')

optimizer = tf.train.GradientDescentOptimizer(
    learning_rate=0.01).minimize(loss)

In [3]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter('./graph', sess.graph)

    # online training
    for i in range(100):
        total_loss = 0
        for x, y in data:
            _, l = sess.run([optimizer, loss], feed_dict={X: x, Y: y})
            total_loss += l
        print("Epoch {0}: {1}".format(i, total_loss / n_samples))

    w_value, b_value = sess.run([w, b])

writer.close()


Epoch 0: [ 20.78248941]
Epoch 1: [ 17.5085926]
Epoch 2: [ 17.43059022]
Epoch 3: [ 17.35258784]
Epoch 4: [ 17.27458546]
Epoch 5: [ 17.19658309]
Epoch 6: [ 17.11858071]
Epoch 7: [ 17.04057833]
Epoch 8: [ 16.96257595]
Epoch 9: [ 16.88457357]
Epoch 10: [ 16.80657119]
Epoch 11: [ 16.72856881]
Epoch 12: [ 16.65056643]
Epoch 13: [ 16.57189105]
Epoch 14: [ 16.49214522]
Epoch 15: [ 16.41331964]
Epoch 16: [ 16.30321906]
Epoch 17: [ 16.16355141]
Epoch 18: [ 16.11939643]
Epoch 19: [ 15.89864353]
Epoch 20: [ 16.00006652]
Epoch 21: [ 15.55065575]
Epoch 22: [ 15.84791581]
Epoch 23: [ 15.37867564]
Epoch 24: [ 15.71846368]
Epoch 25: [ 15.07982732]
Epoch 26: [ 15.62268616]
Epoch 27: [ 15.03241077]
Epoch 28: [ 15.48340592]
Epoch 29: [ 14.69999218]
Epoch 30: [ 15.40843224]
Epoch 31: [ 14.96007081]
Epoch 32: [ 14.97233815]
Epoch 33: [ 14.66121992]
Epoch 34: [ 15.15022668]
Epoch 35: [ 14.38269019]
Epoch 36: [ 15.06684032]
Epoch 37: [ 14.25294879]
Epoch 38: [ 14.98151321]
Epoch 39: [ 14.43444012]
Epoch 40: [ 14.81153604]
Epoch 41: [ 14.02340102]
Epoch 42: [ 14.76137576]
Epoch 43: [ 13.94231161]
Epoch 44: [ 14.67783121]
Epoch 45: [ 13.90886679]
Epoch 46: [ 14.56982981]
Epoch 47: [ 13.76349129]
Epoch 48: [ 14.48199347]
Epoch 49: [ 13.67545607]
Epoch 50: [ 14.39147976]
Epoch 51: [ 13.59499179]
Epoch 52: [ 14.29828419]
Epoch 53: [ 13.52127125]
Epoch 54: [ 14.20256893]
Epoch 55: [ 13.4554631]
Epoch 56: [ 14.07500369]
Epoch 57: [ 13.35704845]
Epoch 58: [ 13.93325481]
Epoch 59: [ 13.29660827]
Epoch 60: [ 13.78049716]
Epoch 61: [ 13.20453661]
Epoch 62: [ 13.75811527]
Epoch 63: [ 13.14788769]
Epoch 64: [ 13.3960246]
Epoch 65: [ 13.10041458]
Epoch 66: [ 13.55168312]
Epoch 67: [ 13.05141046]
Epoch 68: [ 13.05051798]
Epoch 69: [ 13.24394998]
Epoch 70: [ 13.00572995]
Epoch 71: [ 13.0008947]
Epoch 72: [ 13.00473231]
Epoch 73: [ 12.97760102]
Epoch 74: [ 12.96426423]
Epoch 75: [ 12.94953553]
Epoch 76: [ 12.93409441]
Epoch 77: [ 12.91831815]
Epoch 78: [ 12.90240773]
Epoch 79: [ 12.88646809]
Epoch 80: [ 12.87055339]
Epoch 81: [ 12.85469139]
Epoch 82: [ 12.83889614]
Epoch 83: [ 12.82317469]
Epoch 84: [ 12.79406614]
Epoch 85: [ 12.77535871]
Epoch 86: [ 12.76014887]
Epoch 87: [ 12.74410493]
Epoch 88: [ 12.73003248]
Epoch 89: [ 12.71685299]
Epoch 90: [ 12.70383218]
Epoch 91: [ 12.68992133]
Epoch 92: [ 12.67794208]
Epoch 93: [ 12.66402113]
Epoch 94: [ 12.65329018]
Epoch 95: [ 12.6397949]
Epoch 96: [ 12.62922677]
Epoch 97: [ 12.61670992]
Epoch 98: [ 12.60639625]
Epoch 99: [ 12.594726]

In [4]:
# plot the results
X, Y = data.T[0], data.T[1]
plt.plot(X, Y, 'bo', label='Real data')
plt.plot(X, X * w_value + b_value, 'r', label='Predicted data')
plt.legend()
plt.show()