Lifted transformations#

⚠️ Advanced topic ⚠️

This design note explains the underlying implementation of flax.linen.transform, which enables JAX transformations inside Flax Modules.

Introduction#

JAX uses a functional API meaning that it only guarantees correct behavior when using functions without side effects (JAX docs). Typically, these side effects are the result of mutating an object that lives outside the function.

The functional paradigm has some advantages like the ability to explicitly reason about state and stochasticity. The function output only changes when an input argument changes. Therefore, a function is guaranteed to behave deterministically.

But pure functions offer another big advantage to JAX: specifically, they enable functional transformations. For example jax.vmap(f) will vectorize a function f. Because f cannot have side effects the vectorized/parallel version of f is well-defined. To see why we need this restriction, consider what happens if f would increment a counter or draw a random number. Would f draw the same or a different random number for each item in the vector? Would each item in the batch have its own counter or is the counter shared among the items? And in what order is the counter incremented if f is computed in parallel? The answer to all these questions is “it depends”. The behavior is ambiguous and the functional constraint elegantly avoids this problem.

Flax introduces a safe way to have limited randomness and stateful variables in a JAX-compatible form. The reason why the state in Flax is not problematic is because it is local: inside a Flax Module there are variables and PRNG sequences, but on the outside there are only JAX Arrays and PRNG keys.

For most use cases, Flax is used to define models in a stateful way. Because a Module behaves like a pure function externally, we can fully utilize JAX with all of its transformations. There are, however, cases when we want to have the best of both worlds by using transformations and Module together. This design note explains how we extend JAX’s functional transformation to work on Modules that have internal state and randomness.

Functionalization#

Before we jump into the details let’s consider a simple example where we would like to use vmap inside a Module.

First, we define a simple MLP without any transformations:

import jax
from jax import random, numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    h = nn.Dense(4, name='hidden')(xs)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

Now what if we want to have separate MLP parameters for each item in xs? If this were “vanilla JAX” we could imagine writing something like jax.vmap(apply_mlp)(mlp_params, xs). But doing something like this in Linen will actually fail:

class NaiveVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP()
    return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs)  # fails

JAX will raise an error when vmap is used on mlp because it’s not a JAX array or a simple container of arrays. We can not really blame JAX for refusing to perform this under-specified job. After all, it’s not even clear what should happen here. The parameters inside the MLP are not even initialized yet and we will need a separate PRNG key for each group of parameters. jax.vmap can only broadcast or map over an axis but it cannot automatically split an PRNG key. Therefore, we have to call jax.random.split manually.

We can fix this problem by first turning MLP into a pure init and apply function. Afterwards, we use the param method to store the parameters:

class ManualVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP(parent=None)
    init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
    apply_fn = jax.vmap(mlp.apply, in_axes=0)
    mlp_params = self.param('mlp', init_fn, xs)
    return apply_fn({'params': mlp_params}, xs)

xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        hidden: {
            bias: (3, 4),
            kernel: (3, 4, 4),
        },
        out: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

Here, MLP(parent=None) creates a detached instance of MLP. This avoids reserving a name for the submodule inside the current module. Although not strictly necessary, this also ensures we cannot accidentally use the MLP instance in a stateful way and we are forced to use it through either .init or .apply.

This example is still relatively concise but it already takes a few extra “bookkeeping” statements to make it work. However, this implementation has a number of limitations:

  1. During initialization, we call the submodule twice through init_fn and apply_fn. If the submodule used the same trick to do functional transformation we will end up executing a lot of code as the number of module calls grows like 2^d where d is the number of nested function transformations.

  2. The implementation assumes the submodule only requires the parameter RNG sequence.

  3. The implementation assumes we only create variables in the “params” collection during init. However, it does not support other variable collections and creating/updating variables in apply.

Point 3 in particular makes manual functionalization cumbersome. Feel free to try and extend the above example with a nn.BatchNorm layer in the MLP module. This will require dealing with some additional complexity like storing the updated batch stats and making sure the batch stats are not mutable inside vmap when it should be immutable (e.g.: eval mode).

We call the process of transforming a stateful Module into a pure function “functionalization”. By temporarily turning a stateful Module into a function we make it compatible with JAX’s functional transformations.

Lifting#

Flax provides an alternative for manual functionalization which we call lifted transformation. Lifted transformations are defined in flax.core.lift. All the lifted JAX transformations are defined with a single generic lifting API called pack.

A number of decisions had to be made in order to define pack. The implementation of pack controls how variables and rngs are lifted and how fine-grained the user control is. It must also decide whether lifting decisions are made at variable or transformation definition.

Lifting granularity#

With the Linen API, users can define arbitrary variable collections and PRNG sequences. Each variable in a collection is lifted in the same way.

Collections are typically given a semantically meaningful name like “params” or “batch_stats” rather than a general purpose name like “state”. Because collections carry semantic meaning we can decide at the transformation level how each collection should be lifted. For example, we want to share all parameter variables when we add a batch dimension to a model.

At the same time we can write generic code that uses transformations without knowing exactly what kind of variables the submodules will create. Collections thus strike a balance between fine-grained control and generality. We also avoid brittle string matching code that loops over all variables and tries to split up collections in an ad-hoc way based on naming conventions like: target all variables with the name prefix “kernel”. If more fine-grained control is necessary a user can simply split up a set of variables over multiple collections that should be handled differently.

Transformation vs variable control#

Lifting behavior could be defined either at the transformation level or during variable definition. We use transformation level definitions of lifting behavior. The reason for this choice is that there are many different transformations with various behaviors. For example: vmap has broadcasted and vectorized arguments, while scan has scan, carry, and broadcast arguments. A variable would have to define its behavior for all these transformations otherwise a Module would not be compatible with these transformations. Alternatively, we would have to make default decisions for how transformations are handled. However, this could lead to silent bugs because the behavior might not actually be valid given the users intent.

The lift package also provides a general purpose transform, which allows an arbitrary function to transform a variable collection. For example, this can be used to tie the weights in a tied auto-encoder by transposing the weights. It is unclear whether a similar general purpose transform could be defined if lifting decisions were made at variable definition.

Linen#

The lifting module does not know about the Linen Module API. Instead it operates directly on instances of flax.core.Scope. A Scope instance contains the variables and PRNG sequences of a Module. Each Module instance has a Scope instance in the .scope field if it has a parent or it was created using init or apply. Typically, the top-level Module instance — on which you call init or apply — is the only Module instance that does not have a Scope bound to it.

When a Module is transformed, we use the flax.core.lift APIs to lift the scope and use Module.clone() to create a new Module instance with the lifted scope bound to it.

flax.linen.transforms exposes wrappers for the transformations in flax.core.lift. The core lifting APIs operate on functions while the Linen wrappers can transform either a Module class or a Module method.

Thus, lifting is implemented independently from the Linen API. This separation of concern simplifies the implementation, while potentially allowing alternative Module abstractions to build upon a common core for lifting and state management.

Implementation#

The pack(fn, in_vars, out_vars, rngs) API goes through the following stages:

  1. Scope de-duplication

    This stage is only relevant if multiple Scopes are lifted together. In this case we must first find the set of root scopes. A scope is a root if none of its ancestors are in the set of scopes that need to be lifted.

    By only lifting roots we avoid lifting the same variables twice.

    For non-root scopes we store a reference to its ancestor scope and a path such that we can later reconstruct it (stage 4).

  2. Filter stage

    Variables and PRNG sequences are split up into groups. This way fn can lift each group into the transformation separately. A group is defined by a filter specified as:

    • a list of collections/prng names

    • True (match everything)

    • False (match nothing)

    • DenyList(filter) (match everything but the specified collections (e.g.: DenyList(['params']) matches everything except the ‘params’ collection.)).

    A collection or PRNG sequence can only be put into a single group. If a collection matches multiple filters, it will be put into the first group with a matching filter. If a collection or PRNG sequence does not match any filter it will not be lifted. This means that it cannot be used inside the transformation and attempting to do this will cause an error to be raised. For example, in_vars = (["params"], True) will cause the “params” collection to be put in the first group and all other collection to be put in the second group.

    For each PRNG sequence that is matched we seed a new PRNG sequence by calling make_rng. This avoids the need to update the PRNG state after the lifted transformation is complete.

  3. Transform-specific lifting

    fn is called with the variable and PRNG groups. JAX transforms have varying signatures and lifting options. Arguably the cleanest example is vmap. In the case of vmap the function arguments, PRNGs and variable collections are passed into a jax.vmap wrapped function.

  4. Scope reconstruction

    Now that the variables and PRNGs are lifted inside the transformation, we want to recreate the lifted scopes. Pack calls fn with a scope_fn that takes the lifted variables and PRNGs and returns the reconstructed scopes with the lifted variables and rng sequences.

  5. Repack stage

    After we have used the lifted scopes we have to retrieve the updated variables (PRNG sequences can simply be discarded). pack passes the repack_fn to support this. This stage is similar to stage 2 except that we only lift variables and immutable variables are ignored. Immutable variables cannot be updated. Therefore, they should not be returned from the transformed function.

  6. Commit stage

    pack expects fn to return a pair where the first item will simply be returned from pack and the second item should be the repacked variables. The updated variables are stored in the original/un-lifted scopes such that the mutations that happen inside the transformation survive after the transformation completes.

Using pack example#

A minimal example of using pack to transpose each matrix in a variable collection:

from flax.core import lift
from flax.core import Scope, init, apply, nn as core_nn

def lift_transpose(fn, target='params', variables=True, rngs=True):
  # by default we transpose 'params' and simply pass through all other variables.
  def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args):
    # normally we would first call into a JAX transformed function here...
    target, rest = variable_groups
    def trans(x):
      if x.ndim == 2:
        return x.T
      return x
    target = jax.tree_util.tree_map(trans, target)
    variable_groups = (target, rest)
    scope = scope_fn(variable_groups, rng_groups)
    y = fn(scope, *args)
    out_variables = repack_fn(scope)
    return y, out_variables
  return lift.pack(
      wrapper,
      in_variable_filters=(target, variables),
      out_variable_filters=(variables,),
      rng_filters=(rngs,))

x = jnp.ones((3, 2))
y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4)

NOTE that most users should not need to interact with pack directly. Please open a GitHub issue when you find a use case that is not supported yet by the existing lifted transformations.

Supported transformations#

Jax Transform

Supported in Linen?

Comments

vmap

scan

Carry variables cannot be initialized inside the scan body.

remat

jit

Current implementation might cause unnecessary recompilation.

jvp

vjp

custom_vjp

custom_jvp

while_loop

Carry variables cannot be initialized inside the while_loop body.

cond

Variable initialization / mutation must structurally match across branches.

switch

Variable initialization / mutation must structurally match across branches.

pmap

xmap

References:

Linen examples#

Going back to our original example, we can now use nn.vmap to simplify our implementation:

class LinenVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs)

variables = LinenVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        Dense_0: {
            bias: (3, 4),
            kernel: (3, 2, 4),
        },
        Dense_1: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

Here we use variable_axes={'params': 0} to indicate that parameters are vectorized rather than shared and split_rngs={'params': True} means each set of parameters is initialized independently.

We can also extend the example with some inner state by adding a BatchNorm layer:

class StatefulMLP(nn.Module):
  @nn.compact
  def __call__(self, x, *, train):
    h = nn.Dense(4, name='hidden')(x)
    h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

class LinenStatefulVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs, *, train):
    VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs, train=train)
variables = LinenStatefulVmapMLP().init(random.key(0), xs)

All we had to add to nn.vmap is 'batch_stats': 0, indicating that the batch stats are vectorized rather than shared along the first axis.

Alternatives#

Other numerical computation frameworks consider variables a first-class citizen. An alternative to functionalization would be to use a variable system either integrated or on top of JAX. An advantage of this is that per-variable lifting becomes easier. If variables are part of the JAX IR (JAXPR), we could inspect which variables have to be lifted in a certain computation. Optionally, they could be annotated with a collection tag to decide on various lifting options.

The downside of this approach is that a variable system is more complicated. Variables are related references and break a core assumption of Functional Programming (see referential transparency) Other APIs that currently have a functional interface would probably require integration as well (e.g.: checkpointing and optimization APIs).