Activation functions#
Activation functions.
- class flax.linen.activation.PReLU(param_dtype=<class 'jax.numpy.float32'>, negative_slope_init=0.01, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Parametric Rectified Linear Unit (PReLU) activation function.
Note that PReLU is a Flax layer and not a simple activation function, so it needs to be initialized before being called.
- Example usage::
>>> import flax.linen as nn
>>> class MLP(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(2)(x) ... x = nn.PReLU()(x) # initialized ... return x
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- negative_slope_init#
the value to initialize the negative slope (default 0.01).
- Type
float
- flax.linen.activation.celu(x, alpha=1.0)[source]#
Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
- Parameters
x – input array
alpha – array or scalar (default: 1.0)
- Returns
An array.
- flax.linen.activation.elu(x, alpha=1.0)[source]#
Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- Parameters
x – input array
alpha – scalar or array of alpha values (default: 1.0)
- Returns
An array.
See also
- flax.linen.activation.gelu(x, approximate=True)[source]#
Gaussian error linear unit activation function.
If
approximate=False
, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]If
approximate=True
, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
- Parameters
x – input array
approximate – whether to use the approximate or exact formulation.
- flax.linen.activation.glu(x, axis=-1)[source]#
Gated linear unit activation function.
Computes the function:
\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]where the array is split into two along
axis
. The size of theaxis
dimension must be divisible by two.- Parameters
x – input array
axis – the axis along which the split should be computed (default: -1)
- Returns
An array.
See also
- flax.linen.activation.hard_sigmoid(x)[source]#
Hard Sigmoid activation function.
Computes the element-wise function
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]- Parameters
x – input array
- Returns
An array.
See also
relu6()
- flax.linen.activation.hard_silu(x)[source]#
Hard SiLU (swish) activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]Both
hard_silu()
andhard_swish()
are aliases for the same function.- Parameters
x – input array
- Returns
An array.
See also
- flax.linen.activation.hard_swish(x)#
Hard SiLU (swish) activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]Both
hard_silu()
andhard_swish()
are aliases for the same function.- Parameters
x – input array
- Returns
An array.
See also
- flax.linen.activation.hard_tanh(x)[source]#
Hard \(\mathrm{tanh}\) activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]- Parameters
x – input array
- Returns
An array.
- flax.linen.activation.leaky_relu(x, negative_slope=0.01)[source]#
Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.- Parameters
x – input array
negative_slope – array or scalar specifying the negative slope (default: 0.01)
- Returns
An array.
See also
- flax.linen.activation.log_sigmoid(x)[source]#
Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- Parameters
x – input array
- Returns
An array.
See also
- flax.linen.activation.log_softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[source]#
Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters
x – input array
axis – the axis or axes along which the
log_softmax
should be computed. Either an integer or a tuple of integers.where – Elements to include in the
log_softmax
.
- Returns
An array.
Note
If any input values are
+inf
, the result will be allNaN
: this reflects the fact thatinf / inf
is not well-defined in the context of floating-point math.See also
- flax.linen.activation.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None)[source]#
Log-sum-exp reduction.
JAX implementation of
scipy.special.logsumexp()
.\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]where the \(j\) indices range over one or more dimensions to be reduced.
- Parameters
a – the input array
axis – the axis or axes over which to reduce. May be either
None
, an int, or a tuple of ints.b – scaling factors for \(\mathrm{exp}(a)\). Must be broadcastable to the shape of a.
keepdims – If
True
, the axes that are reduced are left in the output as dimensions of size 1.return_sign – If
True
, the output will be a(result, sign)
pair, wheresign
is the sign of the sums andresult
contains the logarithms of their absolute values. IfFalse
onlyresult
is returned and it will contain NaN values if the sums are negative.where – Elements to include in the reduction.
- Returns
Either an array
result
or a pair of arrays(result, sign)
, depending on the value of thereturn_sign
argument.
- flax.linen.activation.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#
One-hot encodes the given indices.
Each index in the input
x
is encoded as a vector of zeros of lengthnum_classes
with the element atindex
set to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters
x – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – optional, a float dtype for the returned values (default
jnp.float_
).axis – the axis or axes along which the function should be computed.
- flax.linen.activation.relu(x)[source]#
Rectified linear unit activation function.
Computes the element-wise function:
\[\mathrm{relu}(x) = \max(x, 0)\]except under differentiation, we take:
\[\nabla \mathrm{relu}(0) = 0\]For more information see Numerical influence of ReLU’(0) on backpropagation.
- Parameters
x – input array
- Returns
An array.
Examples
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
See also
relu6()
- flax.linen.activation.selu(x)[source]#
Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
- Parameters
x – input array
- Returns
An array.
See also
- flax.linen.activation.sigmoid(x)[source]#
Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- Parameters
x – input array
- Returns
An array.
See also
- flax.linen.activation.silu(x)[source]#
SiLU (aka swish) activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]swish()
andsilu()
are both aliases for the same function.- Parameters
x – input array
- Returns
An array.
See also
- flax.linen.activation.soft_sign(x)[source]#
Soft-sign activation function.
Computes the element-wise function
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]- Parameters
x – input array
- flax.linen.activation.softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[source]#
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters
x – input array
axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
where – Elements to include in the
softmax
.
- Returns
An array.
Note
If any input values are
+inf
, the result will be allNaN
: this reflects the fact thatinf / inf
is not well-defined in the context of floating-point math.See also
- flax.linen.activation.softplus(x)[source]#
Softplus activation function.
Computes the element-wise function
\[\mathrm{softplus}(x) = \log(1 + e^x)\]- Parameters
x – input array
- flax.linen.activation.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#
Normalizes an array by subtracting
mean
and dividing by \(\sqrt{\mathrm{variance}}\).
- flax.linen.activation.swish(x)#
SiLU (aka swish) activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]swish()
andsilu()
are both aliases for the same function.- Parameters
x – input array
- Returns
An array.
See also
- flax.linen.activation.tanh(x, /)[source]#
Calculate element-wise hyperbolic tangent of input.
JAX implementation of
numpy.tanh
.The hyperbolic tangent is defined by:
\[tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]- Parameters
x – input array or scalar.
- Returns
An array containing the hyperbolic tangent of each element of
x
, promoting to inexact dtype.
Note
jnp.tanh
is equivalent to computing-1j * jnp.tan(1j * x)
.See also
jax.numpy.sinh()
: Computes the element-wise hyperbolic sine of the input.jax.numpy.cosh()
: Computes the element-wise hyperbolic cosine of the input.jax.numpy.arctanh()
: Computes the element-wise inverse of hyperbolic tangent of the input.
Examples
>>> x = jnp.array([[-1, 0, 1], ... [3, -2, 5]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(x) Array([[-0.762, 0. , 0.762], [ 0.995, -0.964, 1. ]], dtype=float32) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * x) Array([[-0.762+0.j, 0. -0.j, 0.762-0.j], [ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(2-5j) Array(1.031+0.021j, dtype=complex64, weak_type=True) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * (2-5j)) Array(1.031+0.021j, dtype=complex64, weak_type=True)