Quick Intro to Keras Functional API

Preamble: All models (layers) are callables

from keras.layers import Input, Dense
from keras.models import Model

# this returns a tensor
inputs = Input(shape=(784,))

# a layer instance is callable on a tensor, and returns a tensor
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

# this creates a model that includes
# the Input layer and three Dense layers
model = Model(input=inputs, output=predictions)
model.fit(data, labels)  # starts training

Multi-Input Networks

Keras Merge Layer

Here's a good use case for the functional API: models with multiple inputs and outputs.

The functional API makes it easy to manipulate a large number of intertwined datastreams.

Let's consider the following model.

from keras.layers import Dense, Input
from keras.models import Model
from keras.layers.merge import concatenate

left_input = Input(shape=(784, ), name='left_input')
left_branch = Dense(32, input_dim=784, name='left_branch')(left_input)

right_input = Input(shape=(784,), name='right_input')
right_branch = Dense(32, input_dim=784, name='right_branch')(right_input)

x = concatenate([left_branch, right_branch])
predictions = Dense(10, activation='softmax', name='main_output')(x)

model = Model(inputs=[left_input, right_input], outputs=predictions)

Resulting Model will look like the following network:

Such a two-branch model can then be trained via e.g.:

model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit([input_data_1, input_data_2], targets)  # we pass one data array per model input

Try yourself

Step 1: Get Data - MNIST

In [ ]:
# let's load MNIST data as we did in the exercise on MNIST with FC Nets

In [ ]:
# %load ../solutions/sol_52.py

Step 2: Create the Multi-Input Network

In [ ]:
## try yourself

In [ ]:
## `evaluate` the model on test data

Keras supports different Merge strategies:

  • add: element-wise sum
  • concatenate: tensor concatenation. You can specify the concatenation axis via the argument concat_axis.
  • multiply: element-wise multiplication
  • average: tensor average
  • maximum: element-wise maximum of the inputs.
  • dot: dot product. You can specify which axes to reduce along via the argument dot_axes. You can also specify applying any normalisation. In that case, the output of the dot product is the cosine proximity between the two samples.

You can also pass a function as the mode argument, allowing for arbitrary transformations:

merged = Merge([left_branch, right_branch], mode=lambda x: x[0] - x[1])