Module#
The Flax Module system.
- class flax.linen.Module[source]#
Base class for all neural network modules.
Layers and models should subclass this class.
All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over
__init__
, you should instead overridesetup()
, which is automatically called to initialize the module.Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the
setup()
method.You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased,
__call__
is a popular choice because it allows you to use module instances as if they are functions:>>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the
compact()
wrapper.- __setattr__(name, val)[source]#
Sets an attribute on this Module.
We overload setattr solely to support pythonic naming via assignment of submodules in the special
setup()
function:self.submodule_name = MyModule(...)
We also support lists and other general pytrees, e.g.:
self.submodules = [MyModule0(..), MyModule1(..), ...]
- Parameters
name – Attribute to set.
val – Value of the attribute.
- apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)[source]#
Applies a module method to variables and returns output and modified variables.
Note that
method
should be set if one would like to callapply
on a different class method than__call__
. For instance, suppose a Transformer modules has a method calledencode
, then the following callsapply
on that method:>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Transformer(nn.Module): ... def encode(self, x): ... ... >>> x = jnp.ones((16, 9)) >>> model = Transformer() >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) >>> encoded = model.apply(variables, x, method=Transformer.encode)
If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:
>>> encoded = model.apply(variables, x, method=model.encode)
You can also pass a string to a callable attribute of the module. For example, the previous can be written as:
>>> encoded = model.apply(variables, x, method='encode')
Note
method
can also be a function that is not defined inTransformer
. In that case, the function should have at least one argument representing an instance of the Module class:>>> def other_fn(instance, x): ... # instance.some_module_attr(...) ... instance.encode ... ... >>> model.apply(variables, x, method=other_fn)
If you pass a single
PRNGKey
, Flax will use it to feed the'params'
RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its correspondingPRNGKey
toapply
. Ifself.make_rng(name)
is called on an RNG stream name that isn’t passed by the user, it will default to using the'params'
RNG stream.Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, add_noise=False): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... if add_noise: ... # Add gaussian noise ... noise_key = self.make_rng('noise') ... x = x + jax.random.normal(noise_key, x.shape) ... ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} >>> variables = module.init(rngs, x) >>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> rngs['noise'] = jax.random.key(0) >>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # different output (key(1) vs key(0)) >>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1) >>> del rngs['noise'] >>> # self.make_rng('noise') will default to using the 'params' RNG stream >>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # same output (key(0)) >>> np.testing.assert_allclose(out1, out2) >>> # passing in a single key is equivalent to passing in {'params': key} >>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0)) >>> # same output (key(0)) >>> np.testing.assert_allclose(out2, out3)
- Parameters
variables – A dictionary containing variables keyed by variable collections. See
flax.core.variables
for more details about variables.*args – Named arguments passed to the specified apply method.
rngs – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.
method – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the
__call__
method of the module. A string can also be provided to specify a method by name.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.capture_intermediates – If
True
, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the specified apply method.
- Returns
If
mutable
is False, returns output. If any collections are mutable, returns(output, vars)
, wherevars
are is a dict of the modified collections.
- bind(variables, *args, rngs=None, mutable=False)[source]#
Creates an interactive Module instance by binding variables and RNGs.
bind
provides an “interactive” instance of a Module directly without transforming a function withapply
. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.Once the variables (and optionally RNGs) are bound to a
Module
it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs.bind()
should only be used for interactive experimentation, and in all other cases we strongly encourage users to useapply()
instead.Example:
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = nn.Dense(3) ... self.decoder = nn.Dense(5) ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> x = jnp.ones((16, 9)) >>> ae = AutoEncoder() >>> variables = ae.init(jax.random.key(0), x) >>> model = ae.bind(variables) >>> z = model.encoder(x) >>> x_reconstructed = model.decoder(z)
- Parameters
variables – A dictionary containing variables keyed by variable collections. See
flax.core.variables
for more details about variables.*args – Named arguments (not used).
rngs – a dict of PRNGKeys to initialize the PRNG sequences.
mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.
- Returns
A copy of this instance with bound variables and RNGs.
- copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)[source]#
Creates a copy of this Module, with optionally updated arguments.
- Parameters
parent – The parent of the copy. By default the current module is taken as parent if not explicitly specified.
name – A new name for the copied Module, by default a new automatic name will be given.
**updates – Attribute updates.
- Returns
A copy of the this Module with the updated name, parent, and attributes.
- get_variable(col, name, default=None)[source]#
Retrieves the value of a Variable.
- Parameters
col – the variable collection.
name – the name of the variable.
default – the default value to return if the variable does not exist in this scope.
- Returns
The value of the input variable, of the default value if the variable doesn’t exist in this scope.
- has_variable(col, name)[source]#
Checks if a variable of given collection and name exists in this Module.
See
flax.core.variables
for more explanation on variables and collections.- Parameters
col – The variable collection name.
name – The name of the variable.
- Returns
True if the variable exists.
- init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#
Initializes a module method with variables and returns modified variables.
init
takes as first argument either a singlePRNGKey
, or a dictionary mapping variable collections names to theirPRNGKeys
, and will callmethod
(which is the module’s__call__
function by default) passing*args
and**kwargs
, and returns a dictionary of initialized variables.Example:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(16)(x) ... x = nn.BatchNorm(use_running_average=not train)(x) ... x = nn.relu(x) ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> key = jax.random.key(0) >>> variables = module.init(key, x, train=False)
If you pass a single
PRNGKey
, Flax will use it to feed the'params'
RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its correspondingPRNGKey
toinit
. Ifself.make_rng(name)
is called on an RNG stream name that isn’t passed by the user, it will default to using the'params'
RNG stream.Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... other_variable = self.variable( ... 'other_collection', ... 'other_variable', ... lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape), ... x, ... ) ... x = x + other_variable.value ... ... return nn.Dense(1)(x) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)} >>> variables0 = module.init(rngs, x) >>> rngs['other_rng'] = jax.random.key(0) >>> variables1 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables0['params'], variables1['params'] ... ) >>> # different other_variable (key(1) vs key(0)) >>> np.testing.assert_raises( ... AssertionError, ... np.testing.assert_allclose, ... variables0['other_collection']['other_variable'], ... variables1['other_collection']['other_variable'], ... ) >>> del rngs['other_rng'] >>> # self.make_rng('other_rng') will default to using the 'params' RNG stream >>> variables2 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables1['params'], variables2['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables1['other_collection']['other_variable'], ... variables2['other_collection']['other_variable'], ... ) >>> # passing in a single key is equivalent to passing in {'params': key} >>> variables3 = module.init(jax.random.key(0), x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables2['params'], variables3['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables2['other_collection']['other_variable'], ... variables3['other_collection']['other_variable'], ... )
Jitting
init
initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:>>> module = nn.Dense(1) >>> init_jit = jax.jit(module.init) >>> variables = init_jit(jax.random.key(0), x)
init
is a light wrapper overapply
, so otherapply
arguments likemethod
,mutable
, andcapture_intermediates
are also available.- Parameters
rngs – The rngs for the variable collections.
*args – Named arguments passed to the init function.
method – An optional method. If provided, applies this method. If not provided, applies the
__call__
method. A string can also be provided to specify a method by name.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except “intermediates” are mutable.capture_intermediates – If
True
, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the init function.
- Returns
The initialized variable dict.
- init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#
Initializes a module method with variables and returns output and modified variables.
- Parameters
rngs – The rngs for the variable collections.
*args – Named arguments passed to the init function.
method – An optional method. If provided, applies this method. If not provided, applies the
__call__
method. A string can also be provided to specify a method by name.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.capture_intermediates – If
True
, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the init function.
- Returns
(output, vars)
, wherevars
are is a dict of the modified collections.
- is_initializing()[source]#
Returns True if running under self.init(…) or nn.init(…)().
This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under
module.init
ornn.init
. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.
- lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)[source]#
Initializes a module without computing on an actual input.
lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a
jax.ShapeDtypeStruct
which specifies the shape and dtype of the input but no concrete data.Example:
>>> model = nn.Dense(features=256) >>> variables = model.lazy_init( ... jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))
The args and kwargs args passed to
lazy_init
can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwiselazy_init
cannot infer which variables should be initialized.- Parameters
rngs – The rngs for the variable collections.
*args – arguments passed to the init function.
method – An optional method. If provided, applies this method. If not provided, applies the
__call__
method.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except “intermediates” are mutable.**kwargs – Keyword arguments passed to the init function.
- Returns
The initialized variable dict.
- make_rng(name='params')[source]#
Returns a new RNG key from a given RNG sequence for this Module.
The new RNG key is split from the previous one. Thus, every call to
make_rng
returns a new RNG key, while still guaranteeing full reproducibility.Note
If an invalid name is passed (i.e. no RNG key was passed by the user in
.init
or.apply
for this name), thenname
will default to'params'
.Example:
>>> import jax >>> import flax.linen as nn >>> class ParamsModule(nn.Module): ... def __call__(self): ... return self.make_rng('params') >>> class OtherModule(nn.Module): ... def __call__(self): ... return self.make_rng('other') >>> key = jax.random.key(0) >>> params_out, _ = ParamsModule().init_with_output({'params': key}) >>> # self.make_rng('other') will default to using the 'params' RNG stream >>> other_out, _ = OtherModule().init_with_output({'params': key}) >>> assert params_out == other_out
Learn more about RNG’s by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html
- Parameters
name – The RNG sequence name.
- Returns
The newly generated RNG key.
- module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)[source]#
Returns a dictionary mapping module paths to module instances.
This method has the same signature and internally calls
Module.init
, but instead of returning the variables, it returns a dictionary mapping module paths to unbounded copies of module instances that were used at runtime.module_paths
usesjax.eval_shape
to run the forward computation without consuming any FLOPs or allocating memory.Example:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> modules = Foo().module_paths(jax.random.key(0), x) >>> print({ ... p: type(m).__name__ for p, m in modules.items() ... }) {'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
- Parameters
rngs – The rngs for the variable collections as passed to
Module.init
.*args – The arguments to the forward computation.
show_repeated – If
True
, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default isFalse
.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.**kwargs – keyword arguments to pass to the forward computation.
- Returns
A dict`ionary mapping module paths to module instances.
- param(name, init_fn, *init_args, unbox=True, **init_kwargs)[source]#
Declares and returns a parameter in this Module.
Parameters are read-only variables in the collection named “params”. See
flax.core.variables
for more details on variables.The first argument of
init_fn
is assumed to be a PRNG key, which is provided automatically and does not have to be passed usinginit_args
orinit_kwargs
:>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... mean = self.param('mean', nn.initializers.lecun_normal(), x.shape) ... ... ... return x * mean >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}
In the example above, the function
lecun_normal
expects two arguments:key
andshape
, but onlyshape
has to be provided explicitly;key
is set automatically using the PRNG forparams
that is passed when initializing the module usinginit()
.- Parameters
name – The parameter name.
init_fn – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.
*init_args – The positional arguments to pass to init_fn.
unbox – If True,
AxisMetadata
instances are replaced by their unboxed value, seeflax.nn.meta.unbox
(default: True).**init_kwargs – The key-word arguments to pass to init_fn.
- Returns
The value of the initialized parameter. Throws an error if the parameter exists already.
- property path#
Get the path of this Module. Top-level root modules have an empty path
()
. Note that this method can only be used on bound modules that have a valid scope.Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class SubModel(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'SubModel path: {self.path}') ... return x >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'Model path: {self.path}') ... return SubModel()(x) >>> model = Model() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 2))) Model path: () SubModel path: ('SubModel_0',)
- perturb(name, value, collection='perturbations')[source]#
Add an zero-value variable (‘perturbation’) to the intermediate value.
The gradient of
value
would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients ofvalue
by runningjax.grad
on the perturbation argument.Note
This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.
Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = self.perturb('dense3', x) ... return nn.Dense(2)(x) >>> def loss(variables, inputs, targets): ... preds = model.apply(variables, inputs) ... return jnp.square(preds - targets).mean() >>> x = jnp.ones((2, 9)) >>> y = jnp.ones((2, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y) >>> print(intm_grads['perturbations']['dense3']) [[-1.456924 -0.44332537 0.02422847] [-1.456924 -0.44332537 0.02422847]]
If perturbations are not passed to
apply
,perturb
behaves like a no-op so you can easily disable the behavior when not needed:>>> model.apply(variables, x) # works as expected Array([[-1.0980128 , -0.67961735], [-1.0980128 , -0.67961735]], dtype=float32) >>> model.apply({'params': variables['params']}, x) # behaves like a no-op Array([[-1.0980128 , -0.67961735], [-1.0980128 , -0.67961735]], dtype=float32) >>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y) >>> 'perturbations' not in intm_grads True
- put_variable(col, name, value)[source]#
Updates the value of the given variable if it is mutable, or an error otherwise.
- Parameters
col – the variable collection.
name – the name of the variable.
value – the new value of the variable.
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute onself
is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
- sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#
Stores a value in a collection.
Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.
If the target collection is not mutable
sow
behaves like a no-op and returnsFalse
.Example:
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... self.sow('intermediates', 'h', h) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply(variables, x, mutable=['intermediates']) >>> jax.tree.map(jnp.shape, state['intermediates']) {'h': ((16, 4),)}
By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:
>>> class Foo2(nn.Module): ... @nn.compact ... def __call__(self, x): ... init_fn = lambda: 0 ... reduce_fn = lambda a, b: a + b ... self.sow('intermediates', 'h', x, ... init_fn=init_fn, reduce_fn=reduce_fn) ... self.sow('intermediates', 'h', x * 2, ... init_fn=init_fn, reduce_fn=reduce_fn) ... return x >>> x = jnp.ones((1, 1)) >>> model = Foo2() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply( ... variables, x, mutable=['intermediates']) >>> print(state['intermediates']) {'h': Array([[3.]], dtype=float32)}
- Parameters
col – The name of the variable collection.
name – The name of the variable.
value – The value of the variable.
reduce_fn – The function used to combine the existing value with the new value. The default is to append the value to a tuple.
init_fn – For the first value stored,
reduce_fn
will be passed the result ofinit_fn
together with the value to be stored. The default is an empty tuple.
- Returns
True
if the value has been stored successfully,False
otherwise.
- tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#
Creates a summary of the Module represented as a table.
This method has the same signature and internally calls
Module.init
, but instead of returning the variables, it returns the string summarizing the Module in a table.tabulate
usesjax.eval_shape
to run the forward computation without consuming any FLOPs or allocating memory.Additional arguments can be passed into the
console_kwargs
argument, for example,{'width': 120}
. For a full list ofconsole_kwargs
arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.ConsoleExample:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> # print(Foo().tabulate( >>> # jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))
This gives the following output:
Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B)
Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in
variables
which are sorted alphabetically.Note:
vjp_flops
returns0
if the module is not differentiable.- Parameters
rngs – The rngs for the variable collections as passed to
Module.init
.*args – The arguments to the forward computation.
depth – controls how many submodule deep the summary can go. By default, its
None
which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.show_repeated – If
True
, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default isFalse
.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.console_kwargs – An optional dictionary with additional keyword arguments that are passed to
rich.console.Console
when rendering the table. Default arguments are{'force_terminal': True, 'force_jupyter': False}
.table_kwargs – An optional dictionary with additional keyword arguments that are passed to
rich.table.Table
constructor.column_kwargs – An optional dictionary with additional keyword arguments that are passed to
rich.table.Table.add_column
when adding columns to the table.compute_flops – whether to include a
flops
column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion’s UNet, whereas otherwise tabulation would finish in 5 seconds).compute_vjp_flops – whether to include a
vjp_flops
column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X ofcompute_flops
.**kwargs – keyword arguments to pass to the forward computation.
- Returns
A string summarizing the Module.
- unbind()[source]#
Returns an unbound copy of a Module and its variables.
unbind
helps create a stateless version of a bound Module.An example of a common use case: to extract a sub-Module defined inside
setup()
and its corresponding variables: 1) temporarilybind
the parent Module; and then 2)unbind
the desired sub-Module. (Recall thatsetup()
is only called when the Module is bound.):>>> class Encoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(256)(x) >>> class Decoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(784)(x) >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = Encoder() ... self.decoder = Decoder() ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> module = AutoEncoder() >>> variables = module.init(jax.random.key(0), jnp.ones((1, 784))) >>> # Extract the Encoder sub-Module and its variables >>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
- Returns
A tuple with an unbound copy of this Module and its variables.
- variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)[source]#
Declares and returns a variable in this Module.
See
flax.core.variables
for more information. See alsoparam()
for a shorthand way to define read-only variables in the “params” collection.Contrary to
param()
, all arguments passing usinginit_fn
should be passed on explicitly:>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... key = self.make_rng('stats') ... mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape) ... ... ... return x * mean.value >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}
In the example above, the function
lecun_normal
expects two arguments:key
andshape
, and both have to be passed on. The PRNG forstats
has to be provided explicitly when callinginit()
andapply()
.- Parameters
col – The variable collection name.
name – The variable name.
init_fn – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.
*init_args – The positional arguments to pass to init_fn.
unbox – If True,
AxisMetadata
instances are replaced by their unboxed value, seeflax.nn.meta.unbox
(default: True).**init_kwargs – The key-word arguments to pass to init_fn
- Returns
A
flax.core.variables.Variable
that can be read or set via “.value” attribute. Throws an error if the variable exists already.
- property variables#
Returns the variables in this module.
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#
Creates an apply function to call
fn
with a bound module.Unlike
Module.apply
this function returns a new function with the signature(variables, *args, rngs=None, **kwargs) -> T
whereT
is the return type offn
. Ifmutable
is notFalse
the return type is a tuple where the second item is aFrozenDict
with the mutated variables.The apply function that is returned can be directly composed with JAX transformations like
jax.jit
:>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- Parameters
fn – The function that should be applied. The first argument passed will be a module instance of the
module
with variables and RNGs bound to it.module – The
Module
that will be used to bind variables and RNGs to. TheModule
passed as the first argument tofn
will be a clone of module.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.capture_intermediates – If
True
, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Returns
The apply function wrapping
fn
.
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
Creates an init function to call
fn
with a bound module.Unlike
Module.init
this function returns a new function with the signature(rngs, *args, **kwargs) -> variables
. The rngs can be a dict of PRNGKeys or a single`PRNGKey
which is equivalent to passing a dict with one PRNGKey with the name “params”.The init function that is returned can be directly composed with JAX transformations like
jax.jit
:>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- Parameters
fn – The function that should be applied. The first argument passed will be a module instance of the
module
with variables and RNGs bound to it.module – The
Module
that will be used to bind variables and RNGs to. TheModule
passed as the first argument tofn
will be a clone of module.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Returns
The init function wrapping
fn
.
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
Creates an init function to call
fn
with a bound module that also returns the function outputs.Unlike
Module.init_with_output
this function returns a new function with the signature(rngs, *args, **kwargs) -> (T, variables)
whereT
is the return type offn
. The rngs can be a dict of PRNGKeys or a single`PRNGKey
which is equivalent to passing a dict with one PRNGKey with the name “params”.The init function that is returned can be directly composed with JAX transformations like
jax.jit
:>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- Parameters
fn – The function that should be applied. The first argument passed will be a module instance of the
module
with variables and RNGs bound to it.module – The
Module
that will be used to bind variables and RNGs to. TheModule
passed as the first argument tofn
will be a clone of module.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.capture_intermediates – If
True
, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Returns
The init function wrapping
fn
.
- flax.linen.intercept_methods(interceptor)[source]#
Registers a new method interceptor.
Method interceptors allow you to (at a distance) intercept method calls to modules. It works similarly to decorators. You could modify args/kwargs before calling the underlying method and/or modify the result returning from calling the underlying method. Or you could completely skip calling the underlying method and decide to do something differently. For example:
>>> import flax.linen as nn >>> import jax.numpy as jnp ... >>> class Foo(nn.Module): ... def __call__(self, x): ... return x ... >>> def my_interceptor1(next_fun, args, kwargs, context): ... print('calling my_interceptor1') ... return next_fun(*args, **kwargs) ... >>> foo = Foo() >>> with nn.intercept_methods(my_interceptor1): ... _ = foo(jnp.ones([1])) calling my_interceptor1
You could also register multiple interceptors on the same method. Interceptors will run in order. For example:
>>> def my_interceptor2(next_fun, args, kwargs, context): ... print('calling my_interceptor2') ... return next_fun(*args, **kwargs) ... >>> with nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor1 calling my_interceptor2
You could skip other interceptors by directly calling the
context.orig_method
. For example:>>> def my_interceptor3(next_fun, args, kwargs, context): ... print('calling my_interceptor3') ... return context.orig_method(*args, **kwargs) >>> with nn.intercept_methods(my_interceptor3), \ ... nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor3
The following methods couldn’t be intercepted:
Methods decoratored with
nn.nowrap
.Dunder methods including
__eq__
,__repr__
,__init__
,__hash__
, and__post_init__
.Module dataclass fields.
Module descriptors.
- Parameters
interceptor – A method interceptor.
Modifies one of the Modules such that they share the same scope. This is useful when you want to wrap a Module and extend its functionality without changing the parameter structure.
share_scope
takes two Modules,module
andother
.module
will useother
’s scope ifother
has a scope and its not a descendant of``module``’s scope:>>> import flax.linen as nn >>> import jax >>> from jax import numpy as jnp, random ... >>> class DenseLoRA(nn.Module): ... base: nn.Dense ... rank: int ... ... def setup(self): ... nn.share_scope(self, self.base) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.base.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.base(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... dense = nn.Dense(10) # base scope ... return DenseLoRA(dense, rank=2)(x) # reuse the base scope ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['Dense_0'].keys()) ['A', 'B', 'kernel', 'bias']
When
other
’s scope is a descendant ofmodule
’s scope thenother
will usemodule
’s scope instead:>>> class DenseLoRA(nn.Module): ... features: int ... rank: int ... ... def setup(self): ... self.child = nn.Dense(self.features) ... nn.share_scope(self, self.child) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.child(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... return DenseLoRA(10, rank=2)(x) ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['DenseLoRA_0'].keys()) ['A', 'B', 'kernel', 'bias']