Variable dictionary#
A variable dict is a normal Python dictionary, which is a container for one
or more “variable collections”, each of which are nested dictionaries whose
leaves are jax.numpy
arrays.
The different variable collections share the same nested tree structure.
For example, consider the following variable dictionary:
{
"params": {
"Conv1": { "weight": ..., "bias": ... },
"BatchNorm1": { "scale": ..., "mean": ... },
"Conv2": {...}
},
"batch_stats": {
"BatchNorm1": { "moving_mean": ..., "moving_average": ...}
}
}
In this case, the "BatchNorm1"
key lives in both the "params"
and
`"batch_stats""
collections. This reflects the fact that the submodule
named ""BatchNorm1""
has both trainable parameters (the "params"
collection),
as well as other non-trainable variables (the "batch_stats"
collection)
TODO: Make “variable dict” design note, and link to it from here.
- class flax.linen.Variable(scope, collection, name, unbox)[source]#
A Variable object allows mutable access to a variable in a VariableDict.
Variables are identified by a collection (e.g., “batch_stats”) and a name (e.g., “moving_mean”). The value property gives access to the variable’s content and can be assigned to for mutation.