Randomness and PRNGs in Flax#
In this guide, you will learn how Flax uses JAX’s explicit pseudorandom number generator (PRNG) keys to emulate randomness, and adds some additional features to make it easier for users to thread PRNG keys through different Flax Module
s.
If you are new to JAX PRNG keys or need a refresher, check out:
Setup#
Install or upgrade Flax, and then import some necessary dependencies.
Note: This guide uses the --xla_force_host_platform_device_count=8
flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don’t need this if you are already using a multi-device Google Cloud TPU environment, for example, on Google Cloud or in a Kaggle VM with a TPU.
!pip install -q flax
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import flax, flax.linen as nn
import jax, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
import hashlib
jax.devices()
[CpuDevice(id=0),
CpuDevice(id=1),
CpuDevice(id=2),
CpuDevice(id=3),
CpuDevice(id=4),
CpuDevice(id=5),
CpuDevice(id=6),
CpuDevice(id=7)]
Set the JAX config variable jax_threefry_partitionable
to True
. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under jax.jit
. Refer to JAX discussion for more details.
jax.config.update('jax_threefry_partitionable', True)
assert jax.config.jax_threefry_partitionable == True
assert jax.config.jax_default_prng_impl == 'threefry2x32'
Receiving, manipulating and creating PRNG keys with Module.make_rng
#
The primary method Flax uses to receive, manipulate and create PRNG keys is via the Module
method self.make_rng
. It is a method that accepts a string name that represents an “RNG stream”. Each RNG stream has an initial starting seed PRNG key, which the user passes in as a dictionary argument (i.e. into an .init
or .apply
function), and the starting seed is used by self.make_rng
to generate more PRNG keys for that stream. If self.make_rng
is called on a string name that does not have an initial starting seed PRNG key (i.e. the user did not pass a key with the corresponding name into .init
or .apply
), then self.make_rng
will use the 'params'
key as the initial starting seed by default.
Note that this method can only be called with bounded modules (see The Flax Module lifecycle).
class RNGModule(nn.Module):
@nn.compact
def __call__(self):
print(self.make_rng('rng_stream'))
print(self.make_rng('rng_stream'))
print(self.make_rng('rng_stream'))
rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
Now if we use a different starting seed PRNG key, we will generate different values (as intended).
variables = rng_module.init({'rng_stream': jax.random.key(1)})
Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]
Calling self.make_rng
for one stream will not affect the random values generated from another stream; i.e. the call order doesn’t matter.
class RNGModuleTwoStreams(nn.Module):
@nn.compact
def __call__(self):
# same value as first code snippet above
print(f"rng_stream1: {self.make_rng('rng_stream1')}")
# same value as second code snippet above
print(f"rng_stream2: {self.make_rng('rng_stream2')}")
# same value as first code snippet above
print(f"rng_stream1: {self.make_rng('rng_stream1')}")
# same value as second code snippet above
print(f"rng_stream2: {self.make_rng('rng_stream2')}")
# same value as first code snippet above
print(f"rng_stream1: {self.make_rng('rng_stream1')}")
# same value as second code snippet above
print(f"rng_stream2: {self.make_rng('rng_stream2')}")
rng_module_two_streams = RNGModuleTwoStreams()
variables = rng_module_two_streams.init(
{'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(1)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]
Providing the same seed PRNG key will result in the same values being generated (provided that the same operations are used for those keys).
variables = rng_module_two_streams.init(
{'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(0)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
How self.make_rng
works under the hood#
This is what happens when self.make_rng
(flax.linen.Module.make_rng
) is called:
The following data is collected:
The path of the
Module
as provided byself.scope.path
(the top-level root module has an empty path()
).The
self.make_rng
call count. That is, the number of timesself.make_rng
has been called for this specific stream (including this call).Note: Each sub-
Module
will have its own individual call count that’s separate from otherModule
s. For example, aModule
that has calledself.make_rng('params')
twice and contains a sub-Module
that has calledself.make_rng('params')
once, will have a call count of 2 and 1 for each of the RNG stream'params'
, respectively.
The data is bundled into a tuple and fed into a hash function and produces an integer.
The generated integer is folded into the RNG stream’s starting seed PRNG key to generate a new, unique PRNG key.
Below is a slightly simplified version of the hash function that Flax uses for self.make_rng
:
def produce_hash(data):
m = hashlib.sha1()
for x in data:
if isinstance(x, str):
m.update(x.encode('utf-8'))
elif isinstance(x, int):
m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big'))
else:
raise ValueError(f'Expected int or string, got: {x}')
d = m.digest()
hash_int = int.from_bytes(d[:4], byteorder='big')
return hash_int
And now you can manually reproduce the PRNG keys generated from self.make_rng
:
stream_seed = jax.random.key(0)
for call_count in range(1, 4):
hash_int = produce_hash(data=(call_count,))
print(jax.random.fold_in(stream_seed, jnp.uint32(hash_int)))
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
Sub-Module
s and self.make_rng
#
This section explores how self.make_rng
(flax.linen.Module.make_rng
) behaves with sub-Module
s.
Consider the following example:
class RNGSubSubModule(nn.Module):
def __call__(self):
print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
class RNGSubModule(nn.Module):
@nn.compact
def __call__(self):
print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
RNGSubSubModule()()
class RNGModule(nn.Module):
@nn.compact
def __call__(self):
print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
RNGSubModule()()
rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]
As previously discussed, the data that is fed into the Flax hash function consists of:
The path of the
Module
, provided byself.scope.path
(the top-level root module has an empty path()
); andThe call count for the specific RNG stream.
In addition, note that each Flax Module
and sub-Module
have their own individual call counts, even for the same RNG stream. The convention for sub-Module
names is: f'{module_name}_{module_number}'
. For example, the first Dense
sub-Module
will be called Dense_0
, the second one will be called Dense_1
, and so on.
Therefore, the following data will be fed into the hash function:
For
RNGModule
: The data is just the call count, such as(1,)
and(2,)
, since the rootModule
has an empty path.For
RNGSubModule
: The data is('RNGSubModule_0', 1)
and('RNGSubModule_0', 2)
.For
RNGSubSubModule
: The data is('RNGSubModule_0', 'RNGSubSubModule_0', 1)
and('RNGSubModule_0', 'RNGSubSubModule_0', 2)
.
With this data, you can manually reproduce the PRNG keys generated from the Module
and sub-Module
s using self.make_rng
.
For example:
stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModule_0')):
if initial_data:
module_name = initial_data[-1]
else:
module_name = 'RNGModule'
for call_count in (1, 2):
hash_int = produce_hash(data=initial_data+(call_count,))
rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]
If the same sub-Module
class is used multiple times, you can increment the suffix of the sub-Module
name accordingly. For example: RNGSubModule_0
, RNGSubModule_1
, and so on.
class RNGSubModule(nn.Module):
@nn.compact
def __call__(self):
print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
class RNGModule(nn.Module):
@nn.compact
def __call__(self):
print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
RNGSubModule()()
RNGSubModule()()
rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]
stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_1',)):
if initial_data:
module_name = initial_data[-1]
else:
module_name = 'RNGModule'
for call_count in (1, 2):
hash_int = produce_hash(data=initial_data+(call_count,))
rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]
Using self.param
and self.variable
#
Flax users have the option of creating additional parameters and variables in their modules by using the self.param
and self.variable
Module
methods. An init_fn
argument must be passed to these methods so that it can generate the initial value of the parameter/variable. self.make_rng
is commonly used implicitly or explicitly in this init_fn
, since many initializer functions are stochastic in nature and require a PRNG key. See the full list of Flax initializers here.
There are a couple of differences between the two methods that the user should take note of:
self.param
always creates a parameter in the'params'
collection, whereasself.variable
creates a variable in any collection the user specifiesself.param
will automatically callself.make_rng('params')
and pass in the generated PRNG key implicitly to theinit_fn
of the parameter you instantiated (it will be passed in as the first argument), whereas users will have to manually specify what RNG stream to callself.make_rng
on in theinit_fn
ofself.variable
(it could be'params'
or something different).
Below is an example using both self.param
and self.variable
:
class Model(nn.Module):
@nn.compact
def __call__(self, x):
# kernel will use 'params' seed, initial data will include 'Dense_0', call count 1
x = nn.Dense(2, kernel_init=jax.random.normal, use_bias=False)(x)
# model_param will use 'params' seed, call count 1
model_param = self.param('model_param', jax.random.normal, x.shape)
# model_variable1 will use 'params' seed, call count 2
model_variable1 = self.variable(
'other_collection',
'model_variable1',
lambda: jax.random.normal(self.make_rng('params'), x.shape),
)
# model_variable2 will use 'other' seed, call count 1
model_variable2 = self.variable(
'other_collection',
'model_variable2',
lambda: jax.random.normal(self.make_rng('other'), x.shape),
)
# kernel will use 'params' seed, initial data will include 'Dense_1', call count 1
# bias will use 'params' seed, initial data will include 'Dense_1', call count 2
x = nn.Dense(2, kernel_init=jax.random.normal, bias_init=jax.random.normal)(
x
)
return x
model = Model()
variables = model.init(
{'params': jax.random.key(0), 'other': jax.random.key(1)}, jnp.ones((2, 2))
)
print(variables['params']['Dense_0']['kernel'])
print(variables['params']['model_param'])
print(variables['other_collection']['model_variable1'])
print(variables['other_collection']['model_variable2'])
print(variables['params']['Dense_1']['kernel'])
print(variables['params']['Dense_1']['bias'])
[[-1.6185919 0.700908 ]
[-1.3146383 -0.79342234]]
[[ 0.0761425 -1.6157459]
[-1.6857724 0.7126891]]
[[ 0.60175574 0.2553228 ]
[ 0.27367848 -2.1975214 ]]
[[1.6249592 0.30813068]
[1.6613585 1.0404155 ]]
[[ 0.0030665 0.29551846]
[ 0.16670242 -0.78252524]]
[1.582462 0.15216611]
Remember:
there is a separate count for each RNG stream; this is why the count for
self.make_rng('other')
starts at 1 even though there were earlier calls ofself.make_rng('params')
each submodule has their own separate count for each rng stream; this is why each
Dense
layer has their own separate count forself.make_rng('params')
and whymodel_param
andmodel_variable1
share the same count (since they are defined within the same top-level parent module)
params_seed = jax.random.key(0)
other_seed = jax.random.key(1)
for initial_data, count, seed, shape in (
(('Dense_0',), 1, params_seed, (2, 2)),
((), 1, params_seed, (2, 2)),
((), 2, params_seed, (2, 2)),
((), 1, other_seed, (2, 2)),
(('Dense_1',), 1, params_seed, (2, 2)),
(('Dense_1',), 2, params_seed, (1, 2)),
):
hash_int = produce_hash(data=(*initial_data, count))
rng_key = jax.random.fold_in(seed, jnp.uint32(hash_int))
print(jax.random.normal(rng_key, shape))
[[-1.6185919 0.700908 ]
[-1.3146383 -0.79342234]]
[[ 0.0761425 -1.6157459]
[-1.6857724 0.7126891]]
[[ 0.60175574 0.2553228 ]
[ 0.27367848 -2.1975214 ]]
[[1.6249592 0.30813068]
[1.6613585 1.0404155 ]]
[[ 0.0030665 0.29551846]
[ 0.16670242 -0.78252524]]
[[1.582462 0.15216611]]
Managing RNG streams inside a training loop#
Below is an example of managing RNG streams from self.make_rng
, self.param
, self.variable
and nn.Dropout
in a training loop (note: nn.Dropout
requires a seed PRNG key to be passed in the 'dropout'
RNG stream, since it implicitly calls self.make_rng('dropout')
):
class SubModule(nn.Module):
@nn.compact
def __call__(self, x, train):
# variables created using `self.param` will use `self.make_rng('params')`
kernel = self.param('submodule_kernel', jax.random.normal, x.shape)
x = x + kernel
# `nn.Dropout` will use self.make_rng('dropout')
x = nn.Dropout(0.2)(x, deterministic=not train)
# `nn.Dense` will use self.make_rng('params')
x = nn.Dense(3)(x)
return x
class Model(nn.Module):
@nn.compact
def __call__(self, x, train):
# make kernel use `self.make_rng('other')`
kernel = self.variable(
'other_collection',
'module_kernel',
lambda: jax.random.normal(self.make_rng('other'), x.shape),
)
x = (
x + kernel.value
) # `.value` will extract the underlying value of the variable
x = SubModule()(x, train)
# `nn.Dropout` will use self.make_rng('dropout')
x = nn.Dropout(0.2)(x, deterministic=not train)
# `nn.Dense` will use self.make_rng('params')
x = nn.Dense(2)(x)
return x
params_rng, other_rng, train_rng = jax.random.split(jax.random.key(0), 3)
init_rngs = {'params': params_rng, 'other': other_rng}
x = jnp.ones((1, 3))
y = jnp.ones((1, 2))
module = Model()
variables = module.init(init_rngs, x, train=False)
def update(variables, rng):
# we don't need to provide a 'params' or 'other' rng, as only 'dropout' rng will be used during training
# split the rng to get a dropout_rng to be used for this training iteration,
# and to get another rng key to be used for the next training iteration
dropout_rng, next_rng = jax.random.split(rng)
def loss(params):
out = module.apply(
{'params': params, 'other_collection': variables['other_collection']},
x,
train=True,
rngs={'dropout': dropout_rng},
)
return jnp.mean((y - out) ** 2)
grads = jax.grad(loss)(variables['params'])
params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, variables['params'], grads)
return {
'params': params,
'other_collection': variables['other_collection'],
}, next_rng
for _ in range(10):
variables, train_rng = update(variables, train_rng)
out = module.apply(variables, x, train=False)
print(jnp.mean((y - out)**2))
2.518454
2.4859657
2.4171872
2.412684
2.3435805
2.2773488
2.2592616
2.2009292
2.1839895
2.1707344
🔪 Sharp edge 🔪 - unintentionally generating the same values#
There is an edge case where the same value can be unintentionally generated. See the Flax issue for more details.
class Leaf(nn.Module):
def __call__(self, x):
return x + jax.random.randint(self.make_rng("rng"), (), 0, 100)
class Node(nn.Module):
leaf_name: str
@nn.compact
def __call__(self, x):
return Leaf(name=self.leaf_name)(x)
class Model(nn.Module):
@nn.compact
def __call__(self, x):
return (Node(name="ab", leaf_name="cdef")(x),
Node(name="abc", leaf_name="def")(x),
)
out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # same output, despite having different submodule names
Array(True, dtype=bool)
This occurs because the hash function concatenates strings together, so the data ('AB', 'C')
is equivalent to data ('A', 'BC')
when fed into the hash function, therefore producing the same hash int.
print(produce_hash(data=('A', 'B', 'C', 1)))
print(produce_hash(data=('AB', 'C', 1)))
print(produce_hash(data=('A', 'BC', 1)))
print(produce_hash(data=('ABC', 1)))
947574064
947574064
947574064
947574064
To avoid this edge case, users can flip the flax_fix_rng_separator
configuration flag to True
.
flax.config.update('flax_fix_rng_separator', True)
out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # different output
Array(False, dtype=bool)
Managing RNG’s on multiple devices#
This section will show examples on how to use jit
and shard_map
to use RNG’s in multi-device settings.
Using jax.jit
#
When using jax.jit
, we can use RNG’s as we did before, but we now include in_shardings
and out_shardings
arguments to shard specify how to shard input and output data.
For more details on training on multiple devices in Flax using jax.jit
, see our Scale up Flax Modules on multiple devices guide and lm1b example.
# Create a mesh and annotate the axis with a name.
device_mesh = mesh_utils.create_device_mesh((8,))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data',))
print(mesh)
data_sharding = NamedSharding(mesh, PartitionSpec('data',))
print(data_sharding)
[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)
CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]
Mesh('data': 8)
NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec('data',), memory_kind=unpinned_host)
class Model(nn.Module):
@nn.compact
def __call__(self, x, add_noise):
x = nn.Dense(1)(x)
# use jnp.where for control flow; for more details see: https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
return jnp.where(
add_noise, x + jax.random.normal(self.make_rng('params'), x.shape), x
)
module = Model()
init_rng, apply_rng = jax.random.split(jax.random.key(0))
x = jnp.ones((8, 1))
variables = module.init(init_rng, x, False)
# create custom forward function, since jit does not support kwargs when in_shardings is specified
def forward(variables, x, add_noise, rng):
return module.apply(variables, x, add_noise, rngs={'params': rng})
# shard the inputs x across devices
# replicate the variables, add_noise boolean and rng key across devices
# shard the output across devices
jit_forward = jax.jit(
forward,
in_shardings=(None, data_sharding, None, None),
out_shardings=data_sharding,
)
out = jit_forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
[-2.8055234 ],
[-2.5464184 ],
[ 1.027039 ],
[-3.5243359 ],
[-2.2795477 ],
[-0.6504516 ],
[ 0.17373265]], dtype=float32)
The output is different given the same input, meaning the RNG key was used to add noise to the output.
We can also confirm that the output is sharded across devices:
out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-2.2187614]]),
Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-2.8055234]]),
Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-2.5464184]]),
Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[1.027039]]),
Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-3.5243359]]),
Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-2.2795477]]),
Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-0.6504516]]),
Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.17373265]])]
Another way to visualize the output sharding:
jax.debug.visualize_array_sharding(out)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
If we choose not to add noise, then the output is the same across all batches (as expected, since the input is the same for all batches):
out = jit_forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764]], dtype=float32)
We can confirm the un-jitted function produces the same values, albeit unsharded (note there may be small numerical differences due to compiler optimizations from jitting):
out = forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
[-2.8055234 ],
[-2.5464187 ],
[ 1.0270392 ],
[-3.5243359 ],
[-2.2795477 ],
[-0.6504516 ],
[ 0.17373264]], dtype=float32)
out = forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764]], dtype=float32)
Using shard_map
#
When using jax.experimental.shard_map.shard_map
, the important parts to remember are to:
split your PRNG key to produce a different key for each device
the PRNG keys will be sharded automatically to each device (provided you use the correct partition specification), but the rank of the original batched PRNG key array will not be reduced; e.g. with a batch of 8 PRNG keys and 8 devices, each device will see a PRNG key batch of size 1 within the
shard_map
-ed functiontherefore to access the PRNG key itself, we need to index slice into it (see the example below)
def forward(variables, x, add_noise, rng_key_batch):
# rng_key_batch is a batch of size 1 containing 1 PRNG key
# index slice into the rng_key_batch to access the PRNG key
return module.apply(
variables, x, add_noise, rngs={'params': rng_key_batch[0]}
)
# define partition specifications
data_pspec = PartitionSpec('data')
no_pspec = PartitionSpec()
# shard the inputs x and rng keys across devices
# replicate the variables and add_noise boolean across devices
# shard the output across devices
shmap_forward = shard_map(
forward,
mesh=mesh,
in_specs=(no_pspec, data_pspec, no_pspec, data_pspec),
out_specs=data_pspec,
)
# get 8 different rng's that will be used by the 8 devices when doing forward inference
apply_rngs = jax.random.split(apply_rng, 8)
out = shmap_forward(variables, x, True, apply_rngs)
out
Array([[-1.2605132 ],
[-1.2405176 ],
[-0.99350417],
[-1.0277128 ],
[-1.4154483 ],
[-0.3905797 ],
[-2.417677 ],
[ 0.9023453 ]], dtype=float32)
Confirm that the output is sharded across devices:
out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-1.2605132]]),
Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-1.2405176]]),
Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-0.99350417]]),
Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[-1.0277128]]),
Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-1.4154483]]),
Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-0.3905797]]),
Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-2.417677]]),
Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.9023453]])]
jax.debug.visualize_array_sharding(out)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
Lifted transforms#
Flax lifted transforms allow you to use JAX transforms with Module
arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.
Refer to Lifted transformations for more detail.
nn.vmap
#
We can use nn.vmap
to create a batched Dense
layer:
x = jnp.ones((3, 2))
BatchDense = nn.vmap(
nn.Dense,
in_axes=0, out_axes=0,
variable_axes={'params': None},
split_rngs={'params': False})
BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([0., 0.], dtype=float32),
'kernel': Array([[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]], dtype=float32)}}
By denoting variable_axes={'params': 0}'
, we vectorize the params
Arrays on the first axis. However the parameter values generated are all identical to each other:
BatchDense = nn.vmap(
nn.Dense,
in_axes=0, out_axes=0,
variable_axes={'params': 0},
split_rngs={'params': False})
BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]],
[[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]],
[[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]]], dtype=float32)}}
If we also make split_rngs={'params': True}
, then the PRNG key we provide is split across the variable axis (in this case, the batch axis 0), and we can generate different parameters for each batch input:
BatchDense = nn.vmap(
nn.Dense,
in_axes=0, out_axes=0,
variable_axes={'params': 0},
split_rngs={'params': True})
BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.2526208 , -0.15088455],
[-1.1987205 , -0.40843305]],
[[-0.7064888 , -1.108805 ],
[-0.938775 , 1.4812315 ]],
[[-0.59468937, -0.2502723 ],
[-1.33515 , 0.5067442 ]]], dtype=float32)}}
Adding a variable via self.variable
is straightforward:
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(2)(x)
kernel = self.variable(
'other_collection',
'kernel',
lambda: jax.random.normal(self.make_rng('other'), x.shape),
)
return x + kernel.value
BatchModel = nn.vmap(
Model,
in_axes=0,
out_axes=0,
variable_axes={'params': 0, 'other_collection': 0},
split_rngs={'params': True, 'other': True},
)
BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.9079084 , 0.76390624],
[-0.01285526, 0.4320353 ]],
[[ 0.12398645, 0.7884565 ],
[ 1.5344163 , 1.3186085 ]],
[[-0.44171348, 0.43430036],
[-0.40732604, 0.29774475]]], dtype=float32)}},
'other_collection': {'kernel': Array([[-0.8193048 , 0.711106 ],
[-0.37802765, -0.66705877],
[-0.44808003, 0.93031347]], dtype=float32)}}
We can control which RNG stream to split, for example, if we only wanted to split the 'params'
RNG stream, then the variables generated from self.variable
will be the same for each batch input:
BatchModel = nn.vmap(
Model,
in_axes=0, out_axes=0,
variable_axes={'params': 0, 'other_collection': 0},
split_rngs={'params': True, 'other': False})
BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.9079084 , 0.76390624],
[-0.01285526, 0.4320353 ]],
[[ 0.12398645, 0.7884565 ],
[ 1.5344163 , 1.3186085 ]],
[[-0.44171348, 0.43430036],
[-0.40732604, 0.29774475]]], dtype=float32)}},
'other_collection': {'kernel': Array([[ 0.44956833, -1.1854612 ],
[ 0.44956833, -1.1854612 ],
[ 0.44956833, -1.1854612 ]], dtype=float32)}}
We can also control which parameters / variables should be generated for each batch input, for example, if we only wanted 'params'
to generate separate parameters for each batch input:
BatchModel = nn.vmap(
Model,
in_axes=0, out_axes=0,
variable_axes={'params': 0, 'other_collection': None},
split_rngs={'params': True, 'other': False})
BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.9079084 , 0.76390624],
[-0.01285526, 0.4320353 ]],
[[ 0.12398645, 0.7884565 ],
[ 1.5344163 , 1.3186085 ]],
[[-0.44171348, 0.43430036],
[-0.40732604, 0.29774475]]], dtype=float32)}},
'other_collection': {'kernel': Array([ 0.44956833, -1.1854612 ], dtype=float32)}}
nn.scan
#
We can use nn.scan
to create a scanned Module
layer (this is useful for simplifying repetitively stacked submodules):
x = jnp.ones((3, 2))
class ResidualMLPBlock(nn.Module):
@nn.compact
def __call__(self, x, _):
h = nn.Dense(features=2)(x)
h = nn.relu(h)
return x + h, None # return an empty carry
ScanMLP = nn.scan(
ResidualMLPBlock, variable_axes={'params': 0},
variable_broadcast=False, split_rngs={'params': True},
length=3)
ScanMLP().init(jax.random.key(0), x, None) # pass in an empty carry
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.07838312, -0.7422982 ],
[ 0.87488323, 0.13773395]],
[[ 0.97309333, 0.9087693 ],
[-0.12564984, -1.0920651 ]],
[[-0.99055105, 1.1499453 ],
[-0.15721127, -0.62520015]]], dtype=float32)}}}
Similar to before, we can control whether to split the RNG stream or not, for example, if we wanted all the stacked modules to be initialized to the same parameter values, we can pass in split_rngs={'params': False}
:
ScanMLP = nn.scan(
ResidualMLPBlock, variable_axes={'params': 0},
variable_broadcast=False, split_rngs={'params': False},
length=3)
ScanMLP().init(jax.random.key(0), x, None)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.66715515, -0.0484313 ],
[ 0.9867164 , 0.75408363]],
[[-0.66715515, -0.0484313 ],
[ 0.9867164 , 0.75408363]],
[[-0.66715515, -0.0484313 ],
[ 0.9867164 , 0.75408363]]], dtype=float32)}}}