flax.serialization package#
Serialization utilities for Jax.
All Flax classes that carry state (e.g., Optimizer) can be turned into a state dict of numpy arrays for easy serialization.
State dicts#
- flax.serialization.from_state_dict(target, state, name='.')[source]#
Restores the state of the given target using a state dict.
This function takes the current target as an argument. This lets us know the exact structure of the target, as well as lets us add assertions that shapes and dtypes don’t change.
In practice, none of the leaf values in
target
are actually used. Only the tree structure, shapes and dtypes.- Parameters
target – the object of which the state should be restored.
state – a dictionary generated by
to_state_dict
with the desired new state fortarget
.name – name of branch taken, used to improve deserialization error messages.
- Returns
A copy of the object with the restored state.
- flax.serialization.to_state_dict(target)[source]#
Returns a dictionary with the state of the given target.
- flax.serialization.register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False)[source]#
Register a type for serialization.
- Parameters
ty – the type to be registered
ty_to_state_dict – a function that takes an instance of ty and returns its state as a dictionary.
ty_from_state_dict – a function that takes an instance of ty and a state dict, and returns a copy of the instance with the restored state.
override – override a previously registered serialization handler (default: False).
Serialization with MessagePack#
- flax.serialization.msgpack_serialize(pytree, in_place=False)[source]#
Save data structure to bytes in msgpack format.
Low-level function that only supports python trees with array leaves, for custom objects use
to_bytes
. It splits arrays above MAX_CHUNK_SIZE into multiple chunks.- Parameters
pytree – python tree of dict, list, tuple with python primitives and array leaves.
in_place – boolean specifying if pytree should be modified in place.
- Returns
msgpack-encoded bytes of pytree.
- flax.serialization.msgpack_restore(encoded_pytree)[source]#
Restore data structure from bytes in msgpack format.
Low-level function that only supports python trees with array leaves, for custom objects use
from_bytes
.- Parameters
encoded_pytree – msgpack-encoded bytes of python tree.
- Returns
Python tree of dict, list, tuple with python primitive and array leaves.
- flax.serialization.to_bytes(target)[source]#
Save optimizer or other object as msgpack-serialized state-dict.
- Parameters
target – template object with state-dict registrations to be serialized to msgpack format. Typically a flax model or optimizer.
- Returns
Bytes of msgpack-encoded state-dict of
target
object.
- flax.serialization.from_bytes(target, encoded_bytes)[source]#
Restore optimizer or other object from msgpack-serialized state-dict.
- Parameters
target – template object with state-dict registrations that matches the structure being deserialized from
encoded_bytes
.encoded_bytes – msgpack serialized object structurally isomorphic to
target
. Typically a flax model or optimizer.
- Returns
A new object structurally isomorphic to
target
containing the updated leaf data from saved data.