Migrating from Haiku to Flax#
This guide will walk through the process of migrating Haiku models to Flax, and highlight the differences between the two libraries.
Basic Example#
To create custom Modules you subclass from a Module
base class in
both Haiku and Flax. However, Haiku classes use a regular __init__
method
whereas Flax classes are dataclasses
, meaning you define some class
attributes that are used to automatically generate a constructor. Also,
all Flax Modules accept a name
argument without needing to define it,
whereas in Haiku name
must be explicitly defined in the constructor
signature and passed to the superclass constructor.
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
The __call__
method looks very similar in both libraries, however, in Flax
you have to use the @nn.compact
decorator in order to be able to define
submodules inline. In Haiku, this is the default behavior.
Now, a place where Haiku and Flax differ substantially is in how you construct
the model. In Haiku, you use hk.transform
over a function
that calls your Module, transform
will return an object with init
and apply
methods. In Flax, you simply instantiate your Module.
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
...
model = Model(256, 10)
To get the model parameters in both libraries you use the init
method
with a random.key
plus some inputs to run the model. The main difference here is
that Flax returns a mapping from collection names to nested array dictionaries,
params
is just one of these possible collections. In Haiku, you get the params
structure directly.
sample_x = jax.numpy.ones((1, 784))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables["params"]
One very important thing to note is that in Flax the parameters structure is
hierarchical, with one level per nested module and a final level for the
parameter name.
In Haiku the parameters structure is a python dictionary with a two level hierarchy:
the fully qualified module name mapping to the parameter name. The module name
consists of a /
separated string path of all the nested Modules.
...
{
'model/block/linear': {
'b': (256,),
'w': (784, 256),
},
'model/linear': {
'b': (10,),
'w': (256, 10),
}
}
...
FrozenDict({
Block_0: {
Dense_0: {
bias: (256,),
kernel: (784, 256),
},
},
Dense_0: {
bias: (10,),
kernel: (256, 10),
},
})
During training in both frameworks you pass the parameters structure to the
apply
method to run the forward pass. Since we are using dropout, in
both cases we must provide a key
to apply
in order to generate
the random dropout masks.
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params,
key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
The most notable differences is that in Flax you have to
pass the parameters inside a dictionary with a params
key, and the
key inside a dictionary with a dropout
key. This is because in Flax
you can have many types of model state and random state. In Haiku, you
just pass the parameters and the key directly.
Handling State#
Now let’s see how mutable state is handled in both libraries. We will take the same model as before, but now we will replace Dropout with BatchNorm.
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(
momentum=0.99
)(x, use_running_average=not training)
x = jax.nn.relu(x)
return x
The code is very similar in this case as both libraries provide a BatchNorm
layer. The most notable difference is that Haiku uses is_training
to
control whether or not to update the running statistics, whereas Flax uses
use_running_average
for the same purpose.
To instantiate a stateful model in Haiku you use hk.transform_with_state
,
which changes the signature for init
and apply
to accept and return
state. As before, in Flax you construct the Module directly.
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
...
model = Model(256, 10)
To initialize both the parameters and state you just call the init
method
as before. However, in Haiku you now get state
as a second return value, and
in Flax you get a new batch_stats
collection in the variables
dictionary.
Note that since hk.BatchNorm
only initializes batch statistics when
is_training=True
, we must set training=True
when initializing parameters
of a Haiku model with an hk.BatchNorm
layer. In Flax, we can set
training=False
as usual.
sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
random.key(0),
sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
In general, in Flax you might find other state collections in the variables
dictionary such as cache
for auto-regressive transformers models,
intermediates
for intermediate values added using Module.sow
, or other
collection names defined by custom layers. Haiku only makes a distinction
between params
(variables which do not change while running apply
) and
state
(variables which can change while running apply
).
Now, training looks very similar in both frameworks as you use the same
apply
method to run the forward pass. In Haiku, now pass the state
as the second argument to apply
, and get the new state as the second
return value. In Flax, you instead add batch_stats
as a new key to the
input dictionary, and get the updates
variables dictionary as the second
return value.
def train_step(params, state, inputs, labels):
def loss_fn(params):
logits, new_state = model.apply(
params, state,
None, # <== rng
inputs, training=True # <== inputs
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, new_state
grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, new_state
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params):
logits, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
mutable='batch_stats',
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates["batch_stats"]
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
One major difference is that in Flax a state collection can be mutable or immutable.
During init
all collections are mutable by default, however, during apply
you have to explicitly specify which collections are mutable. In this example,
we specify that batch_stats
is mutable. Here a single string is passed but a list
can also be given if there are more mutable collections. If this is not done an
error will be raised at runtime when trying to mutate batch_stats
.
Also, when mutable
is anything other than False
, the updates
dictionary is returned as the second return value of apply
, else only the
model output is returned.
Haiku makes the mutable/immutable distinction through having params
(immutable) and state
(mutable) and using either hk.transform
or
hk.transform_with_state
Using Multiple Methods#
In this section we will take a look at how to use multiple methods in Haiku and Flax.
As an example, we will implement an auto-encoder model with three methods:
encode
, decode
, and __call__
.
In Haiku, we can just define the submodules that encode
and decode
need
directly in __init__
, in this case each will just use a Linear
layer.
In Flax, we will define an encoder
and a decoder
Module ahead of time
in setup
, and use them in the encode
and decode
respectively.
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
Note that in Flax setup
doesn’t run after __init__
, instead it runs
when init
or apply
are called.
Now, we want to be able to call any method from our AutoEncoder
model. In Haiku we
can define multiple apply
methods for a module through hk.multi_transform
. The
function passed to multi_transform
defines how to initialize the module and which
different apply methods to generate.
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
...
model = AutoEncoder(256, 784)
To initialize the parameters of our model, init
can be used to trigger the
__call__
method, which uses both the encode
and decode
method. This will create all the necessary parameters for the model.
params = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
params = variables["params"]
This generates the following parameter structure.
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
FrozenDict({
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
})
Finally, let’s explore how we can employ the apply
function to invoke the encode
method:
encode, decode = model.apply
z = encode(
params,
None, # <== rng
x=jax.numpy.ones((1, 784)),
)
...
z = model.apply(
{"params": params},
x=jax.numpy.ones((1, 784)),
method="encode",
)
Because the Haiku apply
function is generated through
hk.multi_transform
, it’s a tuple of two functions which we can unpack into
an encode
and decode
function which correspond to the methods on the
AutoEncoder
module. In Flax we call the encode
method through passing
the method name as a string.
Another noteworthy distinction here is that in Haiku, rng
needs to be
explicitly passed, even though the module does not use any stochastic
operations during apply
. In Flax this is not necessary (check out
Randomness and PRNGs in Flax).
The Haiku rng
is set to None
here, but you could also use
hk.without_apply_rng
on the apply
function to remove the rng
argument.
Lifted Transforms#
Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms,
that wrap JAX transformations in such a way that they can be used with Modules and sometimes
provide additional functionality. In this section we will take a look at how to use the
lifted version of scan
in both Flax and Haiku to implement a simple RNN layer.
To begin, we will first define a RNNCell
module that will contain the logic for a single
step of the RNN. We will also define a initial_state
method that will be used to initialize
the state (a.k.a. carry
) of the RNN. Like with jax.lax.scan
, the RNNCell.__call__
method will be a function that takes the carry and input, and returns the new
carry and output. In this case, the carry and the output are the same.
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
Next, we will define a RNN
Module that will contain the logic for the entire RNN.
In Haiku, we will first initialze the RNNCell
, then use it to construct the carry
,
and finally use hk.scan
to run the RNNCell
over the input sequence. In Flax its
done a bit differently, we will use nn.scan
to define a new temporary type that wraps
RNNCell
. During this process we will also specify instruct nn.scan
to broadcast
the params
collection (all steps share the same parameters) and to not split the
params
rng stream (so all steps intialize with the same parameters), and finally
we will specify that we want scan to run over the second axis of the input and stack
the outputs along the second axis as well. We will then use this temporary type immediately
to create an instance of the lifted RNNCell
and use it to create the carry
and
the run the __call__
method which will scan
over the sequence.
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
in_axes=1, out_axes=1)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
In general, the main difference between lifted transforms between Flax and Haiku is that
in Haiku the lifted transforms don’t operate over the state, that is, Haiku will handle the
params
and state
in such a way that it keeps the same shape inside and outside of the
transform. In Flax, the lifted transforms can operate over both variable collections and rng
streams, the user must define how different collections are treated by each transform
according to the transform’s semantics.
Finally, let’s quickly view how the RNN
Module would be used in both Haiku and Flax.
def forward(x):
return RNN(64)(x)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
y = model.apply(
params,
x=jax.numpy.ones((3, 12, 32)),
)
...
model = RNN(64)
variables = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
{'params': params},
x=jax.numpy.ones((3, 12, 32)),
)
The only notable change with respect to the examples in the previous sections is that
this time around we used hk.without_apply_rng
in Haiku so we didn’t have to
pass the rng
argument as None
to the apply
method.
Scan over layers#
One very important application of scan
is apply a sequence of layers iteratively
over an input, passing the output of each layer as the input to the next layer. This
is very useful to reduce compilation time for big models. As an example we will create
a simple Block
Module, and then use it inside an MLP
Module that will apply
the Block
Module num_layers
times.
In Haiku, we define the Block
Module as usual, and then inside MLP
we will
use hk.experimental.layer_stack
over a stack_block
function to create a stack
of Block
Modules. In Flax, the definition of Block
is a little different,
__call__
will accept and return a second dummy input/output that in both cases will
be None
. In MLP
, we will use nn.scan
as in the previous example, but
by setting split_rngs={'params': True}
and variable_axes={'params': 0}
we are telling nn.scan
create different parameters for each step and slice the
params
collection along the first axis, effectively implementing a stack of
Block
Modules as in Haiku.
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
Notice how in Flax we pass None
as the second argument to ScanBlock
and ignore
its second output. These represent the inputs/outputs per-step but they are None
because in this case we don’t have any.
Initializing each model is the same as in previous examples. In this case,
we will be specifying that we want to use 5
layers each with 64
features.
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jax.numpy.ones((1, 64))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
...
model = MLP(64, num_layers=5)
sample_x = jax.numpy.ones((1, 64))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables['params']
When using scan over layers the one thing you should notice is that all layers
are fused into a single layer whose parameters have an extra “layer” dimension on
the first axis. In this case, the shape of all parameters will start with (5, ...)
as we are using 5
layers.
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
FrozenDict({
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
})
Top-level Haiku functions vs top-level Flax modules#
In Haiku, it is possible to write the entire model as a single function by using the raw hk.{get,set}_{parameter,state}
to define/access model parameters and states. It very common to write the top-level “Module” as a function instead:
The Flax team recommends a more Module-centric approach that uses __call__ to define the forward function. The corresponding accessor will be nn.module.param and nn.module.variable (go to Handling State for an explanaion on collections).
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output
model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']