In [ ]:
import numpy as np
import matplotlib.pyplot as plt
import jax
from jax import numpy as jnp, random, jit, lax
import flax
from flax import nn, optim
# init jax with some random compute.
# JAX might complain about not having access to a GPU or TPU.
_ = jnp.square(2.)
JAX is a numerical computation library which aims to replicate the numpy api.
A few important things to know about JAX:
Functions using the jax.numpy api can be traced for automatic transformations
Before we dive into Flax what a typical neural networks component looks like when written in "native" JAX.
We decompose a learnable linear layer into two parts: a initializer function which uses a JAX PRNGKey to generate a random kernel and bias and the apply function which computes the linear transformation using a set of parameters and some inputs.
In [ ]:
def dense_init(rng, in_features, out_features,
kernel_init=jax.nn.initializers.lecun_normal(),
bias_init=jax.nn.initializers.zeros):
k1, k2 = random.split(rng)
# init functions take a PRNGKey and a shape tuple and return ndarrays.
kernel = kernel_init(k1, (in_features, out_features))
bias = bias_init(k2, (out_features,))
return kernel, bias
def dense_apply(params, inputs):
kernel, bias = params
return jnp.dot(inputs, kernel) + bias
Functional programming without abstractions naturally results into somewhat verbose but very explicit code.
Note how the random number generators and parameters are passed on explicitly to functions. JAX has no concept of variables so we cannot hide the parameters in variables somewhere. Similarly, there is no global random number generator which updates an internal seed.
In [ ]:
params = dense_init(random.PRNGKey(0), in_features=2, out_features=4)
print(params)
Once we generated a set of parameters it is easy enough to apply them to some inputs.
In [ ]:
x = jnp.ones((1, 2))
dense_apply(params, x)
Because everything is functional we can use the functional transformations that JAX provides to do useful things like taking gradients to optimize the model.
In [ ]:
def loss_fn(params, x):
y = dense_apply(params, x)
return jnp.mean(y ** 2)
grad_fn = jax.grad(loss_fn) # by default jax.grad takes the gradient w.r.t. the first argument
grad_fn(params, x)
The core of Flax is the Module abstraction. Modules allow you to write parameterized functions just as if you were writing a normal numpy function with JAX. The Module api allows you to declare parameters and use them directly with the JAX api's.
A few things to know about Modules:
flax.nn.Module
and implementing the apply
method.self.param(name, shape, init_func)
and return an initialized parameter value.Dense.init(rng, ...)
and Dense.call(params, ...)
behave identically to the dense_init
and dense_apply
implemented earlier.Now let's try to do redefine the dense layer using Flax Modules.
In [ ]:
class Dense(nn.Module):
"""A learned linear transformation."""
def apply(self, x, features,
# init functions are of the form (PrngKey, shape) => init_value
kernel_init=jax.nn.initializers.lecun_normal(),
bias_init=jax.nn.initializers.zeros):
"""The main entry point to a Module. Represents the function that
given inputs and hyper-parameters computes an output. The actual parameters
(inputs and parameters) are user-controlled, and depend on the actual Module functionality.
For this example:
* `x`: the input, an array of shape `(in_features)`.
* `features`: the number of outputs, an integer.
* `kernel_init`: the initializer for the kernel.
* `bias_init`: the initializer for the biases.
"""
in_features = x.shape[-1]
kernel_shape = (in_features, features)
kernel = self.param('kernel', kernel_shape, kernel_init)
bias = self.param('bias', (features,), bias_init)
return jnp.dot(x, kernel) + bias
In [ ]:
y, params = Dense.init(random.PRNGKey(0), x, features=4)
print(params)
In [ ]:
Dense.call(params, x, features=4)
Note that both init
and call
end up using the same apply
function. That is why we must specify all the inputs
and parameters (the number of features) in both init
and call
. Often the parameters are the same for each
call to init
and call
.
For these situations, we can use Module.partial
to apply these arguments. partial
takes keyword arguments
and returns a new Module for which the given arguments are already applied.
It can be thought of as the equivalent of functools.partial
for Modules.
In [ ]:
module = Dense.partial(features=4) # Module with concrete hyper parameters, ready to be initialized
_, params = module.init(random.PRNGKey(0), x)
module.call(params, x)
In [ ]:
# same as flax.nn.relu
def relu(x):
return jnp.maximum(0., x)
class MLP(nn.Module):
"""Multi Layer Perceptron."""
def apply(self, x,
hidden_features,
output_features,
activation_fn):
z = Dense(x, hidden_features)
h = activation_fn(z)
y = Dense(h, output_features)
return y
module = MLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
y, params = module.init(random.PRNGKey(0), x)
print(y)
The params
returned by init
have a nested structure of lists, tuples, dicts and
and other types that can contain arrays; we call such a structure a pytree.
When we compose Modules as in our example, the params
is a structure of nested dictionaries.
We can use jax.tree_map
to apply a function to each leaf of a pytree, e.g., to reveal
the params
structure of the MLP model (recall that x.shape = (1, 2)
).
In [ ]:
jax.tree_map(np.shape, params)
In [ ]:
class NamedMLP(nn.Module):
def apply(self, x,
hidden_features,
output_features,
activation_fn):
z = Dense(x, hidden_features, name='hidden')
h = activation_fn(z)
y = Dense(h, output_features, name='out')
return y
module = NamedMLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
_, params = module.init(random.PRNGKey(0), x)
jax.tree_map(np.shape, params)
In [ ]:
class SimpleRNN(nn.Module):
def apply(self, x, iterations=3):
dense = Dense.shared(
features=x.shape[-1],
kernel_init=jax.nn.initializers.orthogonal(),
name='cell')
ys = []
for i in range(iterations):
x = dense(x)
ys.append(x)
return ys
we call the Dense layer named 'cell' 3 times but only one set of parameters shows up in the parameter structure due to weight sharing.
In [ ]:
ys, params = SimpleRNN.init(random.PRNGKey(0), x)
print(ys)
jax.tree_map(np.shape, params)
Previously we initialized the model by passing in some inputs. This is useful because it allows for Modules which automatically infer the shape of parameters based on inputs. It can also help catch errors in the model early, in the initialization phase of a program.
Nonetheless, Module.init
includes some unnecessary overhead because typically we are not interested in the actual output of the model during initialization. Therefore, we can use JAX built-in lazy evaluation to get the benefits of shape inference without doing any unnecessary compute.
Module.init_by_shape
returns only the shape and dtype of outputs but still creates fully initialized parameters. If you want to use initializers that (indirectly) depend on the values (not shape) of the inputs you should keep using Module.init
.
In [ ]:
input_spec = [(1, 2)] # the input specification is a list of shape tuples
out_spec, params = SimpleRNN.init_by_shape(random.PRNGKey(0), input_spec)
# TODO: uncomment this line once __repr__ is fixed in jax
# print('out_spec:', out_spec)
jax.tree_map(np.shape, params)
Module makes it easy to keep track of parameters inside a Model but so far it still required to explicitly keep track of parameter structure and the init
& call
functions.
Model is a thin abstraction around a Module and a set of parameters. A Model instance is callable and functional (e.g., changing parameters requires a new model instance).
Using Module.init
or Module.init_by_shape
will create a newly initialized set of parameters. Then you can wrap the module and the initialized parameters in a Model
instance.
In [ ]:
x = jnp.ones((1, 2))
module = Dense.partial(features=4)
ys, initial_params = module.init(random.PRNGKey(0), x)
model = nn.Model(module, initial_params)
jax.tree_map(np.shape, model.params)
In [ ]:
model(x)
In [ ]:
model.params
Parameters can be updated using the Model.replace
method
In [ ]:
biased_model = model.replace(params={'kernel': model.params['kernel'], 'bias': model.params['bias'] + 1.})
biased_model.params
Model is registered as a JAX pytree container object which means that it can be passed to JAX transformations and jax.tree_map
.
For example we can take gradients w.r.t. a Model object. The returned Model object will contain the gradients corresponding to each parameter.
In [ ]:
def loss_fn(model):
y = model(x)
return jnp.mean(y ** 2)
model_grad = jax.grad(loss_fn)(model)
model_grad.params
Flax allows stateful operations to happen within a limited scope.
Stateful Modules are defined using the Module.state
api. It returns a state object that has a property value that can be assigned to.
A typical use of stateful Module is BatchNorm which maintains a moving average of batch statistics (mean, variance). During training the moving averages are updated such that they can be used during test time.
In [ ]:
# simplified version of nn.BatchNorm
class BatchNorm(nn.Module):
def apply(self, x, red_axis=0, eps=1e-5,
momentum=0.99, training=False,
gamma_init=nn.initializers.ones,
beta_init=nn.initializers.zeros):
# compute the moments of the input
mean = x.mean(red_axis, keepdims=True)
var = jnp.square(x - mean).mean(red_axis, keepdims=True)
# define the state variables
ra_mean = self.state('mean', mean.shape, nn.initializers.zeros)
ra_var = self.state('var', var.shape, nn.initializers.ones)
if not self.is_initializing(): # during init we ignore the moving averages completely
if training:
# during training the moving averages are updated
alpha = 1. - momentum
ra_mean.value += alpha * (mean - ra_mean.value)
ra_var.value += alpha * (var - ra_var.value)
else:
# if we are not training we use the moving averages
mean = ra_mean.value
var = ra_var.value
# standardize the input
y = (x - mean) / jnp.sqrt(var + eps)
# learn the scale and bias of the output
gamma = self.param('gamma', mean.shape, gamma_init)
beta = self.param('beta', mean.shape, beta_init)
return gamma * y + beta
Stateful modules require special care when used. The nn.stateful
context manager defines a scope in which stateful operations are allowed. Outside of this scope the state becomes immutable.
The state is stored in a nn.Collection
object which internally stores the state as a dictionary.
nn.stateful
takes a Collection containing the current state and returns a new Collection that contains the updated state. By default a new Collection will be created.
When using nn.stateful(state, mutable=False)
the state can be read but any updates will raise an error. This is often useful during test time to guarantee that test data does not affect the model.
In [ ]:
class MyModel(nn.Module):
def apply(self, x, training=False):
x = Dense(x, features=4)
x = BatchNorm(x, training=training, momentum=0., name='batch_norm')
return x
dist_a = lambda rng, shape: random.normal(rng, shape) * jnp.array([[1., 3.]])
x_a = dist_a(random.PRNGKey(1), (1024, 2))
print('std. deviation of input:', x_a.std(0))
with nn.stateful() as init_state:
y, params = MyModel.init(random.PRNGKey(2), x_a)
print('std. deviation of output (init):', y.std(0))
with nn.stateful(init_state) as new_state:
y = MyModel.call(params, x_a, training=True)
print('std. deviation of output (training):', y.std(0))
with nn.stateful(new_state, mutable=False):
y = MyModel.call(params, x_a, training=False)
print('std. deviation of output (testing):', y.std(0))
The state can be inspected using Collection.as_dict()
.
Each Module has a path like key into the Collection (eg. '/some_module/nested_module/dense').
In [ ]:
init_state.as_dict()
In [ ]:
new_state.as_dict()
The stateful mechanism forces the user to be explicit about stateful operations.
One motivating example for this approach is to enforce that state is not updated at test time.
Another benefit is that it is easier to replace the state when necessary. For example let say we want to apply this model on a second input distribution (b) with different statistics.
In [ ]:
dist_b = lambda rng, shape: random.normal(rng, shape) * jnp.array([[2., 5.]])
x_b = dist_b(random.PRNGKey(1), (1024, 2))
with nn.stateful(new_state, mutable=False):
y = MyModel.call(params, x_b, training=False)
print(y.std(0)) # this will not be properly normalized!
We can solve the skew in statistics by creating a separate state for this alternative input distribution.
In [ ]:
with nn.stateful(init_state) as state_b:
y = MyModel.call(params, x_b, training=True)
print('std. deviation of output (training):', y.std(0))
with nn.stateful(state_b, mutable=False):
y = MyModel.call(params, x_b, training=False)
print('std. deviation of output (testing):', y.std(0))
In [ ]:
rng = random.PRNGKey(0)
rng, key1, key2 = random.split(rng, 3)
n = 30
x = jnp.linspace(-5., 5.)
X = random.uniform(key1, (n,), minval=-5., maxval=5.)
f = lambda x: 2. * x
Y = f(X) + random.normal(key2, (n,))
plt.plot(x, f(x))
plt.scatter(X, Y)
plt.show()
The model is nothing more than a Dense module with a single feature
In [ ]:
class LinearRegression(nn.Module):
def apply(self, x):
# add a singleton dimension to the input and remove the singleton feature dim from the output
return nn.Dense(x[..., None], features=1)[..., 0]
rng, key = random.split(rng)
_, initial_params = LinearRegression.init(key, X)
model = nn.Model(LinearRegression, initial_params)
# plot the data together with the line used to generate the data (blue) and the untrained model (orange)
plt.plot(x, f(x))
plt.plot(x, model(x))
plt.scatter(X, Y)
plt.show()
We will use gradient descent with momentum to fit the model to the data.
Each optimizer inherits from the flax.optim.OptimizerDef
.
The OptimizerDef
class provides an init_state
and apply_gradient
method which initialize and update the optimizer state, respectively.
OptimizerDef
does not actually maintain the state and optimized parameters. It it simply a collection of functions.
When calling OptimizerDef.create
the optimization target and the optimizer state (eg. gradient moving average) are wrapped together with the OptimizerDef
in an instance of Optimizer
.
Optimizers can optimize any pytree of arrays (nested dicts, Model
instances, etc), as long as the gradient w.r.t. is computable by jax.grad
.
In [ ]:
optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.9)
optimizer = optimizer_def.create(model)
train_steps = 100
def loss_fn(model):
Y_hat = model(X)
return jnp.square(Y - Y_hat).mean()
for i in range(train_steps):
# optimizer.target is passed to the loss_fn
loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
# `apply_gradient` returns a new `Optimizer` instance with the updated target and optimizer state.
optimizer = optimizer.apply_gradient(grad)
print('mean square error:', loss)
trained_model = optimizer.target
print(trained_model.params)
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)
plt.show()
The flax.serialization
module provides utilities for extracting the state of optimizers, model, and other structures as a dictionary of arrays.
It also provides an integration with message pack which can be used to efficiently serialize the state dictionary in a binary, cross-platform compatible format.
The flax.struct.dataclass
decorator can be used to create a Python dataclass which can be passed into jax transformations like jit
, grad
, and tree_map
. It also integrates with the state dict api.
In [ ]:
# Model and Optimizer are also Flax dataclasses.
@flax.struct.dataclass
class TrainState:
optimizer: optim.Optimizer
step: int
state = TrainState(optimizer=optimizer, step=5)
flax.serialization.to_state_dict(state)
Flax dataclasses are immutable. Using the replace
method a new instance can be created with a set of updated fields.
In [ ]:
new_state = state.replace(step=6)
print(state.step, new_state.step)
The to_bytes
and from_bytes
functions are used to convert an object to the message pack format and back.
In [ ]:
data = flax.serialization.to_bytes(state)
print('num bytes:', len(data))
In [ ]:
corrupted_state = jax.tree_map(lambda x: 0 * x, state)
flax.serialization.to_state_dict(corrupted_state)
In [ ]:
# Restore the state using the state dict using the serialized state dict stored in data
restored_state = flax.serialization.from_bytes(corrupted_state, data)
flax.serialization.to_state_dict(restored_state)
Sometimes we wish to apply a different optimizer to a subset of the model parameters. To illustrate we will apply weight decay to bias value of the linear model that was trained before.
flax.optim.MultiOptimizer
takes in tuples of traversals and OptimizerDef
instances. The traversal is responsible for selecting the subset of parameters that should be optimized. We will use flax.optim.ModelParamTraversal
which allows you to filter parameters based on the path (eg. '/hidden/dense/kernel').
In [ ]:
slope_opt_def = optim.Momentum(learning_rate=0.1)
# by applying decay to the bias parameter it will end up being closer to zero than before
bias_opt_def = optim.Momentum(learning_rate=0.1, weight_decay=10.)
# select all kernel parameters
slope_traversal = optim.ModelParamTraversal(lambda path, param: 'kernel' in path)
# select all bias parameters
bias_traversal = optim.ModelParamTraversal(lambda path, param: 'bias' in path)
optimizer_def = optim.MultiOptimizer((slope_traversal, slope_opt_def), (bias_traversal, bias_opt_def))
_, initial_params = LinearRegression.init(random.PRNGKey(0), X)
model = nn.Model(LinearRegression, initial_params)
optimizer = optimizer_def.create(model)
train_steps = 100
def loss_fn(model):
Y_hat = model(X)
return jnp.square(Y - Y_hat).mean()
for i in range(train_steps):
loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
print('mean square error:', loss)
trained_model = optimizer.target
print(trained_model.params)
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)
plt.show()
In [ ]:
class MultiMethodModule(nn.Module):
def apply(self, x):
kernel = self.param('kernel', (), lambda _, shape: jnp.full(shape, 2.))
return x * kernel
@nn.module_method
def decode(self, x):
kernel = self.get_param('kernel')
return x * kernel
x = 2. ** jnp.arange(5)
y, initial_params = MultiMethodModule.init(random.PRNGKey(0), x)
model = nn.Model(MultiMethodModule, initial_params)
print('target:', x[1:], 'teacher forced decoding:', y[:-1])
print('sequential decoding (one step):', model.decode(1.))
In [ ]:
def body_fn(carry, _):
y = model.decode(carry)
new_carry = y # feed output back
return new_carry, y
carry, ys = lax.scan(body_fn, 1., (), length=4)
print('carry:', carry)
print('sequential decoding:', ys)
In [ ]:
def add_scale(module):
class ScaleWrapper(nn.Module):
"""Add a learnable scale to the kernel of a module."""
def apply(self, *args, **kwargs):
def init_fn(rng, _):
_, params = module.init(rng, *args, **kwargs)
# here we could change the initial parameters of the wrapped module
return params
params = self.param('params', None, init_fn)
# here change transform parameters every call
assert 'kernel' in params
kernel = params['kernel']
features = kernel.shape[-1]
scale = self.param('scale', (features,), nn.initializers.ones)
scaled_kernel = kernel * scale
scaled_params = params.copy()
scaled_params['kernel'] = scaled_kernel
return module.call(scaled_params, *args, **kwargs)
return ScaleWrapper
x = jnp.ones((1, 2))
module = add_scale(Dense).partial(features=4)
y, params = module.init(random.PRNGKey(0), x)
params
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: