SPMD#
Utilities for working with jit and partitioned models.
This module introduces axis_rules
, logical_to_mesh_axes
,
logical_to_mesh
, with_logical_constraint
for appyling jit sharding
constraints in terms of “logical named axes” rather than jit’s default mesh
axes.
Additionally the LogicallyPartitioned
metadata wrapper is defined as
well as the initializer function wrapper ``with_logical_partitioning``for
introducing logical axis metadata into a model’s variables.
- flax.linen.Partitioned(value, names, mesh=None)[source]#
Wrapper for partitioning metadata.
Partitioned
is used to extend variables with partitioning information required forjax.experimental.pjit
.The easiest way to define Partitioned variables is by using the
with_partitioning
wrapper around the variable initializer.Example:
class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. # this way we can determine the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), PartitionSpec("data", "model"))) apply_fn(variables, x)
Partitioned
values can gain additional axes when using transformations likenn.vmap
andnn.scan
. In this case you can specify the name of the new axis with the metadata_params args in vmap/scan:class Model(nn.Module): @nn.compact def __call__(self, x): def body(mdl, c): c = MLP(4096)(c) return c, () c, _ = nn.scan( body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8, metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x) return c
- flax.linen.with_partitioning(fn, names, mesh=None)[source]#
Wraps a function’s return value with Partitioned.
Example:
>>> import flax.linen as nn >>> kernel_init = nn.with_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- Parameters
fn – The function to be wrapped. Typically this is an initializer.
names – The logical axis passed to
Partitioned
.mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.
- Returns
A function wrapping
fn
that will return an instance ofPartitioned
.
- flax.linen.get_partition_spec(tree)[source]#
Extracts a PartitionSpec tree from a PyTree containing
Partitioned
values.
- flax.linen.get_sharding(tree, mesh)[source]#
Extracts a jax.sharding tree from a PyTree containing
Partitioned
values and a mesh.
- flax.linen.LogicallyPartitioned(value: Any, names: tuple[Optional[str], ...], mesh: jax._src.mesh.Mesh | None = None, rules: collections.abc.Sequence[tuple[str, str | tuple[str, ...] | None]] | None = None)[source]#
- flax.linen.logical_axis_rules(rules)[source]#
Context manager for setting the logical to mesh axis bindings.
- flax.linen.set_logical_axis_rules(rules)[source]#
Sets the global logical axis to mesh axis binding.
- flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#
Compute layout for an array.
The rules are in order of precedence, and consist of pairs:
(ArrayDimensionName, MeshDimensionName)
, meaning that the given array dimension (if present and unused) should be sharded across the given mesh dimension (if present and unused).A Layout of an Array is expressed as a tuple with one element for each dimension in the Array. The element is either None, or is the name of a mesh-dimension, meaning that this dimension of the array is sharded across this dimension of the mesh.
For example, given an array with:
array_dim_names = ('batch', 'length', 'heads', 'features')
and the layout rules are:
rules = (('batch', 'X'), ('features', 'X'), ('heads', 'Y'), ('batch', 'Z'))
then this function will return:
PartitionSpec('X', None, 'Y', None)
- Parameters
array_dim_names – Tuple of array dimension names or None.
rules – Optional logical to mesh rules override. Defaults to using the rules defined in the dynamic context set from the
axis_rules
function.
- Returns
PartitionSpec for the parameter.
- flax.linen.logical_to_mesh(tree, rules=None)[source]#
Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.
- flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[source]#
Convert pytrees of logical PartitionSpecs to shardings.
- flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#
Version of jit’s with_sharding_constraint that uses logical axis names.
- flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#
Wraps a function’s return value with LogicallyPartitioned.
Example:
>>> import flax.linen as nn >>> kernel_init = nn.with_logical_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- Parameters
fn – The function to be wrapped. Typically this is an initializer.
names – The logical axis passed to
LogicallyPartitioned
.mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.
rules – Optional logical to mesh rules use. If None, the global rules are used if available.
- Returns
A function wrapping
fn
that will return an instance ofLogicallyPartitioned
.