Upgrading my codebase to Optax#
We have proposed to replace flax.optim
with Optax in 2021 with FLIP #1009 and
the Flax optimizers have been removed in v0.6.0 - this guide is targeted
towards flax.optim
users to help them update their code to Optax.
See also Optax’s quick start documentation: https://optax.readthedocs.io/en/latest/getting_started.html
Replacing flax.optim
with optax
#
Optax has drop-in replacements for all of Flax’s optimizers. Refer to Optax’s documentation Common Optimizers for API details.
The usage is very similar, with the difference that optax
does not keep a
copy of the params
, so they need to be passed around separately. Flax
provides the utility TrainState
to store
optimizer state, parameters, and other associated data in a single dataclass
(not used in code below).
@jax.jit
def train_step(optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
return optimizer.apply_gradient(grads)
optimizer_def = flax.optim.Momentum(
learning_rate, momentum)
optimizer = optimizer_def.create(variables['params'])
for batch in get_ds_train():
optimizer = train_step(optimizer, batch)
@jax.jit
def train_step(params, opt_state, batch):
grads = jax.grad(loss)(params, batch)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)
for batch in ds_train:
params, opt_state = train_step(params, opt_state, batch)
Composable Gradient Transformations#
The function optax.sgd()
used in the code snippet above is simply a wrapper
for the sequential application of two gradient transformations. Instead of using
this alias, it is common to use optax.chain()
to combine multiple of these
generic building blocks.
# Note that the aliases follow the convention to use positive
# values for the learning rate by default.
tx = optax.sgd(learning_rate, momentum)
#
tx = optax.chain(
# 1. Step: keep a trace of past updates and add to gradients.
optax.trace(decay=momentum),
# 2. Step: multiply result from step 1 with negative learning rate.
# Note that `optax.apply_updates()` simply adds the final updates to the
# parameters, so we must make sure to flip the sign here for gradient
# descent.
optax.scale(-learning_rate),
)
Weight Decay#
Some of Flax’s optimizers also include a weight decay. In Optax, some optimizers
also have a weight decay parameter (such as optax.adamw()
), and to others the
weight decay can be added as another “gradient transformation”
optax.add_decayed_weights()
that adds an update derived from the parameters.
optimizer_def = flax.optim.Adam(
learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])
# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
optax.scale_by_adam(),
optax.add_decayed_weights(weight_decay),
# params -= learning_rate * (adam(grads) + params * weight_decay)
optax.scale(-learning_rate),
)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)
Gradient Clipping#
Training can be stabilized by clipping gradients to a global norm (Pascanu et
al, 2012). In Flax this is often done by
processing the gradients before passing them to the optimizer. With Optax this
becomes just another gradient transformation optax.clip_by_global_norm()
.
def train_step(optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
grads_flat, _ = jax.tree_util.tree_flatten(grads)
global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
return optimizer.apply_gradient(grads)
tx = optax.chain(
optax.clip_by_global_norm(grad_clip_norm),
optax.trace(decay=momentum),
optax.scale(-learning_rate),
)
Learning Rate Schedules#
For learning rate schedules, Flax allows overwriting hyper parameters when
applying the gradients. Optax maintains a step counter and provides this as an
argument to a function for scaling the updates added with
optax.scale_by_schedule()
. Optax also allows specifying a functions to
inject arbitrary scalar values for other gradient updates via
optax.inject_hyperparams()
.
Read more about learning rate schedules in the lr_schedule guide.
Read more about schedules defined in Optax under Optimizer Schedules. the
standard optimizers (like optax.adam()
, optax.sgd()
etc.) also accept a
learning rate schedule as a parameter for learning_rate
.
def train_step(step, optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
tx = optax.chain(
optax.trace(decay=momentum),
# Note that we still want a negative value for scaling the updates!
optax.scale_by_schedule(lambda step: -schedule(step)),
)
Multiple Optimizers / Updating a Subset of Parameters#
In Flax, traversals are used to specify which parameters should be updated by an
optimizer. And you can combine traversals using
flax.optim.MultiOptimizer
to apply different optimizers on different
parameters. The equivalent in Optax is optax.masked()
and optax.chain()
.
Note that the example below is using flax.traverse_util
to create the
boolean masks required by optax.masked()
- alternatively you could also
create them manually, or use optax.multi_transform()
that takes a
multivalent pytree to specify gradient transformations.
Beware that optax.masked()
flattens the pytree internally and the inner
gradient transformations will only be called with that partial flattened view of
the params/gradients. This is not a problem usually, but it makes it hard to
nest multiple levels of masked gradient transformations (because the inner
masks will expect the mask to be defined in terms of the partial flattened view
that is not readily available outside the outer mask).
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)
kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)
optimizer = flax.optim.MultiOptimizer(
(kernels, kernel_opt),
(biases, bias_opt)
).create(variables['params'])
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)
all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)
tx = optax.chain(
optax.trace(decay=momentum),
optax.masked(optax.scale(-learning_rate), kernels_mask),
optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)
Final Words#
All above patterns can of course also be mixed and Optax makes it possible to encapsulate all these transformations into a single place outside the main training loop, which makes testing much easier.