TensorNetworks in Neural Networks.

Here, we have a small toy example of how to use a TN inside of a fully connected neural network.

First off, let's install tensornetwork


In [4]:
!pip install tensornetwork

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
tf.enable_v2_behavior()
# Import tensornetwork
import tensornetwork as tn
# Set the backend to tesorflow
# (default is numpy)
tn.set_default_backend("tensorflow")


Collecting tensornetwork
  Downloading https://files.pythonhosted.org/packages/d5/84/4421ac1add2011e50e8d85dc1a8446f5eeae8ad404cb4df6d4d598a61383/tensornetwork-0.2.1-py3-none-any.whl (232kB)
     |████████████████████████████████| 235kB 4.9MB/s 
Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.6/dist-packages (from tensornetwork) (1.17.5)
Collecting graphviz>=0.11.1
  Downloading https://files.pythonhosted.org/packages/f5/74/dbed754c0abd63768d3a7a7b472da35b08ac442cf87d73d5850a6f32391e/graphviz-0.13.2-py2.py3-none-any.whl
Collecting h5py>=2.9.0
  Downloading https://files.pythonhosted.org/packages/60/06/cafdd44889200e5438b897388f3075b52a8ef01f28a17366d91de0fa2d05/h5py-2.10.0-cp36-cp36m-manylinux1_x86_64.whl (2.9MB)
     |████████████████████████████████| 2.9MB 45.0MB/s 
Requirement already satisfied: opt-einsum>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensornetwork) (3.1.0)
Requirement already satisfied: scipy>=1.1 in /usr/local/lib/python3.6/dist-packages (from tensornetwork) (1.4.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from h5py>=2.9.0->tensornetwork) (1.12.0)
Installing collected packages: graphviz, h5py, tensornetwork
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
  Found existing installation: h5py 2.8.0
    Uninstalling h5py-2.8.0:
      Successfully uninstalled h5py-2.8.0
Successfully installed graphviz-0.13.2 h5py-2.10.0 tensornetwork-0.2.1

The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
We recommend you upgrade now or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic: .

TensorNetwork layer definition

Here, we define the TensorNetwork layer we wish to use to replace the fully connected layer. Here, we simply use a 2 node Matrix Product Operator network to replace the normal dense weight matrix.

We TensorNetwork's NCon API to keep the code short.


In [0]:
class TNLayer(tf.keras.layers.Layer):

  def __init__(self):
    super(TNLayer, self).__init__()
    # Create the variables for the layer.
    self.a_var = tf.Variable(tf.random.normal(
            shape=(32, 32, 2), stddev=1.0/32.0),
             name="a", trainable=True)
    self.b_var = tf.Variable(tf.random.normal(shape=(32, 32, 2), stddev=1.0/32.0),
                             name="b", trainable=True)
    self.bias = tf.Variable(tf.zeros(shape=(32, 32)), name="bias", trainable=True)

  def call(self, inputs):
    # Define the contraction.
    # We break it out so we can parallelize a batch using
    # tf.vectorized_map (see below).
    def f(input_vec, a_var, b_var, bias_var):
      # Reshape to a matrix instead of a vector.
      input_vec = tf.reshape(input_vec, (32,32))

      # Now we create the network.
      a = tn.Node(a_var)
      b = tn.Node(b_var)
      x_node = tn.Node(input_vec)
      a[1] ^ x_node[0]
      b[1] ^ x_node[1]
      a[2] ^ b[2]

      # The TN should now look like this
      #   |     |
      #   a --- b
      #    \   /
      #      x

      # Now we begin the contraction.
      c = a @ x_node
      result = (c @ b).tensor

      # To make the code shorter, we also could've used Ncon.
      # The above few lines of code is the same as this:
      # result = tn.ncon([x, a_var, b_var], [[1, 2], [-1, 1, 3], [-2, 2, 3]])

      # Finally, add bias.
      return result + bias_var
  
    # To deal with a batch of items, we can use the tf.vectorized_map
    # function.
    # https://www.tensorflow.org/api_docs/python/tf/vectorized_map
    result = tf.vectorized_map(
        lambda vec: f(vec, self.a_var, self.b_var, self.bias), inputs)
    return tf.nn.relu(tf.reshape(result, (-1, 1024)))

Smaller model

These two models are effectively the same, but notice how the TN layer has nearly 10x fewer parameters.


In [19]:
Dense = tf.keras.layers.Dense
fc_model = tf.keras.Sequential(
    [
     tf.keras.Input(shape=(2,)),
     Dense(1024, activation=tf.nn.relu),
     Dense(1024, activation=tf.nn.relu),
     Dense(1, activation=None)])
fc_model.summary()


Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_15 (Dense)             (None, 1024)              3072      
_________________________________________________________________
dense_16 (Dense)             (None, 1024)              1049600   
_________________________________________________________________
dense_17 (Dense)             (None, 1)                 1025      
=================================================================
Total params: 1,053,697
Trainable params: 1,053,697
Non-trainable params: 0
_________________________________________________________________

In [27]:
tn_model = tf.keras.Sequential(
    [
     tf.keras.Input(shape=(2,)),
     Dense(1024, activation=tf.nn.relu),
     # Here, we replace the dense layer with our MPS.
     TNLayer(),
     Dense(1, activation=None)])
tn_model.summary()


Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_20 (Dense)             (None, 1024)              3072      
_________________________________________________________________
tn_layer_4 (TNLayer)         (None, 1024)              5120      
_________________________________________________________________
dense_21 (Dense)             (None, 1)                 1025      
=================================================================
Total params: 9,217
Trainable params: 9,217
Non-trainable params: 0
_________________________________________________________________

Training a model

You can train the TN model just as you would a normal neural network model! Here, we give an example of how to do it in Keras.


In [0]:
X = np.concatenate([np.random.randn(20, 2) + np.array([3, 3]), 
             np.random.randn(20, 2) + np.array([-3, -3]), 
             np.random.randn(20, 2) + np.array([-3, 3]), 
             np.random.randn(20, 2) + np.array([3, -3]),])

Y = np.concatenate([np.ones((40)), -np.ones((40))])

In [29]:
tn_model.compile(optimizer="adam", loss="mean_squared_error")
tn_model.fit(X, Y, epochs=300, verbose=1)


Train on 80 samples
Epoch 1/300
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/indexed_slices.py:424: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
80/80 [==============================] - 1s 8ms/sample - loss: 0.9842
Epoch 2/300
80/80 [==============================] - 0s 238us/sample - loss: 0.9290
Epoch 3/300
32/80 [===========>..................] - ETA: 0s - loss: 0.8913
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/indexed_slices.py:424: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
80/80 [==============================] - 0s 260us/sample - loss: 0.8794
Epoch 4/300
80/80 [==============================] - 0s 228us/sample - loss: 0.8258
Epoch 5/300
80/80 [==============================] - 0s 230us/sample - loss: 0.7643
Epoch 6/300
80/80 [==============================] - 0s 209us/sample - loss: 0.6869
Epoch 7/300
80/80 [==============================] - 0s 200us/sample - loss: 0.5934
Epoch 8/300
80/80 [==============================] - 0s 307us/sample - loss: 0.4874
Epoch 9/300
80/80 [==============================] - 0s 249us/sample - loss: 0.3713
Epoch 10/300
80/80 [==============================] - 0s 266us/sample - loss: 0.2546
Epoch 11/300
80/80 [==============================] - 0s 220us/sample - loss: 0.1580
Epoch 12/300
80/80 [==============================] - 0s 211us/sample - loss: 0.1004
Epoch 13/300
80/80 [==============================] - 0s 247us/sample - loss: 0.0962
Epoch 14/300
80/80 [==============================] - 0s 260us/sample - loss: 0.1127
Epoch 15/300
80/80 [==============================] - 0s 335us/sample - loss: 0.1125
Epoch 16/300
80/80 [==============================] - 0s 292us/sample - loss: 0.0933
Epoch 17/300
80/80 [==============================] - 0s 250us/sample - loss: 0.0785
Epoch 18/300
80/80 [==============================] - 0s 280us/sample - loss: 0.0740
Epoch 19/300
80/80 [==============================] - 0s 228us/sample - loss: 0.0751
Epoch 20/300
80/80 [==============================] - 0s 268us/sample - loss: 0.0759
Epoch 21/300
80/80 [==============================] - 0s 337us/sample - loss: 0.0722
Epoch 22/300
80/80 [==============================] - 0s 319us/sample - loss: 0.0670
Epoch 23/300
80/80 [==============================] - 0s 304us/sample - loss: 0.0660
Epoch 24/300
80/80 [==============================] - 0s 275us/sample - loss: 0.0642
Epoch 25/300
80/80 [==============================] - 0s 245us/sample - loss: 0.0617
Epoch 26/300
80/80 [==============================] - 0s 269us/sample - loss: 0.0603
Epoch 27/300
80/80 [==============================] - 0s 365us/sample - loss: 0.0594
Epoch 28/300
80/80 [==============================] - 0s 207us/sample - loss: 0.0570
Epoch 29/300
80/80 [==============================] - 0s 235us/sample - loss: 0.0554
Epoch 30/300
80/80 [==============================] - 0s 253us/sample - loss: 0.0550
Epoch 31/300
80/80 [==============================] - 0s 257us/sample - loss: 0.0530
Epoch 32/300
80/80 [==============================] - 0s 375us/sample - loss: 0.0510
Epoch 33/300
80/80 [==============================] - 0s 289us/sample - loss: 0.0494
Epoch 34/300
80/80 [==============================] - 0s 282us/sample - loss: 0.0479
Epoch 35/300
80/80 [==============================] - 0s 307us/sample - loss: 0.0468
Epoch 36/300
80/80 [==============================] - 0s 246us/sample - loss: 0.0453
Epoch 37/300
80/80 [==============================] - 0s 240us/sample - loss: 0.0439
Epoch 38/300
80/80 [==============================] - 0s 250us/sample - loss: 0.0430
Epoch 39/300
80/80 [==============================] - 0s 368us/sample - loss: 0.0418
Epoch 40/300
80/80 [==============================] - 0s 349us/sample - loss: 0.0398
Epoch 41/300
80/80 [==============================] - 0s 335us/sample - loss: 0.0392
Epoch 42/300
80/80 [==============================] - 0s 296us/sample - loss: 0.0370
Epoch 43/300
80/80 [==============================] - 0s 294us/sample - loss: 0.0355
Epoch 44/300
80/80 [==============================] - 0s 347us/sample - loss: 0.0347
Epoch 45/300
80/80 [==============================] - 0s 261us/sample - loss: 0.0337
Epoch 46/300
80/80 [==============================] - 0s 292us/sample - loss: 0.0318
Epoch 47/300
80/80 [==============================] - 0s 262us/sample - loss: 0.0306
Epoch 48/300
80/80 [==============================] - 0s 299us/sample - loss: 0.0295
Epoch 49/300
80/80 [==============================] - 0s 216us/sample - loss: 0.0282
Epoch 50/300
80/80 [==============================] - 0s 383us/sample - loss: 0.0274
Epoch 51/300
80/80 [==============================] - 0s 310us/sample - loss: 0.0255
Epoch 52/300
80/80 [==============================] - 0s 282us/sample - loss: 0.0242
Epoch 53/300
80/80 [==============================] - 0s 219us/sample - loss: 0.0238
Epoch 54/300
80/80 [==============================] - 0s 213us/sample - loss: 0.0216
Epoch 55/300
80/80 [==============================] - 0s 244us/sample - loss: 0.0204
Epoch 56/300
80/80 [==============================] - 0s 304us/sample - loss: 0.0197
Epoch 57/300
80/80 [==============================] - 0s 275us/sample - loss: 0.0178
Epoch 58/300
80/80 [==============================] - 0s 223us/sample - loss: 0.0170
Epoch 59/300
80/80 [==============================] - 0s 231us/sample - loss: 0.0158
Epoch 60/300
80/80 [==============================] - 0s 276us/sample - loss: 0.0147
Epoch 61/300
80/80 [==============================] - 0s 285us/sample - loss: 0.0139
Epoch 62/300
80/80 [==============================] - 0s 411us/sample - loss: 0.0131
Epoch 63/300
80/80 [==============================] - 0s 223us/sample - loss: 0.0120
Epoch 64/300
80/80 [==============================] - 0s 293us/sample - loss: 0.0110
Epoch 65/300
80/80 [==============================] - 0s 234us/sample - loss: 0.0105
Epoch 66/300
80/80 [==============================] - 0s 229us/sample - loss: 0.0097
Epoch 67/300
80/80 [==============================] - 0s 254us/sample - loss: 0.0088
Epoch 68/300
80/80 [==============================] - 0s 339us/sample - loss: 0.0083
Epoch 69/300
80/80 [==============================] - 0s 303us/sample - loss: 0.0076
Epoch 70/300
80/80 [==============================] - 0s 268us/sample - loss: 0.0070
Epoch 71/300
80/80 [==============================] - 0s 243us/sample - loss: 0.0065
Epoch 72/300
80/80 [==============================] - 0s 219us/sample - loss: 0.0060
Epoch 73/300
80/80 [==============================] - 0s 227us/sample - loss: 0.0055
Epoch 74/300
80/80 [==============================] - 0s 349us/sample - loss: 0.0053
Epoch 75/300
80/80 [==============================] - 0s 268us/sample - loss: 0.0049
Epoch 76/300
80/80 [==============================] - 0s 281us/sample - loss: 0.0045
Epoch 77/300
80/80 [==============================] - 0s 233us/sample - loss: 0.0043
Epoch 78/300
80/80 [==============================] - 0s 216us/sample - loss: 0.0039
Epoch 79/300
80/80 [==============================] - 0s 267us/sample - loss: 0.0036
Epoch 80/300
80/80 [==============================] - 0s 386us/sample - loss: 0.0034
Epoch 81/300
80/80 [==============================] - 0s 303us/sample - loss: 0.0033
Epoch 82/300
80/80 [==============================] - 0s 304us/sample - loss: 0.0031
Epoch 83/300
80/80 [==============================] - 0s 318us/sample - loss: 0.0029
Epoch 84/300
80/80 [==============================] - 0s 277us/sample - loss: 0.0025
Epoch 85/300
80/80 [==============================] - 0s 343us/sample - loss: 0.0023
Epoch 86/300
80/80 [==============================] - 0s 236us/sample - loss: 0.0022
Epoch 87/300
80/80 [==============================] - 0s 257us/sample - loss: 0.0021
Epoch 88/300
80/80 [==============================] - 0s 232us/sample - loss: 0.0019
Epoch 89/300
80/80 [==============================] - 0s 273us/sample - loss: 0.0017
Epoch 90/300
80/80 [==============================] - 0s 279us/sample - loss: 0.0016
Epoch 91/300
80/80 [==============================] - 0s 289us/sample - loss: 0.0016
Epoch 92/300
80/80 [==============================] - 0s 366us/sample - loss: 0.0015
Epoch 93/300
80/80 [==============================] - 0s 343us/sample - loss: 0.0013
Epoch 94/300
80/80 [==============================] - 0s 285us/sample - loss: 0.0012
Epoch 95/300
80/80 [==============================] - 0s 265us/sample - loss: 0.0012
Epoch 96/300
80/80 [==============================] - 0s 236us/sample - loss: 0.0011
Epoch 97/300
80/80 [==============================] - 0s 264us/sample - loss: 9.3151e-04
Epoch 98/300
80/80 [==============================] - 0s 240us/sample - loss: 8.6390e-04
Epoch 99/300
80/80 [==============================] - 0s 265us/sample - loss: 8.3163e-04
Epoch 100/300
80/80 [==============================] - 0s 268us/sample - loss: 7.5102e-04
Epoch 101/300
80/80 [==============================] - 0s 230us/sample - loss: 7.2688e-04
Epoch 102/300
80/80 [==============================] - 0s 362us/sample - loss: 6.5222e-04
Epoch 103/300
80/80 [==============================] - 0s 368us/sample - loss: 6.5107e-04
Epoch 104/300
80/80 [==============================] - 0s 282us/sample - loss: 6.3698e-04
Epoch 105/300
80/80 [==============================] - 0s 296us/sample - loss: 5.3166e-04
Epoch 106/300
80/80 [==============================] - 0s 305us/sample - loss: 4.7975e-04
Epoch 107/300
80/80 [==============================] - 0s 308us/sample - loss: 4.7437e-04
Epoch 108/300
80/80 [==============================] - 0s 372us/sample - loss: 4.9682e-04
Epoch 109/300
80/80 [==============================] - 0s 346us/sample - loss: 4.4769e-04
Epoch 110/300
80/80 [==============================] - 0s 304us/sample - loss: 4.0579e-04
Epoch 111/300
80/80 [==============================] - 0s 234us/sample - loss: 4.2932e-04
Epoch 112/300
80/80 [==============================] - 0s 325us/sample - loss: 4.0622e-04
Epoch 113/300
80/80 [==============================] - 0s 303us/sample - loss: 3.3058e-04
Epoch 114/300
80/80 [==============================] - 0s 334us/sample - loss: 3.3755e-04
Epoch 115/300
80/80 [==============================] - 0s 303us/sample - loss: 3.0059e-04
Epoch 116/300
80/80 [==============================] - 0s 266us/sample - loss: 3.4185e-04
Epoch 117/300
80/80 [==============================] - 0s 292us/sample - loss: 2.6187e-04
Epoch 118/300
80/80 [==============================] - 0s 278us/sample - loss: 2.4878e-04
Epoch 119/300
80/80 [==============================] - 0s 321us/sample - loss: 2.4584e-04
Epoch 120/300
80/80 [==============================] - 0s 296us/sample - loss: 2.2617e-04
Epoch 121/300
80/80 [==============================] - 0s 260us/sample - loss: 2.0133e-04
Epoch 122/300
80/80 [==============================] - 0s 291us/sample - loss: 1.8818e-04
Epoch 123/300
80/80 [==============================] - 0s 273us/sample - loss: 1.6604e-04
Epoch 124/300
80/80 [==============================] - 0s 269us/sample - loss: 1.7035e-04
Epoch 125/300
80/80 [==============================] - 0s 364us/sample - loss: 1.5003e-04
Epoch 126/300
80/80 [==============================] - 0s 312us/sample - loss: 1.4791e-04
Epoch 127/300
80/80 [==============================] - 0s 295us/sample - loss: 1.2680e-04
Epoch 128/300
80/80 [==============================] - 0s 253us/sample - loss: 1.2657e-04
Epoch 129/300
80/80 [==============================] - 0s 272us/sample - loss: 1.2066e-04
Epoch 130/300
80/80 [==============================] - 0s 323us/sample - loss: 1.1825e-04
Epoch 131/300
80/80 [==============================] - 0s 267us/sample - loss: 1.0894e-04
Epoch 132/300
80/80 [==============================] - 0s 333us/sample - loss: 1.0311e-04
Epoch 133/300
80/80 [==============================] - 0s 253us/sample - loss: 1.0394e-04
Epoch 134/300
80/80 [==============================] - 0s 258us/sample - loss: 9.2189e-05
Epoch 135/300
80/80 [==============================] - 0s 248us/sample - loss: 8.4650e-05
Epoch 136/300
80/80 [==============================] - 0s 256us/sample - loss: 8.3155e-05
Epoch 137/300
80/80 [==============================] - 0s 279us/sample - loss: 8.0666e-05
Epoch 138/300
80/80 [==============================] - 0s 250us/sample - loss: 7.7635e-05
Epoch 139/300
80/80 [==============================] - 0s 259us/sample - loss: 7.7501e-05
Epoch 140/300
80/80 [==============================] - 0s 240us/sample - loss: 7.6662e-05
Epoch 141/300
80/80 [==============================] - 0s 249us/sample - loss: 6.8760e-05
Epoch 142/300
80/80 [==============================] - 0s 270us/sample - loss: 6.5677e-05
Epoch 143/300
80/80 [==============================] - 0s 293us/sample - loss: 6.1257e-05
Epoch 144/300
80/80 [==============================] - 0s 272us/sample - loss: 6.8409e-05
Epoch 145/300
80/80 [==============================] - 0s 387us/sample - loss: 7.6242e-05
Epoch 146/300
80/80 [==============================] - 0s 238us/sample - loss: 6.2456e-05
Epoch 147/300
80/80 [==============================] - 0s 235us/sample - loss: 5.8871e-05
Epoch 148/300
80/80 [==============================] - 0s 300us/sample - loss: 7.4420e-05
Epoch 149/300
80/80 [==============================] - 0s 316us/sample - loss: 5.8465e-05
Epoch 150/300
80/80 [==============================] - 0s 270us/sample - loss: 5.1686e-05
Epoch 151/300
80/80 [==============================] - 0s 234us/sample - loss: 5.3041e-05
Epoch 152/300
80/80 [==============================] - 0s 269us/sample - loss: 5.4061e-05
Epoch 153/300
80/80 [==============================] - 0s 330us/sample - loss: 5.1879e-05
Epoch 154/300
80/80 [==============================] - 0s 328us/sample - loss: 4.6496e-05
Epoch 155/300
80/80 [==============================] - 0s 234us/sample - loss: 5.5916e-05
Epoch 156/300
80/80 [==============================] - 0s 362us/sample - loss: 4.6813e-05
Epoch 157/300
80/80 [==============================] - 0s 288us/sample - loss: 5.6480e-05
Epoch 158/300
80/80 [==============================] - 0s 270us/sample - loss: 5.2724e-05
Epoch 159/300
80/80 [==============================] - 0s 225us/sample - loss: 4.6230e-05
Epoch 160/300
80/80 [==============================] - 0s 300us/sample - loss: 5.5393e-05
Epoch 161/300
80/80 [==============================] - 0s 310us/sample - loss: 5.3761e-05
Epoch 162/300
80/80 [==============================] - 0s 290us/sample - loss: 4.5535e-05
Epoch 163/300
80/80 [==============================] - 0s 240us/sample - loss: 5.9120e-05
Epoch 164/300
80/80 [==============================] - 0s 244us/sample - loss: 6.0277e-05
Epoch 165/300
80/80 [==============================] - 0s 277us/sample - loss: 6.0580e-05
Epoch 166/300
80/80 [==============================] - 0s 268us/sample - loss: 5.4407e-05
Epoch 167/300
80/80 [==============================] - 0s 289us/sample - loss: 5.3495e-05
Epoch 168/300
80/80 [==============================] - 0s 231us/sample - loss: 3.6709e-05
Epoch 169/300
80/80 [==============================] - 0s 212us/sample - loss: 4.9593e-05
Epoch 170/300
80/80 [==============================] - 0s 273us/sample - loss: 4.0684e-05
Epoch 171/300
80/80 [==============================] - 0s 272us/sample - loss: 4.6589e-05
Epoch 172/300
80/80 [==============================] - 0s 293us/sample - loss: 4.0371e-05
Epoch 173/300
80/80 [==============================] - 0s 420us/sample - loss: 4.5683e-05
Epoch 174/300
80/80 [==============================] - 0s 343us/sample - loss: 3.7670e-05
Epoch 175/300
80/80 [==============================] - 0s 217us/sample - loss: 3.4601e-05
Epoch 176/300
80/80 [==============================] - 0s 284us/sample - loss: 2.3439e-05
Epoch 177/300
80/80 [==============================] - 0s 238us/sample - loss: 2.2799e-05
Epoch 178/300
80/80 [==============================] - 0s 344us/sample - loss: 2.0087e-05
Epoch 179/300
80/80 [==============================] - 0s 272us/sample - loss: 2.3606e-05
Epoch 180/300
80/80 [==============================] - 0s 251us/sample - loss: 1.9352e-05
Epoch 181/300
80/80 [==============================] - 0s 248us/sample - loss: 2.3715e-05
Epoch 182/300
80/80 [==============================] - 0s 263us/sample - loss: 2.4014e-05
Epoch 183/300
80/80 [==============================] - 0s 249us/sample - loss: 2.4038e-05
Epoch 184/300
80/80 [==============================] - 0s 341us/sample - loss: 2.1474e-05
Epoch 185/300
80/80 [==============================] - 0s 243us/sample - loss: 1.9672e-05
Epoch 186/300
80/80 [==============================] - 0s 313us/sample - loss: 1.7820e-05
Epoch 187/300
80/80 [==============================] - 0s 221us/sample - loss: 1.5253e-05
Epoch 188/300
80/80 [==============================] - 0s 217us/sample - loss: 1.8206e-05
Epoch 189/300
80/80 [==============================] - 0s 279us/sample - loss: 1.6989e-05
Epoch 190/300
80/80 [==============================] - 0s 304us/sample - loss: 1.5647e-05
Epoch 191/300
80/80 [==============================] - 0s 275us/sample - loss: 1.5527e-05
Epoch 192/300
80/80 [==============================] - 0s 274us/sample - loss: 1.6099e-05
Epoch 193/300
80/80 [==============================] - 0s 283us/sample - loss: 1.7626e-05
Epoch 194/300
80/80 [==============================] - 0s 249us/sample - loss: 1.4902e-05
Epoch 195/300
80/80 [==============================] - 0s 296us/sample - loss: 1.4011e-05
Epoch 196/300
80/80 [==============================] - 0s 381us/sample - loss: 1.3903e-05
Epoch 197/300
80/80 [==============================] - 0s 360us/sample - loss: 1.2729e-05
Epoch 198/300
80/80 [==============================] - 0s 370us/sample - loss: 1.2183e-05
Epoch 199/300
80/80 [==============================] - 0s 268us/sample - loss: 1.0802e-05
Epoch 200/300
80/80 [==============================] - 0s 222us/sample - loss: 1.0722e-05
Epoch 201/300
80/80 [==============================] - 0s 232us/sample - loss: 1.0619e-05
Epoch 202/300
80/80 [==============================] - 0s 218us/sample - loss: 9.9713e-06
Epoch 203/300
80/80 [==============================] - 0s 306us/sample - loss: 1.0005e-05
Epoch 204/300
80/80 [==============================] - 0s 213us/sample - loss: 9.7906e-06
Epoch 205/300
80/80 [==============================] - 0s 272us/sample - loss: 9.3796e-06
Epoch 206/300
80/80 [==============================] - 0s 214us/sample - loss: 9.2135e-06
Epoch 207/300
80/80 [==============================] - 0s 343us/sample - loss: 8.6900e-06
Epoch 208/300
80/80 [==============================] - 0s 245us/sample - loss: 8.2565e-06
Epoch 209/300
80/80 [==============================] - 0s 216us/sample - loss: 7.7426e-06
Epoch 210/300
80/80 [==============================] - 0s 224us/sample - loss: 9.4238e-06
Epoch 211/300
80/80 [==============================] - 0s 218us/sample - loss: 9.5810e-06
Epoch 212/300
80/80 [==============================] - 0s 266us/sample - loss: 8.3745e-06
Epoch 213/300
80/80 [==============================] - 0s 226us/sample - loss: 7.6993e-06
Epoch 214/300
80/80 [==============================] - 0s 403us/sample - loss: 7.3743e-06
Epoch 215/300
80/80 [==============================] - 0s 251us/sample - loss: 6.9396e-06
Epoch 216/300
80/80 [==============================] - 0s 308us/sample - loss: 6.7948e-06
Epoch 217/300
80/80 [==============================] - 0s 239us/sample - loss: 6.5376e-06
Epoch 218/300
80/80 [==============================] - 0s 203us/sample - loss: 6.8594e-06
Epoch 219/300
80/80 [==============================] - 0s 252us/sample - loss: 7.7030e-06
Epoch 220/300
80/80 [==============================] - 0s 435us/sample - loss: 7.1884e-06
Epoch 221/300
80/80 [==============================] - 0s 333us/sample - loss: 6.7609e-06
Epoch 222/300
80/80 [==============================] - 0s 343us/sample - loss: 6.4750e-06
Epoch 223/300
80/80 [==============================] - 0s 301us/sample - loss: 7.2972e-06
Epoch 224/300
80/80 [==============================] - 0s 318us/sample - loss: 6.4829e-06
Epoch 225/300
80/80 [==============================] - 0s 287us/sample - loss: 6.7412e-06
Epoch 226/300
80/80 [==============================] - 0s 269us/sample - loss: 5.8594e-06
Epoch 227/300
80/80 [==============================] - 0s 270us/sample - loss: 5.3580e-06
Epoch 228/300
80/80 [==============================] - 0s 226us/sample - loss: 5.8172e-06
Epoch 229/300
80/80 [==============================] - 0s 246us/sample - loss: 5.3314e-06
Epoch 230/300
80/80 [==============================] - 0s 258us/sample - loss: 5.2835e-06
Epoch 231/300
80/80 [==============================] - 0s 385us/sample - loss: 4.9139e-06
Epoch 232/300
80/80 [==============================] - 0s 358us/sample - loss: 5.5161e-06
Epoch 233/300
80/80 [==============================] - 0s 245us/sample - loss: 5.2079e-06
Epoch 234/300
80/80 [==============================] - 0s 292us/sample - loss: 4.9370e-06
Epoch 235/300
80/80 [==============================] - 0s 278us/sample - loss: 4.8542e-06
Epoch 236/300
80/80 [==============================] - 0s 360us/sample - loss: 5.0252e-06
Epoch 237/300
80/80 [==============================] - 0s 280us/sample - loss: 6.0323e-06
Epoch 238/300
80/80 [==============================] - 0s 276us/sample - loss: 5.8640e-06
Epoch 239/300
80/80 [==============================] - 0s 215us/sample - loss: 5.8375e-06
Epoch 240/300
80/80 [==============================] - 0s 202us/sample - loss: 5.0084e-06
Epoch 241/300
80/80 [==============================] - 0s 253us/sample - loss: 5.3107e-06
Epoch 242/300
80/80 [==============================] - 0s 396us/sample - loss: 6.9470e-06
Epoch 243/300
80/80 [==============================] - 0s 257us/sample - loss: 5.1279e-06
Epoch 244/300
80/80 [==============================] - 0s 347us/sample - loss: 4.6312e-06
Epoch 245/300
80/80 [==============================] - 0s 308us/sample - loss: 4.6249e-06
Epoch 246/300
80/80 [==============================] - 0s 258us/sample - loss: 4.0908e-06
Epoch 247/300
80/80 [==============================] - 0s 400us/sample - loss: 4.6141e-06
Epoch 248/300
80/80 [==============================] - 0s 271us/sample - loss: 5.3335e-06
Epoch 249/300
80/80 [==============================] - 0s 252us/sample - loss: 4.5967e-06
Epoch 250/300
80/80 [==============================] - 0s 270us/sample - loss: 4.4175e-06
Epoch 251/300
80/80 [==============================] - 0s 297us/sample - loss: 3.7273e-06
Epoch 252/300
80/80 [==============================] - 0s 208us/sample - loss: 3.5868e-06
Epoch 253/300
80/80 [==============================] - 0s 295us/sample - loss: 3.8783e-06
Epoch 254/300
80/80 [==============================] - 0s 246us/sample - loss: 4.6820e-06
Epoch 255/300
80/80 [==============================] - 0s 219us/sample - loss: 3.8402e-06
Epoch 256/300
80/80 [==============================] - 0s 286us/sample - loss: 3.6388e-06
Epoch 257/300
80/80 [==============================] - 0s 234us/sample - loss: 3.3160e-06
Epoch 258/300
80/80 [==============================] - 0s 255us/sample - loss: 2.9783e-06
Epoch 259/300
80/80 [==============================] - 0s 351us/sample - loss: 3.0353e-06
Epoch 260/300
80/80 [==============================] - 0s 327us/sample - loss: 3.2506e-06
Epoch 261/300
80/80 [==============================] - 0s 242us/sample - loss: 3.5131e-06
Epoch 262/300
80/80 [==============================] - 0s 284us/sample - loss: 2.9673e-06
Epoch 263/300
80/80 [==============================] - 0s 231us/sample - loss: 3.2775e-06
Epoch 264/300
80/80 [==============================] - 0s 242us/sample - loss: 2.9067e-06
Epoch 265/300
80/80 [==============================] - 0s 249us/sample - loss: 2.9360e-06
Epoch 266/300
80/80 [==============================] - 0s 307us/sample - loss: 3.1205e-06
Epoch 267/300
80/80 [==============================] - 0s 241us/sample - loss: 2.4764e-06
Epoch 268/300
80/80 [==============================] - 0s 235us/sample - loss: 2.8608e-06
Epoch 269/300
80/80 [==============================] - 0s 264us/sample - loss: 2.5285e-06
Epoch 270/300
80/80 [==============================] - 0s 242us/sample - loss: 2.7170e-06
Epoch 271/300
80/80 [==============================] - 0s 224us/sample - loss: 2.9606e-06
Epoch 272/300
80/80 [==============================] - 0s 292us/sample - loss: 2.6021e-06
Epoch 273/300
80/80 [==============================] - 0s 332us/sample - loss: 2.4352e-06
Epoch 274/300
80/80 [==============================] - 0s 231us/sample - loss: 2.7566e-06
Epoch 275/300
80/80 [==============================] - 0s 235us/sample - loss: 2.1990e-06
Epoch 276/300
80/80 [==============================] - 0s 263us/sample - loss: 2.5519e-06
Epoch 277/300
80/80 [==============================] - 0s 268us/sample - loss: 2.5170e-06
Epoch 278/300
80/80 [==============================] - 0s 269us/sample - loss: 1.9573e-06
Epoch 279/300
80/80 [==============================] - 0s 276us/sample - loss: 2.6911e-06
Epoch 280/300
80/80 [==============================] - 0s 240us/sample - loss: 2.3851e-06
Epoch 281/300
80/80 [==============================] - 0s 233us/sample - loss: 2.4802e-06
Epoch 282/300
80/80 [==============================] - 0s 253us/sample - loss: 2.2609e-06
Epoch 283/300
80/80 [==============================] - 0s 253us/sample - loss: 3.1520e-06
Epoch 284/300
80/80 [==============================] - 0s 309us/sample - loss: 2.7761e-06
Epoch 285/300
80/80 [==============================] - 0s 368us/sample - loss: 2.4314e-06
Epoch 286/300
80/80 [==============================] - 0s 269us/sample - loss: 2.3167e-06
Epoch 287/300
80/80 [==============================] - 0s 242us/sample - loss: 2.2000e-06
Epoch 288/300
80/80 [==============================] - 0s 267us/sample - loss: 2.7738e-06
Epoch 289/300
80/80 [==============================] - 0s 236us/sample - loss: 2.7103e-06
Epoch 290/300
80/80 [==============================] - 0s 290us/sample - loss: 3.5895e-06
Epoch 291/300
80/80 [==============================] - 0s 386us/sample - loss: 2.0682e-06
Epoch 292/300
80/80 [==============================] - 0s 271us/sample - loss: 2.2282e-06
Epoch 293/300
80/80 [==============================] - 0s 227us/sample - loss: 2.2646e-06
Epoch 294/300
80/80 [==============================] - 0s 254us/sample - loss: 1.6006e-06
Epoch 295/300
80/80 [==============================] - 0s 277us/sample - loss: 2.0118e-06
Epoch 296/300
80/80 [==============================] - 0s 291us/sample - loss: 1.7880e-06
Epoch 297/300
80/80 [==============================] - 0s 352us/sample - loss: 1.5494e-06
Epoch 298/300
80/80 [==============================] - 0s 308us/sample - loss: 1.6165e-06
Epoch 299/300
80/80 [==============================] - 0s 232us/sample - loss: 2.1783e-06
Epoch 300/300
80/80 [==============================] - 0s 283us/sample - loss: 1.9141e-06
Out[29]:
<tensorflow.python.keras.callbacks.History at 0x7f5c4b7f0358>

In [30]:
# Plotting code, feel free to ignore.
h = 1.0
x_min, x_max = X[:, 0].min() - 5, X[:, 0].max() + 5
y_min, y_max = X[:, 1].min() - 5, X[:, 1].max() + 5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))

# here "model" is your model's prediction (classification) function
Z = tn_model.predict(np.c_[xx.ravel(), yy.ravel()]) 

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z)
plt.axis('off')

# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)


Out[30]:
<matplotlib.collections.PathCollection at 0x7f5c4bdcdbe0>

VS Fully Connected


In [31]:
fc_model.compile(optimizer="adam", loss="mean_squared_error")
fc_model.fit(X, Y, epochs=300, verbose=0)
# Plotting code, feel free to ignore.
h = 1.0
x_min, x_max = X[:, 0].min() - 5, X[:, 0].max() + 5
y_min, y_max = X[:, 1].min() - 5, X[:, 1].max() + 5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))

# here "model" is your model's prediction (classification) function
Z = fc_model.predict(np.c_[xx.ravel(), yy.ravel()]) 

# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z)
plt.axis('off')

# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)


Out[31]:
<matplotlib.collections.PathCollection at 0x7f5c4bbec390>