RNNCellBase Upgrade Guide

RNNCellBase Upgrade Guide#

The RNNCellBase API has undergone some key updates aimed at enhancing usability:

  • The initialize_carry method has transitioned from a class method to an instance method, simplifying its application.

  • All necessary metadata is now stored directly within the cell instance, providing a streamlined method signature.

This guide will walk you through these changes, demonstrating how to update your existing code to align with these enhancements.

Basic Usage#

Let’s begin by defining some variables and a sample input that represents a batch of sequences:

batch_size = 32
seq_len = 10
in_features = 64
out_features = 128

x = jnp.ones((batch_size, seq_len, in_features))

First and foremost, it’s important to note that all metadata, including the number of features, carry initializer, and so on, is now stored within the cell instance:

cell = nn.LSTMCell()
cell = nn.LSTMCell(features=out_features)

A significant change is that initialize_carry has been transitioned into an instance method. Given that the cell instance now contains all metadata, the initialize_carry method’s signature only requires a PRNG key and a sample input:

carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features)
carry = cell.initialize_carry(jax.random.key(0), x[:, 0].shape)

Here, x[:, 0].shape represents the input for the cell (without the time dimension). You can also just create the input shape directly when its more convenient:

carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features))

Upgrade Patterns#

The following sections will demonstrate some useful patterns for updating your code to align with the new API.

First, we will show how to upgrade a Module that wraps a cell, applies the scan logic during __call__, and has a static initialize_carry method. Here, we will try to make the minimal amount of changes to the code to get it working, albeit not in the most idiomatic way:

class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):

    return nn.OptimizedLSTMCell()(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell.initialize_carry(
      jax.random.key(0), batch_dims, hidden_size)
class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):
    features = carry[0].shape[-1]
    return nn.OptimizedLSTMCell(features)(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell(hidden_size, parent=None).initialize_carry(
      jax.random.key(0), (*batch_dims, hidden_size))

Notice how in the new version, we have to extract the number of features from the carry during __call__, and use parent=None during initialize_carry to avoid some potential side effects.

Next, we will show a more idiomatic way of writing a similar LSTM module. The main change here will be that we will add a features attribute to the module and use it to initialize a nn.scan-ed version of the cell in the setup method:

class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):
    return nn.OptimizedLSTMCell()(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell.initialize_carry(
      jax.random.key(0), batch_dims, hidden_size)

model = SimpleLSTM()
carry = SimpleLSTM.initialize_carry((batch_size,), out_features)
variables = model.init(jax.random.key(0), carry, x)
class SimpleLSTM(nn.Module):
  features: int

  def setup(self):
    self.scan_cell = nn.transforms.scan(
      nn.OptimizedLSTMCell,
      variable_broadcast='params',
      in_axes=1, out_axes=1,
      split_rngs={'params': False})(self.features)


  @nn.compact
  def __call__(self, x):
    carry = self.scan_cell.initialize_carry(jax.random.key(0), x[:, 0].shape)
    return self.scan_cell(carry, x)[1]  # only return the output


model = SimpleLSTM(features=out_features)
variables = model.init(jax.random.key(0), x)

Because the carry can be easily initialized from the sample input, we can move the call to initialize_carry into the __call__ method, somewhat simplifying the code.

Development Notes#

When developing a new cell, consider the following:

  • Include necessary metadata as instance attributes.

  • The initialize_carry now only requires a PRNG key and a sample input.

  • A new num_feature_axes property is required to specify the number of feature dimensions.

class LSTMCell(nn.RNNCellBase):
  features: int # ← All metadata is now stored within the cell instance
  ... #              ↓
  carry_init: Initializer

  def initialize_carry(self, rng, input_shape) -> Carry:
    ...

  @property
  def num_feature_axes(self):
    return 1

num_feature_axes is a new API feature that allows code handling arbitrary RNNCellBase instances, such as the RNN Module, to infer the number of batch dimensions and determine the position of the time axis.