Init/Apply#
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#
Creates an apply function to call
fnwith a bound module.Unlike
Module.applythis function returns a new function with the signature(variables, *args, rngs=None, **kwargs) -> TwhereTis the return type offn. Ifmutableis notFalsethe return type is a tuple where the second item is aFrozenDictwith the mutated variables.The apply function that is returned can be directly composed with JAX transformations like
jax.jit:>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- Parameters:
fn – The function that should be applied. The first argument passed will be a module instance of the
modulewith variables and RNGs bound to it.module – The
Modulethat will be used to bind variables and RNGs to. TheModulepassed as the first argument tofnwill be a clone of module.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool: all/no collections are mutable.str: The name of a single mutable collection.list: A list of names of mutable collections.capture_intermediates – If
True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Returns:
The apply function wrapping
fn.
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
Creates an init function to call
fnwith a bound module.Unlike
Module.initthis function returns a new function with the signature(rngs, *args, **kwargs) -> variables. The rngs can be a dict of PRNGKeys or a single`PRNGKeywhich is equivalent to passing a dict with one PRNGKey with the name “params”.The init function that is returned can be directly composed with JAX transformations like
jax.jit:>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- Parameters:
fn – The function that should be applied. The first argument passed will be a module instance of the
modulewith variables and RNGs bound to it.module – The
Modulethat will be used to bind variables and RNGs to. TheModulepassed as the first argument tofnwill be a clone of module.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool: all/no collections are mutable.str: The name of a single mutable collection.list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Returns:
The init function wrapping
fn.
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
Creates an init function to call
fnwith a bound module that also returns the function outputs.Unlike
Module.init_with_outputthis function returns a new function with the signature(rngs, *args, **kwargs) -> (T, variables)whereTis the return type offn. The rngs can be a dict of PRNGKeys or a single`PRNGKeywhich is equivalent to passing a dict with one PRNGKey with the name “params”.The init function that is returned can be directly composed with JAX transformations like
jax.jit:>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- Parameters:
fn – The function that should be applied. The first argument passed will be a module instance of the
modulewith variables and RNGs bound to it.module – The
Modulethat will be used to bind variables and RNGs to. TheModulepassed as the first argument tofnwill be a clone of module.mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable:
bool: all/no collections are mutable.str: The name of a single mutable collection.list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.capture_intermediates – If
True, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
- Returns:
The init function wrapping
fn.