Layers#
Linear Modules#
- class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A linear transformation applied over the last dimension of the input.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Dense(features=4) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4,), 'kernel': (3, 4)}}
- features#
the number of output features.
- Type:
int
- use_bias#
whether to add a bias to the output (default: True).
- Type:
bool
- dtype#
the dtype of the computation (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- precision#
numerical precision of the computation see
jax.lax.Precisionfor details.- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- kernel_init#
initializer function for the weight matrix.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer function for the bias.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- promote_dtype#
function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.- Type:
flax.linen.linear.PromoteDtypeFn
- __call__(inputs)[source]#
Applies a linear transformation to the inputs along the last dimension.
- Parameters:
inputs – The nd-array to be transformed.
- Returns:
The transformed input.
Methods
- class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, promote_dtype=<function promote_dtype>, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A linear transformation with flexible axes.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # equivalent to `nn.Dense(features=4)` >>> layer = nn.DenseGeneral(features=4) >>> # output features (4, 5) >>> layer = nn.DenseGeneral(features=(4, 5)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}} >>> # apply transformation on the second and last axes >>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
- features#
int or tuple with number of output features.
- Type:
int | collections.abc.Sequence[int]
- axis#
int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.
- Type:
int | collections.abc.Sequence[int]
- batch_dims#
tuple with batch axes.
- Type:
collections.abc.Sequence[int]
- use_bias#
whether to add a bias to the output (default: True).
- Type:
bool
- dtype#
the dtype of the computation (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- kernel_init#
initializer function for the weight matrix.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer function for the bias.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- precision#
numerical precision of the computation see
jax.lax.Precisionfor details.- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- promote_dtype#
function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.- Type:
flax.linen.linear.PromoteDtypeFn
- __call__(inputs)[source]#
Applies a linear transformation to the inputs along multiple dimensions.
- Parameters:
inputs – The nd-array to be transformed.
- Returns:
The transformed input.
Methods
- class flax.linen.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Convolution Module wrapping
lax.conv_general_dilated.Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
- features#
number of convolution filters.
- Type:
int
- kernel_size#
shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer.
- Type:
int | collections.abc.Sequence[int]
- strides#
an integer or a sequence of n integers, representing the inter-window strides (default: 1).
- Type:
None | int | collections.abc.Sequence[int]
- padding#
either the string
'SAME', the string'VALID', the string'CIRCULAR'(periodic boundary conditions), or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides.'CAUSAL'padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.- Type:
str | int | collections.abc.Sequence[int | tuple[int, int]]
- input_dilation#
an integer or a sequence of
nintegers, giving the dilation factor to apply in each spatial dimension ofinputs(default: 1). Convolution with input dilationdis equivalent to transposed convolution with strided.- Type:
None | int | collections.abc.Sequence[int]
- kernel_dilation#
an integer or a sequence of
nintegers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.- Type:
None | int | collections.abc.Sequence[int]
- feature_group_count#
integer, default 1. If specified divides the input features into groups.
- Type:
int
- use_bias#
whether to add a bias to the output (default: True).
- Type:
bool
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- Type:
jax.Array | Any | None
- dtype#
the dtype of the computation (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- precision#
numerical precision of the computation see ``jax.lax.Precision` for details.
- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- kernel_init#
initializer for the convolutional kernel.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- promote_dtype#
function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.- Type:
flax.linen.linear.PromoteDtypeFn
- __call__(inputs)#
Applies a (potentially unshared) convolution to the inputs.
- Parameters:
inputs – input data with dimensions
(*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used bylax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns:
The convolved data.
Methods
- class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, promote_dtype=<function promote_dtype>, preferred_element_type=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Convolution Module wrapping
lax.conv_transpose.Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True) >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}} >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
- features#
number of convolution filters.
- Type:
int
- kernel_size#
shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.
- Type:
int | collections.abc.Sequence[int]
- strides#
an integer or a sequence of n integers, representing the inter-window strides.
- Type:
collections.abc.Sequence[int] | None
- padding#
either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides.
- Type:
str | int | collections.abc.Sequence[int | tuple[int, int]]
- kernel_dilation#
None, or an integer or a sequence ofnintegers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as ‘atrous convolution’.- Type:
collections.abc.Sequence[int] | None
- use_bias#
whether to add a bias to the output (default: True).
- Type:
bool
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- Type:
jax.Array | Any | None
- dtype#
the dtype of the computation (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- precision#
numerical precision of the computation see
jax.lax.Precisionfor details.- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- kernel_init#
initializer for the convolutional kernel.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- transpose_kernel#
if
Trueflips spatial axes and swaps the input/output channel axes of the kernel.- Type:
bool
- promote_dtype#
function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.- Type:
flax.linen.linear.PromoteDtypeFn
- __call__(inputs)[source]#
Applies a transposed convolution to the inputs.
Behaviour mirrors of
jax.lax.conv_transpose.- Parameters:
inputs – input data with dimensions
(*batch_dims, spatial_dims..., features).This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used bylax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns:
The convolved data.
Methods
- class flax.linen.ConvLocal(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Local convolution Module wrapping
lax.conv_general_dilated_local.Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((6, 9, 4))) >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
- features#
number of convolution filters.
- Type:
int
- kernel_size#
shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer.
- Type:
int | collections.abc.Sequence[int]
- strides#
an integer or a sequence of n integers, representing the inter-window strides (default: 1).
- Type:
None | int | collections.abc.Sequence[int]
- padding#
either the string
'SAME', the string'VALID', the string'CIRCULAR'(periodic boundary conditions), or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides.'CAUSAL'padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.- Type:
str | int | collections.abc.Sequence[int | tuple[int, int]]
- input_dilation#
an integer or a sequence of
nintegers, giving the dilation factor to apply in each spatial dimension ofinputs(default: 1). Convolution with input dilationdis equivalent to transposed convolution with strided.- Type:
None | int | collections.abc.Sequence[int]
- kernel_dilation#
an integer or a sequence of
nintegers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.- Type:
None | int | collections.abc.Sequence[int]
- feature_group_count#
integer, default 1. If specified divides the input features into groups.
- Type:
int
- use_bias#
whether to add a bias to the output (default: True).
- Type:
bool
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- Type:
jax.Array | Any | None
- dtype#
the dtype of the computation (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- precision#
numerical precision of the computation see
jax.lax.Precisionfor details.- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- kernel_init#
initializer for the convolutional kernel.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- promote_dtype#
function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.- Type:
flax.linen.linear.PromoteDtypeFn
- __call__(inputs)#
Applies a (potentially unshared) convolution to the inputs.
- Parameters:
inputs – input data with dimensions
(*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used bylax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns:
The convolved data.
Methods
- class flax.linen.Einsum(shape, einsum_str=None, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, preferred_element_type=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
An einsum transformation with learnable kernel and bias.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Einsum((5, 6, 7), 'abc,cde->abde') >>> variables = layer.init(jax.random.key(0), jnp.ones((3, 4, 5))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (6, 7), 'kernel': (5, 6, 7)}}
- shape#
the shape of the kernel.
- Type:
collections.abc.Sequence[int]
- einsum_str#
a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of
einsum_strin the constructor argument and call argument must be not None, while the other must be None.- Type:
str | None
- use_bias#
whether to add a bias to the output (default: True).
- Type:
bool
- dtype#
the dtype of the computation (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- precision#
numerical precision of the computation see
jax.lax.Precisionfor details.- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- kernel_init#
initializer function for the weight matrix.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer function for the bias.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- promote_dtype#
function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.- Type:
flax.linen.linear.PromoteDtypeFn
- __call__(inputs, einsum_str=None)[source]#
Applies a linear transformation to the inputs along the last dimension.
- Parameters:
inputs – The nd-array to be transformed.
einsum_str – a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. The
einsum_strpassed into the call method will take precedence over theeinsum_strpassed into the constructor.
- Returns:
The transformed input.
Methods
- class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, promote_dtype=<function promote_dtype>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Embedding Module.
A parameterized function from integers [0,
num_embeddings) tofeatures-dimensional vectors. ThisModulewill create anembeddingmatrix with shape(num_embeddings, features). When calling this layer, the input values will be used to 0-index into theembeddingmatrix. Indexing on a value greater than or equal tonum_embeddingswill result innanvalues. Whennum_embeddingsequals to 1, it will broadcast theembeddingmatrix to input shape withfeaturesdimension appended.Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Embed(num_embeddings=5, features=3) >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> variables = layer.init(jax.random.key(0), indices_input) >>> variables {'params': {'embedding': Array([[ 0.04396089, -0.9328513 , -0.97328115], [ 0.41147125, 0.66334754, 0.49469155], [ 0.09719624, 0.49861377, 0.49519277], [-0.13316602, 0.6697022 , 0.3710195 ], [-0.5039532 , 0.287319 , 1.4369922 ]], dtype=float32)}} >>> # get the first three and last three embeddings >>> layer.apply(variables, indices_input) Array([[[ 0.04396089, -0.9328513 , -0.97328115], [ 0.41147125, 0.66334754, 0.49469155], [ 0.09719624, 0.49861377, 0.49519277]], [[-0.5039532 , 0.287319 , 1.4369922 ], [-0.13316602, 0.6697022 , 0.3710195 ], [ 0.09719624, 0.49861377, 0.49519277]]], dtype=float32)
- num_embeddings#
number of embeddings / vocab size.
- Type:
int
- features#
number of feature dimensions for each embedding.
- Type:
int
- dtype#
the dtype of the embedding vectors (default: same as embedding).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- embedding_init#
embedding initializer.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- promote_dtype#
function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(embedding,)during__call__or(query, embedding)duringattend, and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.- Type:
flax.linen.linear.PromoteDtypeFn
- __call__(inputs)[source]#
Embeds the inputs along the last dimension.
- Parameters:
inputs – input data, all dimensions are considered batch dimensions. Values in the input array must be integers.
- Returns:
Output which is embedded input data. The output shape follows the input, with an additional
featuresdimension appended.
- attend(query)[source]#
Attend over the embedding using a query array.
- Parameters:
query – array with last dimension equal the feature depth
featuresof the embedding.- Returns:
An array with final dim
num_embeddingscorresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.
Methods
attend(query)Attend over the embedding using a query array.
Pooling#
- flax.linen.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#
Pools the input by taking the maximum of a window slice.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of
nintegers, representing the inter-window strides (default:(1, ..., 1)).padding – either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension (default:'VALID').
- Returns:
The maximum for each window slice.
- flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#
Pools the input by taking the average over a window.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of
nintegers, representing the inter-window strides (default:(1, ..., 1)).padding – either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension (default:'VALID').count_include_pad – a boolean whether to include padded tokens in the average calculation (default:
True).
- Returns:
The average for each window slice.
- flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#
Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op.
Note
Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
init – the initial value for the reduction
reduce_fn – a reduce function of the form
(T, T) -> T.window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of
nintegers, representing the inter-window strides (default:(1, ..., 1)).padding – either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension.
- Returns:
The output of the reduction for each window slice.
Normalization#
- class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
BatchNorm Module.
Usage Note: If we define a model with BatchNorm, for example:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> BN = nn.BatchNorm(momentum=0.9, epsilon=1e-5, dtype=jnp.float32)
The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain all the running statistics for all the BatchNorm layers in a model:
>>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> variables = BN.init(jax.random.key(1), x, use_running_average=False) >>> jax.tree_util.tree_map(jnp.shape, variables) {'batch_stats': {'mean': (6,), 'var': (6,)}, 'params': {'bias': (6,), 'scale': (6,)}}
We then update the batch_stats during training by specifying that the
batch_statscollection is mutable in theapplymethod for our module.:>>> y, new_batch_stats = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=False)
During eval we would define BN with
use_running_average=Trueand use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn’t mark it mutable:>>> y = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=True)
- use_running_average#
if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
- Type:
bool | None
- axis#
the feature or non-batch axis of the input.
- Type:
int
- momentum#
decay rate for the exponential moving average of the batch statistics.
- Type:
float
- epsilon#
a small float added to variance to avoid dividing by zero.
- Type:
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- use_bias#
if True, bias (beta) is added.
- Type:
bool
- use_scale#
if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type:
bool
- bias_init#
initializer for bias, by default, zero.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- scale_init#
initializer for scale, by default, one.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- axis_name#
the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.- Type:
str | None
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.- Type:
Any
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- Type:
bool
- __call__(x, use_running_average=None, *, mask=None)[source]#
Normalizes the input using batch statistics.
Note
During initialization (when
self.is_initializing()isTrue) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don’t need to match that of the actual input distribution and the reduction axis (set withaxis_name) does not have to exist.- Parameters:
x – the input to be normalized.
use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
mask – Binary array of shape broadcastable to
inputstensor, indicating the positions for which the mean and variance should be computed.
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Layer normalization (https://arxiv.org/abs/1607.06450).
LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
Note
This normalization operation is identical to InstanceNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters).
Example usage:
>>> import flax.linen as nn >>> import jax >>> import numpy as np >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nn.LayerNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> y = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x) >>> y2 = nn.GroupNorm(num_groups=1).apply(variables, x) >>> np.testing.assert_allclose(y, y2) >>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x) >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2)
- epsilon#
A small float added to variance to avoid dividing by zero.
- Type:
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- use_bias#
If True, bias (beta) is added.
- Type:
bool
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type:
bool
- bias_init#
Initializer for bias, by default, zero.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- scale_init#
Initializer for scale, by default, one.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- reduction_axes#
Axes for computing normalization statistics.
- Type:
int | collections.abc.Sequence[int]
- feature_axes#
Feature axes for learned bias and scaling.
- Type:
int | collections.abc.Sequence[int]
- axis_name#
the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.- Type:
str | None
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.- Type:
Any
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- Type:
bool
- __call__(x, *, mask=None)[source]#
Applies layer normalization on the input.
- Parameters:
x – the inputs
mask – Binary array of shape broadcastable to
inputstensor, indicating the positions for which the mean and variance should be computed.
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Group normalization (arxiv.org/abs/1803.08494).
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.
Note
LayerNorm is a special case of GroupNorm where
num_groups=1, and InstanceNorm is a special case of GroupNorm wheregroup_size=1.Example usage:
>>> import flax.linen as nn >>> import jax >>> import numpy as np >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nn.GroupNorm(num_groups=3) >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> y = nn.GroupNorm(num_groups=1).apply(variables, x) >>> y2 = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x) >>> np.testing.assert_allclose(y, y2) >>> y = nn.GroupNorm(num_groups=None, group_size=1).apply(variables, x) >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2)
- num_groups#
the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
- Type:
int | None
- group_size#
the number of channels in a group.
- Type:
int | None
- epsilon#
A small float added to variance to avoid dividing by zero.
- Type:
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- use_bias#
If True, bias (beta) is added.
- Type:
bool
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type:
bool
- bias_init#
Initializer for bias, by default, zero.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- scale_init#
Initializer for scale, by default, one.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- reduction_axes#
List of axes used for computing normalization statistics. This list must include the final dimension, which is assumed to be the feature axis. Furthermore, if the input used at call time has additional leading axes compared to the data used for initialisation, for example due to batching, then the reduction axes need to be defined explicitly.
- Type:
int | collections.abc.Sequence[int] | None
- axis_name#
the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.- Type:
str | None
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.- Type:
Any
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- Type:
bool
- __call__(x, *, mask=None)[source]#
Applies group normalization to the input (arxiv.org/abs/1803.08494).
- Parameters:
x – the input of shape
...CwhereCis a channels dimension and...represents an arbitrary number of extra dimensions that can be used to accumulate statistics over. If no reduction axes have been specified then all additional dimensions...will be used to accumulate statistics apart from the leading dimension which is assumed to represent the batch.mask – Binary array of shape broadcastable to
inputstensor, indicating the positions for which the mean and variance should be computed.
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.linen.RMSNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
RMS Layer normalization (https://arxiv.org/abs/1910.07467).
RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.
Example usage:
>>> import flax.linen as nn >>> import jax >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nn.RMSNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}} >>> y = layer.apply(variables, x)
- epsilon#
A small float added to variance to avoid dividing by zero.
- Type:
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type:
bool
- scale_init#
Initializer for scale, by default, one.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- reduction_axes#
Axes for computing normalization statistics.
- Type:
int | collections.abc.Sequence[int]
- feature_axes#
Feature axes for learned bias and scaling.
- Type:
int | collections.abc.Sequence[int]
- axis_name#
the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.- Type:
str | None
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.- Type:
Any
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- Type:
bool
- __call__(x, *, mask=None)[source]#
Applies RMS layer normalization on the input.
- Parameters:
x – the inputs
mask – Binary array of shape broadcastable to
inputstensor, indicating the positions for which the mean and variance should be computed.
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.linen.InstanceNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Instance normalization (https://arxiv.org/abs/1607.08022v3).
InstanceNorm normalizes the activations of the layer for each channel (rather than across all channels like Layer Normalization), and for each given example in a batch independently (rather than across an entire batch like Batch Normalization). i.e. applies a transformation that maintains the mean activation within each channel within each example close to 0 and the activation standard deviation close to 1.
Note
This normalization operation is identical to LayerNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters).
Example usage:
>>> import flax.linen as nn >>> import jax >>> import numpy as np >>> # dimensions: (batch, height, width, channel) >>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) >>> layer = nn.InstanceNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch, >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm >>> y2 = nn.LayerNorm(reduction_axes=[1, 2], feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2, atol=1e-7) >>> y3 = nn.GroupNorm(num_groups=x.shape[-1]).apply(variables, x) >>> np.testing.assert_allclose(y, y3, atol=1e-7)
- epsilon#
A small float added to variance to avoid dividing by zero.
- Type:
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- use_bias#
If True, bias (beta) is added.
- Type:
bool
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type:
bool
- bias_init#
Initializer for bias, by default, zero.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- scale_init#
Initializer for scale, by default, one.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- feature_axes#
Axes for features. The learned bias and scaling parameters will be in the shape defined by the feature axes. All other axes except the batch axes (which is assumed to be the leading axis) will be reduced.
- Type:
int | collections.abc.Sequence[int]
- axis_name#
the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.- Type:
str | None
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.- Type:
Any
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- Type:
bool
- __call__(x, *, mask=None)[source]#
Applies instance normalization on the input.
- Parameters:
x – the inputs
mask – Binary array of shape broadcastable to
inputstensor, indicating the positions for which the mean and variance should be computed.
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Spectral normalization.
See:
Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its
__call__output.Note
The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain a
uvector andsigmavalue, which are intermediate values used when performing spectral normalization. During training, we pass inupdate_stats=Trueandmutable=['batch_stats']so thatuandsigmaare updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass inupdate_stats=Falseto ensure we get deterministic behavior from the model.Example usage:
>>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> import optax >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(3)(x) ... # only spectral normalize the params of the second Dense layer ... x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train) ... x = nn.Dense(5)(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 5)) >>> model = Foo() >>> variables = model.init(jax.random.PRNGKey(0), x, train=False) >>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables)) FrozenDict({ batch_stats: { SpectralNorm_0: { Dense_1/kernel/sigma: (), Dense_1/kernel/u: (1, 4), }, }, params: { Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, }, }) >>> # train >>> def train_step(variables, x, y): ... def loss_fn(params): ... logits, updates = model.apply( ... {'params': params, 'batch_stats': variables['batch_stats']}, ... x, ... train=True, ... mutable=['batch_stats'], ... ) ... loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y)) ... return loss, updates ... ... (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)( ... variables['params'] ... ) ... return { ... 'params': jax.tree_util.tree_map( ... lambda p, g: p - 0.1 * g, variables['params'], grads ... ), ... 'batch_stats': updates['batch_stats'], ... }, loss >>> for _ in range(10): ... variables, loss = train_step(variables, x, y) >>> # inference / eval >>> out = model.apply(variables, x, train=False)
- layer_instance#
Module instance that is wrapped with SpectralNorm
- Type:
- n_steps#
How many steps of power iteration to perform to approximate the singular value of the weight params.
- Type:
int
- epsilon#
A small float added to l2-normalization to avoid dividing by zero.
- Type:
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- error_on_non_matrix#
Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer.
- Type:
bool
- collection_name#
Name of the collection to store intermediate values used when performing spectral normalization.
- Type:
str
- __call__(*args, update_stats, **kwargs)[source]#
Compute the largest singular value of the weights in
self.layer_instanceusing power iteration and normalize the weights using this value before computing the__call__output.- Parameters:
*args – positional arguments to be passed into the call method of the underlying layer instance in
self.layer_instance.update_stats – if True, update the internal
uvector andsigmavalue after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time.**kwargs – keyword arguments to be passed into the call method of the underlying layer instance in
self.layer_instance.
- Returns:
Output of the layer using spectral normalized weights.
Methods
- class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
L2 weight normalization (https://arxiv.org/abs/1602.07868).
Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its
__call__output.Example usage:
>>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Baz(nn.Module): ... @nn.compact ... def __call__(self, x): ... return nn.Dense(2)(x) >>> class Bar(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = Baz()(x) ... x = nn.Dense(3)(x) ... x = Baz()(x) ... x = nn.Dense(3)(x) ... return x >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... # l2-normalize all params of the second Dense layer ... x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) ... x = nn.Dense(5)(x) ... # l2-normalize all kernels in the Bar submodule and all params in ... # the Baz submodule ... x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables)) FrozenDict({ params: { Bar_0: { Baz_0: { Dense_0: { bias: (2,), kernel: (5, 2), }, }, Baz_1: { Dense_0: { bias: (2,), kernel: (3, 2), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (3,), kernel: (2, 3), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, WeightNorm_0: { Dense_1/bias/scale: (4,), Dense_1/kernel/scale: (4,), }, WeightNorm_1: { Bar_0/Baz_0/Dense_0/bias/scale: (2,), Bar_0/Baz_0/Dense_0/kernel/scale: (2,), Bar_0/Baz_1/Dense_0/bias/scale: (2,), Bar_0/Baz_1/Dense_0/kernel/scale: (2,), Bar_0/Dense_0/kernel/scale: (3,), Bar_0/Dense_1/kernel/scale: (3,), }, }, })
- layer_instance#
Module instance that is wrapped with WeightNorm
- Type:
- epsilon#
A small float added to l2-normalization to avoid dividing by zero.
- Type:
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- use_scale#
If True, creates a learnable variable
scalethat is multiplied to thelayer_instancevariables after l2-normalization.- Type:
bool
- scale_init#
Initialization function for the scaling function.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- feature_axes#
The feature axes dimension(s). The l2-norm is calculated by reducing the
layer_instancevariables over the remaining (non-feature) axes. Therefore a separate l2-norm value is calculated and a separate scale (ifuse_scale=True) is learned for each specified feature. By default, the trailing dimension is treated as the feature axis.- Type:
int | collections.abc.Sequence[int] | None
- variable_filter#
An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the
layer_instancevariables whose key path (delimited by ‘/’) has a match withvariable_filter. For example,variable_filter={'kernel'}will only apply l2-normalization to variables whose key path contains ‘kernel’. By default,variable_filter={'kernel'}.- Type:
collections.abc.Iterable | None
- __call__(*args, **kwargs)[source]#
Compute the l2-norm of the weights in
self.layer_instanceand normalize the weights using this value before computing the__call__output.- Parameters:
*args – positional arguments to be passed into the call method of the underlying layer instance in
self.layer_instance.**kwargs – keyword arguments to be passed into the call method of the underlying layer instance in
self.layer_instance.
- Returns:
Output of the layer using l2-normalized weights.
Methods
Combinators#
- class flax.linen.Sequential(layers, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Applies a linear chain of Modules.
Meant to be used only for the simple case of fusing together callables where the input of a particular module/op is the output of the previous one.
Modules will be applied in the order that they are passed in the constructor.
The
__call__method of Sequential accepts any input and forwards it to the first module it contains. It chains the output sequentially to the input of the next module and returns the output of the final module.Example usage:
>>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... return nn.Sequential([nn.Dense(4), ... nn.relu, ... nn.Dense(2), ... nn.log_softmax])(x)
Since Sequential.__call__ is a compact method, you can also pass functions that construct Modules inline if you need shape inference:
module = nn.Sequential([ # << more layers lambda x: SomeModule(x.shape[-1])(x), # shape inference # << more layers ])
This combinator supports also layers that return multiple outputs if returned as a tuple or a dictionary. If the output of a layer is a
tupleit will be expanded as*argsin the next layer, if its adictit will be expanded as**kwargs.Example usage:
>>> class CrossAttentionBlock(nn.Module): ... num_heads: int = 2 ... qkv_features: int = 16 ... ... @nn.compact ... def __call__(self, query, key_value): ... output = nn.MultiHeadDotProductAttention( ... num_heads=self.num_heads, qkv_features=self.qkv_features)(query, ... key_value) ... output = nn.Dense(self.qkv_features)(output) ... return dict(query=output, key_value=key_value) # also works for tuples >>> from typing import Sequence >>> class CrossAttentionNetwork(nn.Module): ... num_layers: Sequence[int] ... ... @nn.compact ... def __call__(self, x): ... return nn.Sequential([CrossAttentionBlock() for _ in ... range(self.num_layers)])(query, key_value)
- layers#
A sequence of callables to be applied in order.
- Type:
collections.abc.Sequence[collections.abc.Callable[[…], Any]]
- Raises:
ValueError – If layers is not a sequence.
Methods
Stochastic#
- class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Create a dropout layer.
Note
When using
Module.apply(), make sure to include an RNG seed named'dropout'. Dropout isn’t necessary for variable initialization.Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class MLP(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(4)(x) ... x = nn.Dropout(0.5, deterministic=not train)(x) ... return x >>> model = MLP() >>> x = jnp.ones((1, 3)) >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout >>> model.apply(variables, x, train=False) # don't use dropout Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
- rate#
the dropout probability. (_not_ the keep rate!)
- Type:
float
- broadcast_dims#
dimensions that will share the same dropout mask
- Type:
collections.abc.Sequence[int]
- deterministic#
if false the inputs are scaled by
1 / (1 - rate)and masked, whereas if true, no mask is applied and the inputs are returned as is.- Type:
bool | None
- rng_collection#
the rng collection name to use when requesting an rng key.
- Type:
str
- __call__(inputs, deterministic=None, rng=None)[source]#
Applies a random dropout mask to the input.
- Parameters:
inputs – the inputs that should be randomly masked.
deterministic – if false the inputs are scaled by
1 / (1 - rate)and masked, whereas if true, no mask is applied and the inputs are returned as is.rng – an optional PRNGKey used as the random key, if not specified, one will be generated using
make_rngwith therng_collectionname.
- Returns:
The masked inputs reweighted to preserve mean.
Methods
Attention#
- class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Multi-head dot-product attention.
Example usage:
>>> import flax.linen as nn >>> import jax >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) >>> shape = (4, 3, 2, 5) >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) >>> variables = layer.init(jax.random.key(0), q) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer.apply(variables, q, k, v) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) >>> out = layer.apply(variables, q, k) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) >>> out = layer.apply(variables, q) >>> attention_kwargs = dict( ... num_heads=8, ... qkv_features=16, ... kernel_init=nn.initializers.ones, ... bias_init=nn.initializers.zeros, ... dropout_rate=0.5, ... deterministic=False, ... ) >>> class Module(nn.Module): ... attention_kwargs: dict ... ... @nn.compact ... def __call__(self, x, dropout_rng=None): ... out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... return out1, out2 >>> module = Module(attention_kwargs) >>> variables = module.init({'params': key1, 'dropout': key2}, q) >>> # out1 and out2 are different. >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) >>> # out3 and out4 are different. >>> # out1 and out3 are different. out2 and out4 are different. >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) >>> # out1 and out2 are the same. >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) >>> # out1 and out2 are the same as out3 and out4. >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
- num_heads#
Number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.
- Type:
int
- dtype#
The dtype of the computation (default: infer from inputs and params)
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
The dtype passed to parameter initializers (default: float32)
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- qkv_features#
Dimension of the key, query, and value.
- Type:
int | None
- out_features#
Dimension of the last projection
- Type:
int | None
- broadcast_dropout#
Use a broadcasted dropout along batch dims.
- Type:
bool
- dropout_rate#
Dropout rate.
- Type:
float
- deterministic#
If False, the attention weight is masked randomly using dropout, whereas if True, the attention weights are deterministic.
- Type:
bool | None
- precision#
Numerical precision of the computation see
jax.lax.Precisionfor details.- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- kernel_init#
Initializer for the kernel of the Dense layers.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- out_kernel_init#
Optional Initializer for the kernel of the output Dense layer, if None,
kernel_initwill be used.- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any] | None
- bias_init#
Initializer for the bias of the Dense layers.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- out_bias_init#
Optional Initializer for the bias of the output Dense layer, if None,
bias_initwill be used.- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any] | None
- use_bias#
Whether pointwise QKVO dense transforms use bias.
- Type:
bool
- attention_fn#
dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape
[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]- Type:
collections.abc.Callable[[…], jax.Array | Any]
- decode#
Whether to prepare and use an autoregressive cache.
- Type:
bool
- normalize_qk#
Should QK normalization be applied (arxiv.org/abs/2302.05442).
- Type:
bool
- qk_attn_weights_einsum_cls#
factory function to create the einsum for computing the attention weights.
- Type:
collections.abc.Callable[[…], collections.abc.Callable[[…], jax.Array | Any]] | None
- attn_weights_value_einsum_cls#
factory function to create the einsum for computing the product of the attention weights and the values.
- Type:
collections.abc.Callable[[…], collections.abc.Callable[[…], jax.Array | Any]] | None
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#
Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.
If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k.
- Parameters:
inputs_q – input queries of shape
[batch_sizes..., length, features].inputs_k – key of shape
[batch_sizes..., length, features]. If None, inputs_k will copy the value of inputs_q.inputs_v – values of shape
[batch_sizes..., length, features]. If None, inputs_v will copy the value of inputs_k.inputs_kv – key/values of shape
[batch_sizes..., length, features]. If None, inputs_kv will copy the value of inputs_q. This arg will be deprecated soon. Use inputs_k and inputs_v instead.mask – attention mask of shape
[batch_sizes..., num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value isFalse.deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
dropout_rng – optional rng key to pass to the attention layer’s dropout mask. Otherwise, self.make_rng(‘dropout’) is used instead.
sow_weights – if
True, the attention weights are sowed into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable viamutable=['intermediates']in order to have that collection returned.
- Returns:
output of shape
[batch_sizes..., length, features].
Methods
- class flax.linen.MultiHeadAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Multi-head dot-product attention. Alias for
MultiHeadDotProductAttention.NOTE:
MultiHeadAttentionis a wrapper ofMultiHeadDotProductAttention, and so their implementations are identical. HoweverMultiHeadAttentionlayers will, by default, be namedMultiHeadAttention_{index}, whereasMultiHeadDotProductAttentionwill be namedMultiHeadDotProductAttention_{index}. Therefore, this could affect checkpointing, param collection names and RNG threading (since the layer name is used when generating new RNG’s) within the module.Example usage:
>>> import flax.linen as nn >>> import jax >>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16) >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) >>> shape = (4, 3, 2, 5) >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) >>> variables = layer.init(jax.random.key(0), q) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer.apply(variables, q, k, v) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) >>> out = layer.apply(variables, q, k) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) >>> out = layer.apply(variables, q) >>> attention_kwargs = dict( ... num_heads=8, ... qkv_features=16, ... kernel_init=nn.initializers.ones, ... bias_init=nn.initializers.zeros, ... dropout_rate=0.5, ... deterministic=False, ... ) >>> class Module(nn.Module): ... attention_kwargs: dict ... ... @nn.compact ... def __call__(self, x, dropout_rng=None): ... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... return out1, out2 >>> module = Module(attention_kwargs) >>> variables = module.init({'params': key1, 'dropout': key2}, q) >>> # out1 and out2 are different. >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) >>> # out3 and out4 are different. >>> # out1 and out3 are different. out2 and out4 are different. >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) >>> # out1 and out2 are the same. >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) >>> # out1 and out2 are the same as out3 and out4. >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
- num_heads#
number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.
- Type:
int
- dtype#
the dtype of the computation (default: infer from inputs and params)
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32)
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- qkv_features#
dimension of the key, query, and value.
- Type:
int | None
- out_features#
dimension of the last projection
- Type:
int | None
- broadcast_dropout#
bool: use a broadcasted dropout along batch dims.
- Type:
bool
- dropout_rate#
dropout rate
- Type:
float
- deterministic#
if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
- Type:
bool | None
- precision#
numerical precision of the computation see
jax.lax.Precisionfor details.- Type:
None | str | jax._src.lax.lax.Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]
- kernel_init#
initializer for the kernel of the Dense layers.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias of the Dense layers.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- use_bias#
bool: whether pointwise QKVO dense transforms use bias.
- Type:
bool
- attention_fn#
dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape
[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]- Type:
collections.abc.Callable[[…], jax.Array | Any]
- decode#
whether to prepare and use an autoregressive cache.
- Type:
bool
- normalize_qk#
should QK normalization be applied (arxiv.org/abs/2302.05442).
- Type:
bool
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)#
Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.
If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k.
- Parameters:
inputs_q – input queries of shape
[batch_sizes..., length, features].inputs_k – key of shape
[batch_sizes..., length, features]. If None, inputs_k will copy the value of inputs_q.inputs_v – values of shape
[batch_sizes..., length, features]. If None, inputs_v will copy the value of inputs_k.inputs_kv – key/values of shape
[batch_sizes..., length, features]. If None, inputs_kv will copy the value of inputs_q. This arg will be deprecated soon. Use inputs_k and inputs_v instead.mask – attention mask of shape
[batch_sizes..., num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value isFalse.deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
dropout_rng – optional rng key to pass to the attention layer’s dropout mask. Otherwise, self.make_rng(‘dropout’) is used instead.
sow_weights – if
True, the attention weights are sowed into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable viamutable=['intermediates']in order to have that collection returned.
- Returns:
output of shape
[batch_sizes..., length, features].
Methods
- class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Self-attention special case of multi-head dot-product attention. This layer is deprecated in favor of
MultiHeadDotProductAttention.- Example usage::
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) >>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
- __call__(inputs_q, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#
Applies multi-head dot product self-attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.
- Parameters:
inputs_q – input queries of shape
[batch_sizes..., length, features].mask – attention mask of shape
[batch_sizes..., num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value isFalse.deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
- Returns:
output of shape
[batch_sizes..., length, features].
Methods
- flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, einsum=None)[source]#
Computes dot-product attention weights given query and key.
Used by
dot_product_attention(), which is what you’ll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself.- Parameters:
query – queries for calculating attention with shape of
[batch..., q_length, num_heads, qk_depth_per_head].key – keys for calculating attention with shape of
[batch..., kv_length, num_heads, qk_depth_per_head].bias – bias for the attention weights. This should be broadcastable to the shape
[batch..., num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.mask – mask for the attention weights. This should be broadcastable to the shape
[batch..., num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value isFalse.broadcast_dropout – bool: use a broadcasted dropout along batch dims.
dropout_rng – JAX PRNGKey: to be used for dropout
dropout_rate – dropout rate
deterministic – bool, deterministic or not (to apply dropout)
dtype – the dtype of the computation (default: infer from inputs and params)
precision – numerical precision of the computation see
jax.lax.Precisionfor details.module – the Module that will sow the attention weights into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via
mutable=['intermediates']in order to have that collection returned. Ifmoduleis None, the attention weights will not be sowed.force_fp32_for_softmax – bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability.
einsum_dot_general – the dot_general to use in einsum.
einsum – If unspecified, default jnp.einsum will be used. This argument is mutually exclusive with precision and einsum_dot_general.
- Raises:
ValueError – if both precision/einsum_dot_general and einsum are specified.
- Returns:
Output of shape
[batch..., num_heads, q_length, kv_length].
- flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, qk_attn_weights_einsum=None, attn_weights_value_einsum=None)[source]#
Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.
Note
query,key,valueneedn’t have any batch dimensions.- Parameters:
query – queries for calculating attention with shape of
[batch..., q_length, num_heads, qk_depth_per_head].key – keys for calculating attention with shape of
[batch..., kv_length, num_heads, qk_depth_per_head].value – values to be used in attention with shape of
[batch..., kv_length, num_heads, v_depth_per_head].bias – bias for the attention weights. This should be broadcastable to the shape
[batch..., num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.mask – mask for the attention weights. This should be broadcastable to the shape
[batch..., num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value isFalse.broadcast_dropout – bool: use a broadcasted dropout along batch dims.
dropout_rng – JAX PRNGKey: to be used for dropout
dropout_rate – dropout rate
deterministic – bool, deterministic or not (to apply dropout)
dtype – the dtype of the computation (default: infer from inputs)
precision – numerical precision of the computation see ``jax.lax.Precision` for details.
module – the Module that will sow the attention weights into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via
mutable=['intermediates']in order to have that collection returned. Ifmoduleis None, the attention weights will not be sowed.force_fp32_for_softmax – bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability.
einsum_dot_general – the dot_general to use in jnp.einsum.
qk_attn_weights_einsum – the einsum for computing the attention weights. When unspecified, the default jnp.einsum will be used. This argument is mutually exclusive with precision and einsum_dot_general.
attn_weights_value_einsum – the einsum for computing the product of the attention weights and the values. When unspecified, the default jnp.einsum will be used. This argument is mutually exclusive with precision and einsum_dot_general.
- Returns:
Output of shape
[batch..., q_length, num_heads, v_depth_per_head].- Raises:
ValueError – if both precision/einsum_dot_general and
qk_attn_weights_einsum – specified.
- flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Mask-making helper for attention weights.
In case of 1d inputs (i.e.,
[batch..., len_q],[batch..., len_kv], the attention weights will be[batch..., heads, len_q, len_kv]and this function will produce[batch..., 1, len_q, len_kv].- Parameters:
query_input – a batched, flat input of query_length size
key_input – a batched, flat input of key_length size
pairwise_fn – broadcasting elementwise comparison function
extra_batch_dims – number of extra batch dims to add singleton axes for, none by default
dtype – mask return dtype
- Returns:
A
[batch..., 1, len_q, len_kv]shaped mask for 1d attention.
- flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Make a causal mask for self-attention.
In case of 1d inputs (i.e.,
[batch..., len], the self-attention weights will be[batch..., heads, len, len]and this function will produce a causal mask of shape[batch..., 1, len, len].- Parameters:
x – input array of shape
[batch..., len]extra_batch_dims – number of batch dims to add singleton axes for, none by default
dtype – mask return dtype
- Returns:
A
[batch..., 1, len, len]shaped causal mask for 1d attention.
Recurrent#
- class flax.linen.RNNCellBase(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
RNN cell base class.
- __call__(**kwargs)#
Call self as a function.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(rng, input_shape)Initialize the RNN cell carry.
- class flax.linen.LSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
LSTM cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.LSTMCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
number of output features.
- Type:
int
- gate_fn#
activation function used for gates (default: sigmoid).
- Type:
collections.abc.Callable[[…], Any]
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type:
collections.abc.Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init())
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- dtype#
the dtype of the computation (default: infer from inputs and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- __call__(carry, inputs)[source]#
A long short-term memory (LSTM) cell.
- Parameters:
carry – the hidden state of the LSTM cell, initialized using
LSTMCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(rng, input_shape)Initialize the RNN cell carry.
- class flax.linen.OptimizedLSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
More efficient LSTM Cell that concatenates state components before matmul.
The parameters are compatible with
LSTMCell. Note that this cell is often faster thanLSTMCellas long as the hidden size is roughly <= 2048 units.The mathematical definition of the cell is the same as
LSTMCelland as follows\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.OptimizedLSTMCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- gate_fn#
activation function used for gates (default: sigmoid).
- Type:
collections.abc.Callable[[…], Any]
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type:
collections.abc.Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init()).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- dtype#
the dtype of the computation (default: infer from inputs and params).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- __call__(carry, inputs)[source]#
An optimized long short-term memory (LSTM) cell.
- Parameters:
carry – the hidden state of the LSTM cell, initialized using
LSTMCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(rng, input_shape)Initialize the RNN cell carry.
- class flax.linen.ConvLSTMCell(features, kernel_size, strides=None, padding='SAME', use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A convolutional LSTM cell.
The implementation is based on xingjian2015convolutional. Given x_t and the previous state (h_{t-1}, c_{t-1}) the core computes
\[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]where * denotes the convolution operator; i_t, f_t, o_t are input, forget and output gate activations, and g_t is a vector of cell updates.
Note
- Forget gate initialization:
Following jozefowicz2015empirical we add 1.0 to b_f after initialization in order to reduce the scale of forgetting in the beginning of the training.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (3, 5, 5)) >>> layer = nn.ConvLSTMCell(features=4, kernel_size=(2, 2)) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
number of convolution filters.
- Type:
int
- kernel_size#
shape of the convolutional kernel.
- Type:
collections.abc.Sequence[int]
- strides#
a sequence of
nintegers, representing the inter-window strides.- Type:
collections.abc.Sequence[int] | None
- padding#
either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension.- Type:
str | collections.abc.Sequence[tuple[int, int]]
- bias#
whether to add a bias to the output (default: True).
- dtype#
the dtype of the computation (default: None).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- __call__(carry, inputs)[source]#
Constructs a convolutional LSTM.
- Parameters:
carry – the hidden state of the Conv2DLSTM cell, initialized using
Conv2DLSTM.initialize_carry.inputs – input data with dimensions (batch, spatial_dims…, features).
- Returns:
A tuple with the new carry and the output.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(rng, input_shape)Initialize the RNN cell carry.
- class flax.linen.SimpleCell(features, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, residual=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Simple cell.
The mathematical definition of the cell is as follows
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]where x is the input and h is the output of the previous time step.
If residual is True,
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.SimpleCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
number of output features.
- Type:
int
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type:
collections.abc.Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init())
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- dtype#
the dtype of the computation (default: None).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- residual#
pre-activation residual connection (https://arxiv.org/abs/1801.06105).
- Type:
bool
- __call__(carry, inputs)[source]#
Simple cell.
- Parameters:
carry – the hidden state of the Simple cell, initialized using
SimpleCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(rng, input_shape)Initialize the RNN cell carry.
- class flax.linen.GRUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
GRU cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]where x is the input and h is the output of the previous time step.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.GRUCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
number of output features.
- Type:
int
- gate_fn#
activation function used for gates (default: sigmoid).
- Type:
collections.abc.Callable[[…], Any]
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type:
collections.abc.Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init())
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- dtype#
the dtype of the computation (default: None).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- __call__(carry, inputs)[source]#
Gated recurrent unit (GRU) cell.
- Parameters:
carry – the hidden state of the GRU cell, initialized using
GRUCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(rng, input_shape)Initialize the RNN cell carry.
- class flax.linen.MGUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, forget_bias_init=<function ones>, activation_bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, reset_gate=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
MGU cell (https://arxiv.org/pdf/1603.09420.pdf).
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + f * (W_{hn} h + b_{hn})) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]where x is the input and h is the output of the previous time step.
If
reset_gateis false, the above becomes\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + W_{hn} h) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.MGUCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
number of output features.
- Type:
int
- gate_fn#
activation function used for gates (default: sigmoid).
- Type:
collections.abc.Callable[[…], Any]
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type:
collections.abc.Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- forget_bias_init#
initializer for the bias parameters of the forget gate. The default is set to initializers.ones_init() because this prevents vanishing gradients. See https://proceedings.mlr.press/v37/jozefowicz15.pdf, section 2.2 for more details.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- activation_bias_init#
initializer for the bias parameters of the activation output (default: initializers.zeros_init()).
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- dtype#
the dtype of the computation (default: None).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any | None
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType | Any
- reset_gate#
flag for applying reset gating.
- Type:
bool
- __call__(carry, inputs)[source]#
Minimal gated unit (MGU) cell.
- Parameters:
carry – the hidden state of the MGU cell, initialized using
MGUCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(rng, input_shape)Initialize the RNN cell carry.
- class flax.linen.RNN(cell, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, variable_axes=FrozenDict({}), variable_broadcast='params', variable_carry=False, split_rngs=FrozenDict({ params: False, }), parent=<flax.linen.module._Sentinel object>, name=None)[source]#
The
RNNmodule takes anyRNNCellBaseinstance and applies it over a sequenceusing
flax.linen.scan().Example:
>>> import jax.numpy as jnp >>> import jax >>> import flax.linen as nn >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64)) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (batch, time, cell_size) (10, 50, 64)
As shown above, RNN uses the
cell_sizeargument to set thesizeargument for the cell’sinitialize_carrymethod, in practice this is typically the number of hidden units you want for the cell. However, this may vary depending on the cell you are using, for example theConvLSTMCellrequires asizeargument of the form(kernel_height, kernel_width, features):>>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features) >>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3))) >>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x) >>> y.shape # (batch, time, height, width, features) (10, 50, 32, 32, 64)
By default RNN expect the time dimension after the batch dimension (
(*batch, time, *features)), if you settime_major=TrueRNN will instead expect the time dimension to be at the beginning ((time, *batch, *features)):>>> x = jnp.ones((50, 10, 32)) # (time, batch, features) >>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (time, batch, cell_size) (50, 10, 64)
The output is an array of shape
(*batch, time, *cell_size)by default (typically), however if you setreturn_carry=Trueit will instead return a tuple of the final carry and the output:>>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True) >>> variables = lstm.init(jax.random.key(0), x) >>> carry, y = lstm.apply(variables, x) >>> jax.tree_util.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size)) ((10, 64), (10, 64)) >>> y.shape # (batch, time, cell_size) (10, 50, 64)
To support variable length sequences, you can pass a
seq_lengthswhich is an integer array of shape(*batch)where each element is the length of the sequence in the batch. For example:>>> seq_lengths = jnp.array([3, 2, 5])
The output elements corresponding to padding elements are NOT zeroed out. If
return_carryis set toTruethe carry will be the state of the last valid element of each sequence.RNN also accepts some of the arguments of
flax.linen.scan(), by default they are set to work with cells likeLSTMCellandGRUCellbut they can be overridden as needed. Overriding default values to scan looks like this:>>> lstm = nn.RNN( ... nn.LSTMCell(64), ... unroll=1, variable_axes={}, variable_broadcast='params', ... variable_carry=False, split_rngs={'params': False})
- cell#
an instance of
RNNCellBase.
- time_major#
if
time_major=False(default) it will expect inputs with shape(*batch, time, *features), else it will expect inputs with shape(time, *batch, *features).- Type:
bool
- return_carry#
if
return_carry=False(default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.- Type:
bool
- reverse#
if
reverse=False(default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. Ifseq_lengthsis passed, padding will always remain at the end of the sequence.- Type:
bool
- keep_order#
if
keep_order=True, whenreverse=Truethe output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. Ifkeep_order=False(default), the output will remain in the order specified byreverse.- Type:
bool
- unroll#
how many scan iterations to unroll within a single iteration of a loop, defaults to 1. This argument will be passed to
nn.scan.- Type:
int
- variable_axes#
a dictionary mapping each collection to either an integer
i(meaning we scan over dimensioni) orNone(replicate rather than scan). This argument is forwarded tonn.scan.- Type:
collections.abc.Mapping[bool | str | Collection[str] | DenyList, int | flax.typing.In[int] | flax.typing.Out[int]]
- variable_broadcast#
Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. This argument is forwarded to
nn.scan.- Type:
bool | str | Collection[str] | DenyList
- variable_carry#
Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. This argument is forwarded to
nn.scan.- Type:
bool | str | Collection[str] | DenyList
- split_rngs#
a mapping from PRNGSequenceFilter to bool specifying whether a collection’s PRNG key should be split such that its values are different at each step, or replicated such that its values remain the same at each step. This argument is forwarded to
nn.scan.- Type:
collections.abc.Mapping[bool | str | Collection[str] | DenyList, bool]
- __call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
Applies the RNN to the inputs.
__call__allows you to optionally override some attributes likereturn_carryandtime_majordefined in the constructor.- Parameters:
inputs – the input sequence.
initial_carry – the initial carry, if not provided it will be initialized using the cell’s
RNNCellBase.initialize_carry()method.init_key – a PRNG key used to initialize the carry, if not provided
jax.random.key(0)will be used. Most cells will ignore this argument.seq_lengths – an optional integer array of shape
(*batch)indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored.return_carry – if
return_carry=False(default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.time_major – if
time_major=False(default) it will expect inputs with shape(*batch, time, *features), else it will expect inputs with shape(time, *batch, *features).reverse – overrides the
reverseattribute, ifreverse=False(default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. Ifseq_lengthsis passed, padding will always remain at the end of the sequence.keep_order – overrides the
keep_orderattribute, ifkeep_order=True, whenreverse=Truethe output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. Ifkeep_order=False(default), the output will remain in the order specified byreverse.
- Returns:
if
return_carry=False(default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.
Methods
- class flax.linen.Bidirectional(forward_rnn, backward_rnn, merge_fn=<function _concatenate>, time_major=False, return_carry=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Processes the input in both directions and merges the results.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Bidirectional(nn.RNN(nn.GRUCell(4)), nn.RNN(nn.GRUCell(4))) >>> x = jnp.ones((2, 3)) >>> variables = layer.init(jax.random.key(0), x) >>> out = layer.apply(variables, x)
- __call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
Call self as a function.
Methods
BatchApply#
- class flax.linen.BatchApply(f, num_dims=2)[source]#
Temporarily merges leading dimensions of input tensors.
Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input.
Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified.
This may be useful for applying a module to each timestep of e.g. a
[Time, Batch, ...]array.For some
fs and platforms, this may be more efficient thanjax.vmap(), especially when combined with other transformations likejax.grad().Example usage:
>>> import jax, jax.numpy as jnp >>> a = jax.random.normal(jax.random.key(0), [2, 3, 4]) >>> b = jax.random.normal(jax.random.key(1), [4]) >>> def raises(a, b): ... if len(a.shape) != 2: ... raise ValueError("a must be shape 2") ... if len(b.shape) != 1: ... raise ValueError("b must be shape 1") ... return jnp.dot(a, b) >>> out = BatchApply(raises)(a, b) >>> expected_merged_leading = raises(a.reshape(2*3, 4), b) >>> expected = expected_merged_leading.reshape((2, 3) + expected_merged_leading.shape[1:]) >>> np.testing.assert_array_equal(out, expected)
Methods