In [ ]:
    
import numpy as np
import matplotlib.pyplot as plt
import jax
from jax import numpy as jnp, random, jit, lax
import flax
from flax import nn, optim
# init jax with some random compute. 
# JAX might complain about not having access to a GPU or TPU.
_ = jnp.square(2.)
    
JAX is a numerical computation library which aims to replicate the numpy api.
A few important things to know about JAX:
Functions using the jax.numpy api can be traced for automatic transformations
Before we dive into Flax what a typical neural networks component looks like when written in "native" JAX.
We decompose a learnable linear layer into two parts: a initializer function which uses a JAX PRNGKey to generate a random kernel and bias and the apply function which computes the linear transformation using a set of parameters and some inputs.
In [ ]:
    
def dense_init(rng, in_features, out_features,
               kernel_init=jax.nn.initializers.lecun_normal(),
               bias_init=jax.nn.initializers.zeros):
  k1, k2 = random.split(rng)
  # init functions take a PRNGKey and a shape tuple and return ndarrays.
  kernel = kernel_init(k1, (in_features, out_features))
  bias = bias_init(k2, (out_features,))
  return kernel, bias
def dense_apply(params, inputs):
  kernel, bias = params
  return jnp.dot(inputs, kernel) + bias
    
Functional programming without abstractions naturally results into somewhat verbose but very explicit code.
Note how the random number generators and parameters are passed on explicitly to functions. JAX has no concept of variables so we cannot hide the parameters in variables somewhere. Similarly, there is no global random number generator which updates an internal seed.
In [ ]:
    
params = dense_init(random.PRNGKey(0), in_features=2, out_features=4)
print(params)
    
Once we generated a set of parameters it is easy enough to apply them to some inputs.
In [ ]:
    
x = jnp.ones((1, 2))
dense_apply(params, x)
    
Because everything is functional we can use the functional transformations that JAX provides to do useful things like taking gradients to optimize the model.
In [ ]:
    
def loss_fn(params, x):
  y = dense_apply(params, x)
  return jnp.mean(y ** 2)
grad_fn = jax.grad(loss_fn) # by default jax.grad takes the gradient w.r.t. the first argument
grad_fn(params, x)
    
The core of Flax is the Module abstraction. Modules allow you to write parameterized functions just as if you were writing a normal numpy function with JAX. The Module api allows you to declare parameters and use them directly with the JAX api's.
A few things to know about Modules:
flax.nn.Module and implementing the apply method.self.param(name, shape, init_func) and return an initialized parameter value.Dense.init(rng, ...) and Dense.call(params, ...) behave identically to the dense_init and dense_apply implemented earlier.Now let's try to do redefine the dense layer using Flax Modules.
In [ ]:
    
class Dense(nn.Module):
  """A learned linear transformation."""
  def apply(self, x, features,
            # init functions are of the form (PrngKey, shape) => init_value
            kernel_init=jax.nn.initializers.lecun_normal(),
            bias_init=jax.nn.initializers.zeros):
    """The main entry point to a Module. Represents the function that
    given inputs and hyper-parameters computes an output. The actual parameters
    (inputs and parameters) are user-controlled, and depend on the actual Module functionality.
    For this example:
    
      * `x`: the input, an array of shape `(in_features)`.
      * `features`: the number of outputs, an integer.
      * `kernel_init`: the initializer for the kernel.
      * `bias_init`: the initializer for the biases.
    """
    in_features = x.shape[-1]
    kernel_shape = (in_features, features)
    kernel = self.param('kernel', kernel_shape, kernel_init)
    bias = self.param('bias', (features,), bias_init)
    return jnp.dot(x, kernel) + bias
    
In [ ]:
    
y, params = Dense.init(random.PRNGKey(0), x, features=4)
print(params)
    
In [ ]:
    
Dense.call(params, x, features=4)
    
Note that both init and call end up using the same apply function. That is why we must specify all the inputs
and parameters (the number of features) in both init and call. Often the parameters are the same for each
call to init and call.
For these situations, we can use Module.partial to apply these arguments. partial takes keyword arguments 
and returns a new Module for which the given arguments are already applied.
It can be thought of as the equivalent of functools.partial for Modules.
In [ ]:
    
module = Dense.partial(features=4) # Module with concrete hyper parameters, ready to be initialized
_, params = module.init(random.PRNGKey(0), x)
module.call(params, x)
    
In [ ]:
    
# same as flax.nn.relu
def relu(x):
  return jnp.maximum(0., x)
class MLP(nn.Module):
  """Multi Layer Perceptron."""
  
  def apply(self, x,
            hidden_features,
            output_features,
            activation_fn):
    z = Dense(x, hidden_features)
    h = activation_fn(z)
    y = Dense(h, output_features)
    return y
module = MLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
y, params = module.init(random.PRNGKey(0), x)
print(y)
    
The params returned by init have a nested structure of lists, tuples, dicts and 
and other types that can contain arrays; we call such a structure a pytree.
When we compose Modules as in our example, the params is a structure of nested dictionaries.
We can use jax.tree_map to apply a function to each leaf of a pytree, e.g., to reveal
the params structure of the MLP model (recall that x.shape = (1, 2)).
In [ ]:
    
jax.tree_map(np.shape, params)
    
In [ ]:
    
class NamedMLP(nn.Module):
  def apply(self, x,
            hidden_features,
            output_features,
            activation_fn):
    z = Dense(x, hidden_features, name='hidden')
    h = activation_fn(z)
    y = Dense(h, output_features, name='out')
    return y
module = NamedMLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
_, params = module.init(random.PRNGKey(0), x)
jax.tree_map(np.shape, params)
    
In [ ]:
    
class SimpleRNN(nn.Module):
  def apply(self, x, iterations=3):
    dense = Dense.shared(
        features=x.shape[-1],
        kernel_init=jax.nn.initializers.orthogonal(),
        name='cell')
    ys = []
    for i in range(iterations):
      x = dense(x)
      ys.append(x)
    return ys
    
we call the Dense layer named 'cell' 3 times but only one set of parameters shows up in the parameter structure due to weight sharing.
In [ ]:
    
ys, params = SimpleRNN.init(random.PRNGKey(0), x)
print(ys)
jax.tree_map(np.shape, params)
    
Previously we initialized the model by passing in some inputs. This is useful because it allows for Modules which automatically infer the shape of parameters based on inputs. It can also help catch errors in the model early, in the initialization phase of a program.
Nonetheless, Module.init includes some unnecessary overhead because typically we are not interested in the actual output of the model during initialization. Therefore, we can use JAX built-in lazy evaluation to get the benefits of shape inference without doing any unnecessary compute.
Module.init_by_shape returns only the shape and dtype of outputs but still creates fully initialized parameters. If you want to use initializers that (indirectly) depend on the values (not shape) of the inputs you should keep using Module.init.
In [ ]:
    
input_spec = [(1, 2)] # the input specification is a list of shape tuples
out_spec, params = SimpleRNN.init_by_shape(random.PRNGKey(0), input_spec)
# TODO: uncomment this line  once __repr__ is fixed in jax
# print('out_spec:', out_spec)
jax.tree_map(np.shape, params)
    
Module makes it easy to keep track of parameters inside a Model but so far it still required to explicitly keep track of parameter structure and the init & call functions.
Model is a thin abstraction around a Module and a set of parameters. A Model instance is callable and functional (e.g., changing parameters requires a new model instance).
Using Module.init or Module.init_by_shape will create a newly initialized set of parameters. Then you can wrap the module and the initialized parameters in a Model instance.
In [ ]:
    
x = jnp.ones((1, 2))
module = Dense.partial(features=4)
ys, initial_params = module.init(random.PRNGKey(0), x)
model = nn.Model(module, initial_params)
jax.tree_map(np.shape, model.params)
    
In [ ]:
    
model(x)
    
In [ ]:
    
model.params
    
Parameters can be updated using the Model.replace method
In [ ]:
    
biased_model = model.replace(params={'kernel': model.params['kernel'], 'bias': model.params['bias'] + 1.})
biased_model.params
    
Model is registered as a JAX pytree container object which means that it can be passed to JAX transformations and jax.tree_map.
For example we can take gradients w.r.t. a Model object. The returned Model object will contain the gradients corresponding to each parameter.
In [ ]:
    
def loss_fn(model):
  y = model(x)
  return jnp.mean(y ** 2)
model_grad = jax.grad(loss_fn)(model)
model_grad.params
    
Flax allows stateful operations to happen within a limited scope.
Stateful Modules are defined using the Module.state api. It returns a state object that has a property value that can be assigned to.
A typical use of stateful Module is BatchNorm which maintains a moving average of batch statistics (mean, variance). During training the moving averages are updated such that they can be used during test time.
In [ ]:
    
# simplified version of nn.BatchNorm
class BatchNorm(nn.Module):
  def apply(self, x, red_axis=0, eps=1e-5,
            momentum=0.99, training=False,
            gamma_init=nn.initializers.ones,
            beta_init=nn.initializers.zeros):
    # compute the moments of the input
    mean = x.mean(red_axis, keepdims=True)
    var = jnp.square(x - mean).mean(red_axis, keepdims=True)
    # define the state variables
    ra_mean = self.state('mean', mean.shape, nn.initializers.zeros)
    ra_var = self.state('var', var.shape, nn.initializers.ones)
    if not self.is_initializing():  # during init we ignore the moving averages completely
      if training:
        # during training the moving averages are updated
        alpha = 1. - momentum
        ra_mean.value += alpha * (mean - ra_mean.value)
        ra_var.value += alpha * (var - ra_var.value)
      else:
        # if we are not training we use the moving averages
        mean = ra_mean.value
        var = ra_var.value
    # standardize the input
    y = (x - mean) / jnp.sqrt(var + eps)
    # learn the scale and bias of the output
    gamma = self.param('gamma', mean.shape, gamma_init)
    beta = self.param('beta', mean.shape, beta_init)
    return gamma * y + beta
    
Stateful modules require special care when used. The nn.stateful context manager defines a scope in which stateful operations are allowed. Outside of this scope the state becomes immutable.
The state is stored in a nn.Collection object which internally stores the state as a dictionary.
nn.stateful takes a Collection containing the current state and returns a new Collection that contains the updated state. By default a new Collection will be created.
When using nn.stateful(state, mutable=False) the state can be read but any updates will raise an error. This is often useful during test time to guarantee that test data does not affect the model.
In [ ]:
    
class MyModel(nn.Module):
  def apply(self, x, training=False):
    x = Dense(x, features=4)
    x = BatchNorm(x, training=training, momentum=0., name='batch_norm')
    return x
dist_a = lambda rng, shape: random.normal(rng, shape) * jnp.array([[1., 3.]])
x_a = dist_a(random.PRNGKey(1), (1024, 2))
print('std. deviation of input:', x_a.std(0))
with nn.stateful() as init_state:
  y, params = MyModel.init(random.PRNGKey(2), x_a)
print('std. deviation of output (init):', y.std(0))
with nn.stateful(init_state) as new_state:
  y = MyModel.call(params, x_a, training=True)
print('std. deviation of output (training):', y.std(0))
with nn.stateful(new_state, mutable=False):
  y = MyModel.call(params, x_a, training=False)
print('std. deviation of output (testing):', y.std(0))
    
The state can be inspected using Collection.as_dict().
Each Module has a path like key into the Collection (eg. '/some_module/nested_module/dense').
In [ ]:
    
init_state.as_dict()
    
In [ ]:
    
new_state.as_dict()
    
The stateful mechanism forces the user to be explicit about stateful operations.
One motivating example for this approach is to enforce that state is not updated at test time.
Another benefit is that it is easier to replace the state when necessary. For example let say we want to apply this model on a second input distribution (b) with different statistics.
In [ ]:
    
dist_b = lambda rng, shape: random.normal(rng, shape) * jnp.array([[2., 5.]])
x_b = dist_b(random.PRNGKey(1), (1024, 2))
with nn.stateful(new_state, mutable=False):
  y = MyModel.call(params, x_b, training=False)
print(y.std(0)) # this will not be properly normalized!
    
We can solve the skew in statistics by creating a separate state for this alternative input distribution.
In [ ]:
    
with nn.stateful(init_state) as state_b:
  y = MyModel.call(params, x_b, training=True)
print('std. deviation of output (training):', y.std(0))
with nn.stateful(state_b, mutable=False):
  y = MyModel.call(params, x_b, training=False)
print('std. deviation of output (testing):', y.std(0))
    
In [ ]:
    
rng = random.PRNGKey(0)
rng, key1, key2 = random.split(rng, 3)
n = 30
x = jnp.linspace(-5., 5.)
X = random.uniform(key1, (n,), minval=-5., maxval=5.)
f = lambda x: 2. * x
Y = f(X) + random.normal(key2, (n,))
plt.plot(x, f(x))
plt.scatter(X, Y)
plt.show()
    
The model is nothing more than a Dense module with a single feature
In [ ]:
    
class LinearRegression(nn.Module):
  def apply(self, x):
    # add a singleton dimension to the input and remove the singleton feature dim from the output
    return nn.Dense(x[..., None], features=1)[..., 0]
rng, key = random.split(rng)
_, initial_params = LinearRegression.init(key, X)
model = nn.Model(LinearRegression, initial_params)
# plot the data together with the line used to generate the data (blue) and the untrained model (orange)
plt.plot(x, f(x))
plt.plot(x, model(x))
plt.scatter(X, Y)
plt.show()
    
We will use gradient descent with momentum to fit the model to the data.
Each optimizer inherits from the flax.optim.OptimizerDef.
The OptimizerDef class provides an init_state and apply_gradient method which initialize and update the optimizer state, respectively.
OptimizerDef does not actually maintain the state and optimized parameters. It it simply a collection of functions.
When calling OptimizerDef.create the optimization target and the optimizer state (eg. gradient moving average) are wrapped together with the OptimizerDef in an instance of Optimizer.
Optimizers can optimize any pytree of arrays (nested dicts, Model instances, etc), as long as the gradient w.r.t. is computable by jax.grad.
In [ ]:
    
optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.9)
optimizer = optimizer_def.create(model)
train_steps = 100
def loss_fn(model):
  Y_hat = model(X)
  return jnp.square(Y - Y_hat).mean()
for i in range(train_steps):
  # optimizer.target is passed to the loss_fn
  loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
  # `apply_gradient` returns a new `Optimizer` instance with the updated target and optimizer state.
  optimizer = optimizer.apply_gradient(grad)
print('mean square error:', loss)
trained_model = optimizer.target
print(trained_model.params)
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)
plt.show()
    
The flax.serialization module provides utilities for extracting the state of optimizers, model, and other structures as a dictionary of arrays.
It also provides an integration with message pack which can be used to efficiently serialize the state dictionary in a binary, cross-platform compatible format.
The flax.struct.dataclass decorator can be used to create a Python dataclass which can be passed into jax transformations like jit, grad, and tree_map. It also integrates with the state dict api.
In [ ]:
    
# Model and Optimizer are also Flax dataclasses.
@flax.struct.dataclass
class TrainState:
  optimizer: optim.Optimizer
  step: int
state = TrainState(optimizer=optimizer, step=5)
flax.serialization.to_state_dict(state)
    
Flax dataclasses are immutable. Using the replace method a new instance can be created with a set of updated fields.
In [ ]:
    
new_state = state.replace(step=6)
print(state.step, new_state.step)
    
The to_bytes and from_bytes functions are used to convert an object to the message pack format and back.
In [ ]:
    
data = flax.serialization.to_bytes(state)
print('num bytes:', len(data))
    
In [ ]:
    
corrupted_state = jax.tree_map(lambda x: 0 * x, state)
flax.serialization.to_state_dict(corrupted_state)
    
In [ ]:
    
# Restore the state using the state dict using the serialized state dict stored in data
restored_state = flax.serialization.from_bytes(corrupted_state, data)
flax.serialization.to_state_dict(restored_state)
    
Sometimes we wish to apply a different optimizer to a subset of the model parameters. To illustrate we will apply weight decay to bias value of the linear model that was trained before.
flax.optim.MultiOptimizer takes in tuples of traversals and OptimizerDef instances. The traversal is responsible for selecting the subset of parameters that should be optimized. We will use flax.optim.ModelParamTraversal which allows you to filter parameters based on the path  (eg. '/hidden/dense/kernel').
In [ ]:
    
slope_opt_def = optim.Momentum(learning_rate=0.1)
# by applying decay to the bias parameter it will end up being closer to zero than before
bias_opt_def = optim.Momentum(learning_rate=0.1, weight_decay=10.)
# select all kernel parameters
slope_traversal = optim.ModelParamTraversal(lambda path, param: 'kernel' in path)
# select all bias parameters
bias_traversal = optim.ModelParamTraversal(lambda path, param: 'bias' in path)
optimizer_def = optim.MultiOptimizer((slope_traversal, slope_opt_def), (bias_traversal, bias_opt_def))
_, initial_params = LinearRegression.init(random.PRNGKey(0), X)
model = nn.Model(LinearRegression, initial_params)
optimizer = optimizer_def.create(model)
train_steps = 100
def loss_fn(model):
  Y_hat = model(X)
  return jnp.square(Y - Y_hat).mean()
for i in range(train_steps):
  loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
print('mean square error:', loss)
trained_model = optimizer.target
print(trained_model.params)
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)
plt.show()
    
In [ ]:
    
class MultiMethodModule(nn.Module):
  def apply(self, x):
    kernel = self.param('kernel', (), lambda _, shape: jnp.full(shape, 2.))
    return x * kernel
  @nn.module_method
  def decode(self, x):
    kernel = self.get_param('kernel')
    return x * kernel
x = 2. ** jnp.arange(5)
y, initial_params = MultiMethodModule.init(random.PRNGKey(0), x)
model = nn.Model(MultiMethodModule, initial_params)
print('target:', x[1:], 'teacher forced decoding:', y[:-1])
print('sequential decoding (one step):', model.decode(1.))
    
In [ ]:
    
def body_fn(carry, _):
  y = model.decode(carry)
  new_carry = y  # feed output back
  return new_carry, y
carry, ys = lax.scan(body_fn, 1., (), length=4)
print('carry:', carry)
print('sequential decoding:', ys)
    
In [ ]:
    
def add_scale(module):
  class ScaleWrapper(nn.Module):
    """Add a learnable scale to the kernel of a module."""
    def apply(self, *args, **kwargs):
      def init_fn(rng, _):
        _, params = module.init(rng, *args, **kwargs)
        # here we could change the initial parameters of the wrapped module
        return params
      params = self.param('params', None, init_fn)
      # here change transform parameters every call
      assert 'kernel' in params
      kernel = params['kernel']
      features = kernel.shape[-1]
      scale = self.param('scale', (features,), nn.initializers.ones)
      scaled_kernel = kernel * scale
      scaled_params = params.copy()
      scaled_params['kernel'] = scaled_kernel
      return module.call(scaled_params, *args, **kwargs)
  return ScaleWrapper
x = jnp.ones((1, 2))
module = add_scale(Dense).partial(features=4)
y, params = module.init(random.PRNGKey(0), x)
params
    
In [ ]:
    
    
In [ ]:
    
    
In [ ]:
    
    
In [ ]:
    
    
In [ ]:
    
    
In [ ]:
    
    
In [ ]: