
flax.jax_utils package
========================

.. currentmodule:: flax.jax_utils

.. automodule:: flax.jax_utils


.. autofunction:: partial_eval_by_shape


Multi device utilities
------------------------

.. autofunction:: replicate
.. autofunction:: unreplicate

.. autofunction:: prefetch_to_device

.. autofunction:: pmean

.. autofunction:: pad_shard_unpad
