flax.traverse_util package#
A utility for traversing immutable datastructures.
A Traversal can be used to iterate and update complex data structures. Traversals take in an object and return a subset of its contents. For example, a Traversal could select an attribute of an object:
>>> from flax import traverse_util
>>> import dataclasses
>>> @dataclasses.dataclass
... class Foo:
... foo: int = 0
... bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]
More complex traversals can be constructed using composition. It is often useful to start from the identity traversal and use a method chain to construct the intended Traversal:
>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]
Traversals can also be used to make changes using the update
method:
>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}
Traversals never mutate the original data. Therefore, an update essentially returns a copy of the data including the provided updates.
Traversal objects#
- class flax.traverse_util.Traversal(*args, **kwargs)[source]#
Base class for all traversals.
- abstract iterate(inputs)[source]#
Iterate over the values selected by this
Traversal
.- Parameters
inputs – the object that should be traversed.
- Returns
An iterator over the traversed values.
- class flax.traverse_util.TraverseId(*args, **kwargs)[source]#
The identity Traversal.
- class flax.traverse_util.TraverseMerge(*args, **kwargs)[source]#
Merges the selection from a set of traversals.
- class flax.traverse_util.TraverseCompose(*args, **kwargs)[source]#
Compose two traversals.
- class flax.traverse_util.TraverseFilter(*args, **kwargs)[source]#
Filter selected values based on a predicate.
- class flax.traverse_util.TraverseAttr(*args, **kwargs)[source]#
Traverse the attribute of an object.
- class flax.traverse_util.TraverseItem(*args, **kwargs)[source]#
Traverse the item of an object.
- class flax.traverse_util.TraverseEach(*args, **kwargs)[source]#
Traverse each item of a container.
Dict utils#
- flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#
Flatten a nested dictionary.
The nested keys are flattened to a tuple. See
unflatten_dict
on how to restore the nested dictionary structure.Example:
>>> from flax.traverse_util import flatten_dict >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = flatten_dict(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2}
Note that empty dictionaries are ignored and will not be restored by
unflatten_dict
.- Parameters
xs – a nested dictionary
keep_empty_nodes – replaces empty dictionaries with
traverse_util.empty_node
.is_leaf – an optional function that takes the next nested dictionary and nested keys and returns True if the nested dictionary is a leaf (i.e., should not be flattened further).
sep – if specified, then the keys of the returned dictionary will be
sep
-joined strings (ifNone
, then keys will be tuples).
- Returns
The flattened dictionary.
- flax.traverse_util.unflatten_dict(xs, sep=None)[source]#
Unflatten a dictionary.
See
flatten_dict
Example:
>>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = unflatten_dict(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}}
- Parameters
xs – a flattened dictionary
sep – separator (same as used with
flatten_dict()
).
- Returns
The nested dictionary.
- flax.traverse_util.path_aware_map(f, nested_dict)[source]#
A map function that operates over nested dictionary structures while taking the path to each leaf into account.
Example:
>>> import jax.numpy as jnp >>> from flax import traverse_util >>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}} >>> f = lambda path, x: x + 5 if 'x' in path else -x >>> traverse_util.path_aware_map(f, params) {'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
- Parameters
f – A callable that takes in
(path, value)
arguments and maps them to a new value. Herepath
is a tuple of strings.nested_dict – A nested dictionary structure.
- Returns
A new nested dictionary structure with the mapped values.
Model parameter traversal#
- class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[source]#
Select model parameters using a name filter.
This traversal operates on a nested dictionary of parameters and selects a subset based on the
filter_fn
argument.See
flax.optim.MultiOptimizer
for an example of how to useModelParamTraversal
to update subsets of the parameter tree with a specific optimizer.- __init__(filter_fn)[source]#
Constructor a new ModelParamTraversal.
- Parameters
filter_fn – a function that takes a parameter’s full name and its value and returns whether this parameter should be selected or not. The name of a parameter is determined by the module hierarchy and the parameter name (for example: ‘/module/sub_module/parameter_name’).