Transformations#
JAX transformations on Modules.
Jax functional transformations operate on pure functions. Flax extends these transformations to also operate on Module’s which have stateful variables and PRNG sequences. We refer to these extended versions as “lifted transformations”.
A lifted transformation can be applied to a Module
class or a
function that takes a Module
instance as its first argument.
- flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#
A lifted version of
jax.vmap
.See
jax.vmap
for the unlifted batch transform in Jax.vmap
can be used to add a batch axis to aModule
. For example we could create a version ofDense
with a batch axis that does not share parameters:>>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': 0}, ... split_rngs={'params': True})
By using
variable_axes={'params': 0}
, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the ‘params’ RNG, otherwise the parameters would be initialized identically along the mapped axis.Similarly,
vmap
could be used to add a batch axis with parameter sharing:>>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': None}, ... split_rngs={'params': False})
Here we use
variable_axes={'params': None}
to indicate the parameter variables are shared along the mapped axis. Consequently, the ‘params’ RNG must also be shared.- Parameters
target – a
Module
or a function taking aModule
as its first argument.variable_axes – the variable collections that are lifted into the batching transformation. Use
None
to indicate a broadcasted collection or an integer to map over an axis. For example, passing invariable_axes={'params': None}
will indicate that the parameter variables should be shared along the mapped axis.split_rngs – Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted.
in_axes – Specifies the mapping of the input arguments (see
jax.vmap
).out_axes – Specifies the mapping of the return value (see
jax.vmap
).axis_size – Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments.
axis_name – Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g.
jax.lax.pmean
,jax.lax.ppermute
, etc.). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.methods – If
target
is aModule
, the methods ofModule
to vmap over.spmd_axis_name – Axis name added to any pjit sharding constraints appearing in
fn
. See also google/flax.metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.
- Returns
A batched/vectorized version of
target
, with the same arguments but with extra axes at positions indicated byin_axes
, and the same return value, but with extra axes at positions indicated byout_axes
.
- flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None, _split_transpose=False)[source]#
A lifted version of
jax.lax.scan
.See
jax.lax.scan
for the unlifted scan in Jax.To improve consistency with
vmap
, this version of scan usesin_axes
andout_axes
to determine which arguments are scanned over and along which axis.scan
distinguishes between 3 different types of values inside the loop:scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.
carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.
broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.
The
target
should have the signature(module, carry, *xs) -> (carry, ys)
, wherexs
andys
are the scan values that go in and out of the loop.Example:
>>> import flax.linen as nn >>> import jax >>> import jax.numpy as jnp ... >>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ScanLSTM = nn.scan( ... nn.LSTMCell, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... lstm = ScanLSTM(self.features) ... input_shape = x[:, 0].shape ... carry = lstm.initialize_carry(jax.random.key(0), input_shape) ... carry, x = lstm(carry, x) ... return x ... >>> x = jnp.ones((4, 12, 7)) >>> module = LSTM(features=32) >>> y, variables = module.init_with_output(jax.random.key(0), x)
Note that when providing a function to
nn.scan
, the scanning happens over all arguments starting from the third argument, as specified byin_axes
. The previous example could also be written using the functional form as:>>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ... cell = nn.LSTMCell(self.features) ... def body_fn(cell, carry, x): ... carry, y = cell(carry, x) ... return carry, y ... scan = nn.scan( ... body_fn, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... input_shape = x[:, 0].shape ... carry = cell.initialize_carry( ... jax.random.key(0), input_shape) ... carry, x = scan(cell, carry, x) ... return x ... >>> module = LSTM(features=32) >>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))
You can also use
scan
to reduce the compilation time of your JAX program by merging multiple layers into a single scan loop, you can do this when you have a sequence of identical layers that you want to apply iteratively to an input. For example:>>> class ResidualMLPBlock(nn.Module): ... @nn.compact ... def __call__(self, x, _): ... h = nn.Dense(features=2)(x) ... h = nn.relu(h) ... return x + h, None ... >>> class ResidualMLP(nn.Module): ... n_layers: int = 4 ... ... @nn.compact ... def __call__(self, x): ... ScanMLP = nn.scan( ... ResidualMLPBlock, variable_axes={'params': 0}, ... variable_broadcast=False, split_rngs={'params': True}, ... length=self.n_layers) ... x, _ = ScanMLP()(x, None) ... return x ... >>> model = ResidualMLP(n_layers=4) >>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))
To reduce both compilation and memory usage, you can use
remat_scan()
which will in addition checkpoint each layer in the scan loop.- Parameters
target – a
Module
or a function taking aModule
as its first argument.variable_axes – the variable collections that are scanned over.
variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.
variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.
split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.
in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use
flax.core.broadcast
to feed an entire input to each iteration of the scan body.out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.
length – Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments.
reverse – If true, scan from end to start in reverse order.
unroll – how many scan iterations to unroll within a single iteration of a loop (default: 1).
data_transform – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.
metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.
methods – If
target
is aModule
, the methods ofModule
to scan over._split_transpose – An experimental feature to split the transpose of a scan into a scan and a map, backed by an experimental Jax lax.scan() feature.
- Returns
The scan function with the signature
(module, carry, *xs) -> (carry, ys)
, wherexs
andys
are the scan values that go in and out of the loop.
- flax.linen.jit(target, variables=True, rngs=True, static_argnums=(), static_argnames=(), donate_argnums=(), device=None, backend=None, methods=None)[source]#
Lifted version of
jax.jit
.- Parameters
target – a
Module
or a function taking aModule
as its first argument.variables – The variable collections that are lifted. By default all collections are lifted.
rngs – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.
static_argnums – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both
__hash__
and__eq__
are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated bystatic_argnums
then an error is raised. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().static_argnames – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on
static_argnums
for details. If not provided butstatic_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.donate_argnums – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.
device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.backend – a string representing the XLA backend:
'cpu'
,'gpu'
, or'tpu'
.methods – If
target
is aModule
, the methods ofModule
to jit.
- Returns
A wrapped version of target, set up for just-in-time compilation.
- flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#
Lifted version of
jax.checkpoint
.Checkpointing is a technique for reducing memory usage by recomputing activations during backpropagation. When training large models, it can be helpful to checkpoint parts of the model to trade off memory usage for additional computation.
Example:
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CheckpointedMLP(nn.Module): ... @nn.checkpoint ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(128)(x) ... x = nn.relu(x) ... x = nn.Dense(1)(x) ... return x ... >>> model = CheckpointedMLP() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))
This function is aliased to
remat
just likejax.remat
.- Parameters
target – a
Module
or a function taking aModule
as its first argument. intermediate computations will be re-computed when computing gradients for the target.variables – The variable collections that are lifted. By default all collections are lifted.
rngs – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.
concrete – Optional, boolean indicating whether
fun
may involve value-dependent Python control flow (defaultFalse
). Support for such control flow is optional, and disabled by default, because in some edge-case compositions withjax.jit()
it can lead to some extra computation.prevent_cse – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a
jit
orpmap
, CSE can defeat the purpose of this decorator. But in some settings, like when used inside ascan
, this CSE prevention mechanism is unnecessary, in which caseprevent_cse
should be set to False.static_argnums – Optional, int or sequence of ints, indicates which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads.
policy – Experimental checkpoint policy, see
jax.checkpoint
.methods – An optional list of method names that will be lifted, if
methods
is None (default) only the__call__
method will be lifted. If``target`` is a function,methods
is ignored.
- Returns
A wrapped version of
target
. When computing gradients intermediate computations will be re-computed on the backward pass.
- flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#
Combines remat and scan for memory efficiency and constant time compilation.
remat_scan
allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models.Example:
>>> import flax.linen as nn >>> class BigModel(nn.Module): ... @nn.compact ... def __call__(self, x): ... DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10)) ... # 100x dense with O(sqrt(N)) memory for gradient computation ... return DenseStack(8, name="dense_stack")(x)
- Parameters
target – a
Module
or a function taking aModule
as its first argument.lengths – number of loop iterations at the given level. The total number of iterations
n = prod(lengths)
. each loop is rematerialized. This way the memory consumption is proportional ton^(1 / d)
whered = len(lengths)
. Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop.policy – Experimental checkpoint policy, see
jax.checkpoint
.variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.
variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.
variable_axes – the variable collections that are scanned over. Defaults to
{True: 0}
.split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Defaults to
{True: True}
.
- Returns
A wrapped version of
target
that repeats itself prod(lengths) times.
- flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#
Map Variables inside a module.
map_variables
can be used to transform the variables inside a module both before and after the module is applied. This is useful among other things for masking the weights of a module without having to modify the module itself.Example:
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CausalDense(nn.Module): ... '''A dense layer that masks the weights such that the output is ... causal, i.e. output i only depends on input <= i. ... ''' ... features: int ... ... def apply_mask(self, variables): ... return (jax.tree_util.tree_map(jnp.triu, variables) ... if not self.is_initializing() else variables) ... ... def setup(self): ... # temporary class ... _CausalDense = nn.map_variables( ... nn.Dense, 'params', self.apply_mask, init=self.is_initializing()) ... self.dense = _CausalDense(features=self.features, use_bias=False) ... ... def __call__(self, x): ... return self.dense(x) ... >>> module = CausalDense(features=5) >>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
- Parameters
target – the module or function to be transformed.
mapped_collections – the collection(s) to be transformed.
trans_in_fn – modifies the variables before applying the module or function.
trans_out_fn – modifies the variables after applying the module or function, it is only applied if either
init
ormutable
are not False.init – If True, variables are initialized before transformation.
mutable – If True, the mapped variable collections will be mutable.
rngs – PRNGSequences added to the transformed scope (default: all).
variables – Additional Variable collections added to the transformed scope. Besides those specified by
target
(default: all).methods – If
target
is aModule
, the methods ofModule
to map variables for.
- Returns
a wrapped version of
target
that will map the specified collections.
- flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[source]#
A lifted version of
jax.jvp
.See
jax.jvp
for the unlifted Jacobian-vector product (forward gradient).Note that no tangents are returned for variables. When variable tangents are required their value should be returned explicitly by
fn
usingModule.variables
:>>> import flax.linen as nn >>> import jax.numpy as jnp >>> class LearnScale(nn.Module): ... @nn.compact ... def __call__(self, x): ... p = self.param('test', nn.initializers._init(), ()) ... return p * x >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... scale = LearnScale() ... vars_t = jax.tree_util.tree_map(jnp.ones_like, ... scale.variables.get('params', {})) ... _, out_t = nn.jvp( ... lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),), ... variable_tangents={'params': vars_t}) ... return out_t
Example:
>>> def learn_scale(scope, x): ... p = scope.param('scale', nn.initializers.zeros_init(), ()) ... return p * x >>> def f(scope, x): ... vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {})) ... x, out_t = lift.jvp( ... learn_scale, scope, (x,), (jnp.zeros_like(x),), ... variable_tangents={'params': vars_t}) ... return out_t
- Parameters
fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.
mdl – The module of which the variables will be differentiated.
primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun
.tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals
.variable_tangents – A dict or PyTree fo dicts with the same structure as scopes. Each entry in the dict specifies the tangents for a variable collection. Not specifying a collection in variable_tangents is equivalent to passing a zero vector as the tangent.
variables – other variables collections that are available in
fn
but do not receive a tangent.rngs – the prngs that are available inside
fn
.
- Returns
A
(primals_out, tangents_out)
pair, whereprimals_out
isfun(*primals)
, andtangents_out
is the Jacobian-vector product offunction
evaluated atprimals
withtangents
. Thetangents_out
value has the same Python tree structure and shapes asprimals_out
.
- flax.linen.vjp(fn, mdl, *primals, has_aux=False, reduce_axes=(), vjp_variables='params', variables=True, rngs=True, multi_scope=False)[source]#
A lifted version of
jax.vjp
.See
jax.vjp
for the unlifted vector-Jacobian product (backward gradient).Note that a gradient is returned for all variables in the collections specified by
vjp_variables
. However, the backward function only expects a cotangent for the return value offn
. If variables require a co-tangent as well they can be returned fromfn
usingModule.variables
.Example:
>>> import flax.linen as nn >>> import jax.numpy as jnp >>> class LearnScale(nn.Module): ... @nn.compact ... def __call__(self, x, y): ... p = self.param('scale', nn.initializers.zeros_init(), ()) ... return p * x * y >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, y): ... z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) ... params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape)) ... return z, params_grad, x_grad, y_grad
- Parameters
fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.
mdl – The module of which the variables will be differentiated.
*primals – A sequence of primal values at which the Jacobian of
fn
should be evaluated. The length ofprimals
should be equal to the number of positional parameters tofn
. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.has_aux – Optional, bool. Indicates whether
fn
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. DefaultFalse
.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fn
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if'batch'
is a named batch axis,vjp(f, *args, reduce_axes=('batch',))
will create a VJP function that sums over the batch whilevjp(f, *args)
will create a per-example VJP.vjp_variables – The vjpfun will return a cotangent vector for all variable collections specified by this filter.
variables – other variables collections that are available inside
fn
but do not receive a cotangent.rngs – the prngs that are available inside
fn
.multi_scope – for Modules containing multiple scopes from outside modules passed in, allow for variable gradients to be returned for multiple scopes instead of erroring.
- Returns
If
has_aux
isFalse
, returns a(primals_out, vjpfun)
pair, whereprimals_out
isfn(*primals)
.vjpfun
is a function from a cotangent vector with the same shape asprimals_out
to a tuple of cotangent vectors with the same shape asprimals
, representing the vector-Jacobian product offn
evaluated atprimals
. Ifhas_aux
isTrue
, returns a(primals_out, vjpfun, aux)
tuple whereaux
is the auxiliary data returned byfn
.
- flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[source]#
Lifted version of
jax.custom_vjp
.forward_fn
andbackward_fn
together define a custom vjp forfn
. The originalfn
will run in case a vjp (backward gradient) is not computed.The
forward_fn
receives the same arguments asfn
but is expected to return a tuple containing the output offn(mdl, *args)
and the residuals that are passed tobackward_fn
.The
backward_fn
receives the nondiff arguments, residuals, and the output tangents. It should return a tuple containing the variable and input tangents.Note that the vjp function returned by
nn.vjp
can be passed as residual and used in thebackward_fn
. The scope is unavailable during the backward pass. If the module is required inbackward_fn
, a snapshot of the variables can be taken and returned as a residual in theforward_fn
.Example:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... def f(mdl, x): ... return mdl(x) ... ... def fwd(mdl, x): ... return nn.vjp(f, mdl, x) ... ... def bwd(vjp_fn, y_t): ... params_t, *inputs_t = vjp_fn(y_t) ... params_t = jax.tree_util.tree_map(jnp.sign, params_t) ... return (params_t, *inputs_t) ... ... sign_grad = nn.custom_vjp( ... f, forward_fn=fwd, backward_fn=bwd) ... return sign_grad(nn.Dense(1), x).reshape(()) >>> x = jnp.ones((2,)) >>> variables = Foo().init(jax.random.key(0), x) >>> grad = jax.grad(Foo().apply)(variables, x)
- Parameters
fn – The function to define a custom_vjp for.
forward_fn – A function with the same arguments as
fn
returning an tuple with the original output and the residuals that will be passsed tobackward_fn
.backward_fn – arguments are passed as
(*nondiff_args, residuals, tangents)
The function should return a tuple containing the tangents for the variable in the collections specified bygrad_vars
and the input arguments (except the module and nondiff args).grad_vars – The collections for which a vjp will be computed (default: “params”).
nondiff_argnums – arguments for which no vjp is computed.
- Returns
A function with the same signature as
fn
with the custom vjp.
- flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#
Lifted version of jax.lax.while_loop.
The lifted scope is passed to
cond_fn
andbody_fn
. Broadcasted variables are immutable. The carry variable are mutable but cannot change shape and dtype. This also means you cannot initialize variables inside the body. Consider callingbody_fn
once manually before callingwhile_loop
if variable initialization is required.Example:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class WhileLoopExample(nn.Module): ... @nn.compact ... def __call__(self, x): ... def cond_fn(mdl, c): ... return mdl.variables['state']['acc'] < 10 ... def body_fn(mdl, c): ... acc = mdl.variable('state', 'acc', lambda: jnp.array(0)) ... acc.value += 1 ... y = nn.Dense(c.shape[-1])(c) ... return y ... c = x ... if self.is_mutable_collection('params'): ... return body_fn(self, c) ... else: ... return nn.while_loop(cond_fn, body_fn, self, c, ... carry_variables='state') >>> k = jax.random.key(0) >>> x = jnp.ones((2, 2)) >>> initial_vars = WhileLoopExample().init(k, x) >>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state'])
- Parameters
cond_fn – Should return True as long as the loop should continue.
body_fn – The body of the while loop.
mdl – The Module which should be lifted into the loop.
init – The initial state passed to the loop
carry_variables – collections that are carried through the loop and are therefore mutable (default: none).
broadcast_variables – collections that are closed over and are therefore read-only (default: all collections)
split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.
- Returns
The final state after executing the while loop.
- flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#
Lifted version of
jax.lax.cond
.The returned values from
true_fun
andfalse_fun
must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.Example:
>>> import flax.linen as nn >>> class CondExample(nn.Module): ... @nn.compact ... def __call__(self, x, pred): ... self.variable('state', 'true_count', lambda: 0) ... self.variable('state', 'false_count', lambda: 0) ... def true_fn(mdl, x): ... mdl.variable('state', 'true_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... def false_fn(mdl, x): ... mdl.variable('state', 'false_count').value += 1 ... return -nn.Dense(2, name='dense')(x) ... return nn.cond(pred, true_fn, false_fn, self, x)
- Parameters
pred – determines if true_fun or false_fun is evaluated.
true_fun – The function evaluated when
pred
isTrue
. The signature is (module, *operands) -> T.false_fun – The function evaluated when
pred
isFalse
. The signature is (module, *operands) -> T.mdl – A Module target to pass.
*operands – The arguments passed to
true_fun
andfalse_fun
variables – The variable collections passed to the conditional branches (default: all)
rngs – The PRNG sequences passed to the conditionals (default: all)
- Returns
The result of the evaluated branch (
true_fun
orfalse_fun
).
- flax.linen.switch(index, branches, mdl, *operands, variables=True, rngs=True)[source]#
Lifted version of
jax.lax.switch
.The returned values from
branches
must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.Example:
>>> import flax.linen as nn >>> class SwitchExample(nn.Module): ... @nn.compact ... def __call__(self, x, index): ... self.variable('state', 'a_count', lambda: 0) ... self.variable('state', 'b_count', lambda: 0) ... self.variable('state', 'c_count', lambda: 0) ... def a_fn(mdl, x): ... mdl.variable('state', 'a_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... def b_fn(mdl, x): ... mdl.variable('state', 'b_count').value += 1 ... return -nn.Dense(2, name='dense')(x) ... def c_fn(mdl, x): ... mdl.variable('state', 'c_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... return nn.switch(index, [a_fn, b_fn, c_fn], self, x)
If you want to have a different parameter structure for each branch you should run all branches on initialization before calling switch:
>>> class MultiHeadSwitchExample(nn.Module): ... def setup(self) -> None: ... self.heads = [ ... nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]), ... nn.Sequential([nn.Dense(11), nn.Dense(5)]), ... nn.Dense(5), ... ] ... ... @nn.compact ... def __call__(self, x, index): ... def head_fn(i): ... return lambda mdl, x: mdl.heads[i](x) ... branches = [head_fn(i) for i in range(len(self.heads))] ... ... # run all branches on init ... if self.is_mutable_collection('params'): ... for branch in branches: ... _ = branch(self, x) ... ... return nn.switch(index, branches, self, x)
- Parameters
index – Integer scalar type, indicating which branch function to apply.
branches – Sequence of functions to be applied based on index. The signature of each function is (module, *operands) -> T.
mdl – A Module target to pass.
*operands – The arguments passed to the branches.
variables – The variable collections passed to the conditional branches (default: all)
rngs – The PRNG sequences passed to the conditionals (default: all)
- Returns
The result of the evaluated branch.