Model surgery#
Usually, Flax modules and optimizers track and update the params for you. But there may be some time when you want to do some model surgery and tweak the param tensors yourself. This guide shows you how to do the trick.
Setup#
!pip install --upgrade -q pip jax jaxlib flax
import functools
import jax
import jax.numpy as jnp
from flax import traverse_util
from flax import linen as nn
from flax.core import freeze
import jax
import optax
Surgery with Flax Modules#
Let’s create a small convolutional neural network model for our demo.
As usual, you can run CNN.init(...)['params']
to get the params
to pass and modify it in every step of your training.
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
def get_initial_params(key):
init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_shape)['params']
return initial_params
key = jax.random.key(0)
params = get_initial_params(key)
jax.tree_util.tree_map(jnp.shape, params)
{'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)},
'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)},
'Dense_0': {'bias': (256,), 'kernel': (3136, 256)},
'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}
Note that what returned as params
is a FrozenDict
, which contains a few JAX arrays as kernel and bias.
A FrozenDict
is nothing more than a read-only dict, and Flax made it read-only because of the functional nature of JAX: JAX arrays are immutable, and the new params
need to replace the old params
. Making the dict read-only ensures that no in-place mutation of the dict can happen accidentally during the training and updating.
One way to actually modify the params outside of a Flax module is to explicitly flatten it and creates a mutable dict. Note that you can use a separator sep
to join all nested keys. If no sep
is given, the key will be a tuple of all nested keys.
# Get a flattened key-value list.
flat_params = traverse_util.flatten_dict(params, sep='/')
jax.tree_util.tree_map(jnp.shape, flat_params)
{'Conv_0/bias': (32,),
'Conv_0/kernel': (3, 3, 1, 32),
'Conv_1/bias': (64,),
'Conv_1/kernel': (3, 3, 32, 64),
'Dense_0/bias': (256,),
'Dense_0/kernel': (3136, 256),
'Dense_1/bias': (10,),
'Dense_1/kernel': (256, 10)}
Now you can do whatever you want with the params. When you are done, unflatten it back and use it in future training.
# Somehow modify a layer
dense_kernel = flat_params['Dense_1/kernel']
flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)
# Unflatten.
unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')
# Refreeze.
unflat_params = freeze(unflat_params)
jax.tree_util.tree_map(jnp.shape, unflat_params)
FrozenDict({
Conv_0: {
bias: (32,),
kernel: (3, 3, 1, 32),
},
Conv_1: {
bias: (64,),
kernel: (3, 3, 32, 64),
},
Dense_0: {
bias: (256,),
kernel: (3136, 256),
},
Dense_1: {
bias: (10,),
kernel: (256, 10),
},
})
Surgery with Optimizers#
When using Optax
as an optimizer, the opt_state
is actually a nested tuple
of the states of individual gradient transformations that compose the optimizer.
These states contain pytrees that mirror the parameter tree, and can be modified
the same way: flattening, modifying, unflattening, and then recreating a new
optimizer state that mirrors the original state.
tx = optax.adam(1.0)
opt_state = tx.init(params)
# The optimizer state is a tuple of gradient transformation states.
jax.tree_util.tree_map(jnp.shape, opt_state)
(ScaleByAdamState(count=(), mu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}, nu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}),
EmptyState())
The pytrees inside the optimizer state follow the same structure as the parameters and can be flattened / modified exactly the same way.
flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')
flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')
jax.tree_util.tree_map(jnp.shape, flat_mu)
{'Conv_0/bias': (32,),
'Conv_0/kernel': (3, 3, 1, 32),
'Conv_1/bias': (64,),
'Conv_1/kernel': (3, 3, 32, 64),
'Dense_0/bias': (256,),
'Dense_0/kernel': (3136, 256),
'Dense_1/bias': (10,),
'Dense_1/kernel': (256, 10)}
After modification, re-create optimizer state. Use this for future training.
opt_state = (
opt_state[0]._replace(
mu=traverse_util.unflatten_dict(flat_mu, sep='/'),
nu=traverse_util.unflatten_dict(flat_nu, sep='/'),
),
) + opt_state[1:]
jax.tree_util.tree_map(jnp.shape, opt_state)
(ScaleByAdamState(count=(), mu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}, nu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}),
EmptyState())