Upgrading my codebase to Optax#

We have proposed to replace flax.optim with Optax in 2021 with FLIP #1009 and the Flax optimizers have been removed in v0.6.0 - this guide is targeted towards flax.optim users to help them update their code to Optax.

See also Optax’s quick start documentation: https://optax.readthedocs.io/en/latest/getting_started.html

Replacing flax.optim with optax#

Optax has drop-in replacements for all of Flax’s optimizers. Refer to Optax’s documentation Common Optimizers for API details.

The usage is very similar, with the difference that optax does not keep a copy of the params, so they need to be passed around separately. Flax provides the utility TrainState to store optimizer state, parameters, and other associated data in a single dataclass (not used in code below).

def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)

  return optimizer.apply_gradient(grads)

optimizer_def = flax.optim.Momentum(
    learning_rate, momentum)
optimizer = optimizer_def.create(variables['params'])

for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)

for batch in ds_train:
  params, opt_state = train_step(params, opt_state, batch)

Composable Gradient Transformations#

The function optax.sgd() used in the code snippet above is simply a wrapper for the sequential application of two gradient transformations. Instead of using this alias, it is common to use optax.chain() to combine multiple of these generic building blocks.

# Note that the aliases follow the convention to use positive
# values for the learning rate by default.
tx = optax.sgd(learning_rate, momentum)

tx = optax.chain(
    # 1. Step: keep a trace of past updates and add to gradients.
    # 2. Step: multiply result from step 1 with negative learning rate.
    # Note that `optax.apply_updates()` simply adds the final updates to the
    # parameters, so we must make sure to flip the sign here for gradient
    # descent.

Weight Decay#

Some of Flax’s optimizers also include a weight decay. In Optax, some optimizers also have a weight decay parameter (such as optax.adamw()), and to others the weight decay can be added as another “gradient transformation” optax.add_decayed_weights() that adds an update derived from the parameters.

optimizer_def = flax.optim.Adam(
    learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])
# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
    # params -= learning_rate * (adam(grads) + params * weight_decay)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)

Gradient Clipping#

Training can be stabilized by clipping gradients to a global norm (Pascanu et al, 2012). In Flax this is often done by processing the gradients before passing them to the optimizer. With Optax this becomes just another gradient transformation optax.clip_by_global_norm().

def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  grads_flat, _ = jax.tree_util.tree_flatten(grads)
  global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
  g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
  grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
  return optimizer.apply_gradient(grads)
tx = optax.chain(

Learning Rate Schedules#

For learning rate schedules, Flax allows overwriting hyper parameters when applying the gradients. Optax maintains a step counter and provides this as an argument to a function for scaling the updates added with optax.scale_by_schedule(). Optax also allows specifying a functions to inject arbitrary scalar values for other gradient updates via optax.inject_hyperparams().

Read more about learning rate schedules in the lr_schedule guide.

Read more about schedules defined in Optax under Optimizer Schedules. the standard optimizers (like optax.adam(), optax.sgd() etc.) also accept a learning rate schedule as a parameter for learning_rate.

def train_step(step, optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
tx = optax.chain(
    # Note that we still want a negative value for scaling the updates!
    optax.scale_by_schedule(lambda step: -schedule(step)),

Multiple Optimizers / Updating a Subset of Parameters#

In Flax, traversals are used to specify which parameters should be updated by an optimizer. And you can combine traversals using flax.optim.MultiOptimizer to apply different optimizers on different parameters. The equivalent in Optax is optax.masked() and optax.chain().

Note that the example below is using flax.traverse_util to create the boolean masks required by optax.masked() - alternatively you could also create them manually, or use optax.multi_transform() that takes a multivalent pytree to specify gradient transformations.

Beware that optax.masked() flattens the pytree internally and the inner gradient transformations will only be called with that partial flattened view of the params/gradients. This is not a problem usually, but it makes it hard to nest multiple levels of masked gradient transformations (because the inner masks will expect the mask to be defined in terms of the partial flattened view that is not readily available outside the outer mask).

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)

optimizer = flax.optim.MultiOptimizer(
    (kernels, kernel_opt),
    (biases, bias_opt)
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
    optax.masked(optax.scale(-learning_rate), kernels_mask),
    optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),

Final Words#

All above patterns can of course also be mixed and Optax makes it possible to encapsulate all these transformations into a single place outside the main training loop, which makes testing much easier.