Simple linear regression

Solution using MLP and Linear Algebra


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# numpy package
import numpy as np

# keras modules
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import plot_model

# generate x data
x = np.arange(-1,1,0.2)
x = np.reshape(x, [-1,1])

# generate y data
y = 2 * x + 3

# True if noise is added to y
is_noisy = False

# add noise if enabled
if is_noisy:
    noise = np.random.uniform(-0.1, 0.1, x.shape)
    x = x + noise

# deep learning method
# build 2-layer MLP network 
model = Sequential()
# 1st MLP has 8 units (perceptron), input is 1-dim
model.add(Dense(units=8, input_dim=1))
# 2nd MLP has 1 unit, output is 1-dim
model.add(Dense(units=1))
# print summary to double check the network
model.summary()
# create a nice image of the network model
plot_model(model, to_file='linear-model.png', show_shapes=True)
# indicate the loss function and use stochastic gradient descent
# (sgd) as optimizer
model.compile(loss='mse', optimizer='sgd')
# feed the network with complete dataset (1 epoch) 100 times
# batch size of sgd is 4
model.fit(x, y, epochs=100, batch_size=4)
# simple validation by predicting the output based on x
ypred = model.predict(x)

# linear algebra method
ones = np.ones(x.shape)
# A is the concat of x and 1s
A = np.concatenate([x,ones], axis=1)
# compute k using using pseudo-inverse
k = np.matmul(np.linalg.pinv(A), y) 
print("k (Linear Algebra Method):")
print(k)
# predict the output using linear algebra solution
yla = np.matmul(A, k)

# print ground truth, linear algebra, MLP solutions
outputs = np.concatenate([y, yla, ypred], axis=1)
print("Ground Truth, Linear Alg Prediction, MLP Prediction")
print(outputs)


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 8)                 16        
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 9         
=================================================================
Total params: 25
Trainable params: 25
Non-trainable params: 0
_________________________________________________________________
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
Train on 10 samples
Epoch 1/100
10/10 [==============================] - 1s 109ms/sample - loss: 8.4181
Epoch 2/100
10/10 [==============================] - 0s 2ms/sample - loss: 5.7370
Epoch 3/100
10/10 [==============================] - 0s 1ms/sample - loss: 3.8146
Epoch 4/100
10/10 [==============================] - 0s 1ms/sample - loss: 2.6083
Epoch 5/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.8234
Epoch 6/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.3077
Epoch 7/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.9801
Epoch 8/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.6866
Epoch 9/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.4787
Epoch 10/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.3478
Epoch 11/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.2568
Epoch 12/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.1964
Epoch 13/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.1499
Epoch 14/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.1252
Epoch 15/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0894
Epoch 16/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.0693
Epoch 17/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0516
Epoch 18/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0411
Epoch 19/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0288
Epoch 20/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0201
Epoch 21/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0157
Epoch 22/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0115
Epoch 23/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.0090
Epoch 24/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0068
Epoch 25/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0047
Epoch 26/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0036
Epoch 27/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.0026
Epoch 28/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0019
Epoch 29/100
10/10 [==============================] - 0s 1ms/sample - loss: 0.0014
Epoch 30/100
10/10 [==============================] - 0s 2ms/sample - loss: 0.0011
Epoch 31/100
10/10 [==============================] - 0s 2ms/sample - loss: 8.3436e-04
Epoch 32/100
10/10 [==============================] - 0s 2ms/sample - loss: 5.4558e-04
Epoch 33/100
10/10 [==============================] - 0s 2ms/sample - loss: 3.7619e-04
Epoch 34/100
10/10 [==============================] - 0s 2ms/sample - loss: 2.7842e-04
Epoch 35/100
10/10 [==============================] - 0s 2ms/sample - loss: 2.1498e-04
Epoch 36/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.5641e-04
Epoch 37/100
10/10 [==============================] - 0s 2ms/sample - loss: 1.1316e-04
Epoch 38/100
10/10 [==============================] - 0s 1ms/sample - loss: 8.4790e-05
Epoch 39/100
10/10 [==============================] - 0s 1ms/sample - loss: 6.0916e-05
Epoch 40/100
10/10 [==============================] - 0s 2ms/sample - loss: 4.6339e-05
Epoch 41/100
10/10 [==============================] - 0s 2ms/sample - loss: 3.5867e-05
Epoch 42/100
10/10 [==============================] - 0s 1ms/sample - loss: 2.6099e-05
Epoch 43/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.9969e-05
Epoch 44/100
10/10 [==============================] - 0s 2ms/sample - loss: 1.3725e-05
Epoch 45/100
10/10 [==============================] - 0s 2ms/sample - loss: 1.0290e-05
Epoch 46/100
10/10 [==============================] - 0s 1ms/sample - loss: 7.2686e-06
Epoch 47/100
10/10 [==============================] - 0s 1ms/sample - loss: 5.5383e-06
Epoch 48/100
10/10 [==============================] - 0s 2ms/sample - loss: 4.2897e-06
Epoch 49/100
10/10 [==============================] - 0s 2ms/sample - loss: 3.2522e-06
Epoch 50/100
10/10 [==============================] - 0s 2ms/sample - loss: 2.3353e-06
Epoch 51/100
10/10 [==============================] - 0s 2ms/sample - loss: 1.7212e-06
Epoch 52/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.2761e-06
Epoch 53/100
10/10 [==============================] - 0s 1ms/sample - loss: 8.6673e-07
Epoch 54/100
10/10 [==============================] - 0s 1ms/sample - loss: 6.3416e-07
Epoch 55/100
10/10 [==============================] - 0s 992us/sample - loss: 4.9319e-07
Epoch 56/100
10/10 [==============================] - 0s 938us/sample - loss: 3.7122e-07
Epoch 57/100
10/10 [==============================] - 0s 1ms/sample - loss: 2.7542e-07
Epoch 58/100
10/10 [==============================] - 0s 942us/sample - loss: 2.0714e-07
Epoch 59/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.4989e-07
Epoch 60/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.1624e-07
Epoch 61/100
10/10 [==============================] - 0s 908us/sample - loss: 8.4808e-08
Epoch 62/100
10/10 [==============================] - 0s 1ms/sample - loss: 6.4244e-08
Epoch 63/100
10/10 [==============================] - 0s 2ms/sample - loss: 4.4389e-08
Epoch 64/100
10/10 [==============================] - 0s 2ms/sample - loss: 3.1544e-08
Epoch 65/100
10/10 [==============================] - 0s 2ms/sample - loss: 2.3182e-08
Epoch 66/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.6740e-08
Epoch 67/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.1735e-08
Epoch 68/100
10/10 [==============================] - 0s 2ms/sample - loss: 8.9919e-09
Epoch 69/100
10/10 [==============================] - 0s 1ms/sample - loss: 6.3282e-09
Epoch 70/100
10/10 [==============================] - 0s 2ms/sample - loss: 4.6842e-09
Epoch 71/100
10/10 [==============================] - 0s 1ms/sample - loss: 3.5122e-09
Epoch 72/100
10/10 [==============================] - 0s 2ms/sample - loss: 2.6531e-09
Epoch 73/100
10/10 [==============================] - 0s 2ms/sample - loss: 1.9445e-09
Epoch 74/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.4403e-09
Epoch 75/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.0559e-09
Epoch 76/100
10/10 [==============================] - 0s 1ms/sample - loss: 7.3598e-10
Epoch 77/100
10/10 [==============================] - 0s 1ms/sample - loss: 5.5197e-10
Epoch 78/100
10/10 [==============================] - 0s 2ms/sample - loss: 4.1046e-10
Epoch 79/100
10/10 [==============================] - 0s 1ms/sample - loss: 3.0287e-10
Epoch 80/100
10/10 [==============================] - 0s 1ms/sample - loss: 2.2959e-10
Epoch 81/100
10/10 [==============================] - 0s 2ms/sample - loss: 1.7719e-10
Epoch 82/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.2162e-10
Epoch 83/100
10/10 [==============================] - 0s 1ms/sample - loss: 8.6062e-11
Epoch 84/100
10/10 [==============================] - 0s 2ms/sample - loss: 6.5083e-11
Epoch 85/100
10/10 [==============================] - 0s 1ms/sample - loss: 4.5488e-11
Epoch 86/100
10/10 [==============================] - 0s 1ms/sample - loss: 3.3987e-11
Epoch 87/100
10/10 [==============================] - 0s 1ms/sample - loss: 2.5058e-11
Epoch 88/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.8902e-11
Epoch 89/100
10/10 [==============================] - 0s 1ms/sample - loss: 1.2513e-11
Epoch 90/100
10/10 [==============================] - 0s 2ms/sample - loss: 9.1873e-12
Epoch 91/100
10/10 [==============================] - 0s 1ms/sample - loss: 7.9140e-12
Epoch 92/100
10/10 [==============================] - 0s 1ms/sample - loss: 6.7146e-12
Epoch 93/100
10/10 [==============================] - 0s 1ms/sample - loss: 6.2997e-12
Epoch 94/100
10/10 [==============================] - 0s 1ms/sample - loss: 5.3575e-12
Epoch 95/100
10/10 [==============================] - 0s 1ms/sample - loss: 4.4082e-12
Epoch 96/100
10/10 [==============================] - 0s 1ms/sample - loss: 4.0487e-12
Epoch 97/100
10/10 [==============================] - 0s 1ms/sample - loss: 3.9037e-12
Epoch 98/100
10/10 [==============================] - 0s 3ms/sample - loss: 3.7019e-12
Epoch 99/100
10/10 [==============================] - 0s 1ms/sample - loss: 3.0823e-12
Epoch 100/100
10/10 [==============================] - 0s 1ms/sample - loss: 3.2315e-12
k (Linear Algebra Method):
[[2.]
 [3.]]
Ground Truth, Linear Alg Prediction, MLP Prediction
[[1.         1.         1.00000286]
 [1.4        1.4        1.400002  ]
 [1.8        1.8        1.80000138]
 [2.2        2.2        2.200001  ]
 [2.6        2.6        2.60000038]
 [3.         3.         2.99999976]
 [3.4        3.4        3.39999938]
 [3.8        3.8        3.79999876]
 [4.2        4.2        4.1999979 ]
 [4.6        4.6        4.59999752]]