User Guide on Using FP8#
JAX supports various FP8 formats, including E4M3 (jnp.float8_e4m3fn) and E5M2 (jnp.float8_e5m2). Due to the limited range of FP8 data types, higher-precision data must be scaled to fit within the FP8 representable range, a process known as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back to its original type.
Although jnp.dot supports FP8 inputs, certain limitations make it impractical
for real-world applications. Alternatively, XLA, our compiler, can recognize
patterns like
This tutorial will walk you through the basics of how to use it.
Setting up our environment#
Here, we provide the code necessary to set up the environment for our notebook. Additionally, we define a function to check if the XLA-optimized HLO will indeed call an FP8 dot operation under the hood.
Note: This tutorial relies on the XLA-FP8 feature, which is only supported on NVIDIA Hopper GPUs or later.
import flax
import jax
import re
import pprint
from jax import random
from jax import numpy as jnp
from jax._src import test_util as jtu
from flax import linen as nn
from flax.linen import fp8_ops
e4m3 = jnp.float8_e4m3fn
e5m2 = jnp.float8_e5m2
f32 = jnp.float32
E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)
assert jtu.is_cuda_compute_capability_at_least("9.0")
def check_fp8_call(lowered):
hlo = lowered.compile()
if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()):
print("Fp8 call detected!")
else:
print("No Fp8 call!")
FLAX Low Level API#
The JAX dot operations (e.g. jnp.dot) support the FP8 dtype inputs. So it is
legal to do the following call:
key = random.key(0)
A = random.uniform(key, (16, 32))
B = random.uniform(key, (32, 64))
@jax.jit
def dot_fp8(A, B):
return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)
check_fp8_call(dot_fp8.lower(A, B))
However, there are two main issues with this approach. Firstly, jnp.dot does
not accept scaling factors for the operands, defaulting to a scaling factor of
1.0. Secondly, it does not support operands of mixed FP8 data types. For
example, when the operands are E5M2 and E4M3, the dot product is performed using
the promoted FP16 data type.
In real-world scenarios, it is essential to specify scaling factors, either from calibration for inference or a user-defined algorithm during training. Additionally, it is common practice to use E5M2 for gradients and E4M3 for activations and kernels. These limitations make this method less practical for real-world applications.
To address these limitations and create a more versatile FP8 dot product, we recommend leveraging XLA-FP8. Let’s begin with a simple scaling strategy.
Current Scaling#
Scaling factors are usually defined as scale = amax(x) / MAX, where amax is
an operation to find the absolute maximum value of the tensor, and MAX is the
maximum value of the representable range of the target dtype. This scaling
approach allows us to derive the scaling factors directly from the current
operand tensors of the dot product.
@jax.jit
def dot_fp8(A, B):
A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX
B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX
A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)
B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)
C = jnp.dot(A, B)
return C
C = dot_fp8(A, B)
check_fp8_call(dot_fp8.lower(A, B))
As shown in the code, we perform fake quantization
(fp8_ops.quantize_dequantize) on the operands of the dot product. Although the
jnp.dot still processes higher-precision inputs, XLA detects this pattern and
rewrites the dot operation as an FP8 dot call (e.g., cublasLt call for GPUs).
This approach effectively mimics the first example but offers greater
flexibility. We can control the input dtypes (both are set to E4M3 here, but we
could use mixed E4M3 and E5M2) and define scaling factors, which XLA can detect
and use in the dot backend.
One major issue with the current scaling method is the overhead introduced by
computing A_scale and B_scale, which requires additional loading of the
operand tensors. To overcome this issue, we recommend the delayed scaling.
Delayed Scaling#
In delayed scaling, we use a scaling factor associated with an amax history. The scaling factor remains a scalar, but the amax history is a list that stores amax values from recent steps (e.g., 1024 steps). Both tensors are computed from previous steps and maintained in the model parameters.
Fake quantization for delayed scaling is provided by fp8_ops.in_qdq for the
activations and weights, and fp8_ops.out_qdq for the gradients.
a_scale = jnp.array(1.0)
b_scale = jnp.array(1.0)
g_scale = jnp.array(1.0)
a_amax_hist = jnp.zeros((1024,))
b_amax_hist = jnp.zeros((1024,))
g_amax_hist = jnp.zeros((1024,))
@jax.jit
def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,
g_scale, g_amax_hist):
a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)
b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)
c = jnp.dot(a, b)
c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)
return c
C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
g_scale, g_amax_hist)
check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
g_scale, g_amax_hist))
In this example, we first prepare three pairs of scaling factors and amax
histories, treating them as results computed from previous steps. Then, we apply
fp8_ops.in_qdq to the input operands of jnp.dot, followed by
fp8_ops.out_qdq to the output of jnp.dot. Note the fp8_ops.out_qdq will
apply fake quantization to the gradient of the output via custom_vjp functions.
The new scaling factors and amax histories will be returned through their
gradients, which will be covered in the next section.
FLAX High Level API#
With the FLAX library, incorporating FP8 operations into existing FLAX layers
is a seamless process. Users don’t need to manipulate the low-level APIs for
quantization. Instead, they can integrate the provided custom FP8 dot
(fp8_ops.Fp8DotGeneralOp) into FLAX layers using a straightforward
“code-injection” approach. This custom operation encapsulates all FP8-related
tasks, including the placement of quantization-dequantization ops, algorithms
for updating scaling factors, and the selection of FP8 dtype combinations for
forward and backward propagation.
Consider the following example:
model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)
params = model.init(key, A)
@jax.jit
def train_step(var, a):
c = model.apply(var, a)
return jnp.sum(c)
check_fp8_call(train_step.lower(params, A))
In this example, we simply set dot_general_cls=fp8_ops.Fp8DotGeneralOp to
enable the Dense layer to utilize the FP8 dot operation. The usage of the model
remains almost the same as before. The main difference is the addition of a new
category of parameters: the sets of scaling factors and amax history. In the
next section, we will explore how to update these parameters.
Manipulate FP8 params#
Let’s first examine the data structure of params. In the code below, we redact
the parameter values and then display the PyTree structure.
params_structure = flax.core.unfreeze(params).copy()
params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')
for key, value in params_structure.items():
params_structure[key] = '*'
params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')
pprint.pprint(params_structure)
The output is as follows:
{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',
'input_scale': '*',
'kernel_amax_history': '*',
'kernel_scale': '*',
'output_grad_amax_history': '*',
'output_grad_scale': '*'}},
'params': {'bias': '*', 'kernel': '*'}}
In addition to the expected params, there is an additional category called
_overwrite_with_gradient. This category includes three pairs of amax_history
and scale for the activation, kernel, and dot gradient, respectively.
Update gradient of FP8 params#
Now, we perform one training step to obtain the gradients and see how to use them to update the parameters.
step_fn = jax.jit(jax.grad(train_step, (0, 1)))
grads = step_fn(params, A)
params = flax.core.unfreeze(params)
params = flax.traverse_util.flatten_dict(params, sep='/')
grads = flax.traverse_util.flatten_dict(grads[0], sep='/')
for key, value in params.items():
if key.startswith('params'):
params[key] = value + 0.01 * grads[key]
if key.startswith('_overwrite_with_gradient'):
params[key] = grads[key]
params = flax.traverse_util.unflatten_dict(params, sep='/')
params = flax.core.freeze(params)
The above code demonstrates how to update both params and
_overwrite_with_gradient. For params, we use the formula new_param = old_param + 0.01 * grads, where 0.01 is the learning rate (or users can use
whatever optimizers from optax). For _overwrite_with_gradient, we simply use
the gradient to overwrite the old values.
Note that flax.training.train_state.TrainState conveniently supports the
category of _overwrite_with_gradient, so users do not need to modify their
scripts if they don’t use custom TrainState.
Accumulate gradient of FP8 params#
When the same parameter is used in a branched manner, the autograd mechanism
will add their gradients from these branches. This is common in scenarios like
pipeline parallelism, where each microbatch shares the same set of parameters
for the minibatch. However, for the _overwrite_with_gradient parameters, this
accumulation by addition is not meaningful. Instead, we prefer custom
accumulation by taking the maximum value.
To address this, we introduce a custom dtype fp8_ops.fp32_max_grad. The basic
usage is demonstrated below:
fmax32 = fp8_ops.fp32_max_grad
def reuse_fp8_param(x, y, scale, amax_history):
scale = scale.astype(fmax32)
amax_history = amax_history.astype(fmax32)
x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)
y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)
return x + y
reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))
reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)
_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)
print(new_ah, new_sf)
In this example, we first cast the scale and amax_history to
fp8_ops.fp32_max_grad and then call fp8_ops.in_qdq twice using the same pair
of scale and amax_history. During autograd, their gradients from each branch
will be taken as the maximum, giving us the correct results of:
1.0 [3. 0. 0. ... 0. 0. 0.]
If we do not perform the type casting, we get the following result, meaning the gradients of the two branches are added:
2.0 [5. 0. 0. ... 0. 0. 0.]
This casting is already included if users choose to use the high-level APIs.