SPMD#
Utilities for working with jit and partitioned models.
This module introduces axis_rules, logical_to_mesh_axes,
logical_to_mesh, with_logical_constraint for appyling jit sharding
constraints in terms of “logical named axes” rather than jit’s default mesh
axes.
Additionally the LogicallyPartitioned metadata wrapper is defined as
well as the initializer function wrapper ``with_logical_partitioning``for
introducing logical axis metadata into a model’s variables.
- flax.linen.Partitioned(value, names, mesh=None)[source]#
Wrapper for partitioning metadata.
Partitionedis used to extend variables with partitioning information required forjax.experimental.pjit.The easiest way to define Partitioned variables is by using the
with_partitioningwrapper around the variable initializer.Example:
class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. # this way we can determine the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), PartitionSpec("data", "model"))) apply_fn(variables, x)
Partitionedvalues can gain additional axes when using transformations likenn.vmapandnn.scan. In this case you can specify the name of the new axis with the metadata_params args in vmap/scan:class Model(nn.Module): @nn.compact def __call__(self, x): def body(mdl, c): c = MLP(4096)(c) return c, () c, _ = nn.scan( body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8, metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x) return c
- flax.linen.with_partitioning(fn, names, mesh=None)[source]#
Wraps a function’s return value with Partitioned.
Example:
>>> import flax.linen as nn >>> kernel_init = nn.with_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- Parameters:
fn – The function to be wrapped. Typically this is an initializer.
names – The logical axis passed to
Partitioned.mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.
- Returns:
A function wrapping
fnthat will return an instance ofPartitioned.
- flax.linen.get_partition_spec(tree)[source]#
Extracts a PartitionSpec tree from a PyTree containing
Partitionedvalues.
- flax.linen.get_sharding(tree, mesh)[source]#
Extracts a jax.sharding tree from a PyTree containing
Partitionedvalues and a mesh.
- flax.linen.LogicallyPartitioned(value: Any, names: tuple[Optional[str], ...], mesh: jax._src.mesh.Mesh | None = None, rules: collections.abc.Sequence[tuple[str, str | tuple[str, ...] | None]] | None = None)[source]#
- flax.linen.logical_axis_rules(rules)[source]#
Context manager for setting the logical to mesh axis bindings.
- flax.linen.set_logical_axis_rules(rules)[source]#
Sets the global logical axis to mesh axis binding.
- flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#
Compute layout for an array.
The rules are in order of precedence, and consist of pairs:
(ArrayDimensionName, MeshDimensionName), meaning that the given array dimension (if present and unused) should be sharded across the given mesh dimension (if present and unused).A Layout of an Array is expressed as a tuple with one element for each dimension in the Array. The element is either None, or is the name of a mesh-dimension, meaning that this dimension of the array is sharded across this dimension of the mesh.
For example, given an array with:
array_dim_names = ('batch', 'length', 'heads', 'features')
and the layout rules are:
rules = (('batch', 'X'), ('features', 'X'), ('heads', 'Y'), ('batch', 'Z'))
then this function will return:
PartitionSpec('X', None, 'Y', None)
- Parameters:
array_dim_names – Tuple of array dimension names or None.
rules – Optional logical to mesh rules override. Defaults to using the rules defined in the dynamic context set from the
axis_rulesfunction.
- Returns:
PartitionSpec for the parameter.
- flax.linen.logical_to_mesh(tree, rules=None)[source]#
Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.
- flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[source]#
Convert pytrees of logical PartitionSpecs to shardings.
- flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#
Version of jit’s with_sharding_constraint that uses logical axis names.
- flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#
Wraps a function’s return value with LogicallyPartitioned.
Example:
>>> import flax.linen as nn >>> kernel_init = nn.with_logical_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- Parameters:
fn – The function to be wrapped. Typically this is an initializer.
names – The logical axis passed to
LogicallyPartitioned.mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.
rules – Optional logical to mesh rules use. If None, the global rules are used if available.
- Returns:
A function wrapping
fnthat will return an instance ofLogicallyPartitioned.