Upgrading my codebase to Linen#
As of Flax v0.4.0, flax.nn
no longer exists, and is replaced with the new
Linen API at flax.linen
. If your codebase is still using the old API, you
can use this upgrade guide to upgrade it to Linen.
Defining simple Flax Modules#
from flax import nn
class Dense(base.Module):
def apply(self,
inputs,
features,
use_bias=True,
kernel_init=default_kernel_init,
bias_init=initializers.zeros_init()):
kernel = self.param('kernel',
(inputs.shape[-1], features), kernel_init)
y = jnp.dot(inputs, kernel)
if use_bias:
bias = self.param(
'bias', (features,), bias_init)
y = y + bias
return y
from flax import linen as nn # [1]
class Dense(nn.Module):
features: int # [2]
use_bias: bool = True
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
@nn.compact
def __call__(self, inputs): # [3]
kernel = self.param('kernel',
self.kernel_init, (inputs.shape[-1], self.features)) # [4]
y = jnp.dot(inputs, kernel)
if self.use_bias:
bias = self.param(
'bias', self.bias_init, (self.features,)) # [5]
y = y + bias
return y
Replace
from flax import nn
withfrom flax import linen as nn
.Move arguments to
apply
into dataclass attributes. Add type annotations (or use typeAny
to bypass).Rename method
apply
to__call__
and (optionally) wrap with@compact
. Methods wrapped in@compact
can define submodules directly within the method (like in old Flax). You can only wrap a single method with@compact
. Alternatively, you can define asetup
method. For more details, please see our other HOWTO Should I use setup or nn.compact?.Access dataclass attributes values by
self.<attr>
inside methods, e.g.self.features
.Move shape to the end of the arguments to
self.param
(initializer functions can take arbitrary argument lists).
Using Flax Modules inside other Modules#
class Encoder(nn.Module):
def apply(self, x):
x = nn.Dense(x, 500)
x = nn.relu(x)
z = nn.Dense(x, 500, name="latents")
return z
class Encoder(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(500)(x) # [1]
x = nn.relu(x)
z = nn.Dense(500, name='latents')(x) # [2]
return z
Module constructors no longer return the outputs. Instead, they work like normal constructors and return module instances. These instances can be shared like in normal Python (instead of using
.shared()
in old Flax). Since most modules implement__call__
, you can retain the conciseness of old Flax.Names can be optionally passed to all module constructors.
Sharing submodules and defining multiple methods#
class AutoEncoder(nn.Module):
def _create_submodules(self):
return Decoder.shared(name="encoder")
def apply(self, x, z_rng, latents=20):
decoder = self._create_decoder()
z = Encoder(x, latents, name="encoder")
return decoder(z)
@nn.module_method
def generate(self, z, **unused_kwargs):
decoder = self._create_decoder()
return nn.sigmoid(decoder(z))
class AutoEncoder(nn.Module):
latents: int = 20
def setup(self): # [1]
self.encoder = Encoder(self.latents) # [2]
self.decoder = Decoder()
def __call__(self, x): # [3]
z = self.encoder(x)
return self.decoder(z)
def generate(self, z): # [4]
return nn.sigmoid(self.decoder(z))
Use
setup
instead of__init__
, which is already defined in the dataclasses library. Flax calls setup right after modules are ready to be used. (You can do this for all modules if you like instead of using@compact
, but we like how@compact
co-locates where modules are defined and used, especially if you have loops or conditionals).Like regular Python, share submodules by assigning to self during initialization. Similar to PyTorch,
self.encoder
automatically has the name"encoder"
.We don’t use
@compact
here because we’re not defining any inline submodules (all submodules are defined in setup).Define additional methods just like in regular Python.
Module.partial
inside other modules#
# no import
class ResNet(nn.Module):
"""ResNetV1."""
def apply(self, x,
stage_sizes,
num_filters=64,
train=True):
conv = nn.Conv.partial(bias=False)
norm = nn.BatchNorm.partial(
use_running_average=not train,
momentum=0.9, epsilon=1e-5)
x = conv(x, num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')
x = norm(x, name='bn_init')
# [...]
return x
from functools import partial
class ResNet(nn.Module):
"""ResNetV1."""
stage_sizes: Sequence[int]
num_filters: int = 64
train: bool = True
@nn.compact
def __call__(self, x):
conv = partial(nn.Conv, use_bias=False)
norm = partial(nn.BatchNorm,
use_running_average=not self.train,
momentum=0.9, epsilon=1e-5)
x = conv(self.num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')(x)
x = norm(name='bn_init')(x)
# [...]
return x
Use normal functools.partial
instead of Module.partial
. The rest stays
the same.
Top-level training code patterns#
def create_model(key):
_, initial_params = CNN.init_by_shape(
key, [((1, 28, 28, 1), jnp.float32)])
model = nn.Model(CNN, initial_params)
return model
def create_optimizer(model, learning_rate):
optimizer_def = optim.Momentum(learning_rate=learning_rate)
optimizer = optimizer_def.create(model)
return optimizer
def cross_entropy_loss(*, logits, labels):
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
def loss_fn(model):
logits = model(batch['image'])
one_hot = jax.nn.one_hot(batch['label'], num_classes=10)
loss = -jnp.mean(jnp.sum(one_hot_labels * batch['label'],
axis=-1))
return loss, logits
def create_train_state(rng, config): # [1]
variables = CNN().init(rng, jnp.ones([1, 28, 28, 1])) # [2]
params = variables['params'] # [3]
tx = optax.sgd(config.learning_rate, config.momentum) # [4]
return train_state.TrainState.create(
apply_fn=CNN.apply, params=params, tx=tx)
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image']) # [5]
one_hot = jax.nn.one_hot(batch['label'], 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits,
labels=one_hot))
return loss, logits
We no longer use the
Model
abstraction – instead we pass parameters around directly, usually encapsulated in a TrainState object, which can directly be passed to JAX transformations.To compute initial parameters, construct a module instance and call
init
orinit_with_output
. We haven’t ported overinit_by_shape
because this function did some magic we did not like (it evaluated the function by shape. but returned real values anyway). Therefore, you should now pass concrete values to the initializer functions, and you can optimize the initialization by wrapping it withjax.jit
, which is highly recommended to avoid running a full forward pass.Linen generalizes parameters into variables. Parameters are one “collection” of variables. Variables are nested dicts, where the top-level keys reflect the different variable collections, of which “param” is one of. See the Variables documentation for more details.
We recommend using Optax optimizers. See our separate HOWTO called Upgrading my codebase to Optax for more details.
To make predictions with your model, make an instance at the top level (this is free – just a wrapper around constructor attributes) and call the
apply
method (which will call__call__
internally).
Non-trainable variables (“state”): Use within Modules#
class BatchNorm(nn.Module):
def apply(self, x):
# [...]
ra_mean = self.state(
'mean', (x.shape[-1], ), initializers.zeros_init())
ra_var = self.state(
'var', (x.shape[-1], ), initializers.ones_init())
# [...]
class BatchNorm(nn.Module):
def __call__(self, x):
# [...]
ra_mean = self.variable(
'batch_stats', 'mean', initializers.zeros_init(), (x.shape[-1], ))
ra_var = self.variable(
'batch_stats', 'var', initializers.ones_init(), (x.shape[-1], ))
# [...]
The first argument is the name of the variable collection (“param” is the only variable collection that’s always available). Some colllections may be treated as mutable, and others as immutable at top-level training code (see next section for details). Flax also lets you treat each variable collection differently when using JAX transformations inside modules.
Non-trainable variables (“state”): Top-level training code patterns#
# initial params and state
def initial_model(key, init_batch):
with nn.stateful() as initial_state:
_, initial_params = ResNet.init(key, init_batch)
model = nn.Model(ResNet, initial_params)
return model, init_state
# updates batch statistics during training
def loss_fn(model, model_state):
with nn.stateful(model_state) as new_model_state:
logits = model(batch['image'])
# [...]
# reads immutable batch statistics during evaluation
def eval_step(model, model_state, batch):
with nn.stateful(model_state, mutable=False):
logits = model(batch['image'], train=False)
return compute_metrics(logits, batch['label'])
# initial variables ({"param": ..., "batch_stats": ...})
def initial_variables(key, init_batch):
return ResNet().init(key, init_batch) # [1]
# updates batch statistics during training
def loss_fn(params, batch_stats):
variables = {'params': params, 'batch_stats': batch_stats} # [2]
logits, new_variables = ResNet(train=true).apply(
variables, batch['image'], mutable=['batch_stats']) # [3]
new_batch_stats = new_variables['batch_stats']
# [...]
# reads immutable batch statistics during evaluation
def eval_step(params, batch_stats, batch):
variables = {'params': params, 'batch_stats': batch_stats}
logits = ResNet(train=False).apply(
variables, batch['image'], mutable=False) # [4]
return compute_metrics(logits, batch['label'])
init
returns a variable dict, e.g.{"param": ..., "batch_stats": ...}
(see Variables documentation).Combine the different variable collections into a variable dict.
During training, the
batch_stats
variable collection changes. Since we specify that in the mutable argument, the return value frommodule.apply
becomes an ordered pair ofoutput, new_variables
.During evaluation, we want to raise an error if we’re accidentally applying Batch Norm in training mode. By passing
mutable=False
intomodule.apply
we enforce that. Since no variables are mutated, the return value is once again just the output.
Loading pre-Linen checkpoints#
While most Linen modules should be able to use pre-Linen weights without any modification, there is one catch: In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class.
In pre-Linen, params have the following structure:
{'Conv_0': { ... }, 'Dense_1': { ... } }
In Linen this is instead:
{'Conv_0': { ... }, 'Dense_0': { ... } }
TODO: Add an example here how to load a new TrainState
object.
Randomness#
def dropout(inputs, rate, deterministic=False):
keep_prob = 1. - rate
if deterministic:
return inputs
else:
mask = random.bernoulli(
make_rng(), p=keep_prob, shape=inputs.shape)
return lax.select(
mask, inputs / keep_prob, jnp.zeros_like(inputs))
def loss_fn(model, dropout_rng):
with nn.stochastic(dropout_rng):
logits = model(inputs)
class Dropout(nn.Module):
rate: float
@nn.compact
def __call__(self, inputs, deterministic=False):
keep_prob = 1. - self.rate
if deterministic:
return inputs
else:
mask = random.bernoulli(
self.make_rng('dropout'), p=keep_prob, shape=inputs.shape) # [1]
return lax.select(
mask, inputs / keep_prob, jnp.zeros_like(inputs))
def loss_fn(params, dropout_rng):
logits = Transformer().apply(
{'params': params}, inputs, rngs={'dropout': dropout_rng}) # [2]
RNGs in Linen have “kinds” – in this case
'dropout'
. Different kinds can be treated different in JAX transformations (for example, do you want the same dropout mask for each timestep in a sequence model or a different one?)Instead of using the
nn.stochastic
context manager, you pass in RNGs explicitly tomodule.apply
. During evaluation you wouldn’t pass any RNGs – then if you accidentally use dropout in non-deterministic mode,self.make_rng('dropout')
would raise an error.
Lifted transformations#
In Linen, rather than using JAX transformation directly, we are using “lifted transforms”, which are JAX transformations applied to Flax Modules.
For more information, please see the design note on Lifted transformations.
TODO: Given an example of jax.scan_in_dim
(pre-Linen) vs. nn.scan
(Linen).