User Guide on Using FP8#
JAX supports various FP8 formats, including E4M3 (jnp.float8_e4m3fn) and E5M2 (jnp.float8_e5m2). Due to the limited range of FP8 data types, higher-precision data must be scaled to fit within the FP8 representable range, a process known as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back to its original type.
While jnp.dot supports FP8 inputs directly, proper quantization and dequantization is needed for optimal performance. Flax provides nn.fp8_ops.Fp8DotGeneral and nn.fp8_ops.Fp8Einsum modules that handle this automatically and can be used with existing layers like nn.Dense.
This tutorial will walk you through the basics of how to use it.
Setting up our environment#
Here, we provide the code necessary to set up the environment for our notebook. Additionally, we define a function to check if the XLA-optimized HLO will indeed call an FP8 dot operation under the hood.
Note: This tutorial relies on the XLA-FP8 feature, which is only supported on NVIDIA Hopper GPUs or later.
import flax
import jax
import re
import pprint
from jax import random
from jax import numpy as jnp
from jax._src import test_util as jtu
from flax import linen as nn
from flax.linen import fp8_ops
e4m3 = jnp.float8_e4m3fn
f32 = jnp.float32
E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)
assert jtu.is_cuda_compute_capability_at_least("9.0")
def check_fp8_call(lowered):
hlo = lowered.compile()
if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()):
print("Fp8 call detected!")
else:
print("No Fp8 call!")
FLAX Low Level API#
The JAX dot operations (e.g. jnp.dot) support the FP8 dtype inputs. So it is
legal to do the following call:
k0, k1 = random.split(random.key(0), 2)
a = random.uniform(k0, (16, 32))
b = random.uniform(k1, (32, 64))
@jax.jit
def dot_fp8(a, b):
return jnp.dot(a.astype(e4m3), b.astype(e4m3), preferred_element_type=f32)
check_fp8_call(dot_fp8.lower(a, b))
However, this approach has two key limitations:
jnp.dotdoes not support custom scaling factors for operands, defaulting to a scale of 1.0The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice
To overcome these limitations and implement proper FP8 matrix multiplication, we recommend using the Flax FP8 APIs. Let’s start with a basic scaling approach.
Current Scaling#
Scaling factors are usually defined as scale = amax(x) / MAX, where amax is
an operation to find the absolute maximum value of the tensor, and MAX is the
maximum value of the representable range of the target dtype. This scaling
approach allows us to derive the scaling factors directly from the current
operand tensors of the dot product.
@jax.jit
def dot_fp8(a, b):
a_scale = jnp.max(jnp.abs(A)) / E4M3_MAX
b_scale = jnp.max(jnp.abs(B)) / E4M3_MAX
a = fp8_ops.quantize(a, e4m3, a_scale, f32)
b = fp8_ops.quantize(b, e4m3, b_scale, f32)
c = jnp.dot(a, b, preferred_element_type=f32)
c = fp8_ops.dequantize(c, f32, a_scale * b_scale)
return c
c = dot_fp8(a, b)
check_fp8_call(dot_fp8.lower(a, b))
As shown in the code, we perform quantization (fp8_ops.quantize) on the
tensors to get the lower precision operands. The jnp.dot processes them and
accumulates the output in high precision (i.e., the preferred_element_type).
After that, we multiply the result by the scaling factors to dequantize back to
the original range (fp8_ops.dequantize). Note that while this example uses
E4M3 for both inputs, it is possible to use different FP8 dtypes like E4M3 and
E5M2 for the inputs. The quantization method and the scaling factors can also be
customized based on application needs.
One major issue with the current scaling method is the performance overhead
introduced by computing a_scale and b_scale, which requires additional
loading of the operand tensors. To overcome this issue, we recommend the delayed
scaling.
Delayed Scaling#
In delayed scaling, we use a scaling factor associated with an amax history. The scaling factor remains a scalar, but the amax history is a list that stores amax values from recent steps (e.g., 1024 steps). Both tensors are computed from previous steps and maintained in the model parameters.
The quantization and dequantization operations for delayed scaling are provided
by fp8_ops.in_q and fp8_ops.out_dq respectively. fp8_ops.in_q handles
input quantization and update the amax history and scaling factor, while
fp8_ops.out_dq performs output dequantization.
a_scale = jnp.array(1.0)
b_scale = jnp.array(1.0)
a_amax_hist = jnp.zeros((1024,))
b_amax_hist = jnp.zeros((1024,))
@jax.jit
def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist):
a, a_scale = fp8_ops.in_q(f32, e4m3, a, a_scale, a_amax_hist)
b, b_scale = fp8_ops.in_q(f32, e4m3, b, b_scale, b_amax_hist)
c = jnp.dot(a, b, preferred_element_type=f32)
c = fp8_ops.out_dq(f32, a_scale, b_scale, c)
return c
c = dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist)
check_fp8_call(dot_fp8.lower(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist))
In this example, we first prepare three pairs of scaling factors and amax
histories, treating them as results computed from previous steps. Then, we apply
fp8_ops.in_q to the input operands of jnp.dot, followed by fp8_ops.out_dq
to the output of jnp.dot.
FLAX High Level API#
Flax provides high-level operations to seamlessly integrate FP8 quantization into existing layers. Instead of manually handling quantization of the delayed scaling (e.g., the maintanence of the amax history and scaling factors), users can simply use these drop-in replacements:
fp8_ops.Fp8DotGeneralforlax.dot_generaloperationsfp8_ops.Fp8Einsumforjnp.einsumoperations
These operations automatically handle all FP8-related functionality, including quantization/dequantization, scale factor updates, and FP8 dtype selection for both forward and backward passes.
Consider the following example:
model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneral)
params = model.init(k0, A)
@jax.jit
def train_step(var, a):
c = model.apply(var, a)
return jnp.sum(c)
check_fp8_call(train_step.lower(params, A))
By setting dot_general_cls=fp8_ops.Fp8DotGeneral, we replace the
default lax.dot_general operation in nn.Dense with an FP8-enabled version.
The model usage remains similar, but now includes additional parameters for FP8
quantization: scaling factors and amax history values. The next section explains
how to update these FP8-specific parameters.
For models that use jnp.einsum operations, such as Mixture of Experts (MoE)
layers, users can replace them with fp8_ops.Fp8Einsum to enable FP8
quantization. Here’s an example:
from typing import Any
class FooModule(nn.Module):
einsum: Any = None
@nn.compact
def __call__(self, a, b):
if self.einsum is not None:
einsum_fn = self.einsum()
elif self.einsum is None:
einsum_fn = jnp.einsum
c = einsum_fn("mk,kn->mn", a, b)
return c
model = FooModule(einsum=fp8_ops.Fp8Einsum)
params = model.init(k0, a, b)
@jax.jit
def train_step(var, a, b):
c = model.apply(var, a, b)
return jnp.sum(c)
check_fp8_call(train_step.lower(params, a, b))
Manipulate FP8 params#
The following sections explain the internal FP8 parameters managed by
fp8_ops.Fp8DotGeneral and fp8_ops.Fp8Einsum. These parameters
include scaling factors and amax history values that control the FP8
quantization process. While most users don’t need to interact with these
directly, understanding them can be valuable for advanced optimization and
debugging.
Let’s first examine the data structure of params. In the code below, we redact
the parameter values and then display the PyTree structure.
params_structure = flax.core.unfreeze(params).copy()
params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')
for key, value in params_structure.items():
params_structure[key] = '*'
params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')
pprint.pprint(params_structure)
The output is as follows:
{'_overwrite_with_gradient': {'Fp8Einsum_0': {'input_amax_history': '*',
'input_scale': '*',
'kernel_amax_history': '*',
'kernel_scale': '*',
'output_grad_amax_history': '*',
'output_grad_scale': '*'}}}
In addition to the expected params, there is an additional category called
_overwrite_with_gradient. This category includes three pairs of amax_history
and scale for the activation, kernel, and dot gradient, respectively.
Update gradient of FP8 params#
Now, we perform one training step to obtain the gradients and see how to use them to update the parameters.
step_fn = jax.jit(jax.grad(train_step, (0, 1)))
grads = step_fn(params, A)
params = flax.core.unfreeze(params)
params = flax.traverse_util.flatten_dict(params, sep='/')
grads = flax.traverse_util.flatten_dict(grads[0], sep='/')
for key, value in params.items():
if key.startswith('params'):
params[key] = value + 0.01 * grads[key]
if key.startswith('_overwrite_with_gradient'):
params[key] = grads[key]
params = flax.traverse_util.unflatten_dict(params, sep='/')
params = flax.core.freeze(params)
The above code demonstrates how to update both params and
_overwrite_with_gradient. For params, we use the formula new_param = old_param + 0.01 * grads, where 0.01 is the learning rate (or users can use
whatever optimizers from optax). For _overwrite_with_gradient, we simply use
the gradient to overwrite the old values.
Note that flax.training.train_state.TrainState conveniently supports the
category of _overwrite_with_gradient, so users do not need to modify their
scripts if they don’t use custom TrainState.
Accumulate gradient of FP8 params#
When the same parameter is used in a branched manner, the autograd mechanism
will add their gradients from these branches. This is common in scenarios like
pipeline parallelism, where each microbatch shares the same set of parameters
for the minibatch. However, for the _overwrite_with_gradient parameters, this
accumulation by addition is not meaningful. Instead, we prefer custom
accumulation by taking the maximum value.
To address this, we introduce a custom dtype fp8_ops.fp32_max_grad. The basic
usage is demonstrated below:
fmax32 = fp8_ops.fp32_max_grad
def reuse_fp8_param(x, y, scale, amax_history):
scale = scale.astype(fmax32)
amax_history = amax_history.astype(fmax32)
x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)
y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)
return x + y
reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))
reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)
_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)
print(new_ah, new_sf)
In this example, we first cast the scale and amax_history to
fp8_ops.fp32_max_grad and then call fp8_ops.in_qdq twice using the same pair
of scale and amax_history. During autograd, their gradients from each branch
will be taken as the maximum, giving us the correct results of:
1.0 [3. 0. 0. ... 0. 0. 0.]
If we do not perform the type casting, we get the following result, meaning the gradients of the two branches are added:
2.0 [5. 0. 0. ... 0. 0. 0.]
This casting is already included if users choose to use the high-level APIs.
Deprecated APIs#
Previously, we provided APIs like fp8_ops.quantize_dequantize for current
scaling and fp8_ops.[in|out]_qdq for delayed scaling. These were used with
high precision dot operations, leveraging an XLA-FP8 feature that
pattern-matched QDQ->dot sequences to Q->fp8_cublas_gemm. The corresponding
high-level API was called fp8_ops.Fp8DotGeneralOp. However, this pattern
matching-based solution proved brittle, as the patterns could be easily broken
by other XLA optimizations. We recommend users migrate from these deprecated
APIs to the newer ones described above.
For migration, users should replace:
fp8_ops.quantize_dequantize -> jnp.dotwithfp8_ops.quantize -> jnp.dot -> fp8_ops.dequantizefp8_ops.in_qdq -> jnp.dot -> fp8_ops.out_qdqwithfp8_ops.in_q -> jnp.dot -> fp8_ops.out_dqfp8_ops.Fp8DotGeneralOpwithfp8_ops.Fp8DotGeneral
Additionally, we provide an einsum variant through fp8_ops.Fp8Einsum.