Module#

The Flax Module system.

class flax.linen.Module[source]#

Base class for all neural network modules.

Layers and models should subclass this class.

All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over __init__, you should instead override setup(), which is automatically called to initialize the module.

Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the setup() method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice because it allows you to use module instances as if they are functions:

>>> from flax import linen as nn
>>> from typing import Tuple

>>> class Module(nn.Module):
...   features: Tuple[int, ...] = (16, 4)

...   def setup(self):
...     self.dense1 = nn.Dense(self.features[0])
...     self.dense2 = nn.Dense(self.features[1])

...   def __call__(self, x):
...     return self.dense2(nn.relu(self.dense1(x)))

Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the compact() wrapper.

__setattr__(name, val)[source]#

Sets an attribute on this Module.

We overload setattr solely to support pythonic naming via assignment of submodules in the special setup() function:

self.submodule_name = MyModule(...)

We also support lists and other general pytrees, e.g.:

self.submodules = [MyModule0(..), MyModule1(..), ...]
Parameters
  • name – Attribute to set.

  • val – Value of the attribute.

apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)[source]#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Transformer(nn.Module):
...   def encode(self, x):
...     ...

>>> x = jnp.ones((16, 9))
>>> model = Transformer()
>>> variables = model.init(jax.random.key(0), x, method=Transformer.encode)

>>> encoded = model.apply(variables, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

>>> encoded = model.apply(variables, x, method=model.encode)

You can also pass a string to a callable attribute of the module. For example, the previous can be written as:

>>> encoded = model.apply(variables, x, method='encode')

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

>>> def other_fn(instance, x):
...   # instance.some_module_attr(...)
...   instance.encode
...   ...

>>> model.apply(variables, x, method=other_fn)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to apply. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, add_noise=False):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     if add_noise:
...       # Add gaussian noise
...       noise_key = self.make_rng('noise')
...       x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)}
>>> variables = module.init(rngs, x)
>>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs)

>>> rngs['noise'] = jax.random.key(0)
>>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # different output (key(1) vs key(0))
>>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1)

>>> del rngs['noise']
>>> # self.make_rng('noise') will default to using the 'params' RNG stream
>>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs)
>>> # same output (key(0))
>>> np.testing.assert_allclose(out1, out2)

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0))
>>> # same output (key(0))
>>> np.testing.assert_allclose(out2, out3)
Parameters
  • variables – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments passed to the specified apply method.

  • rngs – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.

  • method – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module. A string can also be provided to specify a method by name.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the specified apply method.

Returns

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.

bind(variables, *args, rngs=None, mutable=False)[source]#

Creates an interactive Module instance by binding variables and RNGs.

bind provides an “interactive” instance of a Module directly without transforming a function with apply. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.

Once the variables (and optionally RNGs) are bound to a Module it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. bind() should only be used for interactive experimentation, and in all other cases we strongly encourage users to use apply() instead.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = nn.Dense(3)
...     self.decoder = nn.Dense(5)
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> x = jnp.ones((16, 9))
>>> ae = AutoEncoder()
>>> variables = ae.init(jax.random.key(0), x)
>>> model = ae.bind(variables)
>>> z = model.encoder(x)
>>> x_reconstructed = model.decoder(z)
Parameters
  • variables – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments (not used).

  • rngs – a dict of PRNGKeys to initialize the PRNG sequences.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

Returns

A copy of this instance with bound variables and RNGs.

copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)[source]#

Creates a copy of this Module, with optionally updated arguments.

Parameters
  • parent – The parent of the copy. By default the current module is taken as parent if not explicitly specified.

  • name – A new name for the copied Module, by default a new automatic name will be given.

  • **updates – Attribute updates.

Returns

A copy of the this Module with the updated name, parent, and attributes.

get_variable(col, name, default=None)[source]#

Retrieves the value of a Variable.

Parameters
  • col – the variable collection.

  • name – the name of the variable.

  • default – the default value to return if the variable does not exist in this scope.

Returns

The value of the input variable, of the default value if the variable doesn’t exist in this scope.

has_rng(name)[source]#

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)[source]#

Checks if a variable of given collection and name exists in this Module.

See flax.core.variables for more explanation on variables and collections.

Parameters
  • col – The variable collection name.

  • name – The name of the variable.

Returns

True if the variable exists.

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#

Initializes a module method with variables and returns modified variables.

init takes as first argument either a single PRNGKey, or a dictionary mapping variable collections names to their PRNGKeys, and will call method (which is the module’s __call__ function by default) passing *args and **kwargs, and returns a dictionary of initialized variables.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import numpy as np

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)

>>> x = jnp.empty((1, 7))
>>> module = Foo()
>>> key = jax.random.key(0)
>>> variables = module.init(key, x, train=False)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to init. If self.make_rng(name) is called on an RNG stream name that isn’t passed by the user, it will default to using the 'params' RNG stream.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(16)(x)
...     x = nn.relu(x)
...
...     other_variable = self.variable(
...       'other_collection',
...       'other_variable',
...       lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape),
...       x,
...     )
...     x = x + other_variable.value
...
...     return nn.Dense(1)(x)

>>> module = Foo()
>>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)}
>>> variables0 = module.init(rngs, x)

>>> rngs['other_rng'] = jax.random.key(0)
>>> variables1 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables0['params'], variables1['params']
... )
>>> # different other_variable (key(1) vs key(0))
>>> np.testing.assert_raises(
...   AssertionError,
...   np.testing.assert_allclose,
...   variables0['other_collection']['other_variable'],
...   variables1['other_collection']['other_variable'],
... )

>>> del rngs['other_rng']
>>> # self.make_rng('other_rng') will default to using the 'params' RNG stream
>>> variables2 = module.init(rngs, x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables1['params'], variables2['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables1['other_collection']['other_variable'],
...   variables2['other_collection']['other_variable'],
... )

>>> # passing in a single key is equivalent to passing in {'params': key}
>>> variables3 = module.init(jax.random.key(0), x)
>>> # equivalent params (key(0))
>>> _ = jax.tree_util.tree_map(
...   np.testing.assert_allclose, variables2['params'], variables3['params']
... )
>>> # equivalent other_variable (key(0))
>>> np.testing.assert_allclose(
...   variables2['other_collection']['other_variable'],
...   variables3['other_collection']['other_variable'],
... )

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.key(0), x)

init is a light wrapper over apply, so other apply arguments like method, mutable, and capture_intermediates are also available.

Parameters
  • rngs – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method – An optional method. If provided, applies this method. If not provided, applies the __call__ method. A string can also be provided to specify a method by name.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Returns

The initialized variable dict.

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#

Initializes a module method with variables and returns output and modified variables.

Parameters
  • rngs – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method – An optional method. If provided, applies this method. If not provided, applies the __call__ method. A string can also be provided to specify a method by name.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Returns

(output, vars), where vars are is a dict of the modified collections.

is_initializing()[source]#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

is_mutable_collection(col)[source]#

Returns true if the collection col is mutable.

lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)[source]#

Initializes a module without computing on an actual input.

lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a jax.ShapeDtypeStruct which specifies the shape and dtype of the input but no concrete data.

Example:

>>> model = nn.Dense(features=256)
>>> variables = model.lazy_init(
...     jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))

The args and kwargs args passed to lazy_init can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise lazy_init cannot infer which variables should be initialized.

Parameters
  • rngs – The rngs for the variable collections.

  • *args – arguments passed to the init function.

  • method – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • **kwargs – Keyword arguments passed to the init function.

Returns

The initialized variable dict.

make_rng(name='params')[source]#

Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility.

Note

If an invalid name is passed (i.e. no RNG key was passed by the user in .init or .apply for this name), then name will default to 'params'.

Example:

>>> import jax
>>> import flax.linen as nn

>>> class ParamsModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('params')
>>> class OtherModule(nn.Module):
...   def __call__(self):
...     return self.make_rng('other')

>>> key = jax.random.key(0)
>>> params_out, _ = ParamsModule().init_with_output({'params': key})
>>> # self.make_rng('other') will default to using the 'params' RNG stream
>>> other_out, _ = OtherModule().init_with_output({'params': key})
>>> assert params_out == other_out

Learn more about RNG’s by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html

Parameters

name – The RNG sequence name.

Returns

The newly generated RNG key.

module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)[source]#

Returns a dictionary mapping module paths to module instances.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns a dictionary mapping module paths to unbounded copies of module instances that were used at runtime. module_paths uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> modules = Foo().module_paths(jax.random.key(0), x)
>>> print({
...     p: type(m).__name__ for p, m in modules.items()
... })
{'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
Parameters
  • rngs – The rngs for the variable collections as passed to Module.init.

  • *args – The arguments to the forward computation.

  • show_repeated – If True, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is False.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.

  • **kwargs – keyword arguments to pass to the forward computation.

Returns

A dict`ionary mapping module paths to module instances.

param(name, init_fn, *init_args, unbox=True, **init_kwargs)[source]#

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named “params”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args or init_kwargs:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     mean = self.param('mean', nn.initializers.lecun_normal(), x.shape)
...     ...
...     return x * mean
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Parameters
  • name – The parameter name.

  • init_fn – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.

  • *init_args – The positional arguments to pass to init_fn.

  • unbox – If True, AxisMetadata instances are replaced by their unboxed value, see flax.nn.meta.unbox (default: True).

  • **init_kwargs – The key-word arguments to pass to init_fn.

Returns

The value of the initialized parameter. Throws an error if the parameter exists already.

property path#

Get the path of this Module. Top-level root modules have an empty path (). Note that this method can only be used on bound modules that have a valid scope.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class SubModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'SubModel path: {self.path}')
...     return x

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     print(f'Model path: {self.path}')
...     return SubModel()(x)

>>> model = Model()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 2)))
Model path: ()
SubModel path: ('SubModel_0',)
perturb(name, value, collection='perturbations')[source]#

Add an zero-value variable (‘perturbation’) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation argument.

Note

This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = self.perturb('dense3', x)
...     return nn.Dense(2)(x)

>>> def loss(variables, inputs, targets):
...   preds = model.apply(variables, inputs)
...   return jnp.square(preds - targets).mean()

>>> x = jnp.ones((2, 9))
>>> y = jnp.ones((2, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
>>> print(intm_grads['perturbations']['dense3'])
[[-1.456924   -0.44332537  0.02422847]
 [-1.456924   -0.44332537  0.02422847]]

If perturbations are not passed to apply, perturb behaves like a no-op so you can easily disable the behavior when not needed:

>>> model.apply(variables, x) # works as expected
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
Array([[-1.0980128 , -0.67961735],
       [-1.0980128 , -0.67961735]], dtype=float32)
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
>>> 'perturbations' not in intm_grads
True
put_variable(col, name, value)[source]#

Updates the value of the given variable if it is mutable, or an error otherwise.

Parameters
  • col – the variable collection.

  • name – the name of the variable.

  • value – the new value of the variable.

setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     self.sow('intermediates', 'h', h)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(variables, x, mutable=['intermediates'])
>>> jax.tree.map(jnp.shape, state['intermediates'])
{'h': ((16, 4),)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

>>> class Foo2(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     init_fn = lambda: 0
...     reduce_fn = lambda a, b: a + b
...     self.sow('intermediates', 'h', x,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     self.sow('intermediates', 'h', x * 2,
...               init_fn=init_fn, reduce_fn=reduce_fn)
...     return x

>>> x = jnp.ones((1, 1))
>>> model = Foo2()
>>> variables = model.init(jax.random.key(0), x)
>>> y, state = model.apply(
...     variables, x, mutable=['intermediates'])
>>> print(state['intermediates'])
{'h': Array([[3.]], dtype=float32)}
Parameters
  • col – The name of the variable collection.

  • name – The name of the variable.

  • value – The value of the variable.

  • reduce_fn – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Returns

True if the value has been stored successfully, False otherwise.

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#

Creates a summary of the Module represented as a table.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Additional arguments can be passed into the console_kwargs argument, for example, {'width': 120}. For a full list of console_kwargs arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))

>>> # print(Foo().tabulate(
>>> #     jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))

This gives the following output:

                                      Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                              Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Note: vjp_flops returns 0 if the module is not differentiable.

Parameters
  • rngs – The rngs for the variable collections as passed to Module.init.

  • *args – The arguments to the forward computation.

  • depth – controls how many submodule deep the summary can go. By default, its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • show_repeated – If True, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is False.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.

  • console_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.console.Console when rendering the table. Default arguments are {'force_terminal': True, 'force_jupyter': False}.

  • table_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.table.Table constructor.

  • column_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.table.Table.add_column when adding columns to the table.

  • compute_flops – whether to include a flops column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion’s UNet, whereas otherwise tabulation would finish in 5 seconds).

  • compute_vjp_flops – whether to include a vjp_flops column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of compute_flops.

  • **kwargs – keyword arguments to pass to the forward computation.

Returns

A string summarizing the Module.

unbind()[source]#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

>>> class Encoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(256)(x)

>>> class Decoder(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     ...
...     return nn.Dense(784)(x)

>>> class AutoEncoder(nn.Module):
...   def setup(self):
...     self.encoder = Encoder()
...     self.decoder = Decoder()
...
...   def __call__(self, x):
...     return self.decoder(self.encoder(x))

>>> module = AutoEncoder()
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 784)))

>>> # Extract the Encoder sub-Module and its variables
>>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
Returns

A tuple with an unbound copy of this Module and its variables.

variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)[source]#

Declares and returns a variable in this Module.

See flax.core.variables for more information. See also param() for a shorthand way to define read-only variables in the “params” collection.

Contrary to param(), all arguments passing using init_fn should be passed on explicitly:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(4)(x)
...     key = self.make_rng('stats')
...     mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape)
...     ...
...     return x * mean.value
>>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}

In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling init() and apply().

Parameters
  • col – The variable collection name.

  • name – The variable name.

  • init_fn – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.

  • *init_args – The positional arguments to pass to init_fn.

  • unbox – If True, AxisMetadata instances are replaced by their unboxed value, see flax.nn.meta.unbox (default: True).

  • **init_kwargs – The key-word arguments to pass to init_fn

Returns

A flax.core.variables.Variable that can be read or set via “.value” attribute. Throws an error if the variable exists already.

property variables#

Returns the variables in this module.

flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#

Creates an apply function to call fn with a bound module.

Unlike Module.apply this function returns a new function with the signature (variables, *args, rngs=None, **kwargs) -> T where T is the return type of fn. If mutable is not False the return type is a tuple where the second item is a FrozenDict with the mutated variables.

The apply function that is returned can be directly composed with JAX transformations like jax.jit:

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> variables = {}
>>> foo = Foo()
>>> f_jitted = jax.jit(nn.apply(f, foo))
>>> f_jitted(variables, jnp.ones((1, 3)))
Parameters
  • fn – The function that should be applied. The first argument passed will be a module instance of the module with variables and RNGs bound to it.

  • module – The Module that will be used to bind variables and RNGs to. The Module passed as the first argument to fn will be a clone of module.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

Returns

The apply function wrapping fn.

flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

Creates an init function to call fn with a bound module.

Unlike Module.init this function returns a new function with the signature (rngs, *args, **kwargs) -> variables. The rngs can be a dict of PRNGKeys or a single `PRNGKey which is equivalent to passing a dict with one PRNGKey with the name “params”.

The init function that is returned can be directly composed with JAX transformations like jax.jit:

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init(f, foo))
>>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
Parameters
  • fn – The function that should be applied. The first argument passed will be a module instance of the module with variables and RNGs bound to it.

  • module – The Module that will be used to bind variables and RNGs to. The Module passed as the first argument to fn will be a clone of module.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

Returns

The init function wrapping fn.

flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

Creates an init function to call fn with a bound module that also returns the function outputs.

Unlike Module.init_with_output this function returns a new function with the signature (rngs, *args, **kwargs) -> (T, variables) where T is the return type of fn. The rngs can be a dict of PRNGKeys or a single `PRNGKey which is equivalent to passing a dict with one PRNGKey with the name “params”.

The init function that is returned can be directly composed with JAX transformations like jax.jit:

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init_with_output(f, foo))
>>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
Parameters
  • fn – The function that should be applied. The first argument passed will be a module instance of the module with variables and RNGs bound to it.

  • module – The Module that will be used to bind variables and RNGs to. The Module passed as the first argument to fn will be a clone of module.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

Returns

The init function wrapping fn.

flax.linen.intercept_methods(interceptor)[source]#

Registers a new method interceptor.

Method interceptors allow you to (at a distance) intercept method calls to modules. It works similarly to decorators. You could modify args/kwargs before calling the underlying method and/or modify the result returning from calling the underlying method. Or you could completely skip calling the underlying method and decide to do something differently. For example:

>>> import flax.linen as nn
>>> import jax.numpy as jnp
...
>>> class Foo(nn.Module):
...   def __call__(self, x):
...     return x
...
>>> def my_interceptor1(next_fun, args, kwargs, context):
...   print('calling my_interceptor1')
...   return next_fun(*args, **kwargs)
...
>>> foo = Foo()
>>> with nn.intercept_methods(my_interceptor1):
...   _ = foo(jnp.ones([1]))
calling my_interceptor1

You could also register multiple interceptors on the same method. Interceptors will run in order. For example:

>>> def my_interceptor2(next_fun, args, kwargs, context):
...   print('calling my_interceptor2')
...   return next_fun(*args, **kwargs)
...
>>> with nn.intercept_methods(my_interceptor1), \
...      nn.intercept_methods(my_interceptor2):
...   _ = foo(jnp.ones([1]))
calling my_interceptor1
calling my_interceptor2

You could skip other interceptors by directly calling the context.orig_method. For example:

>>> def my_interceptor3(next_fun, args, kwargs, context):
...   print('calling my_interceptor3')
...   return context.orig_method(*args, **kwargs)
>>> with nn.intercept_methods(my_interceptor3), \
...      nn.intercept_methods(my_interceptor1), \
...      nn.intercept_methods(my_interceptor2):
...   _ = foo(jnp.ones([1]))
calling my_interceptor3

The following methods couldn’t be intercepted:

  1. Methods decoratored with nn.nowrap.

  2. Dunder methods including __eq__, __repr__, __init__, __hash__, and __post_init__.

  3. Module dataclass fields.

  4. Module descriptors.

Parameters

interceptor – A method interceptor.

flax.linen.share_scope(module, other, /)[source]#

Modifies one of the Modules such that they share the same scope. This is useful when you want to wrap a Module and extend its functionality without changing the parameter structure.

share_scope takes two Modules, module and other. module will use other’s scope if other has a scope and its not a descendant of``module``’s scope:

>>> import flax.linen as nn
>>> import jax
>>> from jax import numpy as jnp, random
...
>>> class DenseLoRA(nn.Module):
...   base: nn.Dense
...   rank: int
...
...   def setup(self):
...     nn.share_scope(self, self.base)
...
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     din, dout = x.shape[-1], self.base.features
...     A = self.param('A', nn.zeros_init(), (din, self.rank))
...     B = self.param('B', nn.zeros_init(), (self.rank, dout))
...     return self.base(x) + x @ A @ B
...
>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     dense = nn.Dense(10) # base scope
...     return DenseLoRA(dense, rank=2)(x) # reuse the base scope
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['Dense_0'].keys())
['A', 'B', 'kernel', 'bias']

When other’s scope is a descendant of module’s scope then other will use module’s scope instead:

>>> class DenseLoRA(nn.Module):
...   features: int
...   rank: int
...
...   def setup(self):
...     self.child = nn.Dense(self.features)
...     nn.share_scope(self, self.child)
...
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     din, dout = x.shape[-1], self.features
...     A = self.param('A', nn.zeros_init(), (din, self.rank))
...     B = self.param('B', nn.zeros_init(), (self.rank, dout))
...     return self.child(x) + x @ A @ B
...
>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x: jax.Array):
...     return DenseLoRA(10, rank=2)(x)
...
>>> model = Model()
...
>>> params = model.init(random.key(0), jnp.ones((1, 5)))['params']
>>> list(params['DenseLoRA_0'].keys())
['A', 'B', 'kernel', 'bias']