A guided tour of Flax

This notebook provides an guided tour of the features of Flax, starting from plain JAX, explaining the Flax module abstraction, and on to more advanced functionality.

In [ ]:
import numpy as np
import matplotlib.pyplot as plt

import jax
from jax import numpy as jnp, random, jit, lax

import flax
from flax import nn, optim

# init jax with some random compute. 
# JAX might complain about not having access to a GPU or TPU.
_ = jnp.square(2.)

Intro to JAX

JAX is a numerical computation library which aims to replicate the numpy api.

A few important things to know about JAX:

  • It is functional. This means no in-place ops and sliced assignments. Functions should not take inputs or produce outputs using global state.
  • JAX works best on functions where the bulk of the computation is in numpy calls, with Python control-flow generally limited to operate on array shapes and non-array data. See JAX - The Sharp Bits.
  • JAX can execute computations on CPUs, GPUs, and TPUs.
  • Functions using the jax.numpy api can be traced for automatic transformations

    • jit: compile a function using XLA enabling fast execution
    • grad: take the gradient of a function
    • vmap: adds a batch dimension to a function
    • pmap: split a computation across devices based on the first dimension of each input argument.

Neural Networks in JAX without Flax

Before we dive into Flax what a typical neural networks component looks like when written in "native" JAX.

We decompose a learnable linear layer into two parts: a initializer function which uses a JAX PRNGKey to generate a random kernel and bias and the apply function which computes the linear transformation using a set of parameters and some inputs.

In [ ]:
def dense_init(rng, in_features, out_features,
  k1, k2 = random.split(rng)
  # init functions take a PRNGKey and a shape tuple and return ndarrays.
  kernel = kernel_init(k1, (in_features, out_features))
  bias = bias_init(k2, (out_features,))
  return kernel, bias

def dense_apply(params, inputs):
  kernel, bias = params
  return jnp.dot(inputs, kernel) + bias

Functional programming without abstractions naturally results into somewhat verbose but very explicit code.

Note how the random number generators and parameters are passed on explicitly to functions. JAX has no concept of variables so we cannot hide the parameters in variables somewhere. Similarly, there is no global random number generator which updates an internal seed.

In [ ]:
params = dense_init(random.PRNGKey(0), in_features=2, out_features=4)

Once we generated a set of parameters it is easy enough to apply them to some inputs.

In [ ]:
x = jnp.ones((1, 2))
dense_apply(params, x)

Because everything is functional we can use the functional transformations that JAX provides to do useful things like taking gradients to optimize the model.

In [ ]:
def loss_fn(params, x):
  y = dense_apply(params, x)
  return jnp.mean(y ** 2)
grad_fn = jax.grad(loss_fn) # by default jax.grad takes the gradient w.r.t. the first argument
grad_fn(params, x)

Simplifying Neural Networks in JAX: Flax Modules

The core of Flax is the Module abstraction. Modules allow you to write parameterized functions just as if you were writing a normal numpy function with JAX. The Module api allows you to declare parameters and use them directly with the JAX api's.

A few things to know about Modules:

  1. A Module is created by defining a subclass of flax.nn.Module and implementing the apply method.
  2. parameters are declared using self.param(name, shape, init_func) and return an initialized parameter value.
  3. Dense.init(rng, ...) and Dense.call(params, ...) behave identically to the dense_init and dense_apply implemented earlier.

Now let's try to do redefine the dense layer using Flax Modules.

In [ ]:
class Dense(nn.Module):
  """A learned linear transformation."""
  def apply(self, x, features,
            # init functions are of the form (PrngKey, shape) => init_value
    """The main entry point to a Module. Represents the function that
    given inputs and hyper-parameters computes an output. The actual parameters
    (inputs and parameters) are user-controlled, and depend on the actual Module functionality.
    For this example:
      * `x`: the input, an array of shape `(in_features)`.
      * `features`: the number of outputs, an integer.
      * `kernel_init`: the initializer for the kernel.
      * `bias_init`: the initializer for the biases.
    in_features = x.shape[-1]
    kernel_shape = (in_features, features)
    kernel = self.param('kernel', kernel_shape, kernel_init)
    bias = self.param('bias', (features,), bias_init)
    return jnp.dot(x, kernel) + bias

In [ ]:
y, params = Dense.init(random.PRNGKey(0), x, features=4)

In [ ]:
Dense.call(params, x, features=4)

Note that both init and call end up using the same apply function. That is why we must specify all the inputs and parameters (the number of features) in both init and call. Often the parameters are the same for each call to init and call. For these situations, we can use Module.partial to apply these arguments. partial takes keyword arguments and returns a new Module for which the given arguments are already applied. It can be thought of as the equivalent of functools.partial for Modules.

In [ ]:
module = Dense.partial(features=4) # Module with concrete hyper parameters, ready to be initialized
_, params = module.init(random.PRNGKey(0), x)
module.call(params, x)


Modules can be composed to form more complex Modules.

Within a Module's apply function other modules behave just like functions.

In [ ]:
# same as flax.nn.relu
def relu(x):
  return jnp.maximum(0., x)

class MLP(nn.Module):
  """Multi Layer Perceptron."""
  def apply(self, x,

    z = Dense(x, hidden_features)
    h = activation_fn(z)
    y = Dense(h, output_features)
    return y

module = MLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
y, params = module.init(random.PRNGKey(0), x)

The params returned by init have a nested structure of lists, tuples, dicts and and other types that can contain arrays; we call such a structure a pytree. When we compose Modules as in our example, the params is a structure of nested dictionaries. We can use jax.tree_map to apply a function to each leaf of a pytree, e.g., to reveal the params structure of the MLP model (recall that x.shape = (1, 2)).

In [ ]:
jax.tree_map(np.shape, params)

Module name

By default Flax will use integers as keys for the parameters of sub Modules. By passing the name argument we can control the parameter structure and make it more meaningful.

In [ ]:
class NamedMLP(nn.Module):
  def apply(self, x,

    z = Dense(x, hidden_features, name='hidden')
    h = activation_fn(z)
    y = Dense(h, output_features, name='out')
    return y

module = NamedMLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
_, params = module.init(random.PRNGKey(0), x)
jax.tree_map(np.shape, params)

Parameter sharing

Sometimes a Module should be applied to multiple inputs with one set of parameters. We can make a Module for which parameters are shared between calls using Module.shared. Just like with Module.partial we can pass keyword arguments that are fixed for each call to the Module.

In [ ]:
class SimpleRNN(nn.Module):
  def apply(self, x, iterations=3):
    dense = Dense.shared(
    ys = []
    for i in range(iterations):
      x = dense(x)
    return ys

we call the Dense layer named 'cell' 3 times but only one set of parameters shows up in the parameter structure due to weight sharing.

In [ ]:
ys, params = SimpleRNN.init(random.PRNGKey(0), x)
jax.tree_map(np.shape, params)

Shape inference

Previously we initialized the model by passing in some inputs. This is useful because it allows for Modules which automatically infer the shape of parameters based on inputs. It can also help catch errors in the model early, in the initialization phase of a program.

Nonetheless, Module.init includes some unnecessary overhead because typically we are not interested in the actual output of the model during initialization. Therefore, we can use JAX built-in lazy evaluation to get the benefits of shape inference without doing any unnecessary compute.

Module.init_by_shape returns only the shape and dtype of outputs but still creates fully initialized parameters. If you want to use initializers that (indirectly) depend on the values (not shape) of the inputs you should keep using Module.init.

In [ ]:
input_spec = [(1, 2)] # the input specification is a list of shape tuples
out_spec, params = SimpleRNN.init_by_shape(random.PRNGKey(0), input_spec)
# TODO: uncomment this line  once __repr__ is fixed in jax
# print('out_spec:', out_spec)
jax.tree_map(np.shape, params)


Module makes it easy to keep track of parameters inside a Model but so far it still required to explicitly keep track of parameter structure and the init & call functions.

Model is a thin abstraction around a Module and a set of parameters. A Model instance is callable and functional (e.g., changing parameters requires a new model instance).

Using Module.init or Module.init_by_shape will create a newly initialized set of parameters. Then you can wrap the module and the initialized parameters in a Model instance.

In [ ]:
x = jnp.ones((1, 2))
module = Dense.partial(features=4)
ys, initial_params = module.init(random.PRNGKey(0), x)
model = nn.Model(module, initial_params)
jax.tree_map(np.shape, model.params)

In [ ]:

In [ ]:

Parameters can be updated using the Model.replace method

In [ ]:
biased_model = model.replace(params={'kernel': model.params['kernel'], 'bias': model.params['bias'] + 1.})

Model is registered as a JAX pytree container object which means that it can be passed to JAX transformations and jax.tree_map.

For example we can take gradients w.r.t. a Model object. The returned Model object will contain the gradients corresponding to each parameter.

In [ ]:
def loss_fn(model):
  y = model(x)
  return jnp.mean(y ** 2)

model_grad = jax.grad(loss_fn)(model)


Flax allows stateful operations to happen within a limited scope.

Stateful Modules are defined using the Module.state api. It returns a state object that has a property value that can be assigned to.

A typical use of stateful Module is BatchNorm which maintains a moving average of batch statistics (mean, variance). During training the moving averages are updated such that they can be used during test time.

In [ ]:
# simplified version of nn.BatchNorm
class BatchNorm(nn.Module):
  def apply(self, x, red_axis=0, eps=1e-5,
            momentum=0.99, training=False,

    # compute the moments of the input
    mean = x.mean(red_axis, keepdims=True)
    var = jnp.square(x - mean).mean(red_axis, keepdims=True)

    # define the state variables
    ra_mean = self.state('mean', mean.shape, nn.initializers.zeros)
    ra_var = self.state('var', var.shape, nn.initializers.ones)

    if not self.is_initializing():  # during init we ignore the moving averages completely
      if training:
        # during training the moving averages are updated
        alpha = 1. - momentum
        ra_mean.value += alpha * (mean - ra_mean.value)
        ra_var.value += alpha * (var - ra_var.value)
        # if we are not training we use the moving averages
        mean = ra_mean.value
        var = ra_var.value

    # standardize the input
    y = (x - mean) / jnp.sqrt(var + eps)

    # learn the scale and bias of the output
    gamma = self.param('gamma', mean.shape, gamma_init)
    beta = self.param('beta', mean.shape, beta_init)
    return gamma * y + beta

Stateful modules require special care when used. The nn.stateful context manager defines a scope in which stateful operations are allowed. Outside of this scope the state becomes immutable.

The state is stored in a nn.Collection object which internally stores the state as a dictionary.

nn.stateful takes a Collection containing the current state and returns a new Collection that contains the updated state. By default a new Collection will be created.

When using nn.stateful(state, mutable=False) the state can be read but any updates will raise an error. This is often useful during test time to guarantee that test data does not affect the model.

In [ ]:
class MyModel(nn.Module):

  def apply(self, x, training=False):
    x = Dense(x, features=4)
    x = BatchNorm(x, training=training, momentum=0., name='batch_norm')
    return x

dist_a = lambda rng, shape: random.normal(rng, shape) * jnp.array([[1., 3.]])

x_a = dist_a(random.PRNGKey(1), (1024, 2))
print('std. deviation of input:', x_a.std(0))

with nn.stateful() as init_state:
  y, params = MyModel.init(random.PRNGKey(2), x_a)
print('std. deviation of output (init):', y.std(0))

with nn.stateful(init_state) as new_state:
  y = MyModel.call(params, x_a, training=True)
print('std. deviation of output (training):', y.std(0))

with nn.stateful(new_state, mutable=False):
  y = MyModel.call(params, x_a, training=False)
print('std. deviation of output (testing):', y.std(0))

The state can be inspected using Collection.as_dict().

Each Module has a path like key into the Collection (eg. '/some_module/nested_module/dense').

In [ ]:

In [ ]:

The stateful mechanism forces the user to be explicit about stateful operations.

One motivating example for this approach is to enforce that state is not updated at test time.

Another benefit is that it is easier to replace the state when necessary. For example let say we want to apply this model on a second input distribution (b) with different statistics.

In [ ]:
dist_b = lambda rng, shape: random.normal(rng, shape) * jnp.array([[2., 5.]])

x_b = dist_b(random.PRNGKey(1), (1024, 2))

with nn.stateful(new_state, mutable=False):
  y = MyModel.call(params, x_b, training=False)
print(y.std(0)) # this will not be properly normalized!

We can solve the skew in statistics by creating a separate state for this alternative input distribution.

In [ ]:
with nn.stateful(init_state) as state_b:
  y = MyModel.call(params, x_b, training=True)
print('std. deviation of output (training):', y.std(0))

with nn.stateful(state_b, mutable=False):
  y = MyModel.call(params, x_b, training=False)
print('std. deviation of output (testing):', y.std(0))


The flax.optim package contains a simple api for optimizing a set of parameters using gradient descent algorithms.

To illustrate the optimizer api let's first define a simple linear regression problem:

In [ ]:
rng = random.PRNGKey(0)
rng, key1, key2 = random.split(rng, 3)
n = 30
x = jnp.linspace(-5., 5.)
X = random.uniform(key1, (n,), minval=-5., maxval=5.)
f = lambda x: 2. * x
Y = f(X) + random.normal(key2, (n,))
plt.plot(x, f(x))
plt.scatter(X, Y)

The model is nothing more than a Dense module with a single feature

In [ ]:
class LinearRegression(nn.Module):
  def apply(self, x):
    # add a singleton dimension to the input and remove the singleton feature dim from the output
    return nn.Dense(x[..., None], features=1)[..., 0]

rng, key = random.split(rng)
_, initial_params = LinearRegression.init(key, X)
model = nn.Model(LinearRegression, initial_params)

# plot the data together with the line used to generate the data (blue) and the untrained model (orange)
plt.plot(x, f(x))
plt.plot(x, model(x))
plt.scatter(X, Y)

We will use gradient descent with momentum to fit the model to the data. Each optimizer inherits from the flax.optim.OptimizerDef. The OptimizerDef class provides an init_state and apply_gradient method which initialize and update the optimizer state, respectively.

OptimizerDef does not actually maintain the state and optimized parameters. It it simply a collection of functions.

When calling OptimizerDef.create the optimization target and the optimizer state (eg. gradient moving average) are wrapped together with the OptimizerDef in an instance of Optimizer.

Optimizers can optimize any pytree of arrays (nested dicts, Model instances, etc), as long as the gradient w.r.t. is computable by jax.grad.

In [ ]:
optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.9)
optimizer = optimizer_def.create(model)
train_steps = 100

def loss_fn(model):
  Y_hat = model(X)
  return jnp.square(Y - Y_hat).mean()

for i in range(train_steps):
  # optimizer.target is passed to the loss_fn
  loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
  # `apply_gradient` returns a new `Optimizer` instance with the updated target and optimizer state.
  optimizer = optimizer.apply_gradient(grad)
print('mean square error:', loss)

trained_model = optimizer.target
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)


The flax.serialization module provides utilities for extracting the state of optimizers, model, and other structures as a dictionary of arrays. It also provides an integration with message pack which can be used to efficiently serialize the state dictionary in a binary, cross-platform compatible format.

The flax.struct.dataclass decorator can be used to create a Python dataclass which can be passed into jax transformations like jit, grad, and tree_map. It also integrates with the state dict api.

In [ ]:
# Model and Optimizer are also Flax dataclasses.
class TrainState:
  optimizer: optim.Optimizer
  step: int

state = TrainState(optimizer=optimizer, step=5)

Flax dataclasses are immutable. Using the replace method a new instance can be created with a set of updated fields.

In [ ]:
new_state = state.replace(step=6)
print(state.step, new_state.step)

The to_bytes and from_bytes functions are used to convert an object to the message pack format and back.

In [ ]:
data = flax.serialization.to_bytes(state)
print('num bytes:', len(data))

In [ ]:
corrupted_state = jax.tree_map(lambda x: 0 * x, state)

In [ ]:
# Restore the state using the state dict using the serialized state dict stored in data
restored_state = flax.serialization.from_bytes(corrupted_state, data)

Advanced features

Selective Optimization

Sometimes we wish to apply a different optimizer to a subset of the model parameters. To illustrate we will apply weight decay to bias value of the linear model that was trained before.

flax.optim.MultiOptimizer takes in tuples of traversals and OptimizerDef instances. The traversal is responsible for selecting the subset of parameters that should be optimized. We will use flax.optim.ModelParamTraversal which allows you to filter parameters based on the path (eg. '/hidden/dense/kernel').

In [ ]:
slope_opt_def = optim.Momentum(learning_rate=0.1)
# by applying decay to the bias parameter it will end up being closer to zero than before
bias_opt_def = optim.Momentum(learning_rate=0.1, weight_decay=10.)
# select all kernel parameters
slope_traversal = optim.ModelParamTraversal(lambda path, param: 'kernel' in path)
# select all bias parameters
bias_traversal = optim.ModelParamTraversal(lambda path, param: 'bias' in path)
optimizer_def = optim.MultiOptimizer((slope_traversal, slope_opt_def), (bias_traversal, bias_opt_def))

_, initial_params = LinearRegression.init(random.PRNGKey(0), X)
model = nn.Model(LinearRegression, initial_params)
optimizer = optimizer_def.create(model)

train_steps = 100

def loss_fn(model):
  Y_hat = model(X)
  return jnp.square(Y - Y_hat).mean()

for i in range(train_steps):
  loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
print('mean square error:', loss)

trained_model = optimizer.target
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)

Multi method modules

In [ ]:
class MultiMethodModule(nn.Module):

  def apply(self, x):
    kernel = self.param('kernel', (), lambda _, shape: jnp.full(shape, 2.))
    return x * kernel

  def decode(self, x):
    kernel = self.get_param('kernel')
    return x * kernel

x = 2. ** jnp.arange(5)
y, initial_params = MultiMethodModule.init(random.PRNGKey(0), x)
model = nn.Model(MultiMethodModule, initial_params)
print('target:', x[1:], 'teacher forced decoding:', y[:-1])
print('sequential decoding (one step):', model.decode(1.))

In [ ]:
def body_fn(carry, _):
  y = model.decode(carry)
  new_carry = y  # feed output back
  return new_carry, y

carry, ys = lax.scan(body_fn, 1., (), length=4)
print('carry:', carry)
print('sequential decoding:', ys)

Transforming sub module parameters

In [ ]:
def add_scale(module):
  class ScaleWrapper(nn.Module):
    """Add a learnable scale to the kernel of a module."""

    def apply(self, *args, **kwargs):
      def init_fn(rng, _):
        _, params = module.init(rng, *args, **kwargs)
        # here we could change the initial parameters of the wrapped module
        return params
      params = self.param('params', None, init_fn)
      # here change transform parameters every call
      assert 'kernel' in params
      kernel = params['kernel']
      features = kernel.shape[-1]
      scale = self.param('scale', (features,), nn.initializers.ones)
      scaled_kernel = kernel * scale
      scaled_params = params.copy()
      scaled_params['kernel'] = scaled_kernel

      return module.call(scaled_params, *args, **kwargs)
  return ScaleWrapper

x = jnp.ones((1, 2))
module = add_scale(Dense).partial(features=4)
y, params = module.init(random.PRNGKey(0), x)

In [ ]:

In [ ]:

In [ ]:

In [ ]:

In [ ]:

In [ ]:

In [ ]: