Dropout#
This guide provides an overview of how to apply
dropout
using flax.linen.Dropout()
.
Dropout is a stochastic regularization technique that randomly removes hidden and visible units in a network.
Throughout the guide, you will be able to compare code examples with and without
Flax Dropout
.
Split the PRNG key#
Since dropout is a random operation, it requires a pseudorandom number generator (PRNG) state. Flax uses JAX’s (splittable) PRNG keys, which have a number of desirable properties for neural networks. To learn more, refer to the Pseudorandom numbers in JAX tutorial.
Note: Recall that JAX has an explicit way of giving you PRNG keys:
you can fork the main PRNG state (such as key = jax.random.key(seed=0)
)
into multiple new PRNG keys with key, subkey = jax.random.split(key)
. You
can refresh your memory in
🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys.
Begin by splitting the PRNG key using
jax.random.split()
into three keys, including one for Flax Linen Dropout
.
root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
Note: In Flax, you provide PRNG streams with names, so that you can use them later
in your flax.linen.Module()
. For example, you pass the stream 'params'
for initializing parameters, and 'dropout'
for applying
flax.linen.Dropout()
.
Define your model with Dropout
#
To create a model with dropout:
Subclass
flax.linen.Module()
, and then useflax.linen.Dropout()
to add a dropout layer. Recall thatflax.linen.Module()
is the base class for all neural network Modules, and all layers and models are subclassed from it.In
flax.linen.Dropout()
, thedeterministic
argument is required to be passed as a keyword argument, either:When constructing the
flax.linen.Module()
; orWhen calling
flax.linen.init()
orflax.linen.apply()
on a constructedModule
. (Refer toflax.linen.module.merge_param()
for more details.)
Because
deterministic
is a boolean:If it’s set to
False
, the inputs are masked (that is, set to zero) with a probability set byrate
. And the remaining inputs are scaled by1 / (1 - rate)
, which ensures that the means of the inputs are preserved.If it’s set to
True
, no mask is applied (the dropout is turned off), and the inputs are returned as-is.
A common pattern is to accept a training
(or train
) argument (a boolean)
in the parent Flax Module
, and use it to enable or disable dropout (as
demonstrated in later sections of this guide). In other machine learning
frameworks, like PyTorch or TensorFlow (Keras), this is specified via a
mutable state or a call flag (for example, in
torch.nn.Module.eval
or tf.keras.Model
by setting the
training flag).
Note: Flax provides an implicit way of handling PRNG key streams via Flax
flax.linen.Module()
’s flax.linen.Module.make_rng()
method.
This allows you to split off a fresh PRNG key inside Flax Modules (or their
sub-Modules) from the PRNG stream. The make_rng
method guarantees to provide a
unique key each time you call it. Internally, flax.linen.Dropout()
makes
use of flax.linen.Module.make_rng()
to create a key for dropout. You can
check out the
source code.
In short, flax.linen.Module.make_rng()
guarantees full reproducibility.
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
return x
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a `rate` of 50%.
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
Initialize the model#
After creating your model:
Instantiate the model.
Then, in the
flax.linen.init()
call, settraining=False
.Finally, extract the
params
from the variable dictionary.
Here, the main difference between the code without Flax Dropout
and with Dropout
is that the training
(or train
) argument must be
provided if you need dropout enabled.
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x)
params = variables['params']
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']
Perform the forward pass during training#
When using flax.linen.apply()
to run your model:
Pass
training=True
toflax.linen.apply()
.Then, to draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed the
'dropout'
stream when you callflax.linen.apply()
.
# No need to pass the `training` and `rngs` flags.
y = my_model.apply({'params': params}, x)
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})
Here, the main difference between the code without Flax Dropout
and with Dropout
is that the training
(or train
) and rngs
arguments must be provided if you need dropout enabled.
During evaluation, use the above code with no dropout enabled (this means you do not have to pass a RNG either).
TrainState
and the training step#
This section explains how to amend your code inside the training step function if you have dropout enabled.
Note: Recall that Flax has a common pattern where you create a dataclass
that represents the whole training state, including parameters and the optimizer
state. Then, you can pass a single parameter, state: TrainState
, to
the training step function. Refer to the
flax.training.train_state.TrainState()
API docs to learn more.
First, add a
key
field to a customflax.training.train_state.TrainState()
class.Then, pass the
key
value—in this case, thedropout_key
—to thetrain_state.TrainState.create()
method.
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=my_model.apply,
params=params,
tx=optax.adam(1e-3)
)
from flax.training import train_state
class TrainState(train_state.TrainState):
key: jax.Array
state = TrainState.create(
apply_fn=my_model.apply,
params=params,
key=dropout_key,
tx=optax.adam(1e-3)
)
Next, in the Flax training step function,
train_step
, generate a new PRNG key from thedropout_key
to apply dropout at each step. This can be done with one of the following:Using
jax.random.fold_in()
is generally faster. When you usejax.random.split()
you split off a PRNG key that can be reused afterwards. However, usingjax.random.fold_in()
makes sure to 1) fold in unique data; and 2) can result in longer sequences of PRNG streams.Finally, when performing the forward pass, pass the new PRNG key to
state.apply_fn()
as an extra parameter.
@jax.jit
def train_step(state: train_state.TrainState, batch):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
training=True,
rngs={'dropout': dropout_train_key}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
Flax examples with dropout#
A Transformer-based model trained on the WMT Machine Translation dataset. This example uses dropout and attention dropout.
Applying word dropout to a batch of input IDs in a text classification context. This example uses a custom
flax.linen.Dropout()
layer.
More Flax examples that use Module make_rng()
#
Defining a prediction token in a decoder of a sequence-to-sequence model.