flax.training package#

Checkpoints#

Checkpointing helper functions.

Handles saving and restoring optimizer checkpoints based on step-number or other numerical metric in filename. Cleans up older / worse-performing checkpoint files.

flax.training.checkpoints.save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, orbax_checkpointer=None)[source]#

Save a checkpoint of the model. Suitable for single-host.

In this method, every JAX process saves the checkpoint on its own. Do not use it if you have multiple processes and you intend for them to save data to a common directory (e.g., a GCloud bucket). To save multi-process checkpoints to a shared storage or to save GlobalDeviceArray``s, use ``save_checkpoint_multiprocess() instead.

Pre-emption safe by writing to temporary before a final rename and cleanup of past files. However, if async_manager is used, the final commit will happen inside an async callback, which can be explicitly waited by calling async_manager.wait_previous_save().

Example usage:

>>> from flax.training import checkpoints
>>> import jax.numpy as jnp
>>> import tempfile

>>> with tempfile.TemporaryDirectory() as dir_path:
...   test_object = {
...     'a': jnp.array([1, 2, 3], jnp.int32),
...     'b': jnp.array([1, 1, 1], jnp.int32),
...   }
...   file_path = checkpoints.save_checkpoint(
...     dir_path, target=test_object, step=0, prefix='test_', keep=1
...   )
...   restored_object = checkpoints.restore_checkpoint(
...     file_path, target=None
...   )
>>> restored_object
{'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
Parameters
  • ckpt_dir – str or pathlib-like path to store checkpoint files in.

  • target – serializable flax object, usually a flax optimizer.

  • step – int or float: training step number or other metric number.

  • prefix – str: checkpoint file name prefix.

  • keep – number of past checkpoint files to keep.

  • overwrite – overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False).

  • keep_every_n_steps – if defined, keep every checkpoints every n steps (in addition to keeping the last ‘keep’ checkpoints).

  • async_manager – if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly.

  • orbax_checkpointer – if defined, the save will be done by ocp. In the future, all Flax checkpointing features will be migrated to Orbax, and starting to use an orbax_checkpointer is recommended. Please check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers.

Returns

Filename of saved checkpoint.

flax.training.checkpoints.save_checkpoint_multiprocess(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, gda_manager=None, orbax_checkpointer=None)[source]#

Save a checkpoint of the model in multi-process environment.

Use this method to save ``GlobalDeviceArray``s, or to save data to a common directory. Only process 0 will save the main checkpoint file and remove old checkpoint files.

Pre-emption safe by writing to temporary before a final rename and cleanup of past files. However, if async_manager or gda_manager is used, the final commit will happen inside an async callback, which can be explicitly waited by calling async_manager.wait_previous_save() or gda_manager.wait_until_finished().

Parameters
  • ckpt_dir – str or pathlib-like path to store checkpoint files in.

  • target – serializable flax object, usually a flax optimizer.

  • step – int or float: training step number or other metric number.

  • prefix – str: checkpoint file name prefix.

  • keep – number of past checkpoint files to keep.

  • overwrite – overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False).

  • keep_every_n_steps – if defined, keep every checkpoints every n steps (in addition to keeping the last ‘keep’ checkpoints).

  • async_manager – if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly.

  • gda_manager – required if target contains a JAX GlobalDeviceArray. Will save the GDAs to a separate subdirectory with postfix “_gda” asynchronously. Same as async_manager, this will block subsequent saves.

  • orbax_checkpointer – if defined, the save will be done by Orbax In the future, all Flax checkpointing features will be migrated to Orbax, and starting to use an orbax_checkpointer is recommended. Please check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers.

Returns

Filename of saved checkpoint.

flax.training.checkpoints.latest_checkpoint(ckpt_dir, prefix='checkpoint_')[source]#

Retrieve the path of the latest checkpoint in a directory.

Parameters
  • ckpt_dir – str: directory of checkpoints to restore from.

  • prefix – str: name prefix of checkpoint files.

Returns

The latest checkpoint path or None if no checkpoints were found.

flax.training.checkpoints.restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True, gda_manager=None, allow_partial_mpa_restoration=False, orbax_checkpointer=None, orbax_transforms=None)[source]#

Restore last/best checkpoint from checkpoints in path.

Sorts the checkpoint files naturally, returning the highest-valued file, e.g.:

  • ckpt_1, ckpt_2, ckpt_3 --> ckpt_3

  • ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1

  • ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

Example usage:

>>> from flax.training import checkpoints
>>> import jax.numpy as jnp
>>> import tempfile
...
>>> with tempfile.TemporaryDirectory() as dir_path:
...   test_object = {
...     'a': jnp.array([1, 2, 3], jnp.int32),
...     'b': jnp.array([1, 1, 1], jnp.int32),
...   }
...   file_path = checkpoints.save_checkpoint(
...     dir_path, target=test_object, step=0, prefix='test_', keep=1
...   )
...   restored_object = checkpoints.restore_checkpoint(
...     file_path, target=None
...   )
>>> restored_object
{'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
Parameters
  • ckpt_dir – str: checkpoint file or directory of checkpoints to restore from.

  • target – matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is.

  • step – int or float: step number to load or None to load latest. If specified, ckpt_dir must be a directory.

  • prefix – str: name prefix of checkpoint files.

  • parallel – bool: whether to load seekable checkpoints in parallel, for speed.

  • gda_manager – required if checkpoint contains a multiprocess array (GlobalDeviceArray or jax Array from pjit). Will read the arrays from the separate subdirectory with postfix “_gda”.

  • allow_partial_mpa_restoration – If true, the given target doesn’t have to contain all valid multiprocess arrays. As a result, the restored Pytree may have some MPAs not restored correctly. Use this if you cannot provide a fully valid target and don’t need all the MPAs in the checkpoint to be restored.

  • orbax_checkpointer – the ocp.Checkpointer that handles the underlying restore, if the given checkpoint is saved with ocp.

  • orbax_transforms – the Orbax transformations that will be passed into orbax_checkpointer.restore() call.

Returns

Restored target updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in target unchanged. If a file path is specified and is not found, the passed-in target will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created.

flax.training.checkpoints.convert_pre_linen(params)[source]#

Converts a pre-Linen parameter pytree.

In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class.

Consider the following module:

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(1, 1)(x)
    x = nn.Dense(1)(x)
    return x

In pre-Linen the resulting params would have had the structure:

{'Conv_0': { ... }, 'Dense_1': { ... } }

With Linen the resulting params would instead have had the structure:

{'Conv_0': { ... }, 'Dense_0': { ... } }

To convert from pre-Linen format to Linen simply call:

params = convert_pre_linen(pre_linen_params)

Note that you can also use this utility to convert pre-Linen collections because they’re following the same module naming. Note though that collections were “flat” in pre-Linen and first need to be unflattened before they can be used with this function:

batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({
    tuple(k.split('/')[1:]): v
    for k, v in pre_linen_model_state.as_dict().items()
}))

Then Linen variables can be defined from these converted collections:

variables = {'params': params, 'batch_stats': batch_stats}
Parameters

params – Parameter pytree in pre-Linen format. If the pytree is already in Linen format, then the returned pytree is unchanged (i.e. this function can safely be called on any loaded checkpoint for use with Linen).

Returns

Parameter pytree with Linen submodule naming.

Learning rate schedules#

Learning rate schedules used in FLAX image classification examples.

Note that with FLIP #1009 learning rate schedules in flax.training are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.

flax.training.lr_schedule.create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, warmup_length=0.0)[source]#

Create a constant learning rate schedule with optional warmup.

Note that with FLIP #1009 learning rate schedules in flax.training are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.

Holds the learning rate constant. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.

Parameters
  • base_learning_rate – the base learning rate

  • steps_per_epoch – the number of iterations per epoch

  • warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs

Returns

Function f(step) -> lr that computes the learning rate for a given step.

flax.training.lr_schedule.create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0)[source]#

Create a stepped learning rate schedule with optional warmup.

Note that with FLIP #1009 learning rate schedules in flax.training are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.

A stepped learning rate schedule decreases the learning rate by specified amounts at specified epochs. The steps are given as the lr_sched_steps parameter. A common ImageNet schedule decays the learning rate by a factor of 0.1 at epochs 30, 60 and 80. This would be specified as:

[
  [30, 0.1],
  [60, 0.01],
  [80, 0.001]
]

This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.

Parameters
  • base_learning_rate – the base learning rate

  • steps_per_epoch – the number of iterations per epoch

  • lr_sched_steps – the schedule as a list of steps, each of which is a [epoch, lr_factor] pair; the step occurs at epoch epoch and sets the learning rate to base_learning_rage * lr_factor

  • warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs

Returns

Function f(step) -> lr that computes the learning rate for a given step.

flax.training.lr_schedule.create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0)[source]#

Create a cosine learning rate schedule with optional warmup.

Note that with FLIP #1009 learning rate schedules in flax.training are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.

A cosine learning rate schedule modules the learning rate with half a cosine wave, gradually scaling it to 0 at the end of training.

This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.

Parameters
  • base_learning_rate – the base learning rate

  • steps_per_epoch – the number of iterations per epoch

  • halfcos_epochs – the number of epochs to complete half a cosine wave; normally the number of epochs used for training

  • warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs

Returns

Function f(step) -> lr that computes the learning rate for a given step.

Train state#

class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#

Simple train state for the common case with a single Optax optimizer.

Example usage:

>>> import flax.linen as nn
>>> from flax.training.train_state import TrainState
>>> import jax, jax.numpy as jnp
>>> import optax

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 2))
>>> model = nn.Dense(2)
>>> variables = model.init(jax.random.key(0), x)
>>> tx = optax.adam(1e-3)

>>> state = TrainState.create(
...     apply_fn=model.apply,
...     params=variables['params'],
...     tx=tx)

>>> def loss_fn(params, x, y):
...   predictions = state.apply_fn({'params': params}, x)
...   loss = optax.l2_loss(predictions=predictions, targets=y).mean()
...   return loss
>>> loss_fn(state.params, x, y)
Array(3.3514676, dtype=float32)

>>> grads = jax.grad(loss_fn)(state.params, x, y)
>>> state = state.apply_gradients(grads=grads)
>>> loss_fn(state.params, x, y)
Array(3.343844, dtype=float32)

Note that you can easily extend this dataclass by subclassing it for storing additional data (e.g. additional variable collections).

For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.

Parameters
  • step – Counter starts at 0 and is incremented by every call to .apply_gradients().

  • apply_fn – Usually set to model.apply(). Kept in this dataclass for convenience to have a shorter params list for the train_step() function in your training loop.

  • params – The parameters to be updated by tx and used by apply_fn.

  • tx – An Optax gradient transformation.

  • opt_state – The state for tx.

apply_gradients(*, grads, **kwargs)[source]#

Updates step, params, opt_state and **kwargs in return value.

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters
  • grads – Gradients that have the same pytree structure as .params.

  • **kwargs – Additional dataclass attributes that should be .replace()-ed.

Returns

An updated instance of self with step incremented by one, params and opt_state updated by applying grads, and additional attributes replaced as specified by kwargs.

classmethod create(*, apply_fn, params, tx, **kwargs)[source]#

Creates a new instance with step=0 and initialized opt_state.

Early Stopping#

class flax.training.early_stopping.EarlyStopping(min_delta=0, patience=0, best_metric=inf, patience_count=0, should_stop=False, has_improved=False)[source]#

Early stopping to avoid overfitting during training.

The following example stops training early if the difference between losses recorded in the current epoch and previous epoch is less than 1e-3 consecutively for 2 times:

>>> from flax.training.early_stopping import EarlyStopping

>>> def train_epoch(optimizer, train_ds, batch_size, epoch, input_rng):
...   ...
...   loss = [4, 3, 3, 3, 2, 2, 2, 2, 1, 1][epoch]
...   return None, {'loss': loss}

>>> early_stop = EarlyStopping(min_delta=1e-3, patience=2)
>>> optimizer = None
>>> for epoch in range(10):
...   optimizer, train_metrics = train_epoch(
...       optimizer=optimizer, train_ds=None, batch_size=None, epoch=epoch, input_rng=None)
...   early_stop = early_stop.update(train_metrics['loss'])
...   if early_stop.should_stop:
...     print(f'Met early stopping criteria, breaking at epoch {epoch}')
...     break
Met early stopping criteria, breaking at epoch 7
min_delta#

Minimum delta between updates to be considered an improvement.

Type

float

patience#

Number of steps of no improvement before stopping.

Type

int

best_metric#

Current best metric value.

Type

float

patience_count#

Number of steps since last improving update.

Type

int

should_stop#

Whether the training loop should stop to avoid overfitting.

Type

bool

has_improved#

Whether the metric has improved greater or equal to the min_delta in the last .update call.

Type

bool

update(metric)[source]#

Update the state based on metric.

Returns

The updated EarlyStopping class. The .has_improved attribute is True when there was an improvement greater than min_delta from the previous best_metric.

Common Utilities#

flax.training.common_utils.shard(xs)[source]#

Helper for pmap to shard a pytree of arrays by local_device_count.

Parameters

xs – a pytree of arrays.

Returns

A matching pytree with arrays’ leading dimensions sharded by the local device count.

flax.training.common_utils.shard_prng_key(prng_key)[source]#

Helper to shard (aka split) a PRNGKey for use with pmap’d functions.

PRNG keys can be used at train time to drive stochastic modules e.g. Dropout. We would like a different PRNG key for each local device so that we end up with different random numbers on each one, hence we split our PRNG key.

Parameters

prng_key – JAX PRNGKey

Returns

A new array of PRNGKeys with leading dimension equal to local device count.

flax.training.common_utils.stack_forest(forest)[source]#

Helper function to stack the leaves of a sequence of pytrees.

Parameters

forest – a sequence of pytrees (e.g tuple or list) of matching structure whose leaves are arrays with individually matching shapes.

Returns

A single pytree of the same structure whose leaves are individually

stacked arrays.

flax.training.common_utils.get_metrics(device_metrics)[source]#

Helper utility for pmap, gathering replicated timeseries metric data.

Parameters

device_metrics – replicated, device-resident pytree of metric data, whose leaves are presumed to be a sequence of arrays recorded over time.

Returns

A pytree of unreplicated, host-resident, stacked-over-time arrays useful for computing host-local statistics and logging.

flax.training.common_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#

Create a dense one-hot version of an indexed array.

NB: consider using the more standard jax.nn.one_hot instead.

Parameters
  • labels – an n-dim JAX array whose last dimension contains integer indices.

  • num_classes – the maximum possible index.

  • on_value – the “on” value for the one-hot array, defaults to 1.0.

  • off_value – the “off” value for the one-hot array, defaults to 0.0.

Returns

A (n+1)-dim array whose last dimension contains one-hot vectors of length num_classes.