flax.cursor package#
The Cursor API allows for mutability of pytrees. This API provides a more
ergonomic solution to making partial-updates of deeply nested immutable
data structures, compared to making many nested dataclasses.replace
calls.
To illustrate, consider the example below:
>>> from flax.cursor import cursor
>>> import dataclasses
>>> from typing import Any
>>> @dataclasses.dataclass(frozen=True)
>>> class A:
... x: Any
>>> a = A(A(A(A(A(A(A(0)))))))
To replace the int 0
using dataclasses.replace
, we would have to write many nested calls:
>>> a2 = dataclasses.replace(
... a,
... x=dataclasses.replace(
... a.x,
... x=dataclasses.replace(
... a.x.x,
... x=dataclasses.replace(
... a.x.x.x,
... x=dataclasses.replace(
... a.x.x.x.x,
... x=dataclasses.replace(
... a.x.x.x.x.x,
... x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
... ),
... ),
... ),
... ),
... ),
... )
The equivalent can be achieved much more simply using the Cursor API:
>>> a3 = cursor(a).x.x.x.x.x.x.x.set(1)
>>> assert a2 == a3
The Cursor object keeps tracks of changes made to it and when .build
is called,
generates a new object with the accumulated changes. Basic usage involves
wrapping the object in a Cursor, making changes to the Cursor object and
generating a new copy of the original object with the accumulated changes.
- flax.cursor.cursor(obj)[source]#
Wrap
Cursor
overobj
and return it. Changes can then be applied to the Cursor object in the following ways:single-line change via the
.set
methodmultiple changes, and then calling the
.build
methodmultiple changes conditioned on the pytree path and node value via the
.apply_update
method, and then calling the.build
method
.set
example:>>> from flax.cursor import cursor >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10) >>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}
.build
example:>>> from flax.cursor import cursor >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> c = cursor(dict_obj) >>> c['b'][0] = 10 >>> c['a'] = (100, 200) >>> modified_dict_obj = c.build() >>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}
.apply_update
example:>>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> def update_fn(path, value): ... '''Replace params with empty dictionary.''' ... if 'params' in path: ... return {} ... return value >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params={'a': 1, 'b': 2}, ... tx=optax.adam(1e-3), ... ) >>> c = cursor(state) >>> state2 = c.apply_update(update_fn).build() >>> assert state2.params == {} >>> assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged
If the underlying
obj
is alist
ortuple
, iterating over the Cursor object to get the child Cursors is also possible:>>> from flax.cursor import cursor >>> c = cursor(((1, 2), (3, 4))) >>> for child_c in c: ... child_c[1] *= -1 >>> assert c.build() == ((1, -2), (3, -4))
View the docstrings for each method to see more examples of their usage.
- Parameters
obj – the object you want to wrap the Cursor in
- Returns
A Cursor object wrapped around obj.
- class flax.cursor.Cursor(obj, parent_key)[source]#
- apply_update(update_fn)[source]#
Traverse the Cursor object and record conditional changes recursively via an
update_fn
. The changes are recorded in the Cursor object’s._changes
dictionary. To generate a copy of the original object with the accumulated changes, call the.build
method after calling.apply_update
.The
update_fn
has a function signature of(str, Any) -> Any
:The input arguments are the current key path (in the form of a string delimited by
'/'
) and value at that current key pathThe output is the new value (either modified by the
update_fn
or same as the input value if the condition wasn’t fulfilled)
Note
If the
update_fn
returns a modified value, this method will not recurse any further down that branch to record changes. For example, if we intend to replace an attribute that points to a dictionary with an int, we don’t need to look for further changes inside the dictionary, since the dictionary will be replaced anyways.The
is
operator is used to determine whether the return value is modified (by comparing it to the input value). Therefore if theupdate_fn
modifies a mutable container (e.g. lists, dicts, etc.) and returns the same container,.apply_update
will treat the returned value as unmodified as it contains the sameid
. To avoid this, return a copy of the modified value..apply_update
WILL NOT call theupdate_fn
to the value at the top-most level of the pytree (i.e. the root node). Theupdate_fn
will first be called on the root node’s children, and then the pytree traversal will continue recursively from there.
Example:
>>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params'] >>> def update_fn(path, value): ... '''Multiply all dense kernel params by 2 and add 1. ... Subtract the Dense_1 bias param by 1.''' ... if 'kernel' in path: ... return value * 2 + 1 ... elif 'Dense_1' in path and 'bias' in path: ... return value - 1 ... return value >>> c = cursor(params) >>> new_params = c.apply_update(update_fn).build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all() ... if layer == 'Dense_1': ... assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all() ... else: ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.key(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged
- Parameters
update_fn – the function that will conditionally record changes to the Cursor object
- Returns
The current Cursor object with the recorded conditional changes specified by the
update_fn
. To generate a copy of the original object with the accumulated changes, call the.build
method after calling.apply_update
.
- build()[source]#
Create and return a copy of the original object with accumulated changes. This method is to be called after making changes to the Cursor object.
Note
The new object is built bottom-up, the changes will be first applied to the leaf nodes, and then its parent, all the way up to the root.
Example:
>>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> c = cursor(dict_obj) >>> c['b'][0] = 10 >>> c['a'] = (100, 200) >>> modified_dict_obj = c.build() >>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]} >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params=dict_obj, ... tx=optax.adam(1e-3), ... ) >>> new_fn = lambda x: x + 1 >>> c = cursor(state) >>> c.params['b'][1] = 10 >>> c.apply_fn = new_fn >>> modified_state = c.build() >>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]} >>> assert modified_state.apply_fn == new_fn
- Returns
A copy of the original object with the accumulated changes.
- find(cond_fn)[source]#
Traverse the Cursor object and return a child Cursor object that fulfill the conditions in the
cond_fn
. Thecond_fn
has a function signature of(str, Any) -> bool
:The input arguments are the current key path (in the form of a string delimited by
'/'
) and value at that current key pathThe output is a boolean, denoting whether to return the child Cursor object at this path
Raises a
CursorFindError
if no object or more than one object is found that fulfills the condition of thecond_fn
. We raise an error because the user should always expect this method to return the only object whose corresponding key path and value fulfill the condition of thecond_fn
.Note
If the
cond_fn
evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the “earliest” child node that fulfills the condition incond_fn
in a particular key path.find
WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). Thecond_fn
will be evaluated recursively, starting at the root node’s children.
Example:
>>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] >>> def cond_fn(path, value): ... '''Find the second dense layer params.''' ... return 'Dense_1' in path >>> new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1) >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... if layer == 'Dense_1': ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() ... else: ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> c = cursor(params) >>> c2 = c.find(cond_fn) >>> c2['kernel'] += 2 >>> c2['bias'] += 2 >>> new_params = c.build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... if layer == 'Dense_1': ... assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all() ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all() ... else: ... assert (new_params[layer]['kernel'] == params[layer]['kernel']).all() ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged
- Parameters
cond_fn – the function that will conditionally find child Cursor objects
- Returns
A child Cursor object that fulfills the condition in the
cond_fn
.
- find_all(cond_fn)[source]#
Traverse the Cursor object and return a generator of child Cursor objects that fulfill the conditions in the
cond_fn
. Thecond_fn
has a function signature of(str, Any) -> bool
:The input arguments are the current key path (in the form of a string delimited by
'/'
) and value at that current key pathThe output is a boolean, denoting whether to return the child Cursor object at this path
Note
If the
cond_fn
evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the “earliest” child nodes that fulfill the condition incond_fn
in a particular key path.find_all
WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). Thecond_fn
will be evaluated recursively, starting at the root node’s children.
Example:
>>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] >>> def cond_fn(path, value): ... '''Find all dense layer params.''' ... return 'Dense' in path >>> c = cursor(params) >>> for dense_params in c.find_all(cond_fn): ... dense_params['bias'] += 1 >>> new_params = c.build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged
- Parameters
cond_fn – the function that will conditionally find child Cursor objects
- Returns
A generator of child Cursor objects that fulfill the condition in the
cond_fn
.
- set(value)[source]#
Set a new value for an attribute, property, element or entry in the Cursor object and return a copy of the original object, containing the new set value.
Example:
>>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10) >>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]} >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params=dict_obj, ... tx=optax.adam(1e-3), ... ) >>> modified_state = cursor(state).params['b'][1].set(10) >>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
- Parameters
value – the value used to set an attribute, property, element or entry in the Cursor object
- Returns
A copy of the original object with the new set value.