This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:

**Layers**: the basic building blocks and how to combine them into networks**Inputs and Outputs**: how data streams flow through the layers**Defining New Layer Classes****Testing and Debugging Layer Classes**

```
In [0]:
```# Copyright 2018 Google LLC.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

```
In [0]:
```# Import Trax
! pip install -q -U trax
! pip install -q tensorflow
from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature

```
```

```
In [0]:
```# Settings and utilities for handling inputs, outputs, and object properties.
np.set_printoptions(precision=3) # Reduce visual noise from extra digits.
def show_layer_properties(layer_obj, layer_name):
template = ('{}.n_in: {}\n'
'{}.n_out: {}\n'
'{}.sublayers: {}\n'
'{}.weights: {}\n')
print(template.format(layer_name, layer_obj.n_in,
layer_name, layer_obj.n_out,
layer_name, layer_obj.sublayers,
layer_name, layer_obj.weights))

The Layer class represents Trax's basic building blocks:

```
class Layer:
"""Base class for composable layers in a deep learning network.
Layers are the basic building blocks for deep learning models. A Trax layer
computes a function from zero or more inputs to zero or more outputs,
optionally using trainable weights (common) and non-parameter state (not
common). Authors of new layer subclasses typically override at most two
methods of the base `Layer` class:
`forward(inputs)`:
Computes this layer's output as part of a forward pass through the model.
`init_weights_and_state(self, input_signature)`:
Initializes weights and state for inputs with the given signature.
...
```

A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.

The simplest layers, those with no weights or sublayers, can be used without initialization. You can think of them as (pure) mathematical functions that can be plugged into neural networks.

For ease of testing and interactive exploration, layer objects implement the
`__call__`

method, so you can call them directly on input data:

`y = my_layer(x)`

Layers are also objects, so you can inspect their properties. For example:

`print(f'Number of inputs expected by this layer: {my_layer.n_in}')`

**Example 1.** tl.Relu $[n_{in} = 1, n_{out} = 1]$

```
In [0]:
```relu = tl.Relu()
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]])
y = relu(x)
# Show input, output, and two layer properties.
print(f'x:\n{x}\n\n'
f'relu(x):\n{y}\n\n'
f'Number of inputs expected by this layer: {relu.n_in}\n'
f'Number of outputs promised by this layer: {relu.n_out}')

```
```

**Example 2.** tl.Concatenate $[n_{in} = 2, n_{out} = 1]$

```
In [0]:
```concat = tl.Concatenate()
x0 = np.array([[1, 2, 3],
[4, 5, 6]])
x1 = np.array([[10, 20, 30],
[40, 50, 60]])
y = concat([x0, x1])
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'concat([x1, x2]):\n{y}\n\n'
f'Number of inputs expected by this layer: {concat.n_in}\n'
f'Number of outputs promised by this layer: {concat.n_out}')

```
```

Many layer types have creation-time parameters for flexibility. The
`Concatenate`

layer type, for instance, has two optional parameters:

`axis`

: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.`n_items`

: number of tensors to join into one by concatenation; default value is 2.

The following example shows `Concatenate`

configured for **3** input tensors,
and concatenation along the initial $(0^{th})$ axis.

**Example 3.** tl.Concatenate(n_items=3, axis=0)

```
In [0]:
```concat3 = tl.Concatenate(n_items=3, axis=0)
x0 = np.array([[1, 2, 3],
[4, 5, 6]])
x1 = np.array([[10, 20, 30],
[40, 50, 60]])
x2 = np.array([[100, 200, 300],
[400, 500, 600]])
y = concat3([x0, x1, x2])
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'x2:\n{x2}\n\n'
f'concat3([x0, x1, x2]):\n{y}')

```
```

Many layer types include weights that affect the computation of outputs from inputs, and they use back-progagated gradients to update those weights.

🚧🚧 *A very small subset of layer types, such as BatchNorm, also include
modifiable weights (called state) that are updated based on forward-pass
inputs/computation rather than back-propagated gradients.*

**Initialization**

Trainable layers must be initialized before use. Trax can take care of this
as part of the overall training process. In other settings (e.g., in tests or
interactively in a Colab notebook), you need to initialize the
*outermost/topmost* layer explicitly. For this, use `init`

:

```
def init(self, input_signature, rng=None, use_cache=False):
"""Initializes weights/state of this layer and its sublayers recursively.
Initialization creates layer weights and state, for layers that use them.
It derives the necessary array shapes and data types from the layer's input
signature, which is itself just shape and data type information.
For layers without weights or state, this method safely does nothing.
This method is designed to create weights/state only once for each layer
instance, even if the same layer instance occurs in multiple places in the
network. This enables weight sharing to be implemented as layer sharing.
Args:
input_signature: `ShapeDtype` instance (if this layer takes one input)
or list/tuple of `ShapeDtype` instances.
rng: Single-use random number generator (JAX PRNG key), or `None`;
if `None`, use a default computed from an integer 0 seed.
use_cache: If `True`, and if this layer instance has already been
initialized elsewhere in the network, then return special marker
values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
Else return this layer's newly initialized weights and state.
Returns:
A `(weights, state)` tuple.
"""
```

Input signatures can be built from scratch using `ShapeDType`

objects, or can
be derived from data via the `signature`

function (in module `shapes`

):

```
def signature(obj):
"""Returns a `ShapeDtype` signature for the given `obj`.
A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
instances. Note that this function is permissive with respect to its inputs
(accepts lists or tuples or dicts, and underlying objects can be any type
as long as they have shape and dtype attributes) and returns the corresponding
nested structure of `ShapeDtype`.
Args:
obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
of such objects.
Returns:
A corresponding nested structure of `ShapeDtype` instances.
"""
```

**Example 4.** tl.LayerNorm $[n_{in} = 1, n_{out} = 1]$

```
In [0]:
```layer_norm = tl.LayerNorm()
x = np.array([[-2, -1, 0, 1, 2],
[1, 2, 3, 4, 5],
[10, 20, 30, 40, 50]]).astype(np.float32)
layer_norm.init(shapes.signature(x))
y = layer_norm(x)
print(f'x:\n{x}\n\n'
f'layer_norm(x):\n{y}\n')
print(f'layer_norm.weights:\n{layer_norm.weights}')

```
```

The Trax library authors encourage users to build new layers as combinations of
existing layers. Hence, the library provides a small set of *combinator*
layers: layer objects that make a list of layers behave as a single layer.

The new layer, like other layers, can:

- compute outputs from inputs,
- update parameters from gradients, and
- combine with yet more layers.

**Combine with Serial**

The most common way to combine layers is with the `Serial`

class:

```
class Serial(base.Layer):
"""Combinator that applies layers serially (by function composition).
This combinator is commonly used to construct deep networks, e.g., like this::
mlp = tl.Serial(
tl.Dense(128),
tl.Relu(),
tl.Dense(10),
tl.LogSoftmax()
)
A Serial combinator uses stack semantics to manage data for its sublayers.
Each sublayer sees only the inputs it needs and returns only the outputs it
has generated. The sublayers interact via the data stack. For instance, a
sublayer k, following sublayer j, gets called with the data stack in the
state left after layer j has applied. The Serial combinator then:
- takes n_in items off the top of the stack (n_in = k.n_in) and calls
layer k, passing those items as arguments; and
- takes layer k's n_out return values (n_out = k.n_out) and pushes
them onto the data stack.
A Serial instance with no sublayers acts as a special-case (but useful)
1-input 1-output no-op.
"""
```

If one layer has the same number of outputs as the next layer has inputs (which is the usual case), the successive layers behave like function composition:

```
# h(.) = g(f(.))
layer_h = Serial(
layer_f,
layer_g,
)
```

Note how, inside `Serial`

, function composition is expressed naturally as a
succession of operations, so that no nested parentheses are needed.

**Example 5.** y = layer*norm(relu(x)) $[n*{in} = 1, n_{out} = 1]$

```
In [0]:
```layer_block = tl.Serial(
tl.Relu(),
tl.LayerNorm(),
)
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)
print(f'x:\n{x}\n\n'
f'layer_block(x):\n{y}')

```
```

And we can inspect the block as a whole, as if it were just another layer:

**Example 5'.** Inspecting a `Serial`

layer.

```
In [0]:
```print(f'layer_block: {layer_block}\n\n'
f'layer_block.weights: {layer_block.weights}')

```
```

**Combine with Branch**

The `Branch`

combinator arranges layers into parallel computational channels:

```
def Branch(*layers, name='Branch'):
"""Combinator that applies a list of layers in parallel to copies of inputs.
Each layer in the input list is applied to as many inputs from the stack
as it needs, and their outputs are successively combined on stack.
For example, suppose one has three layers:
- F: 1 input, 1 output
- G: 3 inputs, 1 output
- H: 2 inputs, 2 outputs (h1, h2)
Then Branch(F, G, H) will take 3 inputs and give 4 outputs:
- inputs: a, b, c
- outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)
As an important special case, a None argument to Branch acts as if it takes
one argument, which it leaves unchanged. (It acts as a one-arg no-op.)
Args:
*layers: List of layers.
name: Descriptive name for this layer.
Returns:
A branch layer built from the given sublayers.
"""
```

Residual blocks, for example, are implemented using `Branch`

:

```
def Residual(*layers, shortcut=None):
"""Wraps a series of layers with a residual connection.
Args:
*layers: One or more layers, to be applied in series.
shortcut: If None (the usual case), the Residual layer computes the
element-wise sum of the stack-top input with the output of the layer
series. If specified, the `shortcut` layer applies to a copy of the
inputs and (elementwise) adds its output to the output from the main
layer series.
Returns:
A layer representing a residual connection paired with a layer series.
"""
layers = _ensure_flat(layers)
layer = layers[0] if len(layers) == 1 else Serial(layers)
return Serial(
Branch(shortcut, layer),
Add(),
)
```

Here's a simple code example to highlight the mechanics.

**Example 6.** `Branch`

```
In [0]:
```relu = tl.Relu()
times_100 = tl.Fn("Times100", lambda x: x * 100.0)
branch_relu_t100 = tl.Branch(relu, times_100)
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]])
branch_relu_t100.init(shapes.signature(x))
y0, y1 = branch_relu_t100(x)
print(f'x:\n{x}\n\n'
f'y0:\n{y0}\n\n'
f'y1:\n{y1}')

```
```

The Trax runtime supports the concept of multiple data streams, which gives individual layers flexibility to:

- process a single data stream ($n_{in} = n_{out} = 1$),
- process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, ... $),
- split or inject data streams ($n_{in} < n_{out}$), or
- merge or remove data streams ($n_{in} > n_{out}$).

We saw in section 1 the example of `Residual`

, which involves both a split and a merge:

```
...
return Serial(
Branch(shortcut, layer),
Add(),
)
```

In other words, layer by layer:

`Branch(shortcut, layers)`

: makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [$n_{in} = 1$, $n_{out} = 2$]`Add()`

: combines the two streams back into one by adding two tensors elementwise. [$n_{in} = 2$, $n_{out} = 1$]

`Fn`

layer-creating function.Many layer types needed in deep learning compute pure functions from inputs to
outputs, using neither weights nor randomness. You can use Trax's `Fn`

function
to define your own pure layer types:

```
def Fn(name, f, n_out=1): # pylint: disable=invalid-name
"""Returns a layer with no weights that applies the function `f`.
`f` can take and return any number of arguments, and takes only positional
arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
The following, for example, would create a layer that takes two inputs and
returns two outputs -- element-wise sums and maxima:
`Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
The layer's number of inputs (`n_in`) is automatically set to number of
positional arguments in `f`, but you must explicitly set the number of
outputs (`n_out`) whenever it's not the default value 1.
Args:
name: Class-like name for the resulting layer; for use in debugging.
f: Pure function from input tensors to output tensors, where each input
tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
Output tensors must be packaged as specified in the `Layer` class
docstring.
n_out: Number of outputs promised by the layer; default value 1.
Returns:
Layer executing the function `f`.
"""
```

**Example 7.** Use `Fn`

to define a new layer type:

```
In [0]:
```# Define new layer type.
def Gcd():
"""Returns a layer to compute the greatest common divisor, elementwise."""
return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))
# Use it.
gcd = Gcd()
x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
y = gcd((x0, x1))
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'gcd((x0, x1)):\n{y}')

```
```

`Fn`

function infers `n_in`

(number of inputs) as the length of `f`

's arg
list. `Fn`

does not infer `n_out`

(number out outputs) though. If your `f`

has
more than one output, you need to give an explicit value using the `n_out`

keyword arg.

**Example 8.** `Fn`

with multiple outputs:

```
In [0]:
```# Define new layer type.
def SumAndMax():
"""Returns a layer to compute sums and maxima of two input tensors."""
return tl.Fn('SumAndMax',
lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
n_out=2)
# Use it.
sum_and_max = SumAndMax()
x0 = np.array([1, 2, 3, 4, 5])
x1 = np.array([10, 20, 30, 40, 50])
y0, y1 = sum_and_max([x0, x1])
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'y0:\n{y0}\n\n'
f'y1:\n{y1}')

```
```

**Example 9.** Use `Fn`

to define a configurable layer:

```
In [0]:
```# Function defined in trax/layers/core.py:
def Flatten(n_axes_to_keep=1):
"""Returns a layer that combines one or more trailing axes of a tensor.
Flattening keeps all the values of the input tensor, but reshapes it by
collapsing one or more trailing axes into a single axis. For example, a
`Flatten(n_axes_to_keep=2)` layer would map a tensor with shape
`(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.
Args:
n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;
collapse only the axes after these.
"""
layer_name = f'Flatten_keep{n_axes_to_keep}'
def f(x):
in_rank = len(x.shape)
if in_rank <= n_axes_to_keep:
raise ValueError(f'Input rank ({in_rank}) must exceed the number of '
f'axes to keep ({n_axes_to_keep}) after flattening.')
return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
return tl.Fn(layer_name, f)
flatten_keep_1_axis = Flatten(n_axes_to_keep=1)
flatten_keep_2_axes = Flatten(n_axes_to_keep=2)
x = np.array([[[1, 2, 3],
[10, 20, 30],
[100, 200, 300]],
[[4, 5, 6],
[40, 50, 60],
[400, 500, 600]]])
y1 = flatten_keep_1_axis(x)
y2 = flatten_keep_2_axes(x)
print(f'x:\n{x}\n\n'
f'flatten_keep_1_axis(x):\n{y1}\n\n'
f'flatten_keep_2_axes(x):\n{y2}')

```
```