Keras Model for a Simple Linear Function

In this notebook, I've created a simple Keras model to approximate a linear function. I am new to Keras and deep learning in general and this exercise really helped me understand what's happening in a Keras neural network model.

Trying to predict a simple linear function like this with a neural network is, of course, overkill. But using a linear function makes it easy to see how different aspects of a Keras model, like the learning rate, input normalization, stochastic gradient descent, and dataset size, affect the performance of the model, without getting confused by image processing or other concepts necessary to understand other neural network examples.

Plot inline and import all necessary libraries and functions


In [184]:
%matplotlib inline

In [185]:
import pandas as pd
import numpy as np
import seaborn as sns

from keras.layers import Dense
from keras.models import Model, Sequential
from keras import initializers

Create a dataset that approximates a linear function with some noise


In [186]:
## Set the mean, standard deviation, and size of the dataset, respectively
mu, sigma, size = 0, 4, 100

## Set the slope (m) and y-intercept (b), respectively
m, b = 2, 100

## Create a uniformally distributed set of X values between 0 and 10 and store in pandas dataframe
x = np.random.uniform(0,10, size)
df = pd.DataFrame({'x':x})

## Find the "perfect" y value corresponding to each x value given
df['y_perfect'] = df['x'].apply(lambda x: m*x+b)


## Create some noise and add it to each "perfect" y value to create a realistic y dataset
df['noise'] = np.random.normal(mu, sigma, size=(size,))
df['y'] = df['y_perfect']+df['noise']

## Plot our noisy dataset with a standard linear regression 
## (note seaborn, the plotting library, does the linear regression by default)
ax1 = sns.regplot(x='x', y='y', data=df)


Create a callback function so we can track the progress of our predictions through epochs


In [187]:
from keras.callbacks import Callback

class PrintAndSaveWeights(Callback):
    """
    Print and save the weights after each epoch.
    """
    
    def on_train_begin(self, logs={}):
        """
        Create our weights history list when we begin training
        """
        self.weights_history = {"m":[], "b":[]}
    
    def on_epoch_end(self, batch, logs={}):
        """
        At the end of every epoch, save and print our slope and intercept weights
        """
        ## Get the current weights
        current_m = self.model.layers[-1].get_weights()[0][0][0]
        current_b = self.model.layers[-1].get_weights()[1][0]
        
        ## Save them to hour history object
        self.weights_history['m'].append(current_m)
        self.weights_history['b'].append(current_b)
        
        ## Print them after each epoch
        print "\nm=%.2f b=%.2f\n" % (current_m, current_b)

## Initialize our callback function for use in the model later
print_save_weights = PrintAndSaveWeights()

Create our Keras model to approximate our linear function

The goal of our model will be to find the weights that best predict the outputs, given the inputs. In our simple linear, example the weights are the slope (m) and y-intercept (b) of our line.

To do so, we are using single "dense" or "fully connected layer" with a 'linear' activation function.

To get a feel for how models work I tried a few different things:

  1. I tried running the model with and without kernel initialization (eg Glorot or Xavier input normalization)
  2. I changed the number of epochs
  3. I changed the learning rate
  4. I changed the amount of data (by adjusting the "size" parameter in the dataset creation cell
  5. I changed the optimizer to 'Adam'

In [188]:
## Create our model with a single dense layer, with a linear activation function and glorot (Xavier) input normalization
model = Sequential([
        Dense(1, activation='linear', input_shape=(1,), kernel_initializer='glorot_uniform')
    ])

## Compile our model using the method of least squares (mse) loss function 
## and a stochastic gradient descent (sgd) optimizer
model.compile(loss='mse', optimizer='sgd') ## To try our model with an Adam optimizer simple replace 'sgd' with 'Adam'

## Set our learning rate to 0.01 and print it
model.optimizer.lr.set_value(.001)
print model.optimizer.lr.get_value()

## Fit our model to the noisy data we create above. Notes: 
## The validation split parameter reserves 20% of our data for validation (ie 80% will be used for training)
## The callback parameter is where we tell our model to use the callback function created above
## I don't really know if using a batch size of 1 makes sense
history = model.fit(x=df['x'], y=df['y'], validation_split=0.2, batch_size=1, epochs=100, callbacks=[print_save_weights])

## As the model is fitting the data you can watch below and see how our m and b parameters are improving

## Save and print our final weights
predicted_m = model.get_weights()[0][0][0]
predicted_b = model.get_weights()[1][0]
print "\nm=%.2f b=%.2f\n" % (predicted_m, predicted_b)


0.0010000000475
Train on 80 samples, validate on 20 samples
Epoch 1/100
61/80 [=====================>........] - ETA: 0s - loss: 3254.8245 
m=17.40 b=5.94

80/80 [==============================] - 0s - loss: 2923.9774 - val_loss: 1940.0523
Epoch 2/100
60/80 [=====================>........] - ETA: 0s - loss: 1969.4222
m=16.03 b=9.02

80/80 [==============================] - 0s - loss: 1932.3340 - val_loss: 1705.4155
Epoch 3/100
61/80 [=====================>........] - ETA: 0s - loss: 1779.0832
m=15.05 b=12.05

80/80 [==============================] - 0s - loss: 1800.0216 - val_loss: 1561.6847
Epoch 4/100
60/80 [=====================>........] - ETA: 0s - loss: 1429.9036
m=15.90 b=15.22

80/80 [==============================] - 0s - loss: 1654.1787 - val_loss: 1598.8304
Epoch 5/100
60/80 [=====================>........] - ETA: 0s - loss: 1601.2826
m=14.87 b=18.01

80/80 [==============================] - 0s - loss: 1560.0334 - val_loss: 1426.6680
Epoch 6/100
60/80 [=====================>........] - ETA: 0s - loss: 1473.5384
m=14.45 b=20.81

80/80 [==============================] - 0s - loss: 1455.8174 - val_loss: 1338.8968
Epoch 7/100
61/80 [=====================>........] - ETA: 0s - loss: 1445.0844
m=12.93 b=23.34

80/80 [==============================] - 0s - loss: 1360.0619 - val_loss: 1194.2299
Epoch 8/100
60/80 [=====================>........] - ETA: 0s - loss: 1242.2576
m=12.86 b=26.01

80/80 [==============================] - 0s - loss: 1274.1579 - val_loss: 1126.5568
Epoch 9/100
60/80 [=====================>........] - ETA: 0s - loss: 1117.3057
m=11.90 b=28.42

80/80 [==============================] - 0s - loss: 1173.1392 - val_loss: 1051.0473
Epoch 10/100
60/80 [=====================>........] - ETA: 0s - loss: 989.9489
m=12.64 b=31.02

80/80 [==============================] - 0s - loss: 1103.4493 - val_loss: 1020.1068
Epoch 11/100
60/80 [=====================>........] - ETA: 0s - loss: 1075.7260
m=11.90 b=33.32

80/80 [==============================] - 0s - loss: 1029.6436 - val_loss: 935.0837
Epoch 12/100
61/80 [=====================>........] - ETA: 0s - loss: 938.6137 
m=11.16 b=35.52

80/80 [==============================] - 0s - loss: 958.6370 - val_loss: 867.6731
Epoch 13/100
60/80 [=====================>........] - ETA: 0s - loss: 788.3477 
m=10.73 b=37.68

80/80 [==============================] - 0s - loss: 892.4741 - val_loss: 814.2150
Epoch 14/100
60/80 [=====================>........] - ETA: 0s - loss: 832.6967 
m=10.84 b=39.86

80/80 [==============================] - 0s - loss: 835.5153 - val_loss: 771.0487
Epoch 15/100
61/80 [=====================>........] - ETA: 0s - loss: 782.6091
m=11.18 b=41.99

80/80 [==============================] - 0s - loss: 778.0384 - val_loss: 763.0698
Epoch 16/100
60/80 [=====================>........] - ETA: 0s - loss: 731.2550
m=10.02 b=43.83

80/80 [==============================] - 0s - loss: 729.3196 - val_loss: 675.7371
Epoch 17/100
60/80 [=====================>........] - ETA: 0s - loss: 707.4904
m=10.19 b=45.79

80/80 [==============================] - 0s - loss: 675.4848 - val_loss: 649.0732
Epoch 18/100
60/80 [=====================>........] - ETA: 0s - loss: 666.5468
m=9.45 b=47.56

80/80 [==============================] - 0s - loss: 634.1966 - val_loss: 596.9651
Epoch 19/100
60/80 [=====================>........] - ETA: 0s - loss: 590.0887
m=10.20 b=49.47

80/80 [==============================] - 0s - loss: 587.2458 - val_loss: 613.4780
Epoch 20/100
60/80 [=====================>........] - ETA: 0s - loss: 528.1474
m=10.11 b=51.20

80/80 [==============================] - 0s - loss: 550.8217 - val_loss: 596.0288
Epoch 21/100
60/80 [=====================>........] - ETA: 0s - loss: 578.7846
m=8.81 b=52.69

80/80 [==============================] - 0s - loss: 515.3784 - val_loss: 498.8729
Epoch 22/100
60/80 [=====================>........] - ETA: 0s - loss: 484.9126
m=8.96 b=54.32

80/80 [==============================] - 0s - loss: 478.1580 - val_loss: 482.7918
Epoch 23/100
59/80 [=====================>........] - ETA: 0s - loss: 460.6974 
m=9.26 b=55.94

80/80 [==============================] - 0s - loss: 447.1203 - val_loss: 492.6456
Epoch 24/100
60/80 [=====================>........] - ETA: 0s - loss: 442.1458
m=8.78 b=57.37

80/80 [==============================] - 0s - loss: 419.6172 - val_loss: 445.5236
Epoch 25/100
59/80 [=====================>........] - ETA: 0s - loss: 394.1116
m=8.01 b=58.72

80/80 [==============================] - 0s - loss: 392.3556 - val_loss: 395.0426
Epoch 26/100
60/80 [=====================>........] - ETA: 0s - loss: 392.2094
m=8.03 b=60.14

80/80 [==============================] - 0s - loss: 367.3931 - val_loss: 380.0095
Epoch 27/100
60/80 [=====================>........] - ETA: 0s - loss: 318.5145
m=7.55 b=61.43

80/80 [==============================] - 0s - loss: 343.5631 - val_loss: 350.4184
Epoch 28/100
60/80 [=====================>........] - ETA: 0s - loss: 330.7573
m=7.53 b=62.74

80/80 [==============================] - 0s - loss: 321.3138 - val_loss: 334.7887
Epoch 29/100
60/80 [=====================>........] - ETA: 0s - loss: 233.7687
m=7.55 b=64.01

80/80 [==============================] - 0s - loss: 299.0947 - val_loss: 324.7155
Epoch 30/100
60/80 [=====================>........] - ETA: 0s - loss: 269.2542
m=7.51 b=65.24

80/80 [==============================] - 0s - loss: 280.4874 - val_loss: 314.2952
Epoch 31/100
60/80 [=====================>........] - ETA: 0s - loss: 273.6415
m=7.49 b=66.41

80/80 [==============================] - 0s - loss: 260.6317 - val_loss: 307.8667
Epoch 32/100
60/80 [=====================>........] - ETA: 0s - loss: 216.4861
m=6.77 b=67.45

80/80 [==============================] - 0s - loss: 247.4432 - val_loss: 266.2519
Epoch 33/100
60/80 [=====================>........] - ETA: 0s - loss: 245.7895
m=6.30 b=68.46

80/80 [==============================] - 0s - loss: 225.8498 - val_loss: 248.8595
Epoch 34/100
60/80 [=====================>........] - ETA: 0s - loss: 202.3596
m=7.44 b=69.68

80/80 [==============================] - 0s - loss: 207.3104 - val_loss: 301.9131
Epoch 35/100
60/80 [=====================>........] - ETA: 0s - loss: 202.4290
m=6.27 b=70.53

80/80 [==============================] - 0s - loss: 202.9010 - val_loss: 225.6619
Epoch 36/100
60/80 [=====================>........] - ETA: 0s - loss: 182.5107
m=5.81 b=71.44

80/80 [==============================] - 0s - loss: 187.0969 - val_loss: 212.3431
Epoch 37/100
60/80 [=====================>........] - ETA: 0s - loss: 169.5725
m=6.64 b=72.51

80/80 [==============================] - 0s - loss: 174.2677 - val_loss: 233.3112
Epoch 38/100
60/80 [=====================>........] - ETA: 0s - loss: 166.1943
m=6.07 b=73.34

80/80 [==============================] - 0s - loss: 167.8694 - val_loss: 198.6565
Epoch 39/100
60/80 [=====================>........] - ETA: 0s - loss: 161.0349
m=5.59 b=74.17

80/80 [==============================] - 0s - loss: 155.5988 - val_loss: 181.5560
Epoch 40/100
60/80 [=====================>........] - ETA: 0s - loss: 138.0464
m=5.71 b=75.06

80/80 [==============================] - 0s - loss: 147.6375 - val_loss: 176.5084
Epoch 41/100
60/80 [=====================>........] - ETA: 0s - loss: 131.3308
m=5.55 b=75.86

80/80 [==============================] - 0s - loss: 137.5344 - val_loss: 166.9056
Epoch 42/100
61/80 [=====================>........] - ETA: 0s - loss: 133.1601
m=5.65 b=76.69

80/80 [==============================] - 0s - loss: 129.0335 - val_loss: 165.9233
Epoch 43/100
60/80 [=====================>........] - ETA: 0s - loss: 127.7980
m=5.48 b=77.43

80/80 [==============================] - 0s - loss: 121.1999 - val_loss: 156.1030
Epoch 44/100
60/80 [=====================>........] - ETA: 0s - loss: 114.3041
m=5.10 b=78.12

80/80 [==============================] - 0s - loss: 114.4090 - val_loss: 142.2615
Epoch 45/100
60/80 [=====================>........] - ETA: 0s - loss: 118.1772
m=5.04 b=78.84

80/80 [==============================] - 0s - loss: 107.1359 - val_loss: 136.3020
Epoch 46/100
60/80 [=====================>........] - ETA: 0s - loss: 101.0124
m=4.82 b=79.50

80/80 [==============================] - 0s - loss: 99.5082 - val_loss: 128.7621
Epoch 47/100
61/80 [=====================>........] - ETA: 0s - loss: 98.4678
m=5.38 b=80.26

80/80 [==============================] - 0s - loss: 94.1012 - val_loss: 144.8276
Epoch 48/100
60/80 [=====================>........] - ETA: 0s - loss: 84.5120
m=4.88 b=80.84

80/80 [==============================] - 0s - loss: 91.0142 - val_loss: 121.4670
Epoch 49/100
60/80 [=====================>........] - ETA: 0s - loss: 83.4029
m=5.05 b=81.49

80/80 [==============================] - 0s - loss: 83.5654 - val_loss: 125.8072
Epoch 50/100
60/80 [=====================>........] - ETA: 0s - loss: 78.2002 
m=4.43 b=81.98

80/80 [==============================] - 0s - loss: 78.5084 - val_loss: 107.4848
Epoch 51/100
60/80 [=====================>........] - ETA: 0s - loss: 82.8408
m=4.96 b=82.64

80/80 [==============================] - 0s - loss: 74.4249 - val_loss: 120.4673
Epoch 52/100
60/80 [=====================>........] - ETA: 0s - loss: 70.7674
m=4.61 b=83.15

80/80 [==============================] - 0s - loss: 72.3379 - val_loss: 104.6928
Epoch 53/100
60/80 [=====================>........] - ETA: 0s - loss: 63.8046
m=4.43 b=83.67

80/80 [==============================] - 0s - loss: 68.0940 - val_loss: 97.8390
Epoch 54/100
60/80 [=====================>........] - ETA: 0s - loss: 65.0319
m=4.45 b=84.21

80/80 [==============================] - 0s - loss: 64.5154 - val_loss: 96.5811
Epoch 55/100
60/80 [=====================>........] - ETA: 0s - loss: 59.7439
m=4.20 b=84.68

80/80 [==============================] - 0s - loss: 60.7724 - val_loss: 88.7473
Epoch 56/100
60/80 [=====================>........] - ETA: 0s - loss: 60.2357
m=4.42 b=85.20

80/80 [==============================] - 0s - loss: 56.4120 - val_loss: 93.6560
Epoch 57/100
60/80 [=====================>........] - ETA: 0s - loss: 49.9111
m=4.23 b=85.65

80/80 [==============================] - 0s - loss: 54.8564 - val_loss: 85.9741
Epoch 58/100
60/80 [=====================>........] - ETA: 0s - loss: 48.9606
m=4.05 b=86.08

80/80 [==============================] - 0s - loss: 52.1978 - val_loss: 80.1632
Epoch 59/100
60/80 [=====================>........] - ETA: 0s - loss: 49.3268 
m=3.89 b=86.50

80/80 [==============================] - 0s - loss: 49.5656 - val_loss: 75.7317
Epoch 60/100
60/80 [=====================>........] - ETA: 0s - loss: 47.1471
m=3.98 b=86.93

80/80 [==============================] - 0s - loss: 47.0151 - val_loss: 75.6728
Epoch 61/100
60/80 [=====================>........] - ETA: 0s - loss: 42.2812
m=3.98 b=87.35

80/80 [==============================] - 0s - loss: 44.4397 - val_loss: 74.9013
Epoch 62/100
60/80 [=====================>........] - ETA: 0s - loss: 42.9502
m=3.93 b=87.74

80/80 [==============================] - 0s - loss: 42.4925 - val_loss: 72.4803
Epoch 63/100
60/80 [=====================>........] - ETA: 0s - loss: 46.7105
m=3.54 b=88.06

80/80 [==============================] - 0s - loss: 40.6862 - val_loss: 65.0597
Epoch 64/100
61/80 [=====================>........] - ETA: 0s - loss: 42.3674 
m=3.71 b=88.46

80/80 [==============================] - 0s - loss: 39.7936 - val_loss: 65.3012
Epoch 65/100
60/80 [=====================>........] - ETA: 0s - loss: 35.7763
m=3.79 b=88.83

80/80 [==============================] - 0s - loss: 37.4369 - val_loss: 66.6885
Epoch 66/100
60/80 [=====================>........] - ETA: 0s - loss: 38.1030
m=3.51 b=89.14

80/80 [==============================] - 0s - loss: 36.0938 - val_loss: 59.9070
Epoch 67/100
60/80 [=====================>........] - ETA: 0s - loss: 35.5782
m=3.44 b=89.46

80/80 [==============================] - 0s - loss: 34.3016 - val_loss: 57.8604
Epoch 68/100
60/80 [=====================>........] - ETA: 0s - loss: 35.1737
m=3.35 b=89.77

80/80 [==============================] - 0s - loss: 33.0113 - val_loss: 55.8120
Epoch 69/100
60/80 [=====================>........] - ETA: 0s - loss: 32.4944
m=3.51 b=90.10

80/80 [==============================] - 0s - loss: 31.8562 - val_loss: 57.2174
Epoch 70/100
60/80 [=====================>........] - ETA: 0s - loss: 26.2677
m=3.49 b=90.40

80/80 [==============================] - 0s - loss: 30.7014 - val_loss: 56.3331
Epoch 71/100
60/80 [=====================>........] - ETA: 0s - loss: 31.0274
m=3.43 b=90.67

80/80 [==============================] - 0s - loss: 29.7172 - val_loss: 54.4210
Epoch 72/100
60/80 [=====================>........] - ETA: 0s - loss: 28.7420
m=3.31 b=90.93

80/80 [==============================] - 0s - loss: 28.5992 - val_loss: 51.4180
Epoch 73/100
60/80 [=====================>........] - ETA: 0s - loss: 28.1412
m=3.24 b=91.19

80/80 [==============================] - 0s - loss: 27.5390 - val_loss: 49.6822
Epoch 74/100
60/80 [=====================>........] - ETA: 0s - loss: 24.8776
m=3.51 b=91.49

80/80 [==============================] - 0s - loss: 25.8587 - val_loss: 56.5398
Epoch 75/100
60/80 [=====================>........] - ETA: 0s - loss: 27.3154
m=3.22 b=91.70

80/80 [==============================] - 0s - loss: 26.2401 - val_loss: 48.0600
Epoch 76/100
58/80 [====================>.........] - ETA: 0s - loss: 24.6736
m=3.36 b=91.97

80/80 [==============================] - 0s - loss: 24.9628 - val_loss: 51.5836
Epoch 77/100
59/80 [=====================>........] - ETA: 0s - loss: 24.4853
m=3.09 b=92.16

80/80 [==============================] - 0s - loss: 24.7051 - val_loss: 45.0907
Epoch 78/100
60/80 [=====================>........] - ETA: 0s - loss: 24.8046
m=3.13 b=92.39

80/80 [==============================] - 0s - loss: 24.1694 - val_loss: 45.1782
Epoch 79/100
60/80 [=====================>........] - ETA: 0s - loss: 24.1394
m=3.07 b=92.61

80/80 [==============================] - 0s - loss: 23.4496 - val_loss: 43.8457
Epoch 80/100
60/80 [=====================>........] - ETA: 0s - loss: 17.7516
m=3.18 b=92.83

80/80 [==============================] - 0s - loss: 22.4935 - val_loss: 46.1222
Epoch 81/100
59/80 [=====================>........] - ETA: 0s - loss: 21.2255
m=2.96 b=93.01

80/80 [==============================] - 0s - loss: 22.0972 - val_loss: 41.2891
Epoch 82/100
60/80 [=====================>........] - ETA: 0s - loss: 20.9124
m=2.98 b=93.21

80/80 [==============================] - 0s - loss: 21.7000 - val_loss: 41.2505
Epoch 83/100
59/80 [=====================>........] - ETA: 0s - loss: 21.6255
m=3.03 b=93.41

80/80 [==============================] - 0s - loss: 21.4863 - val_loss: 41.9179
Epoch 84/100
59/80 [=====================>........] - ETA: 0s - loss: 22.3901
m=2.91 b=93.58

80/80 [==============================] - 0s - loss: 20.9283 - val_loss: 39.4799
Epoch 85/100
60/80 [=====================>........] - ETA: 0s - loss: 20.2719
m=3.05 b=93.77

80/80 [==============================] - 0s - loss: 19.4856 - val_loss: 42.3431
Epoch 86/100
59/80 [=====================>........] - ETA: 0s - loss: 20.9492   
m=3.04 b=93.94

80/80 [==============================] - 0s - loss: 19.9136 - val_loss: 42.2160
Epoch 87/100
59/80 [=====================>........] - ETA: 0s - loss: 21.9033
m=3.03 b=94.10

80/80 [==============================] - 0s - loss: 19.7118 - val_loss: 42.0752
Epoch 88/100
59/80 [=====================>........] - ETA: 0s - loss: 17.6468
m=2.84 b=94.23

80/80 [==============================] - 0s - loss: 18.9890 - val_loss: 37.3732
Epoch 89/100
59/80 [=====================>........] - ETA: 0s - loss: 18.9624
m=2.76 b=94.37

80/80 [==============================] - 0s - loss: 19.1146 - val_loss: 35.9136
Epoch 90/100
60/80 [=====================>........] - ETA: 0s - loss: 18.5635
m=2.84 b=94.53

80/80 [==============================] - 0s - loss: 18.7964 - val_loss: 36.9875
Epoch 91/100
60/80 [=====================>........] - ETA: 0s - loss: 17.4610
m=2.85 b=94.67

80/80 [==============================] - 0s - loss: 18.2107 - val_loss: 37.1862
Epoch 92/100
60/80 [=====================>........] - ETA: 0s - loss: 16.6464
m=2.78 b=94.80

80/80 [==============================] - 0s - loss: 18.1023 - val_loss: 35.6384
Epoch 93/100
60/80 [=====================>........] - ETA: 0s - loss: 17.8673
m=2.84 b=94.94

80/80 [==============================] - 0s - loss: 17.7690 - val_loss: 37.0773
Epoch 94/100
60/80 [=====================>........] - ETA: 0s - loss: 15.0662
m=2.74 b=95.05

80/80 [==============================] - 0s - loss: 17.9816 - val_loss: 34.6711
Epoch 95/100
60/80 [=====================>........] - ETA: 0s - loss: 16.7920
m=2.70 b=95.17

80/80 [==============================] - 0s - loss: 17.5933 - val_loss: 33.9587
Epoch 96/100
60/80 [=====================>........] - ETA: 0s - loss: 15.3865
m=2.79 b=95.30

80/80 [==============================] - 0s - loss: 17.0875 - val_loss: 35.8242
Epoch 97/100
60/80 [=====================>........] - ETA: 0s - loss: 18.0846
m=2.72 b=95.41

80/80 [==============================] - 0s - loss: 17.4079 - val_loss: 34.1651
Epoch 98/100
60/80 [=====================>........] - ETA: 0s - loss: 16.5372
m=2.42 b=95.47

80/80 [==============================] - 0s - loss: 16.2869 - val_loss: 31.2707
Epoch 99/100
60/80 [=====================>........] - ETA: 0s - loss: 17.9679
m=2.54 b=95.60

80/80 [==============================] - 0s - loss: 17.3852 - val_loss: 31.3174
Epoch 100/100
59/80 [=====================>........] - ETA: 0s - loss: 16.1793
m=2.60 b=95.71

80/80 [==============================] - 0s - loss: 17.0351 - val_loss: 31.8478

m=2.60 b=95.71

Plot our model's slope (m) and y-intercept (b) guesses over each epoch

Seeing this plot really helped me understand how the model improves its guesses over each epoch


In [189]:
import matplotlib.pyplot as plt

plt.plot(print_save_weights.weights_history['m'])
plt.plot(print_save_weights.weights_history['b'])
plt.title('Predicted Weights')
plt.ylabel('weights')
plt.xlabel('epoch')
plt.legend(['m', 'b'], loc='upper left')
plt.show()


Plot our model's loss function over time

Seeing this plot really helped me understand how the model is improving its loss over each epoch.


In [190]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='upper right')
plt.show()


Plot our model's prediction over the data and real line


In [191]:
## Create our predicted y's based on the model
df['y_predicted'] = df['x'].apply(lambda x: predicted_m*x + predicted_b)

## Plot the original data with a standard linear regression
ax1 = sns.regplot(x='x', y='y', data=df, label='real')

## Plot our predicted line based on our Keras model's slope and y-intercept
ax2 = sns.regplot(x='x', y='y_predicted', data=df, scatter=False, label='predicted')
ax2.legend(loc="upper left")


Out[191]:
<matplotlib.legend.Legend at 0x7f5f84e2f150>

Conclusion

As we would expect, the standard linear regression does a slightly better job of approximating our linear data (duh). But we've also done something pretty cool with our Keras model. Without explicitly telling our model that our data was approximately linear, it sort of "learned" that it was. This means we can use the same technique of neural network for functions that aren't linear (like image processing, or speech recognition).


In [ ]: