Extracting intermediate values#
This guide will show you how to extract intermediate values from a module.
Let’s start with this simple CNN that uses nn.compact
.
from flax import linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
Because this module uses nn.compact
, we don’t have direct access to
intermediate values. There are a few ways to expose them:
Store intermediate values in a new variable collection#
The CNN can be augmented with calls to sow
to store intermediates as following:
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
class SowCNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
self.sow('intermediates', 'features', x)
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
sow
acts as a no-op when the variable collection is not mutable.
Therefore, it works perfectly for debugging and optional tracking of intermediates.
The ‘intermediates’ collection is also used by the capture_intermediates
API (see the Use capture_intermediates section).
Note that, by default sow
appends values every time it is called:
This is necessary because once instantiated, a module could be called multiple times in its parent module, and we want to catch all the sowed values.
Therefore you want to make sure that you do not feed intermediate values back into
variables
. Otherwise every call will increase the length of that tuple and trigger a recompile.To override the default append behavior, specify
init_fn
andreduce_fn
- seeModule.sow()
.
class SowCNN2(nn.Module):
@nn.compact
def __call__(self, x):
mod = SowCNN(name='SowCNN')
return mod(x) + mod(x) # Calling same module instance twice.
@jax.jit
def init(key, x):
variables = SowCNN2().init(key, x)
# By default the 'intermediates' collection is not mutable during init.
# So variables will only contain 'params' here.
return variables
@jax.jit
def predict(variables, x):
# If mutable='intermediates' is not specified, then .sow() acts as a noop.
output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates')
features = mod_vars['intermediates']['SowCNN']['features']
return output, features
batch = jnp.ones((1,28,28,1))
variables = init(jax.random.key(0), batch)
preds, feats = predict(variables, batch)
assert len(feats) == 2 # Tuple with two values since module was called twice.
Refactor module into submodules#
This is a useful pattern for cases where it’s clear in what particular
way you want to split your submodules. Any submodule you expose in setup
can
be used directly. In the limit, you can define all submodules in setup
and
avoid using nn.compact
altogether.
class RefactoredCNN(nn.Module):
def setup(self):
self.features = Features()
self.classifier = Classifier()
def __call__(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class Features(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
return x
class Classifier(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
@jax.jit
def init(key, x):
variables = RefactoredCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return RefactoredCNN().apply({"params": params}, x,
method=lambda module, x: module.features(x))
params = init(jax.random.key(0), batch)
features(params, batch)
Use capture_intermediates
#
Linen supports the capture of intermediate return values from submodules automatically without any code changes. This pattern should be considered the “sledge hammer” approach to capturing intermediates. As a debugging and inspection tool it is very useful, but using the other patterns described in this guide will give you more fine-grained control over what intermediates you want to extract.
In the following code example we check if any intermediate activations are non-finite (NaN or infinite):
@jax.jit
def init(key, x):
variables = CNN().init(key, x)
return variables
@jax.jit
def predict(variables, x):
y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
intermediates = state['intermediates']
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin
variables = init(jax.random.key(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non-finite intermediate detected!"
By default only the intermediates of __call__
methods are collected.
Alternatively, you can pass a custom filter function based on the Module
instance and the method name.
filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"
y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']
Note that capture_intermediates
will only apply to layers. You can use self.sow
to manually store
non-layer intermediates, but the filter function won’t be applied to it.
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b # not a Flax layer, so won't be stored as an intermediate
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
return Model().apply({"params": params}, x, capture_intermediates=True)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model was not stored because it's not a Flax layer
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b
self.sow('intermediates', 'c', c) # store intermediate c
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
# filter specifically for only the Dense_0 and Dense_2 layer
filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'})
return Model().apply({"params": params}, x, capture_intermediates=filter_fn)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model is stored and isn't filtered out by the filter function
To separate the intermediates extracted from self.sow
from the intermediates extracted from capture_intermediates
,
we can either define a separate collection like self.sow('sow_intermediates', 'c', c)
, or manually filter out
the intermediates after calling .apply()
. For example:
flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/')
flattened_dict['c']
In terms of efficiency, as long as everything is jitted, then any intermediates you don’t end up using should be optimized away by XLA.
Use Sequential
#
You could also define CNN
using a simple implementation of a Sequential
combinator (this is quite common in more stateful approaches). This may be useful
for very simple models and gives you arbitrary model
surgery. But it can be very limiting – if you even want to add one conditional, you are
forced to refactor away from Sequential
and structure
your model more explicitly.
class Sequential(nn.Module):
layers: Sequence[nn.Module]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def SeqCNN():
return Sequential([
nn.Conv(features=32, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(features=64, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(features=256),
nn.relu,
nn.Dense(features=10),
nn.log_softmax,
])
@jax.jit
def init(key, x):
variables = SeqCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)
batch = jnp.ones((1,28,28,1))
params = init(jax.random.key(0), batch)
features(params, batch)
Extracting gradients of intermediate values#
For debugging purposes, it can be useful to extract the gradients of intermediate values.
This can be done by using the Module.perturb()
method over the desired values.
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(8)(x))
x = self.perturb('hidden', x)
x = nn.Dense(2)(x)
x = self.perturb('logits', x)
return x
perturb
adds a variable to a perturbations
collection by default,
it behaves like an identity function and the gradient of the perturbation
matches the gradient of the input. To get the perturbations just initialize
the model:
x = jnp.empty((1, 4)) # random data
y = jnp.empty((1, 2)) # random data
model = Model()
variables = model.init(jax.random.key(1), x)
params, perturbations = variables['params'], variables['perturbations']
Finally compute the gradients of the loss with respect to the perturbations, these will match the gradients of the intermediates:
def loss_fn(params, perturbations, x, y):
y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
return jnp.mean((y_pred - y) ** 2)
intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)