A Conceptual, Practical Introduction to Trax Layers

This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:

  1. Layers: the basic building blocks and how to combine them into networks
  2. Inputs and Outputs: how data streams flow through the layers
  3. Defining New Layer Classes
  4. Testing and Debugging Layer Classes

General Setup

Execute the following few cells (once) before running any of the code samples in this notebook.


In [0]:
# Copyright 2018 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

In [0]:
# Import Trax

! pip install -q -U trax
! pip install -q tensorflow

from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp  # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature


/bin/sh: pip: command not found
/bin/sh: pip: command not found

In [0]:
# Settings and utilities for handling inputs, outputs, and object properties.

np.set_printoptions(precision=3)  # Reduce visual noise from extra digits.

def show_layer_properties(layer_obj, layer_name):
  template = ('{}.n_in:  {}\n'
              '{}.n_out: {}\n'
              '{}.sublayers: {}\n'
              '{}.weights:    {}\n')
  print(template.format(layer_name, layer_obj.n_in,
                        layer_name, layer_obj.n_out,
                        layer_name, layer_obj.sublayers,
                        layer_name, layer_obj.weights))

1. Layers

The Layer class represents Trax's basic building blocks:

class Layer:
  """Base class for composable layers in a deep learning network.

  Layers are the basic building blocks for deep learning models. A Trax layer
  computes a function from zero or more inputs to zero or more outputs,
  optionally using trainable weights (common) and non-parameter state (not
  common). Authors of new layer subclasses typically override at most two
  methods of the base `Layer` class:

    `forward(inputs)`:
      Computes this layer's output as part of a forward pass through the model.

    `init_weights_and_state(self, input_signature)`:
      Initializes weights and state for inputs with the given signature.

  ...

Layers compute functions.

A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.

The simplest layers, those with no weights or sublayers, can be used without initialization. You can think of them as (pure) mathematical functions that can be plugged into neural networks.

For ease of testing and interactive exploration, layer objects implement the __call__ method, so you can call them directly on input data:

y = my_layer(x)

Layers are also objects, so you can inspect their properties. For example:

print(f'Number of inputs expected by this layer: {my_layer.n_in}')

Example 1. tl.Relu $[n_{in} = 1, n_{out} = 1]$


In [0]:
relu = tl.Relu()

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]])
y = relu(x)

# Show input, output, and two layer properties.
print(f'x:\n{x}\n\n'
      f'relu(x):\n{y}\n\n'
      f'Number of inputs expected by this layer: {relu.n_in}\n'
      f'Number of outputs promised by this layer: {relu.n_out}')


x:
[[ -2  -1   0   1   2]
 [-20 -10   0  10  20]]

relu(x):
[[ 0  0  0  1  2]
 [ 0  0  0 10 20]]

Number of inputs expected by this layer: 1
Number of outputs promised by this layer: 1

Example 2. tl.Concatenate $[n_{in} = 2, n_{out} = 1]$


In [0]:
concat = tl.Concatenate()

x0 = np.array([[1, 2, 3],
               [4, 5, 6]])
x1 = np.array([[10, 20, 30],
               [40, 50, 60]])
y = concat([x0, x1])

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'concat([x1, x2]):\n{y}\n\n'
      f'Number of inputs expected by this layer: {concat.n_in}\n'
      f'Number of outputs promised by this layer: {concat.n_out}')


x0:
[[1 2 3]
 [4 5 6]]

x1:
[[10 20 30]
 [40 50 60]]

concat([x1, x2]):
[[ 1  2  3 10 20 30]
 [ 4  5  6 40 50 60]]

Number of inputs expected by this layer: 2
Number of outputs promised by this layer: 1

Layers are configurable.

Many layer types have creation-time parameters for flexibility. The Concatenate layer type, for instance, has two optional parameters:

  • axis: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.
  • n_items: number of tensors to join into one by concatenation; default value is 2.

The following example shows Concatenate configured for 3 input tensors, and concatenation along the initial $(0^{th})$ axis.

Example 3. tl.Concatenate(n_items=3, axis=0)


In [0]:
concat3 = tl.Concatenate(n_items=3, axis=0)

x0 = np.array([[1, 2, 3],
               [4, 5, 6]])
x1 = np.array([[10, 20, 30],
               [40, 50, 60]])
x2 = np.array([[100, 200, 300],
               [400, 500, 600]])

y = concat3([x0, x1, x2])

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'x2:\n{x2}\n\n'
      f'concat3([x0, x1, x2]):\n{y}')


x0:
[[1 2 3]
 [4 5 6]]

x1:
[[10 20 30]
 [40 50 60]]

x2:
[[100 200 300]
 [400 500 600]]

concat3([x0, x1, x2]):
[[  1   2   3]
 [  4   5   6]
 [ 10  20  30]
 [ 40  50  60]
 [100 200 300]
 [400 500 600]]

Layers are trainable.

Many layer types include weights that affect the computation of outputs from inputs, and they use back-progagated gradients to update those weights.

🚧🚧 A very small subset of layer types, such as BatchNorm, also include modifiable weights (called state) that are updated based on forward-pass inputs/computation rather than back-propagated gradients.

Initialization

Trainable layers must be initialized before use. Trax can take care of this as part of the overall training process. In other settings (e.g., in tests or interactively in a Colab notebook), you need to initialize the outermost/topmost layer explicitly. For this, use init:

  def init(self, input_signature, rng=None, use_cache=False):
    """Initializes weights/state of this layer and its sublayers recursively.

    Initialization creates layer weights and state, for layers that use them.
    It derives the necessary array shapes and data types from the layer's input
    signature, which is itself just shape and data type information.

    For layers without weights or state, this method safely does nothing.

    This method is designed to create weights/state only once for each layer
    instance, even if the same layer instance occurs in multiple places in the
    network. This enables weight sharing to be implemented as layer sharing.

    Args:
      input_signature: `ShapeDtype` instance (if this layer takes one input)
          or list/tuple of `ShapeDtype` instances.
      rng: Single-use random number generator (JAX PRNG key), or `None`;
          if `None`, use a default computed from an integer 0 seed.
      use_cache: If `True`, and if this layer instance has already been
          initialized elsewhere in the network, then return special marker
          values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
          Else return this layer's newly initialized weights and state.

    Returns:
      A `(weights, state)` tuple.
    """

Input signatures can be built from scratch using ShapeDType objects, or can be derived from data via the signature function (in module shapes):

def signature(obj):
  """Returns a `ShapeDtype` signature for the given `obj`.

  A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
  instances. Note that this function is permissive with respect to its inputs
  (accepts lists or tuples or dicts, and underlying objects can be any type
  as long as they have shape and dtype attributes) and returns the corresponding
  nested structure of `ShapeDtype`.

  Args:
    obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
        of such objects.

  Returns:
    A corresponding nested structure of `ShapeDtype` instances.
  """

Example 4. tl.LayerNorm $[n_{in} = 1, n_{out} = 1]$


In [0]:
layer_norm = tl.LayerNorm()

x = np.array([[-2, -1, 0, 1, 2],
              [1, 2, 3, 4, 5],
              [10, 20, 30, 40, 50]]).astype(np.float32)
layer_norm.init(shapes.signature(x))

y = layer_norm(x)

print(f'x:\n{x}\n\n'
      f'layer_norm(x):\n{y}\n')
print(f'layer_norm.weights:\n{layer_norm.weights}')


x:
[[-2. -1.  0.  1.  2.]
 [ 1.  2.  3.  4.  5.]
 [10. 20. 30. 40. 50.]]

layer_norm(x):
[[-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]]

layer_norm.weights:
(DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))

Layers combine into layers.

The Trax library authors encourage users to build new layers as combinations of existing layers. Hence, the library provides a small set of combinator layers: layer objects that make a list of layers behave as a single layer.

The new layer, like other layers, can:

  • compute outputs from inputs,
  • update parameters from gradients, and
  • combine with yet more layers.

Combine with Serial

The most common way to combine layers is with the Serial class:

class Serial(base.Layer):
  """Combinator that applies layers serially (by function composition).

  This combinator is commonly used to construct deep networks, e.g., like this::

      mlp = tl.Serial(
        tl.Dense(128),
        tl.Relu(),
        tl.Dense(10),
        tl.LogSoftmax()
      )

  A Serial combinator uses stack semantics to manage data for its sublayers.
  Each sublayer sees only the inputs it needs and returns only the outputs it
  has generated. The sublayers interact via the data stack. For instance, a
  sublayer k, following sublayer j, gets called with the data stack in the
  state left after layer j has applied. The Serial combinator then:

    - takes n_in items off the top of the stack (n_in = k.n_in) and calls
      layer k, passing those items as arguments; and

    - takes layer k's n_out return values (n_out = k.n_out) and pushes
      them onto the data stack.

  A Serial instance with no sublayers acts as a special-case (but useful)
  1-input 1-output no-op.
  """

If one layer has the same number of outputs as the next layer has inputs (which is the usual case), the successive layers behave like function composition:

#  h(.) = g(f(.))
layer_h = Serial(
    layer_f,
    layer_g,
)

Note how, inside Serial, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed.

Example 5. y = layernorm(relu(x)) $[n{in} = 1, n_{out} = 1]$


In [0]:
layer_block = tl.Serial(
    tl.Relu(),
    tl.LayerNorm(),
)

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)

print(f'x:\n{x}\n\n'
      f'layer_block(x):\n{y}')


x:
[[ -2.  -1.   0.   1.   2.]
 [-20. -10.   0.  10.  20.]]

layer_block(x):
[[-0.75 -0.75 -0.75  0.5   1.75]
 [-0.75 -0.75 -0.75  0.5   1.75]]

And we can inspect the block as a whole, as if it were just another layer:

Example 5'. Inspecting a Serial layer.


In [0]:
print(f'layer_block: {layer_block}\n\n'
      f'layer_block.weights: {layer_block.weights}')


layer_block: Serial[
  Relu
  LayerNorm
]

layer_block.weights: [(), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))]

Combine with Branch

The Branch combinator arranges layers into parallel computational channels:

def Branch(*layers, name='Branch'):
  """Combinator that applies a list of layers in parallel to copies of inputs.

  Each layer in the input list is applied to as many inputs from the stack
  as it needs, and their outputs are successively combined on stack.

  For example, suppose one has three layers:

    - F: 1 input, 1 output
    - G: 3 inputs, 1 output
    - H: 2 inputs, 2 outputs (h1, h2)

  Then Branch(F, G, H) will take 3 inputs and give 4 outputs:

    - inputs: a, b, c
    - outputs: F(a), G(a, b, c), h1, h2    where h1, h2 = H(a, b)

  As an important special case, a None argument to Branch acts as if it takes
  one argument, which it leaves unchanged. (It acts as a one-arg no-op.)

  Args:
    *layers: List of layers.
    name: Descriptive name for this layer.

  Returns:
    A branch layer built from the given sublayers.
  """

Residual blocks, for example, are implemented using Branch:

def Residual(*layers, shortcut=None):
  """Wraps a series of layers with a residual connection.

  Args:
    *layers: One or more layers, to be applied in series.
    shortcut: If None (the usual case), the Residual layer computes the
        element-wise sum of the stack-top input with the output of the layer
        series. If specified, the `shortcut` layer applies to a copy of the
        inputs and (elementwise) adds its output to the output from the main
        layer series.

  Returns:
      A layer representing a residual connection paired with a layer series.
  """
  layers = _ensure_flat(layers)
  layer = layers[0] if len(layers) == 1 else Serial(layers)
  return Serial(
      Branch(shortcut, layer),
      Add(),
  )

Here's a simple code example to highlight the mechanics.

Example 6. Branch


In [0]:
relu = tl.Relu()
times_100 = tl.Fn("Times100", lambda x: x * 100.0)
branch_relu_t100 = tl.Branch(relu, times_100)

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]])
branch_relu_t100.init(shapes.signature(x))

y0, y1 = branch_relu_t100(x)

print(f'x:\n{x}\n\n'
      f'y0:\n{y0}\n\n'
      f'y1:\n{y1}')


x:
[[ -2  -1   0   1   2]
 [-20 -10   0  10  20]]

y0:
[[ 0  0  0  1  2]
 [ 0  0  0 10 20]]

y1:
[[ -200.  -100.     0.   100.   200.]
 [-2000. -1000.     0.  1000.  2000.]]

2. Inputs and Outputs

The Trax runtime supports the concept of multiple data streams, which gives individual layers flexibility to:

  • process a single data stream ($n_{in} = n_{out} = 1$),
  • process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, ... $),
  • split or inject data streams ($n_{in} < n_{out}$), or
  • merge or remove data streams ($n_{in} > n_{out}$).

We saw in section 1 the example of Residual, which involves both a split and a merge:

  ...
  return Serial(
      Branch(shortcut, layer),
      Add(),
  )

In other words, layer by layer:

  • Branch(shortcut, layers): makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [$n_{in} = 1$, $n_{out} = 2$]
  • Add(): combines the two streams back into one by adding two tensors elementwise. [$n_{in} = 2$, $n_{out} = 1$]

Data Stack

TBD

3. Defining New Layer Classes

With the Fn layer-creating function.

Many layer types needed in deep learning compute pure functions from inputs to outputs, using neither weights nor randomness. You can use Trax's Fn function to define your own pure layer types:

def Fn(name, f, n_out=1):  # pylint: disable=invalid-name
  """Returns a layer with no weights that applies the function `f`.

  `f` can take and return any number of arguments, and takes only positional
  arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
  The following, for example, would create a layer that takes two inputs and
  returns two outputs -- element-wise sums and maxima:

      `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`

  The layer's number of inputs (`n_in`) is automatically set to number of
  positional arguments in `f`, but you must explicitly set the number of
  outputs (`n_out`) whenever it's not the default value 1.

  Args:
    name: Class-like name for the resulting layer; for use in debugging.
    f: Pure function from input tensors to output tensors, where each input
        tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
        Output tensors must be packaged as specified in the `Layer` class
        docstring.
    n_out: Number of outputs promised by the layer; default value 1.

  Returns:
    Layer executing the function `f`.
  """

Example 7. Use Fn to define a new layer type:


In [0]:
# Define new layer type.
def Gcd():
  """Returns a layer to compute the greatest common divisor, elementwise."""
  return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))

# Use it.
gcd = Gcd()

x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])

y = gcd((x0, x1))

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'gcd((x0, x1)):\n{y}')


x0:
[ 1  2  3  4  5  6  7  8  9 10]

x1:
[11 12 13 14 15 16 17 18 19 20]

gcd((x0, x1)):
[ 1  2  1  2  5  2  1  2  1 10]

The Fn function infers n_in (number of inputs) as the length of f's arg list. Fn does not infer n_out (number out outputs) though. If your f has more than one output, you need to give an explicit value using the n_out keyword arg.

Example 8. Fn with multiple outputs:


In [0]:
# Define new layer type.
def SumAndMax():
  """Returns a layer to compute sums and maxima of two input tensors."""
  return tl.Fn('SumAndMax',
               lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
               n_out=2)

# Use it.
sum_and_max = SumAndMax()

x0 = np.array([1, 2, 3, 4, 5])
x1 = np.array([10, 20, 30, 40, 50])

y0, y1 = sum_and_max([x0, x1])

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'y0:\n{y0}\n\n'
      f'y1:\n{y1}')


x0:
[1 2 3 4 5]

x1:
[10 20 30 40 50]

y0:
[11 22 33 44 55]

y1:
[10 20 30 40 50]

Example 9. Use Fn to define a configurable layer:


In [0]:
# Function defined in trax/layers/core.py:
def Flatten(n_axes_to_keep=1):
  """Returns a layer that combines one or more trailing axes of a tensor.

  Flattening keeps all the values of the input tensor, but reshapes it by
  collapsing one or more trailing axes into a single axis. For example, a
  `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape
  `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.

  Args:
    n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;
        collapse only the axes after these.
  """
  layer_name = f'Flatten_keep{n_axes_to_keep}'
  def f(x):
    in_rank = len(x.shape)
    if in_rank <= n_axes_to_keep:
      raise ValueError(f'Input rank ({in_rank}) must exceed the number of '
                       f'axes to keep ({n_axes_to_keep}) after flattening.')
    return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
  return tl.Fn(layer_name, f)

flatten_keep_1_axis = Flatten(n_axes_to_keep=1)
flatten_keep_2_axes = Flatten(n_axes_to_keep=2)

x = np.array([[[1, 2, 3],
               [10, 20, 30],
               [100, 200, 300]],
              [[4, 5, 6],
               [40, 50, 60],
               [400, 500, 600]]])

y1 = flatten_keep_1_axis(x)
y2 = flatten_keep_2_axes(x)

print(f'x:\n{x}\n\n'
      f'flatten_keep_1_axis(x):\n{y1}\n\n'
      f'flatten_keep_2_axes(x):\n{y2}')


x:
[[[  1   2   3]
  [ 10  20  30]
  [100 200 300]]

 [[  4   5   6]
  [ 40  50  60]
  [400 500 600]]]

flatten_keep_1_axis(x):
[[  1   2   3  10  20  30 100 200 300]
 [  4   5   6  40  50  60 400 500 600]]

flatten_keep_2_axes(x):
[[[  1   2   3]
  [ 10  20  30]
  [100 200 300]]

 [[  4   5   6]
  [ 40  50  60]
  [400 500 600]]]

By defining a Layer subclass

TBD

By defining a Combinator subclass

TBD

Testing and Debugging Layer Classes

TBD