Frequently Asked Questions (FAQ)#
This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in GitHub Discussions.
How to take the derivative with respect to an intermediate value (using Module.perturb
)?#
To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use flax.linen.Module.perturb()
. You define a zero-value flax.linen.Module
“perturbation” parameter – perturb(...)
– in the forward pass with the same shape as the intermediate activation, define the loss function with 'perturbations'
as an added standalone argument, perform a JAX derivative operation with jax.grad
on the perturbation argument.
For full examples and detailed documentation, go to:
The
flax.linen.Module.perturb()
API docs
Is Flax Linen remat_scan()
the same as scan(remat(...))
?#
Flax remat_scan()
(flax.linen.remat_scan()
) and scan(remat(...))
(flax.linen.scan()
over flax.linen.remat()
) are not the same, and remat_scan()
is limited in cases it supports. Namely, remat_scan()
treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use scan(remat(...))
, as typically you would need the extra parameters, such as in_axes
(for input array axes) or out_axes
(output array axes), which flax.linen.remat_scan()
does not expose.
What are the recommended training loop libraries?#
Consider using CLU (Common Loop Utils) google/CommonLoopUtils. To get started, go to this CLU Synopsis Colab. You can find answers to common questions about CLU with Flax on google/flax GitHub Discussions.
Check out the official google/flax Examples for examples of using the training loop with (CLU) metrics. For example, this is Flax ImageNet’s train.py.
For computer vision research, consider google-research/scenic. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the README page on GitHub.