Frequently Asked Questions (FAQ)#
This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in GitHub Discussions.
How to take the derivative with respect to an intermediate value (using Module.perturb)?#
To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use flax.linen.Module.perturb(). You define a zero-value flax.linen.Module “perturbation” parameter – perturb(...) – in the forward pass with the same shape as the intermediate activation, define the loss function with 'perturbations' as an added standalone argument, perform a JAX derivative operation with jax.grad on the perturbation argument.
For full examples and detailed documentation, go to:
The
flax.linen.Module.perturb()API docs
Is Flax Linen remat_scan() the same as scan(remat(...))?#
Flax remat_scan() (flax.linen.remat_scan()) and scan(remat(...)) (flax.linen.scan() over flax.linen.remat()) are not the same, and remat_scan() is limited in cases it supports. Namely, remat_scan() treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use scan(remat(...)), as typically you would need the extra parameters, such as in_axes (for input array axes) or out_axes (output array axes), which flax.linen.remat_scan() does not expose.
What are the recommended training loop libraries?#
Consider using CLU (Common Loop Utils) google/CommonLoopUtils. To get started, go to this CLU Synopsis Colab. You can find answers to common questions about CLU with Flax on google/flax GitHub Discussions.
Check out the official google/flax Examples for examples of using the training loop with (CLU) metrics. For example, this is Flax ImageNet’s train.py.
For computer vision research, consider google-research/scenic. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the README page on GitHub.