flax.serialization package#
Serialization utilities for Jax.
All Flax classes that carry state (e.g., Optimizer) can be turned into a state dict of numpy arrays for easy serialization.
State dicts#
- flax.serialization.from_state_dict(target, state, name='.')[source]#
Restores the state of the given target using a state dict.
This function takes the current target as an argument. This lets us know the exact structure of the target, as well as lets us add assertions that shapes and dtypes don’t change.
In practice, none of the leaf values in
targetare actually used. Only the tree structure, shapes and dtypes.- Parameters:
target – the object of which the state should be restored.
state – a dictionary generated by
to_state_dictwith the desired new state fortarget.name – name of branch taken, used to improve deserialization error messages.
- Returns:
A copy of the object with the restored state.
- flax.serialization.to_state_dict(target)[source]#
Returns a dictionary with the state of the given target.
- flax.serialization.register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False)[source]#
Register a type for serialization.
- Parameters:
ty – the type to be registered
ty_to_state_dict – a function that takes an instance of ty and returns its state as a dictionary.
ty_from_state_dict – a function that takes an instance of ty and a state dict, and returns a copy of the instance with the restored state.
override – override a previously registered serialization handler (default: False).
Serialization with MessagePack#
- flax.serialization.msgpack_serialize(pytree, in_place=False)[source]#
Save data structure to bytes in msgpack format.
Low-level function that only supports python trees with array leaves, for custom objects use
to_bytes. It splits arrays above MAX_CHUNK_SIZE into multiple chunks.- Parameters:
pytree – python tree of dict, list, tuple with python primitives and array leaves.
in_place – boolean specifying if pytree should be modified in place.
- Returns:
msgpack-encoded bytes of pytree.
- flax.serialization.msgpack_restore(encoded_pytree)[source]#
Restore data structure from bytes in msgpack format.
Low-level function that only supports python trees with array leaves, for custom objects use
from_bytes.- Parameters:
encoded_pytree – msgpack-encoded bytes of python tree.
- Returns:
Python tree of dict, list, tuple with python primitive and array leaves.
- flax.serialization.to_bytes(target)[source]#
Save optimizer or other object as msgpack-serialized state-dict.
- Parameters:
target – template object with state-dict registrations to be serialized to msgpack format. Typically a flax model or optimizer.
- Returns:
Bytes of msgpack-encoded state-dict of
targetobject.
- flax.serialization.from_bytes(target, encoded_bytes)[source]#
Restore optimizer or other object from msgpack-serialized state-dict.
- Parameters:
target – template object with state-dict registrations that matches the structure being deserialized from
encoded_bytes.encoded_bytes – msgpack serialized object structurally isomorphic to
target. Typically a flax model or optimizer.
- Returns:
A new object structurally isomorphic to
targetcontaining the updated leaf data from saved data.